mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-16 07:03:40 +00:00
Merge pull request #4176 from ver217/feature/pipeline-policy
[pipeline] fit shardformer policy
This commit is contained in:
commit
12e6d5df6d
@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .basepolicy import Policy
|
from .base_policy import Policy
|
||||||
|
|
||||||
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
|
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
|
||||||
|
|
@ -2,9 +2,13 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Type, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
from ..shard.shard_config import ShardConfig
|
from ..shard.shard_config import ShardConfig
|
||||||
|
|
||||||
@ -71,9 +75,8 @@ class Policy(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.shard_config = None
|
self.shard_config: Optional[ShardConfig] = None
|
||||||
self.model = None
|
self.model: Optional[Module] = None
|
||||||
self.shard_config = None
|
|
||||||
|
|
||||||
def set_model(self, model: nn.Module) -> None:
|
def set_model(self, model: nn.Module) -> None:
|
||||||
r"""
|
r"""
|
||||||
@ -94,6 +97,12 @@ class Policy(ABC):
|
|||||||
self.shard_config = shard_config
|
self.shard_config = shard_config
|
||||||
self.config_sanity_check()
|
self.config_sanity_check()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pipeline_stage_manager(self) -> Optional[PipelineStageManager]:
|
||||||
|
if self.shard_config is not None:
|
||||||
|
return self.shard_config.pipeline_stage_manager
|
||||||
|
return None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def config_sanity_check(self):
|
def config_sanity_check(self):
|
||||||
"""
|
"""
|
||||||
@ -151,3 +160,19 @@ class Policy(ABC):
|
|||||||
policy[target_key] = ModulePolicyDescription(sub_module_replacement=description)
|
policy[target_key] = ModulePolicyDescription(sub_module_replacement=description)
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get layers that should be held in current stage. This method should be implemented by subclass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Module]: List of layers that should be hold in current stage
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
"""Get parameters that should be shared across stages. This method should be implemented by subclass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
|
||||||
|
"""
|
||||||
|
return []
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
|||||||
import colossalai.shardformer.layer as col_nn
|
import colossalai.shardformer.layer as col_nn
|
||||||
|
|
||||||
from .._utils import getattr_, setattr_
|
from .._utils import getattr_, setattr_
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy',
|
'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy',
|
||||||
|
@ -4,7 +4,7 @@ import colossalai.shardformer.layer as col_nn
|
|||||||
|
|
||||||
from .._utils import getattr_, setattr_
|
from .._utils import getattr_, setattr_
|
||||||
from ..modeling.bloom import build_bloom_alibi_tensor_fn
|
from ..modeling.bloom import build_bloom_alibi_tensor_fn
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
|
||||||
class BloomPolicy(Policy):
|
class BloomPolicy(Policy):
|
||||||
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
|||||||
import colossalai.shardformer.layer as col_nn
|
import colossalai.shardformer.layer as col_nn
|
||||||
|
|
||||||
from .._utils import getattr_, setattr_
|
from .._utils import getattr_, setattr_
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy',
|
'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy',
|
||||||
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||||
|
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
|
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||||
|
|
||||||
from .._utils import getattr_, setattr_
|
from .._utils import getattr_, setattr_
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy',
|
'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy',
|
||||||
|
@ -6,10 +6,10 @@ from colossalai.shardformer.layer import (
|
|||||||
Linear1D_Row,
|
Linear1D_Row,
|
||||||
VocabParallelEmbedding1D,
|
VocabParallelEmbedding1D,
|
||||||
)
|
)
|
||||||
from colossalai.shardformer.policies.basepolicy import ModulePolicyDescription
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
|
||||||
|
|
||||||
from .._utils import getattr_, setattr_
|
from .._utils import getattr_, setattr_
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
|
__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row
|
from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row
|
||||||
|
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = ['ViTPolicy']
|
__all__ = ['ViTPolicy']
|
||||||
|
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
__all__ = ['ShardConfig']
|
__all__ = ['ShardConfig']
|
||||||
|
|
||||||
|
|
||||||
@ -12,12 +15,14 @@ class ShardConfig:
|
|||||||
The config for sharding the huggingface model
|
The config for sharding the huggingface model
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
|
tensor_parallel_process_group (Optional[ProcessGroup]): The process group for tensor parallelism, defaults to None, which is the global process group.
|
||||||
|
pipeline_stage_manager (Optional[PipelineStageManager]): The pipeline stage manager, defaults to None, which means no pipeline.
|
||||||
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
|
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
|
||||||
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
|
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
|
||||||
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
|
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
|
||||||
"""
|
"""
|
||||||
tensor_parallel_process_group: ProcessGroup = None
|
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
||||||
|
pipeline_stage_manager: Optional[PipelineStageManager] = None
|
||||||
enable_tensor_parallelism: bool = True
|
enable_tensor_parallelism: bool = True
|
||||||
enable_fused_normalization: bool = False
|
enable_fused_normalization: bool = False
|
||||||
enable_all_optimization: bool = False
|
enable_all_optimization: bool = False
|
||||||
|
@ -1,11 +1,15 @@
|
|||||||
from typing import Any, Callable, Dict, List, Union
|
from typing import Any, Callable, Dict, List, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from colossalai.lazy import LazyTensor
|
||||||
|
|
||||||
from .._utils import getattr_, setattr_
|
from .._utils import getattr_, setattr_
|
||||||
from ..policies.autopolicy import get_autopolicy
|
from ..policies.auto_policy import get_autopolicy
|
||||||
from ..policies.basepolicy import Policy, SubModuleReplacementDescription
|
from ..policies.base_policy import Policy, SubModuleReplacementDescription
|
||||||
from .shard_config import ShardConfig
|
from .shard_config import ShardConfig
|
||||||
|
from .utils import set_tensors_to_none
|
||||||
|
|
||||||
__all__ = ['ModelSharder', 'shard_model']
|
__all__ = ['ModelSharder', 'shard_model']
|
||||||
|
|
||||||
@ -25,15 +29,18 @@ class ModelSharder(object):
|
|||||||
self.policy = get_autopolicy(self.model) if policy is None else policy
|
self.policy = get_autopolicy(self.model) if policy is None else policy
|
||||||
self.shard_config = shard_config
|
self.shard_config = shard_config
|
||||||
|
|
||||||
def shard(self) -> None:
|
def shard(self) -> List[Dict[int, Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Shard the model according to the policy
|
Shard the model according to the policy
|
||||||
"""
|
"""
|
||||||
self.policy.set_model(self.model)
|
self.policy.set_model(self.model)
|
||||||
self.policy.set_shard_config(self.shard_config)
|
self.policy.set_shard_config(self.shard_config)
|
||||||
self._preprocess()
|
self._preprocess()
|
||||||
|
self._release_unheld_layers()
|
||||||
self._replace_module()
|
self._replace_module()
|
||||||
|
self._materialize()
|
||||||
self._postprocess()
|
self._postprocess()
|
||||||
|
return self.policy.get_shared_params()
|
||||||
|
|
||||||
def _preprocess(self) -> None:
|
def _preprocess(self) -> None:
|
||||||
self.model = self.policy.preprocess()
|
self.model = self.policy.preprocess()
|
||||||
@ -172,3 +179,23 @@ class ModelSharder(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
setattr_(org_layer, suffix, replace_layer)
|
setattr_(org_layer, suffix, replace_layer)
|
||||||
|
|
||||||
|
def _release_unheld_layers(self) -> None:
|
||||||
|
r"""
|
||||||
|
Release the unheld layers in the model
|
||||||
|
"""
|
||||||
|
if self.shard_config and self.shard_config.pipeline_stage_manager:
|
||||||
|
held_layers = self.policy.get_held_layers()
|
||||||
|
set_tensors_to_none(self.model, exclude=set(held_layers))
|
||||||
|
|
||||||
|
def _materialize(self) -> None:
|
||||||
|
r"""
|
||||||
|
Materialize the model if lazy initialization is used
|
||||||
|
"""
|
||||||
|
for p in self.model.parameters():
|
||||||
|
if isinstance(p, LazyTensor):
|
||||||
|
p.materialize()
|
||||||
|
|
||||||
|
for b in self.model.buffers():
|
||||||
|
if isinstance(b, LazyTensor):
|
||||||
|
b.materialize()
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
|
||||||
from ..policies.basepolicy import Policy
|
from ..policies.base_policy import Policy
|
||||||
from .shard_config import ShardConfig
|
from .shard_config import ShardConfig
|
||||||
from .sharder import ModelSharder
|
from .sharder import ModelSharder
|
||||||
|
|
||||||
@ -24,7 +27,7 @@ class ShardFormer:
|
|||||||
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||||
shard_config = ShardConfig()
|
shard_config = ShardConfig()
|
||||||
shard_former = ShardFormer(shard_config=shard_config)
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
model = shard_former.optimize(org_model)
|
model, shared_params = shard_former.optimize(org_model)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -32,7 +35,7 @@ class ShardFormer:
|
|||||||
self.coordinator = DistCoordinator()
|
self.coordinator = DistCoordinator()
|
||||||
self.shard_config = shard_config
|
self.shard_config = shard_config
|
||||||
|
|
||||||
def optimize(self, model: nn.Module, policy: Policy = None):
|
def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
|
||||||
r"""
|
r"""
|
||||||
This method will optimize the model based on the given policy.
|
This method will optimize the model based on the given policy.
|
||||||
|
|
||||||
@ -40,7 +43,9 @@ class ShardFormer:
|
|||||||
model (`torch.nn.Model`): the origin huggingface model
|
model (`torch.nn.Model`): the origin huggingface model
|
||||||
shard_config (`ShardConfig`): the config for distribute information
|
shard_config (`ShardConfig`): the config for distribute information
|
||||||
policy (`Policy`): the custom policy for sharding
|
policy (`Policy`): the custom policy for sharding
|
||||||
|
|
||||||
|
Returns: the sharded model and the shared parameters
|
||||||
"""
|
"""
|
||||||
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
|
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
|
||||||
sharder.shard()
|
shared_params = sharder.shard()
|
||||||
return model
|
return model, shared_params
|
||||||
|
19
colossalai/shardformer/shard/utils.py
Normal file
19
colossalai/shardformer/shard/utils.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from typing import Set
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> None:
|
||||||
|
"""Set all parameters and buffers of model to None
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The model to set
|
||||||
|
"""
|
||||||
|
if model in exclude:
|
||||||
|
return
|
||||||
|
for child in model.children():
|
||||||
|
set_tensors_to_none(child, exclude=exclude)
|
||||||
|
for n, p in model.named_parameters(recurse=False):
|
||||||
|
setattr(model, n, None)
|
||||||
|
for n, buf in model.named_buffers(recurse=False):
|
||||||
|
setattr(model, n, None)
|
@ -12,8 +12,8 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle
|
|||||||
enable_tensor_parallelism=enable_tensor_parallelism)
|
enable_tensor_parallelism=enable_tensor_parallelism)
|
||||||
model_copy = copy.deepcopy(org_model)
|
model_copy = copy.deepcopy(org_model)
|
||||||
shard_former = ShardFormer(shard_config=shard_config)
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
sharded_model = shard_former.optimize(model_copy).cuda()
|
sharded_model, shared_params = shard_former.optimize(model_copy)
|
||||||
return org_model, sharded_model
|
return org_model, sharded_model.cuda()
|
||||||
|
|
||||||
|
|
||||||
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||||
|
27
tests/test_shardformer/test_shard_utils.py
Normal file
27
tests/test_shardformer/test_shard_utils.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
|
||||||
|
self.out = nn.Linear(3, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_release_layer():
|
||||||
|
orig_cuda_allocated = torch.cuda.memory_allocated()
|
||||||
|
model = Net().cuda()
|
||||||
|
set_tensors_to_none(model, exclude={model.layers[0]})
|
||||||
|
assert model.layers[1].weight is None
|
||||||
|
assert model.layers[1].bias is None
|
||||||
|
assert model.out.weight is None
|
||||||
|
assert model.out.bias is None
|
||||||
|
set_tensors_to_none(model)
|
||||||
|
assert model.layers[0].weight is None
|
||||||
|
assert model.layers[0].bias is None
|
||||||
|
assert len(list(model.parameters())) == 0
|
||||||
|
assert torch.cuda.memory_allocated() == orig_cuda_allocated
|
@ -44,7 +44,7 @@ def check_shardformer_with_ddp(rank, world_size, port):
|
|||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
# create and shard model
|
# create and shard model
|
||||||
model = model_fn().cuda()
|
model = model_fn().cuda()
|
||||||
sharded_model = shardformer.optimize(model)
|
sharded_model, _ = shardformer.optimize(model)
|
||||||
|
|
||||||
# add ddp
|
# add ddp
|
||||||
sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group)
|
sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group)
|
||||||
|
Loading…
Reference in New Issue
Block a user