#!/usr/bin/python

"""
Taylor Green 3D : see paper van Rees 2011.

All parameters are set and defined in python module dataTG.

"""

from hysop import Box
from hysop.f2py import fftw2py
import numpy as np
import cPickle
from scitools.NumPyDB import NumPyDB_cPickle as hysopPickle
from hysop.fields.continuous import Field
from hysop.fields.variable_parameter import VariableParameter
from hysop.mpi.topology import Cartesian
from hysop.operator.advection import Advection
from hysop.operator.stretching import Stretching
from hysop.operator.absorption_BC import AbsorptionBC
from hysop.operator.poisson import Poisson
from hysop.operator.diffusion import Diffusion
from hysop.operator.adapt_timestep import AdaptTimeStep
from hysop.operator.redistribute_intra import RedistributeIntra
from hysop.operator.hdf_io import HDF_Writer
from hysop.operator.energy_enstrophy import EnergyEnstrophy
from hysop.operator.profiles import Profiles
from hysop.problem.simulation import Simulation
from hysop.methods_keys import Scales, TimeIntegrator, Interpolation,\
    Remesh, Support, Splitting, dtCrit, SpaceDiscretisation
from hysop.numerics.integrators.runge_kutta2 import RK2 as RK2
from hysop.numerics.integrators.runge_kutta3 import RK3 as RK3
from hysop.numerics.integrators.runge_kutta4 import RK4 as RK4
from hysop.numerics.finite_differences import FD_C_4, FD_C_2
from hysop.numerics.interpolation import Linear
from hysop.numerics.remeshing import L6_4 as rmsh
import hysop.tools.io_utils as io
import hysop.tools.numpywrappers as npw
from hysop.mpi import main_rank, MPI
from hysop.tools.parameters import Discretization, IOParams

print " ========= Start Navier-Stokes 3D (Taylor Green benchmark) ========="

# ====== pi constant and trigonometric functions ======
pi = np.pi
cos = np.cos
sin = np.sin

# ====== Flow constants =======
uinf = 1.0
VISCOSITY = 1. / 300.

# ======= Domain =======
dim = 3
Nx = 129
Ny = Nz = 65
#Nx = 257
#Ny = Nz = 129
#Nx = 513
#Ny = Nz = 257
#Nx = 1025
#Ny = Nz = 513
g = 2
boxlength = npw.asrealarray([10.24, 5.12, 5.12])
boxorigin = npw.asrealarray([-2.0, -2.56, -2.56])
box = Box(length=boxlength, origin=boxorigin)

# A global discretization with ghost points
d3dg = Discretization([Nx, Ny, Nz], [g, g, g])
# A global discretization, without ghost points
d3d = Discretization([Nx, Ny, Nz])

# ====== Sphere inside the domain ======
RADIUS = 0.5
pos = [0., 0., 0.]
from hysop.domain.subsets import Sphere, HemiSphere
sphere = HemiSphere(origin=pos, radius=RADIUS, parent=box)


# ======= Function to compute initial velocity  =======
def computeVel(res, x, y, z, t):
    res[0][...] = uinf
    res[1][...] = 0.
    res[2][...] = 0.
    return res


# ======= Function to compute initial vorticity =======
def computeVort(res, x, y, z, t):
    res[0][...] = 0.
    res[1][...] = 0.
    res[2][...] = 0.
    return res

#  ====== Time-dependant required-flowrate (Variable Parameter) ======
def computeFlowrate(simu):
    # === Time-dependant flow rate ===
    t = simu.tk
    Tstart = 3.0
    flowrate = np.zeros(3)
    flowrate[0] = uinf * box.length[1] * box.length[2]
    if t >= Tstart and t <= Tstart + 1.0:
        flowrate[1] = sin(pi * (t - Tstart)) * \
                      box.length[1] * box.length[2]
    # === Constant flow rate ===
    #    flowrate = np.zeros(3)
    #    flowrate[0] = uinf * box.length[1] * box.length[2]
    return flowrate


# ======= Fields =======
velo = Field(domain=box, formula=computeVel,
             name='Velocity', is_vector=True)
vorti = Field(domain=box, formula=computeVort,
              name='Vorticity', is_vector=True)

