#!/usr/bin/env python3
"""
Domain Validator Server
Validates domains via DNS lookup + HTTP request
Optimized for large lists (300k+ domains)
"""

import asyncio
import aiohttp
import socket
import json
import time
import threading
from flask import Flask, request, jsonify, Response
from flask_cors import CORS
from dataclasses import dataclass, asdict
from typing import Optional
from concurrent.futures import ThreadPoolExecutor
import queue

app = Flask(__name__)
CORS(app)

# ============================================
# Configuration
# ============================================

CONFIG = {
    'DNS_TIMEOUT': 3,           # seconds
    'HTTP_TIMEOUT': 8,          # seconds
    'CONCURRENT_REQUESTS': 100, # parallel validations
    'BATCH_SIZE': 100,          # domains per batch update
}

# ============================================
# Data Classes
# ============================================

@dataclass
class DomainResult:
    domain: str
    status: str  # 'online', 'offline', 'pending', 'error'
    dns_ok: bool = False
    http_ok: bool = False
    http_status: Optional[int] = None
    response_time: Optional[float] = None
    error: Optional[str] = None

# ============================================
# Validation State
# ============================================

class ValidationState:
    def __init__(self):
        self.reset()

    def reset(self):
        self.domains = []
        self.results = {}
        self.total = 0
        self.processed = 0
        self.online_count = 0
        self.offline_count = 0
        self.error_count = 0
        self.running = False
        self.paused = False
        self.start_time = None
        self.end_time = None
        self.current_batch = []

state = ValidationState()
state_lock = threading.Lock()

# ============================================
# DNS Validation
# ============================================

def check_dns(domain: str) -> tuple[bool, Optional[str]]:
    """Check if domain resolves via DNS"""
    try:
        socket.setdefaulttimeout(CONFIG['DNS_TIMEOUT'])
        socket.gethostbyname(domain)
        return True, None
    except socket.gaierror as e:
        return False, f"DNS_FAIL: {str(e)}"
    except socket.timeout:
        return False, "DNS_TIMEOUT"
    except Exception as e:
        return False, f"DNS_ERROR: {str(e)}"

# ============================================
# HTTP Validation
# ============================================

async def check_http(session: aiohttp.ClientSession, domain: str) -> tuple[bool, Optional[int], Optional[float], Optional[str]]:
    """Check if domain responds to HTTP/HTTPS"""
    urls = [f"https://{domain}", f"http://{domain}"]

    for url in urls:
        try:
            start = time.time()
            async with session.get(
                url,
                timeout=aiohttp.ClientTimeout(total=CONFIG['HTTP_TIMEOUT']),
                allow_redirects=True,
                ssl=False  # Skip SSL verification for speed
            ) as response:
                elapsed = time.time() - start
                # Any response means the server is online
                return True, response.status, elapsed, None
        except aiohttp.ClientConnectorError:
            continue
        except asyncio.TimeoutError:
            continue
        except Exception:
            continue

    return False, None, None, "HTTP_FAIL: No response"

# ============================================
# Main Validation Logic
# ============================================

async def validate_domain(session: aiohttp.ClientSession, domain: str, executor: ThreadPoolExecutor) -> DomainResult:
    """Validate a single domain with DNS + HTTP"""
    result = DomainResult(domain=domain, status='pending')

    # Step 1: DNS Check (run in thread pool to not block)
    loop = asyncio.get_event_loop()
    dns_ok, dns_error = await loop.run_in_executor(executor, check_dns, domain)
    result.dns_ok = dns_ok

    if not dns_ok:
        result.status = 'offline'
        result.error = dns_error
        return result

    # Step 2: HTTP Check
    http_ok, http_status, response_time, http_error = await check_http(session, domain)
    result.http_ok = http_ok
    result.http_status = http_status
    result.response_time = round(response_time, 3) if response_time else None

    if http_ok:
        result.status = 'online'
    else:
        result.status = 'offline'
        result.error = http_error

    return result

async def validate_batch(domains: list[str]) -> list[DomainResult]:
    """Validate a batch of domains concurrently"""
    connector = aiohttp.TCPConnector(
        limit=CONFIG['CONCURRENT_REQUESTS'],
        force_close=True,
        enable_cleanup_closed=True
    )

    timeout = aiohttp.ClientTimeout(total=CONFIG['HTTP_TIMEOUT'] + 5)

    async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
        with ThreadPoolExecutor(max_workers=CONFIG['CONCURRENT_REQUESTS']) as executor:
            tasks = [validate_domain(session, domain, executor) for domain in domains]
            results = await asyncio.gather(*tasks, return_exceptions=True)

            # Handle any exceptions
            processed_results = []
            for i, result in enumerate(results):
                if isinstance(result, Exception):
                    processed_results.append(DomainResult(
                        domain=domains[i],
                        status='error',
                        error=str(result)
                    ))
                else:
                    processed_results.append(result)

            return processed_results

def run_validation():
    """Main validation loop running in background thread"""
    global state

    async def async_validation():
        batch_size = CONFIG['BATCH_SIZE']

        for i in range(0, len(state.domains), batch_size):
            if not state.running:
                break

            while state.paused and state.running:
                await asyncio.sleep(0.5)

            batch = state.domains[i:i + batch_size]

            with state_lock:
                state.current_batch = batch

            results = await validate_batch(batch)

            with state_lock:
                for result in results:
                    state.results[result.domain] = result
                    state.processed += 1

                    if result.status == 'online':
                        state.online_count += 1
                    elif result.status == 'offline':
                        state.offline_count += 1
                    else:
                        state.error_count += 1

        with state_lock:
            state.running = False
            state.end_time = time.time()

    asyncio.run(async_validation())

