"""Integration with an SGLang server."""

import json
import warnings
from typing import (
    TYPE_CHECKING, Any, AsyncIterator, Iterator, Optional, Union
)

from outlines.inputs import Chat
from outlines.models.base import AsyncModel, Model, ModelTypeAdapter
from outlines.models.openai import OpenAITypeAdapter
from outlines.types.dsl import (
    CFG,
    JsonSchema,
    python_types_to_terms,
    to_regex,
)

if TYPE_CHECKING:
    from openai import AsyncOpenAI, OpenAI

__all__ = ["AsyncSGLang", "SGLang", "from_sglang"]


class SGLangTypeAdapter(ModelTypeAdapter):
    """Type adapter for the `SGLang` and `AsyncSGLang` models."""

    def format_input(self, model_input: Union[Chat, list, str]) -> list:
        """Generate the value of the messages argument to pass to the client.

        We rely on the OpenAITypeAdapter to format the input as the sglang
        server expects input in the same format as OpenAI.

        Parameters
        ----------
        model_input
            The input passed by the user.

        Returns
        -------
        list
            The formatted input to be passed to the client.

        """
        return OpenAITypeAdapter().format_input(model_input)

    def format_output_type(self, output_type: Optional[Any] = None) -> dict:
        """Generate the structured output argument to pass to the client.

        Parameters
        ----------
        output_type
            The structured output type provided.

        Returns
        -------
        dict
            The formatted output type to be passed to the client.

        """
        if output_type is None:
            return {}

        term = python_types_to_terms(output_type)
        if isinstance(term, CFG):
            warnings.warn(
                "SGLang grammar-based structured outputs expects an EBNF "
                "grammar instead of a Lark grammar as is generally used in "
                "Outlines. The grammar cannot be used as a structured output "
                "type with an outlines backend, it is only compatible with "
                "the sglang and llguidance backends."
            )
            return {"extra_body": {"ebnf": term.definition}}
        elif isinstance(term, JsonSchema):
            return OpenAITypeAdapter().format_json_output_type(
                json.loads(term.schema)
            )
        else:
            return {"extra_body": {"regex": to_regex(term)}}


