From ca0c27555b964e9b9c6b6854db702a0828989257 Mon Sep 17 00:00:00 2001
From: Jean-Baptiste Keck <Jean-Baptiste.Keck@imag.fr>
Date: Mon, 19 Nov 2018 23:02:54 +0100
Subject: [PATCH] added buffers to hdf_io

---
 .../particles_above_salt_bc_3d.py             | 11 ++--
 hysop/operator/hdf_io.py                      | 66 ++++++++++++++++---
 2 files changed, 60 insertions(+), 17 deletions(-)

diff --git a/examples/particles_above_salt/particles_above_salt_bc_3d.py b/examples/particles_above_salt/particles_above_salt_bc_3d.py
index 014f03bf8..e0c655b8f 100644
--- a/examples/particles_above_salt/particles_above_salt_bc_3d.py
+++ b/examples/particles_above_salt/particles_above_salt_bc_3d.py
@@ -241,8 +241,7 @@ def compute(args):
     io_params = IOParams(filename='fields', frequency=args.dump_freq)
     dump_fields = HDF_Writer(name='dump',
                              io_params=io_params,
-                             variables={velo: npts, 
-                                        vorti: npts,
+                             variables={velo[0]: npts, 
                                         C: npts, 
                                         S: npts})
 
@@ -276,12 +275,10 @@ def compute(args):
         )
 
     problem = Problem(method=method)
-    problem.insert(poisson, 
-                   diffuse_W, diffuse_S, diffuse_C,
-                   splitting, 
+    problem.insert(poisson, diffuse_W, diffuse_S, diffuse_C,
                    dump_fields,
-                   min_max_U, min_max_W, 
-                   adapt_dt)
+                   splitting, 
+                   min_max_U, min_max_W, adapt_dt)
     problem.build()
     
     # If a visu_rank was provided, and show_graph was set,
diff --git a/hysop/operator/hdf_io.py b/hysop/operator/hdf_io.py
index 009c931cf..7ec7b59eb 100755
--- a/hysop/operator/hdf_io.py
+++ b/hysop/operator/hdf_io.py
@@ -7,6 +7,7 @@
 * :class:`~HDF_IO` abstract interface for hdf io classes
 
 """
+import functools
 from abc import ABCMeta, abstractmethod
 from hysop.deps import h5py, sys
 from hysop.core.graph.graph import discretized
@@ -19,6 +20,7 @@ from hysop.core.graph.graph import op_apply
 from hysop.core.graph.computational_graph import ComputationalGraphOperator
 from hysop.fields.continuous_field import Field
 from hysop.topology.cartesian_descriptor import CartesianTopologyDescriptors
+from hysop.core.memory.memory_request import MemoryRequest
 
 class HDF_IO(ComputationalGraphOperator):
     """
@@ -127,6 +129,14 @@ class HDF_IO(ComputationalGraphOperator):
             (field, td, req) = ireq
             req.axes = (TranspositionState[field.dim].default_axes(),)
         return requirements
+    
+    def get_node_requirements(self):
+        node_reqs = super(HDF_IO, self).get_node_requirements()
+        node_reqs.enforce_unique_transposition_state = True
+        node_reqs.enforce_unique_topology_shape      = True
+        node_reqs.enforce_unique_memory_order        = False
+        node_reqs.enforce_unique_ghosts              = False
+        return node_reqs
 
     def discretize(self):
         super(HDF_IO, self).discretize()
@@ -148,15 +158,15 @@ class HDF_IO(ComputationalGraphOperator):
                 local_compute_slices[field]  = mesh.local_compute_slices
                 global_compute_slices[field] = mesh.global_compute_slices
             else:
-                local_compute_slices[field] = tuple(slice(0, 0) for _ in xrange(self.domain.dim))
+                local_compute_slices[field]  = tuple(slice(0, 0) for _ in xrange(self.domain.dim))
                 global_compute_slices[field] = tuple(slice(0, 0) for _ in xrange(self.domain.dim))
         self._local_compute_slices = local_compute_slices
         self._global_compute_slices = global_compute_slices
         self.refmesh = refmesh
 
-    def setup(self, work=None):
-        super(HDF_IO, self).setup(work=work)
-        # No list of hdf dataset names provided by user ...
+    #def setup(self, work=None):
+        #super(HDF_IO, self).setup(work=work)
+         #No list of hdf dataset names provided by user ...
 
         name_prefix, name_postfix = self.name_prefix, self.name_postfix
         if (self.var_names is None):
@@ -246,10 +256,47 @@ class HDF_Writer(HDF_IO):
         # if that happens.
         self._last_written_time = None
         self._xmf_file = None
+        self._data_getters = {}
+
+    def get_work_properties(self, **kwds):
+        requests = super(HDF_Writer, self).get_work_properties(**kwds)
+
+        max_bytes = 0
+        for (name, data) in self.dataset.iteritems():
+            if (data.backend.kind != Backend.HOST):
+                # we need a host buffer to get the data
+                max_bytes = max(data.nbytes, max_bytes)
+                host_backend = data.backend.host_array_backend
+        
+        if (max_bytes > 0):
+            request = MemoryRequest(backend=host_backend, size=max_bytes, dtype=npw.uint8)
+            requests.push_mem_request(request_identifier='buffer', mem_request=request)
+
+        return requests
 
-    def setup(self, **kwds):
-        super(HDF_Writer,self).setup(**kwds)
+    def setup(self, work, **kwds):
+        super(HDF_Writer, self).setup(work=work, **kwds)
         self._setup_grid_template()
+        for (name, data) in self.dataset.iteritems():
+            data = data[self._local_compute_slices[name]]
+            if (data.backend.kind is Backend.HOST):
+                def get_data(data=data):
+                    return data.handle
+            elif (data.backend.kind is Backend.OPENCL):
+                from hysop.backend.device.opencl.opencl_copy_kernel_launchers import OpenClCopyBufferRectLauncher
+                buf, = work.get_buffer(self, 'buffer', handle=True) 
+                assert buf.dtype == npw.uint8
+                assert buf.size >= data.nbytes
+                buf = buf[:data.nbytes].view(dtype=data.dtype).reshape(data.shape)
+                cpy = OpenClCopyBufferRectLauncher.from_slices(varname=name, src=data, dst=buf)
+                cpy = functools.partial(cpy, queue=data.backend.cl_env.default_queue)
+                def get_data(cpy=cpy, buf=buf):
+                    cpy().wait()
+                    return buf
+            else:
+                msg='Data type not understood or unknown array backend.'
+                raise NotImplementedError(msg)
+            self._data_getters[name] = get_data
 
     def finalize(self):
         if self._xmf_file:
@@ -355,11 +402,10 @@ class HDF_Writer(HDF_IO):
         for name in self.dataset:
             ds = self._hdf_file.create_dataset(name,
                                                self._global_grid_resolution,
-                                               dtype=HYSOP_REAL,
+                                               dtype=npw.float64,
                                                compression=compression)
-            # In parallel, each proc must write at the right place of the dataset
-            data = self.dataset[name].get()
-            ds[self._global_compute_slices[name]] = npw.asrealarray(data[self._local_compute_slices[name]])
+                # In parallel, each proc must write at the right place of the dataset
+            ds[self._global_compute_slices[name]] = self._data_getters[name]()
         
         # Collect datas required to write the xdmf file
         # --> add tuples (counter, time).
-- 
GitLab