from sympy.core import S
from sympy.core.function import Lambda
from sympy.core.power import Pow
from .pycode import PythonCodePrinter, _known_functions_math, _print_known_const, _print_known_func, _unpack_integral_limits, ArrayPrinter
from .codeprinter import CodePrinter


_not_in_numpy = 'erf erfc factorial gamma loggamma'.split()
_in_numpy = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_numpy]
_known_functions_numpy = dict(_in_numpy, **{
    'acos': 'arccos',
    'acosh': 'arccosh',
    'asin': 'arcsin',
    'asinh': 'arcsinh',
    'atan': 'arctan',
    'atan2': 'arctan2',
    'atanh': 'arctanh',
    'exp2': 'exp2',
    'sign': 'sign',
    'logaddexp': 'logaddexp',
    'logaddexp2': 'logaddexp2',
})
_known_constants_numpy = {
    'Exp1': 'e',
    'Pi': 'pi',
    'EulerGamma': 'euler_gamma',
    'NaN': 'nan',
    'Infinity': 'PINF',
    'NegativeInfinity': 'NINF'
}

_numpy_known_functions = {k: 'numpy.' + v for k, v in _known_functions_numpy.items()}
_numpy_known_constants = {k: 'numpy.' + v for k, v in _known_constants_numpy.items()}

class NumPyPrinter(ArrayPrinter, PythonCodePrinter):
    """
    Numpy printer which handles vectorized piecewise functions,
    logical operators, etc.
    """

    _module = 'numpy'
    _kf = _numpy_known_functions
    _kc = _numpy_known_constants

    def __init__(self, settings=None):
        """
        `settings` is passed to CodePrinter.__init__()
        `module` specifies the array module to use, currently 'NumPy', 'CuPy'
        or 'JAX'.
        """
        self.language = "Python with {}".format(self._module)
        self.printmethod = "_{}code".format(self._module)

        self._kf = {**PythonCodePrinter._kf, **self._kf}

        super().__init__(settings=settings)


    def _print_seq(self, seq):
        "General sequence printer: converts to tuple"
        # Print tuples here instead of lists because numba supports
        #     tuples in nopython mode.
        delimiter=', '
        return '({},)'.format(delimiter.join(self._print(item) for item in seq))

    def _print_MatMul(self, expr):
        "Matrix multiplication printer"
        if expr.as_coeff_matrices()[0] is not S.One:
            expr_list = expr.as_coeff_matrices()[1]+[(expr.as_coeff_matrices()[0])]
            return '({})'.format(').dot('.join(self._print(i) for i in expr_list))
        return '({})'.format(').dot('.join(self._print(i) for i in expr.args))

    def _print_MatPow(self, expr):
        "Matrix power printer"
        return '{}({}, {})'.format(self._module_format(self._module + '.linalg.matrix_power'),
            self._print(expr.args[0]), self._print(expr.args[1]))

    def _print_Inverse(self, expr):
        "Matrix inverse printer"
        return '{}({})'.format(self._module_format(self._module + '.linalg.inv'),
            self._print(expr.args[0]))

    def _print_DotProduct(self, expr):
        # DotProduct allows any shape order, but numpy.dot does matrix
        # multiplication, so we have to make sure it gets 1 x n by n x 1.
        arg1, arg2 = expr.args
        if arg1.shape[0] != 1:
            arg1 = arg1.T
        if arg2.shape[1] != 1:
            arg2 = arg2.T

        return "%s(%s, %s)" % (self._module_format(self._module + '.dot'),
                               self._print(arg1),
                               self._print(arg2))

    def _print_MatrixSolve(self, expr):
        return "%s(%s, %s)" % (self._module_format(self._module + '.linalg.solve'),
                               self._print(expr.matrix),
                               self._print(expr.vector))

    def _print_ZeroMatrix(self, expr):
        return '{}({})'.format(self._module_format(self._module + '.zeros'),
            self._print(expr.shape))

    def _print_OneMatrix(self, expr):
        return '{}({})'.format(self._module_format(self._module + '.ones'),
            self._print(expr.shape))

    def _print_FunctionMatrix(self, expr):
        from sympy.abc import i, j
        lamda = expr.lamda
        if not isinstance(lamda, Lambda):
            lamda = Lambda((i, j), lamda(i, j))
        return '{}(lambda {}: {}, {})'.format(self._module_format(self._module + '.fromfunction'),
            ', '.join(self._print(arg) for arg in lamda.args[0]),
            self._print(lamda.args[1]), self._print(expr.shape))

    def _print_HadamardProduct(self, expr):
        func = self._module_format(self._module + '.multiply')
        return ''.join('{}({}, '.format(func, self._print(arg)) \
            for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]),
            ')' * (len(expr.args) - 1))

    def _print_KroneckerProduct(self, expr):
        func = self._module_format(self._module + '.kron')
        return ''.join('{}({}, '.format(func, self._print(arg)) \
            for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]),
            ')' * (len(expr.args) - 1))

    def _print_Adjoint(self, expr):
        return '{}({}({}))'.format(
            self._module_format(self._module + '.conjugate'),
            self._module_format(self._module + '.transpose'),
            self._print(expr.args[0]))

    def _print_DiagonalOf(self, expr):
        vect = '{}({})'.format(
            self._module_format(self._module + '.diag'),
            self._print(expr.arg))
        return '{}({}, (-1, 1))'.format(
            self._module_format(self._module + '.reshape'), vect)

    def _print_DiagMatrix(self, expr):
        return '{}({})'.format(self._module_format(self._module + '.diagflat'),
            self._print(expr.args[0]))

    def _print_DiagonalMatrix(self, expr):
        return '{}({}, {}({}, {}))'.format(self._module_format(self._module + '.multiply'),
            self._print(expr.arg), self._module_format(self._module + '.eye'),
            self._print(expr.shape[0]), self._print(expr.shape[1]))

    def _print_Piecewise(self, expr):
        "Piecewise function printer"
        from sympy.logic.boolalg import ITE, simplify_logic
        def print_cond(cond):
            """ Problem having an ITE in the cond. """
            if cond.has(ITE):
                return self._print(simplify_logic(cond))
            else:
                return self._print(cond)
        exprs = '[{}]'.format(','.join(self._print(arg.expr) for arg in expr.args))
        conds = '[{}]'.format(','.join(print_cond(arg.cond) for arg in expr.args))
        # If [default_value, True] is a (expr, cond) sequence in a Piecewise object
        #     it will behave the same as passing the 'default' kwarg to select()
        #     *as long as* it is the last element in expr.args.
        # If this is not the case, it may be triggered prematurely.
        return '{}({}, {}, default={})'.format(
            self._module_format(self._module + '.select'), conds, exprs,
            self._print(S.NaN))

    def _print_Relational(self, expr):
        "Relational printer for Equality and Unequality"
        op = {
            '==' :'equal',
            '!=' :'not_equal',
            '<'  :'less',
            '<=' :'less_equal',
            '>'  :'greater',
            '>=' :'greater_equal',
        }
        if expr.rel_op in op:
            lhs = self._print(expr.lhs)
            rhs = self._print(expr.rhs)
            return '{op}({lhs}, {rhs})'.format(op=self._module_format(self._module + '.'+op[expr.rel_op]),
                                               lhs=lhs, rhs=rhs)
        return super()._print_Relational(expr)

    def _print_And(self, expr):
        "Logical And printer"
        # We have to override LambdaPrinter because it uses Python 'and' keyword.
        # If LambdaPrinter didn't define it, we could use StrPrinter's
        # version of the function and add 'logical_and' to NUMPY_TRANSLATIONS.
        return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_and'), ','.join(self._print(i) for i in expr.args))

    def _print_Or(self, expr):
        "Logical Or printer"
        # We have to override LambdaPrinter because it uses Python 'or' keyword.
        # If LambdaPrinter didn't define it, we could use StrPrinter's
        # version of the function and add 'logical_or' to NUMPY_TRANSLATIONS.
        return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_or'), ','.join(self._print(i) for i in expr.args))

    def _print_Not(self, expr):
        "Logical Not printer"
        # We have to override LambdaPrinter because it uses Python 'not' keyword.
        # If LambdaPrinter didn't define it, we would still have to define our
        #     own because StrPrinter doesn't define it.
        return '{}({})'.format(self._module_format(self._module + '.logical_not'), ','.join(self._print(i) for i in expr.args))

    def _print_Pow(self, expr, rational=False):
        # XXX Workaround for negative integer power error
        if expr.exp.is_integer and expr.exp.is_negative:
            expr = Pow(expr.base, expr.exp.evalf(), evaluate=False)
        return self._hprint_Pow(expr, rational=rational, sqrt=self._module + '.sqrt')

    def _print_Min(self, expr):
        return '{}(({}), axis=0)'.format(self._module_format(self._module + '.amin'), ','.join(self._print(i) for i in expr.args))

    def _print_Max(self, expr):
        return '{}(({}), axis=0)'.format(self._module_format(self._module + '.amax'), ','.join(self._print(i) for i in expr.args))

    def _print_arg(self, expr):
        return "%s(%s)" % (self._module_format(self._module + '.angle'), self._print(expr.args[0]))

    def _print_im(self, expr):
        return "%s(%s)" % (self._module_format(self._module + '.imag'), self._print(expr.args[0]))

    def _print_Mod(self, expr):
        return "%s(%s)" % (self._module_format(self._module + '.mod'), ', '.join(
            (self._print(arg) for arg in expr.args)))

    def _print_re(self, expr):
        return "%s(%s)" % (self._module_format(self._module + '.real'), self._print(expr.args[0]))

    def _print_sinc(self, expr):
        return "%s(%s)" % (self._module_format(self._module + '.sinc'), self._print(expr.args[0]/S.Pi))

    def _print_MatrixBase(self, expr):
        func = self.known_functions.get(expr.__class__.__name__, None)
        if func is None:
            func = self._module_format(self._module + '.array')
        return "%s(%s)" % (func, self._print(expr.tolist()))

    def _print_Identity(self, expr):
        shape = expr.shape
        if all(dim.is_Integer for dim in shape):
            return "%s(%s)" % (self._module_format(self._module + '.eye'), self._print(expr.shape[0]))
        else:
            raise NotImplementedError("Symbolic matrix dimensions are not yet supported for identity matrices")

    def _print_BlockMatrix(self, expr):
        return '{}({})'.format(self._module_format(self._module + '.block'),
                                 self._print(expr.args[0].tolist()))

    def _print_NDimArray(self, expr):
        if len(expr.shape) == 1:
            return self._module + '.array(' + self._print(expr.args[0]) + ')'
        if len(expr.shape) == 2:
            return self._print(expr.tomatrix())
        # Should be possible to extend to more dimensions
        return CodePrinter._print_not_supported(self, expr)

    _add = "add"
    _einsum = "einsum"
    _transpose = "transpose"
    _ones = "ones"
    _zeros = "zeros"

    _print_lowergamma = CodePrinter._print_not_supported
    _print_uppergamma = CodePrinter._print_not_supported
    _print_fresnelc = CodePrinter._print_not_supported
    _print_fresnels = CodePrinter._print_not_supported

