#!/usr/bin/python
import sys,os

def transpose_h(h_matrix):
    # Vasp has the H-matrix in the transposed form.
    h_matrix=[h_matrix[0],h_matrix[3],h_matrix[6],h_matrix[1],h_matrix[4],h_matrix[7],h_matrix[2],h_matrix[5],h_matrix[8]]
    return h_matrix


def matvec(h,x):
    return [h[0]*x[0]+h[1]*x[1]+h[2]*x[2],
            h[3]*x[0]+h[4]*x[1]+h[5]*x[2],
            h[6]*x[0]+h[7]*x[1]+h[8]*x[2]]

read_contcar=False
read_outcar=False
read_poscar=False
read_xdatcar=True
expandcell=True

expand=[1,1,1]
translate=[0,0,0]

def show_help():
    print "Help for vaspymol"
    print "[-c] [-p] [-x] [-1] [a b c [a0 b0 c0]]"
    print "-c : Read CONTCAR"
    print "-p : Read POSCAR"
    print "-x : Read XDATCAR"
    print "-1 : When expanding the cell, only expand atom positions, not the cell"
    print "[a b c] : Expand cell along cell vectors. a,b,c integers"
    print "[a0 b0 c0] : Translate atom positions along cell vectors. a0, b0, c0 fractional displacements (0-1)"

argv=sys.argv

argv=argv[1:]

while len(argv)>=1:
    if argv[0][0]!="-":
        break
    if argv[0]=="-c":
        read_contcar=True
        read_xdatcar=False
    elif argv[0]=="-p":
        read_poscar=True
        read_xdatcar=False
    elif argv[0]=="-x":
        read_xdatcar=True
        read_xdatcar=False
    elif argv[0]=="-1":
        expandcell=False
    elif argv[0]=="-help":
        show_help()
        sys.exit(0)
    else:
        print "Unknown option ",argv[0]
        show_help()
        sys.exit(1)
    argv=argv[1:]

if len(argv)>=3:
    expand=map(int,argv[:3])
    argv=argv[3:]
    if len(argv)>=3:
        translate=map(float,argv[:3])
        argv=argv[3:]


if read_xdatcar:
    try:
        f=open("XDATCAR")
    except:
        read_poscar=True
        read_xdatcar=False

if read_poscar:
    f=open("POSCAR")
elif read_contcar:
    f=open("CONTCAR")
elif read_outcar:
    f=open("OUTCAR")

coordinates=[]
h_matrix=[]
atom_numbers=[]
natom_types=[]
nframes=1

if read_xdatcar:
    print "Reading XDATCAR"
    print "Obtaining atom types and h-matrix from POSCAR"
    fp=open("POSCAR")
    fp.readline()
    scale=float(fp.readline().split()[0])
    for i in range(3):
        h_matrix+=map(lambda x: float(x)*scale,fp.readline().split())
    h_matrix=transpose_h(h_matrix)
    found_n_atom_types=True
    try:
        n_atom_types=map(int,fp.readline().split())
    except:
        found_n_atom_types=False
    if not found_n_atom_types:
        n_atom_types=map(int,fp.readline().split())
    fp.close()
    print "Reading XDATCAR"
    nframes=0
    while True:
        line=f.readline()
        if not line:
            break
        if line[:6]=="Direct":
            a_new_frame=[]
            while True:
                line=f.readline()
                if not line:
                    break
                d=line.split()
                has_more_coordinates=False
                if len(d)>=3:
                    has_more_coordinates=True
                    try:
                        a_new_frame.append(map(float,d[:3]))
                    except:
                        has_more_coordinates=False
                if not has_more_coordinates and len(a_new_frame)>0:
                    nframes+=1
                    coordinates+=a_new_frame
                    a_new_frame=[]
                    if len(d)>0:
                        break
            if len(a_new_frame)>0:
                nframes+=1
                coordinates+=a_new_frame
                a_new_frame=[]
    # Vasp 5.2 has a diffent XDATCAR format without "Direct", only
    # blank lines, so if no frames are found we try to parse in a different way...
    if nframes==0:
        f.close()
        f=open("XDATCAR")
        f.readline()
        scale=float(f.readline().split()[0])
        for i in range(3):
            h_matrix+=map(lambda x: float(x)*scale,f.readline().split())
        h_matrix=transpose_h(h_matrix)
        found_n_atom_types=True
        try:
            n_atom_types=map(int,f.readline().split())
        except:
            found_n_atom_types=False
        if not found_n_atom_types:
            n_atom_types=map(int,f.readline().split())
        while True:
            line=f.readline()
            if not line:
                break
            d=line.split()
            # Find a blank line
            if len(d)==0:
                new_frame=False
                while True:
                    line=f.readline()
                    if not line:
                        break
                    d=line.split()
                    if len(d)>=3:
                        coordinates.append(map(float,d[:3]))
                        new_frame=True
                    else:
                        nframes+=1
                        new_frame=False
                if new_frame:
                    nframes+=1
    for i in range(len(coordinates)):
        coordinates[i]=matvec(h_matrix,coordinates[i])
    print "Found ",nframes," frames"
