mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-20 09:01:06 +00:00
[Pipeline Middleware] Adapt scheduler for Topo (#2066)
* adapt scheduler for Topo * remoove comment * fix set input Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
@@ -108,7 +108,7 @@ def get_topology(gm: GraphModule):
|
||||
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
|
||||
topo_input_partition.add_output_val(p_output_val)
|
||||
topo.set_partitions(partition_id=0, partition=topo_input_partition)
|
||||
topo.set_input_partition(partition_id=0)
|
||||
topo.set_input_partition_id(partition_id=0)
|
||||
|
||||
for i, partition in enumerate(partitions):
|
||||
topo_mid_partition = Partition()
|
||||
@@ -140,6 +140,6 @@ def get_topology(gm: GraphModule):
|
||||
torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val(
|
||||
find_input_in_partition(n, partitions, input_partitions)))
|
||||
topo.set_partitions(partition_id=1, partition=topo_output_partition)
|
||||
topo.set_output_partition(partition_id=1)
|
||||
topo.set_output_partition_id(partition_id=1)
|
||||
|
||||
return topo
|
@@ -71,6 +71,36 @@ class Partition(object):
|
||||
|
||||
def get_output_vals(self):
|
||||
return self._output_vals
|
||||
|
||||
# get the output offsets sent to dst_partition_id
|
||||
def get_output_offsets(self, dst_partition_id):
|
||||
res = []
|
||||
for offset, output_val in enumerate(self._output_vals):
|
||||
outputs = output_val.get()
|
||||
for val_pos in outputs:
|
||||
if val_pos.partition_id == dst_partition_id:
|
||||
res.append(offset)
|
||||
|
||||
return res
|
||||
|
||||
# get all input dst partition_ids
|
||||
def get_input_partition_ids(self):
|
||||
res = []
|
||||
for input_val in self._input_vals:
|
||||
val_pos = input_val.get()
|
||||
if val_pos.partition_id not in res:
|
||||
res.append(val_pos.partition_id)
|
||||
return res
|
||||
|
||||
# get all output dst partition_ids
|
||||
def get_output_partition_ids(self):
|
||||
res = []
|
||||
for output_val in self._output_vals:
|
||||
outputs = output_val.get()
|
||||
for val_pos in outputs:
|
||||
if val_pos.partition_id not in res:
|
||||
res.append(val_pos.partition_id)
|
||||
return res
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
@@ -107,11 +137,17 @@ class Topo(object):
|
||||
self._input_partition_id = input_partition_id
|
||||
self._output_partition_id = output_partition_id
|
||||
|
||||
def set_input_partition(self, partition_id: int):
|
||||
def set_input_partition_id(self, partition_id: int):
|
||||
self._input_partition_id = partition_id
|
||||
|
||||
def set_output_partition(self, partition_id: int):
|
||||
def set_output_partition_id(self, partition_id: int):
|
||||
self._output_partition_id = partition_id
|
||||
|
||||
def get_input_partition_id(self):
|
||||
return self._input_partition_id
|
||||
|
||||
def get_output_partition_id(self):
|
||||
return self._output_partition_id
|
||||
|
||||
def set_partitions(self, partition_id: int, partition: Partition):
|
||||
self._partitions[partition_id] = partition
|
||||
@@ -124,6 +160,9 @@ class Topo(object):
|
||||
res[partition_id] = partition
|
||||
return res
|
||||
|
||||
def get_mid_partition_ids(self):
|
||||
return list(self.get_mid_partitions().keys())
|
||||
|
||||
def get_input_partition(self):
|
||||
if self._input_partition_id is not None:
|
||||
return self._partitions[self._input_partition_id]
|
||||
@@ -133,6 +172,9 @@ class Topo(object):
|
||||
if self._output_partition_id is not None:
|
||||
return self._partitions[self._output_partition_id]
|
||||
return None
|
||||
|
||||
def get_partition_by_id(self, partition_id):
|
||||
return self._partitions[partition_id]
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
|
Reference in New Issue
Block a user