"""Integration with a vLLM server."""

import json
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, Optional, Union

from outlines.inputs import Chat
from outlines.models.base import AsyncModel,Model, ModelTypeAdapter
from outlines.models.openai import OpenAITypeAdapter
from outlines.types.dsl import CFG, JsonSchema, python_types_to_terms, to_regex

if TYPE_CHECKING:
    from openai import AsyncOpenAI, OpenAI

__all__ = ["VLLM", "AsyncVLLM", "from_vllm"]


class VLLMTypeAdapter(ModelTypeAdapter):
    """Type adapter for the `VLLM` and `AsyncVLLM` models."""

    def format_input(self, model_input: Union[Chat, str, list]) -> list:
        """Generate the value of the messages argument to pass to the client.

        We rely on the OpenAITypeAdapter to format the input as the vLLM server
        expects input in the same format as OpenAI.

        Parameters
        ----------
        model_input
            The input passed by the user.

        Returns
        -------
        list
            The formatted input to be passed to the model.

        """
        return OpenAITypeAdapter().format_input(model_input)

    def format_output_type(self, output_type: Optional[Any] = None) -> dict:
        """Generate the structured output argument to pass to the client.

        Parameters
        ----------
        output_type
            The structured output type provided.

        Returns
        -------
        dict
            The structured output argument to pass to the model.

        """
        if output_type is None:
            return {}

        term = python_types_to_terms(output_type)
        if isinstance(term, CFG):
            return {"guided_grammar": term.definition}
        elif isinstance(term, JsonSchema):
            extra_body = {"guided_json": json.loads(term.schema)}
            if term.whitespace_pattern:
                extra_body["whitespace_pattern"] = term.whitespace_pattern
            return extra_body
        else:
            return {"guided_regex": to_regex(term)}


