# Copyright © 2023-2024 Apple Inc.

from typing import Callable, List, Optional, Tuple, Union

import mlx.core as mx
from mlx.nn import Module
from mlx.utils import tree_flatten, tree_map, tree_merge, tree_reduce, tree_unflatten


class Optimizer:
    """The base class for all optimizers. It allows us to implement an
    optimizer on a per-parameter basis and apply it to a parameter tree.
    """

    def __init__(self, schedulers=None):
        self._initialized = False
        self._state = {"step": mx.array(0, mx.uint64)}
        self._schedulers = {k: v for k, v in (schedulers or {}).items()}

    def update(self, model: Module, gradients: dict):
        """Apply the gradients to the parameters of the model and update the
        model with the new parameters.

        Args:
            model (mlx.nn.Module): An mlx module to be updated.
            gradients (dict): A Python tree of gradients, most likely computed
                              via :func:`mlx.nn.value_and_grad`.
        """
        model.update(self.apply_gradients(gradients, model))

    def init(self, parameters: dict):
        """Initialize the optimizer's state

        This function can be used to initialize optimizers which have state
        (like momentum in :class:`SGD`). Using this method is optional as the
        optimizer will initialize itself if the state is not yet set. However,
        there are some cases where explicit initialization is useful in order
        to have access to the :attr:`Optimizer.state` before the first call to
        :meth:`Optimizer.update`.

        Args:
            model (dict): A Python tree of parameters.

        Example:
            >>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9)
            >>> model = nn.Linear(2, 2)
            >>> optimizer.init(model.trainable_parameters())
            >>> optimizer.state.keys()
            dict_keys(['step', 'learning_rate', 'weight', 'bias'])
        """

        # Initialize the optimizer state to match the parameter state
        def update_state(params, state):
            if isinstance(params, (list, tuple)):
                state = list(state)
                for i in range(len(state)):
                    state[i] = update_state(params[i], state[i])
                if len(state) != len(params):
                    state.extend(tree_map(lambda _: {}, params[len(state) :]))
                return type(params)(state)
            elif isinstance(params, dict):
                for k, v in params.items():
                    if k not in state:
                        state[k] = tree_map(lambda _: {}, v)
                    else:
                        state[k] = update_state(v, state[k])
                return state
            else:
                return state

        update_state(parameters, self._state)
        tree_map(lambda p, s: s or self.init_single(p, s), parameters, self._state)
        self._initialized = True

    def init_single(self, parameter: mx.array, state: dict):
        """To be extended by the children classes to implement each optimizer's
        state initialization.

        Args:
            parameter (mx.array): A single parameter that will be optimized.
            state (dict): The optimizer's state.
        """
        raise NotImplementedError()

    def apply_gradients(self, gradients: dict, parameters: dict):
        """Apply the gradients to the parameters and return the updated parameters.

        Can be used to update a model via
        ``model.update(opt.apply_gradients(grads, model))`` which is precisely
        how :meth:`Optimizer.update` is implemented.

        Args:
            gradients (dict): A Python tree of gradients.
            parameters (dict): A Python tree of parameters. It can be a
              superset of the gradients. In that case the returned python
              tree will be of the same structure as the gradients.
        """
        if not self._initialized:
            self.init(gradients)

        # Update any scheduled variables
        for param, scheduler in self._schedulers.items():
            self.state[param] = scheduler(self.step)

        # Increment the step
        self.state["step"] = self.step + 1

        # Apply the update
        return tree_map(self.apply_single, gradients, parameters, self.state)

    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
        """To be extended by derived classes to implement the optimizer's update.

        Args:
            gradient (mx.array): The ``parameter`` gradient.
            parameter (mx.array): The ``parameter`` to update.
            state (dict): The optimizer's state.
        """
        raise NotImplementedError()

    @property
    def state(self):
        """The optimizer's state dictionary."""
        return self._state

    @state.setter
    def state(self, state: dict):
        self._initialized = False
        self._state = state

    @property
    def step(self):
        return self.state["step"]

    @property
    def learning_rate(self):
        return self.state["learning_rate"]

    @learning_rate.setter
    def learning_rate(self, learning_rate: Union[float, mx.array]):
        self.state["learning_rate"] = mx.array(learning_rate)

    def _maybe_schedule(
        self, name: str, param: Union[float, Callable[[mx.array], mx.array]]
    ):
        """
        To be used by derived classes to optionally put a parameter on a schedule.
        """
        if isinstance(param, Callable):
            self._schedulers[name] = param
            parameter = param(self.step)
        else:
            parameter = mx.array(param)
        self.state[name] = parameter


