From d5d04ed7aace90c482de8e679734a7a3f84b4310 Mon Sep 17 00:00:00 2001
From: Jean-Baptiste Keck <Jean-Baptiste.Keck@imag.fr>
Date: Sun, 1 Mar 2020 17:27:43 +0100
Subject: [PATCH] fix 1D HPTT

---
 hysop/numerics/fft/host_fft.py                |  6 +++--
 .../tests/test_spectral_derivative.py         | 26 +++++++++----------
 2 files changed, 17 insertions(+), 15 deletions(-)

diff --git a/hysop/numerics/fft/host_fft.py b/hysop/numerics/fft/host_fft.py
index a57cfef8d..3a96f03d9 100644
--- a/hysop/numerics/fft/host_fft.py
+++ b/hysop/numerics/fft/host_fft.py
@@ -30,7 +30,7 @@ def can_exec_hptt(src, dst):
         return False
     if src.flags['F_CONTIGUOUS'] != dst.flags['F_CONTIGUOUS']:
         return False
-    if not (src.flags['C_CONTIGUOUS'] or src.flags['F_CONTIGUOUS']):
+    if not (src.flags['C_CONTIGUOUS'] ^ src.flags['F_CONTIGUOUS']):
         return False
     return (src.data is not dst.data)
 
@@ -137,7 +137,9 @@ class HostFFTI(FFTI):
         @static_vars(numba_copy=None)
         def exec_copy(src=src, dst=dst):
             src, dst = src(), dst()
-            if HAS_HPTT and can_exec_hptt(src, dst):
+            if (src.ndim == 1):
+                dst[...] = src
+            elif HAS_HPTT and can_exec_hptt(src, dst):
                 hptt.tensorTransposeAndUpdate(perm=range(src.ndim),
                         alpha=1.0, A=src, beta=0.0, B=dst)
             elif HAS_NUMBA:
diff --git a/hysop/operator/tests/test_spectral_derivative.py b/hysop/operator/tests/test_spectral_derivative.py
index b8f3cb7f0..2bbe4e45c 100644
--- a/hysop/operator/tests/test_spectral_derivative.py
+++ b/hysop/operator/tests/test_spectral_derivative.py
@@ -319,24 +319,24 @@ class TestSpectralDerivative(object):
 
 
 
-    # def test_1d_trigonometric_float32(self, **kwds):
-    #     self._test(dim=1, dtype=npw.float32, polynomial=False, **kwds)
-    # def test_2d_trigonometric_float32(self, **kwds):
-    #     self._test(dim=2, dtype=npw.float32, polynomial=False, **kwds)
+    def test_1d_trigonometric_float32(self, **kwds):
+        self._test(dim=1, dtype=npw.float32, polynomial=False, **kwds)
+    def test_2d_trigonometric_float32(self, **kwds):
+        self._test(dim=2, dtype=npw.float32, polynomial=False, **kwds)
     def test_3d_trigonometric_float32(self, **kwds):
         self._test(dim=3, dtype=npw.float32, polynomial=False, **kwds)
 
-    # def test_1d_trigonometric_float64(self, **kwds):
-    #     self._test(dim=1, dtype=npw.float64, polynomial=False, **kwds)
-    # def test_2d_trigonometric_float64(self, **kwds):
-    #     self._test(dim=2, dtype=npw.float64, polynomial=False, **kwds)
+    def test_1d_trigonometric_float64(self, **kwds):
+        self._test(dim=1, dtype=npw.float64, polynomial=False, **kwds)
+    def test_2d_trigonometric_float64(self, **kwds):
+        self._test(dim=2, dtype=npw.float64, polynomial=False, **kwds)
     def test_3d_trigonometric_float64(self, **kwds):
         self._test(dim=3, dtype=npw.float64, polynomial=False, **kwds)
 
-    # def test_1d_polynomial_float32(self, **kwds):
-    #     self._test(dim=1, dtype=npw.float32, polynomial=True, **kwds)
-    # def test_2d_polynomial_float32(self, **kwds):
-    #     self._test(dim=2, dtype=npw.float32, polynomial=True, **kwds)
+    def test_1d_polynomial_float32(self, **kwds):
+        self._test(dim=1, dtype=npw.float32, polynomial=True, **kwds)
+    def test_2d_polynomial_float32(self, **kwds):
+        self._test(dim=2, dtype=npw.float32, polynomial=True, **kwds)
     def test_3d_polynomial_float32(self, **kwds):
         self._test(dim=3, dtype=npw.float32, polynomial=True, **kwds)
 
@@ -344,7 +344,7 @@ class TestSpectralDerivative(object):
         max_2d_runs = None if __ENABLE_LONG_TESTS__ else 2
         max_3d_runs = None if __ENABLE_LONG_TESTS__ else 2
 
-        # self.test_1d_trigonometric_float32(max_derivative=3)
+        self.test_1d_trigonometric_float32(max_derivative=3)
         # self.test_2d_trigonometric_float32(max_derivative=2, max_runs=max_2d_runs)
         self.test_3d_trigonometric_float32(max_derivative=1, max_runs=max_3d_runs)
 
-- 
GitLab