#!/usr/bin/env python

import time
from parmepy.domain.box import Box
from parmepy.fields.continuous import Field
from parmepy.operator.advection import Advection
from parmepy.problem.transport import TransportProblem
from parmepy.gpu import PARMES_REAL_GPU, PARMES_DOUBLE_GPU


def vitesse(x, y, z):
    vx = 1. + x
    vy = - x * y
    vz = x * y * z + 10.
    return vx, vy, vz


def scalaire(x, y, z):
    if x < 0.5 and y < 0.5 and z < 0.5:
        return 1.
    else:
        return 0.


def run():
    # Parameters
    nb = 65
    nbElem = (nb, nb, nb)
    time_step = 0.02
    finalTime = 1.
    outputFilePrefix = './res/RK2_'
    outputModulo = 0

    t0 = time.time()

    ## Domain
    box = Box(3, length=[1., 1., 1.], origin=[0., 0., 0.])

    ## Fields
    scal = Field(domain=box, name='Scalar')
    velo = Field(domain=box, name='Velocity', is_vector=True)
    #scal = pp.AnalyticalField(domain=box, name='Scalar')
    #velo = pp.AnalyticalField(domain=box, formula=vitesse, name='Velocity', is_vector=True)

    ## Operators
    advec = Advection(velo, scal,
                      resolutions={velo: nbElem,
                                   scal: nbElem},
                      #method='gpu_1k_m4prime',
                      #method='gpu_1k_m6prime',
                      #method='gpu_1k_m8prime',
                      method='gpu_2k_m4prime',
                      #method='gpu_2k_m6prime',
                      #method='gpu_2k_m8prime',
                      #method='scales'
                      src=['./levelSet3D.cl'],
                      precision=PARMES_REAL_GPU,
                      #precision=PARMES_REAL_GPU,
                      )

    ##Problem
    pb = TransportProblem([advec])

    ## Setting solver to Problem
    pb.setUp(finalTime, time_step)

    t1 = time.time()
    ## Solve problem
    timings = pb.solve()
    tf = time.time()

    print "\n"
    print "Total time : ", tf - t0, "sec (CPU)"
    print "Init time : ", t1 - t0, "sec (CPU)"
    print "Solving time : ", tf - t1, "sec (CPU)"


if __name__ == "__main__":
    run()