elif read_outcar:
    print "Reading OUTCAR"
    print "Obtaining atom types from POSCAR"
    fp=open("POSCAR")
    for i in range(5):
        fp.readline()
    found_n_atom_types=True
    try:
        n_atom_types=map(int,fp.readline().split())
    except:
        found_n_atom_types=False
    if not found_n_atom_types:
        n_atom_types=map(int,fp.readline().split())
    fp.close()
    print "Reading OUTCAR"
    while True:
        line=f.readline()
        if not line:
            break
        if line.find("position of ions in cartesian coordinates")>=0:
            # Only read first frame
            if len(coordinates):
                break
            while True:
                line=f.readline()
                if not line:
                    print "Cannot read atomic positions"
                    sys.exit(1)
                d=line.split()
                if len(d)!=3:
                    break
                else:
                    coordinates.append(map(float,d[:3]))
        if line.find("direct lattice vectors")>=0:
            for i in range(3):
                h_matrix+=map(float,f.readline().split()[:3])
            h_matrix=transpose_h(h_matrix)
elif read_poscar or read_contcar:
    if read_poscar:
        print "Reading POSCAR"
    if read_contcar:
        print "Reading CONTCAR"
    f.readline()
    scale=float(f.readline().split()[0])
    for i in range(3):
        h_matrix+=map(lambda x: float(x)*scale,f.readline().split())
    h_matrix=transpose_h(h_matrix)
    found_n_atom_types=True
    try:
        n_atom_types=map(int,f.readline().split())
    except:
        found_n_atom_types=False
    if not found_n_atom_types:
        n_atom_types=map(int,f.readline().split())
    line=f.readline()
    if line[0]=="S" or line[0]=="s":
        line=f.readline()
    fractional=True
    if line[0]=="C" or line[0]=="c" or line[0]=="K" or line[0]=="k":
        fractional=False
    while True:
        line=f.readline()
        if not line:
            break
        d=line.split()
        if len(d)<3:
            break
        coordinates.append(map(float,d[:3]))
    if fractional:
        for i in range(len(coordinates)):
            coordinates[i]=matvec(h_matrix,coordinates[i])
f.close()

if nframes==0:
    print "Found no frames."
    sys.exit(1)
    

periodic_table=[ "",
                 "H","He",
                 "Li","Be","B","C","N","O","F","Ne",
                 "Na","Mg","Al","Si","P","S","Cl","Ar",
                 "K","Ca","Sc","Ti","V","Cr","Mn","Fe","Co","Ni","Cu","Zn","Ga","Ge","As","Se","Br","Kr",
                 "Rb","Sr","Y","Zr","Nb","Mo","Tc","Ru","Rh","Pd","Ag","Cd","In","Sn","Sb","Te","I","Xe",
                 "Cs","Ba",
                 "La","Ce","Pr","Nd","Pm","Sm","Eu","Gd","Tb","Dy","Ho","Er","Tm","Yb","Lu",
                 "Hf","Ta","W","Re","Os","Ir","Pt","Au","Hg","Tl","Pb","Bi","Po","At","Rn",
                 "Fr","Ra",
                 "Ac","Th","Pa","U","Np","Pu","Am","Cm","Bk","Cf","Es","Fm","Md","No","Lr",
                 "Rf","Db","Sg","Bh","Hs","Mt" ]

