[Pipeline Middleware] Adapt scheduler for Topo (#2066)

* adapt scheduler for Topo

* remoove comment

* fix set input

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2022-12-05 20:23:41 +08:00
committed by GitHub
parent b3b89865e2
commit 597cdd3006
4 changed files with 160 additions and 128 deletions

View File

@@ -108,7 +108,7 @@ def get_topology(gm: GraphModule):
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
topo_input_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=0, partition=topo_input_partition)
topo.set_input_partition(partition_id=0)
topo.set_input_partition_id(partition_id=0)
for i, partition in enumerate(partitions):
topo_mid_partition = Partition()
@@ -140,6 +140,6 @@ def get_topology(gm: GraphModule):
torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val(
find_input_in_partition(n, partitions, input_partitions)))
topo.set_partitions(partition_id=1, partition=topo_output_partition)
topo.set_output_partition(partition_id=1)
topo.set_output_partition_id(partition_id=1)
return topo

View File

@@ -71,6 +71,36 @@ class Partition(object):
def get_output_vals(self):
return self._output_vals
# get the output offsets sent to dst_partition_id
def get_output_offsets(self, dst_partition_id):
res = []
for offset, output_val in enumerate(self._output_vals):
outputs = output_val.get()
for val_pos in outputs:
if val_pos.partition_id == dst_partition_id:
res.append(offset)
return res
# get all input dst partition_ids
def get_input_partition_ids(self):
res = []
for input_val in self._input_vals:
val_pos = input_val.get()
if val_pos.partition_id not in res:
res.append(val_pos.partition_id)
return res
# get all output dst partition_ids
def get_output_partition_ids(self):
res = []
for output_val in self._output_vals:
outputs = output_val.get()
for val_pos in outputs:
if val_pos.partition_id not in res:
res.append(val_pos.partition_id)
return res
def __str__(self) -> str:
res = ''
@@ -107,11 +137,17 @@ class Topo(object):
self._input_partition_id = input_partition_id
self._output_partition_id = output_partition_id
def set_input_partition(self, partition_id: int):
def set_input_partition_id(self, partition_id: int):
self._input_partition_id = partition_id
def set_output_partition(self, partition_id: int):
def set_output_partition_id(self, partition_id: int):
self._output_partition_id = partition_id
def get_input_partition_id(self):
return self._input_partition_id
def get_output_partition_id(self):
return self._output_partition_id
def set_partitions(self, partition_id: int, partition: Partition):
self._partitions[partition_id] = partition
@@ -124,6 +160,9 @@ class Topo(object):
res[partition_id] = partition
return res
def get_mid_partition_ids(self):
return list(self.get_mid_partitions().keys())
def get_input_partition(self):
if self._input_partition_id is not None:
return self._partitions[self._input_partition_id]
@@ -133,6 +172,9 @@ class Topo(object):
if self._output_partition_id is not None:
return self._partitions[self._output_partition_id]
return None
def get_partition_by_id(self, partition_id):
return self._partitions[partition_id]
def __str__(self) -> str:
res = ''