#!/usr/bin/env python3

"""
Claude CLI Control Server - NEOG System
WebSocket + REST API to control Claude Code CLI
Port: 9003 (NEOG Services: 9001=logs, 9002=wifi, 9003=claude)
"""

import json
import subprocess
import time
import uuid
import threading
import requests
from datetime import datetime
from flask import Flask, jsonify, request
from flask_socketio import SocketIO, emit
from flask_cors import CORS

# ROI Dashboard Integration
ROI_API_BASE = 'http://localhost:9001/api/cc'
ROI_TRACKING_ENABLED = True

def roi_create_run(prompt, session_id, source='cli-server'):
    """Create a new run in the ROI dashboard"""
    if not ROI_TRACKING_ENABLED:
        return None
    try:
        resp = requests.post(f'{ROI_API_BASE}/run', json={
            'prompt': prompt[:500],  # Truncate long prompts
            'session_id': session_id,
            'source': source
        }, timeout=2)
        if resp.ok:
            data = resp.json()
            return data.get('run_id')
    except Exception as e:
        print(f"[ROI] Failed to create run: {e}")
    return None

def roi_log_event(run_id, event_type, event_subtype=None, input_tokens=0, output_tokens=0, success=True, data=None):
    """Log an event to the ROI dashboard"""
    if not ROI_TRACKING_ENABLED or not run_id:
        return
    try:
        requests.post(f'{ROI_API_BASE}/event', json={
            'run_id': run_id,
            'event_type': event_type,
            'event_subtype': event_subtype,
            'input_tokens': input_tokens,
            'output_tokens': output_tokens,
            'success': success,
            'input_data': json.dumps(data) if data else None
        }, timeout=2)
    except Exception as e:
        print(f"[ROI] Failed to log event: {e}")

def roi_complete_run(run_id, status='completed', input_tokens=0, output_tokens=0, error=None):
    """Complete a run in the ROI dashboard"""
    if not ROI_TRACKING_ENABLED or not run_id:
        return
    try:
        payload = {
            'status': status,
            'input_tokens': input_tokens,
            'output_tokens': output_tokens
        }
        if error:
            payload['error_message'] = str(error)[:500]
        requests.post(f'{ROI_API_BASE}/run/{run_id}/complete', json=payload, timeout=2)
    except Exception as e:
        print(f"[ROI] Failed to complete run: {e}")

app = Flask(__name__)
app.config['SECRET_KEY'] = 'claude-cli-control-secret'
CORS(app)
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='threading')

# Active sessions tracking
active_sessions = {}
session_lock = threading.Lock()

# Command queue and execution state
command_queue = []
execution_in_progress = False

def execute_claude_command(prompt, session_id=None, tools=None, streaming=True):
    """
    Execute Claude CLI command with streaming or one-shot mode

    Args:
        prompt: The prompt to send to Claude
        session_id: Optional session ID to resume/continue
        tools: Optional list of tools to restrict (e.g., ["Bash", "Read", "Edit"])
        streaming: If True, use stream-json format for real-time output

    Returns:
        Generator yielding chunks (if streaming) or complete result dict
    """
    try:
        # Build command
        cmd = ['claude', '--print', prompt]

        if streaming:
            cmd.append('--output-format')
            cmd.append('stream-json')
        else:
            cmd.append('--output-format')
            cmd.append('json')

        if session_id:
            cmd.extend(['--session-id', session_id])

        if tools:
            cmd.extend(['--tools', ','.join(tools)])

        print(f"[EXEC] Running: {' '.join(cmd[:4])}...")

        # Execute command
        process = subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            bufsize=1,
            cwd='/Users/neog'  # Use workspace directory
        )

        if streaming:
            # Stream output line by line
            for line in iter(process.stdout.readline, ''):
                if line.strip():
                    try:
                        chunk = json.loads(line)
                        yield chunk
                    except json.JSONDecodeError:
                        # Non-JSON output (warnings, debug messages)
                        yield {
                            'type': 'log',
                            'message': line.strip(),
                            'timestamp': datetime.now().isoformat()
                        }

            # Wait for completion
            process.wait()

            if process.returncode != 0:
                stderr = process.stderr.read()
                yield {
                    'type': 'error',
                    'message': stderr or f'Command failed with code {process.returncode}',
                    'timestamp': datetime.now().isoformat()
                }
        else:
            # Wait for complete output
            stdout, stderr = process.communicate(timeout=300)  # 5 min timeout

            if process.returncode == 0:
                try:
                    result = json.loads(stdout)
                    return result
                except json.JSONDecodeError:
                    return {
                        'type': 'error',
                        'message': f'Failed to parse JSON output: {stdout[:200]}',
                        'timestamp': datetime.now().isoformat()
                    }
            else:
                return {
                    'type': 'error',
                    'message': stderr or f'Command failed with code {process.returncode}',
                    'timestamp': datetime.now().isoformat()
                }

    except subprocess.TimeoutExpired:
        process.kill()
        yield {
            'type': 'error',
            'message': 'Command timeout (5 minutes)',
            'timestamp': datetime.now().isoformat()
        }

    except Exception as e:
        yield {
            'type': 'error',
            'message': str(e),
            'timestamp': datetime.now().isoformat()
        }

# ============= WEBSOCKET EVENT HANDLERS =============

@socketio.on('connect')
def handle_connect():
    """Client connected"""
    print(f"[WebSocket] Client connected: {request.sid}")
    emit('connection_status', {
        'status': 'connected',
        'server': 'Claude CLI Control Server',
        'version': '1.0.0',
        'timestamp': datetime.now().isoformat()
    })

@socketio.on('disconnect')
def handle_disconnect():
    """Client disconnected"""
    print(f"[WebSocket] Client disconnected: {request.sid}")

@socketio.on('claude_command')
def handle_claude_command(data):
    """
    Execute Claude command and stream results

    Expected data format:
    {
        "prompt": "Your prompt here",
        "session_id": "optional-session-uuid",
        "tools": ["Bash", "Read"],  // optional
        "streaming": true  // default: true
    }
    """
    global execution_in_progress

    run_id = None  # ROI tracking ID
    total_input_tokens = 0
    total_output_tokens = 0

    try:
        prompt = data.get('prompt', '')
        session_id = data.get('session_id')
        tools = data.get('tools')
        streaming = data.get('streaming', True)
        source = data.get('source', 'websocket')  # Track source

        if not prompt:
            emit('claude_error', {
                'error': 'No prompt provided',
                'timestamp': datetime.now().isoformat()
            })
            return

        # Generate session ID if not provided
        if not session_id:
            session_id = str(uuid.uuid4())

        print(f"[COMMAND] Prompt: {prompt[:50]}... | Session: {session_id}")

        # Create ROI tracking run
        run_id = roi_create_run(prompt, session_id, source)
        if run_id:
            print(f"[ROI] Created run: {run_id}")

        # Mark execution in progress
        execution_in_progress = True

        # Track session
        with session_lock:
            active_sessions[session_id] = {
                'id': session_id,
                'prompt': prompt,
                'started_at': datetime.now().isoformat(),
                'status': 'running',
                'client_sid': request.sid,
                'run_id': run_id  # Store ROI run_id
            }

        emit('claude_started', {
            'session_id': session_id,
            'run_id': run_id,
            'prompt': prompt,
            'timestamp': datetime.now().isoformat()
        })

        # Execute command and stream results
        chunk_count = 0
        for chunk in execute_claude_command(prompt, session_id, tools, streaming):
            chunk_count += 1

            # Extract token usage from chunks if available
            if isinstance(chunk, dict):
                chunk_type = chunk.get('type', '')

                # Track tool calls
                if chunk_type == 'tool_use':
                    tool_name = chunk.get('name', 'unknown')
                    roi_log_event(run_id, 'tool_call', tool_name, success=True, data={'tool': tool_name})

                # Track token usage from result chunks
                if 'usage' in chunk:
                    usage = chunk['usage']
                    total_input_tokens += usage.get('input_tokens', 0)
                    total_output_tokens += usage.get('output_tokens', 0)

            emit('claude_output', {
                'session_id': session_id,
                'chunk': chunk,
                'chunk_index': chunk_count,
                'timestamp': datetime.now().isoformat()
            })

        # Mark as completed
        with session_lock:
            if session_id in active_sessions:
                active_sessions[session_id]['status'] = 'completed'
                active_sessions[session_id]['completed_at'] = datetime.now().isoformat()
                active_sessions[session_id]['chunk_count'] = chunk_count
                active_sessions[session_id]['input_tokens'] = total_input_tokens
                active_sessions[session_id]['output_tokens'] = total_output_tokens

        # Complete ROI tracking
        roi_complete_run(run_id, 'completed', total_input_tokens, total_output_tokens)

        emit('claude_completed', {
            'session_id': session_id,
            'run_id': run_id,
            'chunk_count': chunk_count,
            'input_tokens': total_input_tokens,
            'output_tokens': total_output_tokens,
            'timestamp': datetime.now().isoformat()
        })

        print(f"[COMPLETE] Session {session_id} - {chunk_count} chunks | Tokens: {total_input_tokens}/{total_output_tokens}")

    except Exception as e:
        print(f"[ERROR] {e}")
        # Log error to ROI
        roi_complete_run(run_id, 'failed', total_input_tokens, total_output_tokens, error=str(e))
        emit('claude_error', {
            'error': str(e),
            'timestamp': datetime.now().isoformat()
        })

    finally:
        execution_in_progress = False