class MultiOptimizer(Optimizer):
    """Wraps a list of optimizers with corresponding weight predicates/filters
    to make it easy to use different optimizers for different weights.

    The predicates take the full "path" of the weight and the weight itself and
    return True if it should be considered for this optimizer. The last
    optimizer in the list is a fallback optimizer and no predicate should be
    given for it.

    Args:
        optimizers (list[Optimizer]): A list of optimizers to delegate to
        filters (list[Callable[[str, array], bool]): A list of predicates that
            should be one less than the provided optimizers.
    """

    def __init__(self, optimizers, filters: list = []):
        super().__init__()
        self._state = {}

        if len(filters) != len(optimizers) - 1:
            raise ValueError(
                f"Given {len(filters)} filters but {len(optimizers)-1} needed."
            )

        self.optimizers = optimizers
        self.filters = filters + [lambda *args, **kwargs: True]

    def _split_dictionary(self, gradients: dict):
        if len(self.optimizers) == 1:
            return [gradients]

        parts = [[] for _ in range(len(self.optimizers))]
        flat_gradients = tree_flatten(gradients)
        for k, g in flat_gradients:
            for i, fn in enumerate(self.filters):
                if fn(k, g):
                    parts[i].append((k, g))
                    break

        return [tree_unflatten(p) for p in parts]

    def init(self, parameters: dict):
        for o, p in zip(self.optimizers, self._split_dictionary(parameters)):
            o.init(p)

    def apply_gradients(self, gradients: dict, parameters: dict):
        tree = {}
        for o, g in zip(self.optimizers, self._split_dictionary(gradients)):
            tree = tree_merge(tree, o.apply_gradients(g, parameters))
        return tree

    @property
    def state(self):
        return {"states": [o.state for o in self.optimizers]}

    @state.setter
    def state(self, state: dict):
        if "states" not in state or len(state["states"]) != len(self.optimizers):
            raise ValueError("Invalid state provided")

        for o, s in zip(self.optimizers, state["states"]):
            o.state = s

    @property
    def learning_rate(self):
        return self.optimizers[0].learning_rate

    @learning_rate.setter
    def learning_rate(self, learning_rate: Union[float, mx.array]):
        for o in self.optimizers:
            o.learning_rate = learning_rate


class SGD(Optimizer):
    r"""The stochastic gradient descent optimizer.

    Updates a parameter :math:`w` with a gradient :math:`g` as follows

    .. math::

        v_{t+1} &= \mu v_t + (1 - \tau) g_t \\
        w_{t+1} &= w_t - \lambda v_{t+1}

    Args:
        learning_rate (float or callable): The learning rate :math:`\lambda`.
        momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0``
        weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0``
        dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0``
        nesterov (bool, optional): Enables Nesterov momentum. Default: ``False``
    """

    def __init__(
        self,
        learning_rate: Union[float, Callable[[mx.array], mx.array]],
        momentum: float = 0.0,
        weight_decay: float = 0.0,
        dampening: float = 0.0,
        nesterov: bool = False,
    ):
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError(
                "Nesterov momentum requires a momentum and zero dampening."
            )
        super().__init__()

        self._maybe_schedule("learning_rate", learning_rate)
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.dampening = dampening
        self.nesterov = nesterov

    def init_single(self, parameter: mx.array, state: dict):
        """Initialize optimizer state"""
        state["v"] = mx.zeros_like(parameter)

    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
        """Performs the SGD parameter update and stores :math:`v` in the
        optimizer state."""

        if self.weight_decay != 0:
            gradient += self.weight_decay * parameter

        if self.momentum <= 0:
            return parameter - self.learning_rate.astype(gradient.dtype) * gradient

        v = self.momentum * state.get("v")
        if self.dampening > 0:
            v += (1 - self.dampening) * gradient
        else:
            v += gradient

        if self.nesterov:
            update = gradient + self.momentum * v
        else:
            update = v

        state["v"] = v
        return parameter - self.learning_rate.astype(gradient.dtype) * update


