From 68dfc1217fa87ab4ec80d68ae1c7f0f7e141fd35 Mon Sep 17 00:00:00 2001
From: Jean-Matthieu Etancelin <jean-matthieu.etancelin@univ-pau.fr>
Date: Mon, 7 Mar 2022 14:02:05 +0100
Subject: [PATCH] add custom opencl operator based on PyOpenCL's
 ElementwiseKernel

---
 .../backend/device/opencl/operator/custom.py  |  59 ++++++
 hysop/backend/host/python/operator/custom.py  | 102 ++--------
 hysop/fields/cartesian_discrete_field.py      |  10 +-
 hysop/operator/base/custom.py                 | 102 ++++++++++
 hysop/operator/custom.py                      |  27 +--
 hysop/operator/tests/test_custom.py           | 178 ++++++++++++++++++
 6 files changed, 367 insertions(+), 111 deletions(-)
 create mode 100644 hysop/backend/device/opencl/operator/custom.py
 create mode 100644 hysop/operator/base/custom.py
 create mode 100644 hysop/operator/tests/test_custom.py

diff --git a/hysop/backend/device/opencl/operator/custom.py b/hysop/backend/device/opencl/operator/custom.py
new file mode 100644
index 000000000..cde464bb8
--- /dev/null
+++ b/hysop/backend/device/opencl/operator/custom.py
@@ -0,0 +1,59 @@
+from hysop.constants import DirectionLabels
+from hysop.tools.decorators import debug
+from hysop.core.memory.memory_request import MemoryRequest
+from hysop.backend.device.opencl.opencl_operator import OpenClOperator, op_apply
+from hysop.backend.device.opencl.autotunable_kernels.custom_symbolic import OpenClAutotunableCustomSymbolicKernel
+from hysop.backend.device.opencl.opencl_kernel_launcher import OpenClKernelListLauncher
+from hysop.backend.device.opencl.opencl_copy_kernel_launchers import OpenClCopyBufferRectLauncher
+from hysop.operator.base.custom import CustomOperatorBase
+
+from hysop.backend.device.codegen.structs.mesh_info import MeshBaseStruct, MeshInfoStruct
+from hysop.backend.device.codegen.base.variables import dtype_to_ctype
+from hysop.backend.device.codegen.base.opencl_codegen import OpenClCodeGenerator
+
+from pyopencl.elementwise import ElementwiseKernel
+
+
+class OpenClCustomOperator(CustomOperatorBase, OpenClOperator):
+
+    @debug
+    def __new__(cls, **kwds):
+        return super().__new__(cls, **kwds)
+
+    @debug
+    def __init__(self, **kwds):
+        super().__init__(**kwds)
+
+    @debug
+    def setup(self, *args, **kwargs):
+        super().setup(*args, **kwargs)
+        dim = self.domain.dim
+        cg = OpenClCodeGenerator('test_generator', self.typegen)
+        mbs = MeshBaseStruct(self.typegen, typedef='{}MeshBase{}D_s'.format(self.typegen.fbtype[0], dim),
+                             vsize=dim)
+        cg.require('mis', MeshInfoStruct(self.typegen,
+                                         typedef='{}MeshInfo{}D_s'.format(self.typegen.fbtype[0], dim),
+                                         mbs_typedef=mbs.typedef, vsize=dim))
+        kernel_args = []
+        for f in self.discrete_fields:
+            fn = f.continuous_fields()[0].name
+            mesh_info = MeshInfoStruct.create_from_mesh(fn+"_mesh", self.typegen, f.mesh,
+                                                        storage=OpenClCodeGenerator.default_keywords['constant'])
+            mesh_info[1].declare(cg, _const=True)
+            cg.append(f"""int3 get_{fn}i_xyz(int i) {{
+              int iz = i/({fn}_mesh.local_mesh.resolution.x*{fn}_mesh.local_mesh.resolution.y);
+              int iy = (i-({fn}_mesh.local_mesh.resolution.x*{fn}_mesh.local_mesh.resolution.y)*iz)/({fn}_mesh.local_mesh.resolution.x);
+              return (int3)(i % {fn}_mesh.local_mesh.resolution.x,iy,iz);}}""")
+            kernel_args.append(f"{self.typegen.fbtype} *{fn}")
+        self.__elementwise = ElementwiseKernel(
+            self.cl_env.context,
+            ",".join(kernel_args),
+            self.func,
+            f"__{self.name}_elementwise", preamble=str(cg))
+
+    @op_apply
+    def apply(self, **kwds):
+        super().apply(**kwds)
+        self.__elementwise(*(_.sbuffer for _ in self.dinvar + self.dinparam + self.doutvar + self.doutparam + self.extra_args))
+        for gh_exch in self.ghost_exchanger:
+            gh_exch.exchange_ghosts()
diff --git a/hysop/backend/host/python/operator/custom.py b/hysop/backend/host/python/operator/custom.py
index feb947cf5..d5c6cf3b0 100644
--- a/hysop/backend/host/python/operator/custom.py
+++ b/hysop/backend/host/python/operator/custom.py
@@ -5,106 +5,44 @@ from hysop.parameters.parameter import Parameter
 from hysop.topology.cartesian_descriptor import CartesianTopologyDescriptors
 from hysop.backend.host.host_operator import HostOperator
 from hysop.core.graph.graph import op_apply