@socketio.on('list_sessions')
def handle_list_sessions():
    """List all active sessions"""
    with session_lock:
        sessions_list = list(active_sessions.values())

    emit('sessions_list', {
        'sessions': sessions_list,
        'count': len(sessions_list),
        'timestamp': datetime.now().isoformat()
    })

@socketio.on('get_session')
def handle_get_session(data):
    """Get specific session details"""
    session_id = data.get('session_id')

    with session_lock:
        session = active_sessions.get(session_id)

    if session:
        emit('session_details', {
            'session': session,
            'timestamp': datetime.now().isoformat()
        })
    else:
        emit('claude_error', {
            'error': f'Session {session_id} not found',
            'timestamp': datetime.now().isoformat()
        })

# ============= REST API ENDPOINTS (Fallback) =============

@app.route('/')
def index():
    """API info"""
    return jsonify({
        'service': 'Claude CLI Control Server',
        'version': '1.0.0',
        'port': 9003,
        'websocket': {
            'url': 'ws://localhost:9003',
            'events': {
                'connect': 'Client connection established',
                'claude_command': 'Execute Claude CLI command (streaming)',
                'list_sessions': 'List all active sessions',
                'get_session': 'Get session details'
            },
            'emitted_events': {
                'connection_status': 'Connection confirmation',
                'claude_started': 'Command execution started',
                'claude_output': 'Streaming output chunks',
                'claude_completed': 'Command execution completed',
                'claude_error': 'Error occurred',
                'sessions_list': 'Response to list_sessions',
                'session_details': 'Response to get_session'
            }
        },
        'rest_endpoints': {
            '/api/execute': 'POST - Execute command (one-shot, non-streaming)',
            '/api/sessions': 'GET - List active sessions',
            '/api/sessions/<id>': 'GET - Get session details',
            '/api/health': 'GET - Health check',
            '/api/status': 'GET - Server status'
        },
        'tailscale': {
            'local': 'http://localhost:9003',
            'remote': 'http://100.75.88.8:9003',
            'websocket': 'ws://100.75.88.8:9003'
        }
    })