periodic_hash={}
for i in range(len(periodic_table)):
    periodic_hash[periodic_table[i]]=i

atom_numbers=[]
f=open("POTCAR")
print "Reading POTCAR"
for n in n_atom_types:
    while True:
        line=f.readline()
        if not line:
            break
        if line.find("End of Dataset")>=0:
            break
        if line.find("TITEL  =")>=0:
            aname=line.split("=")[1].split()[1].split("_")[0]
            Z=periodic_hash[aname]
            atom_numbers+=[Z]*n

f.close()

atoms=[]
if len(coordinates)!=len(atom_numbers)*nframes:
    print "Internal consistency error: Number of coordinates not the same as number of atom numbers times the number of frames: ",len(coordinates),"!=",len(atom_numbers),"*",nframes
    sys.exit(1)

for i in range(len(coordinates)):
    atoms.append([atom_numbers[i % len(atom_numbers)]]+coordinates[i])

# Number of atoms
nframeatoms=len(atom_numbers)

# Should perform expansion here
print "Expanding..."
newatoms=[]
for iframe in range(nframes):
    for i in range(expand[0]):
        for j in range(expand[1]):
            for k in range(expand[2]):
                lvec=matvec(h_matrix,[i,j,k])
                for iatom in range(nframeatoms):
                    atom=atoms[iframe*nframeatoms+iatom]
                    newatoms.append([atom[0],atom[1]+lvec[0],atom[2]+lvec[1],atom[3]+lvec[2]])
atoms=newatoms

# New number of atoms
nframeatoms=len(atom_numbers)*expand[0]*expand[1]*expand[2]

print "Translating..."
tvec=matvec(h_matrix,translate)
for iframe in range(nframes):
    for iatom in range(nframeatoms):
        atom=atoms[iframe*nframeatoms+iatom]
        atom[1]+=tvec[0]
        atom[2]+=tvec[1]
        atom[3]+=tvec[2]

if expandcell:
    # Update h-matrix
    h_matrix=[h_matrix[0]*expand[0],h_matrix[1]*expand[0],h_matrix[2]*expand[0],
              h_matrix[3]*expand[1],h_matrix[4]*expand[1],h_matrix[5]*expand[1],
              h_matrix[6]*expand[2],h_matrix[7]*expand[2],h_matrix[8]*expand[2]]

# Should remove atoms that are too close here
print "Removing duplicated atoms and building frames..."
allatoms=[]
for i in range(nframes):
    frameatoms=[]
    for iatom in range(nframeatoms):
        atom=atoms[i*nframeatoms+iatom]
        too_close=False
        for oldatom in frameatoms:
            rx=atom[1]-oldatom[1]
            ry=atom[2]-oldatom[2]
            rz=atom[3]-oldatom[3]
            r2=rx*rx+ry*ry+rz*rz
            if r2<0.5:
                too_close=True
                break
        if not too_close:
            frameatoms.append(atom)
    allatoms.append(frameatoms)

# Output ymol file
print "Writing vaspymol.mol file"
f=open("vaspymol.mol","w")
f.write(`nframes`+"\n")
for atoms in allatoms:
    for h in h_matrix:
        f.write(" "+`h`)
    f.write("\n")
    f.write(`len(atoms)`+"\n")
    for i in range(len(atoms)):
        atom=atoms[i]
        f.write(`i`+" "+`atom[0]`+" "+`atom[1]`+" "+`atom[2]`+" "+`atom[3]`+"\n")
f.close()
print "Writing ymol_import_options file"
f=open("ymol_import_options","w")
f.write("""1
1 2
1
0
0
0
0
0
""")
f.close()

print "Launching ymol"
os.execlp("ymol","ymol","vaspymol.mol")