+from hysop.operator.base.custom import CustomOperatorBase
 
 
-class PythonCustomOperator(HostOperator):
+class PythonCustomOperator(CustomOperatorBase, HostOperator):
 
     @debug
-    def __new__(cls, func, invars=None, outvars=None,
-                 extra_args=None, variables=None, ghosts=None, **kwds):
-        return super(PythonCustomOperator, cls).__new__(cls,
-            input_fields=None, output_fields=None,
-            input_params=None, output_params=None,
-            **kwds)
+    def __new__(cls, **kwds):
+        return super().__new__(cls, **kwds)
 
     @debug
-    def __init__(self, func, invars=None, outvars=None,
-                 extra_args=None, variables=None, ghosts=None, **kwds):
-        check_instance(invars, (tuple, list), values=(Field, Parameter),
-                       allow_none=True)
-        check_instance(outvars, (tuple, list), values=(Field, Parameter),
-                       allow_none=True)
-        check_instance(extra_args, tuple, allow_none=True)
-        check_instance(variables, dict, keys=Field,
-                       values=CartesianTopologyDescriptors,
-                       allow_none=True)
-        check_instance(ghosts, int, allow_none=True)
-        input_fields, output_fields = {}, {}
-        input_params, output_params = {}, {}
+    def __init__(self, func, invars=None, outvars=None,  extra_args=None, **kwds):
+        super().__init__(func, invars=invars, outvars=outvars, extra_args=extra_args, **kwds)
+
+        from inspect import signature
+        nb_args = len(signature(func).parameters)
+        nb_in_f, nb_in_p, nb_out_f, nb_out_p, nb_extra = 0, 0, 0, 0, 0
         if invars is not None:
             for v in invars:
                 if isinstance(v, Field):
-                    input_fields[v] = variables[v]
+                    nb_in_f += v.nb_components
                 elif isinstance(v, Parameter):
-                    input_params[v.name] = v
+                    nb_in_p += 1
         if outvars is not None:
             for v in outvars:
                 if isinstance(v, Field):
-                    output_fields[v] = variables[v]
+                    nb_out_f += v.nb_components
                 elif isinstance(v, Parameter):
-                    output_params[v.name] = v
-        self.invars, self.outvars = invars, outvars
-        self.func = func
-        self.extra_args = tuple()
+                    nb_out_p += 1
         if not extra_args is None:
