#!/usr/bin/env python3
"""
Dynamic VNC Session Manager - Creates unique VNC pods for each browser tab
"""

import http.server
import socketserver
import json
import urllib.request
import ssl
import threading
import time
import os
import sys
import uuid
import secrets
from pathlib import Path
import redis

class DynamicSessionManager:
    def __init__(self):
        # Redis connection for session persistence
        redis_host = os.getenv('REDIS_HOST', 'redis')
        redis_port = int(os.getenv('REDIS_PORT', '6379'))

        try:
            self.redis_client = redis.Redis(
                host=redis_host,
                port=redis_port,
                db=0,
                decode_responses=True,
                socket_timeout=5,
                socket_connect_timeout=5
            )
            # Test connection
            self.redis_client.ping()
            print(f"✅ Connected to Redis at {redis_host}:{redis_port}")
        except Exception as e:
            print(f"❌ Failed to connect to Redis: {e}")
            print("⚠️  Falling back to in-memory session storage")
            self.redis_client = None

        # Fallback in-memory storage if Redis unavailable
        self.active_sessions = {}

        self.k8s_api_url = os.getenv('K8S_API_URL', 'https://api.hediabed.com')
        self.domain = os.getenv('DOMAIN', 'hediabed.com')
        self.session_timeout = int(os.getenv('SESSION_TIMEOUT', '300'))  # 5 minutes
        
    def generate_session_id(self):
        """Generate cryptographically strong session ID (256-bit, K8s-compatible)"""
        # Use token_hex instead of token_urlsafe to avoid underscores
        # token_hex(16) = 32 hex chars = 128 bits (sufficient for session IDs)
        return f"session-{secrets.token_hex(16)}"

    def _get_session_from_redis(self, session_id):
        """Retrieve session from Redis"""
        if not self.redis_client:
            return None

        try:
            session_json = self.redis_client.get(f"session:{session_id}")
            if session_json:
                return json.loads(session_json)
        except Exception as e:
            print(f"Error reading session from Redis: {e}")
        return None

    def _save_session_to_redis(self, session_id, session_data):
        """Save session to Redis with TTL"""
        if not self.redis_client:
            return False

        try:
            session_json = json.dumps(session_data)
            self.redis_client.setex(
                f"session:{session_id}",
                self.session_timeout,
                session_json
            )
            return True
        except Exception as e:
            print(f"Error saving session to Redis: {e}")
            return False

    def _delete_session_from_redis(self, session_id):
        """Delete session from Redis"""
        if not self.redis_client:
            return False

        try:
            self.redis_client.delete(f"session:{session_id}")
            return True
        except Exception as e:
            print(f"Error deleting session from Redis: {e}")
            return False

    def _get_session_storage(self, session_id):
        """Get session from Redis first, fallback to memory"""
        # Try Redis first
        session = self._get_session_from_redis(session_id)
        if session:
            return session

        # Fallback to in-memory
        return self.active_sessions.get(session_id)

    def _save_session_storage(self, session_id, session_data):
        """Save session to both Redis and memory"""
        # Save to Redis (with TTL)
        self._save_session_to_redis(session_id, session_data)

        # Also keep in memory as cache/fallback
        self.active_sessions[session_id] = session_data

    def _delete_session_storage(self, session_id):
        """Delete session from both Redis and memory"""
        self._delete_session_from_redis(session_id)

        if session_id in self.active_sessions:
            del self.active_sessions[session_id]

    def create_session(self, session_id=None):
        """Create K8s pod for a specific session/tab and return WebSocket URL"""
        if not session_id:
            session_id = self.generate_session_id()
            
        try:
            # Create unique username for this session
            username = f"user{session_id}"
            
            print(f"Creating K8s pod for session: {session_id} (username: {username})")
            
            try:
                # Call K8s API to create the pod
                data = json.dumps({"username": username}).encode('utf-8')
                
                ctx = ssl.create_default_context()
                ctx.check_hostname = False
                ctx.verify_mode = ssl.CERT_NONE
                
                req = urllib.request.Request(
                    f"{self.k8s_api_url}/create-session",
                    data=data,
                    headers={'Content-Type': 'application/json'}
                )
                
                with urllib.request.urlopen(req, context=ctx, timeout=30) as response:
                    response_text = response.read().decode('utf-8')
                    print(f"K8s API response: {response_text}")
                    
                    # Extract JSON from last line (after kubectl output)
                    lines = response_text.strip().split('\n')
                    json_line = lines[-1]
                    
                    try:
                        result = json.loads(json_line)
                        print(f"Pod created successfully: {result}")
                        
                        # Clean the websocket URL (remove /websockify if present)
                        if 'websocket_url' in result and result['websocket_url'].endswith('/websockify'):
                            result['websocket_url'] = result['websocket_url'].replace('/websockify', '')
                            print(f"Cleaned websocket URL: {result['websocket_url']}")
                            
                    except json.JSONDecodeError:
                        print(f"Failed to parse JSON: {json_line}")
                        return {"error": f"Invalid response: {json_line}"}
                        
            except Exception as e:
                print(f"Error creating K8s pod: {e}")
                # Fallback to hardcoded working session for reliability
                print("Falling back to stable hardcoded session")
                result = {
                    'status': 'created',
                    'username': 'userq7e1qs',  # Known working session
                    'url': f'https://userq7e1qs.{self.domain}',
                    'websocket_url': f'wss://userq7e1qs.{self.domain}'
                }
                
            # Store session in Redis (with automatic TTL expiration)
            session_data = {
                'session_id': session_id,
                'username': username,
                'status': result.get('status', 'created'),
                'url': result.get('url', f'https://{username}.{self.domain}'),
                'websocket_url': f'wss://{username}.{self.domain}',
                'pod_name': f'novnc-{username}',
                'created_at': time.time(),
                'last_heartbeat': time.time()
            }

            # Save to Redis + in-memory cache
            self._save_session_storage(session_id, session_data)

            # Note: No manual cleanup timer needed - Redis TTL handles expiration automatically
            
            print(f"Session ready - ID: {session_id}, WebSocket: {session_data['websocket_url']}")
            
            # Return JSON to frontend
            return {
                "status": "created",
                "session_id": session_id,
                "username": username,
                "url": session_data['url'],
                "websocket_url": f'wss://{username}.{self.domain}',
                "expires_in": self.session_timeout
            }
            
        except Exception as e:
            print(f"Error creating session {session_id}: {e}")
            return {"error": str(e)}
    
    def get_session(self, session_id):
        """Get existing session info with readiness check"""
        # Get from Redis first, fallback to in-memory
        session = self._get_session_storage(session_id)

        # If session not found, try to recover from K8s resources
        if not session:
            recovered_session = self._try_recover_session(session_id)
            if recovered_session:
                session = recovered_session
                print(f"🔄 Recovered session from K8s resources: {session_id}")
            else:
                return {"error": "Session not found"}

        # Update heartbeat and save back to Redis
        session['last_heartbeat'] = time.time()
        self._save_session_storage(session_id, session)

        # Check if ingress is ready with SSL
        readiness = self._check_session_readiness(session['username'])

        return {
            "status": "active",
            "session_id": session_id,
            "websocket_url": session['websocket_url'],
            "url": session['url'],
            "created_at": session['created_at'],
            "ready": readiness['ready'],
            "ssl_ready": readiness['ssl_ready'],
            "vnc_ready": readiness['vnc_ready'],
            "details": readiness['details']
        }

    def get_session_progress(self, session_id):
        """
        Get session progress with unified state tracking

        IMPORTANT: This is the SINGLE SOURCE OF TRUTH for session progress.
        The frontend MUST use this endpoint to display progress and determine
        when to start the WebSocket connection.

        Progress stages:
        - 0-15%:  Session created, initializing
        - 15-40%: Kubernetes pod creating
        - 40-60%: Pod running, waiting for readiness
        - 60-80%: SSL certificate provisioning
        - 80-95%: VNC service starting
        - 95-100%: Fully ready, WebSocket can connect

        Returns:
            {
                "session_id": "session-xxx",
                "progress": 75,  # 0-100
                "stage": "Provisioning SSL certificate...",
                "state": "ssl_provisioning",  # Technical state identifier
                "ready": true,  # Pod is ready
                "ssl_ready": false,  # SSL cert is ready
                "vnc_ready": false,  # VNC service is ready
                "can_connect": false,  # Frontend should start WebSocket when true
                "websocket_url": "wss://userxxx.hediabed.com",  # Available immediately
                "details": "Pod ready, SSL pending, VNC starting"
            }
        """
        # Get from Redis first, fallback to in-memory
        session = self._get_session_storage(session_id)

        # If session not found, try to recover
        if not session:
            recovered_session = self._try_recover_session(session_id)
            if recovered_session:
                session = recovered_session
            else:
                return {
                    "error": "Session not found",
                    "progress": 0,
                    "stage": "Session not found",
                    "state": "error",
                    "can_connect": False
                }

        # Update heartbeat and save back to Redis
        session['last_heartbeat'] = time.time()
        self._save_session_storage(session_id, session)

        # Check backend readiness
        readiness = self._check_session_readiness(session['username'])

        # Calculate progress based on readiness state
        progress = 10  # Default: session exists
        stage = "Initializing session..."
        state = "initializing"
        can_connect = False

        if readiness['ready'] and readiness['ssl_ready'] and readiness['vnc_ready']:
            # Fully ready - frontend can connect
            progress = 100
            stage = "Environment ready!"
            state = "ready"
            can_connect = True
        elif readiness['ready'] and readiness['ssl_ready']:
            # SSL ready, VNC starting
            progress = 85
            stage = "Starting VNC service..."
            state = "vnc_starting"
        elif readiness['ready']:
            # Pod ready, SSL provisioning
            progress = 65
            stage = "Securing connection (SSL)..."
            state = "ssl_provisioning"
        elif session.get('session_id'):
            # Session created, pod starting
            progress = 30
            stage = "Starting container..."
            state = "pod_creating"

        return {
            "session_id": session_id,
            "progress": progress,
            "stage": stage,
            "state": state,
            "ready": readiness['ready'],
            "ssl_ready": readiness['ssl_ready'],
            "vnc_ready": readiness['vnc_ready'],
            "can_connect": can_connect,
            "websocket_url": session['websocket_url'],
            "url": session['url'],
            "details": readiness['details']
        }
    
    def _try_recover_session(self, session_id):
        """Try to recover a session by checking if K8s resources exist via API"""
        try:
            # Expected username format: user + session_id
            username = f"user{session_id}"

            # Call K8s API to check if pod exists
            ctx = ssl.create_default_context()
            ctx.check_hostname = False
            ctx.verify_mode = ssl.CERT_NONE

            req = urllib.request.Request(
                f"{self.k8s_api_url}/check-readiness/{username}",
                headers={'Content-Type': 'application/json'}
            )

            with urllib.request.urlopen(req, context=ctx, timeout=10) as response:
                result = json.loads(response.read().decode('utf-8'))

                if result.get('pod_exists'):
                    # Pod exists, create session entry
                    session_data = {
                        'session_id': session_id,
                        'username': username,
                        'status': 'recovered',
                        'url': f'https://{username}.{self.domain}',
                        'websocket_url': f'wss://{username}.{self.domain}',
                        'pod_name': f'novnc-{username}',
                        'created_at': time.time(),
                        'last_heartbeat': time.time()
                    }

                    # Store the recovered session in Redis + memory
                    self._save_session_storage(session_id, session_data)

                    # Note: Redis TTL will handle expiration automatically

                    print(f"✅ Session recovered: {session_id} -> {username}")
                    return session_data

        except Exception as e:
            print(f"Failed to recover session {session_id}: {e}")

        return None
    
    def _check_session_readiness(self, username):
        """Check if session pod, service, and ingress with SSL are ready via API"""
        try:
            # Call K8s API to check readiness
            ctx = ssl.create_default_context()
            ctx.check_hostname = False
            ctx.verify_mode = ssl.CERT_NONE

            req = urllib.request.Request(
                f"{self.k8s_api_url}/check-readiness/{username}",
                headers={'Content-Type': 'application/json'}
            )

            with urllib.request.urlopen(req, context=ctx, timeout=10) as response:
                result = json.loads(response.read().decode('utf-8'))

                # Return in same format as before for backward compatibility
                return {
                    "ready": result.get('ready', False),
                    "ssl_ready": result.get('ssl_ready', False),
                    "vnc_ready": result.get('vnc_ready', False),
                    "details": result.get('details', 'Unknown')
                }

        except Exception as e:
            print(f"Error checking session readiness: {e}")
            return {
                "ready": False,
                "ssl_ready": False,
                "vnc_ready": False,
                "details": f"Check failed: {str(e)}"
            }
    
    
    def heartbeat_session(self, session_id):
        """Update session heartbeat to keep it alive and refresh Redis TTL"""
        session = self._get_session_storage(session_id)
        if session:
            session['last_heartbeat'] = time.time()
            self._save_session_storage(session_id, session)  # This refreshes Redis TTL
            return {"status": "heartbeat_updated"}
        return {"error": "Session not found"}

    def delete_session(self, session_id):
        """Explicitly delete a session and its K8s resources"""
        session = self._get_session_storage(session_id)
        if session:
            username = session['username']
            print(f"Session {session_id} ({username}) manually deleted")

            # Delete from Redis + memory
            self._delete_session_storage(session_id)
            
            # Only try to delete K8s resources for dynamic sessions (not hardcoded ones)
            if username != 'userq7e1qs':
                try:
                    self._delete_k8s_resources(username)
                except Exception as e:
                    print(f"Warning: Failed to cleanup K8s resources for {username}: {e}")
            
            return {"status": "deleted", "username": username}
        return {"error": "Session not found"}
    
    def _delete_k8s_resources(self, username):
        """Delete K8s pod, service, and ingress for a user via API"""
        print(f"Cleaning up K8s resources for user: {username}")

        try:
            # Call K8s API to delete resources
            ctx = ssl.create_default_context()
            ctx.check_hostname = False
            ctx.verify_mode = ssl.CERT_NONE

            req = urllib.request.Request(
                f"{self.k8s_api_url}/delete-session/{username}",
                method='DELETE',
                headers={'Content-Type': 'application/json'}
            )

            with urllib.request.urlopen(req, context=ctx, timeout=10) as response:
                result = json.loads(response.read().decode('utf-8'))
                print(f"K8s resources cleaned up for user: {username} - {result}")

        except Exception as e:
            print(f"Error cleaning up K8s resources for {username}: {e}")
            raise
    
    def list_sessions(self):
        """List all active sessions from Redis + memory"""
        sessions = []

        # List sessions from Redis
        if self.redis_client:
            try:
                # Get all session keys from Redis
                session_keys = self.redis_client.keys("session:*")
                for key in session_keys:
                    session_id = key.replace("session:", "")
                    session = self._get_session_storage(session_id)
                    if session:
                        sessions.append({
                            "session_id": session_id,
                            "username": session['username'],
                            "created_at": session['created_at'],
                            "last_heartbeat": session['last_heartbeat'],
                            "websocket_url": session['websocket_url']
                        })
            except Exception as e:
                print(f"Error listing sessions from Redis: {e}")

        # Add in-memory sessions not in Redis (fallback)
        for sid, session in self.active_sessions.items():
            if not any(s['session_id'] == sid for s in sessions):
                sessions.append({
                    "session_id": sid,
                    "username": session['username'],
                    "created_at": session['created_at'],
                    "last_heartbeat": session['last_heartbeat'],
                    "websocket_url": session['websocket_url']
                })

        return {
            "active_sessions": len(sessions),
            "sessions": sessions
        }