class VLLM(Model):
    """Thin wrapper around the `openai.OpenAI` client used to communicate with
    a `vllm` server.

    This wrapper is used to convert the input and output types specified by the
    users at a higher level to arguments to the `openai.OpenAI` client for the
    `vllm` server.
    """

    def __init__(
        self,
        client: "OpenAI",
        model_name: Optional[str] = None,
    ):
        """
        Parameters
        ----------
        client
            An `openai.OpenAI` client instance.

        """
        self.client = client
        self.model_name = model_name
        self.type_adapter = VLLMTypeAdapter()

    def generate(
        self,
        model_input: Union[Chat, str, list],
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> Union[str, list[str]]:
        """Generate text using vLLM.

        Parameters
        ----------
        model_input
            The prompt based on which the model will generate a response.
        output_type
            The desired format of the response generated by the model. All
            output types available in Outlines are supported provided your
            server uses a structured generation backend that supports them.
        inference_kwargs
            Additional keyword arguments to pass to the client.

        Returns
        -------
        Union[str, list[str]]
            The text generated by the model.

        """
        client_args = self._build_client_args(
            model_input,
            output_type,
            **inference_kwargs,
        )

        response = self.client.chat.completions.create(**client_args)

        messages = [choice.message for choice in response.choices]
        for message in messages:
            if message.refusal is not None:  # pragma: no cover
                raise ValueError(
                    f"The vLLM server refused to answer the request: "
                    f"{message.refusal}"
                )

        if len(messages) == 1:
            return messages[0].content
        else:
            return [message.content for message in messages]

    def generate_batch(
        self,
        model_input,
        output_type = None,
        **inference_kwargs,
    ):
        raise NotImplementedError("VLLM does not support batch inference.")

    def generate_stream(
        self,
        model_input: Union[Chat, str, list],
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> Iterator[str]:
        """Stream text using vLLM.

        Parameters
        ----------
        model_input
            The prompt based on which the model will generate a response.
        output_type
            The desired format of the response generated by the model. All
            output types available in Outlines are supported provided your
            server uses a structured generation backend that supports them.
        inference_kwargs
            Additional keyword arguments to pass to the client.

        Returns
        -------
        Iterator[str]
            An iterator that yields the text generated by the model.

        """
        client_args = self._build_client_args(
            model_input, output_type, **inference_kwargs,
        )

        stream = self.client.chat.completions.create(
            **client_args, stream=True,
        )

        for chunk in stream:  # pragma: no cover
            if chunk.choices and chunk.choices[0].delta.content is not None:
                yield chunk.choices[0].delta.content

    def _build_client_args(
        self,
        model_input: Union[Chat, str, list],
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> dict:
        """Build the arguments to pass to the OpenAI client."""
        messages = self.type_adapter.format_input(model_input)
        output_type_args = self.type_adapter.format_output_type(output_type)
        extra_body = inference_kwargs.pop("extra_body", {})
        extra_body.update(output_type_args)

        if "model" not in inference_kwargs and self.model_name is not None:
            inference_kwargs["model"] = self.model_name

        client_args = {
            "messages": messages,
            **inference_kwargs,
        }
        if extra_body:
            client_args["extra_body"] = extra_body

        return client_args


class AsyncVLLM(AsyncModel):
    """Thin async wrapper around the `openai.OpenAI` client used to communicate
    with a `vllm` server.

    This wrapper is used to convert the input and output types specified by the
    users at a higher level to arguments to the `openai.OpenAI` client for the
    `vllm` server.
    """

    def __init__(
        self,
        client: "AsyncOpenAI",
        model_name: Optional[str] = None,
    ):
        """
        Parameters
        ----------
        client
            An `openai.AsyncOpenAI` client instance.

        """
        self.client = client
        self.model_name = model_name
        self.type_adapter = VLLMTypeAdapter()

    async def generate(
        self,
        model_input: Union[Chat, str, list],
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> Union[str, list[str]]:
        """Generate text using vLLM.

        Parameters
        ----------
        model_input
            The prompt based on which the model will generate a response.
        output_type
            The desired format of the response generated by the model. All
            output types available in Outlines are supported provided your
            server uses a structured generation backend that supports them.
        inference_kwargs
            Additional keyword arguments to pass to the client.

        Returns
        -------
        Union[str, list[str]]
            The text generated by the model.

        """
        client_args = self._build_client_args(
            model_input, output_type, **inference_kwargs,
        )

        response = await self.client.chat.completions.create(**client_args)

        messages = [choice.message for choice in response.choices]
        for message in messages:
            if message.refusal is not None:  # pragma: no cover
                raise ValueError(
                    f"The vLLM server refused to answer the request: "
                    f"{message.refusal}"
                )

        if len(messages) == 1:
            return messages[0].content
        else:
            return [message.content for message in messages]

    async def generate_batch(
        self,
        model_input,
        output_type = None,
        **inference_kwargs,
    ):
        raise NotImplementedError("VLLM does not support batch inference.")

    async def generate_stream( # type: ignore
        self,
        model_input: Union[Chat, str, list],
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> AsyncIterator[str]:
        """Stream text using vLLM.

        Parameters
        ----------
        model_input
            The prompt based on which the model will generate a response.
        output_type
            The desired format of the response generated by the model. All
            output types available in Outlines are supported provided your
            server uses a structured generation backend that supports them.
        inference_kwargs
            Additional keyword arguments to pass to the client.

        Returns
        -------
        AsyncIterator[str]
            An async iterator that yields the text generated by the model.
        """
        client_args = self._build_client_args(
            model_input, output_type, **inference_kwargs,
        )

        stream = await self.client.chat.completions.create(
            **client_args,
            stream=True,
        )

        async for chunk in stream:  # pragma: no cover
            if chunk.choices and chunk.choices[0].delta.content is not None:
                yield chunk.choices[0].delta.content

    def _build_client_args(
        self,
        model_input: Union[Chat, str, list],
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> dict:
        """Build the arguments to pass to the OpenAI client."""
        messages = self.type_adapter.format_input(model_input)
        output_type_args = self.type_adapter.format_output_type(output_type)
        extra_body = inference_kwargs.pop("extra_body", {})
        extra_body.update(output_type_args)

        if "model" not in inference_kwargs and self.model_name is not None:
            inference_kwargs["model"] = self.model_name

        client_args = {
            "messages": messages,
            **inference_kwargs,
        }
        if extra_body:
            client_args["extra_body"] = extra_body

        return client_args


def from_vllm(
    client: Union["OpenAI", "AsyncOpenAI"],
    model_name: Optional[str] = None,
) -> Union[VLLM, AsyncVLLM]:
    """Create an Outlines `VLLM` or `AsyncVLLM` model instance from an
    `openai.OpenAI` or `openai.AsyncOpenAI` instance.

    Parameters
    ----------
    client
        An `openai.OpenAI` or `openai.AsyncOpenAI` instance.
    model_name
        The name of the model to use.

    Returns
    -------
    Union[VLLM, AsyncVLLM]
        An Outlines `VLLM` or `AsyncVLLM` model instance.

    """
    from openai import AsyncOpenAI, OpenAI

    if isinstance(client, OpenAI):
        return VLLM(client, model_name)
    elif isinstance(client, AsyncOpenAI):
        return AsyncVLLM(client, model_name)
    else:
        raise ValueError(
            f"Unsupported client type: {type(client)}.\n"
            "Please provide an OpenAI or AsyncOpenAI instance."
        )