-            self.extra_args = extra_args
-        self._ghosts = ghosts
-        super(PythonCustomOperator, self).__init__(
-            input_fields=input_fields, output_fields=output_fields,
-            input_params=input_params, output_params=output_params,
-            **kwds)
-
-    @classmethod
-    def supports_mpi(cls):
-        return True
-
-    @debug
-    def get_field_requirements(self):
-        requirements = super(PythonCustomOperator, self).get_field_requirements()
-        if not self._ghosts is None:
-            for it in requirements.iter_requirements():
-                if not it[1] is None:
-                    is_input, (field, td, req) = it
-                    min_ghosts = (max(g, self._ghosts) for g in req.min_ghosts.copy())
-                    max_ghosts = (min(g, self._ghosts) for g in req.max_ghosts.copy())
-                    req.min_ghosts = min_ghosts
-                    req.max_ghosts = max_ghosts
-        return requirements
-
-    @debug
-    def discretize(self):
-        if self.discretized:
-            return
-        super(PythonCustomOperator, self).discretize()
-        dinvar, dinparam = [], []
-        doutvar, doutparam = [], []
-        idf, odf = self.input_discrete_fields, self.output_discrete_fields
-        self.ghost_exchanger = []
-        if self.invars is not None:
-            for v in self.invars:
-                if isinstance(v, Field):
-                    for _v in v if isinstance(v, VectorField) else (v, ):
-                        for vd in idf[_v]:
-                            dinvar.append(vd)
-                elif isinstance(v, Parameter):
-                    dinparam.append(v)
-        if self.outvars is not None:
-            for v in self.outvars:
-                if isinstance(v, Field):
-                    for _v in v if isinstance(v, VectorField) else (v, ):
-                        for vd in self.output_discrete_fields[_v]:
-                            doutvar.append(vd)
-                        gh = self.output_discrete_fields[_v].build_ghost_exchanger()
-                        if gh is not None:
-                            self.ghost_exchanger.append(gh)
-                elif isinstance(v, Parameter):
-                    doutparam.append(v)
-        self.dinvar, self.doutvar = tuple(dinvar), tuple(doutvar)
-        self.dinparam, self.doutparam = tuple(dinparam), tuple(doutparam)
+            nb_extra = len(extra_args)
+        msg = "function arguments ({}) did not match given in/out ".format(signature(func))
+        msg += "fields and parameters ({} input fields, {} input params,".format(nb_in_f, nb_in_p)
+        msg += " {} output fields, {} output params).".format(nb_out_f, nb_out_p)
+        assert nb_args == nb_in_f + nb_in_p + nb_out_f + nb_out_p + nb_extra, msg
 
     @op_apply
     def apply(self, **kwds):
-        super(PythonCustomOperator, self).apply(**kwds)
+        super().apply(**kwds)
         self.func(*(self.dinvar + self.dinparam + self.doutvar + self.doutparam + self.extra_args))
         for gh_exch in self.ghost_exchanger:
             gh_exch.exchange_ghosts()