class RMSprop(Optimizer):
    r"""The RMSprop optimizer [1].

    [1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning

    .. math::

        v_{t+1} &= \alpha v_t + (1 - \alpha) g_t^2 \\
        w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}

    Args:
        learning_rate (float or callable): The learning rate :math:`\lambda`.
        alpha (float, optional): The smoothing constant :math:`\alpha`.
          Default: ``0.99``
        eps (float, optional): The term :math:`\epsilon` added to the denominator
          to improve numerical stability. Default: ``1e-8``
    """

    def __init__(
        self,
        learning_rate: Union[float, Callable[[mx.array], mx.array]],
        alpha: float = 0.99,
        eps: float = 1e-8,
    ):
        super().__init__()

        self._maybe_schedule("learning_rate", learning_rate)
        self.alpha = alpha
        self.eps = eps

        if self.alpha < 0.0:
            raise ValueError(
                f"RMSprop alpha should be >=0, {self.alpha} was provided instead"
            )
        if self.eps < 0.0:
            raise ValueError(
                f"RMSprop epsilon should be >0, {self.eps} was provided instead"
            )

    def init_single(self, parameter: mx.array, state: dict):
        """Initialize optimizer state"""
        state["v"] = mx.zeros_like(parameter)

    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
        """Performs the RMSprop parameter update and stores :math:`v` in the optimizer state."""
        lr = self.learning_rate.astype(gradient.dtype)
        alpha = self.alpha
        eps = self.eps

        v = state["v"]
        v = alpha * v + (1 - alpha) * mx.square(gradient)
        state["v"] = v

        return parameter - lr * gradient / (mx.sqrt(v) + eps)


class Adagrad(Optimizer):
    r"""The Adagrad optimizer [1].

    Our Adagrad implementation follows the original paper. In detail,

    [1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods
    for online learning and stochastic optimization. JMLR 2011.

    .. math::

        v_{t+1} &= v_t + g_t^2 \\
        w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}

    Args:
        learning_rate (float or callable): The learning rate :math:`\lambda`.
        eps (float, optional): The term :math:`\epsilon` added to the
          denominator to improve numerical stability. Default: ``1e-8``
    """

    def __init__(
        self,
        learning_rate: Union[float, Callable[[mx.array], mx.array]],
        eps: float = 1e-8,
    ):
        super().__init__()

        self._maybe_schedule("learning_rate", learning_rate)
        self.eps = eps

        if self.eps < 0.0:
            raise ValueError(
                f"Adagrad epsilon should be >0, {self.eps} was provided instead"
            )

    def init_single(self, parameter: mx.array, state: dict):
        """Initialize optimizer state"""
        state["v"] = mx.zeros_like(parameter)

    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
        """Performs the Adagrad parameter update and stores :math:`v` in the
        optimizer state."""
        lr = self.learning_rate.astype(gradient.dtype)
        eps = self.eps

        v = state["v"] + mx.square(gradient)
        state["v"] = v

        return parameter - lr * gradient / (mx.sqrt(v) + eps)


class AdaDelta(Optimizer):
    r"""The AdaDelta optimizer with a learning rate [1].

    Our AdaDelta implementation follows the original paper. In detail,

    [1]: Zeiler, M.D., 2012. ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701.

    .. math::

        v_{t+1} &= \rho v_t + (1 - \rho) g_t^2 \\
        \Delta w_{t+1} &= \frac{\sqrt{u_t + \epsilon}}{\sqrt{v_{t+1} + \epsilon}} g_t \\
        u_{t+1} &= \rho u_t + (1 - \rho) \Delta w_{t+1}^2 \\
        w_{t+1} &= w_t - \lambda \Delta w_{t+1}

    Args:
        learning_rate (float or callable): The learning rate :math:`\lambda`.
        rho (float, optional): The coefficient :math:`\rho` used for computing a
            running average of squared gradients. Default: ``0.9``
        eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
          numerical stability. Default: `1e-8`
    """

    def __init__(
        self,
        learning_rate: Union[float, Callable[[mx.array], mx.array]],
        rho: float = 0.9,
        eps: float = 1e-6,
    ):
        super().__init__()

        self._maybe_schedule("learning_rate", learning_rate)
        self.rho = rho
        self.eps = eps
        if self.rho < 0.0:
            raise ValueError(
                f"AdaDelta rho should be >=0, {self.rho} was provided instead"
            )
        if self.eps < 0.0:
            raise ValueError(
                f"AdaDelta epsilon should be >0, {self.eps} was provided instead"
            )

    def init_single(self, parameter: mx.array, state: dict):
        """Initialize optimizer state"""
        state["v"] = mx.zeros_like(parameter)
        state["u"] = mx.zeros_like(parameter)

    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
        """Performs the AdaDelta parameter update and stores :math:`v` and
        :math:`u` in the optimizer state."""
        lr = self.learning_rate.astype(gradient.dtype)
        rho = self.rho
        eps = self.eps

        v = state["v"]
        u = state["u"]

        v = rho * v + (1 - rho) * mx.square(gradient)
        d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient
        u = rho * u + (1 - rho) * mx.square(d)

        state["v"] = v
        state["u"] = u

        return parameter - lr * d