for func in _numpy_known_functions:
    setattr(NumPyPrinter, f'_print_{func}', _print_known_func)

for const in _numpy_known_constants:
    setattr(NumPyPrinter, f'_print_{const}', _print_known_const)


_known_functions_scipy_special = {
    'Ei': 'expi',
    'erf': 'erf',
    'erfc': 'erfc',
    'besselj': 'jv',
    'bessely': 'yv',
    'besseli': 'iv',
    'besselk': 'kv',
    'cosm1': 'cosm1',
    'powm1': 'powm1',
    'factorial': 'factorial',
    'gamma': 'gamma',
    'loggamma': 'gammaln',
    'digamma': 'psi',
    'polygamma': 'polygamma',
    'RisingFactorial': 'poch',
    'jacobi': 'eval_jacobi',
    'gegenbauer': 'eval_gegenbauer',
    'chebyshevt': 'eval_chebyt',
    'chebyshevu': 'eval_chebyu',
    'legendre': 'eval_legendre',
    'hermite': 'eval_hermite',
    'laguerre': 'eval_laguerre',
    'assoc_laguerre': 'eval_genlaguerre',
    'beta': 'beta',
    'LambertW' : 'lambertw',
}

_known_constants_scipy_constants = {
    'GoldenRatio': 'golden_ratio',
    'Pi': 'pi',
}
_scipy_known_functions = {k : "scipy.special." + v for k, v in _known_functions_scipy_special.items()}
_scipy_known_constants = {k : "scipy.constants." + v for k, v in _known_constants_scipy_constants.items()}

