[utils] fixed lazy init context (#1867)

This commit is contained in:
Frank Lee
2022-11-10 15:17:20 +08:00
committed by GitHub
parent 50c4cb0167
commit e6ec99d389
2 changed files with 35 additions and 25 deletions

View File

@@ -1,16 +1,18 @@
import colossalai
import torch
import torch.nn as nn
import pytest
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import colossalai
from colossalai.fx import ColoTracer
from colossalai.utils.model.lazy_init_context import LazyInitContext
from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass
from colossalai.utils import free_port
from colossalai.tensor import ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.model.lazy_init_context import LazyInitContext
class MLP(torch.nn.Module):
@@ -35,6 +37,9 @@ def run_workflow(world_size):
with LazyInitContext() as ctx:
model = MLP(16)
for param in model.parameters():
assert param.is_meta
# tracing
tracer = ColoTracer()
graph = tracer.trace(model)
@@ -46,6 +51,8 @@ def run_workflow(world_size):
# materialization and sharding
ctx.lazy_init_parameters(annotated_gm)
for param in model.parameters():
assert not param.is_meta
# # check sharding
assert list(model.linear1.weight.shape) == [16 // world_size, 16]
@@ -57,7 +64,7 @@ def run_workflow(world_size):
data = torch.rand(4, 16)
non_fx_out = model(data)
fx_out = annotated_gm(data)
assert torch.equal(non_fx_out, fx_out)
assert torch.equal(non_fx_out, fx_out), f'{non_fx_out} vs {fx_out}'
def run_dist(rank, world_size, port):
@@ -74,4 +81,4 @@ def test_complete_workflow(world_size):
if __name__ == '__main__':
test_complete_workflow(2)
test_complete_workflow(1)