mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-19 00:16:51 +00:00
[Pipeline] Add Topo Class (#2059)
* use Topo class to rewrite DAG * polish code * polish code * polish code * add comment * add else to unended if Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
43
tests/test_fx/test_pipeline/test_topo/test_topo.py
Normal file
43
tests/test_fx/test_pipeline/test_topo/test_topo.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from topo_utils import split_model_and_get_DAG, check_topo, MLP
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
def test_opt():
|
||||
MODEL_LIST = [
|
||||
MLP,
|
||||
transformers.OPTModel,
|
||||
]
|
||||
|
||||
CONFIGS = [
|
||||
{'dim': 10, 'layers': 12},
|
||||
transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4),
|
||||
]
|
||||
|
||||
def data_gen_MLP():
|
||||
x = torch.zeros((16, 10))
|
||||
kwargs = dict(x=x)
|
||||
return kwargs
|
||||
|
||||
def data_gen_OPT():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return kwargs
|
||||
|
||||
DATAGEN = [
|
||||
data_gen_MLP,
|
||||
data_gen_OPT,
|
||||
]
|
||||
|
||||
for i, model_cls in enumerate(MODEL_LIST):
|
||||
model = model_cls(config=CONFIGS[i])
|
||||
top_mod, topo = split_model_and_get_DAG(model, DATAGEN[i])
|
||||
# print(f'{top_mod=}\n----\n{topo=}')
|
||||
check_topo(top_mod, topo)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_opt()
|
92
tests/test_fx/test_pipeline/test_topo/topo_utils.py
Normal file
92
tests/test_fx/test_pipeline/test_topo/topo_utils.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||
from colossalai.pipeline.middleware.adaptor import get_fx_topology
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
MANUAL_SEED = 0
|
||||
random.seed(MANUAL_SEED)
|
||||
np.random.seed(MANUAL_SEED)
|
||||
torch.manual_seed(MANUAL_SEED)
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, config={}):
|
||||
super().__init__()
|
||||
dim = config['dim']
|
||||
layers = config['layers']
|
||||
self.layers = torch.nn.ModuleList()
|
||||
|
||||
for _ in range(layers):
|
||||
self.layers.append(torch.nn.Linear(dim, dim, bias=False))
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
def split_model_and_get_DAG(model, data_gen):
|
||||
model.eval()
|
||||
|
||||
# generate input sample
|
||||
kwargs = data_gen()
|
||||
|
||||
# tracing model
|
||||
tracer = ColoTracer()
|
||||
try:
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
# apply transform passes
|
||||
annotated_model = balanced_split_pass(gm, 2)
|
||||
top_module, split_submodules = split_with_split_nodes_pass(annotated_model)
|
||||
|
||||
topo = get_fx_topology(top_module)
|
||||
for submodule in split_submodules:
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
setattr(submodule, '_topo', topo)
|
||||
|
||||
return top_module, split_submodules[0]._topo
|
||||
|
||||
def check_input(top_module, input_partition: Partition):
|
||||
partition_output = input_partition.get_output_vals()
|
||||
arg_pos = 0
|
||||
for node in top_module.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
cur_checkee = partition_output[arg_pos]
|
||||
to_partition_and_offset = cur_checkee.get()
|
||||
assert len(to_partition_and_offset) == len(node.users.keys())
|
||||
arg_pos += 1
|
||||
|
||||
assert arg_pos == len(partition_output)
|
||||
|
||||
def check_submod(top_module, part_id, mid_partition: Partition):
|
||||
partition_input = mid_partition.get_input_vals()
|
||||
partition_output = mid_partition.get_output_vals()
|
||||
|
||||
cnt = 1
|
||||
cur_node = None
|
||||
for node in top_module.graph.nodes:
|
||||
if node.name.startswith('submod'):
|
||||
cnt += 1
|
||||
if cnt == part_id:
|
||||
cur_node = node
|
||||
break
|
||||
|
||||
assert len(partition_input) == len(cur_node.args)
|
||||
assert len(partition_output) == len(cur_node.users)
|
||||
|
||||
def check_topo(top_module, topo: Topo):
|
||||
input_partition = topo.get_input_partition()
|
||||
mid_partitions = topo.get_mid_partitions()
|
||||
|
||||
check_input(top_module, input_partition)
|
||||
for part_id, submod in mid_partitions.items():
|
||||
check_submod(top_module, part_id, submod)
|
||||
|
Reference in New Issue
Block a user