diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/auto_policy.py similarity index 99% rename from colossalai/shardformer/policies/autopolicy.py rename to colossalai/shardformer/policies/auto_policy.py index 085e3150c..8e961a240 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -3,7 +3,7 @@ from dataclasses import dataclass import torch.nn as nn -from .basepolicy import Policy +from .base_policy import Policy __all__ = ["PolicyLocation", "get_autopolicy", "import_policy"] diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/base_policy.py similarity index 83% rename from colossalai/shardformer/policies/basepolicy.py rename to colossalai/shardformer/policies/base_policy.py index 2d347542f..16f3fa14e 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -2,9 +2,13 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Type, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch.nn as nn +from torch import Tensor +from torch.nn import Module + +from colossalai.pipeline.stage_manager import PipelineStageManager from ..shard.shard_config import ShardConfig @@ -71,9 +75,8 @@ class Policy(ABC): """ def __init__(self) -> None: - self.shard_config = None - self.model = None - self.shard_config = None + self.shard_config: Optional[ShardConfig] = None + self.model: Optional[Module] = None def set_model(self, model: nn.Module) -> None: r""" @@ -94,6 +97,12 @@ class Policy(ABC): self.shard_config = shard_config self.config_sanity_check() + @property + def pipeline_stage_manager(self) -> Optional[PipelineStageManager]: + if self.shard_config is not None: + return self.shard_config.pipeline_stage_manager + return None + @abstractmethod def config_sanity_check(self): """ @@ -151,3 +160,19 @@ class Policy(ABC): policy[target_key] = ModulePolicyDescription(sub_module_replacement=description) return policy + + def get_held_layers(self) -> List[Module]: + """Get layers that should be held in current stage. This method should be implemented by subclass. + + Returns: + List[Module]: List of layers that should be hold in current stage + """ + raise NotImplementedError + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """Get parameters that should be shared across stages. This method should be implemented by subclass. + + Returns: + List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] + """ + return [] diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 9c2736cc6..b69ee7209 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -3,7 +3,7 @@ import torch.nn as nn import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ 'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy', diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index a0b5340f7..8d6f07d4a 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -4,7 +4,7 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ from ..modeling.bloom import build_bloom_alibi_tensor_fn -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription class BloomPolicy(Policy): diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 549cdbf87..598f393c0 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -3,7 +3,7 @@ import torch.nn as nn import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ 'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy', diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 157785bdc..391938b27 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -4,7 +4,7 @@ import torch.nn as nn from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index b87db53f4..c4c6cde01 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,7 +1,7 @@ from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .._utils import getattr_, setattr_ -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ 'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy', diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index cde59ab77..6167e8161 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -6,10 +6,10 @@ from colossalai.shardformer.layer import ( Linear1D_Row, VocabParallelEmbedding1D, ) -from colossalai.shardformer.policies.basepolicy import ModulePolicyDescription +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription from .._utils import getattr_, setattr_ -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index eaebe2eee..3f6bbd106 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -4,7 +4,7 @@ import torch.nn as nn from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['ViTPolicy'] diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 83c08d275..75fad4eb7 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -1,8 +1,11 @@ from dataclasses import dataclass +from typing import Optional import torch.distributed as dist from torch.distributed import ProcessGroup +from colossalai.pipeline.stage_manager import PipelineStageManager + __all__ = ['ShardConfig'] @@ -12,12 +15,14 @@ class ShardConfig: The config for sharding the huggingface model Args: - tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. + tensor_parallel_process_group (Optional[ProcessGroup]): The process group for tensor parallelism, defaults to None, which is the global process group. + pipeline_stage_manager (Optional[PipelineStageManager]): The pipeline stage manager, defaults to None, which means no pipeline. enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. enable_all_optimization (bool): Whether to turn on all optimization, default is False. """ - tensor_parallel_process_group: ProcessGroup = None + tensor_parallel_process_group: Optional[ProcessGroup] = None + pipeline_stage_manager: Optional[PipelineStageManager] = None enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 201e0a08c..ca2f46a18 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -1,11 +1,15 @@ from typing import Any, Callable, Dict, List, Union import torch.nn as nn +from torch import Tensor + +from colossalai.lazy import LazyTensor from .._utils import getattr_, setattr_ -from ..policies.autopolicy import get_autopolicy -from ..policies.basepolicy import Policy, SubModuleReplacementDescription +from ..policies.auto_policy import get_autopolicy +from ..policies.base_policy import Policy, SubModuleReplacementDescription from .shard_config import ShardConfig +from .utils import set_tensors_to_none __all__ = ['ModelSharder', 'shard_model'] @@ -25,15 +29,18 @@ class ModelSharder(object): self.policy = get_autopolicy(self.model) if policy is None else policy self.shard_config = shard_config - def shard(self) -> None: + def shard(self) -> List[Dict[int, Tensor]]: r""" Shard the model according to the policy """ self.policy.set_model(self.model) self.policy.set_shard_config(self.shard_config) self._preprocess() + self._release_unheld_layers() self._replace_module() + self._materialize() self._postprocess() + return self.policy.get_shared_params() def _preprocess(self) -> None: self.model = self.policy.preprocess() @@ -172,3 +179,23 @@ class ModelSharder(object): ) setattr_(org_layer, suffix, replace_layer) + + def _release_unheld_layers(self) -> None: + r""" + Release the unheld layers in the model + """ + if self.shard_config and self.shard_config.pipeline_stage_manager: + held_layers = self.policy.get_held_layers() + set_tensors_to_none(self.model, exclude=set(held_layers)) + + def _materialize(self) -> None: + r""" + Materialize the model if lazy initialization is used + """ + for p in self.model.parameters(): + if isinstance(p, LazyTensor): + p.materialize() + + for b in self.model.buffers(): + if isinstance(b, LazyTensor): + b.materialize() diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index 3fce12463..7a0d75bf2 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -1,8 +1,11 @@ +from typing import Dict, List, Tuple + import torch.nn as nn +from torch import Tensor from colossalai.cluster import DistCoordinator -from ..policies.basepolicy import Policy +from ..policies.base_policy import Policy from .shard_config import ShardConfig from .sharder import ModelSharder @@ -24,7 +27,7 @@ class ShardFormer: org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') shard_config = ShardConfig() shard_former = ShardFormer(shard_config=shard_config) - model = shard_former.optimize(org_model) + model, shared_params = shard_former.optimize(org_model) ``` """ @@ -32,7 +35,7 @@ class ShardFormer: self.coordinator = DistCoordinator() self.shard_config = shard_config - def optimize(self, model: nn.Module, policy: Policy = None): + def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]: r""" This method will optimize the model based on the given policy. @@ -40,7 +43,9 @@ class ShardFormer: model (`torch.nn.Model`): the origin huggingface model shard_config (`ShardConfig`): the config for distribute information policy (`Policy`): the custom policy for sharding + + Returns: the sharded model and the shared parameters """ sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy) - sharder.shard() - return model + shared_params = sharder.shard() + return model, shared_params diff --git a/colossalai/shardformer/shard/utils.py b/colossalai/shardformer/shard/utils.py new file mode 100644 index 000000000..2bac37bfe --- /dev/null +++ b/colossalai/shardformer/shard/utils.py @@ -0,0 +1,19 @@ +from typing import Set + +import torch.nn as nn + + +def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> None: + """Set all parameters and buffers of model to None + + Args: + model (nn.Module): The model to set + """ + if model in exclude: + return + for child in model.children(): + set_tensors_to_none(child, exclude=exclude) + for n, p in model.named_parameters(recurse=False): + setattr(model, n, None) + for n, buf in model.named_buffers(recurse=False): + setattr(model, n, None) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index d83d9ecd3..e03014f3f 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -12,8 +12,8 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle enable_tensor_parallelism=enable_tensor_parallelism) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) - sharded_model = shard_former.optimize(model_copy).cuda() - return org_model, sharded_model + sharded_model, shared_params = shard_former.optimize(model_copy) + return org_model, sharded_model.cuda() def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): diff --git a/tests/test_shardformer/test_shard_utils.py b/tests/test_shardformer/test_shard_utils.py new file mode 100644 index 000000000..220b8291c --- /dev/null +++ b/tests/test_shardformer/test_shard_utils.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn + +from colossalai.shardformer.shard.utils import set_tensors_to_none + + +class Net(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.layers = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) + self.out = nn.Linear(3, 1) + + +def test_release_layer(): + orig_cuda_allocated = torch.cuda.memory_allocated() + model = Net().cuda() + set_tensors_to_none(model, exclude={model.layers[0]}) + assert model.layers[1].weight is None + assert model.layers[1].bias is None + assert model.out.weight is None + assert model.out.bias is None + set_tensors_to_none(model) + assert model.layers[0].weight is None + assert model.layers[0].bias is None + assert len(list(model.parameters())) == 0 + assert torch.cuda.memory_allocated() == orig_cuda_allocated diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py index 9f8a5db6c..f29c8d6f6 100644 --- a/tests/test_shardformer/test_with_torch_ddp.py +++ b/tests/test_shardformer/test_with_torch_ddp.py @@ -44,7 +44,7 @@ def check_shardformer_with_ddp(rank, world_size, port): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): # create and shard model model = model_fn().cuda() - sharded_model = shardformer.optimize(model) + sharded_model, _ = shardformer.optimize(model) # add ddp sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group)