# Global session manager
session_manager = DynamicSessionManager()

class DynamicSessionHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def end_headers(self):
        # Add CORS headers to allow cross-origin requests
        self.send_header('Access-Control-Allow-Origin', '*')
        self.send_header('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, DELETE, PUT')
        self.send_header('Access-Control-Allow-Headers', 'Content-Type, X-Session-ID')
        # Prevent caching of API responses
        if self.path.startswith('/api/'):
            self.send_header('Cache-Control', 'no-cache, no-store, must-revalidate')
            self.send_header('Pragma', 'no-cache')
            self.send_header('Expires', '0')
        super().end_headers()

    def do_OPTIONS(self):
        self.send_response(200)
        self.end_headers()

    def do_POST(self):
        if self.path == '/api/session/create':
            content_length = int(self.headers.get('Content-Length', 0))
            post_data = self.rfile.read(content_length)
            
            try:
                data = json.loads(post_data.decode('utf-8')) if content_length > 0 else {}
                session_id = data.get('session_id')  # Optional: frontend can provide session ID
                
                result = session_manager.create_session(session_id)
                
                self.send_response(201 if 'error' not in result else 500)
                self.send_header('Content-Type', 'application/json')
                self.end_headers()
                self.wfile.write(json.dumps(result).encode('utf-8'))
                
            except Exception as e:
                self.send_response(500)
                self.send_header('Content-Type', 'application/json')
                self.end_headers()
                self.wfile.write(json.dumps({"error": str(e)}).encode('utf-8'))
                
        elif self.path == '/api/session/heartbeat':
            content_length = int(self.headers.get('Content-Length', 0))
            post_data = self.rfile.read(content_length)
            
            try:
                data = json.loads(post_data.decode('utf-8'))
                session_id = data.get('session_id')
                
                if not session_id:
                    self.send_response(400)
                    self.send_header('Content-Type', 'application/json')
                    self.end_headers()
                    self.wfile.write(json.dumps({"error": "session_id required"}).encode('utf-8'))
                    return
                
                result = session_manager.heartbeat_session(session_id)
                
                self.send_response(200 if 'error' not in result else 404)
                self.send_header('Content-Type', 'application/json')
                self.end_headers()
                self.wfile.write(json.dumps(result).encode('utf-8'))
                
            except Exception as e:
                self.send_response(500)
                self.send_header('Content-Type', 'application/json')
                self.end_headers()
                self.wfile.write(json.dumps({"error": str(e)}).encode('utf-8'))
        
        elif self.path.startswith('/api/session/') and self.path.count('/') == 3:
            # Handle sendBeacon cleanup requests: POST /api/session/{id}
            session_id = self.path.split('/')[-1]
            content_length = int(self.headers.get('Content-Length', 0))
            
            if content_length > 0:
                post_data = self.rfile.read(content_length)
                try:
                    data = json.loads(post_data.decode('utf-8'))
                    if data.get('action') == 'delete':
                        print(f"🧹 Tab cleanup - deleting session via sendBeacon: {session_id}")
                        result = session_manager.delete_session(session_id)
                        
                        self.send_response(200 if 'error' not in result else 404)
                        self.send_header('Content-Type', 'application/json')
                        self.end_headers()
                        self.wfile.write(json.dumps(result).encode('utf-8'))
                        return
                except json.JSONDecodeError:
                    pass
            
            self.send_response(400)
            self.send_header('Content-Type', 'application/json')
            self.end_headers()
            self.wfile.write(json.dumps({"error": "Invalid request"}).encode('utf-8'))
        else:
            super().do_POST()

    def do_GET(self):
        if self.path.startswith('/api/session/'):
            # Parse the path to handle /api/session/{id}/progress
            path_parts = self.path.split('/')

            if len(path_parts) >= 4:
                session_id = path_parts[3]

                if session_id == 'list':
                    # List all sessions
                    result = session_manager.list_sessions()
                    self.send_response(200)
                    self.send_header('Content-Type', 'application/json')
                    self.end_headers()
                    self.wfile.write(json.dumps(result).encode('utf-8'))
                    return

                # Check if this is a progress request
                if len(path_parts) == 5 and path_parts[4] == 'progress':
                    # GET /api/session/{id}/progress
                    result = session_manager.get_session_progress(session_id)

                    self.send_response(200 if 'error' not in result else 404)
                    self.send_header('Content-Type', 'application/json')
                    self.end_headers()
                    self.wfile.write(json.dumps(result).encode('utf-8'))
                else:
                    # GET /api/session/{id} - Get specific session
                    result = session_manager.get_session(session_id)

                    self.send_response(200 if 'error' not in result else 404)
                    self.send_header('Content-Type', 'application/json')
                    self.end_headers()
                    self.wfile.write(json.dumps(result).encode('utf-8'))
        else:
            super().do_GET()

    def do_DELETE(self):
        # Add CORS headers for DELETE requests
        self.send_header('Access-Control-Allow-Origin', '*')
        self.send_header('Access-Control-Allow-Methods', 'GET, POST, PUT, DELETE, OPTIONS')
        self.send_header('Access-Control-Allow-Headers', 'Content-Type, Authorization')
        
        if self.path.startswith('/api/session/'):
            session_id = self.path.split('/')[-1]
            print(f"Deleting session via DELETE: {session_id}")
            result = session_manager.delete_session(session_id)
            
            self.send_response(200 if 'error' not in result else 404)
            self.send_header('Content-Type', 'application/json')
            self.end_headers()
            self.wfile.write(json.dumps(result).encode('utf-8'))
        else:
            self.send_response(404)
            self.end_headers()