class Adam(Optimizer):
    r"""The Adam optimizer [1]. In detail,

    [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
    optimization. ICLR 2015.

    .. math::

        m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
        v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
        w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon}

    Args:
        learning_rate (float or callable): The learning rate :math:`\lambda`.
        betas (Tuple[float, float], optional): The coefficients
          :math:`(\beta_1, \beta_2)` used for computing running averages of the
          gradient and its square. Default: ``(0.9, 0.999)``
        eps (float, optional): The term :math:`\epsilon` added to the
          denominator to improve numerical stability. Default: ``1e-8``
        bias_correction (bool, optional): If set to ``True``, bias correction
          is applied. Default: ``False``
    """

    def __init__(
        self,
        learning_rate: Union[float, Callable[[mx.array], mx.array]],
        betas: List[float] = [0.9, 0.999],
        eps: float = 1e-8,
        bias_correction: bool = False,
    ):
        super().__init__()

        self._maybe_schedule("learning_rate", learning_rate)
        self.betas = betas
        self.eps = eps
        self.bias_correction = bias_correction

    def init_single(self, parameter: mx.array, state: dict):
        """Initialize optimizer state"""
        state["m"] = mx.zeros_like(parameter)
        state["v"] = mx.zeros_like(parameter)

    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
        """Performs the Adam parameter update and stores :math:`v` and
        :math:`m` in the optimizer state."""
        lr = self.learning_rate.astype(gradient.dtype)
        b1, b2 = self.betas
        eps = self.eps
        bias_correction = self.bias_correction
        step = self.step

        m = state["m"]
        v = state["v"]
        m = b1 * m + (1 - b1) * gradient
        v = b2 * v + (1 - b2) * mx.square(gradient)
        state["m"] = m
        state["v"] = v

        if bias_correction:
            c1 = (lr / (1 - b1**step)).astype(gradient.dtype)
            c2 = mx.rsqrt(1 - b2**step).astype(gradient.dtype)
            numerator = c1 * m
            denominator = mx.sqrt(v) * c2 + eps
            return parameter - numerator / denominator
        else:
            return parameter - lr * m / (mx.sqrt(v) + eps)


class AdamW(Adam):
    r"""The AdamW optimizer [1]. We update the weights with a weight_decay
    (:math:`\lambda`) value:

    [1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
    regularization. ICLR 2019.

    .. math::

        m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
        v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
        w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon} + \lambda w_t)

    Args:
        learning_rate (float or callable): The learning rate :math:`\alpha`.
        betas (Tuple[float, float], optional): The coefficients
          :math:`(\beta_1, \beta_2)` used for computing running averages of the
          gradient and its square. Default: ``(0.9, 0.999)``
        eps (float, optional): The term :math:`\epsilon` added to the
          denominator to improve numerical stability. Default: ``1e-8``
        weight_decay (float, optional): The weight decay :math:`\lambda`.
          Default: ``0.01``.
        bias_correction (bool, optional): If set to ``True``, bias correction
          is applied. Default: ``False``
    """

    def __init__(
        self,
        learning_rate: Union[float, Callable[[mx.array], mx.array]],
        betas: List[float] = [0.9, 0.999],
        eps: float = 1e-8,
        weight_decay: float = 0.01,
        bias_correction: bool = False,
    ):
        super().__init__(
            learning_rate=learning_rate,
            betas=betas,
            eps=eps,
            bias_correction=bias_correction,
        )
        self.weight_decay = weight_decay

    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
        """Performs the AdamW parameter update by modifying the parameters
        passed into Adam.
        """

        lr = self.learning_rate.astype(gradient.dtype)
        return super().apply_single(
            gradient, parameter * (1 - lr * self.weight_decay), state
        )


