[zero] sharded model support the reuse of fp16 shard (#495)

* sharded model supports reuse fp16 shard

* rename variable

* polish code

* polish code

* polish code
This commit is contained in:
ver217
2022-03-23 14:59:59 +08:00
committed by GitHub
parent f24b5ed201
commit 9ec1ce6ab1
7 changed files with 62 additions and 42 deletions

View File

@@ -16,7 +16,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_params_padding)
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params)
def run_dist(rank, world_size, port, parallel_config):
@@ -87,7 +87,7 @@ def run_dist(rank, world_size, port, parallel_config):
if parallel_config == MP_PARALLEL_CONFIG:
check_params(torch_model, colo_model, loose=True)
elif parallel_config == ZERO_PARALLEL_CONFIG:
check_sharded_params_padding(torch_model, colo_model, loose=True)
check_sharded_model_params(torch_model, colo_model, loose=True)
# FIXME: enable this test in next PR