import numpy as np
from parmepy.domain.obstacle.controlBox import ControlBox
import parmepy as pp
import parmepy.mpi as mpi
from parmepy.domain.obstacle.planes import SubSpace, SubPlane
from parmepy.domain.obstacle.sphere import Sphere

from parmepy.operator.monitors.printer import Printer
from parmepy.problem.simulation import Simulation
from parmepy.constants import HDF5
from parmepy.operator.monitors.compute_forces import DragAndLift
pi = np.pi
cos = np.cos
sin = np.sin


nb = 129

Lx = Ly = Lz = 2
dom = pp.Box(dimension=3, length=[Lx, Ly, Lz], origin=[-1., -1., -1.])
dom2 = pp.Box(dimension=2, length=[Lx, Ly], origin=[-1., -1.])
resol3D = [nb, nb, nb]
resol2D = [nb, nb]


## Function to compute TG velocity
def computeVel(res, x, y, z, t):
    res[0][...] = sin(x) * cos(y) * cos(z)
    res[1][...] = - cos(x) * sin(y) * cos(z)
    res[2][...] = 0.
    return res


## Function to compute reference vorticity
def computeVort(res, x, y, z, t):
    res[0][...] = cos(x) * sin(y) * sin(z)
    res[1][...] = sin(x) * cos(y) * sin(z)
    res[2][...] = 2. * sin(x) * sin(y) * cos(z)
    return res

# 2D Field
scal2 = pp.Field(domain=dom2)

# 3D Field
scal3 = pp.Field(domain=dom, name='s1')
sc3 = pp.Field(domain=dom, name='s2')
velo = pp.Field(domain=dom, formula=computeVel, is_vector=True, name='v1')
vorti = pp.Field(domain=dom, formula=computeVort, is_vector=True, name='w1')
boxl = np.asarray([0.8, .8, .7])
boxl = np.asarray([1., 1., 1.])
boxpos = np.asarray([-0.5, -0.5, -0.5])

# 2D control box

#cb2 = ControlBox(dom2, boxpos[:2], boxl[:2])
ng = 2
topo2 = mpi.topology.Cartesian(dom2, 2, resol2D, ghosts=[ng, ng])

# 3D Control box
cb1 = ControlBox(dom, boxpos, boxl)
cb2 = ControlBox(dom, boxpos, boxl)
topo3 = mpi.topology.Cartesian(dom, 3, resol3D, ghosts=[ng, ng, ng])

# init fields
scal2.discretize(topo2)
scal3.discretize(topo3)
sc3.discretize(topo3)
velo.discretize(topo3)
vorti.discretize(topo3)
velo.initialize(topo=topo3)
vorti.initialize(topo=topo3)
pref2 = './res2_' + str(topo2.size) + '/cb'
pref3 = './res3_' + str(topo2.size) + '/cb'
printer2D = Printer([scal2], topo2, prefix=pref2, frequency=1)
printer2D.setUp()
#printer2HDF5 = Printer([scal2], topo2, frequency=1, formattype=HDF5)
#printer2HDF5.setUp()

sd2 = scal2.discreteFields[topo2].data
sd3 = scal3.discreteFields[topo3].data
sd4 = sc3.discreteFields[topo3].data
wd = vorti.discreteFields[topo3].data
printer3D = Printer([scal3], topo3, prefix=pref3, frequency=1)
printer3D.setUp()
#printerHDF5 = Printer([scal3], topo3, frequency=1, formattype=HDF5)
#printerHDF5.setUp()
simulation = Simulation()


sd2[0][...] = 1.
sd3[0][...] = 1.
sd4[0][...] = 1.

#printer3D.apply(simulation)

topo2.comm.barrier()

#ib2 = cb2.discretize(topo2)

ib3 = cb1.discretize(topo3)
ib33 = cb2.discretize(topo3)
slice3 = cb1.slices[topo3]

topo2.comm.barrier()
sl2 = []
sl3 = []

#for s in cb2.upperS:
##     sl2.append(s)
## for s in cb2.lowerS:
##     sl2.append(s)
for s in cb1.upperS:
    sl3.append(s)
for s in cb1.lowerS:
    sl3.append(s)

