mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 02:20:49 +00:00
[autoparallel] adapt autoparallel with new analyzer (#3261)
* [autoparallel] adapt autoparallel with new analyzer * fix all node handler tests * polish * polish
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from topo_utils import split_model_and_get_DAG, check_topo, MLP
|
||||
from topo_utils import MLP, check_topo, split_model_and_get_DAG
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
def test_opt():
|
||||
MODEL_LIST = [
|
||||
MLP,
|
||||
@@ -13,7 +15,10 @@ def test_opt():
|
||||
]
|
||||
|
||||
CONFIGS = [
|
||||
{'dim': 10, 'layers': 12},
|
||||
{
|
||||
'dim': 10,
|
||||
'layers': 12
|
||||
},
|
||||
transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4),
|
||||
]
|
||||
|
||||
@@ -21,15 +26,15 @@ def test_opt():
|
||||
x = torch.zeros((16, 10))
|
||||
kwargs = dict(x=x)
|
||||
return kwargs
|
||||
|
||||
|
||||
def data_gen_OPT():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return kwargs
|
||||
|
||||
|
||||
DATAGEN = [
|
||||
data_gen_MLP,
|
||||
data_gen_MLP,
|
||||
data_gen_OPT,
|
||||
]
|
||||
|
||||
@@ -39,5 +44,6 @@ def test_opt():
|
||||
# print(f'{top_mod=}\n----\n{topo=}')
|
||||
check_topo(top_mod, topo)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_opt()
|
||||
test_opt()
|
||||
|
Reference in New Issue
Block a user