#!/usr/bin/env python2.7

import numpy as np
import scipy.signal as sig
import scipy.interpolate as itp
import pyproj
import sys
import SacLib

"""
Usage Example:

    SacUtils('Input.sac',
             'Output.sac',
              Filter = [0.2, 20., 3.],
              Cut = [[2017, 319, 12, 30, 0., 0.],
                     [2017, 319, 13, 30, 0., 0.]],
              Dt = 0.01,
              XY = )
"""


def SacUtils(InpFile, OutFile, Filter=None, Cut=None, Dt=None, Utm=None):

    #--------------------------------------------------------------------------
    # Import Sac

    sac = SacLib.Sac(InpFile)

    TN = sac.Head['NPTS']
    DT = sac.Head['DELTA']
    FS = 1./DT

    #--------------------------------------------------------------------------
    # Apply filter

    if Filter:

        LowCorner = Filter[0]
        HighCorner = Filter[1]
        Order = Filter[2]

        # Corner frequencies
        Corners = [2.*LowCorner/FS, 2.*HighCorner/FS]

        # Butterworth filter
        b, a = sig.butter(Order, Corners, btype='band')

        # Filtering records
        zi = sig.lfilter_zi(b, a);
        sac.Data[0],_ = sig.lfilter(b, a, sac.Data[0],
                                    zi=zi*sac.Data[0][0])

    #--------------------------------------------------------------------------
    # Cut and resample

    if Cut:

        Sref = Date2Sec(sac.Head['NZYEAR'],
                        sac.Head['NZJDAY'],
                        sac.Head['NZHOUR'],
                        sac.Head['NZMIN'],
                        sac.Head['NZSEC'],
                        sac.Head['NZMSEC'])

        S0 = Date2Sec(Cut[0][0],
                      Cut[0][1],
                      Cut[0][2],
                      Cut[0][3],
                      Cut[0][4],
                      Cut[0][5])

        S1 = Date2Sec(Cut[1][0],
                      Cut[1][1],
                      Cut[1][2],
                      Cut[1][3],
                      Cut[1][4],
                      Cut[1][5])

        if not Dt:
            Dt = DT

        Tax0 = np.linspace(Sref, Sref+(TN-1)*DT, TN)
        Tax1 = np.arange(S0, S1+Dt, Dt)

        DataI = itp.interp1d(Tax0, sac.Data[0],
                             bounds_error=False,
                             fill_value=0.)

        sac.Data[0] = DataI(Tax1)

        sac.Head['NPTS'] = len(sac.Data[0])
        sac.Head['DELTA'] = Dt
        sac.Head['NZYEAR'] = Cut[0][0]
        sac.Head['NZJDAY'] = Cut[0][1]
        sac.Head['NZHOUR'] = Cut[0][2]
        sac.Head['NZMIN'] = Cut[0][3]
        sac.Head['NZSEC'] = Cut[0][4]
        sac.Head['NZMSEC'] = Cut[0][5]
        sac.Head['B'] = 0.
        sac.Head['E'] = S1-S0

    #--------------------------------------------------------------------------
    # Converting from WGS to UTM cartesian projection

    if Utm:

        Lon = sac.Head['STLO']
        Lat = sac.Head['STLA']

        P = pyproj.Proj(proj='utm', zone=Utm, ellps='WGS84')
        X,Y = P(Lon, Lat)

        sac.Head['USER7'] = X
        sac.Head['USER8'] = Y

    #--------------------------------------------------------------------------
    # Export modified Sac

    sac.Write(OutFile, OverWrite=True)


#------------------------------------------------------------------------------
# Utility functions

def Date2Sec(Year, Jday, Hour, Min, Sec, Msec):

    DSEC = 24.*3600.
    YDAYS = 365.

    YSec = (Year-1)*YDAYS*DSEC
    YSec += LeapNum(Year)*DSEC
    DSec = (Jday-1)*DSEC
    Tsec = YSec + DSec + Hour*3600.+ Min*60. + Sec*1. + Msec/1e3

    return Tsec

def LeapNum(Year):

    N0 = (Year-1)//4
    N1 = (Year-1)//100
    N2 = (Year-1)//400

    return N0 - N1 + N2

def WgsToXY(Lon, Lat, Lat0=None):

    EARTHRADIUS = 6371009.

    if Lat0:
      LatR = Lat0
    else:
      LatR = Lat

    Lon = np.deg2rad(Lon)
    Lat = np.deg2rad(Lat)
    LatR = np.deg2rad(LatR)

    Y = Lat * EARTHRADIUS
    X = Lon * EARTHRADIUS * np.cos(np.deg2rad(LatR))

    return X, Y

#------------------------------------------------------------------------------
# Main

def main(argv):

    if len(argv) < 2:
        print 'Error...'
        return

    Filter = None
    Cut = None
    Dt = None
    Utm = None

    for i in range(0, len(argv)):

        if '-filter' in argv[i]:
            Filter = argv[i+1].split(',')
            Filter = [float(F) for F in Filter]

        if '-cut' in argv[i]:
            T0 = argv[i+1].split(',')
            T1 = argv[i+2].split(',')
            Cut = [[float(T) for T in T0],
                   [float(T) for T in T1]]

        if '-dt' in argv[i]:
            Dt = float(argv[i+1])

        if '-utm' in argv[i]:
            Utm = argv[i+1]

    SacUtils(InpFile=argv[0],
             OutFile=argv[1],
             Filter=Filter,
             Cut=Cut,
             Dt=Dt,
             Utm=Utm)

if __name__ == "__main__":
    main(sys.argv[1:])

