From 726f883dc83ac0e439fcd36e25e318777a39b717 Mon Sep 17 00:00:00 2001
From: Jean-Baptiste Keck <Jean-Baptiste.Keck@imag.fr>
Date: Sat, 31 Aug 2019 01:48:49 +0200
Subject: [PATCH] flint support for stencil generation

---
 hysop/numerics/stencil/stencil_generator.py | 63 +++++++++++++++------
 hysop/tools/sympy_utils.py                  |  2 +
 2 files changed, 48 insertions(+), 17 deletions(-)

diff --git a/hysop/numerics/stencil/stencil_generator.py b/hysop/numerics/stencil/stencil_generator.py
index 1ddafa0b6..7bd1e91f5 100644
--- a/hysop/numerics/stencil/stencil_generator.py
+++ b/hysop/numerics/stencil/stencil_generator.py
@@ -4,11 +4,11 @@
 * :class:`~hysop.numerics.stencil.StencilGenerator`
 
 """
-
+import fractions
 from hysop.deps              import it, np, sp, sm, os, copy, math, gzip, pickle
 from hysop.tools.misc        import prod
 from hysop.tools.io_utils    import IO
-from hysop.tools.numerics    import MPQ, MPZ, MPFR, F2Q, mpqize
+from hysop.tools.numerics    import MPQ, MPZ, MPFR, F2Q, mpqize, mpq, mpz
 from hysop.tools.types       import extend_array
 from hysop.tools.cache       import update_cache, load_data_from_cache
 from hysop.tools.sympy_utils import tensor_symbol, tensor_xreplace, \
@@ -16,6 +16,13 @@ from hysop.tools.sympy_utils import tensor_symbol, tensor_xreplace, \
 
 from hysop.numerics.stencil.stencil import Stencil, CenteredStencil
 
+try:
+    import flint
+    has_flint = True
+except ImportError:
+    flint = None
+    has_flint = False
+
 class StencilGeneratorConfiguration(object):
 
     def __init__(self):
@@ -330,7 +337,10 @@ class StencilGenerator(object):
 
         if (dim!=1):
             raise ValueError('Bad dimension for approximation stencil generation!')
-        if dtype not in [np.float16, np.float32, np.float64]:
+
+        if (has_flint):
+            solve_dtype = flint.fmpq
+        elif dtype not in [np.float16, np.float32, np.float64]:
             solve_dtype = np.float64
         else:
             solve_dtype = dtype
@@ -347,28 +357,47 @@ class StencilGenerator(object):
         
         if k == 0:
             return Stencil([1],[0],0,dx=dx,error=None)
-
-        A = np.zeros((N,N),dtype=solve_dtype)
-        b = np.zeros(N,dtype=solve_dtype)
+        
+        A = np.empty((N,N),dtype=solve_dtype)
+        b = np.empty(N,dtype=solve_dtype)
         for i in xrange(N):
-            b[i] = solve_dtype(i==k)
+            b[i] = solve_dtype(long(i==k))
             for j in xrange(N):
-                A[i,j] = solve_dtype(j-origin)**i
+                A[i,j] = solve_dtype(long((j-origin)**i))
 
         try:
-            S = sp.linalg.solve(A,b,overwrite_a=True,overwrite_b=True)
-            S *= math.factorial(k)
+            if has_flint:
+                coeffs = A.ravel()
+                Afmpq = flint.fmpq_mat(*(A.shape+(coeffs,)))
+                Afmpq_inv = Afmpq.inv()
+                Ainv = np.asarray(Afmpq_inv.entries()).reshape(A.shape)
+                S = Ainv.dot(b)
+            else:
+                S = sp.linalg.solve(A,b,overwrite_a=True,overwrite_b=True)
         except:
             print '\nError: Cannot generate stencil (singular system).\n'
             raise
 
-        if dtype!=solve_dtype:
-            if dtype==MPQ:
-                import fractions
-                from hysop.tools.numerics import mpq
-                def convert(x):
-                    frac = fractions.Fraction(x).limit_denominator((1<<32)-1)
-                    return mpq(frac.numerator, frac.denominator)
+        S *= math.factorial(k)
+        
+        actual_dtype = type(S.ravel()[0])
+        target_dtype = dtype
+        if actual_dtype != target_dtype:
+            if target_dtype in [np.float16, np.float32, np.float64]:
+                if has_flint and (actual_dtype is flint.fmpq):
+                    def convert(x):
+                        return target_dtype(float(long(x.p)) / float(long(x.q)))
+                    S = np.vectorize(convert)(S)
+                else:
+                    S = S.astype(target_dtype)
+            elif target_dtype==MPQ:
+                if has_flint and (actual_dtype is flint.fmpq):
+                    def convert(x):
+                        return mpq(mpz(x.p.str()), mpz(x.q.str()))
+                else:
+                    def convert(x):
+                        frac = fractions.Fraction(x).limit_denominator((1<<32)-1)
+                        return mpq(frac.numerator, frac.denominator)
                 S = np.vectorize(convert)(S)
             else:
                 RuntimeError('Type conversion not implemented yet.')
diff --git a/hysop/tools/sympy_utils.py b/hysop/tools/sympy_utils.py
index d656de1e4..0b7d861fd 100644
--- a/hysop/tools/sympy_utils.py
+++ b/hysop/tools/sympy_utils.py
@@ -266,6 +266,8 @@ def tensor_xreplace(tensor,vars):
                 T[idx] = vars[symbol]
             elif (hasattr(symbol, 'name')) and (symbol.name in vars.keys()):
                 T[idx] = vars[symbol.name]
+            else:
+                T[idx] = symbol.xreplace(vars)
     return T
 
 def non_eval_xreplace(expr, rule):
-- 
GitLab