# -*- coding: utf-8 -*-

from .cartan_type import CartanType
from mpmath import fac
from sympy.core.backend import Matrix, eye, Rational, igcd
from sympy.core.basic import Atom

class WeylGroup(Atom):

    """
    For each semisimple Lie group, we have a Weyl group.  It is a subgroup of
    the isometry group of the root system.  Specifically, it's the subgroup
    that is generated by reflections through the hyperplanes orthogonal to
    the roots.  Therefore, Weyl groups are reflection groups, and so a Weyl
    group is a finite Coxeter group.

    """

    def __new__(cls, cartantype):
        obj = Atom.__new__(cls)
        obj.cartan_type = CartanType(cartantype)
        return obj

    def generators(self):
        """
        This method creates the generating reflections of the Weyl group for
        a given Lie algebra.  For a Lie algebra of rank n, there are n
        different generating reflections.  This function returns them as
        a list.

        Examples
        ========

        >>> from sympy.liealgebras.weyl_group import WeylGroup
        >>> c = WeylGroup("F4")
        >>> c.generators()
        ['r1', 'r2', 'r3', 'r4']
        """
        n = self.cartan_type.rank()
        generators = []
        for i in range(1, n+1):
            reflection = "r"+str(i)
            generators.append(reflection)
        return generators

    def group_order(self):
        """
        This method returns the order of the Weyl group.
        For types A, B, C, D, and E the order depends on
        the rank of the Lie algebra.  For types F and G,
        the order is fixed.

        Examples
        ========

        >>> from sympy.liealgebras.weyl_group import WeylGroup
        >>> c = WeylGroup("D4")
        >>> c.group_order()
        192.0
        """
        n = self.cartan_type.rank()
        if self.cartan_type.series == "A":
            return fac(n+1)

        if self.cartan_type.series in ("B", "C"):
            return fac(n)*(2**n)

        if self.cartan_type.series == "D":
            return fac(n)*(2**(n-1))

        if self.cartan_type.series == "E":
            if n == 6:
                return 51840
            if n == 7:
                return 2903040
            if n == 8:
                return 696729600
        if self.cartan_type.series == "F":
            return 1152

        if self.cartan_type.series == "G":
            return 12

    def group_name(self):
        """
        This method returns some general information about the Weyl group for
        a given Lie algebra.  It returns the name of the group and the elements
        it acts on, if relevant.
        """
        n = self.cartan_type.rank()
        if self.cartan_type.series == "A":
            return "S"+str(n+1) + ": the symmetric group acting on " + str(n+1) + " elements."

        if self.cartan_type.series in ("B", "C"):
            return "The hyperoctahedral group acting on " + str(2*n) + " elements."

        if self.cartan_type.series == "D":
            return "The symmetry group of the " + str(n) + "-dimensional demihypercube."

        if self.cartan_type.series == "E":
            if n == 6:
                return "The symmetry group of the 6-polytope."

            if n == 7:
                return "The symmetry group of the 7-polytope."

            if n == 8:
                return "The symmetry group of the 8-polytope."

        if self.cartan_type.series == "F":
            return "The symmetry group of the 24-cell, or icositetrachoron."

        if self.cartan_type.series == "G":
            return "D6, the dihedral group of order 12, and symmetry group of the hexagon."

    def element_order(self, weylelt):
        """
        This method returns the order of a given Weyl group element, which should
        be specified by the user in the form of products of the generating
        reflections, i.e. of the form r1*r2 etc.

        For types A-F, this method current works by taking the matrix form of
        the specified element, and then finding what power of the matrix is the
        identity.  It then returns this power.

        Examples
        ========

        >>> from sympy.liealgebras.weyl_group import WeylGroup
        >>> b = WeylGroup("B4")
        >>> b.element_order('r1*r4*r2')
        4
        """
        n = self.cartan_type.rank()
        if self.cartan_type.series == "A":
            a = self.matrix_form(weylelt)
            order = 1
            while a != eye(n+1):
                a *= self.matrix_form(weylelt)
                order += 1
            return order

        if self.cartan_type.series == "D":
            a = self.matrix_form(weylelt)
            order = 1
            while a != eye(n):
                a *= self.matrix_form(weylelt)
                order += 1
            return order

        if self.cartan_type.series == "E":
            a = self.matrix_form(weylelt)
            order = 1
            while a != eye(8):
                a *= self.matrix_form(weylelt)
                order += 1
            return order

        if self.cartan_type.series == "G":
            elts = list(weylelt)
            reflections = elts[1::3]
            m = self.delete_doubles(reflections)
            while self.delete_doubles(m) != m:
                m = self.delete_doubles(m)
                reflections = m
            if len(reflections) % 2 == 1:
                return 2

            elif len(reflections) == 0:
                return 1

            else:
                if len(reflections) == 1:
                    return 2
                else:
                    m = len(reflections) // 2
                    lcm = (6 * m)/ igcd(m, 6)
                order = lcm / m
                return order


        if self.cartan_type.series == 'F':
            a = self.matrix_form(weylelt)
            order = 1
            while a != eye(4):
                a *= self.matrix_form(weylelt)
                order += 1
            return order


        if self.cartan_type.series in ("B", "C"):
            a = self.matrix_form(weylelt)
            order = 1
            while a != eye(n):
                a *= self.matrix_form(weylelt)
                order += 1
            return order

    def delete_doubles(self, reflections):
        """
        This is a helper method for determining the order of an element in the
        Weyl group of G2.  It takes a Weyl element and if repeated simple reflections
        in it, it deletes them.
        """
        counter = 0
        copy = list(reflections)
        for elt in copy:
            if counter < len(copy)-1:
                if copy[counter + 1] == elt:
                    del copy[counter]
                    del copy[counter]
            counter += 1


        return copy


    def matrix_form(self, weylelt):
        """
        This method takes input from the user in the form of products of the
        generating reflections, and returns the matrix corresponding to the
        element of the Weyl group.  Since each element of the Weyl group is
        a reflection of some type, there is a corresponding matrix representation.
        This method uses the standard representation for all the generating
        reflections.

        Examples
        ========

        >>> from sympy.liealgebras.weyl_group import WeylGroup
        >>> f = WeylGroup("F4")
        >>> f.matrix_form('r2*r3')
        Matrix([
        [1, 0, 0,  0],
        [0, 1, 0,  0],
        [0, 0, 0, -1],
        [0, 0, 1,  0]])

        """
        elts = list(weylelt)
        reflections = elts[1::3]
        n = self.cartan_type.rank()
        if self.cartan_type.series == 'A':
            matrixform = eye(n+1)
            for elt in reflections:
                a = int(elt)
                mat = eye(n+1)
                mat[a-1, a-1] = 0
                mat[a-1, a] = 1
                mat[a, a-1] = 1
                mat[a, a] = 0
                matrixform *= mat
            return matrixform

        if self.cartan_type.series == 'D':
            matrixform = eye(n)
            for elt in reflections:
                a = int(elt)
                mat = eye(n)
                if a < n:
                    mat[a-1, a-1] = 0
                    mat[a-1, a] = 1
                    mat[a, a-1] = 1
                    mat[a, a] = 0
                    matrixform *= mat
                else:
                    mat[n-2, n-1] = -1
                    mat[n-2, n-2] = 0
                    mat[n-1, n-2] = -1
                    mat[n-1, n-1] = 0
                    matrixform *= mat
            return matrixform

        if self.cartan_type.series == 'G':
            matrixform = eye(3)
            for elt in reflections:
                a = int(elt)
                if a == 1:
                    gen1 = Matrix([[1, 0, 0], [0, 0, 1], [0, 1, 0]])
                    matrixform *= gen1
                else:
                    gen2 = Matrix([[Rational(2, 3), Rational(2, 3), Rational(-1, 3)],
                                   [Rational(2, 3), Rational(-1, 3), Rational(2, 3)],
                                   [Rational(-1, 3), Rational(2, 3), Rational(2, 3)]])
                    matrixform *= gen2
            return matrixform

        if self.cartan_type.series == 'F':
            matrixform = eye(4)
            for elt in reflections:
                a = int(elt)
                if a == 1:
                    mat = Matrix([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
                    matrixform *= mat
                elif a == 2:
                    mat = Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]])
                    matrixform *= mat
                elif a == 3:
                    mat = Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, -1]])
                    matrixform *= mat
                else:

                    mat = Matrix([[Rational(1, 2), Rational(1, 2), Rational(1, 2), Rational(1, 2)],
                        [Rational(1, 2), Rational(1, 2), Rational(-1, 2), Rational(-1, 2)],
                        [Rational(1, 2), Rational(-1, 2), Rational(1, 2), Rational(-1, 2)],
                        [Rational(1, 2), Rational(-1, 2), Rational(-1, 2), Rational(1, 2)]])
                    matrixform *= mat
            return matrixform

        if self.cartan_type.series == 'E':
            matrixform = eye(8)
            for elt in reflections:
                a = int(elt)
                if a == 1:
                    mat = Matrix([[Rational(3, 4), Rational(1, 4), Rational(1, 4), Rational(1, 4),
                        Rational(1, 4), Rational(1, 4), Rational(1, 4), Rational(-1, 4)],
                        [Rational(1, 4), Rational(3, 4), Rational(-1, 4), Rational(-1, 4),
                            Rational(-1, 4), Rational(-1, 4), Rational(1, 4), Rational(-1, 4)],
                        [Rational(1, 4), Rational(-1, 4), Rational(3, 4), Rational(-1, 4),
                        Rational(-1, 4), Rational(-1, 4), Rational(-1, 4), Rational(1, 4)],
                        [Rational(1, 4), Rational(-1, 4), Rational(-1, 4), Rational(3, 4),
                        Rational(-1, 4), Rational(-1, 4), Rational(-1, 4), Rational(1, 4)],
                        [Rational(1, 4), Rational(-1, 4), Rational(-1, 4), Rational(-1, 4),
                        Rational(3, 4), Rational(-1, 4), Rational(-1, 4), Rational(1, 4)],
                        [Rational(1, 4), Rational(-1, 4), Rational(-1, 4), Rational(-1, 4),
                        Rational(-1, 4), Rational(3, 4), Rational(-1, 4), Rational(1, 4)],
                        [Rational(1, 4), Rational(-1, 4), Rational(-1, 4), Rational(-1, 4),
                        Rational(-1, 4), Rational(-1, 4), Rational(-3, 4), Rational(1, 4)],
                        [Rational(1, 4), Rational(-1, 4), Rational(-1, 4), Rational(-1, 4),
                        Rational(-1, 4), Rational(-1, 4), Rational(-1, 4), Rational(3, 4)]])
                    matrixform *= mat
                elif a == 2:
                    mat = eye(8)
                    mat[0, 0] = 0
                    mat[0, 1] = -1
                    mat[1, 0] = -1
                    mat[1, 1] = 0
                    matrixform *= mat
                else:
                    mat = eye(8)
                    mat[a-3, a-3] = 0
                    mat[a-3, a-2] = 1
                    mat[a-2, a-3] = 1
                    mat[a-2, a-2] = 0
                    matrixform *= mat
            return matrixform


        if self.cartan_type.series in ("B", "C"):
            matrixform = eye(n)
            for elt in reflections:
                a = int(elt)
                mat = eye(n)
                if a == 1:
                    mat[0, 0] = -1
                    matrixform *= mat
                else:
                    mat[a - 2, a - 2] = 0
                    mat[a-2, a-1] = 1
                    mat[a - 1, a - 2] = 1
                    mat[a -1, a - 1] = 0
                    matrixform *= mat
            return matrixform



    def coxeter_diagram(self):
        """
        This method returns the Coxeter diagram corresponding to a Weyl group.
        The Coxeter diagram can be obtained from a Lie algebra's Dynkin diagram
        by deleting all arrows; the Coxeter diagram is the undirected graph.
        The vertices of the Coxeter diagram represent the generating reflections
        of the Weyl group, $s_i$.  An edge is drawn between $s_i$ and $s_j$ if the order
        $m(i, j)$ of $s_is_j$ is greater than two.  If there is one edge, the order
        $m(i, j)$ is 3.  If there are two edges, the order $m(i, j)$ is 4, and if there
        are three edges, the order $m(i, j)$ is 6.

        Examples
        ========

        >>> from sympy.liealgebras.weyl_group import WeylGroup
        >>> c = WeylGroup("B3")
        >>> print(c.coxeter_diagram())
        0---0===0
        1   2   3
        """
        n = self.cartan_type.rank()
        if self.cartan_type.series in ("A", "D", "E"):
            return self.cartan_type.dynkin_diagram()

        if self.cartan_type.series in ("B", "C"):
            diag = "---".join("0" for i in range(1, n)) + "===0\n"
            diag += "   ".join(str(i) for i in range(1, n+1))
            return diag

        if self.cartan_type.series == "F":
            diag = "0---0===0---0\n"
            diag += "   ".join(str(i) for i in range(1, 5))
            return diag

        if self.cartan_type.series == "G":
            diag = "0≡≡≡0\n1   2"
            return diag