# ========= Simulation setup =========
simu = Simulation(start=0.0, end=75.0, time_step=0.0125, max_iter=10000000)


# Adaptative timestep method : dt = min(values(dtCrit))
# where dtCrit is a list of criterions on which the computation
# of the adaptative time step is based
# ex : dtCrit = ['gradU', 'cfl', 'stretch'], means :
# dt = min (dtAdv, dtCfl, dtStretch), where dtAdv is equal to LCFL / |gradU|
# For dtAdv, the possible choices are the following:
# 'vort' (infinite norm of vorticity) : dtAdv = LCFL / |vort|
# 'gradU' (infinite norm of velocity gradient), dtAdv = LCFL / |gradU|
# 'deform' (infinite norm of deformation tensor),
# dtAdv = LCFL / (0.5(gradU + gradU^T))
op = {}
iop = IOParams("time_step")
# Default topology (i.e. 3D, with ghosts)
topo_with_ghosts = box.create_topology(d3dg)


op['dtAdapt'] = AdaptTimeStep(velo, vorti, simulation=simu,
                              discretization=topo_with_ghosts,
                              method={TimeIntegrator: RK3,
                                      SpaceDiscretisation: FD_C_4,
                                      dtCrit: ['gradU', 'stretch', 'cfl']},
                              io_params=iop,
                              lcfl=0.125,
                              cfl=0.5)

op['advection'] = Advection(velo, vorti,
                            discretization=d3d,
                            method={Scales: 'p_M6',
                                    Splitting: 'classic'}
                            )

op['stretching'] = Stretching(velo, vorti,
                              discretization=topo_with_ghosts)

op['diffusion'] = Diffusion(viscosity=VISCOSITY, vorticity=vorti,
                            discretization=d3d)

rate = VariableParameter(formula=computeFlowrate)
op['poisson'] = Poisson(velo, vorti, discretization=d3d, flowrate=rate)

# ===== Discretization of computational operators ======
for ope in op.values():
    ope.discretize()

topofft = op['poisson'].discreteFields[vorti].topology
topoadvec = op['advection'].discreteFields[vorti].topology

# =====  Smooth vorticity absorption at the outlet =====
op['vort_absorption'] = AbsorptionBC(velo, vorti, discretization=topofft, 
                                     req_flowrate=rate, 
                                     x_coords_absorp=[7.24, 8.24])
#                                     x_coords_absorp=[1.56, 2.56])
op['vort_absorption'].discretize()

# =====  Penalization of the vorticity on a sphere inside the domain =====
from hysop.operator.penalization import PenalizeVorticity
op['penalVort'] = PenalizeVorticity(velocity=velo, vorticity=vorti,
                                    discretization=topo_with_ghosts,
                                    obstacles=[sphere], coeff=1e8,
                                    method={SpaceDiscretisation: FD_C_4})
op['penalVort'].discretize()

# ==== Operators to map data between the different computational operators ===
# (i.e. between topologies)
distr = {}
distr['fft2str'] = RedistributeIntra(source=op['poisson'],
                                     target=op['stretching'],
                                     variables=[velo, vorti])
distr['str2fft'] = RedistributeIntra(source=op['stretching'],
                                     target=op['poisson'],
                                     variables=[velo, vorti])
distr['fft2advec'] = RedistributeIntra(source=op['poisson'],
                                       target=op['advection'],
                                       variables=[velo, vorti])
distr['advec2fft'] = RedistributeIntra(source=op['advection'],
                                       target=op['poisson'],
                                       variables=[velo, vorti])
# ========= Monitoring operators =========
monitors = {}
#iop = IOParams('fields', frequency=100)
#monitors['writer'] = HDF_Writer(variables={velo: topofft, vorti: topofft},
#                                io_params=iop)

io_ener = IOParams('energy_enstrophy')
monitors['energy'] = EnergyEnstrophy(velo, vorti, discretization=topofft,
                                     io_params=io_ener, is_normalized=False)

rk = 0
if (0.0 in topofft.mesh.coords[2]):
    rk = main_rank
io_prof = IOParams('profile_Y_axis', frequency=10, io_leader=rk)
monitors['profile'] = Profiles(velo, vorti, discretization=topofft,
                               io_params=io_prof, prof_coords=[0.0, 0.0], 
                               direction=1, beginMeanComput=0.1)

