Vous avez reçu un message "Your GitLab account has been locked ..." ? Pas d'inquiétude : lisez cet article https://docs.gricad-pages.univ-grenoble-alpes.fr/help/unlock/

Commit 5e6d2e83 authored by EXT José Ignacio Requeno Jarabo's avatar EXT José Ignacio Requeno Jarabo
Browse files

Cythonization of Oracles

parent 3700f61a
Pipeline #47841 failed with stages
in 2 minutes and 52 seconds
......@@ -10,6 +10,7 @@ This module introduces a set of operations for parallelizing
the creation of comparable and incomparable rectangles of the space.
"""
from multiprocessing import Pool, cpu_count
import cython
from ParetoLib.Geometry.Rectangle import Rectangle, brect
from ParetoLib.Geometry.Point import dim
......@@ -19,7 +20,11 @@ from ParetoLib.Geometry.Point import dim
# Parallel version for the computation of incomparable rectangles in a space
############################################################################
@cython.ccall
@cython.locals(alpha=tuple, yrectangle=object, xspace=object, rect=object)
@cython.returns(object)
def pbrect(args):
# type: (iter) -> Rectangle
"""
Synonym of Rectangle.brect(alpha, yrectangle, xspace)
"""
......@@ -27,8 +32,9 @@ def pbrect(args):
return brect(alpha, yrectangle, xspace)
@cython.locals(nproc=cython.ushort, pool=object)
def pirect(alphaincomp, yrectangle, xspace):
# type: (list, Rectangle, Rectangle) -> list
# type: (list, Rectangle, Rectangle) -> iter
"""
Synonym of Rectangle.irect(alphaincomp, yrectangle, xspace)
"""
......@@ -61,6 +67,9 @@ def pirect(alphaincomp, yrectangle, xspace):
# i.e., these wrappers.
#############################################################################################
@cython.ccall
@cython.locals(rect=object)
@cython.returns(cython.double)
def pvol(rect):
# type: (Rectangle) -> float
"""
......@@ -69,6 +78,8 @@ def pvol(rect):
return rect.volume()
@cython.ccall
@cython.returns(list)
def pvertices(rect):
# type: (Rectangle) -> list
"""
......@@ -77,7 +88,11 @@ def pvertices(rect):
return rect.vertices()
@cython.ccall
@cython.locals(rect=object, xpoint=tuple)
@cython.returns(cython.bint)
def pinside(args):
# type: (iter) -> boolean
"""
Synonym of Rectangle.inside(xpoint)
"""
......
......@@ -33,12 +33,12 @@ RootOracle = ParetoLib.Oracle
@cython.cclass
class NDTree(object):
# cython.declare(root=object, max_points=cython.ushort, min_children=cython.ushort)
# cython.declare(root=object, max_points=cython.ulong, min_children=cython.ushort)
root = cython.declare(object, visibility='public')
max_points = cython.declare(cython.ushort, visibility='public')
max_points = cython.declare(cython.ulong, visibility='public')
min_children = cython.declare(cython.ushort, visibility='public')
@cython.locals(max_points=cython.ushort, min_children=cython.ushort)
@cython.locals(max_points=cython.ulong, min_children=cython.ushort)
@cython.returns(cython.void)
def __init__(self, max_points=2, min_children=2):
# type: (NDTree, int, int) -> None
......@@ -508,10 +508,10 @@ class Node(object):
parent = cython.declare(object, visibility='public')
nodes = cython.declare(list, visibility='public')
L = cython.declare(list, visibility='public')
max_points = cython.declare(cython.ushort, visibility='public')
max_points = cython.declare(cython.ulong, visibility='public')
min_children = cython.declare(cython.ushort, visibility='public')
@cython.locals(parent=object, max_points=cython.ushort, min_children=cython.ushort)
@cython.locals(parent=object, max_points=cython.ulong, min_children=cython.ushort)
@cython.returns(cython.void)
def __init__(self, parent=None, max_points=2, min_children=2):
# type: (Node, Node, int, int) -> None
......
......@@ -24,6 +24,8 @@ import io
from sortedcontainers import SortedSet
from sympy import simplify, expand, default_sort_key, Expr, Symbol
import cython
# import ParetoLib.Oracle as RootOracle
import ParetoLib.Oracle
from ParetoLib.Oracle.Oracle import Oracle
......@@ -32,8 +34,15 @@ RootOracle = ParetoLib.Oracle
# from ParetoLib._py3k import getoutput, viewvalues, viewitems
# @cython.cclass
class Condition(object):
cython.declare(comparison=list)
cython.declare(op=str)
cython.declare(f=object)
cython.declare(g=object)
@cython.locals(f=str, op=str, g=str)
@cython.returns(cython.void)
def __init__(self, f='x', op='==', g='0'):
# type: (Condition, str, str, str) -> None
"""
......@@ -70,6 +79,9 @@ class Condition(object):
str(
self._get_expression_with_negative_coeff())))
@cython.locals(poly_function=str, op_comp=str, op_regex=str, f_regex=str, g_regex=str, regex=str, regex_comp=object,
result=object)
@cython.returns(cython.void)
def init_from_string(self, poly_function):
# type: (Condition, str) -> None
"""
......@@ -111,6 +123,7 @@ class Condition(object):
str(
self._get_expression_with_negative_coeff())))
@cython.returns(str)
def __repr__(self):
# type: (Condition) -> str
"""
......@@ -118,6 +131,7 @@ class Condition(object):
"""
return self._to_str()
@cython.returns(str)
def __str__(self):
# type: (Condition) -> str
"""
......@@ -125,6 +139,7 @@ class Condition(object):
"""
return self._to_str()
@cython.returns(str)
def _to_str(self):
# type: (Condition) -> str
"""
......@@ -132,6 +147,7 @@ class Condition(object):
"""
return str(self.f) + self.op + str(self.g)
@cython.returns(cython.bint)
def __eq__(self, other):
# type: (Condition, Condition) -> bool
"""
......@@ -141,6 +157,7 @@ class Condition(object):
(self.op == other.op) and \
(self.g == other.g)
@cython.returns(cython.bint)
def __ne__(self, other):
# type: (Condition, Condition) -> bool
"""
......@@ -148,6 +165,7 @@ class Condition(object):
"""
return not self.__eq__(other)
@cython.returns(int)
def __hash__(self):
# type: (Condition) -> int
"""
......@@ -155,6 +173,8 @@ class Condition(object):
"""
return hash((self.f, self.op, self.g))
@cython.locals(p=tuple)
@cython.returns(cython.bint)
def __contains__(self, p):
# type: (Condition, tuple) -> bool
"""
......@@ -163,6 +183,9 @@ class Condition(object):
# Due to a peculiar behaviour, it is required to compare the result with "True"
return self.member(p) is True
@cython.ccall
@cython.locals(coeffs=dict, all_positives=cython.bint, i=object)
@cython.returns(cython.bint)
def all_coeff_are_positive(self):
# type: (Condition) -> bool
coeffs = self.get_coeff_of_expression()
......@@ -171,6 +194,9 @@ class Condition(object):
all_positives = all_positives and (coeffs[i] >= 0)
return all_positives
@cython.ccall
@cython.locals(expr=object, expanded_expr=object, simpl_expr=object, coeffs=dict)
@cython.returns(dict)
def get_coeff_of_expression(self):
# type: (Condition) -> dict
"""
......@@ -194,6 +220,9 @@ class Condition(object):
coeffs = simpl_expr.as_coefficients_dict()
return coeffs
@cython.ccall
@cython.locals(expr=object, expanded_expr=object, simpl_expr=object, coeffs=dict, positive_coeff=dict)
@cython.returns(dict)
def get_positive_coeff_of_expression(self):
# type: (Condition) -> dict
"""
......@@ -219,6 +248,9 @@ class Condition(object):
positive_coeff = {i: coeffs[i] for i in coeffs if coeffs[i] >= 0}
return positive_coeff
@cython.ccall
@cython.locals(expr=object, expanded_expr=object, simpl_expr=object, coeffs=dict, negative_coeff=dict)
@cython.returns(dict)
def get_negative_coeff_of_expression(self):
# type: (Condition) -> dict
"""
......@@ -244,18 +276,26 @@ class Condition(object):
negative_coeff = {i: coeffs[i] for i in coeffs if coeffs[i] < 0}
return negative_coeff
@cython.ccall
@cython.locals(negative_coeff=object, neg_expr=list)
@cython.returns(object)
def _get_expression_with_negative_coeff(self):
# type: (Condition) -> Expr
negative_coeff = self.get_negative_coeff_of_expression()
neg_expr = ['{0} * {1}'.format(negative_coeff[i], i) for i in negative_coeff]
return simplify(''.join(neg_expr))
@cython.ccall
@cython.locals(positive_coeff=object, pos_expr=list)
@cython.returns(object)
def _get_expression_with_positive_coeff(self):
# type: (Condition) -> Expr
positive_coeff = self.get_positive_coeff_of_expression()
pos_expr = ['{0} * {1}'.format(positive_coeff[i], i) for i in positive_coeff]
return simplify('+'.join(pos_expr))
@cython.ccall
@cython.returns(object)
def get_expression(self):
# type: (Condition) -> Expr
"""
......@@ -274,6 +314,9 @@ class Condition(object):
"""
return simplify(self.f - self.g)
@cython.ccall
@cython.locals(expr=object)
@cython.returns(list)
def get_variables(self):
# type: (Condition) -> list
"""
......@@ -293,8 +336,11 @@ class Condition(object):
expr = self.get_expression()
return sorted(expr.free_symbols, key=default_sort_key)
@cython.ccall
@cython.locals(variable=object, val=str, fvset=list, fv=object, expr=object, res=object, ex=str)
@cython.returns(object)
def eval_var_val(self, variable=None, val='0.0'):
# type: (Condition, Symbol, float) -> Expr
# type: (Condition, Symbol, str) -> Expr
"""
Substitutes a variable by a value in the polynomial expression of Condition.
......@@ -322,6 +368,9 @@ class Condition(object):
# RootOracle.logger.debug('Expression ' + str(simplify(ex)))
return simplify(ex)
@cython.ccall
@cython.locals(point=tuple, keys_fv=list, di=dict)
@cython.returns(object)
def eval_tuple(self, point):
# type: (Condition, tuple) -> Expr
"""
......@@ -347,6 +396,9 @@ class Condition(object):
# RootOracle.logger.debug('di ' + str(di))
return self.eval_dict(di)
@cython.ccall
@cython.locals(var_point=list, expr=object, res=object, ex=str)
@cython.returns(object)
def eval_zip_tuple(self, var_point):
# type: (Condition, list) -> Expr
"""
......@@ -372,6 +424,9 @@ class Condition(object):
# RootOracle.logger.debug('Expression ' + str(simplify(ex)))
return simplify(ex)
@cython.ccall
@cython.locals(d=dict, di=dict, keys_fv=list, keys=set, expr=object, res=object, ex=str)
@cython.returns(object)
def eval_dict(self, d=None):
# type: (Condition, dict) -> Expr
"""
......@@ -408,6 +463,8 @@ class Condition(object):
return simplify(ex)
# Membership functions
@cython.locals(point=tuple, di=dict)
@cython.returns(object)
def member(self, point):
# type: (Condition, tuple) -> Expr
"""
......@@ -431,6 +488,7 @@ class Condition(object):
di = {key: point[i] for i, key in enumerate(keys)}
return self.eval_dict(di)
@cython.returns(object)
def membership(self):
# type: (Condition) -> callable
"""
......@@ -453,6 +511,9 @@ class Condition(object):
return lambda xpoint: self.member(xpoint)
# Read/Write file functions
@cython.ccall
@cython.returns(cython.void)
@cython.locals(fname=str, human_readable=cython.bint, mode=str)
def from_file(self, fname='', human_readable=False):
# type: (Condition, str, bool) -> None
"""
......@@ -484,6 +545,9 @@ class Condition(object):
self.from_file_binary(finput)
finput.close()
@cython.ccall
@cython.returns(cython.void)
@cython.locals(finput=object)
def from_file_binary(self, finput=None):
# type: (Condition, io.BinaryIO) -> None
"""
......@@ -508,6 +572,9 @@ class Condition(object):
self.op = pickle.load(finput)
self.g = pickle.load(finput)
@cython.ccall
@cython.returns(cython.void)
@cython.locals(finput=object, poly_function=str)
def from_file_text(self, finput=None):
# type: (Condition, io.BinaryIO) -> None
"""
......@@ -531,6 +598,9 @@ class Condition(object):
poly_function = finput.readline()
self.init_from_string(poly_function)
@cython.ccall
@cython.returns(cython.void)
@cython.locals(fname=str, append=cython.bint, human_readable=cython.bint, mode=str)
def to_file(self, fname='', append=False, human_readable=False):
# type: (Condition, str, bool, bool) -> None
"""
......@@ -569,6 +639,9 @@ class Condition(object):
self.to_file_binary(foutput)
foutput.close()
@cython.ccall
@cython.returns(cython.void)
@cython.locals(foutput=object)
def to_file_binary(self, foutput=None):
# type: (Condition, io.BinaryIO) -> None
"""
......@@ -594,6 +667,9 @@ class Condition(object):
pickle.dump(self.op, foutput, pickle.HIGHEST_PROTOCOL)
pickle.dump(self.g, foutput, pickle.HIGHEST_PROTOCOL)
@cython.ccall
@cython.returns(cython.void)
@cython.locals(foutput=object)
def to_file_text(self, foutput=None):
# type: (Condition, io.BinaryIO) -> None
"""
......@@ -619,7 +695,12 @@ class Condition(object):
foutput.write(str(self) + '\n')
@cython.cclass
class OracleFunction(Oracle):
cython.declare(variables=object)
cython.declare(oracle=set)
@cython.returns(cython.void)
def __init__(self):
# type: (OracleFunction) -> None
"""
......@@ -630,6 +711,7 @@ class OracleFunction(Oracle):
self.variables = SortedSet([], key=default_sort_key)
self.oracle = set()
@cython.returns(str)
def __repr__(self):
# type: (OracleFunction) -> str
"""
......@@ -637,6 +719,7 @@ class OracleFunction(Oracle):
"""
return self._to_str()
@cython.returns(str)
def __str__(self):
# type: (OracleFunction) -> str
"""
......@@ -644,6 +727,8 @@ class OracleFunction(Oracle):
"""
return self._to_str()
@cython.ccall
@cython.returns(str)
def _to_str(self):
# type: (OracleFunction) -> str
"""
......@@ -651,6 +736,7 @@ class OracleFunction(Oracle):
"""
return str(self.oracle)
@cython.returns(cython.bint)
def __eq__(self, other):
# type: (OracleFunction, OracleFunction) -> bool
"""
......@@ -658,6 +744,7 @@ class OracleFunction(Oracle):
"""
return self.oracle == other.oracle
@cython.returns(cython.bint)
def __ne__(self, other):
# type: (OracleFunction, OracleFunction) -> bool
"""
......@@ -665,6 +752,7 @@ class OracleFunction(Oracle):
"""
return not self.__eq__(other)
@cython.returns(int)
def __hash__(self):
# type: (OracleFunction) -> int
"""
......@@ -672,6 +760,9 @@ class OracleFunction(Oracle):
"""
return hash(tuple(self.oracle))
@cython.ccall
@cython.locals(cond=object)
@cython.returns(cython.void)
def add(self, cond):
# type: (OracleFunction, Condition) -> None
"""
......@@ -694,6 +785,8 @@ class OracleFunction(Oracle):
self.variables = self.variables.union(cond.get_variables())
self.oracle.add(cond)
@cython.ccall
@cython.returns(cython.ushort)
def dim(self):
# type: (OracleFunction) -> int
"""
......@@ -701,6 +794,9 @@ class OracleFunction(Oracle):
"""
return len(self.get_variables())
@cython.ccall
@cython.locals(i=object)
@cython.returns(list)
def get_var_names(self):
# type: (OracleFunction) -> list
"""
......@@ -708,6 +804,9 @@ class OracleFunction(Oracle):
"""
return [str(i) for i in self.variables]
@cython.ccall
@cython.locals(variable_list=list)
@cython.returns(list)
def get_variables(self):
# type: (OracleFunction) -> list
"""
......@@ -733,6 +832,9 @@ class OracleFunction(Oracle):
variable_list = list(self.variables)
return variable_list
@cython.ccall
@cython.locals(var=object, val=str, _eval_list=list, _eval=cython.bint)
@cython.returns(cython.bint)
def _eval_var_val(self, var=None, val='0'):
# type: (OracleFunction, Symbol, int) -> bool
_eval_list = [cond.eval_var_val(var, val) for cond in self.oracle]
......@@ -742,6 +844,9 @@ class OracleFunction(Oracle):
# _eval = any(_eval_list)
return _eval
@cython.ccall
@cython.locals(point=tuple, _eval_list=list, _eval=cython.bint)
@cython.returns(cython.bint)
def _eval_tuple(self, point):
# type: (OracleFunction, tuple) -> bool
_eval_list = [cond.eval_tuple(point) for cond in self.oracle]
......@@ -751,6 +856,9 @@ class OracleFunction(Oracle):
# _eval = any(_eval_list)
return _eval
@cython.ccall
@cython.locals(var_point=list, _eval_list=list, _eval=cython.bint)
@cython.returns(cython.bint)
def _eval_zip_tuple(self, var_point):
# type: (OracleFunction, list) -> bool
_eval_list = [cond.eval_zip_tuple(var_point) for cond in self.oracle]
......@@ -760,6 +868,8 @@ class OracleFunction(Oracle):
# _eval = any(_eval_list)
return _eval
@cython.locals(d=dict, _eval_list=list, _eval=cython.bint)
@cython.returns(cython.bint)
def _eval_dict(self, d=None):
# type: (OracleFunction, dict) -> bool
_eval_list = [cond.eval_dict(d) for cond in self.oracle]
......@@ -769,6 +879,8 @@ class OracleFunction(Oracle):
# _eval = any(_eval_list)
return _eval
@cython.locals(point=tuple)
@cython.returns(cython.bint)
def __contains__(self, point):
# type: (OracleFunction, tuple) -> bool
"""
......@@ -777,6 +889,9 @@ class OracleFunction(Oracle):
"""
return self.member(point) is True
@cython.ccall
@cython.returns(cython.bint)
@cython.locals(point=tuple, var_point=list)
def _member_zip_tuple(self, point):
# type: (OracleFunction, tuple) -> bool
# keys = [x, y, z]
......@@ -787,6 +902,9 @@ class OracleFunction(Oracle):
# var_point = zip(keys, point) # Works only in Python 2.7
return self._eval_zip_tuple(var_point)
@cython.ccall
@cython.returns(cython.bint)
@cython.locals(point=tuple, di=dict)
def _member_dict(self, point):
# type: (OracleFunction, tuple) -> bool
# keys = [x, y, z]
......@@ -796,6 +914,9 @@ class OracleFunction(Oracle):
di = {key: point[i] for i, key in enumerate(keys)}
return self._eval_dict(di)
@cython.ccall
@cython.returns(cython.bint)
@cython.locals(point=tuple)
def member(self, point):
# type: (OracleFunction, tuple) -> bool
"""
......@@ -806,6 +927,7 @@ class OracleFunction(Oracle):
return self._member_zip_tuple(point)
# return self.member_dict(point)
@cython.returns(object)
def membership(self):
# type: (OracleFunction) -> callable
"""
......@@ -814,6 +936,9 @@ class OracleFunction(Oracle):
return lambda point: self.member(point)
# Read/Write file functions
@cython.ccall
@cython.returns(cython.void)
@cython.locals(finput=object)
def from_file_binary(self, finput=None):
# type: (OracleFunction, io.BinaryIO) -> None
"""
......@@ -824,6 +949,9 @@ class OracleFunction(Oracle):
self.oracle = pickle.load(finput)
self.variables = pickle.load(finput)
@cython.ccall
@cython.returns(cython.void)
@cython.locals(finput=object, line=str, cond=object)
def from_file_text(self, finput=None):
# type: (OracleFunction, io.BinaryIO) -> None
"""
......@@ -837,6 +965,9 @@ class OracleFunction(Oracle):
cond.init_from_string(line)
self.add(cond)
@cython.ccall
@cython.returns(cython.void)
@cython.locals(foutput=object)
def to_file_binary(self, foutput=None):
# type: (OracleFunction, io.BinaryIO) -> None
"""
......@@ -847,6 +978,9 @@ class OracleFunction(Oracle):
pickle.dump(self.oracle, foutput, pickle.HIGHEST_PROTOCOL)
pickle.dump(self.variables, foutput, pickle.HIGHEST_PROTOCOL)
@cython.ccall
@cython.returns(cython.void)
@cython.locals(foutput=object)
def to_file_text(self, foutput=None):
# type: (OracleFunction, io.BinaryIO) -> None
"""
......
......@@ -15,6 +15,7 @@ import io
import os
import sys
import warnings
import cython
try:
import matlab.engine
except ImportError as e:
......@@ -26,7 +27,15 @@ from ParetoLib.Oracle.Oracle import Oracle
from ParetoLib._py3k import get_stdout_matlab, get_stderr_matlab
# @cython.cclass
class OracleMatlab(Oracle):
cython.declare(out=object)
cython.declare(err=object)
cython.declare(eng=object)
cython.declare(f=object)
cython.declare(d=cython.ushort)
def __init__(self, matlab_model_file=''):
# type: (OracleMatlab, str) -> None
"""
......@@ -122,6 +131,7 @@ class OracleMatlab(Oracle):
self.f = None
self.d = None
@cython.returns(str)
def __repr__(self):
# type: (OracleMatlab) -> str
"""
......@@ -129,6 +139,7 @@ class OracleMatlab(Oracle):
"""
return self.matlab_model_file
@cython.returns(str)
def __str__(self):