#! /usr/bin/env python

from __future__ import print_function

import sys,os,argparse

# If your pyfplo is not found you could also
# explicitly specify the pyfplo version path:
#sys.path.insert(0,"/home/magru/FPLO/FPLO22.00-62/PYTHON/doc");

import numpy as np
import numpy.linalg as LA
import pyfplo.slabify as sla
import pyfplo.fploio as fploio
import pyfplo.common as com


print( '\npyfplo version=: {0}\nfrom: {1}\n'.format(sla.version,sla.__file__))
# protect against wrong version
#if fedit.version!='22.00': raise RuntimeError('pyfplo version is incorrect.')


maxerrors={}

# ===================================================================
# 
# ===================================================================

def work(gauge='periodic',tol=1e-8):
    '''
    Check if Hamiltonian symmetry is correct and print eigenvalues
    for little group operations
    '''
    # read bandplot points from =.in
    p=fploio.INParser()
    p.parseFile('../../=.in')
    d=p()('special_sympoints')
    l=[]
    for i in range(d.size()):
        l.append([ d[i]('label').S,d[i]('kpoint').listD])
    

    
    # prepare Slabify
    hamdata='../../+hamdata'
    
    s=sla.Slabify()
    s.object='3d'
    s.printStructureSettings()
    s.prepare(hamdata)
    
    # prepare BandPlot
    bp=sla.BandPlot()
    bp.points=l
    bp.ndiv=10
    bp.calculateBandPlotMesh(s.dirname)

    # now hand made
    dists=bp.kdists
    kpts=bp.kpnts

    if 0: # for handmade list
        kpts=[ [0,0,0] ,[1,0,0],[0,1,0],[0,0,1],
               [0.5,0.5,0.5],[-0.5,0.5,0.5],[0.5,-0.5,0.5],[0.5,0.5,-0.5] ]
        kpts=map(np.array,kpts)
    
    
    evtol=1e-8



    sitecenters=s.wannierCenterMatrix()

    # find out which operators are there
    makesigma=s.hassigma
    makexcfield=s.hasxcfield
    makebasisconnection=s.hasbasisconnection
        

    for ik,k in enumerate(kpts):
        for ms in range(s.nspin):

            print('='*80)
            print( 'k[{}]={} spin={}'.format(ik,k,ms))

            # get Hamiltonian and symmetry reps
            ret=s.hamAtKPoint(k*s.kscale,ms,opindices=None,
                                     gauge=gauge,makewfsymops=True,
                                     makesigma=makesigma,
                                     makexcfield=makexcfield,
                                     makebasisconnection=makebasisconnection)

            i=0
            Hk=ret[i]
            i+=1
            if makesigma:
                Sk=ret[i]
                i+=1
            if makexcfield:
                Bk=ret[i]
                i+=1
            if makebasisconnection:
                Abk=ret[i]
                i+=1
            WF=ret[i]
            i+=1

            d=LA.norm(np.conj(Hk.T)-Hk)
            txt=' |Hk-Hk^{+}|'
            recordError(work.__name__+txt,d)    
            if(d>1e-10): print('{}, error ={}'.format(txt,d))


            
            (EV,C)=LA.eigh(Hk)
            
            for w in WF: # for each operation

                if not w.isinlittlegroup:

                    print('-'*80)
                    print( 'checking',w)
                    checkHamiltonianSymmetry(s,ms,k,Hk,C,w.Dk,w,gauge,tol)
                    if makebasisconnection:
                        checkPosOpSymmetry(s,ms,k,Abk,sitecenters
                                           ,w.Dk,w,gauge,tol)
                    if makesigma:
                        checkSigmaSymmetry(s,ms,k,Sk,w.Dk,w,gauge,isxc=False
                                           ,tol=tol)
                    if makexcfield:
                        checkSigmaSymmetry(s,ms,k,Bk,w.Dk,w,gauge,isxc=True,
                                           tol=tol)
                    
                else:
                    print('-'*80)
                    print( 'checking little group op',w)

                    checkHamiltonianSymmetryForLittelGroup(s,Hk,w.Dk,w,tol)

                    eU=np.zeros(shape=s.nvdim,dtype=complex)

                    #note that the following diagonalization makes
                    #no sense for time reversed operations
                    if not w.timerev:
                        try:
                            (EU,CZ)=s.coDiagonalize(EV,C,w.Dk,check=True)
                        except:
                            (EU,CZ)=s.coDiagonalize(EV,C,w.Dk,check=False)
                        printEnergySubspace(s,EV,EU,evtol)
                        checkResults(s,Hk,EV,CZ,EU,w.Dk,w.timerev,tol)


                    if makebasisconnection:
                        checkPosOpSymmetryForLittelGroup(s,Abk,sitecenters
                                                         ,w.Dk,w,gauge,tol)
                    if makesigma:
                        checkSigmaSymmetryForLittelGroup(s,Sk,w.Dk,w,tol,
                                                         isxc=False)
                    if makexcfield:
                        checkSigmaSymmetryForLittelGroup(s,Bk,w.Dk,w,tol,
                                                         isxc=True)


    l=0
    for key in maxerrors.keys(): l=max(l,len(key))
    for key in maxerrors.keys():
        print(('{:<'+str(l)+'}, max error: {:12.4e}').format(key,maxerrors[key]))
                            
    print( 'done')
                    