class Adamax(Adam):
    r"""The Adamax optimizer, a variant of Adam based on the infinity norm [1].

    Our Adam implementation follows the original paper and omits the bias
    correction in the first and second moment estimates. In detail,

    [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
    optimization. ICLR 2015.

    .. math::

        m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
        v_{t+1} &= \max(\beta_2 v_t, |g_t|) \\
        w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon}

    Args:
        learning_rate (float or callable): The learning rate :math:`\lambda`.
        betas (Tuple[float, float], optional): The coefficients
          :math:`(\beta_1, \beta_2)` used for computing running averages of the
          gradient and its square. Default: ``(0.9, 0.999)``
        eps (float, optional): The term :math:`\epsilon` added to the
          denominator to improve numerical stability. Default: ``1e-8``
    """

    def __init__(
        self,
        learning_rate: Union[float, Callable[[mx.array], mx.array]],
        betas: List[float] = [0.9, 0.999],
        eps: float = 1e-8,
    ):
        super().__init__(learning_rate, betas, eps)
        if not 0.0 <= eps:
            raise ValueError(
                f"Epsilon value should be >=0, {self.eps} was provided instead"
            )

    def init_single(self, parameter: mx.array, state: dict):
        """Initialize optimizer state"""
        state["m"] = mx.zeros_like(parameter)
        state["v"] = mx.zeros_like(parameter)

    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
        """Performs the Adamax parameter update and stores :math:`v` and
        :math:`m` in the optimizer state."""
        lr = self.learning_rate.astype(gradient.dtype)
        b1, b2 = self.betas
        eps = self.eps

        m = state["m"]
        v = state["v"]

        m = b1 * m + (1 - b1) * gradient
        v = mx.maximum(b2 * v, mx.abs(gradient))
        state["m"] = m
        state["v"] = v

        return parameter - lr * m / (v + eps)


class Lion(Optimizer):
    r"""The Lion optimizer [1].

    Since updates are computed through the sign operation, they tend to
    have larger norm than for other optimizers such as SGD and Adam.
    We recommend a learning rate that is 3-10x smaller than AdamW and a
    weight decay 3-10x larger than AdamW to maintain the strength
    (lr * wd). Our Lion implementation follows the original paper. In
    detail,

    [1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv
    preprint arXiv:2302.06675.

    .. math::

        c_{t + 1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
        m_{t + 1} &= \beta_2 m_t + (1 - \beta_2) g_t \\
        w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)

    Args:
        learning_rate (float or callable): The learning rate :math:`\eta`.
        betas (Tuple[float, float], optional): The coefficients
          :math:`(\beta_1, \beta_2)` used for computing the gradient
          momentum and update direction. Default: ``(0.9, 0.99)``
        weight_decay (float, optional): The weight decay :math:`\lambda`. Default: ``0.0``
    """

    def __init__(
        self,
        learning_rate: Union[float, Callable[[mx.array], mx.array]],
        betas: List[float] = [0.9, 0.99],
        weight_decay: float = 0.0,
    ):
        super().__init__()

        self._maybe_schedule("learning_rate", learning_rate)
        self.betas = betas
        self.weight_decay = weight_decay

    def init_single(self, parameter: mx.array, state: dict):
        """Initialize optimizer state"""
        state["m"] = mx.zeros_like(parameter)

    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
        """Performs the Lion parameter update and stores :math:`m`
        in the optimizer state."""
        lr = self.learning_rate.astype(gradient.dtype)
        b1, b2 = self.betas
        weight_decay = self.weight_decay

        m = state["m"]
        c = b1 * m + (1 - b1) * gradient
        state["m"] = b2 * m + (1 - b2) * gradient
        if weight_decay > 0:
            parameter = (1 - lr * weight_decay) * parameter
        return parameter - lr * mx.sign(c)


