# Version 2021/07/26

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np

import OQCatk.Catalogue as Cat
import OQCatk.Smoothing as Smt
import OQCatk.CatUtils as CU

import AreaSourceTool as AST
import CentralAsiaPath as Cap

from SmoothLen import Sigma

# OQ-specific libraries
from openquake.hazardlib.scalerel import WC1994
from openquake.hazardlib.tom import PoissonTOM

from openquake.hazardlib.geo.mesh import Mesh
from openquake.hazardlib.source import MultiPointSource
from openquake.hazardlib.sourcewriter import write_source_model

import sys
import warnings

warnings.filterwarnings("ignore")

#-----------------------------------------------------------------------------
# Settings

#VID = 'V6'
#ZID = 'Shallow'

VID = Cap.VID
ZID = Cap.ZID

Delta = 0.1

Cfg = {'Shallow': 
          {'ModelName': 'Central Asia - Shallow Grid Sources',
           'UpperSeismogenicDepth': 0.,
           'LowerSeismogenicDepth': 65.,
           'DepthRange': [0, 50]},
       'Intermediate':
          {'ModelName': 'Central Asia - Intermediate Grid Sources',
           'UpperSeismogenicDepth': 35.,
           'LowerSeismogenicDepth': 200.,
           'DepthRange': [50, 170]},
       'Deep':
          {'ModelName': 'Central Asia - Deep Grid Sources',
           'UpperSeismogenicDepth': 150.,
           'LowerSeismogenicDepth': 350.,
           'DepthRange': [170, 300]},
       }

Mmin = 4.5

OQFile = 'SourceDB/{0}/XML/GridMultiSources_{1}.xml'.format(VID, ZID)
CsvFile = 'SourceDB/{0}/CSV/GridMultiSources_{1}.csv'.format(VID, ZID)

#-----------------------------------------------------------------------------
# Import catalogue

DbDir = Cap.CAT_DB + '/WB-CentralAsia/Declustered/GardnerKnopoff/'
DbFile = DbDir + 'WBCA_Backbone_Mw_Locals.dec.purge.csv'

Db0 = Cat.Database()
Db0.Import(DbFile)

#-----------------------------------------------------------------------------
# Import Zones

AS = AST.AreaSource()
AS.Import('SourceDB/{0}/{1}/SZ.05.geojson'.format(VID, ZID))

#-----------------------------------------------------------------------------
# Filtering depth

# Remove undetermined
DbNone = Db0.Filter('Depth', None, Opr='=', Owrite=0)
Db0.Filter('Depth', None, Opr='!=', Owrite=1)

# Select by depth range
Db0.Filter('Depth', Cfg[ZID]['DepthRange'][0], Opr='>')
Db0.Filter('Depth', Cfg[ZID]['DepthRange'][1], Opr='<=')

# Include back undetermined
Db0.Append(DbNone)

#-----------------------------------------------------------------------------
# Loop over sources

SourceData = []
X = []
Y = []
A = []

for Id in AS.GetIdList():

    print('Source: {0}'.format(Id))

    # Weighted sum
    PaWS = 0.

    Grp = AS.GetById(Id, 'Group')

    if Grp in Sigma:
        SigmaSel = Sigma[Grp]
    else:
        SigmaSel = Sigma['Default']

    for I in SigmaSel:
        Px,Py,Pa = Smt.SmoothMFD(Db0, AS.GetById(Id, 'Sa'), AS.GetPolygon(Id),
                                 Par=I[0], Delta=Delta, SphereGrid=True,
                                 Buffer=2.0, ZeroRates=True)
        PaWS += I[1]*(10**np.array(Pa))
    PaWS = np.log10(PaWS)

    Threshold = -100
    Idx = [i for i,a in enumerate(PaWS) if a >= Threshold]
    PaWS = [PaWS[i] for i in Idx]
    Px = [Px[i] for i in Idx]
    Py = [Py[i] for i in Idx]

    ck0 = AS.GetById(Id, 'Sa')
    ck1 = np.log10(np.sum(10**np.array(PaWS)))
    if round(ck0, 4) != round(ck1, 4):
        print('Rate balance error. Exit...')
        sys.exit()

    Mfd = AST.FormatMultiMFD(Mmin, AS.GetById(Id, 'SMmax'), 0.1,
                             PaWS, AS.GetById(Id, 'Sb'))
    Npd = AST.FormatPlane(AS.GetById(Id, 'Nodal'))
    Hdd = AST.FormatDepth(AS.GetById(Id, 'Depth'))

    TRT = 'Tectonic_Region_{0}'.format(AS.GetById(Id, 'Tecto'))

    src = MultiPointSource(
            source_id = Id,
            name = AS.GetById(Id, 'Name'),
            tectonic_region_type = TRT,
            mfd = Mfd,
            magnitude_scaling_relationship = WC1994(),
            rupture_aspect_ratio = 2.0,
            temporal_occurrence_model = PoissonTOM(1.0),
            nodal_plane_distribution = Npd,
            hypocenter_distribution = Hdd,
            upper_seismogenic_depth = Cfg[ZID]['UpperSeismogenicDepth'],
            lower_seismogenic_depth = Cfg[ZID]['LowerSeismogenicDepth'],
            mesh=Mesh(np.array(Px),np.array(Py)))

    SourceData.append(src)

    X += Px
    Y += Py
    A += PaWS

#-----------------------------------------------------------------------------
# Plotting normalised rates (CHECK)
if 1:

  plt.close('all')

  #ma = min(A)
  #Ma = max(A)
  ma = -3.5
  Ma = 3.5

  for Id in AS.GetIdList():

      Poly = AS.GetPolygon(Id)
      plt.plot(Poly.x, Poly.y, 'k')

  An = [(a-ma)/(Ma-ma) for a in A]
  plt.scatter(X, Y, s=2, c=plt.cm.coolwarm(An))

  plt.grid('on')
  plt.axis('equal')
  plt.show(block=False)

  plt.savefig('Pictures/{0}/{1}/Smooth_Collapsed'.format(VID, ZID), dpi=150)

#-----------------------------------------------------------------------------
# Export to XML

_ = write_source_model(OQFile, SourceData, Cfg[ZID]['ModelName'])

#-----------------------------------------------------------------------------
# Export to CSV

with open(CsvFile, 'w') as f:
    f.write('Longitude,Latitude,Occurrence\n')
    for x,y,a in zip(X, Y, A):
        f.write('{0},{1},{2}\n'.format(x,y,a))