# ===================================================================
# 
# ===================================================================
def checkHamiltonianSymmetry(s,ms,k,Hk,C,Dk,w,gauge,tol):
    '''
    Check if the Hamiltonian symmetry for an operations which is
    not in the little group is true.
    '''

    # test for alpha*k not \sim  k
    
    fac= (-1 if w.timerev else 1)

    # H^{alpha*k}
    Hak=s.hamAtKPoint(fac*w.alpha.dot(k)*s.kscale,ms,gauge=gauge)


    DkH=np.conj(Dk.T)

    
    if w.timerev:
        tmp=np.conj(DkH.dot(Hak.dot(Dk)))
    else:
        tmp=DkH.dot(Hak.dot(Dk))

    d=LA.norm(tmp-Hk)
    txt=' Hk != D^{+} Hak D'
    recordError(checkHamiltonianSymmetry.__name__+txt,d)
    if d>tol: print('{}, error: {}'.format(txt,d))


    
    #  Hk Ck=Ck Ek
    #  D^{+} Hak D Cak=Ck Ek
    #  Hak D Ck=D Ck Ek
    #  D Ck=Cak U
    #  U=Cak^{+} D Ck must be unitary

    (EVa,Cak)=LA.eigh(Hak)

    
    if w.timerev:
        U=np.conj(Cak.T).dot(Dk.dot(np.conj(C)))
    else:
        U=np.conj(Cak.T).dot(Dk.dot(C))
    d=LA.norm(np.conj(U.T).dot(U)-np.eye(s.nvdim))
    txt=' U^{+}*U!=0 for k not in group'
    recordError(checkHamiltonianSymmetry.__name__+txt,d)
    if d>tol: print('{}, error: {}'.format(txt,d))
    
    return

# ===================================================================
# 
# ===================================================================
def checkHamiltonianSymmetryForLittelGroup(s,Hk,Dk,w,tol):
    '''
    Assuming that Dk is a rep of the little group, check if 
     H=D^{+} H D 
    or 
     H=(D^{+} H D )^{*}  (for timerev==True)
    is true.
    '''
    DkH=np.conj(Dk.T)
    
    if w.timerev:
        tmp=np.conj(DkH.dot(Hk.dot(Dk)))
    else:
        tmp=DkH.dot(Hk.dot(Dk))

    d=LA.norm(tmp-Hk)
    txt=' [D,H]!=0'
    recordError(checkHamiltonianSymmetryForLittelGroup.__name__+txt,d)    
    if d>tol:
        print( 'Hk=')
        print( ndprint(Hk, prec=3,supsmall=True,maxwidth=300))
        print( 'Dk=')
        print( ndprint(Dk, prec=3,supsmall=True,maxwidth=300))

        print( 'D^{+}HD=')
        print( ndprint(tmp, prec=3,supsmall=True,maxwidth=300))
        print( 'D^{+}HD=')
        print( ndprint(tmp, prec=3,supsmall=True,maxwidth=300))
        print('{}, error: {}'.format(txt,d))
    return

