"""Tools and arithmetics for monomials of distributed polynomials. """


from itertools import combinations_with_replacement, product
from textwrap import dedent

from sympy.core import Mul, S, Tuple, sympify
from sympy.polys.polyerrors import ExactQuotientFailed
from sympy.polys.polyutils import PicklableWithSlots, dict_from_expr
from sympy.utilities import public
from sympy.utilities.iterables import is_sequence, iterable

@public
def itermonomials(variables, max_degrees, min_degrees=None):
    r"""
    ``max_degrees`` and ``min_degrees`` are either both integers or both lists.
    Unless otherwise specified, ``min_degrees`` is either ``0`` or
    ``[0, ..., 0]``.

    A generator of all monomials ``monom`` is returned, such that
    either
    ``min_degree <= total_degree(monom) <= max_degree``,
    or
    ``min_degrees[i] <= degree_list(monom)[i] <= max_degrees[i]``,
    for all ``i``.

    Case I. ``max_degrees`` and ``min_degrees`` are both integers
    =============================================================

    Given a set of variables $V$ and a min_degree $N$ and a max_degree $M$
    generate a set of monomials of degree less than or equal to $N$ and greater
    than or equal to $M$. The total number of monomials in commutative
    variables is huge and is given by the following formula if $M = 0$:

        .. math::
            \frac{(\#V + N)!}{\#V! N!}

    For example if we would like to generate a dense polynomial of
    a total degree $N = 50$ and $M = 0$, which is the worst case, in 5
    variables, assuming that exponents and all of coefficients are 32-bit long
    and stored in an array we would need almost 80 GiB of memory! Fortunately
    most polynomials, that we will encounter, are sparse.

    Consider monomials in commutative variables $x$ and $y$
    and non-commutative variables $a$ and $b$::

        >>> from sympy import symbols
        >>> from sympy.polys.monomials import itermonomials
        >>> from sympy.polys.orderings import monomial_key
        >>> from sympy.abc import x, y

        >>> sorted(itermonomials([x, y], 2), key=monomial_key('grlex', [y, x]))
        [1, x, y, x**2, x*y, y**2]

        >>> sorted(itermonomials([x, y], 3), key=monomial_key('grlex', [y, x]))
        [1, x, y, x**2, x*y, y**2, x**3, x**2*y, x*y**2, y**3]

        >>> a, b = symbols('a, b', commutative=False)
        >>> set(itermonomials([a, b, x], 2))
        {1, a, a**2, b, b**2, x, x**2, a*b, b*a, x*a, x*b}

        >>> sorted(itermonomials([x, y], 2, 1), key=monomial_key('grlex', [y, x]))
        [x, y, x**2, x*y, y**2]

    Case II. ``max_degrees`` and ``min_degrees`` are both lists
    ===========================================================

    If ``max_degrees = [d_1, ..., d_n]`` and
    ``min_degrees = [e_1, ..., e_n]``, the number of monomials generated
    is:

    .. math::
        (d_1 - e_1 + 1) (d_2 - e_2 + 1) \cdots (d_n - e_n + 1)

    Let us generate all monomials ``monom`` in variables $x$ and $y$
    such that ``[1, 2][i] <= degree_list(monom)[i] <= [2, 4][i]``,
    ``i = 0, 1`` ::

        >>> from sympy import symbols
        >>> from sympy.polys.monomials import itermonomials
        >>> from sympy.polys.orderings import monomial_key
        >>> from sympy.abc import x, y

        >>> sorted(itermonomials([x, y], [2, 4], [1, 2]), reverse=True, key=monomial_key('lex', [x, y]))
        [x**2*y**4, x**2*y**3, x**2*y**2, x*y**4, x*y**3, x*y**2]
    """
    n = len(variables)
    if is_sequence(max_degrees):
        if len(max_degrees) != n:
            raise ValueError('Argument sizes do not match')
        if min_degrees is None:
            min_degrees = [0]*n
        elif not is_sequence(min_degrees):
            raise ValueError('min_degrees is not a list')
        else:
            if len(min_degrees) != n:
                raise ValueError('Argument sizes do not match')
            if any(i < 0 for i in min_degrees):
                raise ValueError("min_degrees cannot contain negative numbers")
        total_degree = False
    else:
        max_degree = max_degrees
        if max_degree < 0:
            raise ValueError("max_degrees cannot be negative")
        if min_degrees is None:
            min_degree = 0
        else:
            if min_degrees < 0:
                raise ValueError("min_degrees cannot be negative")
            min_degree = min_degrees
        total_degree = True
    if total_degree:
        if min_degree > max_degree:
            return
        if not variables or max_degree == 0:
            yield S.One
            return
        # Force to list in case of passed tuple or other incompatible collection
        variables = list(variables) + [S.One]
        if all(variable.is_commutative for variable in variables):
            monomials_list_comm = []
            for item in combinations_with_replacement(variables, max_degree):
                powers = {variable: 0 for variable in variables}
                for variable in item:
                    if variable != 1:
                        powers[variable] += 1
                if sum(powers.values()) >= min_degree:
                    monomials_list_comm.append(Mul(*item))
            yield from set(monomials_list_comm)
        else:
            monomials_list_non_comm = []
            for item in product(variables, repeat=max_degree):
                powers = {variable: 0 for variable in variables}
                for variable in item:
                    if variable != 1:
                        powers[variable] += 1
                if sum(powers.values()) >= min_degree:
                    monomials_list_non_comm.append(Mul(*item))
            yield from set(monomials_list_non_comm)
    else:
        if any(min_degrees[i] > max_degrees[i] for i in range(n)):
            raise ValueError('min_degrees[i] must be <= max_degrees[i] for all i')
        power_lists = []
        for var, min_d, max_d in zip(variables, min_degrees, max_degrees):
            power_lists.append([var**i for i in range(min_d, max_d + 1)])
        for powers in product(*power_lists):
            yield Mul(*powers)

