mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[shardformer] support DDP in HybridPlugin/add tp+dp tests (#4446)
* support DDP for HybridPlugin/add tp+dp tests * add docstring for HybridParallelPlugin
This commit is contained in:
@@ -13,6 +13,7 @@ from torch.optim import Adam, Optimizer
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
@@ -259,3 +260,15 @@ def check_grad(org_model: Module,
|
||||
assert torch.allclose(
|
||||
org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol
|
||||
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
|
||||
|
||||
|
||||
def unwrap_model(module: Module,
|
||||
base_model_class_name: Optional[str] = None,
|
||||
base_model_attribute_name: Optional[str] = None):
|
||||
if isinstance(module, HybridParallelModule):
|
||||
module = module.unwrap()
|
||||
if base_model_class_name is None:
|
||||
return module
|
||||
if module.__class__.__name__ == base_model_class_name:
|
||||
return module
|
||||
return getattr(module, base_model_attribute_name, None)
|
||||
|
Reference in New Issue
Block a user