mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[hotfix] pass test_complete_workflow (#1877)
This commit is contained in:
@@ -50,7 +50,7 @@ def run_workflow(world_size, dev):
|
||||
annotated_gm.recompile()
|
||||
|
||||
# materialization and sharding
|
||||
ctx.lazy_init_parameters(annotated_gm)
|
||||
ctx.lazy_init_parameters(annotated_gm, device=dev)
|
||||
for param in model.parameters():
|
||||
assert not param.is_meta
|
||||
|
||||
@@ -84,4 +84,4 @@ def test_complete_workflow(world_size, dev):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_complete_workflow(1)
|
||||
test_complete_workflow(1, 'cuda')
|
||||
|
Reference in New Issue
Block a user