class SciPyPrinter(NumPyPrinter):

    _kf = {**NumPyPrinter._kf, **_scipy_known_functions}
    _kc = {**NumPyPrinter._kc, **_scipy_known_constants}

    def __init__(self, settings=None):
        super().__init__(settings=settings)
        self.language = "Python with SciPy and NumPy"

    def _print_SparseRepMatrix(self, expr):
        i, j, data = [], [], []
        for (r, c), v in expr.todok().items():
            i.append(r)
            j.append(c)
            data.append(v)

        return "{name}(({data}, ({i}, {j})), shape={shape})".format(
            name=self._module_format('scipy.sparse.coo_matrix'),
            data=data, i=i, j=j, shape=expr.shape
        )

    _print_ImmutableSparseMatrix = _print_SparseRepMatrix

    # SciPy's lpmv has a different order of arguments from assoc_legendre
    def _print_assoc_legendre(self, expr):
        return "{0}({2}, {1}, {3})".format(
            self._module_format('scipy.special.lpmv'),
            self._print(expr.args[0]),
            self._print(expr.args[1]),
            self._print(expr.args[2]))

    def _print_lowergamma(self, expr):
        return "{0}({2})*{1}({2}, {3})".format(
            self._module_format('scipy.special.gamma'),
            self._module_format('scipy.special.gammainc'),
            self._print(expr.args[0]),
            self._print(expr.args[1]))

    def _print_uppergamma(self, expr):
        return "{0}({2})*{1}({2}, {3})".format(
            self._module_format('scipy.special.gamma'),
            self._module_format('scipy.special.gammaincc'),
            self._print(expr.args[0]),
            self._print(expr.args[1]))

    def _print_betainc(self, expr):
        betainc = self._module_format('scipy.special.betainc')
        beta = self._module_format('scipy.special.beta')
        args = [self._print(arg) for arg in expr.args]
        return f"({betainc}({args[0]}, {args[1]}, {args[3]}) - {betainc}({args[0]}, {args[1]}, {args[2]})) \
            * {beta}({args[0]}, {args[1]})"

    def _print_betainc_regularized(self, expr):
        return "{0}({1}, {2}, {4}) - {0}({1}, {2}, {3})".format(
            self._module_format('scipy.special.betainc'),
            self._print(expr.args[0]),
            self._print(expr.args[1]),
            self._print(expr.args[2]),
            self._print(expr.args[3]))

    def _print_fresnels(self, expr):
        return "{}({})[0]".format(
                self._module_format("scipy.special.fresnel"),
                self._print(expr.args[0]))

    def _print_fresnelc(self, expr):
        return "{}({})[1]".format(
                self._module_format("scipy.special.fresnel"),
                self._print(expr.args[0]))

    def _print_airyai(self, expr):
        return "{}({})[0]".format(
                self._module_format("scipy.special.airy"),
                self._print(expr.args[0]))

    def _print_airyaiprime(self, expr):
        return "{}({})[1]".format(
                self._module_format("scipy.special.airy"),
                self._print(expr.args[0]))

    def _print_airybi(self, expr):
        return "{}({})[2]".format(
                self._module_format("scipy.special.airy"),
                self._print(expr.args[0]))

    def _print_airybiprime(self, expr):
        return "{}({})[3]".format(
                self._module_format("scipy.special.airy"),
                self._print(expr.args[0]))

    def _print_bernoulli(self, expr):
        # scipy's bernoulli is inconsistent with SymPy's so rewrite
        return self._print(expr._eval_rewrite_as_zeta(*expr.args))

    def _print_harmonic(self, expr):
        return self._print(expr._eval_rewrite_as_zeta(*expr.args))

    def _print_Integral(self, e):
        integration_vars, limits = _unpack_integral_limits(e)

        if len(limits) == 1:
            # nicer (but not necessary) to prefer quad over nquad for 1D case
            module_str = self._module_format("scipy.integrate.quad")
            limit_str = "%s, %s" % tuple(map(self._print, limits[0]))
        else:
            module_str = self._module_format("scipy.integrate.nquad")
            limit_str = "({})".format(", ".join(
                "(%s, %s)" % tuple(map(self._print, l)) for l in limits))

        return "{}(lambda {}: {}, {})[0]".format(
                module_str,
                ", ".join(map(self._print, integration_vars)),
                self._print(e.args[0]),
                limit_str)

    def _print_Si(self, expr):
        return "{}({})[0]".format(
                self._module_format("scipy.special.sici"),
                self._print(expr.args[0]))

    def _print_Ci(self, expr):
        return "{}({})[1]".format(
                self._module_format("scipy.special.sici"),
                self._print(expr.args[0]))

