"""Integration with the `mlx_lm` library."""

from functools import singledispatchmethod
from typing import TYPE_CHECKING, Iterator, List, Optional

from outlines.models.base import Model, ModelTypeAdapter
from outlines.models.transformers import TransformerTokenizer
from outlines.processors import OutlinesLogitsProcessor

if TYPE_CHECKING:
    import mlx.nn as nn
    from transformers import PreTrainedTokenizer

__all__ = ["MLXLM", "from_mlxlm"]


class MLXLMTypeAdapter(ModelTypeAdapter):
    """Type adapter for the `MLXLM` model."""

    @singledispatchmethod
    def format_input(self, model_input):
        """Generate the prompt argument to pass to the model.

        Parameters
        ----------
        model_input
            The input provided by the user.

        Returns
        -------
        str
            The formatted input to be passed to the model.

        """
        raise NotImplementedError(
            f"The input type {input} is not available with mlx-lm. "
            "The only available type is `str`."
        )

    @format_input.register(str)
    def format_str_input(self, model_input: str):
        return model_input

    def format_output_type(
        self, output_type: Optional[OutlinesLogitsProcessor] = None,
    ) -> Optional[List[OutlinesLogitsProcessor]]:
        """Generate the logits processor argument to pass to the model.

        Parameters
        ----------
        output_type
            The logits processor provided.

        Returns
        -------
        Optional[list[OutlinesLogitsProcessor]]
            The logits processor argument to be passed to the model.

        """
        if not output_type:
            return None
        return [output_type]


class MLXLM(Model):
    """Thin wrapper around an `mlx_lm` model.

    This wrapper is used to convert the input and output types specified by the
    users at a higher level to arguments to the `mlx_lm` library.

    """

    tensor_library_name = "mlx"

    def __init__(
        self,
        model: "nn.Module",
        tokenizer: "PreTrainedTokenizer",
    ):
        """
        Parameters
        ----------
        model
            An instance of an `mlx_lm` model.
        tokenizer
            An instance of an `mlx_lm` tokenizer or of a compatible
            `transformers` tokenizer.

        """
        self.model = model
        # self.mlx_tokenizer is used by the mlx-lm in its generate function
        self.mlx_tokenizer = tokenizer
        # self.tokenizer is used by the logits processor
        self.tokenizer = TransformerTokenizer(tokenizer._tokenizer)
        self.type_adapter = MLXLMTypeAdapter()

    def generate(
        self,
        model_input: str,
        output_type: Optional[OutlinesLogitsProcessor] = None,
        **kwargs,
    ) -> str:
        """Generate text using `mlx-lm`.

        Parameters
        ----------
        model_input
            The prompt based on which the model will generate a response.
        output_type
            The logits processor the model will use to constrain the format of
            the generated text.
        kwargs
            Additional keyword arguments to pass to the `mlx-lm` library.

        Returns
        -------
        str
            The text generated by the model.

        """
        from mlx_lm import generate

        return generate(
            self.model,
            self.mlx_tokenizer,
            self.type_adapter.format_input(model_input),
            logits_processors=self.type_adapter.format_output_type(output_type),
            **kwargs,
        )

    def generate_batch(
        self,
        model_input,
        output_type = None,
        **kwargs,
    ):
        raise NotImplementedError(
            "The `mlx_lm` library does not support batch inference."
        )

    def generate_stream(
        self,
        model_input: str,
        output_type: Optional[OutlinesLogitsProcessor] = None,
        **kwargs,
    ) -> Iterator[str]:
        """Stream text using `mlx-lm`.

        Parameters
        ----------
        model_input
            The prompt based on which the model will generate a response.
        output_type
            The logits processor the model will use to constrain the format of
            the generated text.
        kwargs
            Additional keyword arguments to pass to the `mlx-lm` library.

        Returns
        -------
        Iterator[str]
            An iterator that yields the text generated by the model.

        """
        from mlx_lm import stream_generate

        for gen_response in stream_generate(
            self.model,
            self.mlx_tokenizer,
            self.type_adapter.format_input(model_input),
            logits_processors=self.type_adapter.format_output_type(output_type),
            **kwargs,
        ):
            yield gen_response.text


def from_mlxlm(model: "nn.Module", tokenizer: "PreTrainedTokenizer") -> MLXLM:
    """Create an Outlines `MLXLM` model instance from an `mlx_lm` model and a
    tokenizer.

    Parameters
    ----------
    model
        An instance of an `mlx_lm` model.
    tokenizer
        An instance of an `mlx_lm` tokenizer or of a compatible
        transformers tokenizer.

    Returns
    -------
    MLXLM
        An Outlines `MLXLM` model instance.

    """
    return MLXLM(model, tokenizer)
