mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[test] fixed tests failed due to dtensor change (#4082)
* [test] fixed tests failed due to dtensor change * polish code
This commit is contained in:
@@ -6,6 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
|
||||
from colossalai.tensor.d_tensor import to_global
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
@@ -82,7 +83,8 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
|
||||
print(f'{model.__class__.__name__} pass')
|
||||
|
||||
|
||||
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None:
|
||||
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh,
|
||||
sharding_spec_dict: dict) -> None:
|
||||
state = model.state_dict()
|
||||
distributed_state = distributed_model.state_dict()
|
||||
|
||||
|
Reference in New Issue
Block a user