Revert "[sync] sync feature/shardformer with develop"

This commit is contained in:
Frank Lee
2023-06-09 09:41:27 +08:00
committed by GitHub
parent 24651fdd4f
commit ddcf58cacf
48 changed files with 445 additions and 3876 deletions

View File

@@ -6,9 +6,7 @@ import numpy as np
import torch
from packaging import version
from colossalai.device.device_mesh import DeviceMesh
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.layout_converter import to_global
from tests.kit.model_zoo.registry import ModelAttribute
@@ -83,8 +81,7 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
print(f'{model.__class__.__name__} pass')
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh,
sharding_spec_dict: dict) -> None:
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None:
state = model.state_dict()
distributed_state = distributed_model.state_dict()
@@ -94,7 +91,6 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.
assert n1 == n2
t1 = t1.cuda()
t2 = t2.cuda()
if n2 in sharding_spec_dict:
layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape)
t2 = to_global(t2, layout)
if n2 in layout_dict:
t2 = to_global(t2, layout_dict[n2])
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'