## Subspaces
#subsp = SubSpace(dom2, [0, 1], [-0.8, -0.8], [0.5, 0.5])
#subsp3 = SubSpace(dom, [0, 1, 0], [-0.8, -0.8, -0.8], [0.5, 0.5, 1.0])
#ind = subsp3.discretize(topo3)

## Subplanes

#sp2 = SubPlane(dom2, [0, 1], [-0.8, -0.8], [0.5, 0.5])
#sp3 = SubPlane(dom, [0, 1, 0], [-0.75, -0.75, -0.75], [0.5, 0.5, 1.0])
#ind = sp3.discretize(topo3)


integ = cb1.integrate(scal3, topo3)
integ2 = cb1.integrate(scal3, topo3, useSlice=False)
integ3 = cb2.integrate(scal3, topo3, useSlice=False)

if topo3.rank == 0:
    print "integ = ", integ, integ2, integ3

#print topo3.rank, cb.slices

#sd3[0][ib3] = 0.0
#sd4[0][slice3] = 0.0

#printer3D.apply(simulation)
cc = topo3.mesh.coords

## for s in sl3:
##  #   print s.slices
##     normal = s.normal
##     ind = np.where(s.normal != 0)[0]
##     resup = cb.integrateOnSurface(scal3, topo3, normalDir=ind, up=True)
##     resdown = cb.integrateOnSurface(scal3, topo3, normalDir=ind, up=False)
##     print topo3.rank, resup, resdown
##     resup = cb.integrateOnSurface(scal3, topo3, normalDir=ind, up=True, useSlice=False)
##     resdown = cb.integrateOnSurface(scal3, topo3, normalDir=ind, up=False, useSlice=False)
##     print 'v2', topo3.rank, resup, resdown
##     #print topo3.rank, ind, s.slices[topo3]
##     #print topo3.rank, ind, cc[2].flat[s.slices[topo3][2]]

    
##     ## if topo3.rank == 0:
##     ##     print 'int ...', topo3.rank, resup, resdown
##     #sd3[0][s.slices[topo3]] = 0.

## integ = cb2.integrate(scal2, topo2)

## print "integ = ", integ

## for s in sl2:
##     normal = s.normal
##     ind = np.where(s.normal != 0)[0]
##     resup = cb2.integrateOnSurface(scal2, topo2, normalDir=ind, up=True)
##     resdown = cb2.integrateOnSurface(scal2, topo2, normalDir=ind, up=False)

##     if topo2.rank == 0:
##         print 'int ...', topo2.rank, resup, resdown
sphere = Sphere(dom, position=[0., 0., 0.], radius=0.3)
sphere.discretize(topo3)

## wd[0][sphere.ind[topo3][0]] *= 1e7
#wd[1][sphere.ind[topo3][0]] = 1e6
#print 
## wd[2][sphere.ind[topo3][0]] *= 1e5

print wd[1].max()
print wd[1][sphere.ind[topo3]].max()
nu = 0.3
dr = DragAndLift(velo, vorti, nu, topo3, cb1, filename=pref3 + 'forces.dat')
dr2 = DragAndLift(velo, vorti, nu, topo3, cb2, obstacles=[sphere])
dr.discretize()
dr2.discretize()
## #import parmepy.tools.numpywrappers as npw
## #res = npw.zeros(3)
## #res = dr._integrateOnBox(res)
## #print 'forces loc ...', res

## #resok= topo3.comm.allreduce(res)

#print cb.coords[topo3]

for i in xrange(10):
    dr.apply(simulation)
    simulation.advance()
    
dr2.apply(simulation)
sd3[0][...] = 0.0
sd3[0][cb1.slices[topo3]] = 1.
sd3[0][sphere.ind[topo3]] = 2.

printer3D.apply(simulation)

print 'forces 1 ...', dr.force
#simulation.advance()
#wd[0][...] +=12.3
#dr.apply(simulation)

print 'forces 2...', dr2.force

## #
## printer3D.apply(simulation)
## printer2D.apply(simulation)
## #printerHDF5.apply(simulation)

## print topo3.rank, cb.mesh[topo3]

## print 'full', topo3.rank, topo3.mesh
## print topo3.mesh.iCompute