class Adafactor(Optimizer):
    r"""The Adafactor optimizer.

    Our Adafactor implementation follows the original paper: `Adafactor:
    Adaptive Learning Rates with Sublinear Memory Cost
    <https://arxiv.org/abs/1804.04235>`_

    Args:
        learning_rate (float or callable, optional): The learning rate.
            Default: ``None``.
        eps (tuple(float, float), optional): The first term :math:`\epsilon_1`
            added to the square of the gradients to improve numerical
            stability and the second term :math:`\epsilon_2` is used for
            parameter scaling if ``parameter_scale`` is set to ``True``.
            Default: ``(1e-30, 1e-3)``.
        clip_threshold (float, optional): Clips the unscaled update at
            ``clip_threshold``. Default: ``1.0``.
        decay_rate (float, optional): Coefficient for the running average
            of the squared gradient. Default: ``-0.8``.
        beta_1 (float, optional): If set to a value bigger than zero
            then first moment will be used. Default: ``None``.
        weight_decay (float, optional): The weight decay :math:`\lambda`.
            Default: ``0.0``.
        scale_parameter (bool, optional): If set to ``True`` the learning rate
            will be scaled by :math:`\max(\epsilon_1, \text{RMS}(w_{t-1}))`.
            Default: ``True``.
        relative_step (bool, optional): If set to ``True`` the ``learning_rate``
            will be ignored and relative step size will be computed.
            Default: ``True``.
        warmup_init (bool, optional): If set to ``True`` then the relative
            step size will be calculated by the current step. Default:
            ``False``.
    """

    def __init__(
        self,
        learning_rate: Union[float, Callable[[mx.array], mx.array], None] = None,
        eps: Tuple[float, float] = (1e-30, 1e-3),
        clip_threshold: float = 1.0,
        decay_rate: float = -0.8,
        beta_1: Optional[float] = None,
        weight_decay: float = 0.0,
        scale_parameter: bool = True,
        relative_step: bool = True,
        warmup_init: bool = False,
    ):
        super().__init__()
        if learning_rate is not None:
            self._maybe_schedule("learning_rate", learning_rate)
        self.eps = eps
        self.clip_threshold = clip_threshold
        self.decay_rate = decay_rate
        self.beta_1 = beta_1
        self.weight_decay = weight_decay
        self.scale_parameter = scale_parameter
        self.relative_step = relative_step
        self.warmup_init = warmup_init

    def init_single(self, parameter: mx.array, state: dict):
        """Initialize optimizer state"""
        if parameter.ndim >= 2:
            shape = parameter.shape
            dtype = parameter.dtype
            state["exp_avg_sq_row"] = mx.zeros(shape[:-1], dtype=dtype)
            state["exp_avg_sq_col"] = mx.zeros(shape[:-2] + shape[-1:], dtype=dtype)
        else:
            state["exp_avg_sq"] = mx.zeros_like(parameter)

        if self.beta_1 is not None:
            state["exp_avg"] = mx.zeros_like(parameter)

    def _compute_rms(self, inputs):
        return mx.sqrt(mx.mean(mx.square(inputs)))

    def _compute_learning_rate(self, step, parameter_rms):
        if self.relative_step:
            min_step = 1e-6 * step if self.warmup_init else 1e-2
            relative_step_size = mx.minimum(min_step, mx.rsqrt(step))
        else:
            relative_step_size = self.learning_rate

        relative_step_size = relative_step_size.astype(parameter_rms.dtype)
        parameter_scale = 1.0
        if self.scale_parameter:
            parameter_scale = mx.maximum(self.eps[1], parameter_rms)
        return parameter_scale * relative_step_size

    def _approximate_exp_moving_avg(self, exp_avg_sq_row, exp_avg_sq_col):
        r_factor = mx.rsqrt(
            exp_avg_sq_row / mx.mean(exp_avg_sq_row, axis=-1, keepdims=True)
        )
        c_factor = mx.rsqrt(exp_avg_sq_col)
        return mx.matmul(
            mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0)
        )

    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
        """Performs the Adafactor parameter and state update."""
        factored = gradient.ndim >= 2

        step = self.step
        use_first_moment = self.beta_1 is not None

        parameter_rms = self._compute_rms(parameter)
        learning_rate = self._compute_learning_rate(step, parameter_rms)
        beta_2 = 1.0 - (step**self.decay_rate).astype(parameter_rms.dtype)
        update = mx.square(gradient) + self.eps[0]

        if factored:
            exp_avg_sq_row = state["exp_avg_sq_row"]
            exp_avg_sq_col = state["exp_avg_sq_col"]
            exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + (
                (1 - beta_2) * mx.mean(update, axis=-1)
            )
            exp_avg_sq_col = (beta_2 * exp_avg_sq_col) + (
                (1 - beta_2) * mx.mean(update, axis=-2)
            )
            state["exp_avg_sq_row"] = exp_avg_sq_row
            state["exp_avg_sq_col"] = exp_avg_sq_col
            update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col)
            update = update * gradient
        else:
            exp_avg_sq = state["exp_avg_sq"]
            exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update)
            state["exp_avg_sq"] = exp_avg_sq
            update = mx.rsqrt(exp_avg_sq) * gradient

        update = update / mx.maximum(
            1.0, self._compute_rms(update) / self.clip_threshold
        )
        update = learning_rate * update

        if use_first_moment:
            exp_avg = state["exp_avg"]
            exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update)
            state["exp_avg"] = exp_avg
            update = exp_avg

        if self.weight_decay != 0:
            parameter += parameter * (-self.weight_decay * learning_rate)
        return parameter - update