from hysop.domain.control_box import ControlBox
from hysop.operator.drag_and_lift import MomentumForces, NocaForces
ref_step = topo_with_ghosts.mesh.space_step
cbpos = npw.zeros(dim)
cblength = npw.zeros(dim)
cbpos[...] = boxorigin[...]
cbpos +=  15 * ref_step
cblength[...] = boxlength[...]
cblength -= 30 * ref_step
cb = ControlBox(parent=box, origin=cbpos, length=cblength)
coeffForce = 1. / (0.5 * uinf ** 2 * pi * RADIUS ** 2)

io_forces=IOParams('drag_and_lift_NocaII')
#monitors['forcesNoca'] = NocaForces(velo, vorti, 
#                                    discretization=topo_with_ghosts,
#                                    nu=VISCOSITY, 
#                                    volume_of_control=cb,
#                                    normalization=coeffForce,
#                                    obstacles=[sphere], 
#                                    io_params=io_forces)

io_forcesPenal=IOParams('drag_and_lift_Mom')
monitors['forcesMom'] = MomentumForces(velocity=velo, 
                                       discretization=topo_with_ghosts,
                                       normalization=coeffForce,
                                       obstacles=[sphere], 
                                       penalisation_coeff=[1e8],
                                       io_params=io_forcesPenal)

#io_forcesPenal=IOParams('drag_and_lift_penal')
#monitors['forcesPenal'] = DragAndLiftPenal(velo, vorti, coeffForce,
#                                           discretization=topofft,
#                                           obstacles=[sphere], factor=[1e8],
#                                           io_params=io_forcesPenal)

step_dir = ref_step[0]
io_sliceXY = IOParams('sliceXY', frequency=20)
thickSliceXY = ControlBox(parent=box, origin=[-2.0, -2.56, -2.0 * step_dir], 
                          length=[10.24- step_dir, 5.12- step_dir, 4.0 * step_dir])
#thickSliceXY = ControlBox(parent=box, origin=[-2.56, -2.56, -2.0 * step_dir], 
#                          length=[5.12 - step_dir, 5.12 - step_dir, 4.0 * step_dir])
monitors['writerSliceXY'] = HDF_Writer(variables={velo: topofft, vorti: topofft},
                                      io_params=io_sliceXY, subset=thickSliceXY, 
                                      xmfalways=True)

io_sliceXZ = IOParams('sliceXZ', frequency=400)
thickSliceXZ = ControlBox(parent=box, origin=[-2.0, -2.0 * step_dir, -2.56], 
                          length=[10.24- step_dir, 4.0 * step_dir, 5.12- step_dir])
monitors['writerSliceXZ'] = HDF_Writer(variables={velo: topofft, vorti: topofft},
                                       io_params=io_sliceXZ, subset=thickSliceXZ, 
                                       xmfalways=True)

io_subBox = IOParams('subBox', frequency=2000)
subBox = ControlBox(parent=box, origin=[-0.7, -2.0, -2.0], length=[8.0, 4.0, 4.0])
monitors['writerSubBox'] = HDF_Writer(variables={velo: topofft, vorti: topofft},
                                      io_params=io_subBox, subset=subBox, 
                                      xmfalways=True)

# ========= Setup for all declared operators/monitors =========
time_setup = MPI.Wtime()
for ope in op.values():
    ope.setup()
for ope in distr.values():
    ope.setup()

for monit in monitors.values():
    monit.discretize()
for monit in monitors.values():
    monit.setup()

print '[', main_rank, '] total time for setup:', MPI.Wtime() - time_setup

# ========= Fields initialization =========
# - initialize velo + vort on topostr
# - penalize vorticity
# - redistribute topostr --> topofft

time_init = MPI.Wtime()
ind = sphere.discretize(topofft)
def initFields():
    velo.initialize(topo=topo_with_ghosts)
    vorti.initialize(topo=topo_with_ghosts)
    op['penalVort'].apply(simu)
    distr['str2fft'].apply(simu)
    distr['str2fft'].wait()

initFields()
print '[', main_rank, '] total time for init :', MPI.Wtime() - time_init