def monomial_count(V, N):
    r"""
    Computes the number of monomials.

    The number of monomials is given by the following formula:

    .. math::

        \frac{(\#V + N)!}{\#V! N!}

    where `N` is a total degree and `V` is a set of variables.

    Examples
    ========

    >>> from sympy.polys.monomials import itermonomials, monomial_count
    >>> from sympy.polys.orderings import monomial_key
    >>> from sympy.abc import x, y

    >>> monomial_count(2, 2)
    6

    >>> M = list(itermonomials([x, y], 2))

    >>> sorted(M, key=monomial_key('grlex', [y, x]))
    [1, x, y, x**2, x*y, y**2]
    >>> len(M)
    6

    """
    from sympy.functions.combinatorial.factorials import factorial
    return factorial(V + N) / factorial(V) / factorial(N)

def monomial_mul(A, B):
    """
    Multiplication of tuples representing monomials.

    Examples
    ========

    Lets multiply `x**3*y**4*z` with `x*y**2`::

        >>> from sympy.polys.monomials import monomial_mul

        >>> monomial_mul((3, 4, 1), (1, 2, 0))
        (4, 6, 1)

    which gives `x**4*y**5*z`.

    """
    return tuple([ a + b for a, b in zip(A, B) ])

def monomial_div(A, B):
    """
    Division of tuples representing monomials.

    Examples
    ========

    Lets divide `x**3*y**4*z` by `x*y**2`::

        >>> from sympy.polys.monomials import monomial_div

        >>> monomial_div((3, 4, 1), (1, 2, 0))
        (2, 2, 1)

    which gives `x**2*y**2*z`. However::

        >>> monomial_div((3, 4, 1), (1, 2, 2)) is None
        True

    `x*y**2*z**2` does not divide `x**3*y**4*z`.

    """
    C = monomial_ldiv(A, B)

    if all(c >= 0 for c in C):
        return tuple(C)
    else:
        return None

