from __future__ import annotations

from sympy.core.expr import Expr
from sympy.core.function import Derivative
from sympy.core.numbers import Integer
from sympy.matrices.common import MatrixCommon
from .ndim_array import NDimArray
from .arrayop import derive_by_array
from sympy.matrices.expressions.matexpr import MatrixExpr
from sympy.matrices.expressions.special import ZeroMatrix
from sympy.matrices.expressions.matexpr import _matrix_derivative


class ArrayDerivative(Derivative):

    is_scalar = False

    def __new__(cls, expr, *variables, **kwargs):
        obj = super().__new__(cls, expr, *variables, **kwargs)
        if isinstance(obj, ArrayDerivative):
            obj._shape = obj._get_shape()
        return obj

    def _get_shape(self):
        shape = ()
        for v, count in self.variable_count:
            if hasattr(v, "shape"):
                for i in range(count):
                    shape += v.shape
        if hasattr(self.expr, "shape"):
            shape += self.expr.shape
        return shape

    @property
    def shape(self):
        return self._shape

    @classmethod
    def _get_zero_with_shape_like(cls, expr):
        if isinstance(expr, (MatrixCommon, NDimArray)):
            return expr.zeros(*expr.shape)
        elif isinstance(expr, MatrixExpr):
            return ZeroMatrix(*expr.shape)
        else:
            raise RuntimeError("Unable to determine shape of array-derivative.")

    @staticmethod
    def _call_derive_scalar_by_matrix(expr: Expr, v: MatrixCommon) -> Expr:
        return v.applyfunc(lambda x: expr.diff(x))

    @staticmethod
    def _call_derive_scalar_by_matexpr(expr: Expr, v: MatrixExpr) -> Expr:
        if expr.has(v):
            return _matrix_derivative(expr, v)
        else:
            return ZeroMatrix(*v.shape)

    @staticmethod
    def _call_derive_scalar_by_array(expr: Expr, v: NDimArray) -> Expr:
        return v.applyfunc(lambda x: expr.diff(x))

    @staticmethod
    def _call_derive_matrix_by_scalar(expr: MatrixCommon, v: Expr) -> Expr:
        return _matrix_derivative(expr, v)

    @staticmethod
    def _call_derive_matexpr_by_scalar(expr: MatrixExpr, v: Expr) -> Expr:
        return expr._eval_derivative(v)

    @staticmethod
    def _call_derive_array_by_scalar(expr: NDimArray, v: Expr) -> Expr:
        return expr.applyfunc(lambda x: x.diff(v))

    @staticmethod
    def _call_derive_default(expr: Expr, v: Expr) -> Expr | None:
        if expr.has(v):
            return _matrix_derivative(expr, v)
        else:
            return None

    @classmethod
    def _dispatch_eval_derivative_n_times(cls, expr, v, count):
        # Evaluate the derivative `n` times.  If
        # `_eval_derivative_n_times` is not overridden by the current
        # object, the default in `Basic` will call a loop over
        # `_eval_derivative`:

        if not isinstance(count, (int, Integer)) or ((count <= 0) == True):
            return None

        # TODO: this could be done with multiple-dispatching:
        if expr.is_scalar:
            if isinstance(v, MatrixCommon):
                result = cls._call_derive_scalar_by_matrix(expr, v)
            elif isinstance(v, MatrixExpr):
                result = cls._call_derive_scalar_by_matexpr(expr, v)
            elif isinstance(v, NDimArray):
                result = cls._call_derive_scalar_by_array(expr, v)
            elif v.is_scalar:
                # scalar by scalar has a special
                return super()._dispatch_eval_derivative_n_times(expr, v, count)
            else:
                return None
        elif v.is_scalar:
            if isinstance(expr, MatrixCommon):
                result = cls._call_derive_matrix_by_scalar(expr, v)
            elif isinstance(expr, MatrixExpr):
                result = cls._call_derive_matexpr_by_scalar(expr, v)
            elif isinstance(expr, NDimArray):
                result = cls._call_derive_array_by_scalar(expr, v)
            else:
                return None
        else:
            # Both `expr` and `v` are some array/matrix type:
            if isinstance(expr, MatrixCommon) or isinstance(expr, MatrixCommon):
                result = derive_by_array(expr, v)
            elif isinstance(expr, MatrixExpr) and isinstance(v, MatrixExpr):
                result = cls._call_derive_default(expr, v)
            elif isinstance(expr, MatrixExpr) or isinstance(v, MatrixExpr):
                # if one expression is a symbolic matrix expression while the other isn't, don't evaluate:
                return None
            else:
                result = derive_by_array(expr, v)
        if result is None:
            return None
        if count == 1:
            return result
        else:
            return cls._dispatch_eval_derivative_n_times(result, v, count - 1)
