import asyncio
import json
import logging
import traceback
from websockets import ServerConnection

from lms_communication.channel import Channel


async def _handle_on_creation_message(
    *,
    data: str | bytes,
    channel_id: int,
    namespace: str,
    endpoint_handlers: dict,
    websocket: ServerConnection,
    channel_map: dict,
    model_module,
):
    """
    Start a long-running task to resolve the channel task.

    Listens to two events:
    - Request task completion, which means the channel is done.
    - The abort signal, which means that there is an error in the channel.

    If the request task completes, send a channelClose message to the client.
    If there is an exception, send a channelError message to the client.
    """
    try:
        channel_map[channel_id] = Channel(
            endpoint=data["endpoint"],
            channel_id=channel_id,
            namespace=namespace,
            endpoint_handlers=endpoint_handlers,
            websocket=websocket,
            model_module=model_module,
        )
        request_task = asyncio.create_task(
            channel_map[channel_id].handle_request(data["creationParameter"])
        )
        abort_signal = asyncio.create_task(channel_map[channel_id].abort_signal.wait())
        done, pending = await asyncio.wait(
            [request_task, abort_signal],
            return_when=asyncio.FIRST_COMPLETED,
        )

        for task in pending:
            task.cancel()

        for task in done:
            await task  # Raise any exceptions that may have occurred

        if abort_signal in done:
            raise Exception(channel_map[channel_id].abort_traceback)

        if request_task in done:
            channel_done_msg = {
                "type": "channelClose",
                "channelId": channel_id,
            }
            del channel_map[channel_id]
            await websocket.send(json.dumps(channel_done_msg))

    except Exception as e:
        logging.error(f"Exception in channel {channel_id}: {e}")
        if channel_id in channel_map:
            # Send an error message to the client
            error_msg = {
                "type": "channelError",
                "channelId": channel_id,
                "error": {
                    "cause": traceback.format_exc(),
                },
            }
            del channel_map[channel_id]
            await websocket.send(json.dumps(error_msg))


async def _handle_client_message(
    *, data: str | bytes, channel_id: int, channel_map: dict
):
    """
    Handle a message from the client.

    If the handler returns an exception, log it and set the abort signal for the channel.
    """
    if channel_id not in channel_map:
        raise ValueError(f"Channel ID {channel_id} not found")
    try:
        channel = channel_map[channel_id]
        await channel.on_client_message(data)
    except Exception:
        traceback_str = traceback.format_exc()
        logging.error(f"Exception in channel {channel_id}: {traceback_str}")
        channel.set_abort_with_traceback(traceback_str)


async def _handle_rpc_call_message(
    *,
    data: str | bytes,
    namespace: str,
    endpoint_handlers: dict,
    websocket: ServerConnection,
    model_module,
):
    """
    Handle an RPC call message from the client.
    """
    call_id = data["callId"]
    try:
        endpoint = data["endpoint"]
        endpoint_pascal = endpoint[0].upper() + endpoint[1:]
        endpoint_preamble = f"{namespace}Rpc{endpoint_pascal}"

        rpc_parameter_cls = model_module.__dict__[f"{endpoint_preamble}Parameter"]
        rpc_returns_cls = model_module.__dict__[f"{endpoint_preamble}Returns"]

        handler_cls = endpoint_handlers.get(endpoint_pascal)
        if handler_cls is None:
            raise ValueError(f"No handler found for endpoint {endpoint_pascal}")

        handler = handler_cls()

        parameter = rpc_parameter_cls.model_validate_json(json.dumps(data["parameter"]))

        result = await handler.handle_request(parameter)
        if not isinstance(result, rpc_returns_cls):
            raise TypeError(
                f"Expected result of type {rpc_returns_cls}, got {type(result)}"
            )
    except Exception:
        error_msg = {
            "type": "rpcError",
            "callId": call_id,
            "error": {
                "cause": traceback.format_exc(),
            },
        }
        await websocket.send(json.dumps(error_msg))
        return
    result_msg = {
        "type": "rpcResult",
        "callId": call_id,
        "result": result.model_dump(mode="json"),
    }
    await websocket.send(json.dumps(result_msg))


async def handle_incoming_message(
    *,
    data: str | bytes,
    channel_map: dict,
    namespace: str,
    endpoint_handlers: dict,
    websocket: ServerConnection,
    model_module,
):
    """
    Handle a single message originating from the websocket.

    There will be multiple copies of this coroutine running concurrently,
    each handling a different request. Each request will run from top to bottom of this function,
    and will not block other requests from being processed.

    There may be multiple messages for the same channel being processed concurrently, so ensure that
    actions are idempotent.
    """
    data = json.loads(data)
    message_type = data["type"]

    if message_type == "channelCreate":
        channel_id = data["channelId"]
        await _handle_on_creation_message(
            data=data,
            channel_id=channel_id,
            namespace=namespace,
            endpoint_handlers=endpoint_handlers,
            websocket=websocket,
            channel_map=channel_map,
            model_module=model_module,
        )
    elif message_type == "channelSend":
        channel_id = data["channelId"]
        await _handle_client_message(
            data=data, channel_id=channel_id, channel_map=channel_map
        )
    elif message_type == "rpcCall":
        await _handle_rpc_call_message(
            data=data,
            namespace=namespace,
            endpoint_handlers=endpoint_handlers,
            websocket=websocket,
            model_module=model_module,
        )
    # Log a warning for unknown message types
    else:
        logging.warning(f"Unknown message type: {message_type}")
        return
