[shardformer] support DDP in HybridPlugin/add tp+dp tests (#4446)

* support DDP for HybridPlugin/add tp+dp tests

* add docstring for HybridParallelPlugin
This commit is contained in:
Baizhou Zhang
2023-08-16 16:11:57 +08:00
committed by GitHub
parent 424629fea0
commit 6ef33f75aa
10 changed files with 199 additions and 100 deletions

View File

@@ -13,6 +13,7 @@ from torch.optim import Adam, Optimizer
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
@@ -259,3 +260,15 @@ def check_grad(org_model: Module,
assert torch.allclose(
org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
def unwrap_model(module: Module,
base_model_class_name: Optional[str] = None,
base_model_attribute_name: Optional[str] = None):
if isinstance(module, HybridParallelModule):
module = module.unwrap()
if base_model_class_name is None:
return module
if module.__class__.__name__ == base_model_class_name:
return module
return getattr(module, base_model_attribute_name, None)