mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[fx]Split partition with DAG information (#2025)
* add DAG to split_module * add comment * add test case for DAG * remove print Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
85
tests/test_fx/test_pipeline/test_DAG/dag_utils.py
Normal file
85
tests/test_fx/test_pipeline/test_DAG/dag_utils.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||
from colossalai.fx import ColoTracer
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
MANUAL_SEED = 0
|
||||
random.seed(MANUAL_SEED)
|
||||
np.random.seed(MANUAL_SEED)
|
||||
torch.manual_seed(MANUAL_SEED)
|
||||
|
||||
def split_model_and_get_DAG(model, data_gen):
|
||||
model.eval()
|
||||
|
||||
# generate input sample
|
||||
kwargs = data_gen()
|
||||
|
||||
# get origin output and rng state
|
||||
cpu_rng_state = torch.get_rng_state()
|
||||
output = model(**kwargs)
|
||||
|
||||
# tracing model
|
||||
tracer = ColoTracer()
|
||||
try:
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
# apply transform passes
|
||||
annotated_model = balanced_split_pass(gm, 2)
|
||||
top_module, split_submodules = split_with_split_nodes_pass(annotated_model)
|
||||
|
||||
return top_module, split_submodules[0]._DAG
|
||||
|
||||
def check_input(input, input_node, top_module):
|
||||
for user in input_node.users.keys():
|
||||
partition_name = user.name
|
||||
assert partition_name in input['output']
|
||||
|
||||
def check_submod(submod_partition, node, top_module):
|
||||
for arg in node.args:
|
||||
input_part_name = None
|
||||
if arg.op == 'placeholder':
|
||||
input_part_name = 'MODEL_INPUT'
|
||||
elif not arg.name.startswith('getitem'):
|
||||
input_part_name = arg.name
|
||||
else:
|
||||
input_part_name = arg.args[0].name
|
||||
assert input_part_name in submod_partition['input']
|
||||
|
||||
for user in node.users:
|
||||
output_part_names = []
|
||||
if user.op == 'output':
|
||||
output_part_names.append('MODEL_OUTPUT')
|
||||
elif not user.name.startswith('getitem'):
|
||||
output_part_names.append(user.name)
|
||||
else:
|
||||
for n in user.users:
|
||||
if n.op == 'output':
|
||||
output_part_names.append('MODEL_OUTPUT')
|
||||
else:
|
||||
output_part_names.append(n.name)
|
||||
|
||||
for output_part_name in output_part_names:
|
||||
assert output_part_name in submod_partition['output']
|
||||
|
||||
def check_DAG(top_module, DAG):
|
||||
assert 'input_partition' in DAG
|
||||
input_partition = DAG['input_partition']
|
||||
|
||||
for node in top_module.graph.nodes:
|
||||
# check input
|
||||
if node.op == 'placeholder':
|
||||
assert node.name in input_partition
|
||||
input = input_partition[node.name]
|
||||
check_input(input, node, top_module)
|
||||
elif node.op == 'call_module':
|
||||
assert node.name in DAG
|
||||
submod_partition = DAG[node.name]
|
||||
check_submod(submod_partition, node, top_module)
|
||||
|
31
tests/test_fx/test_pipeline/test_DAG/test_dag.py
Normal file
31
tests/test_fx/test_pipeline/test_DAG/test_dag.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from dag_utils import split_model_and_get_DAG, check_DAG
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
@pytest.mark.skip('balance split v2 is not ready')
|
||||
def test_opt():
|
||||
MODEL_LIST = [
|
||||
transformers.OPTModel,
|
||||
#transformers.OPTForCausalLM,
|
||||
]
|
||||
|
||||
config = transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return kwargs
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
model = model_cls(config=config)
|
||||
top_mod, DAG = split_model_and_get_DAG(model, data_gen)
|
||||
check_DAG(top_mod, DAG)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_opt()
|
Reference in New Issue
Block a user