class Muon(Optimizer):
    r"""The Muon optimizer.

    Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
    original implementation: `Muon: An optimizer for hidden layers in neural
    networks <https://kellerjordan.github.io/posts/muon/>`_

    Note:
        - Muon may be sub-optimal for the embedding layer, the final fully
          connected layer, or any 0D/1D parameters. Those should be optimized
          by a different method (e.g., :class:`AdamW`).
        - For 4D convolutional filters, it works by flattening their last
          dimensions.

    Args:
        learning_rate (float or callable): The learning rate.
        momentum (float, optional): The momentum strength. Default: ``0.95``
        weight_decay (float, optional): The weight decay (L2 penalty).
            Default: ``0.01``
        nesterov (bool, optional): Enables Nesterov momentum. Recommended for
            better performance.  Default: ``True``
        ns_steps (int, optional): Number of Newton-Schulz iteration steps for
            orthogonalization.  Default: ``5``
    """

    def __init__(
        self,
        learning_rate: Union[float, Callable[[mx.array], mx.array]],
        momentum: float = 0.95,
        weight_decay: float = 0.01,
        nesterov: bool = True,
        ns_steps: int = 5,
    ):
        super().__init__()

        self._maybe_schedule("learning_rate", learning_rate)
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.nesterov = nesterov
        self.ns_steps = ns_steps

    def init_single(self, parameter: mx.array, state: dict):
        """Initialize optimizer state"""
        state["v"] = mx.zeros_like(parameter)

    def _zeropower_via_newtonschulz5(self, X, steps: int):
        assert (
            X.ndim == 2
        ), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
        a, b, c = (3.4445, -4.7750, 2.0315)
        transpose_needed = X.shape[-2] > X.shape[-1]

        if transpose_needed:
            X = X.T

        X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)

        for _ in range(steps):
            A = X @ X.T
            B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
            X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)

        if transpose_needed:
            X = X.T
        return X

    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
        """Performs the Muon parameter update"""

        if self.weight_decay != 0:
            gradient = gradient + self.weight_decay * parameter

        v = self.momentum * state["v"]
        v = v + (1 - self.momentum) * gradient
        state["v"] = v

        if self.nesterov:
            update = gradient * (1 - self.momentum) + v * self.momentum
        else:
            update = v

        lr = self.learning_rate.astype(gradient.dtype)

        if update.ndim >= 2:
            original_shape = update.shape
            reshape_needed = update.ndim > 2

            if reshape_needed:
                update = mx.reshape(update, (update.shape[0], -1))

            update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)

            if reshape_needed:
                update = mx.reshape(update, original_shape)

            lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5

        return parameter - lr * update


def clip_grad_norm(grads, max_norm):
    """Clips the global norm of the gradients.

    This function ensures that the global norm of the gradients does not exceed
    ``max_norm``. It scales down the gradients proportionally if their norm is
    greater than ``max_norm``.

    Example:
        >>> grads = {"w1": mx.array([2, 3]), "w2": mx.array([1])}
        >>> clipped_grads, total_norm = clip_grad_norm(grads, max_norm=2.0)
        >>> print(clipped_grads)
        {"w1": mx.array([...]), "w2": mx.array([...])}

    Args:
        grads (dict): A dictionary containing the gradient arrays.
        max_norm (float): The maximum allowed global norm of the gradients.

    Returns:
        (dict, float): The possibly rescaled gradients and the original
        gradient norm.
    """
    norm_squared = tree_reduce(lambda acc, g: acc + g.square().sum(), grads, 0.0)
    total_norm = mx.sqrt(norm_squared)
    normalizer = mx.minimum(max_norm / (total_norm + 1e-6), 1.0)
    clipped_grads = tree_map(lambda g: g * normalizer, grads)
    return clipped_grads, total_norm
