#!/usr/bin/env python3
"""
YouTube Audio Download Server
Flask backend for downloading YouTube audio as MP3
Port: 9005
"""

import os
import re
import json
import tempfile
import subprocess
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS

# Add ~/bin to PATH for ffmpeg
os.environ['PATH'] = os.path.expanduser('~/bin') + ':' + os.environ.get('PATH', '')

app = Flask(__name__)
# Security: Restrict CORS to localhost only
CORS(app, origins=["http://localhost:*", "http://127.0.0.1:*"])

# Temporary directory for downloads
TEMP_DIR = tempfile.mkdtemp(prefix='youtube_audio_')

# In-memory log buffer for real-time logs
LOG_BUFFER = []
MAX_LOG_ENTRIES = 100

def add_log(message: str, level: str = "info"):
    """Add a log entry to the buffer."""
    import datetime
    entry = {
        "timestamp": datetime.datetime.now().isoformat(),
        "level": level,
        "message": message
    }
    LOG_BUFFER.append(entry)
    # Keep only last MAX_LOG_ENTRIES
    while len(LOG_BUFFER) > MAX_LOG_ENTRIES:
        LOG_BUFFER.pop(0)
    print(f"[{level.upper()}] {message}")

def extract_video_id(url: str) -> str:
    """Extract video ID from various YouTube URL formats."""
    patterns = [
        r'(?:youtube\.com/watch\?v=|youtu\.be/|youtube\.com/embed/)([a-zA-Z0-9_-]{11})',
        r'youtube\.com/shorts/([a-zA-Z0-9_-]{11})',
        r'music\.youtube\.com/watch\?v=([a-zA-Z0-9_-]{11})',
    ]
    for pattern in patterns:
        match = re.search(pattern, url)
        if match:
            return match.group(1)
    return None

def extract_playlist_id(url: str) -> str:
    """Extract playlist ID from YouTube/YouTube Music URL."""
    match = re.search(r'[?&]list=([a-zA-Z0-9_-]+)', url)
    return match.group(1) if match else None

def is_playlist_url(url: str) -> bool:
    """Check if URL is a playlist (includes Radio/Mix playlists with video ID)."""
    # Any URL with list= is a playlist, even if it has a video ID
    return 'list=' in url

def sanitize_filename(title: str) -> str:
    """Create a safe filename from video title."""
    # Security: Remove path components to prevent path traversal
    safe = os.path.basename(title)
    # Remove invalid characters including control chars and dots (prevent hidden files)
    safe = re.sub(r'[<>:"/\\|?*.\x00-\x1f]', '', safe)
    safe = safe.strip()
    # Limit length
    if len(safe) > 100:
        safe = safe[:100]
    return safe or 'audio'

@app.route('/api/health', methods=['GET'])
def health():
    """Health check endpoint."""
    return jsonify({'status': 'ok', 'service': 'youtube-audio-server'})

@app.route('/api/logs', methods=['GET'])
def get_logs():
    """Get recent logs for real-time display."""
    # Optional: filter by timestamp
    since = request.args.get('since')
    if since:
        filtered = [log for log in LOG_BUFFER if log['timestamp'] > since]
        return jsonify({'logs': filtered})
    return jsonify({'logs': LOG_BUFFER})

@app.route('/api/logs/clear', methods=['POST'])
def clear_logs():
    """Clear log buffer."""
    LOG_BUFFER.clear()
    return jsonify({'status': 'cleared'})

