Skip to content
Snippets Groups Projects
Commit 28f7b2ac authored by EXT Jean-Matthieu Etancelin's avatar EXT Jean-Matthieu Etancelin
Browse files

Fix tasks overlapping issue with redistributeInter

parent 21f4a780
No related branches found
No related tags found
2 merge requests!24Resolve "Add python3.x support",!15WIP: Resolve "HySoP with tasks"
Pipeline #54213 failed
...@@ -536,12 +536,12 @@ class ComputationalGraph(ComputationalGraphNode): ...@@ -536,12 +536,12 @@ class ComputationalGraph(ComputationalGraphNode):
ops, already_printed = [], [] ops, already_printed = [], []
for (i, pn, n, op) in __recurse_nodes(self.nodes, nprefix=self.name+"."): for (i, pn, n, op) in __recurse_nodes(self.nodes, nprefix=self.name+"."):
for k in self._profiler.all_data.keys(): for k in self._profiler.all_data.keys():
if n == '.'.join(k.split('.')[:-1]) and not n in already_printed: if n == '.'.join(k.split('.')[:-1]) and not k in already_printed:
strdata = (i, str(op.mpi_params.task_id), pn, type(op).__name__, k.split('.')[-1]) strdata = (i, str(op.mpi_params.task_id), pn, type(op).__name__, k.split('.')[-1])
values = tuple(ff.format(self._profiler.all_data[k][_]) values = tuple(ff.format(self._profiler.all_data[k][_])
for ff, _ in zip(("{:.5g}", "{}", "{:.5g}", "{:.5g}", "{:.5g}", "{}"), (2, 1, 3, 4, 5, 0))) for ff, _ in zip(("{:.5g}", "{}", "{:.5g}", "{:.5g}", "{:.5g}", "{}"), (2, 1, 3, 4, 5, 0)))
ops += multiline_split(strdata + values, maxlen, split_sep, replace, newline_prefix) ops += multiline_split(strdata + values, maxlen, split_sep, replace, newline_prefix)
already_printed.append(n) already_printed.append(k)
isize = max(strlen(s[0]) for s in ops) isize = max(strlen(s[0]) for s in ops)
tasksize = max(max(strlen(s[1]) for s in ops), 6) tasksize = max(max(strlen(s[1]) for s in ops), 6)
......
...@@ -572,29 +572,8 @@ class RedistributeInter(RedistributeOperatorBase): ...@@ -572,29 +572,8 @@ class RedistributeInter(RedistributeOperatorBase):
if _is_target: if _is_target:
assert all(target_topo.mesh.local_resolution == ofield.resolution) assert all(target_topo.mesh.local_resolution == ofield.resolution)
# Compute if there is an overlap
src_overlap, dst_overlap = -1, -1
if _is_source:
src_overlap = source_topo.mpi_params.comm.allreduce(1 if _is_source and _is_target else 0)
if _is_target:
dst_overlap = target_topo.mpi_params.comm.allreduce(1 if _is_source and _is_target else 0)
if domain.task_rank() == 0 and (src_overlap == -1 or dst_overlap == -1):
rcv_overlap = domain.parent_comm.sendrecv(
src_overlap if dst_overlap == -1 else dst_overlap,
sendtag=self._other_task_id, recvtag=first_not_None((source_topo, target_topo)).mpi_params.task_id,
dest=domain.task_root_in_parent(self._other_task_id),
source=domain.task_root_in_parent(self._other_task_id))
src_overlap, dst_overlap = [rcv_overlap if _ == -1 else _ for _ in (src_overlap, dst_overlap)]
# ... then broadcast
if _is_source:
dst_overlap = source_topo.mpi_params.comm.bcast(dst_overlap, root=0)
if _is_target:
src_overlap = target_topo.mpi_params.comm.bcast(src_overlap, root=0)
assert (src_overlap+dst_overlap) % 2 == 0
self._has_overlap = src_overlap+dst_overlap > 0
# Create bridges and store comm types and indices # Create bridges and store comm types and indices
if self._has_overlap: if not domain.tasks_overlapping(source_id, target_id) is None:
self.bridge = BridgeOverlap(source=source_topo, target=target_topo, self.bridge = BridgeOverlap(source=source_topo, target=target_topo,
source_id=source_id, target_id=target_id, source_id=source_id, target_id=target_id,
dtype=self.dtype, order=get_mpi_order(first_not_None((ifield, ofield)).sdata)) dtype=self.dtype, order=get_mpi_order(first_not_None((ifield, ofield)).sdata))
...@@ -638,7 +617,7 @@ class RedistributeInter(RedistributeOperatorBase): ...@@ -638,7 +617,7 @@ class RedistributeInter(RedistributeOperatorBase):
@op_apply @op_apply
def apply(self, **kwds): def apply(self, **kwds):
comm = self.bridge.comm comm = self.bridge.comm
rank = self.mpi_params.rank rank = comm.Get_rank()
types = self._comm_types types = self._comm_types
indices = self._comm_indices indices = self._comm_indices
dFin, dFout = self.dFin, self.dFout dFin, dFout = self.dFin, self.dFout
......
...@@ -190,6 +190,9 @@ class DomainView(TaggedObjectView): ...@@ -190,6 +190,9 @@ class DomainView(TaggedObjectView):
"""Equivalent to self.long_description()""" """Equivalent to self.long_description()"""
return self.long_description() return self.long_description()
def tasks_overlapping(self, ta, tb):
return self._domain._overlapping_map[ta][tb]
domain = property(_get_domain) domain = property(_get_domain)
dim = property(_get_dim) dim = property(_get_dim)
proc_tasks = property(_get_proc_tasks) proc_tasks = property(_get_proc_tasks)
...@@ -339,6 +342,8 @@ class Domain(RegisteredObject): ...@@ -339,6 +342,8 @@ class Domain(RegisteredObject):
# Create intercommunicators from current task to others # Create intercommunicators from current task to others
task_intercomm = {} task_intercomm = {}
# task overlapping map : gives the largest task of two overlapping tasks
overlapping_map = dict([(_, dict([(__, None) for __ in all_tasks if _ != __])) for _ in all_tasks])
# For all tasks the current rank is involved in # For all tasks the current rank is involved in
my_tasks = tuple((_ for _ in all_tasks if is_task_matters(_, proc_tasks[parent_rank]))) my_tasks = tuple((_ for _ in all_tasks if is_task_matters(_, proc_tasks[parent_rank])))
for tsource in my_tasks: for tsource in my_tasks:
...@@ -356,10 +361,12 @@ class Domain(RegisteredObject): ...@@ -356,10 +361,12 @@ class Domain(RegisteredObject):
if any([all([t in _ for _ in proc_tasks]) for t in all_tasks]): if any([all([t in _ for _ in proc_tasks]) for t in all_tasks]):
# TODO: review if nested tasks with differents ranks # TODO: review if nested tasks with differents ranks
# for the moment : ensure ranks are identical throw all local tasks # for the moment : ensure ranks are identical throw all local tasks
assert all([all([t.values()[0] == _ for _ in t.values()]) for t in all_task_ranks]) #assert all([all([t.values()[0] == _ for _ in t.values()]) for t in all_task_ranks]), all_task_ranks
# If nested tasks, we use the largest task communicator # If nested tasks, we use the largest task communicator
largest_task = [t for t in all_tasks if all([t in _ for _ in proc_tasks])][0] largest_task = [t for t in all_tasks if all([t in _ for _ in proc_tasks])][0]
intercomm = task_comm[largest_task] intercomm = task_comm[largest_task]
overlapping_map[tsource][tdest] = largest_task
overlapping_map[tdest][tsource] = largest_task
else: else:
raise NotImplementedError() raise NotImplementedError()
task_intercomm[tdest] = intercomm task_intercomm[tdest] = intercomm
...@@ -387,6 +394,7 @@ class Domain(RegisteredObject): ...@@ -387,6 +394,7 @@ class Domain(RegisteredObject):
self._proc_tasks = proc_tasks self._proc_tasks = proc_tasks
self._registered_topologies = {} self._registered_topologies = {}
self._frame = SymbolicFrame(dim=dim) self._frame = SymbolicFrame(dim=dim)
self._overlapping_map = overlapping_map
def register_topology(self, topo): def register_topology(self, topo):
"""Register a new topology on this domain. """Register a new topology on this domain.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment