mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[zero] Update sharded model v2 using sharded param v2 (#323)
This commit is contained in:
@@ -7,12 +7,14 @@ import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from common import CONFIG
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils.tensor_shard_strategy import \
|
||||
TensorShardStrategy
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
from common import CONFIG, Net
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
@@ -25,11 +27,11 @@ def run_dist(rank, world_size, port):
|
||||
shard_param=True):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
for param in model.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
assert param.ca_attr.data.dtype == torch.half
|
||||
assert param.ca_attr._data_sharded_tensor.is_sharded
|
||||
assert param.ca_attr.data.device.type == 'cuda'
|
||||
for param in model.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
assert param.col_attr.data.dtype == torch.half
|
||||
assert param.col_attr.data.is_sharded
|
||||
assert param.col_attr.data.payload.device.type == 'cuda'
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
||||
Reference in New Issue
Block a user