Skip to content
Snippets Groups Projects
Commit 6b820a30 authored by Jean-Matthieu Etancelin's avatar Jean-Matthieu Etancelin Committed by Franck Pérignon
Browse files

Ajout de décorateur pour les apply des opérateurs continus. Fix gpu_transfer

parent 5b9a399a
No related branches found
No related tags found
No related merge requests found
...@@ -44,7 +44,3 @@ GPU_SRC = os.path.join(__path__[0], "cl_src", '') ...@@ -44,7 +44,3 @@ GPU_SRC = os.path.join(__path__[0], "cl_src", '')
## If use OpenCL profiling events to time computations ## If use OpenCL profiling events to time computations
CL_PROFILE = False CL_PROFILE = False
## Transfert direction
HostToDevice = 9876
DeviceToHost = 6789
from parmepy.operator.computational import Computational from parmepy.operator.computational import Computational
from parmepy.gpu import HostToDevice, DeviceToHost from parmepy.methods_keys import Support
class DataTransfer(Computational): class DataTransfer(Computational):
"""Operator for moving data between CPU and GPU.""" """Operator for moving data between CPU and GPU."""
def __init__(self, way, freq=1, **kwds): def __init__(self, source, target, freq=1, **kwds):
""" """
@param way: HostToDevice or DeviceToHost flag setting @param way: HostToDevice or DeviceToHost flag setting
the data copy direction. the data copy direction.
""" """
super(DataTransfer, self).__init__(**kwds) super(DataTransfer, self).__init__(**kwds)
## Transfer way ## Source operator or topology
self.way = way self.source = source
## Target operator
self.target = target
## Transfer frequency in iteration number ## Transfer frequency in iteration number
self.freq = freq self.freq = freq
source_is_gpu = False
try:
if source.method[Support].find('gpu') >= 0:
source_is_gpu = True
except:
pass
target_is_gpu = False
try:
if target.method[Support].find('gpu') >= 0:
target_is_gpu = True
except:
pass
## Current transfer function ## Current transfer function
if self.way == HostToDevice: if source_is_gpu:
self._transfer = self._apply_toHost
elif target_is_gpu:
self._transfer = self._apply_toDevice self._transfer = self._apply_toDevice
else: else:
self._transfer = self._apply_toHost raise RuntimeError("Source or target mus bu a GPU operator.")
def setup(self): def setup(self):
pass self.target.addRedistributeRequirement(self)
def apply(self, simulation): def apply(self, simulation):
ite = simulation.currentIteration ite = simulation.currentIteration
...@@ -33,10 +50,13 @@ class DataTransfer(Computational): ...@@ -33,10 +50,13 @@ class DataTransfer(Computational):
for v in self.variables: for v in self.variables:
df = self.discreteFields[v] df = self.discreteFields[v]
df.toHost() df.toHost()
df.wait()
def _apply_toDevice(self): def _apply_toDevice(self):
for v in self.variables: for v in self.variables:
df = self.discreteFields[v] df = self.discreteFields[v]
df.toDevice() df.toDevice()
def wait(self):
for v in self.variables:
df = self.discreteFields[v]
df.wait() df.wait()
...@@ -21,3 +21,21 @@ ...@@ -21,3 +21,21 @@
# preset dictionnaries. # preset dictionnaries.
# #
# Keys in methods dict are given in parmepy.method_keys. # Keys in methods dict are given in parmepy.method_keys.
from parmepy.tools.profiler import ftime
def apply_decoration(f):
"""
Decorator for operators apply functions.
"""
def deco(*args, **kwargs):
"""args[0] contains the object"""
assert args[0]._is_discretized
for req in args[0].requirements:
req.wait()
t0 = ftime()
res = f(*args, **kwargs)
prof = args[0].profiler[f.func_name]
prof += ftime() - t0
return res
return deco
...@@ -8,7 +8,7 @@ from parmepy.constants import debug ...@@ -8,7 +8,7 @@ from parmepy.constants import debug
from parmepy.operator.continuous import Operator from parmepy.operator.continuous import Operator
from parmepy.mpi.topology import Cartesian from parmepy.mpi.topology import Cartesian
from parmepy.tools.parameters import Discretization from parmepy.tools.parameters import Discretization
from parmepy.tools.profiler import profile from parmepy.operator import apply_decoration
class Computational(Operator): class Computational(Operator):
...@@ -289,8 +289,7 @@ class Computational(Operator): ...@@ -289,8 +289,7 @@ class Computational(Operator):
if self.discreteOperator is not None: if self.discreteOperator is not None:
self.discreteOperator.finalize() self.discreteOperator.finalize()
@debug @apply_decoration
@profile
def apply(self, simulation=None): def apply(self, simulation=None):
""" """
Apply this operator to its variables. Apply this operator to its variables.
...@@ -298,9 +297,6 @@ class Computational(Operator): ...@@ -298,9 +297,6 @@ class Computational(Operator):
parameters (time, time step, iteration number ...), see parameters (time, time step, iteration number ...), see
parmepy.problem.simulation.Simulation for details. parmepy.problem.simulation.Simulation for details.
""" """
for req in self.requirements:
req.wait()
assert self._is_discretized
self.discreteOperator.apply(simulation) self.discreteOperator.apply(simulation)
def printComputeTime(self): def printComputeTime(self):
......
...@@ -5,6 +5,7 @@ File output for field(s) value on a grid. ...@@ -5,6 +5,7 @@ File output for field(s) value on a grid.
""" """
from parmepy.constants import S_DIR, debug, HDF5, PARMES_REAL from parmepy.constants import S_DIR, debug, HDF5, PARMES_REAL
from parmepy.operator.computational import Computational from parmepy.operator.computational import Computational
from parmepy.operator import apply_decoration
import parmepy.tools.numpywrappers as npw import parmepy.tools.numpywrappers as npw
import parmepy.tools.io_utils as io import parmepy.tools.io_utils as io
from parmepy.tools.parameters import IO_params from parmepy.tools.parameters import IO_params
...@@ -199,8 +200,7 @@ class HDF_Writer(HDF_IO): ...@@ -199,8 +200,7 @@ class HDF_Writer(HDF_IO):
self._get_filename = lambda i: self.io_params.filename + \ self._get_filename = lambda i: self.io_params.filename + \
"_{0:05d}".format(i) + '.h5' "_{0:05d}".format(i) + '.h5'
@debug @apply_decoration
@profile
def apply(self, simulation=None): def apply(self, simulation=None):
if simulation is None: if simulation is None:
raise ValueError("Missing simulation value for monitoring.") raise ValueError("Missing simulation value for monitoring.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment