mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[zero] sharded model support the reuse of fp16 shard (#495)
* sharded model supports reuse fp16 shard * rename variable * polish code * polish code * polish code
This commit is contained in:
@@ -16,7 +16,8 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
|
||||
offload_config=None,
|
||||
gradient_predivide_factor=1.0,
|
||||
use_memory_tracer=False,
|
||||
shard_strategy=TensorShardStrategy())
|
||||
shard_strategy=TensorShardStrategy(),
|
||||
reuse_fp16_shard=False)
|
||||
|
||||
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
|
||||
initial_scale=2**5,
|
||||
@@ -116,10 +117,13 @@ def check_params_padding(model, zero_model, loose=False):
|
||||
assert allclose(p, zero_p, loose=loose)
|
||||
|
||||
|
||||
def check_sharded_params_padding(model, zero_model, loose=False):
|
||||
def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
|
||||
if reuse_fp16_shard:
|
||||
zero_p = zero_p.data.to(p.device).float()
|
||||
else:
|
||||
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
|
||||
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
|
@@ -18,7 +18,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from common import CONFIG, check_sharded_params_padding
|
||||
from common import CONFIG, check_sharded_model_params
|
||||
|
||||
|
||||
def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
|
||||
@@ -65,7 +65,8 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
||||
zero_model = ShardedModelV2(zero_model,
|
||||
shard_strategy,
|
||||
offload_config=dict(device='cpu') if cpu_offload else None,
|
||||
use_memory_tracer=gpu_margin_mem_ratio > 0.0)
|
||||
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
|
||||
reuse_fp16_shard=use_cpuadam)
|
||||
|
||||
model = model_builder(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
@@ -92,7 +93,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
||||
data, label = data.cuda(), label.cuda()
|
||||
_run_step(apex_model, apex_optimizer, data, label, criterion, False)
|
||||
_run_step(zero_model, sharded_optim, data, label, criterion, False)
|
||||
check_sharded_params_padding(model, zero_model, loose=True)
|
||||
check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam)
|
||||
for param in model.parameters():
|
||||
assert not has_inf_or_nan(param)
|
||||
|
||||
|
@@ -16,7 +16,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_params_padding)
|
||||
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, parallel_config):
|
||||
@@ -87,7 +87,7 @@ def run_dist(rank, world_size, port, parallel_config):
|
||||
if parallel_config == MP_PARALLEL_CONFIG:
|
||||
check_params(torch_model, colo_model, loose=True)
|
||||
elif parallel_config == ZERO_PARALLEL_CONFIG:
|
||||
check_sharded_params_padding(torch_model, colo_model, loose=True)
|
||||
check_sharded_model_params(torch_model, colo_model, loose=True)
|
||||
|
||||
|
||||
# FIXME: enable this test in next PR
|
||||
|
Reference in New Issue
Block a user