mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
Revert "[sync] sync feature/shardformer with develop"
This commit is contained in:
@@ -6,9 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.d_tensor.layout_converter import to_global
|
||||
from tests.kit.model_zoo.registry import ModelAttribute
|
||||
|
||||
@@ -83,8 +81,7 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
|
||||
print(f'{model.__class__.__name__} pass')
|
||||
|
||||
|
||||
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh,
|
||||
sharding_spec_dict: dict) -> None:
|
||||
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None:
|
||||
state = model.state_dict()
|
||||
distributed_state = distributed_model.state_dict()
|
||||
|
||||
@@ -94,7 +91,6 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.
|
||||
assert n1 == n2
|
||||
t1 = t1.cuda()
|
||||
t2 = t2.cuda()
|
||||
if n2 in sharding_spec_dict:
|
||||
layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape)
|
||||
t2 = to_global(t2, layout)
|
||||
if n2 in layout_dict:
|
||||
t2 = to_global(t2, layout_dict[n2])
|
||||
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
|
||||
|
Reference in New Issue
Block a user