[zero] new interface for ShardedOptimv2 (#406)

This commit is contained in:
Jiarui Fang
2022-03-14 20:48:41 +08:00
committed by GitHub
parent a9c27be42e
commit 370f567e7d
9 changed files with 51 additions and 35 deletions

View File

@@ -19,11 +19,11 @@ def run_train():
# FIXME: test bert
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func()
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
model = model_builder(checkpoint=False)
engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer_builder(model),
optimizer=optimizer_class(model.parameters(), lr=1e-3),
criterion=criterion,
train_dataloader=train_dataloader)
@@ -84,7 +84,7 @@ def run_engine(rank, world_size, port):
@pytest.mark.dist
def test_engine():
world_size = 4
world_size = 2
run_func = partial(run_engine, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)