#! /usr/bin/env python

from __future__ import print_function
import sys,os
import argparse

# If your pyfplo is not found you could also
# explicitly specify the pyfplo version path:
#sys.path.insert(0,"/home/magru/FPLO/FPLO22.00-62/PYTHON/doc");

import numpy as np
import numpy.linalg as LA
import pyfplo.slabify as sla
import pyfplo.fploio as fploio


print( '\npyfplo version=: {0}\nfrom: {1}\n'.format(sla.version,sla.__file__))
# protect against wrong version
#if fedit.version!='22.00': raise RuntimeError('pyfplo version is incorrect.')


try:
    input = raw_input
except NameError:
    pass

# ===================================================================
# 
# ===================================================================

def work(gauge,Ef,mode):

    
    # read bandplot points from =.in
    p=fploio.INParser()
    p.parseFile('../../=.in')
    d=p()('special_sympoints')
    l=[]
    for i in range(d.size()):
        l.append([ d[i]('label').S,d[i]('kpoint').listD])

    # take path from Wang2006
    l=[
        ['~G',[0.,0.,0.]],
        ['H',[0.,1.,0.]],
        ['P',[0.5,0.5,0.5]],
        ]
    
    # prepare Slabify
    hamdata='../../+hamdata'
    
    s=sla.Slabify()
    s.object='3d'
    s.printStructureSettings()
    s.prepare(hamdata)

    

    curvature_along_path(l,s,Ef,gauge=gauge)

    curvature_density(s,Ef,gauge=gauge,mode=mode)
    
    return
# ===================================================================
# 
# ===================================================================
def curvature_along_path(l,s,Ef,gauge):


    if False:
        ret=input('Calculate along path? [y/n]')
        if len(ret)==0 or ret[0]!='y': return
    
    # prepare BandPlot
    bp=sla.BandPlot()
    bp.points=l
    bp.ndiv=100
    bp.calculateBandPlotMesh(s.dirname)


    # now hand made
    dists=bp.kdists
    kpts=bp.kpnts
    
    with bp.openBandFile(s.dirname+'/+F',s.nspin,len(kpts),progress='F') as fw:
        for ik,k in enumerate(kpts):
            for ms in range(s.nspin):

                (Hk,dHk,Abk)=s.hamAtKPoint(k*s.kscale,ms,makedhk=True,
                                          makebasisconnection=True,
                                          gauge=gauge)
                (E,C)=s.diagonalize(Hk)

                mask=list(map(lambda e: 1 if e<=Ef else 0,E))

                
                (FDDsym,unused)=s.berryCurvature(E,C,dHk)
                FTDDonly_sym=Fbandsum(mask,FDDsym)


                
                (F,Fdetails,unused)=s.berryCurvature(E,C,dHk,
                                                     basisconnection=Abk,
                                                     returndetails=True)
                FTDD    =Fbandsum(mask,Fdetails['DD'])
                FTDa    =Fbandsum(mask,Fdetails['Da'])
                FTaccurv=Fbandsum(mask,Fdetails['f'])
                FT      =Fbandsum(mask,F)
                
                fw.write(ms,dists[ik],k,
                         list(FTDD)+list(FTDa)+list(FTaccurv)
                         +list(FT)+list(FTDDonly_sym))
                

    
    return
# ===================================================================
# 
# ===================================================================
def curvature_density(s,Ef,gauge,mode):

    # ---------------------------------------------------------------
    def PP(x,tol=1e-8):
        return "{:12.5f}".format(x)
    # ---------------------------------------------------------------

    if False:
        ret=input('Calculate curvature density plot? [y/n]')
        if len(ret)==0 or ret[0]!='y': return

    
    fso=sla.FermiSurfaceOptions()
    fso.setMesh(50,[0,1],50,[0,1])
    fso.setPlane(xaxis=[1,0,0],yaxis=[0,0,1],origin=[0,0,0])
    fso.fermienergy=Ef
    mesh=fso.mesh(s.kscale)
    
    for ms in range(s.nspin):

        # For displaying the plane with proper angle between the axes
        # in xfbp we need to define an x- and y-axis, which should have
        # the same angle as fso.xaxis and fso.yaxis.
        # They must lie in the x,y-plane!!!!
        # We can also give an origin in the x,y plotting plane.
        with fso.openDensPlotFile(s.dirname+'/+curvature_density',ms,
                                  plotorigin=[0,0],
                                  plotxaxis=[1,0],
                                  plotyaxis=[0,1],progress='curvature') as fw:
            
            for k in mesh:
                (Hk,dHk,Abk)=s.hamAtKPoint(k,
                                          ms,makedhk=True,
                                          makebasisconnection=True,
                                          gauge=gauge)


                (E,C)=s.diagonalize(Hk)

                mask=list(map(lambda e: 1 if e<=Ef else 0,E))



                (FDDsym,unused)=s.berryCurvature(E,C,dHk)
                FTDDonly_sym=Fbandsum(mask,FDDsym)



                (F,Fdetails,unused)=s.berryCurvature(E,C,dHk,
                                                     basisconnection=Abk,
                                                     returndetails=True)


                FTDD=Fbandsum(mask,Fdetails['DD'])
                FT=Fbandsum(mask,F)


                res=FT
                if mode=='all':
                    res=FT
                elif mode=='DD':
                    res=FTDD
                elif mode=='DDsym':
                    res=FTDDonly_sym

                fw.write(res)


    return

# ===================================================================
# 
# ===================================================================

def Fbandsum(mask,F):
    return np.array([sum(F[i,:]*mask) for i in range(3)])


# ===================================================================
# 
# ===================================================================


if __name__ == '__main__':

    parser=argparse.ArgumentParser(description='',
                                   conflict_handler='resolve',
                                   epilog='')
    parser.add_argument('--gauge', action='store',dest='gauge',
                        default='periodic')
    parser.add_argument('--Ef', action='store',dest='Ef',
                        type=float,default=0.)
    parser.add_argument('--mode', action='store',dest='mode',
                        type=str,default='all')
    
    
    args=parser.parse_args()    

    if args.gauge=='relative': args.gauge='force'+args.gauge

    work(gauge=args.gauge,Ef=args.Ef,mode=args.mode)