@app.route('/api/info', methods=['POST'])
def get_video_info():
    """
    Get video or playlist metadata.

    Request body:
        {"url": "https://youtube.com/watch?v=..." or "https://music.youtube.com/playlist?list=..."}

    Response (video):
        {"id": "...", "title": "...", "duration": 123, "thumbnail": "...", "type": "video"}
    Response (playlist):
        {"id": "...", "title": "...", "type": "playlist", "videos": [...]}
    """
    try:
        data = request.get_json()
        if not data or 'url' not in data:
            add_log("Request without URL", "error")
            return jsonify({'error': 'URL is required'}), 400

        url = data['url']

        # Security: Validate URL format and length
        if not isinstance(url, str) or len(url) > 500:
            add_log("URL too long or invalid type", "error")
            return jsonify({'error': 'Invalid URL format'}), 400

        if not re.match(r'^https?://(www\.)?(youtube\.com|youtu\.be|music\.youtube\.com)/', url):
            add_log(f"Rejected non-YouTube URL: {url[:50]}", "error")
            return jsonify({'error': 'Only YouTube URLs are allowed'}), 400

        add_log(f"Received request for: {url[:80]}...")

        # Check if it's a playlist
        if is_playlist_url(url):
            playlist_id = extract_playlist_id(url)
            add_log(f"Detected playlist: {playlist_id}")

            if not playlist_id:
                add_log("Invalid playlist URL", "error")
                return jsonify({'error': 'Invalid playlist URL'}), 400

            add_log("Fetching playlist info with yt-dlp...")

            # Get playlist info with yt-dlp
            # For Radio/Mix playlists, we need to handle them differently
            is_radio = playlist_id.startswith('RD')

            cmd = [
                'yt-dlp',
                '--dump-json',
                '--flat-playlist',
                '--no-download',
            ]

            # For Radio playlists, fetch more entries
            if is_radio:
                add_log(f"Radio/Mix playlist detected - fetching up to 50 tracks")
                cmd.extend(['--playlist-end', '50'])  # Limit radio to 50 tracks

            cmd.append(url)

            add_log(f"Running: {' '.join(cmd[:5])}...")
            result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)

            if result.returncode != 0:
                add_log(f"yt-dlp failed: {result.stderr[:200]}", "error")
                return jsonify({'error': 'Failed to fetch playlist info', 'details': result.stderr}), 500

            # Parse each line as JSON (one per video)
            videos = []
            playlist_title = "Playlist"
            lines = result.stdout.strip().split('\n')
            add_log(f"Parsing {len(lines)} entries from yt-dlp output...")

            for i, line in enumerate(lines):
                if line:
                    try:
                        video_info = json.loads(line)
                        # Get playlist title from first entry
                        if 'playlist_title' in video_info:
                            playlist_title = video_info.get('playlist_title', 'Playlist')

                        video_id = video_info.get('id', '')
                        video_title = video_info.get('title', 'Unknown')

                        if video_id:  # Only add if we have a valid ID
                            videos.append({
                                'id': video_id,
                                'title': video_title,
                                'duration': video_info.get('duration') or 0,
                                'thumbnail': video_info.get('thumbnail') or f"https://img.youtube.com/vi/{video_id}/mqdefault.jpg",
                                'channel': video_info.get('channel') or video_info.get('uploader', 'Unknown')
                            })

                            if (i + 1) % 10 == 0:
                                add_log(f"Processed {i + 1}/{len(lines)} videos...")
                    except json.JSONDecodeError as e:
                        add_log(f"JSON parse error on line {i}: {str(e)[:50]}", "warn")
                        continue

            add_log(f"Playlist loaded: '{playlist_title}' with {len(videos)} videos", "success")

            return jsonify({
                'id': playlist_id,
                'title': playlist_title,
                'type': 'playlist',
                'count': len(videos),
                'videos': videos
            })

        # Single video
        video_id = extract_video_id(url)
        add_log(f"Single video detected: {video_id}")

        if not video_id:
            add_log("Invalid YouTube URL", "error")
            return jsonify({'error': 'Invalid YouTube URL'}), 400

        add_log("Fetching video info...")

        # Use yt-dlp to get video info
        cmd = [
            'yt-dlp',
            '--dump-json',
            '--no-download',
            '--no-playlist',
            f'https://youtube.com/watch?v={video_id}'
        ]

        result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)

        if result.returncode != 0:
            add_log(f"yt-dlp failed: {result.stderr[:100]}", "error")
            return jsonify({'error': 'Failed to fetch video info', 'details': result.stderr}), 500

        info = json.loads(result.stdout)
        title = info.get('title', 'Unknown')
        add_log(f"Video loaded: '{title}'", "success")

        return jsonify({
            'id': video_id,
            'title': title,
            'duration': info.get('duration', 0),
            'thumbnail': info.get('thumbnail', f'https://img.youtube.com/vi/{video_id}/maxresdefault.jpg'),
            'channel': info.get('channel', 'Unknown'),
            'view_count': info.get('view_count', 0),
            'type': 'video'
        })

    except subprocess.TimeoutExpired:
        add_log("Request timeout", "error")
        return jsonify({'error': 'Request timeout'}), 504
    except json.JSONDecodeError:
        add_log("Failed to parse video info", "error")
        return jsonify({'error': 'Failed to parse video info'}), 500
    except Exception as e:
        add_log(f"Error: {str(e)}", "error")
        return jsonify({'error': str(e)}), 500

