mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 07:00:37 +00:00
[test] merge old components to test to model zoo (#4945)
* [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test
This commit is contained in:
@@ -5,9 +5,9 @@ import colossalai
|
||||
from colossalai.legacy.amp.amp_type import AMP_TYPE
|
||||
from colossalai.legacy.trainer import Trainer
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import MultiTimer
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
BATCH_SIZE = 4
|
||||
IMG_SIZE = 32
|
||||
@@ -16,12 +16,14 @@ NUM_EPOCHS = 200
|
||||
CONFIG = dict(fp16=dict(mode=AMP_TYPE.TORCH))
|
||||
|
||||
|
||||
@parameterize("model_name", ["repeated_computed_layers", "resnet18", "nested_model"])
|
||||
@parameterize("model_name", ["custom_repeated_computed_layers", "torchvision_resnet18", "custom_nested_model"])
|
||||
def run_trainer(model_name):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
model = model_builder()
|
||||
optimizer = optimizer_class(model.parameters(), lr=1e-3)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
train_dataloader = DummyDataloader(data_gen_fn)
|
||||
test_dataloader = DummyDataloader(data_gen_fn)
|
||||
criterion = lambda x: x.sum()
|
||||
engine, train_dataloader, *_ = colossalai.legacy.initialize(
|
||||
model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader
|
||||
)
|
||||
|
Reference in New Issue
Block a user