diff --git a/hysop/fields/cartesian_discrete_field.py b/hysop/fields/cartesian_discrete_field.py
index 949d6e767..f91c98c87 100644
--- a/hysop/fields/cartesian_discrete_field.py
+++ b/hysop/fields/cartesian_discrete_field.py
@@ -635,8 +635,7 @@ class CartesianDiscreteScalarFieldView(CartesianDiscreteScalarFieldViewContainer
     @debug
     def __init__(self, dfield, topology_state, **kwds):
         super(CartesianDiscreteScalarFieldView, self).__init__(dfield=dfield,
-            topology_state=topology_state, **kwds)
-
+                                                               topology_state=topology_state, **kwds)
 
     def _compute_data_view(self, data=None):
         """
@@ -1118,7 +1117,7 @@ class CartesianDiscreteScalarFieldView(CartesianDiscreteScalarFieldViewContainer
             tstate._is_read_only = is_read_only
 
         bfield = self._dfield._field
-        btopo  = self._dfield._topology
+        btopo = self._dfield._topology
 
         field = bfield.field_like(name=name, pretty_name=pretty_name,
                                   latex_name=latex_name, var_name=var_name,
@@ -1537,9 +1536,9 @@ class CartesianDiscreteScalarField(CartesianDiscreteScalarFieldView, DiscreteSca
 
     @debug
     def __init__(self, field, topology, init_topology_state=None,
-                allocate_data=True, **kwds):
+                 allocate_data=True, **kwds):
         super(CartesianDiscreteScalarField, self).__init__(field=field, topology=topology,
-                topology_state=None, dfield=None, **kwds)
+                                                           topology_state=None, dfield=None, **kwds)
 
     def _handle_data(self, data):
         assert (self._data is None)
@@ -1590,7 +1589,6 @@ class TmpCartesianDiscreteScalarField(CartesianDiscreteScalarField):
                                                                   register_discrete_field=True, **kwds)
         return obj
 
-
     @debug
     def __init__(self, **kwds):
         super(TmpCartesianDiscreteScalarField, self).__init__(allocate_data=False,
diff --git a/hysop/operator/base/custom.py b/hysop/operator/base/custom.py
new file mode 100644
index 000000000..6f787e507
--- /dev/null
+++ b/hysop/operator/base/custom.py
@@ -0,0 +1,102 @@
+from hysop.tools.decorators import debug
+from hysop.tools.types import check_instance
+from hysop.fields.continuous_field import Field, VectorField
+from hysop.parameters.parameter import Parameter
+from hysop.topology.cartesian_descriptor import CartesianTopologyDescriptors
+from hysop.core.graph.graph import op_apply
+
+
+class CustomOperatorBase():
+
+    @debug
+    def __new__(cls, func, invars=None, outvars=None,
+                extra_args=None, variables=None, ghosts=None, **kwds):
+        return super().__new__(cls,
+                               input_fields=None, output_fields=None,
+                               input_params=None, output_params=None,
+                               **kwds)
+
+    @debug
+    def __init__(self, func, invars=None, outvars=None,
+                 extra_args=None, variables=None, ghosts=None, **kwds):
+        check_instance(invars, (tuple, list), values=(Field, Parameter),
+                       allow_none=True)
+        check_instance(outvars, (tuple, list), values=(Field, Parameter),
+                       allow_none=True)
+        check_instance(extra_args, tuple, allow_none=True)
+        check_instance(variables, dict, keys=Field,
+                       values=CartesianTopologyDescriptors,
+                       allow_none=True)
+        check_instance(ghosts, int, allow_none=True)
+        input_fields, output_fields = {}, {}
+        input_params, output_params = {}, {}
+        if invars is not None:
+            for v in invars:
+                if isinstance(v, Field):
+                    input_fields[v] = variables[v]
+                elif isinstance(v, Parameter):
+                    input_params[v.name] = v
+        if outvars is not None:
+            for v in outvars:
+                if isinstance(v, Field):
+                    output_fields[v] = variables[v]
+                elif isinstance(v, Parameter):
+                    output_params[v.name] = v
+        self.invars, self.outvars = invars, outvars
+        self.func = func
+        self.extra_args = tuple()
+        if not extra_args is None:
+            self.extra_args = extra_args
+        self._ghosts = ghosts
+        super().__init__(
+            input_fields=input_fields, output_fields=output_fields,
+            input_params=input_params, output_params=output_params,
+            **kwds)
+
+    @debug
+    def get_field_requirements(self):
+        requirements = super().get_field_requirements()
+        if not self._ghosts is None:
+            for it in requirements.iter_requirements():
+                if not it[1] is None:
+                    is_input, (field, td, req) = it
+                    min_ghosts = (max(g, self._ghosts) for g in req.min_ghosts.copy())
+                    max_ghosts = (min(g, self._ghosts) for g in req.max_ghosts.copy())
+                    req.min_ghosts = min_ghosts
+                    req.max_ghosts = max_ghosts
+        return requirements
+
+    @debug
+    def discretize(self):
+        if self.discretized:
+            return
+        super().discretize()
+        dinvar, dinparam = [], []
+        doutvar, doutparam = [], []
+        idf, odf = self.input_discrete_fields, self.output_discrete_fields
+        self.ghost_exchanger = []
+        if self.invars is not None:
+            for v in self.invars:
+                if isinstance(v, Field):
+                    for _v in v if isinstance(v, VectorField) else (v, ):
+                        for vd in idf[_v]:
+                            dinvar.append(vd)
+                elif isinstance(v, Parameter):
+                    dinparam.append(v)
+        if self.outvars is not None:
+            for v in self.outvars:
+                if isinstance(v, Field):
+                    for _v in v if isinstance(v, VectorField) else (v, ):
+                        for vd in self.output_discrete_fields[_v]:
+                            doutvar.append(vd)
+                        gh = self.output_discrete_fields[_v].build_ghost_exchanger()
+                        if gh is not None:
+                            self.ghost_exchanger.append(gh)
+                elif isinstance(v, Parameter):
+                    doutparam.append(v)
+        self.dinvar, self.doutvar = tuple(dinvar), tuple(doutvar)
+        self.dinparam, self.doutparam = tuple(dinparam), tuple(doutparam)
+
+    @classmethod
+    def supports_mpi(cls):
+        return True
diff --git a/hysop/operator/custom.py b/hysop/operator/custom.py
index bc34b45d9..f033f8e87 100644
--- a/hysop/operator/custom.py
+++ b/hysop/operator/custom.py
@@ -20,9 +20,11 @@ class CustomOperator(ComputationalGraphNodeFrontend):
     @classmethod
     def implementations(cls):
         from hysop.backend.host.python.operator.custom import PythonCustomOperator
+        from hysop.backend.device.opencl.operator.custom import OpenClCustomOperator
 
         __implementations = {
             Implementation.PYTHON: PythonCustomOperator,
+            Implementation.OPENCL: OpenClCustomOperator,
         }
         return __implementations
 
@@ -32,8 +34,8 @@ class CustomOperator(ComputationalGraphNodeFrontend):
 
     @debug
     def __new__(cls, func, invars=None, outvars=None, extra_args=None, ghosts=None, **kwds):
-        return super(CustomOperator, cls).__new__(cls,
-                                                  func=func, invars=invars, outvars=outvars, extra_args=extra_args, ghosts=ghosts, **kwds)
+        return super(CustomOperator, cls).__new__(
+            cls, func=func, invars=invars, outvars=outvars, extra_args=extra_args, ghosts=ghosts, **kwds)
 
     @debug
     def __init__(self, func, invars=None, outvars=None, extra_args=None, ghosts=None, **kwds):
@@ -43,27 +45,6 @@ class CustomOperator(ComputationalGraphNodeFrontend):
                        allow_none=True)
         check_instance(extra_args, tuple, allow_none=True)
         check_instance(ghosts, int, allow_none=True)
-        from inspect import signature
-        nb_args = len(signature(func).parameters)
-        nb_in_f, nb_in_p, nb_out_f, nb_out_p, nb_extra = 0, 0, 0, 0, 0
-        if invars is not None:
-            for v in invars:
-                if isinstance(v, Field):
-                    nb_in_f += v.nb_components
-                elif isinstance(v, Parameter):
-                    nb_in_p += 1
-        if outvars is not None:
-            for v in outvars:
-                if isinstance(v, Field):
-                    nb_out_f += v.nb_components
-                elif isinstance(v, Parameter):
-                    nb_out_p += 1
-        if not extra_args is None:
-            nb_extra = len(extra_args)
-        msg = "function arguments ({}) did not match given in/out ".format(signature(func))
-        msg += "fields and parameters ({} input fields, {} input params,".format(nb_in_f, nb_in_p)
-        msg += " {} output fields, {} output params).".format(nb_out_f, nb_out_p)
-        assert nb_args == nb_in_f + nb_in_p + nb_out_f + nb_out_p + nb_extra, msg
 
         super(CustomOperator, self).__init__(
             func=func, invars=invars, outvars=outvars, extra_args=extra_args, ghosts=ghosts, **kwds)
diff --git a/hysop/operator/tests/test_custom.py b/hysop/operator/tests/test_custom.py
new file mode 100644
index 000000000..c1a93cd2a
--- /dev/null
+++ b/hysop/operator/tests/test_custom.py
@@ -0,0 +1,178 @@
+"""Test custom operator."""
+from hysop.constants import HYSOP_REAL
+from hysop.testsenv import __ENABLE_LONG_TESTS__, __HAS_OPENCL_BACKEND__
+from hysop.testsenv import opencl_failed, iter_clenv
+from hysop.tools.contexts import printoptions
+from hysop.tools.types import check_instance, first_not_None
+from hysop.tools.numpywrappers import npw
+from hysop.tools.io_utils import IO
+
+from hysop import Field, Box
+from hysop.operators import CustomOperator
+from hysop.constants import Implementation
+
+
+class TestCustom(object):
+    @classmethod
+    def setup_class(cls,
+                    enable_extra_tests=__ENABLE_LONG_TESTS__,
+                    enable_debug_mode=False):
+
+        IO.set_default_path('/tmp/hysop_tests/test_custom')
+
+        if enable_debug_mode:
+            cls.size_min = 15
+            cls.size_max = 16
+        else:
+            cls.size_min = 23
+            cls.size_max = 87
+
+        cls.enable_extra_tests = enable_extra_tests
+        cls.enable_debug_mode = enable_debug_mode
+
+    @classmethod
+    def teardown_class(cls):
+        pass
+
+    def perform_tests(self):
+        self._test(dim=3, dtype=HYSOP_REAL)
+
+    @staticmethod
+    def __analytic_init(data, coords, component):
+        (x, y, z) = coords
+        data[...] = (x**2)*npw.sin(y)*npw.exp(z)
+        for _ in range(1, 6):
+            data[...] += (x**2)*npw.sin(y)*npw.exp(z*_)
+
+    @staticmethod
+    def __analytic_python(F):
+        (x, y, z) = F.compute_mesh_coords
+        F.data[0][...] = (x**2)*npw.sin(y)*npw.exp(z)
+        for _ in range(1, 6):
+            F.data[0][...] += (x**2)*npw.sin(y)*npw.exp(z*_)
+
+    __analytic_opencl = """
+    int3 i_xyz = get_Fi_xyz(i);
+    double3 xyz = (double3)(F_mesh.local_mesh.xmin.x+i_xyz.x*F_mesh.dx.x,
+                            F_mesh.local_mesh.xmin.y+i_xyz.y*F_mesh.dx.y,
+                            F_mesh.local_mesh.xmin.z+i_xyz.z*F_mesh.dx.z);
+    double Fi = xyz.x*xyz.x*sin(xyz.y)*exp(xyz.z);
+    for(int k=1; k<6;k++) Fi += xyz.x*xyz.x*sin(xyz.y)*exp(xyz.z*k);
+    F[i] = Fi;
+    """
+
+    def _test(self, dim, dtype,
+              size_min=None, size_max=None):
+        enable_extra_tests = self.enable_extra_tests
+
+        size_min = first_not_None(size_min, self.size_min)
+        size_max = first_not_None(size_max, self.size_max)
+
+        shape = tuple(npw.random.randint(low=size_min, high=size_max+1, size=dim).tolist())
+
+        domain = Box(length=(1,)*dim)
+        F = Field(domain=domain, name='F', dtype=dtype,
+                  nb_components=1)
+        print(' >Testing all implementations:')
+
+        implementations = CustomOperator.implementations()
+        variables = {F: shape}
+
+        def iter_impl(impl):
+            base_kwds = dict(invars=(), outvars=(F,),
+                             variables=variables,
+                             implementation=impl,
+                             name='custom_{}'.format(str(impl).lower()))
+            if impl is Implementation.PYTHON:
+                msg = '   *Python: '
+                print(msg, end=' ')
+                yield CustomOperator(func=self.__analytic_python, **base_kwds)
+                print()
+            elif impl is Implementation.OPENCL:
+                msg = '   *OpenCL: '
+                print(msg)
+                for cl_env in iter_clenv():
+                    print('      *platform {}, device {}: '.format(cl_env.platform.name.strip(),
+                                                                   cl_env.device.name.strip()), end=' ')
+                    yield CustomOperator(cl_env=cl_env,
+                                         func=self.__analytic_opencl, **base_kwds)
+                print()
+            else:
+                msg = 'Unknown implementation to test {}.'.format(impl)
+                raise NotImplementedError(msg)
+
+        print('\nTesting {}D Custom Operator: dtype={} shape={}'.format(
+            dim, dtype.__name__, shape))
+        Fref = None
+        for impl in implementations:
+            for op in iter_impl(impl):
+                op = op.build()
+                dF = op.get_output_discrete_field(F)
+
+                if (Fref is None):
+                    dF.initialize(self.__analytic_init)
+                    Fref = tuple(data.get().handle.copy() for data in dF.data)
+
+                op.apply()
+
+                Fout = tuple(data.get().handle.copy() for data in dF.data)
+                self._check_output(impl, op, Fref, Fout)
+
+    @classmethod
+    def _check_output(cls, impl, op, Fref, Fout):
+        check_instance(Fref, tuple,   values=npw.ndarray)
+        check_instance(Fout, tuple,   values=npw.ndarray, size=len(Fref))
+
+        msg0 = 'Reference field {} is not finite.'
+        for (i, field) in enumerate(Fref):
+            iname = 'F{}'.format(i)
+            mask = npw.isfinite(field)
+            if not mask.all():
+                print()
+                print(field)
+                print()
+                print(field[~mask])
+                print()
+                msg = msg0.format(iname)
+                raise ValueError(msg)
+
+        for i, (fout, fref) in enumerate(zip(Fout, Fref)):
+            iname = '{}{}'.format('F', i)
+            assert fout.dtype == fref.dtype, iname
+            assert fout.shape == fref.shape, iname
+
+            eps = npw.finfo(fout.dtype).eps
+            dist = npw.abs(fout-fref)
+            dinf = npw.max(dist)
+            deps = int(npw.ceil(dinf/eps))
+            if (deps < 1000):
+                print('{}eps, '.format(deps), end=' ')
+                continue
+            has_nan = npw.any(npw.isnan(fout))
+            has_inf = npw.any(npw.isinf(fout))
+
+            print()
+            print()
+            print('Test output comparisson for {} failed for component {}:'.format(iname, i))
+            print(' *has_nan: {}'.format(has_nan))
+            print(' *has_inf: {}'.format(has_inf))
+            print(' *dinf={} ({} eps)'.format(dinf, deps))
+            print()
+            print(fout[:, 3, 4])
+            print(fref[:, 3, 4])
+            msg = 'Test failed for {} on component {} for implementation {}.'.format(iname, i, impl)
+            raise RuntimeError(msg)
+
+
+if __name__ == '__main__':
+    TestCustom.setup_class(enable_extra_tests=False,
+                           enable_debug_mode=False)
+
+    test = TestCustom()
+
+    with printoptions(threshold=10000, linewidth=240,
+                      nanstr='nan', infstr='inf',
+                      formatter={'float': lambda x: '{:>6.2f}'.format(x)}):
+        test.perform_tests()
+
+    TestCustom.teardown_class()
-- 
GitLab