@app.route('/api/execute', methods=['POST'])
def api_execute():
    """
    Execute Claude command (one-shot, non-streaming)

    POST body:
    {
        "prompt": "Your prompt",
        "session_id": "optional",
        "tools": ["Bash", "Read"],
        "source": "rest-api"
    }
    """
    run_id = None
    try:
        data = request.get_json()
        prompt = data.get('prompt', '')
        session_id = data.get('session_id', str(uuid.uuid4()))
        tools = data.get('tools')
        source = data.get('source', 'rest-api')

        if not prompt:
            return jsonify({
                'success': False,
                'error': 'No prompt provided'
            }), 400

        print(f"[REST API] Execute: {prompt[:50]}...")

        # Create ROI tracking run
        run_id = roi_create_run(prompt, session_id, source)

        # Execute command (non-streaming)
        result = next(execute_claude_command(prompt, session_id, tools, streaming=False))

        # Extract token usage if available
        input_tokens = 0
        output_tokens = 0
        if isinstance(result, dict) and 'usage' in result:
            input_tokens = result['usage'].get('input_tokens', 0)
            output_tokens = result['usage'].get('output_tokens', 0)

        # Complete ROI tracking
        roi_complete_run(run_id, 'completed', input_tokens, output_tokens)

        return jsonify({
            'success': True,
            'session_id': session_id,
            'run_id': run_id,
            'result': result,
            'input_tokens': input_tokens,
            'output_tokens': output_tokens,
            'timestamp': datetime.now().isoformat()
        })

    except Exception as e:
        roi_complete_run(run_id, 'failed', 0, 0, error=str(e))
        return jsonify({
            'success': False,
            'error': str(e),
            'timestamp': datetime.now().isoformat()
        }), 500

@app.route('/api/sessions', methods=['GET'])
def api_list_sessions():
    """List all active sessions"""
    with session_lock:
        sessions_list = list(active_sessions.values())

    return jsonify({
        'success': True,
        'sessions': sessions_list,
        'count': len(sessions_list),
        'timestamp': datetime.now().isoformat()
    })

@app.route('/api/sessions/<session_id>', methods=['GET'])
def api_get_session(session_id):
    """Get specific session details"""
    with session_lock:
        session = active_sessions.get(session_id)

    if session:
        return jsonify({
            'success': True,
            'session': session,
            'timestamp': datetime.now().isoformat()
        })
    else:
        return jsonify({
            'success': False,
            'error': f'Session {session_id} not found'
        }), 404

@app.route('/api/health', methods=['GET'])
def health_check():
    """Health check endpoint"""
    return jsonify({
        'status': 'healthy',
        'service': 'Claude CLI Control Server',
        'version': '1.0.0',
        'timestamp': datetime.now().isoformat()
    })

@app.route('/api/status', methods=['GET'])
def status_check():
    """Server status"""
    with session_lock:
        active_count = len([s for s in active_sessions.values() if s['status'] == 'running'])
        completed_count = len([s for s in active_sessions.values() if s['status'] == 'completed'])

    return jsonify({
        'server': 'running',
        'execution_in_progress': execution_in_progress,
        'sessions': {
            'total': len(active_sessions),
            'active': active_count,
            'completed': completed_count
        },
        'timestamp': datetime.now().isoformat()
    })

# ============= MAIN =============

if __name__ == '__main__':
    print("=" * 70)
    print("🤖 Claude CLI Control Server")
    print("=" * 70)
    print("📡 Port: 9003")
    print("🔧 Features: WebSocket Streaming + REST API + Session Management")
    print("\n📊 WebSocket:")
    print("   → ws://localhost:9003")
    print("   → Event: 'claude_command' - Execute command (streaming)")
    print("   → Event: 'list_sessions' - List active sessions")
    print("\n🌐 REST API:")
    print("   → POST http://localhost:9003/api/execute")
    print("   → GET  http://localhost:9003/api/sessions")
    print("   → GET  http://localhost:9003/api/health")
    print("\n🌍 Access via Tailscale:")
    print("   → ws://100.75.88.8:9003")
    print("   → http://100.75.88.8:9003/api/execute")
    print("=" * 70)
    print("\n[Server] Starting on port 9003...")
    print("[WebSocket] Ready for connections\n")

    # Run server
    socketio.run(
        app,
        host='0.0.0.0',
        port=9003,
        debug=True,
        allow_unsafe_werkzeug=True
    )