def monomial_ldiv(A, B):
    """
    Division of tuples representing monomials.

    Examples
    ========

    Lets divide `x**3*y**4*z` by `x*y**2`::

        >>> from sympy.polys.monomials import monomial_ldiv

        >>> monomial_ldiv((3, 4, 1), (1, 2, 0))
        (2, 2, 1)

    which gives `x**2*y**2*z`.

        >>> monomial_ldiv((3, 4, 1), (1, 2, 2))
        (2, 2, -1)

    which gives `x**2*y**2*z**-1`.

    """
    return tuple([ a - b for a, b in zip(A, B) ])

def monomial_pow(A, n):
    """Return the n-th pow of the monomial. """
    return tuple([ a*n for a in A ])

def monomial_gcd(A, B):
    """
    Greatest common divisor of tuples representing monomials.

    Examples
    ========

    Lets compute GCD of `x*y**4*z` and `x**3*y**2`::

        >>> from sympy.polys.monomials import monomial_gcd

        >>> monomial_gcd((1, 4, 1), (3, 2, 0))
        (1, 2, 0)

    which gives `x*y**2`.

    """
    return tuple([ min(a, b) for a, b in zip(A, B) ])

def monomial_lcm(A, B):
    """
    Least common multiple of tuples representing monomials.

    Examples
    ========

    Lets compute LCM of `x*y**4*z` and `x**3*y**2`::

        >>> from sympy.polys.monomials import monomial_lcm

        >>> monomial_lcm((1, 4, 1), (3, 2, 0))
        (3, 4, 1)

    which gives `x**3*y**4*z`.

    """
    return tuple([ max(a, b) for a, b in zip(A, B) ])

def monomial_divides(A, B):
    """
    Does there exist a monomial X such that XA == B?

    Examples
    ========

    >>> from sympy.polys.monomials import monomial_divides
    >>> monomial_divides((1, 2), (3, 4))
    True
    >>> monomial_divides((1, 2), (0, 2))
    False
    """
    return all(a <= b for a, b in zip(A, B))

def monomial_max(*monoms):
    """
    Returns maximal degree for each variable in a set of monomials.

    Examples
    ========

    Consider monomials `x**3*y**4*z**5`, `y**5*z` and `x**6*y**3*z**9`.
    We wish to find out what is the maximal degree for each of `x`, `y`
    and `z` variables::

        >>> from sympy.polys.monomials import monomial_max

        >>> monomial_max((3,4,5), (0,5,1), (6,3,9))
        (6, 5, 9)

    """
    M = list(monoms[0])

    for N in monoms[1:]:
        for i, n in enumerate(N):
            M[i] = max(M[i], n)

    return tuple(M)

def monomial_min(*monoms):
    """
    Returns minimal degree for each variable in a set of monomials.

    Examples
    ========

    Consider monomials `x**3*y**4*z**5`, `y**5*z` and `x**6*y**3*z**9`.
    We wish to find out what is the minimal degree for each of `x`, `y`
    and `z` variables::

        >>> from sympy.polys.monomials import monomial_min

        >>> monomial_min((3,4,5), (0,5,1), (6,3,9))
        (0, 3, 1)

    """
    M = list(monoms[0])

    for N in monoms[1:]:
        for i, n in enumerate(N):
            M[i] = min(M[i], n)

    return tuple(M)

def monomial_deg(M):
    """
    Returns the total degree of a monomial.

    Examples
    ========

    The total degree of `xy^2` is 3:

    >>> from sympy.polys.monomials import monomial_deg
    >>> monomial_deg((1, 2))
    3
    """
    return sum(M)