# ===================================================================
# 
# ===================================================================
def printEnergySubspace(s,EV,EU,evtol):
    '''
    For each energy subspace print energy and symmetry eigenvalue 
    '''
    ie=0
    ie0=ie
    while ie<=s.nvdim:
        de=1;
        if ie<s.nvdim:
            de=np.abs(EV[ie]-EV[ie0]);
            if de>1: de/=np.abs(EV[ie0]);
        if de>evtol or ie>=s.nvdim:
            for i in range(ie0,ie):
                print( '{0} {1} {2}'.format(i,EV[i],CX(EU[i])))
            print()
            ie0=ie

        ie+=1
    return
    
# ===================================================================
# 
# ===================================================================
def checkResults(s,Hk,EV,CZ,EU,Dk,timerev,tol):

    '''
    Check if CZ which diagonalizes Hamiltonian and U=C^{+} D C
    is unitary and does indead diagonalize H and U.
    '''
    # CZ is unitary ...
    
    CZH=np.conj(CZ.T)
    tmp=CZH.dot(CZ)

    d=LA.norm(tmp-np.eye(s.nvdim))
    if d>tol:
        print( 'CZ^{+}*CZ=')
        print( ndprint(tmp,prec=3,supsmall=True,maxwidth=300))
        print('CZ is not unitary, error={}'.format(d))


    
    # and diagoanlizes Hk
    tmp=CZH.dot(Hk.dot(CZ))

    d=LA.norm(tmp-np.diag(EV))
    if d>tol:
        print( 'CZ^{+}*H*CZ=')
        print( ndprint(tmp,prec=3,supsmall=True,maxwidth=300))
        print('CZ^{+}*H*CZ != E,error =',d)

    
    if not timerev: 

        # and the symemtry rep mat Dk
        if timerev: # formally this needs to be done
            tmp=CZH.dot(Dk.dot(np.conj(CZ)))
        else:
            tmp=CZH.dot(Dk.dot(CZ))
        d=LA.norm(tmp-np.diag(EU))
        txt=' CZ^{+}*Dk*CZ != E(Dk)'
        recordError(checkResults.__name__+txt,d)    
        if d>tol:
            print( 'CZ^{+}*Dk*CZ')
            print( ndprint(tmp,prec=3,supsmall=True,maxwidth=300))
            print('{}, error={}'.format(txt,d))
    return

# ===================================================================
def CX(c,tol=1e-6):
    if np.abs(c.real)<tol:
        return '{0}j'.format(c.imag)
    if np.abs(c.imag)<tol:
        return '{0}'.format(c.real)
    return '{0:g}{1:+g}j'.format(c.real,c.imag)

# ===================================================================
# 
# ===================================================================

def ndprint(a, prec=5,supsmall=True,maxwidth=75):
    return np.array_str(a,suppress_small=supsmall,
                     precision=prec,max_line_width=maxwidth)

# ===================================================================
# 
# ===================================================================
def checkPosOpSymmetry(s,ms,k,Abk,sitecenters,Dk,w,gauge,tol):
    '''
    Check if the position opreator symmetry for an operations which is
    not in the little group is true.
    '''

    # test for alpha*k not \sim  k
    
    fac= (-1 if w.timerev else 1)

    # H^{alpha*k}
    (Hak,Abak)=s.hamAtKPoint(fac*w.alpha.dot(k)*s.kscale,ms,gauge=gauge,
                            makebasisconnection=True)

    sfac= 1 if gauge!='periodic' else 0

    Rk=[ -Abk[0]+sfac*sitecenters[0],
         -Abk[1]+sfac*sitecenters[1],
         -Abk[2]+sfac*sitecenters[2] ]
    Rak=[ -Abak[0]+sfac*sitecenters[0],
          -Abak[1]+sfac*sitecenters[1],
          -Abak[2]+sfac*sitecenters[2] ]
    
    ndim=Abk[0].shape[0]
    
    DkH=np.conj(Dk.T)

    tmp=[0,0,0]
    for i in range(3):
        if w.timerev:
            tmp[i]=np.conj(DkH.dot((Rak[i]-sitecenters[i]).dot(Dk)))
        else:
            tmp[i]=DkH.dot((Rak[i]-sitecenters[i]).dot(Dk))

    res=[np.zeros(shape=(ndim,ndim),dtype=Abk[0].dtype),
         np.zeros(shape=(ndim,ndim),dtype=Abk[0].dtype),
         np.zeros(shape=(ndim,ndim),dtype=Abk[0].dtype)]
    for i in range(3):
        for j in range(3):
            res[i]+=tmp[j]*w.alpha[j,i]
        res[i]+=sitecenters[i]

            
    d=sum([LA.norm(res[i]-Rk[i]) for i in range(3)])
    txt=' Rk != D^{+} (Rak-delta*s) D +delta*s'
    recordError(checkPosOpSymmetry.__name__+txt,d)
    if d>tol:
        print('{}, error: {}'.format(txt,d))

    
    return

