from numpy import inner, zeros, inf, finfo
from numpy.linalg import norm
from math import sqrt

from .utils import make_system

__all__ = ['minres']


def minres(A, b, x0=None, shift=0.0, tol=1e-5, maxiter=None,
           M=None, callback=None, show=False, check=False):
    """
    Use MINimum RESidual iteration to solve Ax=b

    MINRES minimizes norm(Ax - b) for a real symmetric matrix A.  Unlike
    the Conjugate Gradient method, A can be indefinite or singular.

    If shift != 0 then the method solves (A - shift*I)x = b

    Parameters
    ----------
    A : {sparse matrix, ndarray, LinearOperator}
        The real symmetric N-by-N matrix of the linear system
        Alternatively, ``A`` can be a linear operator which can
        produce ``Ax`` using, e.g.,
        ``scipy.sparse.linalg.LinearOperator``.
    b : ndarray
        Right hand side of the linear system. Has shape (N,) or (N,1).

    Returns
    -------
    x : ndarray
        The converged solution.
    info : integer
        Provides convergence information:
            0  : successful exit
            >0 : convergence to tolerance not achieved, number of iterations
            <0 : illegal input or breakdown

    Other Parameters
    ----------------
    x0 : ndarray
        Starting guess for the solution.
    shift : float
        Value to apply to the system ``(A - shift * I)x = b``. Default is 0.
    tol : float
        Tolerance to achieve. The algorithm terminates when the relative
        residual is below `tol`.
    maxiter : integer
        Maximum number of iterations.  Iteration will stop after maxiter
        steps even if the specified tolerance has not been achieved.
    M : {sparse matrix, ndarray, LinearOperator}
        Preconditioner for A.  The preconditioner should approximate the
        inverse of A.  Effective preconditioning dramatically improves the
        rate of convergence, which implies that fewer iterations are needed
        to reach a given error tolerance.
    callback : function
        User-supplied function to call after each iteration.  It is called
        as callback(xk), where xk is the current solution vector.
    show : bool
        If ``True``, print out a summary and metrics related to the solution
        during iterations. Default is ``False``.
    check : bool
        If ``True``, run additional input validation to check that `A` and
        `M` (if specified) are symmetric. Default is ``False``.

    Examples
    --------
    >>> import numpy as np
    >>> from scipy.sparse import csc_matrix
    >>> from scipy.sparse.linalg import minres
    >>> A = csc_matrix([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=float)
    >>> A = A + A.T
    >>> b = np.array([2, 4, -1], dtype=float)
    >>> x, exitCode = minres(A, b)
    >>> print(exitCode)            # 0 indicates successful convergence
    0
    >>> np.allclose(A.dot(x), b)
    True

    References
    ----------
    Solution of sparse indefinite systems of linear equations,
        C. C. Paige and M. A. Saunders (1975),
        SIAM J. Numer. Anal. 12(4), pp. 617-629.
        https://web.stanford.edu/group/SOL/software/minres/

    This file is a translation of the following MATLAB implementation:
        https://web.stanford.edu/group/SOL/software/minres/minres-matlab.zip

    """
    A, M, x, b, postprocess = make_system(A, M, x0, b)

    matvec = A.matvec
    psolve = M.matvec

    first = 'Enter minres.   '
    last = 'Exit  minres.   '

    n = A.shape[0]

    if maxiter is None:
        maxiter = 5 * n

    msg = [' beta2 = 0.  If M = I, b and x are eigenvectors    ',   # -1
            ' beta1 = 0.  The exact solution is x0          ',   # 0
            ' A solution to Ax = b was found, given rtol        ',   # 1
            ' A least-squares solution was found, given rtol    ',   # 2
            ' Reasonable accuracy achieved, given eps           ',   # 3
            ' x has converged to an eigenvector                 ',   # 4
            ' acond has exceeded 0.1/eps                        ',   # 5
            ' The iteration limit was reached                   ',   # 6
            ' A  does not define a symmetric matrix             ',   # 7
            ' M  does not define a symmetric matrix             ',   # 8
            ' M  does not define a pos-def preconditioner       ']   # 9

    if show:
        print(first + 'Solution of symmetric Ax = b')
        print(first + f'n      =  {n:3g}     shift  =  {shift:23.14e}')
        print(first + f'itnlim =  {maxiter:3g}     rtol   =  {tol:11.2e}')
        print()

    istop = 0
    itn = 0
    Anorm = 0
    Acond = 0
    rnorm = 0
    ynorm = 0

    xtype = x.dtype

    eps = finfo(xtype).eps

    # Set up y and v for the first Lanczos vector v1.
    # y  =  beta1 P' v1,  where  P = C**(-1).
    # v is really P' v1.

    if x0 is None:
        r1 = b.copy()
    else:
        r1 = b - A@x
    y = psolve(r1)

    beta1 = inner(r1, y)

    if beta1 < 0:
        raise ValueError('indefinite preconditioner')
    elif beta1 == 0:
        return (postprocess(x), 0)

    bnorm = norm(b)
    if bnorm == 0:
        x = b
        return (postprocess(x), 0)

    beta1 = sqrt(beta1)

    if check:
        # are these too strict?

        # see if A is symmetric
        w = matvec(y)
        r2 = matvec(w)
        s = inner(w,w)
        t = inner(y,r2)
        z = abs(s - t)
        epsa = (s + eps) * eps**(1.0/3.0)
        if z > epsa:
            raise ValueError('non-symmetric matrix')

        # see if M is symmetric
        r2 = psolve(y)
        s = inner(y,y)
        t = inner(r1,r2)
        z = abs(s - t)
        epsa = (s + eps) * eps**(1.0/3.0)
        if z > epsa:
            raise ValueError('non-symmetric preconditioner')

    # Initialize other quantities
    oldb = 0
    beta = beta1
    dbar = 0
    epsln = 0
    qrnorm = beta1
    phibar = beta1
    rhs1 = beta1
    rhs2 = 0
    tnorm2 = 0
    gmax = 0
    gmin = finfo(xtype).max
    cs = -1
    sn = 0
    w = zeros(n, dtype=xtype)
    w2 = zeros(n, dtype=xtype)
    r2 = r1

    if show:
        print()
        print()
        print('   Itn     x(1)     Compatible    LS       norm(A)  cond(A) gbar/|A|')

    while itn < maxiter:
        itn += 1

        s = 1.0/beta
        v = s*y

        y = matvec(v)
        y = y - shift * v

        if itn >= 2:
            y = y - (beta/oldb)*r1

        alfa = inner(v,y)
        y = y - (alfa/beta)*r2
        r1 = r2
        r2 = y
        y = psolve(r2)
        oldb = beta
        beta = inner(r2,y)
        if beta < 0:
            raise ValueError('non-symmetric matrix')
        beta = sqrt(beta)
        tnorm2 += alfa**2 + oldb**2 + beta**2

        if itn == 1:
            if beta/beta1 <= 10*eps:
                istop = -1  # Terminate later

        # Apply previous rotation Qk-1 to get
        #   [deltak epslnk+1] = [cs  sn][dbark    0   ]
        #   [gbar k dbar k+1]   [sn -cs][alfak betak+1].

        oldeps = epsln
        delta = cs * dbar + sn * alfa   # delta1 = 0         deltak
        gbar = sn * dbar - cs * alfa   # gbar 1 = alfa1     gbar k
        epsln = sn * beta     # epsln2 = 0         epslnk+1
        dbar = - cs * beta   # dbar 2 = beta2     dbar k+1
        root = norm([gbar, dbar])
        Arnorm = phibar * root

        # Compute the next plane rotation Qk

        gamma = norm([gbar, beta])       # gammak
        gamma = max(gamma, eps)
        cs = gbar / gamma             # ck
        sn = beta / gamma             # sk
        phi = cs * phibar              # phik
        phibar = sn * phibar              # phibark+1

        # Update  x.

        denom = 1.0/gamma
        w1 = w2
        w2 = w
        w = (v - oldeps*w1 - delta*w2) * denom
        x = x + phi*w

        # Go round again.

        gmax = max(gmax, gamma)
        gmin = min(gmin, gamma)
        z = rhs1 / gamma
        rhs1 = rhs2 - delta*z
        rhs2 = - epsln*z

        # Estimate various norms and test for convergence.

        Anorm = sqrt(tnorm2)
        ynorm = norm(x)
        epsa = Anorm * eps
        epsx = Anorm * ynorm * eps
        epsr = Anorm * ynorm * tol
        diag = gbar

        if diag == 0:
            diag = epsa

        qrnorm = phibar
        rnorm = qrnorm
        if ynorm == 0 or Anorm == 0:
            test1 = inf
        else:
            test1 = rnorm / (Anorm*ynorm)    # ||r||  / (||A|| ||x||)
        if Anorm == 0:
            test2 = inf
        else:
            test2 = root / Anorm            # ||Ar|| / (||A|| ||r||)

        # Estimate  cond(A).
        # In this version we look at the diagonals of  R  in the
        # factorization of the lower Hessenberg matrix,  Q @ H = R,
        # where H is the tridiagonal matrix from Lanczos with one
        # extra row, beta(k+1) e_k^T.

        Acond = gmax/gmin

        # See if any of the stopping criteria are satisfied.
        # In rare cases, istop is already -1 from above (Abar = const*I).

        if istop == 0:
            t1 = 1 + test1      # These tests work if tol < eps
            t2 = 1 + test2
            if t2 <= 1:
                istop = 2
            if t1 <= 1:
                istop = 1

            if itn >= maxiter:
                istop = 6
            if Acond >= 0.1/eps:
                istop = 4
            if epsx >= beta1:
                istop = 3
            # if rnorm <= epsx   : istop = 2
            # if rnorm <= epsr   : istop = 1
            if test2 <= tol:
                istop = 2
            if test1 <= tol:
                istop = 1

        # See if it is time to print something.

        prnt = False
        if n <= 40:
            prnt = True
        if itn <= 10:
            prnt = True
        if itn >= maxiter-10:
            prnt = True
        if itn % 10 == 0:
            prnt = True
        if qrnorm <= 10*epsx:
            prnt = True
        if qrnorm <= 10*epsr:
            prnt = True
        if Acond <= 1e-2/eps:
            prnt = True
        if istop != 0:
            prnt = True

        if show and prnt:
            str1 = f'{itn:6g} {x[0]:12.5e} {test1:10.3e}'
            str2 = f' {test2:10.3e}'
            str3 = f' {Anorm:8.1e} {Acond:8.1e} {gbar/Anorm:8.1e}'

            print(str1 + str2 + str3)

            if itn % 10 == 0:
                print()

        if callback is not None:
            callback(x)

        if istop != 0:
            break  # TODO check this

    if show:
        print()
        print(last + f' istop   =  {istop:3g}               itn   ={itn:5g}')
        print(last + f' Anorm   =  {Anorm:12.4e}      Acond =  {Acond:12.4e}')
        print(last + f' rnorm   =  {rnorm:12.4e}      ynorm =  {ynorm:12.4e}')
        print(last + f' Arnorm  =  {Arnorm:12.4e}')
        print(last + msg[istop+1])

    if istop == 6:
        info = maxiter
    else:
        info = 0

    return (postprocess(x),info)