@app.route('/api/download', methods=['POST'])
def download_audio():
    """
    Download audio as MP3.

    Request body:
        {"video_id": "...", "title": "..."}

    Response:
        MP3 file binary
    """
    try:
        data = request.get_json()
        if not data or 'video_id' not in data:
            add_log("Download request without video_id", "error")
            return jsonify({'error': 'video_id is required'}), 400

        video_id = data['video_id']
        title = data.get('title', video_id)
        add_log(f"Starting download: {title[:50]}...")

        # Validate video_id format
        if not re.match(r'^[a-zA-Z0-9_-]{11}$', video_id):
            add_log(f"Invalid video_id format: {video_id}", "error")
            return jsonify({'error': 'Invalid video_id format'}), 400

        # Create safe filename
        safe_title = sanitize_filename(title)
        output_template = os.path.join(TEMP_DIR, f'{safe_title}.%(ext)s')

        # Try M4A format first (best compatibility), fallback to HLS
        formats_to_try = [
            ['140', 'bestaudio[ext=m4a]'],  # M4A audio
            ['91/92/93'],  # HLS fallback
        ]

        add_log(f"Trying download formats for {video_id}...")
        result = None
        for i, fmt_list in enumerate(formats_to_try):
            fmt = '/'.join(fmt_list)
            add_log(f"Trying format {i+1}: {fmt}")
            cmd = [
                'yt-dlp',
                '-f', fmt,
                '-o', output_template,
                '--no-playlist',
                '--no-warnings',
                '--no-check-certificates',
                f'https://youtube.com/watch?v={video_id}'
            ]

            result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
            if result.returncode == 0:
                add_log(f"Format {fmt} succeeded")
                break
            else:
                add_log(f"Format {fmt} failed, trying next...", "warn")

        if result.returncode != 0:
            add_log(f"All formats failed: {result.stderr[:100]}", "error")
            return jsonify({
                'error': 'Failed to download audio',
                'details': result.stderr
            }), 500

        # Find the downloaded file (could be .mp4, .m4a, .ts, etc)
        output_path = None
        for f in os.listdir(TEMP_DIR):
            if f.startswith(safe_title):
                output_path = os.path.join(TEMP_DIR, f)
                break

        if not output_path or not os.path.exists(output_path):
            add_log("Download completed but file not found", "error")
            return jsonify({'error': 'Download completed but file not found'}), 500

        add_log(f"File downloaded: {os.path.basename(output_path)}")

        # Check if file is MPEG transport stream and convert to M4A
        ext = os.path.splitext(output_path)[1].lower()

        # Detect MPEG transport stream by reading first bytes
        with open(output_path, 'rb') as f:
            header = f.read(4)

        is_ts = header[0:1] == b'\x47'  # MPEG-TS sync byte

        # Convert MPEG-TS to M4A using ffmpeg
        ffmpeg_path = os.path.expanduser('~/bin/ffmpeg')
        if is_ts and os.path.exists(ffmpeg_path):
            add_log("Converting MPEG-TS to M4A...")
            m4a_path = output_path.rsplit('.', 1)[0] + '.m4a'
            convert_cmd = [
                ffmpeg_path,
                '-i', output_path,
                '-c:a', 'aac',
                '-b:a', '192k',
                '-y',  # Overwrite
                m4a_path
            ]
            convert_result = subprocess.run(convert_cmd, capture_output=True, timeout=120)
            if convert_result.returncode == 0 and os.path.exists(m4a_path):
                os.remove(output_path)  # Remove original .ts
                output_path = m4a_path
                ext = '.m4a'
                add_log("Conversion to M4A complete")

        # Determine mimetype based on extension
        mimetypes = {
            '.m4a': 'audio/mp4',
            '.mp4': 'audio/mp4',
            '.ts': 'video/MP2T',
            '.mp3': 'audio/mpeg'
        }
        mimetype = mimetypes.get(ext, 'audio/mpeg')
        download_name = f'{safe_title}{ext}'

        file_size = os.path.getsize(output_path)
        add_log(f"Sending file: {download_name} ({file_size // 1024} KB)", "success")

        # Send file and schedule cleanup
        response = send_file(
            output_path,
            mimetype=mimetype,
            as_attachment=True,
            download_name=download_name
        )

        # Clean up after sending (in production, use a background task)
        @response.call_on_close
        def cleanup():
            try:
                if os.path.exists(output_path):
                    os.remove(output_path)
            except OSError as e:
                add_log(f"Cleanup failed for {output_path}: {e}", "warn")

        return response

    except subprocess.TimeoutExpired:
        return jsonify({'error': 'Download timeout (max 5 minutes)'}), 504
    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/api/formats', methods=['POST'])
