mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
Revert "[sync] sync feature/shardformer with develop"
This commit is contained in:
@@ -6,9 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.d_tensor.layout_converter import to_global
|
||||
from tests.kit.model_zoo.registry import ModelAttribute
|
||||
|
||||
@@ -83,8 +81,7 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
|
||||
print(f'{model.__class__.__name__} pass')
|
||||
|
||||
|
||||
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh,
|
||||
sharding_spec_dict: dict) -> None:
|
||||
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None:
|
||||
state = model.state_dict()
|
||||
distributed_state = distributed_model.state_dict()
|
||||
|
||||
@@ -94,7 +91,6 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.
|
||||
assert n1 == n2
|
||||
t1 = t1.cuda()
|
||||
t2 = t2.cuda()
|
||||
if n2 in sharding_spec_dict:
|
||||
layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape)
|
||||
t2 = to_global(t2, layout)
|
||||
if n2 in layout_dict:
|
||||
t2 = to_global(t2, layout_dict[n2])
|
||||
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
|
||||
|
@@ -26,19 +26,23 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]:
|
||||
return dim
|
||||
|
||||
|
||||
def make_sharding_spec(original_tensor: torch.Tensor) -> Layout:
|
||||
def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout:
|
||||
shard_dim = find_shard_dim(original_tensor.shape)
|
||||
dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {}
|
||||
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict)
|
||||
return target_sharding_spec
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=target_sharding_spec,
|
||||
entire_shape=original_tensor.shape)
|
||||
return layout
|
||||
|
||||
|
||||
def _get_current_name(prefix: str, name: str) -> str:
|
||||
return f'{prefix}.{name}'.lstrip('.')
|
||||
|
||||
|
||||
def generate_sharding_spec_dict(model: nn.Module) -> dict:
|
||||
sharding_spec_dict = {}
|
||||
def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict:
|
||||
layout_dict = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_recursively(module: nn.Module, prefix: str = ''):
|
||||
@@ -49,17 +53,17 @@ def generate_sharding_spec_dict(model: nn.Module) -> dict:
|
||||
# initialize tensors directly attached to the current module
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if isinstance(param, LazyTensor):
|
||||
sharding_spec = make_sharding_spec(param)
|
||||
sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
|
||||
layout = make_layout(device_mesh, param)
|
||||
layout_dict[_get_current_name(prefix, name)] = layout
|
||||
|
||||
for name, buf in module.named_buffers(recurse=False):
|
||||
if isinstance(buf, LazyTensor):
|
||||
sharding_spec = make_sharding_spec(buf)
|
||||
sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
|
||||
layout = make_layout(device_mesh, buf)
|
||||
layout_dict[_get_current_name(prefix, name)] = layout
|
||||
|
||||
generate_recursively(model)
|
||||
|
||||
return sharding_spec_dict
|
||||
return layout_dict
|
||||
|
||||
|
||||
@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
|
||||
@@ -81,9 +85,9 @@ def run_dist_lazy_init(subset, seed: int = 42):
|
||||
ctx = LazyInitContext()
|
||||
with ctx:
|
||||
deferred_model = model_fn()
|
||||
sharding_spec_dict = generate_sharding_spec_dict(deferred_model)
|
||||
ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True)
|
||||
assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict)
|
||||
layout_dict = generate_layout_dict(deferred_model, device_mesh)
|
||||
ctx.distribute(deferred_model, layout_dict, verbose=True)
|
||||
assert_dist_model_equal(model, deferred_model, layout_dict)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port) -> None:
|
||||
|
Reference in New Issue
Block a user