mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[utils] fixed lazy init context (#1867)
This commit is contained in:
@@ -1,16 +1,18 @@
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.utils.model.lazy_init_context import LazyInitContext
|
||||
from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.lazy_init_context import LazyInitContext
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
@@ -35,6 +37,9 @@ def run_workflow(world_size):
|
||||
with LazyInitContext() as ctx:
|
||||
model = MLP(16)
|
||||
|
||||
for param in model.parameters():
|
||||
assert param.is_meta
|
||||
|
||||
# tracing
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model)
|
||||
@@ -46,6 +51,8 @@ def run_workflow(world_size):
|
||||
|
||||
# materialization and sharding
|
||||
ctx.lazy_init_parameters(annotated_gm)
|
||||
for param in model.parameters():
|
||||
assert not param.is_meta
|
||||
|
||||
# # check sharding
|
||||
assert list(model.linear1.weight.shape) == [16 // world_size, 16]
|
||||
@@ -57,7 +64,7 @@ def run_workflow(world_size):
|
||||
data = torch.rand(4, 16)
|
||||
non_fx_out = model(data)
|
||||
fx_out = annotated_gm(data)
|
||||
assert torch.equal(non_fx_out, fx_out)
|
||||
assert torch.equal(non_fx_out, fx_out), f'{non_fx_out} vs {fx_out}'
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
@@ -74,4 +81,4 @@ def test_complete_workflow(world_size):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_complete_workflow(2)
|
||||
test_complete_workflow(1)
|
||||
|
Reference in New Issue
Block a user