fullseq = []

def run(sequence):
    op['vort_absorption'].apply(simu)
    op['poisson'].apply(simu)               # Poisson + correction
    monitors['forcesMom'].apply(simu)     # Forces Heloise
    distr['fft2str'].apply(simu)
    distr['fft2str'].wait()
    op['penalVort'].apply(simu)             # Vorticity penalization
#    distr['str2fft'].apply(simu)
#    distr['str2fft'].wait()
#    op['poisson'].apply(simu)
#    distr['fft2str'].apply(simu)
#    distr['fft2str'].wait()
    op['stretching'].apply(simu)            # Stretching
#    monitors['forcesNoca'].apply(simu)          # Forces Noca
    distr['str2fft'].apply(simu)
    distr['str2fft'].wait()
    op['diffusion'].apply(simu)             # Diffusion
    distr['fft2advec'].apply(simu)
    distr['fft2advec'].wait()
    op['advection'].apply(simu)             # Advection (scales)
    distr['advec2fft'].apply(simu)
    distr['advec2fft'].wait()
    monitors['writerSliceXY'].apply(simu)
#    monitors['writerSliceXZ'].apply(simu)
#    monitors['writerSubBox'].apply(simu)
    monitors['energy'].apply(simu)          # Energy/enstrophy
    monitors['profile'].apply(simu)         # Profile
    distr['fft2str'].apply(simu)
    distr['fft2str'].wait()
    op['dtAdapt'].apply(simu)               # Update timestep
    op['dtAdapt'].wait()

# ==== Serialize the simulation data of the problem to a "restart" file ====
def dump(filename):
    """
    Serialize some data of the problem to file
    (only data required for a proper restart, namely fields in self.input
    and simulation).
    @param filename : prefix for output file. Real name = filename_rk_N,
    N being current process number. If None use default value from problem
    parameters (self.filename)
    """
    if filename is not None:
        filedump = filename + '_rk_' + str(main_rank)
    db = open(filedump, 'wb')
    cPickle.dump(simu, db)

# ====== Load the simulation data of the problem from a "restart" file ======
def restart(filename):
    """
    Load serialized data to restart from a previous state.
    self.input variables and simulation are loaded.
    @param  filename : prefix for downloaded file.
    Real name = filename_rk_N, N being current process number.
    If None use default value from problem
    parameters (self.filename)
    """
    if filename is not None:
        filedump = filename + '_rk_' + str(main_rank)
    db = open(filedump, 'r')
    simu = cPickle.load(db)
    simu.start = simu.time - simu.time_step
    ite = simu.current_iteration
    simu.initialize()
    simu.current_iteration = ite
    print 'simu', simu
    print ("load ...", filename)
    return simu

seq = fullseq

simu.initialize()
doDump = False
doRestart = False
dumpFreq = 10
io_default=IOParams('restart')
dump_filename = io.Writer(io_params=io_default).filename
#===== Restart (if needed) =====
if doRestart:
    simu = restart(dump_filename)
    iop_vel = IOParams('velo_00000.h5')
    velo.hdf_load(topofft, io_params=iop_vel)
    iop_vort = IOParams('vorti_00000.h5')
    vorti.hdf_load(topofft, io_params=iop_vort)
    # Set up for monitors and redistribute
    for ope in distr.values():
        ope.setup()
    for monit in monitors.values():
        monit.setup()

# ======= Time loop =======
time_run = MPI.Wtime()
while not simu.isOver:
    if topofft.rank == 0:
        simu.printState()
    run(seq)
    simu.advance()
    testdump = simu.current_iteration % dumpFreq is 0
    if doDump and testdump:
        print 'dump ...'
        dump(dump_filename)
        iop_vel = IOParams('velo')
        velo.hdf_dump(topofft, io_params=iop_vel)
        iop_vort = IOParams('vorti')
        vorti.hdf_dump(topofft, io_params=iop_vort)
print '[', main_rank, '] total time for run :', MPI.Wtime() - time_run

# ======= Finalize =======
fftw2py.clean_fftw_solver(box.dimension)
for ope in distr.values():
    ope.finalize()
for monit in monitors.values():
    monit.finalize()