def get_formats():
    """
    Get available audio formats for a video.

    Request body:
        {"url": "https://youtube.com/watch?v=..."}

    Response:
        {"formats": [...]}
    """
    try:
        data = request.get_json()
        if not data or 'url' not in data:
            return jsonify({'error': 'URL is required'}), 400

        url = data['url']
        video_id = extract_video_id(url)

        if not video_id:
            return jsonify({'error': 'Invalid YouTube URL'}), 400

        cmd = [
            'yt-dlp',
            '-F',
            '--no-playlist',
            f'https://youtube.com/watch?v={video_id}'
        ]

        result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)

        if result.returncode != 0:
            return jsonify({'error': 'Failed to fetch formats'}), 500

        return jsonify({'formats': result.stdout})

    except Exception as e:
        return jsonify({'error': str(e)}), 500

# Stem separation directory
STEMS_DIR = os.path.expanduser('~/Music/YouTubeAudio/Stems')

@app.route('/api/separate', methods=['POST'])
def separate_stems():
    """
    Separate audio into stems (vocals, drums, bass, other).
    Uses demucs for high-quality separation.

    Request body:
        {"input_path": "/path/to/audio.m4a", "video_id": "abc123"}

    Response:
        {"status": "success", "stems": {"vocals": "/path/to/vocals.wav", "drums": "/path/to/drums.wav", ...}}
    """
    try:
        data = request.get_json()
        if not data or 'input_path' not in data:
            add_log("Separate request without input_path", "error")
            return jsonify({'error': 'input_path is required'}), 400

        input_path = data['input_path']
        video_id = data.get('video_id', 'unknown')

        # Validate input path exists
        if not os.path.exists(input_path):
            add_log(f"Input file not found: {input_path}", "error")
            return jsonify({'error': 'Input file not found'}), 404

        # Security: Validate path is within allowed directories
        music_dir = os.path.expanduser('~/Music/YouTubeAudio')
        real_input = os.path.realpath(input_path)
        if not real_input.startswith(os.path.realpath(music_dir)):
            add_log(f"Path traversal attempt: {input_path}", "error")
            return jsonify({'error': 'Invalid input path'}), 400

        add_log(f"Starting stem separation for: {os.path.basename(input_path)}")

        # Create stems directory
        os.makedirs(STEMS_DIR, exist_ok=True)

        # Output directory for this track
        output_dir = os.path.join(STEMS_DIR, video_id)
        os.makedirs(output_dir, exist_ok=True)

        # Check if stems already exist
        expected_stems = ['vocals', 'drums', 'bass', 'other']
        existing_stems = {}
        all_exist = True
        for stem in expected_stems:
            stem_path = os.path.join(output_dir, f'{stem}.wav')
            if os.path.exists(stem_path):
                existing_stems[stem] = stem_path
            else:
                all_exist = False

        if all_exist:
            add_log(f"Stems already exist for {video_id}", "success")
            return jsonify({
                'status': 'success',
                'stems': existing_stems,
                'cached': True
            })

        # Convert M4A to WAV first (demucs has trouble finding ffmpeg)
        add_log("Converting audio to WAV format...")
        ffmpeg_path = os.path.expanduser('~/bin/ffmpeg')
        temp_wav = os.path.join(output_dir, 'temp_input.wav')

        convert_cmd = [
            ffmpeg_path,
            '-i', input_path,
            '-ar', '44100',
            '-ac', '2',
            '-y',
            temp_wav
        ]

        convert_result = subprocess.run(convert_cmd, capture_output=True, text=True, timeout=120)
        if convert_result.returncode != 0:
            add_log(f"Audio conversion failed: {convert_result.stderr[:200]}", "error")
            return jsonify({
                'error': 'Audio conversion failed',
                'details': convert_result.stderr
            }), 500

        add_log("Conversion complete. Running demucs separation (this may take a few minutes)...")

        # Run demucs for stem separation
        cmd = [
            'python3', '-m', 'demucs',
            '--two-stems', 'vocals',  # Separate into vocals + instrumental only
            '-n', 'htdemucs',         # Use htdemucs model
            '-o', output_dir,         # Output directory
            temp_wav
        ]

        result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)  # 10 min timeout

        # Clean up temp WAV file
        if os.path.exists(temp_wav):
            os.remove(temp_wav)

        if result.returncode != 0:
            add_log(f"Demucs failed: {result.stderr[:200]}", "error")
            return jsonify({
                'error': 'Stem separation failed',
                'details': result.stderr
            }), 500

        # Find the separated files
        # Demucs outputs to: output_dir/htdemucs/temp_input/vocals.wav, no_vocals.wav
        demucs_output = os.path.join(output_dir, 'htdemucs', 'temp_input')

        stems = {}
        for stem_name in ['vocals', 'no_vocals']:
            stem_file = os.path.join(demucs_output, f'{stem_name}.wav')
            if os.path.exists(stem_file):
                # Move to simpler location
                final_path = os.path.join(output_dir, f'{stem_name}.wav')
                if not os.path.exists(final_path):
                    os.rename(stem_file, final_path)
                stems[stem_name] = final_path
                add_log(f"Stem created: {stem_name}")

        # Cleanup demucs temp directory
        import shutil
        htdemucs_dir = os.path.join(output_dir, 'htdemucs')
        if os.path.exists(htdemucs_dir):
            shutil.rmtree(htdemucs_dir)

        add_log(f"Stem separation complete: {len(stems)} stems created", "success")

        return jsonify({
            'status': 'success',
            'stems': stems,
            'cached': False
        })

    except subprocess.TimeoutExpired:
        add_log("Stem separation timeout (max 10 minutes)", "error")
        return jsonify({'error': 'Stem separation timeout'}), 504
    except Exception as e:
        add_log(f"Stem separation error: {str(e)}", "error")
        return jsonify({'error': str(e)}), 500


