[autoparallel] adapt autoparallel with new analyzer (#3261)

* [autoparallel] adapt autoparallel with new analyzer

* fix all node handler tests

* polish

* polish
This commit is contained in:
YuliangLiu0306
2023-03-30 17:47:24 +08:00
committed by GitHub
parent e78a1e949a
commit fee2af8610
36 changed files with 481 additions and 386 deletions

View File

@@ -1,11 +1,13 @@
import pytest
import torch
import transformers
from topo_utils import split_model_and_get_DAG, check_topo, MLP
from topo_utils import MLP, check_topo, split_model_and_get_DAG
BATCH_SIZE = 1
SEQ_LENGHT = 16
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
def test_opt():
MODEL_LIST = [
MLP,
@@ -13,7 +15,10 @@ def test_opt():
]
CONFIGS = [
{'dim': 10, 'layers': 12},
{
'dim': 10,
'layers': 12
},
transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4),
]
@@ -21,15 +26,15 @@ def test_opt():
x = torch.zeros((16, 10))
kwargs = dict(x=x)
return kwargs
def data_gen_OPT():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs
DATAGEN = [
data_gen_MLP,
data_gen_MLP,
data_gen_OPT,
]
@@ -39,5 +44,6 @@ def test_opt():
# print(f'{top_mod=}\n----\n{topo=}')
check_topo(top_mod, topo)
if __name__ == '__main__':
test_opt()
test_opt()