# ============================================
# API Endpoints
# ============================================

@app.route('/api/validate/start', methods=['POST'])
def start_validation():
    """Start domain validation"""
    global state

    if state.running:
        return jsonify({'error': 'Validation already running'}), 400

    data = request.get_json()

    if not data or 'domains' not in data:
        return jsonify({'error': 'No domains provided'}), 400

    domains = data['domains']

    if not isinstance(domains, list) or len(domains) == 0:
        return jsonify({'error': 'Invalid domains list'}), 400

    # Clean and deduplicate domains
    clean_domains = list(set([
        d.strip().lower().replace('http://', '').replace('https://', '').rstrip('/')
        for d in domains if d.strip()
    ]))

    with state_lock:
        state.reset()
        state.domains = clean_domains
        state.total = len(clean_domains)
        state.running = True
        state.start_time = time.time()

    # Start validation in background thread
    thread = threading.Thread(target=run_validation, daemon=True)
    thread.start()

    return jsonify({
        'status': 'started',
        'total': state.total,
        'message': f'Validating {state.total} domains'
    })

@app.route('/api/validate/status', methods=['GET'])
def get_status():
    """Get current validation status"""
    with state_lock:
        elapsed = 0
        rate = 0
        eta = 0

        if state.start_time:
            if state.end_time:
                elapsed = state.end_time - state.start_time
            else:
                elapsed = time.time() - state.start_time

            if elapsed > 0 and state.processed > 0:
                rate = state.processed / elapsed
                remaining = state.total - state.processed
                eta = remaining / rate if rate > 0 else 0

        return jsonify({
            'running': state.running,
            'paused': state.paused,
            'total': state.total,
            'processed': state.processed,
            'online': state.online_count,
            'offline': state.offline_count,
            'errors': state.error_count,
            'progress': round((state.processed / state.total * 100), 1) if state.total > 0 else 0,
            'elapsed': round(elapsed, 1),
            'rate': round(rate, 1),
            'eta': round(eta, 1),
            'current_batch': state.current_batch[:5] if state.current_batch else []
        })

@app.route('/api/validate/results', methods=['GET'])
def get_results():
    """Get validation results"""
    filter_status = request.args.get('status', None)
    offset = int(request.args.get('offset', 0))
    limit = int(request.args.get('limit', 1000))

    with state_lock:
        results = list(state.results.values())

        if filter_status:
            results = [r for r in results if r.status == filter_status]

        # Sort by status (online first), then by domain
        results.sort(key=lambda r: (0 if r.status == 'online' else 1, r.domain))

        paginated = results[offset:offset + limit]

        return jsonify({
            'total': len(results),
            'offset': offset,
            'limit': limit,
            'results': [asdict(r) for r in paginated]
        })

@app.route('/api/validate/export', methods=['GET'])
def export_results():
    """Export results as text file"""
    filter_status = request.args.get('status', None)

    with state_lock:
        results = list(state.results.values())

        if filter_status:
            results = [r for r in results if r.status == filter_status]

        domains = sorted([r.domain for r in results])
        content = '\n'.join(domains)

        filename = f"domains_{filter_status or 'all'}.txt"

        return Response(
            content,
            mimetype='text/plain',
            headers={'Content-Disposition': f'attachment; filename={filename}'}
        )

@app.route('/api/validate/pause', methods=['POST'])
def pause_validation():
    """Pause validation"""
    with state_lock:
        state.paused = True
    return jsonify({'status': 'paused'})

@app.route('/api/validate/resume', methods=['POST'])
def resume_validation():
    """Resume validation"""
    with state_lock:
        state.paused = False
    return jsonify({'status': 'resumed'})

@app.route('/api/validate/stop', methods=['POST'])
def stop_validation():
    """Stop validation"""
    with state_lock:
        state.running = False
        state.end_time = time.time()
    return jsonify({'status': 'stopped'})

@app.route('/api/validate/restart', methods=['POST'])
def restart_validation():
    """Restart validation from where it stopped"""
    global state

    if state.running:
        return jsonify({'error': 'Validation already running'}), 400

    if not state.domains or len(state.domains) == 0:
        return jsonify({'error': 'No domains to validate. Upload first.'}), 400

    # Filter out already processed domains
    remaining = [d for d in state.domains if d not in state.results]

    if len(remaining) == 0:
        return jsonify({'error': 'All domains already processed'}), 400

    with state_lock:
        state.domains = remaining
        state.total = len(state.domains) + len(state.results)
        state.running = True
        state.paused = False
        state.end_time = None
        if not state.start_time:
            state.start_time = time.time()

    # Start validation in background thread
    thread = threading.Thread(target=run_validation, daemon=True)
    thread.start()

    return jsonify({
        'status': 'restarted',
        'remaining': len(remaining),
        'already_processed': len(state.results),
        'message': f'Continuing validation of {len(remaining)} remaining domains'
    })

@app.route('/api/health', methods=['GET'])
def health_check():
    """Health check endpoint"""
    return jsonify({'status': 'ok', 'service': 'domain-validator'})

# ============================================
# Main
# ============================================

if __name__ == '__main__':
    print("=" * 50)
    print("  DOMAIN VALIDATOR SERVER")
    print("=" * 50)
    print(f"\n  API running at: http://localhost:5001")
    print(f"  Concurrent requests: {CONFIG['CONCURRENT_REQUESTS']}")
    print(f"  DNS timeout: {CONFIG['DNS_TIMEOUT']}s")
    print(f"  HTTP timeout: {CONFIG['HTTP_TIMEOUT']}s")
    print("\n" + "=" * 50 + "\n")

    app.run(host='0.0.0.0', port=5001, debug=False, threaded=True)