def term_div(a, b, domain):
    """Division of two terms in over a ring/field. """
    a_lm, a_lc = a
    b_lm, b_lc = b

    monom = monomial_div(a_lm, b_lm)

    if domain.is_Field:
        if monom is not None:
            return monom, domain.quo(a_lc, b_lc)
        else:
            return None
    else:
        if not (monom is None or a_lc % b_lc):
            return monom, domain.quo(a_lc, b_lc)
        else:
            return None

class MonomialOps:
    """Code generator of fast monomial arithmetic functions. """

    def __init__(self, ngens):
        self.ngens = ngens

    def _build(self, code, name):
        ns = {}
        exec(code, ns)
        return ns[name]

    def _vars(self, name):
        return [ "%s%s" % (name, i) for i in range(self.ngens) ]

    def mul(self):
        name = "monomial_mul"
        template = dedent("""\
        def %(name)s(A, B):
            (%(A)s,) = A
            (%(B)s,) = B
            return (%(AB)s,)
        """)
        A = self._vars("a")
        B = self._vars("b")
        AB = [ "%s + %s" % (a, b) for a, b in zip(A, B) ]
        code = template % {"name": name, "A": ", ".join(A), "B": ", ".join(B), "AB": ", ".join(AB)}
        return self._build(code, name)

    def pow(self):
        name = "monomial_pow"
        template = dedent("""\
        def %(name)s(A, k):
            (%(A)s,) = A
            return (%(Ak)s,)
        """)
        A = self._vars("a")
        Ak = [ "%s*k" % a for a in A ]
        code = template % {"name": name, "A": ", ".join(A), "Ak": ", ".join(Ak)}
        return self._build(code, name)

    def mulpow(self):
        name = "monomial_mulpow"
        template = dedent("""\
        def %(name)s(A, B, k):
            (%(A)s,) = A
            (%(B)s,) = B
            return (%(ABk)s,)
        """)
        A = self._vars("a")
        B = self._vars("b")
        ABk = [ "%s + %s*k" % (a, b) for a, b in zip(A, B) ]
        code = template % {"name": name, "A": ", ".join(A), "B": ", ".join(B), "ABk": ", ".join(ABk)}
        return self._build(code, name)

    def ldiv(self):
        name = "monomial_ldiv"
        template = dedent("""\
        def %(name)s(A, B):
            (%(A)s,) = A
            (%(B)s,) = B
            return (%(AB)s,)
        """)
        A = self._vars("a")
        B = self._vars("b")
        AB = [ "%s - %s" % (a, b) for a, b in zip(A, B) ]
        code = template % {"name": name, "A": ", ".join(A), "B": ", ".join(B), "AB": ", ".join(AB)}
        return self._build(code, name)

    def div(self):
        name = "monomial_div"
        template = dedent("""\
        def %(name)s(A, B):
            (%(A)s,) = A
            (%(B)s,) = B
            %(RAB)s
            return (%(R)s,)
        """)
        A = self._vars("a")
        B = self._vars("b")
        RAB = [ "r%(i)s = a%(i)s - b%(i)s\n    if r%(i)s < 0: return None" % {"i": i} for i in range(self.ngens) ]
        R = self._vars("r")
        code = template % {"name": name, "A": ", ".join(A), "B": ", ".join(B), "RAB": "\n    ".join(RAB), "R": ", ".join(R)}
        return self._build(code, name)

    def lcm(self):
        name = "monomial_lcm"
        template = dedent("""\
        def %(name)s(A, B):
            (%(A)s,) = A
            (%(B)s,) = B
            return (%(AB)s,)
        """)
        A = self._vars("a")
        B = self._vars("b")
        AB = [ "%s if %s >= %s else %s" % (a, a, b, b) for a, b in zip(A, B) ]
        code = template % {"name": name, "A": ", ".join(A), "B": ", ".join(B), "AB": ", ".join(AB)}
        return self._build(code, name)

    def gcd(self):
        name = "monomial_gcd"
        template = dedent("""\
        def %(name)s(A, B):
            (%(A)s,) = A
            (%(B)s,) = B
            return (%(AB)s,)
        """)
        A = self._vars("a")
        B = self._vars("b")
        AB = [ "%s if %s <= %s else %s" % (a, a, b, b) for a, b in zip(A, B) ]
        code = template % {"name": name, "A": ", ".join(A), "B": ", ".join(B), "AB": ", ".join(AB)}
        return self._build(code, name)