# ===================================================================
# 
# ===================================================================
def checkPosOpSymmetryForLittelGroup(s,Abk,sitecenters,Dk,w,gauge,tol):
    '''
    Assuming that Dk is a rep of the little group, check if 
     R=D^{+} (R -delta*s) D *alpha +delta*s
    or 
     R=(D^{+} (R -delta*s) D )^{*} *alpha +delta*s  (for timerev==True)
    is true.
    '''
    ndim=Abk[0].shape[0]
    
    sfac= 1 if gauge!='periodic' else 0

    
    Rk=[ -Abk[0]+sfac*sitecenters[0],
         -Abk[1]+sfac*sitecenters[1],
         -Abk[2]+sfac*sitecenters[2] ]
    
    DkH=np.conj(Dk.T)
            
    tmp=[0,0,0]
    for i in range(3):
        if w.timerev:
            tmp[i]=np.conj(DkH.dot((Rk[i]-sitecenters[i]).dot(Dk)))
        else:
            tmp[i]=DkH.dot((Rk[i]-sitecenters[i]).dot(Dk))

    res=[np.zeros(shape=(ndim,ndim),dtype=Abk[0].dtype),
         np.zeros(shape=(ndim,ndim),dtype=Abk[0].dtype),
         np.zeros(shape=(ndim,ndim),dtype=Abk[0].dtype)]
    for i in range(3):
        for j in range(3):
            res[i]+=tmp[j]*w.alpha[j,i]
        res[i]+=sitecenters[i]
        
    d=LA.norm(res[0]-Rk[0])+LA.norm(res[1]-Rk[1])+LA.norm(res[2]-Rk[2])

    txt=' D^{+}(Rk-delta*s)D*alpha+delta*s-Rk'
    recordError(checkPosOpSymmetryForLittelGroup.__name__+txt,d)
    
    if d>tol:
        print( 'Rk=')
        print( ndprint(Rk[0], prec=3,supsmall=True,maxwidth=300))
        print( ndprint(Rk[1], prec=3,supsmall=True,maxwidth=300))
        print( ndprint(Rk[2], prec=3,supsmall=True,maxwidth=300))
        print( 'Dk=')
        print( ndprint(Dk, prec=3,supsmall=True,maxwidth=300))

        print( 'D^{+}(Rk-delta*s)D*alpha+delta*s-Rk')
        print( ndprint(res[0]-Rk[0], prec=3,supsmall=True,maxwidth=300))
        print( ndprint(res[1]-Rk[1], prec=3,supsmall=True,maxwidth=300))
        print( ndprint(res[2]-Rk[2], prec=3,supsmall=True,maxwidth=300))
        print('{}, error: {}'.format(txt,d))
    return