class SGLang(Model):
    """Thin wrapper around the `openai.OpenAI` client used to communicate with
    an SGLang server.

    This wrapper is used to convert the input and output types specified by the
    users at a higher level to arguments to the `openai.OpenAI` client for the
    SGLang server.

    """

    def __init__(self, client, model_name: Optional[str] = None):
        """
        Parameters
        ----------
        client
            An `openai.OpenAI` client instance.
        model_name
            The name of the model to use.

        """
        self.client = client
        self.model_name = model_name
        self.type_adapter = SGLangTypeAdapter()

    def generate(
        self,
        model_input: Union[Chat, list, str],
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> Union[str, list[str]]:
        """Generate text using SGLang.

        Parameters
        ----------
        model_input
            The prompt based on which the model will generate a response.
        output_type
            The desired format of the response generated by the model. All
            output types available in Outlines are supported provided your
            server uses a structured generation backend that supports them.
        inference_kwargs
            Additional keyword arguments to pass to the client.

        Returns
        -------
        Union[str, list[str]]
            The text generated by the model.

        """
        client_args = self._build_client_args(
            model_input,
            output_type,
            **inference_kwargs,
        )

        response = self.client.chat.completions.create(**client_args)

        messages = [choice.message for choice in response.choices]
        for message in messages:
            if message.refusal is not None:  # pragma: no cover
                raise ValueError(
                    f"The SGLang server refused to answer the request: "
                    f"{message.refusal}"
                )

        if len(messages) == 1:
            return messages[0].content
        else:
            return [message.content for message in messages]

    def generate_batch(
        self,
        model_input,
        output_type = None,
        **inference_kwargs,
    ):
        raise NotImplementedError(
            "SGLang does not support batch inference."
        )

    def generate_stream(
        self,
        model_input: Union[Chat, list, str],
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> Iterator[str]:
        """Stream text using SGLang.

        Parameters
        ----------
        model_input
            The prompt based on which the model will generate a response.
        output_type
            The desired format of the response generated by the model. All
            output types available in Outlines are supported provided your
            server uses a structured generation backend that supports them.
        inference_kwargs
            Additional keyword arguments to pass to the client.

        Returns
        -------
        Iterator[str]
            An iterator that yields the text generated by the model.

        """
        client_args = self._build_client_args(
            model_input, output_type, **inference_kwargs,
        )

        stream = self.client.chat.completions.create(
            **client_args, stream=True,
        )

        for chunk in stream:  # pragma: no cover
            if chunk.choices and chunk.choices[0].delta.content is not None:
                yield chunk.choices[0].delta.content

    def _build_client_args(
        self,
        model_input: Union[Chat, str, list],
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> dict:
        """Build the arguments to pass to the SGLang client."""
        messages = self.type_adapter.format_input(model_input)
        output_type_args = self.type_adapter.format_output_type(output_type)
        inference_kwargs.update(output_type_args)

        if "model" not in inference_kwargs and self.model_name is not None:
            inference_kwargs["model"] = self.model_name

        client_args = {
            "messages": messages,
            **inference_kwargs,
        }

        return client_args


class AsyncSGLang(AsyncModel):
    """Thin async wrapper around the `openai.OpenAI` client used to communicate
    with an SGLang server.

    This wrapper is used to convert the input and output types specified by the
    users at a higher level to arguments to the `openai.OpenAI` client for the
    SGLang server.

    """

    def __init__(self, client, model_name: Optional[str] = None):
        """
        Parameters
        ----------
        client
            An `openai.AsyncOpenAI` client instance.
        model_name
            The name of the model to use.

        Parameters
        ----------
        client
            An `openai.AsyncOpenAI` client instance.

        """
        self.client = client
        self.model_name = model_name
        self.type_adapter = SGLangTypeAdapter()

    async def generate(
        self,
        model_input: Union[Chat, str, list],
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> Union[str, list[str]]:
        """Generate text using `sglang`.

        Parameters
        ----------
        model_input
            The prompt based on which the model will generate a response.
        output_type
            The desired format of the response generated by the model. All
            output types available in Outlines are supported provided your
            server uses a structured generation backend that supports them.
        inference_kwargs
            Additional keyword arguments to pass to the client.

        Returns
        -------
        Union[str, list[str]]
            The text generated by the model.

        """
        client_args = self._build_client_args(
            model_input, output_type, **inference_kwargs,
        )

        response = await self.client.chat.completions.create(**client_args)

        messages = [choice.message for choice in response.choices]
        for message in messages:
            if message.refusal is not None:  # pragma: no cover
                raise ValueError(
                    f"The sglang server refused to answer the request: "
                    f"{message.refusal}"
                )

        if len(messages) == 1:
            return messages[0].content
        else:
            return [message.content for message in messages]

    async def generate_batch(
        self,
        model_input,
        output_type = None,
        **inference_kwargs,
    ):
        raise NotImplementedError(
            "SGLang does not support batch inference."
        )

    async def generate_stream( # type: ignore
        self,
        model_input: Union[Chat, str, list],
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> AsyncIterator[str]:
        """Return a text generator.

        Parameters
        ----------
        model_input
            The prompt based on which the model will generate a response.
        output_type
            The desired format of the response generated by the model. All
            output types available in Outlines are supported provided your
            server uses a structured generation backend that supports them.
        inference_kwargs
            Additional keyword arguments to pass to the client.

        Returns
        -------
        AsyncIterator[str]
            An async iterator that yields the text generated by the model.

        """
        client_args = self._build_client_args(
            model_input, output_type, **inference_kwargs,
        )

        stream = await self.client.chat.completions.create(
            **client_args,
            stream=True,
        )

        async for chunk in stream:  # pragma: no cover
            if chunk.choices and chunk.choices[0].delta.content is not None:
                yield chunk.choices[0].delta.content

    def _build_client_args(
        self,
        model_input: Union[Chat, str, list],
        output_type: Optional[Any] = None,
        **inference_kwargs: Any,
    ) -> dict:
        """Build the arguments to pass to the SGLang client."""
        messages = self.type_adapter.format_input(model_input)
        output_type_args = self.type_adapter.format_output_type(output_type)
        inference_kwargs.update(output_type_args)

        if "model" not in inference_kwargs and self.model_name is not None:
            inference_kwargs["model"] = self.model_name

        client_args = {
            "messages": messages,
            **inference_kwargs,
        }

        return client_args


def from_sglang(
    client: Union["OpenAI", "AsyncOpenAI"],
    model_name: Optional[str] = None,
) -> Union[SGLang, AsyncSGLang]:
    """Create a `SGLang` or `AsyncSGLang` instance from an `openai.OpenAI` or
    `openai.AsyncOpenAI` instance.

    Parameters
    ----------
    client
        An `openai.OpenAI` or `openai.AsyncOpenAI` instance.
    model_name
        The name of the model to use.

    Returns
    -------
    Union[SGLang, AsyncSGLang]
        An Outlines `SGLang` or `AsyncSGLang` model instance.

    """
    from openai import AsyncOpenAI, OpenAI

    if isinstance(client, OpenAI):
        return SGLang(client, model_name)
    elif isinstance(client, AsyncOpenAI):
        return AsyncSGLang(client, model_name)
    else:
        raise ValueError(
            f"Unsupported client type: {type(client)}.\n"
            "Please provide an OpenAI or AsyncOpenAI instance."
        )
