mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[shardformer] support module saving and loading (#4062)
* [shardformer] support module saving and loading * polish code
This commit is contained in:
@@ -7,7 +7,8 @@ import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
|
||||
from colossalai.tensor.d_tensor.layout_converter import to_global
|
||||
from colossalai.tensor.d_tensor import to_global
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from tests.kit.model_zoo.registry import ModelAttribute
|
||||
|
||||
SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0')
|
||||
@@ -91,6 +92,8 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.
|
||||
assert n1 == n2
|
||||
t1 = t1.cuda()
|
||||
t2 = t2.cuda()
|
||||
if n2 in layout_dict:
|
||||
t2 = to_global(t2, layout_dict[n2])
|
||||
if n2 in sharding_spec_dict:
|
||||
layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape)
|
||||
t2.dist_layout = layout
|
||||
t2 = to_global(t2)
|
||||
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
|
||||
|
Reference in New Issue
Block a user