@app.route('/api/stems/<video_id>', methods=['GET'])
def get_stems(video_id):
    """
    Check if stems exist for a video.

    Response:
        {"exists": true/false, "stems": {...}}
    """
    # Validate video_id format
    if not re.match(r'^[a-zA-Z0-9_-]+$', video_id):
        return jsonify({'error': 'Invalid video_id'}), 400

    output_dir = os.path.join(STEMS_DIR, video_id)

    if not os.path.exists(output_dir):
        return jsonify({'exists': False, 'stems': {}})

    stems = {}
    for stem_name in ['vocals', 'no_vocals', 'drums', 'bass', 'other']:
        stem_path = os.path.join(output_dir, f'{stem_name}.wav')
        if os.path.exists(stem_path):
            stems[stem_name] = stem_path

    return jsonify({
        'exists': len(stems) > 0,
        'stems': stems
    })


if __name__ == '__main__':
    print(f"YouTube Audio Server starting...")
    print(f"Temp directory: {TEMP_DIR}")
    print(f"Stems directory: {STEMS_DIR}")
    print(f"Endpoints:")
    print(f"  GET  /api/health       - Health check")
    print(f"  POST /api/info         - Get video info")
    print(f"  POST /api/download     - Download audio as MP3")
    print(f"  POST /api/formats      - Get available formats")
    print(f"  POST /api/separate     - Separate audio into stems")
    print(f"  GET  /api/stems/<id>   - Check if stems exist")
    print(f"\nListening on http://0.0.0.0:9005")
    # threaded=True for concurrent requests, debug=False to avoid file descriptor leaks
    app.run(host='0.0.0.0', port=9005, debug=False, threaded=True)