@public
class Monomial(PicklableWithSlots):
    """Class representing a monomial, i.e. a product of powers. """

    __slots__ = ('exponents', 'gens')

    def __init__(self, monom, gens=None):
        if not iterable(monom):
            rep, gens = dict_from_expr(sympify(monom), gens=gens)
            if len(rep) == 1 and list(rep.values())[0] == 1:
                monom = list(rep.keys())[0]
            else:
                raise ValueError("Expected a monomial got {}".format(monom))

        self.exponents = tuple(map(int, monom))
        self.gens = gens

    def rebuild(self, exponents, gens=None):
        return self.__class__(exponents, gens or self.gens)

    def __len__(self):
        return len(self.exponents)

    def __iter__(self):
        return iter(self.exponents)

    def __getitem__(self, item):
        return self.exponents[item]

    def __hash__(self):
        return hash((self.__class__.__name__, self.exponents, self.gens))

    def __str__(self):
        if self.gens:
            return "*".join([ "%s**%s" % (gen, exp) for gen, exp in zip(self.gens, self.exponents) ])
        else:
            return "%s(%s)" % (self.__class__.__name__, self.exponents)

    def as_expr(self, *gens):
        """Convert a monomial instance to a SymPy expression. """
        gens = gens or self.gens

        if not gens:
            raise ValueError(
                "Cannot convert %s to an expression without generators" % self)

        return Mul(*[ gen**exp for gen, exp in zip(gens, self.exponents) ])

    def __eq__(self, other):
        if isinstance(other, Monomial):
            exponents = other.exponents
        elif isinstance(other, (tuple, Tuple)):
            exponents = other
        else:
            return False

        return self.exponents == exponents

    def __ne__(self, other):
        return not self == other

    def __mul__(self, other):
        if isinstance(other, Monomial):
            exponents = other.exponents
        elif isinstance(other, (tuple, Tuple)):
            exponents = other
        else:
            raise NotImplementedError

        return self.rebuild(monomial_mul(self.exponents, exponents))

    def __truediv__(self, other):
        if isinstance(other, Monomial):
            exponents = other.exponents
        elif isinstance(other, (tuple, Tuple)):
            exponents = other
        else:
            raise NotImplementedError

        result = monomial_div(self.exponents, exponents)

        if result is not None:
            return self.rebuild(result)
        else:
            raise ExactQuotientFailed(self, Monomial(other))

    __floordiv__ = __truediv__

    def __pow__(self, other):
        n = int(other)

        if not n:
            return self.rebuild([0]*len(self))
        elif n > 0:
            exponents = self.exponents

            for i in range(1, n):
                exponents = monomial_mul(exponents, self.exponents)

            return self.rebuild(exponents)
        else:
            raise ValueError("a non-negative integer expected, got %s" % other)

    def gcd(self, other):
        """Greatest common divisor of monomials. """
        if isinstance(other, Monomial):
            exponents = other.exponents
        elif isinstance(other, (tuple, Tuple)):
            exponents = other
        else:
            raise TypeError(
                "an instance of Monomial class expected, got %s" % other)

        return self.rebuild(monomial_gcd(self.exponents, exponents))

    def lcm(self, other):
        """Least common multiple of monomials. """
        if isinstance(other, Monomial):
            exponents = other.exponents
        elif isinstance(other, (tuple, Tuple)):
            exponents = other
        else:
            raise TypeError(
                "an instance of Monomial class expected, got %s" % other)

        return self.rebuild(monomial_lcm(self.exponents, exponents))
