From d68438718100fcb67af93593e31bfc9be7376b17 Mon Sep 17 00:00:00 2001 From: JM Etancelin <jean-matthieu.etancelin@univ-pau.fr> Date: Thu, 4 Jun 2020 17:26:52 +0200 Subject: [PATCH] fix mpi bridge --- hysop/core/mpi/bridge.py | 13 +++++++++---- hysop/core/mpi/topo_tools.py | 21 ++++++++------------- hysop/problem.py | 5 ++++- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/hysop/core/mpi/bridge.py b/hysop/core/mpi/bridge.py index fdc200f5b..901e32886 100644 --- a/hysop/core/mpi/bridge.py +++ b/hysop/core/mpi/bridge.py @@ -376,10 +376,15 @@ class BridgeOverlap(Bridge): def _build_send_recv_dict(self): # Compute local intersections : i.e. find which grid points # are on both source and target mesh. - indices_source = TopoTools.gather_global_indices_overlap( - self._source, self.comm, self.domain) - indices_target = TopoTools.gather_global_indices_overlap( - self._target, self.comm, self.domain) + # Filter out the empty slices (due to none topologies) + indices_source = dict([(rk, sl) for rk, sl in + TopoTools.gather_global_indices_overlap( + self._source, self.comm, self.domain).iteritems() + if not all([_ == slice(0, 0) for _ in sl])]) + indices_target = dict([(rk, sl) for rk, sl in + TopoTools.gather_global_indices_overlap( + self._target, self.comm, self.domain).iteritems() + if not all([_ == slice(0, 0) for _ in sl])]) # From now on, we have indices_source[rk] = global indices (slice) # of grid points of the source on process number rk in parent. diff --git a/hysop/core/mpi/topo_tools.py b/hysop/core/mpi/topo_tools.py index 31e605dc7..6867528c0 100644 --- a/hysop/core/mpi/topo_tools.py +++ b/hysop/core/mpi/topo_tools.py @@ -74,9 +74,7 @@ class TopoTools(object): if root is None: comm.Allgather([iglob[:, rank], MPI.INT], [iglob_res, MPI.INT]) else: - comm.Gather([iglob[:, rank], MPI.INT], [iglob_res, MPI.INT], - root=root) - + comm.Gather([iglob[:, rank], MPI.INT], [iglob_res, MPI.INT], root=root) if toslice: return Utils.array_to_dict(iglob_res) else: @@ -133,19 +131,18 @@ class TopoTools(object): size = comm.Get_size() rank = comm.Get_rank() dimension = dom.dim - iglob = npw.integer_zeros((size, dimension * 2)) - iglob_res = npw.integer_zeros((size, dimension * 2)) + iglob = npw.integer_zeros((dimension * 2, size), order='F') + iglob_res = npw.integer_zeros((dimension * 2, size), order='F') + iglob[0::2, rank] = 0 iglob[1::2, rank] = -1 if root is None: - comm.Allgather([iglob[rank, :], MPI.INT], [iglob_res, MPI.INT]) + comm.Allgather([iglob[:, rank], MPI.INT], [iglob_res, MPI.INT]) else: - comm.Gather([iglob[rank, :], MPI.INT], [iglob_res, MPI.INT], - root=root) + comm.Gather([iglob[:, rank], MPI.INT], [iglob_res, MPI.INT], root=root) if toslice: - return Utils.array_to_dict(iglob_res.T) + return Utils.array_to_dict(iglob_res) else: - return iglob_res.T - + return iglob_res else: return TopoTools.gather_global_indices(topo, toslice, root, comm) @@ -286,8 +283,6 @@ class TopoTools(object): basetype = dtype_to_mpi_type(dtype) subtype = basetype.Create_subarray(shape, subshape, substart, order=order) subtype.Commit() - # print 'MPI_Create_subarray(shape={}, subshape={}, substart={}, order={})'.format( - # shape, subshape, substart, 'C' if order is MPI.ORDER_C else 'F') return subtype @staticmethod diff --git a/hysop/problem.py b/hysop/problem.py index 7970a071f..840bf45e2 100644 --- a/hysop/problem.py +++ b/hysop/problem.py @@ -100,7 +100,7 @@ class Problem(ComputationalGraph): msg += '\n If this is required, override check_unique_clenv().' raise RuntimeError(msg) - def initialize_field(self, field, **kwds): + def initialize_field(self, field, mpi_params=None, **kwds): """Initialize a field on all its input and output topologies.""" initialized = set() for op in self.nodes: @@ -113,6 +113,9 @@ class Problem(ComputationalGraph): if all((df in initialized) for df in dfield.discrete_fields()): # all contained scalar fields were already initialized continue + elif mpi_params and mpi_params.task_id != dfield.topology.task_id: + # Topology task does not matches given mpi_params task + continue else: components = () for (component, scalar_dfield) in dfield.nd_iter(): -- GitLab