def main():
    # Change to the directory containing this script
    script_dir = Path(__file__).parent
    os.chdir(script_dir)
    
    PORT = 3003  # Different port from the original server
    
    try:
        with socketserver.TCPServer(("", PORT), DynamicSessionHTTPRequestHandler) as httpd:
            print(f"🚀 Starting Dynamic VNC Session Manager on port {PORT}")
            print(f"📁 Serving files from: {script_dir}")
            print(f"🔗 Access the API at: http://localhost:{PORT}")
            print(f"🌐 K8s API endpoint: {session_manager.k8s_api_url}")
            print("✅ Dynamic session manager started successfully!")
            print("\nAPI Endpoints:")
            print("  POST /api/session/create - Create new VNC session")
            print("  GET  /api/session/{id} - Get session info")
            print("  GET  /api/session/{id}/progress - Get session progress (RECOMMENDED)")
            print("  GET  /api/session/list - List all sessions")
            print("  POST /api/session/heartbeat - Update session heartbeat")
            print("  DELETE /api/session/{id} - Delete session")
            
            httpd.serve_forever()
            
    except KeyboardInterrupt:
        print("\n🛑 Server stopped by user")
    except OSError as e:
        if e.errno == 98:  # Address already in use
            print(f"❌ Port {PORT} is already in use. Try a different port or stop the existing server.")
            sys.exit(1)
        else:
            print(f"❌ Error starting server: {e}")
            sys.exit(1)
    except Exception as e:
        print(f"❌ Unexpected error: {e}")
        sys.exit(1)

if __name__ == "__main__":
    main()
