[zero] polish sharded param name (#484)

* [zero] polish sharded param name

* polish code

* polish

* polish code

* polish

* polsih

* polish
This commit is contained in:
Jiarui Fang
2022-03-22 14:36:16 +08:00
committed by GitHub
parent 9caa8b6481
commit b334822163
12 changed files with 55 additions and 222 deletions

View File

@@ -17,7 +17,6 @@ def test_bucket_copy():
for shape in shape_list:
# on CPU
src_param = torch.nn.Parameter(torch.randn(shape, dtype=torch.float, device=torch.device('cpu')))
print(src_param)
# on GPU
tgt_param = ShardedParamV2(torch.nn.Parameter(torch.ones(shape, dtype=torch.half, device=torch.device('cuda'))))
@@ -29,9 +28,10 @@ def test_bucket_copy():
copyer.flush()
for src_param, tgt_param in zip(src_param_list, tgt_param_list):
print(tgt_param.data.payload)
diff = src_param.cpu().float() - tgt_param.data.payload.cpu().float()
assert torch.allclose(src_param.cpu().float(), tgt_param.data.payload.cpu().float(), rtol=1e-03,
diff = src_param.cpu().float() - tgt_param.sharded_data_tensor.payload.cpu().float()
assert torch.allclose(src_param.cpu().float(),
tgt_param.sharded_data_tensor.payload.cpu().float(),
rtol=1e-03,
atol=1e-03), f"diff {diff}"

View File

@@ -119,7 +119,7 @@ def check_params_padding(model, zero_model, loose=False):
def check_sharded_params_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_p = zero_p.col_attr.data.payload.to(p.device).float()
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue

View File

@@ -34,10 +34,10 @@ def run_model_test(init_device, shard_strategy_class):
for param in model.parameters():
assert hasattr(param, 'col_attr')
assert param.col_attr.data.dtype == torch.half
assert param.col_attr.data.is_sharded
assert param.col_attr.data.payload.device.type == init_device.type, \
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
assert param.col_attr.sharded_data_tensor.dtype == torch.half
assert param.col_attr.sharded_data_tensor.is_sharded
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
print(f'numel {model_numel_tensor}')

View File

@@ -1,6 +1,3 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from copy import deepcopy
from functools import partial
@@ -8,13 +5,11 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.testing import parameterize
from colossalai.utils import free_port
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
from colossalai.zero.sharded_param import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_zero_data_parallel.common import CONFIG, allclose
@@ -52,7 +47,7 @@ def _run_shard_param_v2(rank, world_size, port):
param_ref = deepcopy(param)
sparam = ShardedParamV2(param=param, process_group=None)
allclose(sparam.data.payload, param_ref.data)
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
sparam.remove_torch_payload()
assert (param.data.numel() == 1)
@@ -65,69 +60,6 @@ def test_shard_param_v2(world_size):
mp.spawn(run_func, nprocs=world_size)
def _run_test_shard_param(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
param = torch.nn.Parameter(torch.randn(2, 3))
param_ref = deepcopy(param)
sparam = ShardedParamV2(param=param, process_group=None)
print(sparam.data)
print(param_ref.data)
logger = get_dist_logger()
for get_components_func in non_distributed_component_funcs:
model_builder, *_ = get_components_func()
model = model_builder(checkpoint=True)
# add an attribute as col_attr to hijack the access to param.data
for _, param in model.named_parameters():
numel_ref = (param.numel() + world_size - 1) // world_size
param.col_attr = ShardedParam(param)
param.col_attr.shard()
param_data = param.col_attr.payload(torch.device('cpu'))
assert (numel_ref == param_data.numel())
for _, param in model.named_parameters():
param.col_attr.gather()
param_data = param.col_attr.payload(torch.device('cpu'))
disable_existing_loggers([logger])
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
def test_shard_param(world_size):
run_func = partial(_run_test_shard_param, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
def _run_init_shard_param(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
param = torch.nn.Parameter(data=torch.rand(world_size, 3))
sparam = ShardedParam(param, None, True)
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [3])
del sparam
param_shape = (world_size, 3)
sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu'))
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [3])
param_shape = (world_size, 3)
sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu'))
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [world_size, 3])
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 4])
def test_init_shard_param(world_size):
run_func = partial(_run_init_shard_param, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_shard_tensor(2)
test_shard_param(2)
test_shard_param_v2(2)
test_init_shard_param(4)