[Pipeline] Add Topo Class (#2059)

* use Topo class to rewrite DAG

* polish code

* polish code

* polish code

* add comment

* add else to unended if

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2022-12-02 18:13:20 +08:00
committed by GitHub
parent e4293e5077
commit 44ea461890
10 changed files with 451 additions and 283 deletions

View File

@@ -0,0 +1,43 @@
import pytest
import torch
import transformers
from topo_utils import split_model_and_get_DAG, check_topo, MLP
BATCH_SIZE = 1
SEQ_LENGHT = 16
def test_opt():
MODEL_LIST = [
MLP,
transformers.OPTModel,
]
CONFIGS = [
{'dim': 10, 'layers': 12},
transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4),
]
def data_gen_MLP():
x = torch.zeros((16, 10))
kwargs = dict(x=x)
return kwargs
def data_gen_OPT():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs
DATAGEN = [
data_gen_MLP,
data_gen_OPT,
]
for i, model_cls in enumerate(MODEL_LIST):
model = model_cls(config=CONFIGS[i])
top_mod, topo = split_model_and_get_DAG(model, DATAGEN[i])
# print(f'{top_mod=}\n----\n{topo=}')
check_topo(top_mod, topo)
if __name__ == '__main__':
test_opt()

View File

@@ -0,0 +1,92 @@
import torch
from torch.fx import GraphModule
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from colossalai.fx import ColoTracer
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from colossalai.pipeline.middleware.adaptor import get_fx_topology
import random
import numpy as np
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
class MLP(torch.nn.Module):
def __init__(self, config={}):
super().__init__()
dim = config['dim']
layers = config['layers']
self.layers = torch.nn.ModuleList()
for _ in range(layers):
self.layers.append(torch.nn.Linear(dim, dim, bias=False))
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def split_model_and_get_DAG(model, data_gen):
model.eval()
# generate input sample
kwargs = data_gen()
# tracing model
tracer = ColoTracer()
try:
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# apply transform passes
annotated_model = balanced_split_pass(gm, 2)
top_module, split_submodules = split_with_split_nodes_pass(annotated_model)
topo = get_fx_topology(top_module)
for submodule in split_submodules:
if isinstance(submodule, torch.fx.GraphModule):
setattr(submodule, '_topo', topo)
return top_module, split_submodules[0]._topo
def check_input(top_module, input_partition: Partition):
partition_output = input_partition.get_output_vals()
arg_pos = 0
for node in top_module.graph.nodes:
if node.op == 'placeholder':
cur_checkee = partition_output[arg_pos]
to_partition_and_offset = cur_checkee.get()
assert len(to_partition_and_offset) == len(node.users.keys())
arg_pos += 1
assert arg_pos == len(partition_output)
def check_submod(top_module, part_id, mid_partition: Partition):
partition_input = mid_partition.get_input_vals()
partition_output = mid_partition.get_output_vals()
cnt = 1
cur_node = None
for node in top_module.graph.nodes:
if node.name.startswith('submod'):
cnt += 1
if cnt == part_id:
cur_node = node
break
assert len(partition_input) == len(cur_node.args)
assert len(partition_output) == len(cur_node.users)
def check_topo(top_module, topo: Topo):
input_partition = topo.get_input_partition()
mid_partitions = topo.get_mid_partitions()
check_input(top_module, input_partition)
for part_id, submod in mid_partitions.items():
check_submod(top_module, part_id, submod)