"""Integration with a TGI server."""

import json
from functools import singledispatchmethod
from typing import (
    TYPE_CHECKING,
    Any,
    AsyncIterator,
    Iterator,
    Optional,
    Union,
)

from outlines.models.base import AsyncModel,Model, ModelTypeAdapter
from outlines.types.dsl import python_types_to_terms, to_regex, JsonSchema, CFG

if TYPE_CHECKING:
    from huggingface_hub import AsyncInferenceClient, InferenceClient

__all__ = ["AsyncTGI", "TGI", "from_tgi"]


class TGITypeAdapter(ModelTypeAdapter):
    """Type adapter for the `TGI` and `AsyncTGI` models."""

    @singledispatchmethod
    def format_input(self, model_input):
        """Generate the prompt argument to pass to the client.

        Argument
        --------
        model_input
            The input passed by the user.

        Returns
        -------
        str
            The formatted input to be passed to the model.

        """
        raise NotImplementedError(
            f"The input type {input} is not available with TGI. "
            + "The only available type is `str`."
        )

    @format_input.register(str)
    def format_str_input(self, model_input: str) -> str:
        return model_input

    def format_output_type(self, output_type: Optional[Any] = None) -> dict:
        """Generate the structured output argument to pass to the client.

        Argument
        --------
        output_type
            The structured output type provided.

        Returns
        -------
        dict
            The structured output argument to pass to the client.

        """
        if output_type is None:
            return {}

        term = python_types_to_terms(output_type)
        if isinstance(term, CFG):
            raise NotImplementedError(
                "TGI does not support CFG-based structured outputs."
            )
        elif isinstance(term, JsonSchema):
            return {
                "grammar": {
                    "type": "json",
                    "value": json.loads(term.schema),
                }
            }
        else:
            return {
                "grammar": {
                    "type": "regex",
                    "value": to_regex(term),
                }
            }


class TGI(Model):
    """Thin wrapper around a `huggingface_hub.InferenceClient` client used to
    communicate with a `TGI` server.

    This wrapper is used to convert the input and output types specified by the
    users at a higher level to arguments to the
    `huggingface_hub.InferenceClient` client.

    """

    def __init__(self, client):
        """
        Parameters
        ----------
        client
            A huggingface `InferenceClient` client instance.

        """
        self.client = client
        self.type_adapter = TGITypeAdapter()

    def generate(
        self,
        model_input: str,
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> str:
        """Generate text using TGI.

        Parameters
        ----------
        model_input
            The prompt based on which the model will generate a response.
        output_type
            The desired format of the response generated by the model. All
            output types except `CFG` are supported provided your server uses
            a backend that supports them.
        inference_kwargs
            Additional keyword arguments to pass to the client.

        Returns
        -------
        str
            The text generated by the model.

        """
        client_args = self._build_client_args(
            model_input,
            output_type,
            **inference_kwargs,
        )

        return self.client.text_generation(**client_args)

    def generate_batch(
        self,
        model_input,
        output_type = None,
        **inference_kwargs,
    ):
        raise NotImplementedError("TGI does not support batch inference.")

    def generate_stream(
        self,
        model_input: str,
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> Iterator[str]:
        """Stream text using TGI.

        Parameters
        ----------
        model_input
            The prompt based on which the model will generate a response.
        output_type
            The desired format of the response generated by the model. All
            output types except `CFG` are supported provided your server uses
            a backend that supports them.
        inference_kwargs
            Additional keyword arguments to pass to the client.

        Returns
        -------
        Iterator[str]
            An iterator that yields the text generated by the model.

        """
        client_args = self._build_client_args(
            model_input, output_type, **inference_kwargs,
        )

        stream = self.client.text_generation(
            **client_args, stream=True,
        )

        for chunk in stream:  # pragma: no cover
            yield chunk

    def _build_client_args(
        self,
        model_input: str,
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> dict:
        """Build the arguments to pass to the TGI client."""
        prompt = self.type_adapter.format_input(model_input)
        output_type_args = self.type_adapter.format_output_type(output_type)
        inference_kwargs.update(output_type_args)

        client_args = {
            "prompt": prompt,
            **inference_kwargs,
        }

        return client_args


class AsyncTGI(AsyncModel):
    """Thin async wrapper around a `huggingface_hub.AsyncInferenceClient`
    client used to communicate with a `TGI` server.

    This wrapper is used to convert the input and output types specified by the
    users at a higher level to arguments to the
    `huggingface_hub.AsyncInferenceClient` client.

    """

    def __init__(self, client):
        """
        Parameters
        ----------
        client
            A huggingface `AsyncInferenceClient` client instance.

        """
        self.client = client
        self.type_adapter = TGITypeAdapter()

    async def generate(
        self,
        model_input: str,
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> str:
        """Generate text using TGI.

        Parameters
        ----------
        model_input
            The prompt based on which the model will generate a response.
        output_type
            The desired format of the response generated by the model. All
            output types except `CFG` are supported provided your server uses
            a backend that supports them.
        inference_kwargs
            Additional keyword arguments to pass to the client.

        Returns
        -------
        str
            The text generated by the model.

        """
        client_args = self._build_client_args(
            model_input, output_type, **inference_kwargs,
        )

        response = await self.client.text_generation(**client_args)

        return response

    async def generate_batch(
        self,
        model_input,
        output_type = None,
        **inference_kwargs,
    ):
        raise NotImplementedError("TGI does not support batch inference.")

    async def generate_stream( # type: ignore
        self,
        model_input: str,
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> AsyncIterator[str]:
        """Stream text using TGI.

        Parameters
        ----------
        model_input
            The prompt based on which the model will generate a response.
        output_type
            The desired format of the response generated by the model. All
            output types except `CFG` are supported provided your server uses
            a backend that supports them.
        inference_kwargs
            Additional keyword arguments to pass to the client.

        Returns
        -------
        AsyncIterator[str]
            An async iterator that yields the text generated by the model.

        """
        client_args = self._build_client_args(
            model_input, output_type, **inference_kwargs,
        )

        stream = await self.client.text_generation(
            **client_args, stream=True
        )

        async for chunk in stream:  # pragma: no cover
            yield chunk

    def _build_client_args(
        self,
        model_input: str,
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> dict:
        """Build the arguments to pass to the TGI client."""
        prompt = self.type_adapter.format_input(model_input)
        output_type_args = self.type_adapter.format_output_type(output_type)
        inference_kwargs.update(output_type_args)

        client_args = {
            "prompt": prompt,
            **inference_kwargs,
        }

        return client_args


def from_tgi(
    client: Union["InferenceClient", "AsyncInferenceClient"],
) -> Union[TGI, AsyncTGI]:
    """Create an Outlines `TGI` or `AsyncTGI` model instance from an
    `huggingface_hub.InferenceClient` or `huggingface_hub.AsyncInferenceClient`
    instance.

    Parameters
    ----------
    client
        An `huggingface_hub.InferenceClient` or
        `huggingface_hub.AsyncInferenceClient` instance.

    Returns
    -------
    Union[TGI, AsyncTGI]
        An Outlines `TGI` or `AsyncTGI` model instance.

    """
    from huggingface_hub import AsyncInferenceClient, InferenceClient

    if isinstance(client, InferenceClient):
        return TGI(client)
    elif isinstance(client, AsyncInferenceClient):
        return AsyncTGI(client)
    else:
        raise ValueError(
            f"Unsupported client type: {type(client)}.\n"
            + "Please provide an HuggingFace InferenceClient "
            + "or AsyncInferenceClient instance."
        )