for func in _scipy_known_functions:
    setattr(SciPyPrinter, f'_print_{func}', _print_known_func)

for const in _scipy_known_constants:
    setattr(SciPyPrinter, f'_print_{const}', _print_known_const)


_cupy_known_functions = {k : "cupy." + v for k, v in _known_functions_numpy.items()}
_cupy_known_constants = {k : "cupy." + v for k, v in _known_constants_numpy.items()}

class CuPyPrinter(NumPyPrinter):
    """
    CuPy printer which handles vectorized piecewise functions,
    logical operators, etc.
    """

    _module = 'cupy'
    _kf = _cupy_known_functions
    _kc = _cupy_known_constants

    def __init__(self, settings=None):
        super().__init__(settings=settings)

for func in _cupy_known_functions:
    setattr(CuPyPrinter, f'_print_{func}', _print_known_func)

for const in _cupy_known_constants:
    setattr(CuPyPrinter, f'_print_{const}', _print_known_const)


_jax_known_functions = {k: 'jax.numpy.' + v for k, v in _known_functions_numpy.items()}
_jax_known_constants = {k: 'jax.numpy.' + v for k, v in _known_constants_numpy.items()}

class JaxPrinter(NumPyPrinter):
    """
    JAX printer which handles vectorized piecewise functions,
    logical operators, etc.
    """
    _module = "jax.numpy"

    _kf = _jax_known_functions
    _kc = _jax_known_constants

    def __init__(self, settings=None):
        super().__init__(settings=settings)

    # These need specific override to allow for the lack of "jax.numpy.reduce"
    def _print_And(self, expr):
        "Logical And printer"
        return "{}({}.asarray([{}]), axis=0)".format(
            self._module_format(self._module + ".all"),
            self._module_format(self._module),
            ",".join(self._print(i) for i in expr.args),
        )

    def _print_Or(self, expr):
        "Logical Or printer"
        return "{}({}.asarray([{}]), axis=0)".format(
            self._module_format(self._module + ".any"),
            self._module_format(self._module),
            ",".join(self._print(i) for i in expr.args),
        )

for func in _jax_known_functions:
    setattr(JaxPrinter, f'_print_{func}', _print_known_func)

for const in _jax_known_constants:
    setattr(JaxPrinter, f'_print_{const}', _print_known_const)