# ===================================================================
# 
# ===================================================================
def checkSigmaSymmetry(s,ms,k,opk,Dk,w,gauge,isxc,tol):
    '''
    Check if the sigma operator symmetry for an operations which is
    not in the little group is true.
    '''

    # test for alpha*k not \sim  k
    
    fac= (-1 if w.timerev else 1)

    # H^{alpha*k}
    if isxc:
        (Hak,opak)=s.hamAtKPoint(fac*w.alpha.dot(k)*s.kscale,ms,gauge=gauge,
                                makexcfield=True)
    else:
        (Hak,opak)=s.hamAtKPoint(fac*w.alpha.dot(k)*s.kscale,ms,gauge=gauge,
                                makesigma=True)



    ndim=opk[0].shape[0]

    DkH=np.conj(Dk.T)

    tmp=[0,0,0]
    for i in range(3):
        if w.timerev:
            tmp[i]=np.conj(DkH.dot(opak[i].dot(Dk)))
        else:
            tmp[i]=DkH.dot(opak[i].dot(Dk))
    fac=-1. if w.timerev^w.isimproper else 1.

    res=[np.zeros(shape=(ndim,ndim),dtype=opk[0].dtype),
         np.zeros(shape=(ndim,ndim),dtype=opk[0].dtype),
         np.zeros(shape=(ndim,ndim),dtype=opk[0].dtype)]
    for i in range(3):
        for j in range(3):
            res[i]+=tmp[j]*w.alpha[j,i]*fac


            
    d=LA.norm(res[0]-opk[0])+LA.norm(res[1]-opk[1])+LA.norm(res[2]-opk[2])
    txt=' opk != D^{+} opak D alpha'
    sss='XCField' if isxc else 'Sigma'
    recordError('check{}Symmetry'.format(sss)+txt,d)    
    if d>tol:
        print(('{}, error: {}').format(txt,d))

    
    return
# ===================================================================
# 
# ===================================================================
def checkSigmaSymmetryForLittelGroup(s,opk,Dk,w,tol,isxc):
    '''
    Assuming that Dk is a rep of the little group, check if 
     op=D^{+} op D *alpha *fac
    or 
     op=(D^{+} op D )^{*} *alpha * fac  (for timerev==True)
    is true, where fac=1 unless the operation has either inversion or 
    timerev exclusively (not both). *opk* can be sigma or the xc-field.

    '''
    ndim=opk[0].shape[0]

    
    DkH=np.conj(Dk.T)
            
    tmp=[0,0,0]
    for i in range(3):
        if w.timerev:
            tmp[i]=np.conj(DkH.dot(opk[i].dot(Dk)))
        else:
            tmp[i]=DkH.dot(opk[i].dot(Dk))

    fac=-1. if w.timerev^w.isimproper else 1.
    res=[np.zeros(shape=(ndim,ndim),dtype=opk[0].dtype),
         np.zeros(shape=(ndim,ndim),dtype=opk[0].dtype),
         np.zeros(shape=(ndim,ndim),dtype=opk[0].dtype)]
    for i in range(3):
        for j in range(3):
            res[i]+=tmp[j]*w.alpha[j,i]*fac

        
    d=LA.norm(res[0]-opk[0])+LA.norm(res[1]-opk[1])+LA.norm(res[2]-opk[2])
    txt=' opk != D^{+} opak D alpha'
    sss='XCField' if isxc else 'Sigma'
    recordError('check{}SymmetryForLittelGroup'.format(sss)+txt,d)    
    
    if d>tol:
        print( 'opk=')
        print( ndprint(opk[0], prec=3,supsmall=True,maxwidth=300))
        print( ndprint(opk[1], prec=3,supsmall=True,maxwidth=300))
        print( ndprint(opk[2], prec=3,supsmall=True,maxwidth=300))
        print( 'Dk=')
        print( ndprint(Dk, prec=3,supsmall=True,maxwidth=300))

        print( 'D^{+} opk D*alpha-opk')
        print( ndprint(res[0]-opk[0], prec=3,supsmall=True,maxwidth=300))
        print( ndprint(res[1]-opk[1], prec=3,supsmall=True,maxwidth=300))
        print( ndprint(res[2]-opk[2], prec=3,supsmall=True,maxwidth=300))
        print('{}: [D,op]!=0, error: {}'.format(w.symbol,d))
    return
# ===================================================================
# 
# ===================================================================
def recordError(key,e):
    try:
        e0=maxerrors[key]
    except:
        e0=0
        pass
    maxerrors[key]=max(e0,e)
        
# ===================================================================
# 
# ===================================================================


if __name__ == '__main__':

    parser=argparse.ArgumentParser(description='',
                                   conflict_handler='resolve',
                                   epilog='')
    parser.add_argument('--gauge', action='store',dest='gauge',
                        default='periodic')


    args=parser.parse_args()
    
    if args.gauge=='relative': args.gauge='force'+args.gauge
    
    work(gauge=args.gauge,tol=1.0e-8)

