mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-10 12:22:28 +00:00
[shardformer] Refactor shardformer api (#4001)
* fix an error in readme * simplify code * refactor shardformer * add todo * remove slicer * resolve code review
This commit is contained in:
parent
611971248c
commit
d3bc530849
@ -1 +1 @@
|
|||||||
from .shard import ShardConfig, shard_model
|
from .shard import ShardConfig, ShardFormer
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .basepolicy import Policy
|
||||||
|
|
||||||
|
|
||||||
def build_policies():
|
def build_policies():
|
||||||
r"""
|
r"""
|
||||||
@ -41,47 +43,25 @@ def build_policies():
|
|||||||
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
|
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
|
||||||
from transformers.models.llama.modeling_llama import LlamaModel
|
from transformers.models.llama.modeling_llama import LlamaModel
|
||||||
|
|
||||||
from .llama import LlamaPolicy
|
# from .llama import LlamaPolicy
|
||||||
auto_policy_dict[LlamaModel] = LlamaPolicy
|
# auto_policy_dict[LlamaModel] = LlamaPolicy
|
||||||
|
# from transformers import LlamaForSequenceClassification
|
||||||
from transformers import LlamaForSequenceClassification
|
# from .llama import LlamaForSequenceClassificationPolicy
|
||||||
|
# auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy
|
||||||
from .llama import LlamaForSequenceClassificationPolicy
|
# from transformers import LlamaForCausalLM
|
||||||
auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy
|
# from .llama import LlamaForCausalLMPolicy
|
||||||
|
# auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
|
||||||
from transformers import LlamaForCausalLM
|
# from transformers import GPT2Model
|
||||||
|
# from .gpt2 import GPT2Policy
|
||||||
from .llama import LlamaForCausalLMPolicy
|
# auto_policy_dict[GPT2Model] = GPT2Policy
|
||||||
auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
|
# from transformers import GPT2LMHeadModel
|
||||||
|
# from .gpt2 import GPT2LMHeadModelPolicy
|
||||||
from transformers import BertForMultipleChoice
|
# auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
|
||||||
|
|
||||||
from .bert import BertForMultipleChoicePolicy
|
|
||||||
auto_policy_dict[BertForMultipleChoice] = BertForMultipleChoicePolicy
|
|
||||||
|
|
||||||
from transformers import GPT2Model
|
|
||||||
|
|
||||||
from .gpt2 import GPT2Policy
|
|
||||||
auto_policy_dict[GPT2Model] = GPT2Policy
|
|
||||||
|
|
||||||
from transformers import GPT2LMHeadModel
|
|
||||||
|
|
||||||
from .gpt2 import GPT2LMHeadModelPolicy
|
|
||||||
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
|
|
||||||
|
|
||||||
from .t5 import T5ForConditionalGenerationPolicy, T5EncoderModelPolicy, T5ModelPolicy
|
|
||||||
from transformers import T5ForConditionalGeneration, T5EncoderModel, T5Model
|
|
||||||
t5 = {
|
|
||||||
T5ForConditionalGeneration: T5ForConditionalGenerationPolicy,
|
|
||||||
T5EncoderModel: T5EncoderModelPolicy,
|
|
||||||
T5Model: T5ModelPolicy,
|
|
||||||
}
|
|
||||||
auto_policy_dict.update(t5)
|
|
||||||
|
|
||||||
return auto_policy_dict
|
return auto_policy_dict
|
||||||
|
|
||||||
|
|
||||||
def get_autopolicy(model: nn.Module):
|
def get_autopolicy(model: nn.Module) -> Policy:
|
||||||
r"""
|
r"""
|
||||||
Return the auto policy for the model
|
Return the auto policy for the model
|
||||||
|
|
||||||
@ -97,7 +77,7 @@ def get_autopolicy(model: nn.Module):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}"
|
f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}"
|
||||||
)
|
)
|
||||||
return policy
|
return policy()
|
||||||
|
|
||||||
|
|
||||||
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
|
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
|
||||||
|
@ -1,102 +1,65 @@
|
|||||||
# part of code modified from https://github.com/tunib-ai/parallelformers
|
# part of code modified from https://github.com/tunib-ai/parallelformers
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
from typing import Any, Callable, Dict, List, Tuple, Type, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ..shard.shard_config import ShardConfig
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Argument:
|
|
||||||
r"""
|
|
||||||
The argument class for the policy
|
|
||||||
|
|
||||||
Args:
|
class ParallelModule():
|
||||||
attr_dict (Dict[str, Any]): The dict for the param setting
|
|
||||||
param_funcs (:class:`List[Callable]`): The list for the param functions
|
def __init__(self):
|
||||||
"""
|
pass
|
||||||
attr_dict: Dict[str, Any]
|
|
||||||
param_funcs: List[Callable]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Layer:
|
class SubModuleReplacementDescription:
|
||||||
r"""
|
r"""
|
||||||
The layer object for the policy
|
Describe how a submodule will be replaced
|
||||||
|
|
||||||
Args:
|
suffix (str): used to get the submodule object
|
||||||
suffix: (str): the suffix of the layer.
|
target_module (ParallelModule): specifies the module class used to replace to submodule
|
||||||
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
|
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
|
||||||
ignore (bool): Whether to ignore this layer if it is not in the model
|
|
||||||
reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in],
|
|
||||||
but in GPT2 `Conv1D` layer is [in, out] which is reversed.
|
|
||||||
n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices,
|
|
||||||
but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and
|
|
||||||
each device should have a part of Q, K and V weight.
|
|
||||||
"""
|
"""
|
||||||
suffix: str = None
|
suffix: str
|
||||||
replace_layer: Any = None
|
target_module: ParallelModule
|
||||||
ignore: bool = False
|
kwargs: Dict[str, Any] = None
|
||||||
reversed: bool = False
|
|
||||||
n_cast: int = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Col_Layer(Layer):
|
class ModulePolicyDescription:
|
||||||
r"""
|
r"""
|
||||||
Class for col shard layer in tensor parrallel
|
Describe how the attributes and parameters will be transformed in a policy
|
||||||
|
|
||||||
Args:
|
attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
|
||||||
weight (str): The weight suffix of the layer
|
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function
|
||||||
bias (str): The bias suffix of the layer
|
must receive two arguments: module, process_group. One example is
|
||||||
gather_output (bool): Whether to gather the output of the layer
|
|
||||||
|
```python
|
||||||
|
def example_replace_weight(module: torch.nn.Module, process_group):
|
||||||
|
weight = module.weight
|
||||||
|
new_weight = shard_rowwise(weight, process_group)
|
||||||
|
module.weight = torch.nn.Parameter(new_weight)
|
||||||
|
```
|
||||||
|
|
||||||
|
sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies
|
||||||
|
the module to be replaced and the target module used to replacement
|
||||||
"""
|
"""
|
||||||
weight: str = None
|
attribute_replacement: Dict[str, Any]
|
||||||
bias: str = None
|
param_replacement: List[Callable]
|
||||||
gather_output: bool = False
|
sub_module_replacement: List[SubModuleReplacementDescription]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class Policy(ABC):
|
||||||
class Row_Layer(Layer):
|
|
||||||
r"""
|
|
||||||
Class for col shard layer in tensor parrallel
|
|
||||||
|
|
||||||
Args:
|
|
||||||
weight (str): The weight suffix of the layer
|
|
||||||
bias (str): The bias suffix of the layer
|
|
||||||
"""
|
|
||||||
weight: str = None
|
|
||||||
bias: str = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Dropout_Layer(Layer):
|
|
||||||
r"""
|
|
||||||
Class for dropout layer in tensor parrallel
|
|
||||||
|
|
||||||
Args:
|
|
||||||
p (str): The dropout rate suffix of the layer
|
|
||||||
"""
|
|
||||||
p: str = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Embedding_Layer(Layer):
|
|
||||||
r"""
|
|
||||||
Class for col shard layer in tensor parrallel
|
|
||||||
|
|
||||||
Args:
|
|
||||||
weight (str): The weight suffix of the layer
|
|
||||||
"""
|
|
||||||
weight: str = None
|
|
||||||
gather_output: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class Policy():
|
|
||||||
r"""
|
r"""
|
||||||
The base class for all the policies
|
The base class for all the policies
|
||||||
|
|
||||||
For each different model, it should have a different policy class, like BertPolicy for Bert Model
|
For each different model, it should have a different policy class, like BertPolicy for Bert Model
|
||||||
or OPTPolicy for OPT model.
|
or OPTPolicy for OPT model.
|
||||||
|
|
||||||
AutoPolicy:
|
AutoPolicy:
|
||||||
Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None
|
Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None
|
||||||
to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
|
to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
|
||||||
@ -111,137 +74,75 @@ class Policy():
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
def __init__(self) -> None:
|
||||||
def argument_policy(model_config, world_size: int) -> Dict[nn.Module, Argument]:
|
self.model = None
|
||||||
|
|
||||||
|
def set_model(self, model: nn.Module) -> None:
|
||||||
|
r"""
|
||||||
|
Set model as an attribute of the Policy object so that we can access the model's attributes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:class:`nn.Module`): The model to be perform
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def preprocess(self, shard_config: ShardConfig = None) -> nn.Module:
|
||||||
|
r"""
|
||||||
|
Perform some preprocessing of the model, like reshaping the embedding layer
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def module_policy(self, shard_config: ShardConfig = None) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
r"""
|
r"""
|
||||||
Return the dict for the modify policy, the key is the original layer class and the value is the
|
Return the dict for the modify policy, the key is the original layer class and the value is the
|
||||||
argument for the modify layer
|
argument for the modify layer
|
||||||
|
|
||||||
Args:
|
|
||||||
model_config (:class:`tansformer.Config`): The config of transformer model
|
|
||||||
world_size (int)): The world size of sharding model
|
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
Dict for the modify policy,
|
Dict for the modify policy,
|
||||||
::
|
::
|
||||||
{
|
{
|
||||||
origin layer class1 (nn.Module): Argument(
|
origin layer class1 (nn.Module): ModulePolicyDescription(
|
||||||
attr_dict = {
|
attribute_replacement = {
|
||||||
argument1: value1,
|
"attribute1": value1,
|
||||||
argument2: value2,
|
"attribute2": value2,
|
||||||
...
|
...
|
||||||
},
|
},
|
||||||
param_funcs = [
|
param_replacement = [
|
||||||
staticmethod1,
|
function1,
|
||||||
staticmethod2,
|
function2,
|
||||||
|
...
|
||||||
|
],
|
||||||
|
sub_module_replacement = [
|
||||||
|
`SubModuleReplacementDescription` description1,
|
||||||
|
`SubModuleReplacementDescription` description2,
|
||||||
...
|
...
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
origin layer class2 (nn.Module): Argument(
|
origin layer class2 (nn.Module): ModulePolicyDescription(
|
||||||
attr_dict = {
|
...
|
||||||
argument1: value1,
|
|
||||||
argument2: value2,
|
|
||||||
...
|
|
||||||
},
|
|
||||||
param_funcs = [
|
|
||||||
staticmethod1,
|
|
||||||
staticmethod2,
|
|
||||||
...
|
|
||||||
]
|
|
||||||
),
|
),
|
||||||
...
|
...
|
||||||
}
|
}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@staticmethod
|
@abstractmethod
|
||||||
def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]:
|
def new_model_class(self) -> Union[Type[nn.Module], None]:
|
||||||
r"""
|
r"""
|
||||||
Return the dict for the inject model
|
Return the new model class for the new model, None means no need to modify the model class
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
The injected model, key is the original model and value is the new shardmodel
|
New model class
|
||||||
::
|
|
||||||
(OrignModel, CustomModel)
|
|
||||||
in `CustomModel`, we can overwrite the forward and backward process
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
E.g.
|
||||||
def binding_policy() -> Union[Dict[str, str], None]:
|
```
|
||||||
|
return BertModel_
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def postprocess(self) -> nn.Module:
|
||||||
r"""
|
r"""
|
||||||
Return the dict for the binding model, None means no need to bind
|
Perform some postprocessing of the model, like binding the weight of embedding layer with
|
||||||
|
the classifier layer
|
||||||
Return:
|
|
||||||
This method should return the binding relationship for some layers share the weight or bias,
|
|
||||||
the key and value is the suffix of the weight or bias of the model
|
|
||||||
::
|
|
||||||
return {
|
|
||||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
|
||||||
}
|
|
||||||
"""
|
"""
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def attn_in() -> Union[List, None]:
|
|
||||||
r"""
|
|
||||||
Attention qkv layer
|
|
||||||
In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be
|
|
||||||
``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters
|
|
||||||
in ``Layer`` object can refer to the ``Layer`` class.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Layer]: List of layer object, each layer is the new
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def attn_out() -> Union[List, None]:
|
|
||||||
r"""
|
|
||||||
Attention output projection layer
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Layer]: List of layer object
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def mlp_in() -> Union[List, None]:
|
|
||||||
r"""
|
|
||||||
h -> 4h mlp layer
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Layer]: List of layer object
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def mlp_out() -> Union[List, None]:
|
|
||||||
r"""
|
|
||||||
4h -> h mlp layer
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Layer]: List of layer object
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def embedding() -> Union[List, None]:
|
|
||||||
r"""
|
|
||||||
Partially slice the embedding layer
|
|
||||||
|
|
||||||
Return:
|
|
||||||
List[Layer]: List of layer object
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def unembedding() -> Union[List, None]:
|
|
||||||
r"""
|
|
||||||
Partially slice the embedding layer, None means there is no unembedding layer
|
|
||||||
|
|
||||||
Return:
|
|
||||||
List[Layer]: List of layer object
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
@ -1,220 +1,77 @@
|
|||||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
|
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
|
||||||
|
|
||||||
import colossalai.shardformer.layer.layers as col_nn
|
import colossalai.shardformer.layer.layers as col_nn
|
||||||
|
|
||||||
from .basepolicy import Argument, Col_Layer, Dropout_Layer, Policy, Row_Layer
|
from ..shard.shard_config import ShardConfig
|
||||||
|
from ..utils import getattr_, setattr_
|
||||||
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelModule():
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BertPolicy(Policy):
|
class BertPolicy(Policy):
|
||||||
|
|
||||||
@staticmethod
|
def preprocess(self, shard_config: ShardConfig = None):
|
||||||
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
|
# reshape the embedding layer
|
||||||
|
r"""
|
||||||
|
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||||
|
"""
|
||||||
|
# TODO:
|
||||||
|
vocab_size = self.model.config.vocab_size
|
||||||
|
world_size = shard_config.tensor_parallel_size
|
||||||
|
if vocab_size % world_size != 0:
|
||||||
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||||
|
self.model.resize_token_embeddings(new_vocab_size)
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def module_policy(self, shard_config: ShardConfig = None):
|
||||||
return {
|
return {
|
||||||
BertLayer:
|
BertLayer:
|
||||||
Argument(
|
ModulePolicyDescription(
|
||||||
attr_dict={
|
attribute_replacement={
|
||||||
# 1. shard hidden size
|
# 1. shard hidden size
|
||||||
"attention.self.all_head_size": config.hidden_size // world_size,
|
"attention.self.all_head_size":
|
||||||
"crossattention.self.all_head_size": config.hidden_size // world_size,
|
self.model.config.hidden_size // shard_config.tensor_parallel_size,
|
||||||
|
"crossattention.self.all_head_size":
|
||||||
|
self.model.config.hidden_size // shard_config.tensor_parallel_size,
|
||||||
# 2. shard number of heads
|
# 2. shard number of heads
|
||||||
"attention.self.num_attention_heads": config.num_attention_heads // world_size,
|
"attention.self.num_attention_heads":
|
||||||
"crossattention.self.num_attention_heads": config.num_attention_heads // world_size,
|
self.model.config.num_attention_heads // shard_config.tensor_parallel_size,
|
||||||
|
"crossattention.self.num_attention_heads":
|
||||||
|
self.model.config.num_attention_heads // shard_config.tensor_parallel_size,
|
||||||
},
|
},
|
||||||
param_funcs=[BertPolicy.attn_in, BertPolicy.attn_out, BertPolicy.mlp_in, BertPolicy.mlp_out]),
|
param_replacement=[],
|
||||||
BertEmbeddings:
|
sub_module_replacement=[
|
||||||
Argument(
|
SubModuleReplacementDescription(
|
||||||
attr_dict={
|
suffix="attention.self.query",
|
||||||
# 1. shard vocab size
|
target_module=ParallelModule,
|
||||||
"word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size,
|
),
|
||||||
},
|
])
|
||||||
param_funcs=[
|
|
||||||
BertPolicy.embedding,
|
|
||||||
]),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
def new_model_class(self):
|
||||||
def attn_in():
|
# do nothing
|
||||||
return [
|
|
||||||
Col_Layer(
|
|
||||||
suffix="attention.self.query",
|
|
||||||
weight="weight",
|
|
||||||
bias="bias",
|
|
||||||
replace_layer=col_nn.Linear1D_Col,
|
|
||||||
),
|
|
||||||
Col_Layer(
|
|
||||||
suffix="attention.self.key",
|
|
||||||
weight="weight",
|
|
||||||
bias="bias",
|
|
||||||
replace_layer=col_nn.Linear1D_Col,
|
|
||||||
),
|
|
||||||
Col_Layer(
|
|
||||||
suffix="attention.self.value",
|
|
||||||
weight="weight",
|
|
||||||
bias="bias",
|
|
||||||
replace_layer=col_nn.Linear1D_Col,
|
|
||||||
),
|
|
||||||
Dropout_Layer(
|
|
||||||
suffix="attention.self.dropout",
|
|
||||||
p="p",
|
|
||||||
replace_layer=col_nn.Dropout1D,
|
|
||||||
),
|
|
||||||
Col_Layer(
|
|
||||||
suffix="crossattention.self.query",
|
|
||||||
weight="weight",
|
|
||||||
bias="bias",
|
|
||||||
replace_layer=col_nn.Linear1D_Col,
|
|
||||||
ignore=True,
|
|
||||||
),
|
|
||||||
Col_Layer(
|
|
||||||
suffix="crossattention.self.key",
|
|
||||||
weight="weight",
|
|
||||||
bias="bias",
|
|
||||||
replace_layer=col_nn.Linear1D_Col,
|
|
||||||
ignore=True,
|
|
||||||
),
|
|
||||||
Col_Layer(
|
|
||||||
suffix="crossattention.self.value",
|
|
||||||
weight="weight",
|
|
||||||
bias="bias",
|
|
||||||
replace_layer=col_nn.Linear1D_Col,
|
|
||||||
ignore=True,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def attn_out():
|
|
||||||
return [
|
|
||||||
Row_Layer(
|
|
||||||
suffix="attention.output.dense",
|
|
||||||
weight="weight",
|
|
||||||
bias="bias",
|
|
||||||
replace_layer=col_nn.Linear1D_Row,
|
|
||||||
),
|
|
||||||
Dropout_Layer(
|
|
||||||
suffix="attention.output.dropout",
|
|
||||||
p="p",
|
|
||||||
replace_layer=col_nn.Dropout1D,
|
|
||||||
),
|
|
||||||
Row_Layer(
|
|
||||||
suffix="crossattention.output.dense",
|
|
||||||
weight="weight",
|
|
||||||
bias="bias",
|
|
||||||
replace_layer=col_nn.Linear1D_Row,
|
|
||||||
ignore=True,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def mlp_in():
|
|
||||||
return [
|
|
||||||
Col_Layer(
|
|
||||||
suffix="intermediate.dense",
|
|
||||||
weight="weight",
|
|
||||||
bias="bias",
|
|
||||||
replace_layer=col_nn.Linear1D_Col,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def mlp_out():
|
|
||||||
return [
|
|
||||||
Row_Layer(
|
|
||||||
suffix="output.dense",
|
|
||||||
weight="weight",
|
|
||||||
bias="bias",
|
|
||||||
replace_layer=col_nn.Linear1D_Row,
|
|
||||||
),
|
|
||||||
Dropout_Layer(
|
|
||||||
suffix="output.dropout",
|
|
||||||
p="p",
|
|
||||||
replace_layer=col_nn.Dropout1D,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def embedding():
|
|
||||||
return [Col_Layer(
|
|
||||||
suffix="word_embeddings",
|
|
||||||
weight="weight",
|
|
||||||
replace_layer=col_nn.VocabParallelEmbedding1D,
|
|
||||||
)]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def unembedding():
|
|
||||||
return [
|
|
||||||
Col_Layer(
|
|
||||||
suffix="decoder",
|
|
||||||
weight="weight",
|
|
||||||
bias="bias",
|
|
||||||
replace_layer=col_nn.Linear1D_Col,
|
|
||||||
gather_output=True,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# BertModel
|
|
||||||
class BertModelPolicy(BertPolicy):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def argument_policy(config, world_size):
|
|
||||||
return BertPolicy.argument_policy(config, world_size)
|
|
||||||
|
|
||||||
|
|
||||||
# BertForPretraining
|
|
||||||
class BertForPretrainingPolicy(BertPolicy):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def argument_policy(config, world_size):
|
|
||||||
base_argument = BertPolicy.argument_policy(config, world_size)
|
|
||||||
argument = {
|
|
||||||
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
|
|
||||||
BertPolicy.unembedding,
|
|
||||||
]),
|
|
||||||
}
|
|
||||||
argument.update(base_argument)
|
|
||||||
return argument
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def inject_policy():
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
def postprocess(self):
|
||||||
def binding_policy():
|
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||||
return {
|
for k, v in binding_map.items():
|
||||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
param = getattr_(self.model, k)
|
||||||
}
|
param = nn.Parameter(param)
|
||||||
|
setattr_(self.model, k, param)
|
||||||
|
setattr_(self.model, v, param)
|
||||||
# BertForMaskedLM
|
return self.model
|
||||||
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
|
|
||||||
|
|
||||||
|
|
||||||
class BertForMaskedLMPolicy(BertPolicy):
|
class BertForMaskedLMPolicy(BertPolicy):
|
||||||
|
|
||||||
@staticmethod
|
def __init__(self) -> None:
|
||||||
def argument_policy(config, world_size):
|
super().__init__()
|
||||||
base_argument = BertPolicy.argument_policy(config, world_size)
|
|
||||||
argument = {
|
|
||||||
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
|
|
||||||
BertPolicy.unembedding,
|
|
||||||
]),
|
|
||||||
}
|
|
||||||
argument.update(base_argument)
|
|
||||||
return argument
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def inject_policy():
|
|
||||||
# return (BertForMaskedLM, BertForMaskedLM_)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def binding_policy():
|
|
||||||
return {
|
|
||||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# BertLMHeadModel
|
# BertLMHeadModel
|
||||||
@ -231,36 +88,5 @@ class BertLMHeadModelPolicy(BertPolicy):
|
|||||||
argument.update(base_argument)
|
argument.update(base_argument)
|
||||||
return argument
|
return argument
|
||||||
|
|
||||||
@staticmethod
|
def __init__(self) -> None:
|
||||||
def inject_policy():
|
super().__init__()
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def binding_policy():
|
|
||||||
return {
|
|
||||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# BertForNextSentencePrediction
|
|
||||||
class BertForNextSentencePredictionPolicy(BertPolicy):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def argument_policy(config, world_size):
|
|
||||||
return BertPolicy.argument_policy(config, world_size)
|
|
||||||
|
|
||||||
|
|
||||||
# BertForSequenceClassification
|
|
||||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def argument_policy(config, world_size):
|
|
||||||
return BertPolicy.argument_policy(config, world_size)
|
|
||||||
|
|
||||||
|
|
||||||
# BertForMultipleChoice
|
|
||||||
class BertForMultipleChoicePolicy(BertPolicy):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def argument_policy(config, world_size):
|
|
||||||
return BertPolicy.argument_policy(config, world_size)
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from .shard_config import ShardConfig
|
from .shard_config import ShardConfig
|
||||||
from .sharder import ModelSharder, shard_model
|
from .sharder import ModelSharder
|
||||||
from .slicer import Slicer
|
from .shardformer import ShardFormer
|
||||||
|
|
||||||
__all__ = ['ShardConfig', 'ModelSharder', 'shard_model', 'Slicer']
|
__all__ = ['ShardConfig', 'ModelSharder', 'ShardFormer']
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Literal
|
||||||
|
|
||||||
__all__ = ['ShardConfig']
|
__all__ = ['ShardConfig']
|
||||||
|
|
||||||
@ -9,10 +10,18 @@ class ShardConfig:
|
|||||||
The config for sharding the huggingface model
|
The config for sharding the huggingface model
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
rank (int): The rank of local process
|
data_parallel_size (int): The size of data parallel
|
||||||
world_size (int): The world size of the distributed process
|
tensor_parallel_size (int): The size of tensor parallel
|
||||||
|
pipeline_parallel_size (int): The size of pipeline parallel
|
||||||
|
tensor_parallel_mode (List): The mode of tensor parallel, choose from `['1d','2d','2.5d','3d']
|
||||||
|
inference_only (bool): Whether to use the inference only mode, when setting to `True`, the model
|
||||||
|
will not calculate the loss and just return the output.
|
||||||
gather_output (bool): Whether to gather the output of the model of the last layer
|
gather_output (bool): Whether to gather the output of the model of the last layer
|
||||||
"""
|
"""
|
||||||
rank: int = None
|
data_parallel_size: int
|
||||||
world_size: int = None
|
tensor_parallel_size: int
|
||||||
|
|
||||||
|
pipeline_parallel_size: int
|
||||||
|
tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||||
|
inference_only: bool = True
|
||||||
gather_output: bool = True
|
gather_output: bool = True
|
||||||
|
@ -4,11 +4,12 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.pytorch_utils import Conv1D
|
from transformers.pytorch_utils import Conv1D
|
||||||
|
|
||||||
|
from colossalai.cluster.process_group_manager import ProcessGroupManager
|
||||||
|
|
||||||
from ..policies.autopolicy import get_autopolicy
|
from ..policies.autopolicy import get_autopolicy
|
||||||
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer, Embedding_Layer
|
from ..policies.basepolicy import Policy
|
||||||
from ..utils.utils import getattr_, hasattr_, setattr_
|
from ..utils.utils import setattr_
|
||||||
from .shard_config import ShardConfig
|
from .shard_config import ShardConfig
|
||||||
from .slicer import Slicer
|
|
||||||
|
|
||||||
__all__ = ['ModelSharder', 'shard_model']
|
__all__ = ['ModelSharder', 'shard_model']
|
||||||
|
|
||||||
@ -28,20 +29,23 @@ class ModelSharder(object):
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
policy: Policy,
|
policy: Policy,
|
||||||
shard_config: ShardConfig = None, # TODO
|
shard_config: ShardConfig = None, # TODO
|
||||||
) -> None:
|
pg_manager: ProcessGroupManager = None) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.policy = get_autopolicy(self.model) if policy is None else policy
|
self.policy = get_autopolicy(self.model) if policy is None else policy
|
||||||
self.slicer = Slicer(shard_config)
|
|
||||||
self.shard_config = shard_config
|
self.shard_config = shard_config
|
||||||
self.model_config = self.model.config
|
self.pg_manager = pg_manager
|
||||||
|
|
||||||
def shard(self) -> None:
|
def shard(self) -> None:
|
||||||
self.reshape_embedding()
|
r"""
|
||||||
self.inject_model(self.model)
|
Shard the model according to the policy
|
||||||
self.replace_layer(self.model)
|
"""
|
||||||
self.bind_layer(self.model)
|
self.policy.set_model(self.model)
|
||||||
|
self.preprocess()
|
||||||
|
self.replace_model_class()
|
||||||
|
self.replace_module()
|
||||||
|
self.postprocess()
|
||||||
|
|
||||||
def reshape_embedding(self,) -> None:
|
def reshape_embedding(self) -> None:
|
||||||
r"""
|
r"""
|
||||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||||
"""
|
"""
|
||||||
@ -52,10 +56,13 @@ class ModelSharder(object):
|
|||||||
self.model.resize_token_embeddings(new_vocab_size)
|
self.model.resize_token_embeddings(new_vocab_size)
|
||||||
self.model_config = self.model.config
|
self.model_config = self.model.config
|
||||||
|
|
||||||
def inject_model(
|
def preprocess(self) -> None:
|
||||||
self,
|
self.model = self.policy.preprocess(self.shard_config)
|
||||||
model: nn.Module,
|
|
||||||
) -> None:
|
def postprocess(self) -> None:
|
||||||
|
self.model = self.policy.postprocess()
|
||||||
|
|
||||||
|
def replace_model_class(self,) -> None:
|
||||||
r"""
|
r"""
|
||||||
Replace the model to policy defined model
|
Replace the model to policy defined model
|
||||||
Mainly modify the forward and backward to fit distributed model
|
Mainly modify the forward and backward to fit distributed model
|
||||||
@ -64,49 +71,43 @@ class ModelSharder(object):
|
|||||||
::
|
::
|
||||||
BertForMaskedLM.forward -> BertForMaskedLM_.forward
|
BertForMaskedLM.forward -> BertForMaskedLM_.forward
|
||||||
"""
|
"""
|
||||||
inject_policy = self.policy.inject_policy()
|
new_model_class = self.policy.new_model_class()
|
||||||
if inject_policy is None:
|
if new_model_class is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if inject_policy is None:
|
for key in new_model_class.__dict__.keys():
|
||||||
return
|
if hasattr(self.model.__class__, key):
|
||||||
org_model_cls = inject_policy[0]
|
setattr(
|
||||||
shard_model_cls = inject_policy[1]
|
self.model.__class__,
|
||||||
|
key,
|
||||||
|
getattr(new_model_class, key),
|
||||||
|
)
|
||||||
|
|
||||||
if model.__class__ == org_model_cls:
|
def replace_module(self,) -> None:
|
||||||
for key in shard_model_cls.__dict__.keys():
|
|
||||||
if hasattr(model.__class__, key):
|
|
||||||
setattr(
|
|
||||||
model.__class__,
|
|
||||||
key,
|
|
||||||
getattr(shard_model_cls, key),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"{model.__class__} is not implemented so far")
|
|
||||||
|
|
||||||
def replace_layer(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
) -> None:
|
|
||||||
r"""
|
r"""
|
||||||
Replace the layer according to the policy, and replace the layer one by one
|
Replace the module according to the policy, and replace the module one by one
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (:class:`torch.nn.Module`): The layer to shard
|
model (:class:`torch.nn.Module`): The model to shard
|
||||||
"""
|
"""
|
||||||
argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size)
|
print(self.policy)
|
||||||
for argument_policy in argument_policies.items():
|
module_descriptions = self.policy.module_policy(self.shard_config)
|
||||||
origin_layer_cls = argument_policy[0]
|
print(f"*******{module_descriptions}")
|
||||||
attr_dict = argument_policy[1].attr_dict
|
for module_description in module_descriptions.items():
|
||||||
param_funcs = argument_policy[1].param_funcs
|
origin_layer_cls = module_description[0]
|
||||||
self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
|
attr_replacement = module_description[1].attribute_replacement
|
||||||
|
param_replacement = module_description[1].param_replacement
|
||||||
|
sub_module_replacement = module_description[1].sub_module_replacement
|
||||||
|
self._recursive_replace_layer(self.model, origin_layer_cls, attr_replacement, param_replacement,
|
||||||
|
sub_module_replacement)
|
||||||
|
|
||||||
def traverse_replace_layer(
|
def _recursive_replace_layer(
|
||||||
self,
|
self,
|
||||||
layer: nn.Module,
|
module: nn.Module,
|
||||||
origin_cls: nn.Module,
|
origin_cls: nn.Module,
|
||||||
attr_dict: Dict[str, Any],
|
attr_replacement: Dict[str, Any],
|
||||||
param_funcs: List[Callable],
|
param_replacement: List[Callable],
|
||||||
|
sub_module_replacement: List[Callable],
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Reverse the replace layer operation
|
Reverse the replace layer operation
|
||||||
@ -114,21 +115,52 @@ class ModelSharder(object):
|
|||||||
Args:
|
Args:
|
||||||
layer (:class:`torch.nn.Module`): The object of layer to shard
|
layer (:class:`torch.nn.Module`): The object of layer to shard
|
||||||
origin_cls (:class:`transformers.model`): The origin layer class
|
origin_cls (:class:`transformers.model`): The origin layer class
|
||||||
attr_dict (Dict): The attribute dict to modify
|
attr_replacement (Dict): The attribute dict to modify
|
||||||
policy_cls (:class:`Policy`): The policy class
|
param_replacement (List[Callable]): The function list to get parameter shard information in polic
|
||||||
|
sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy
|
||||||
"""
|
"""
|
||||||
if layer.__class__ == origin_cls:
|
if module.__class__ == origin_cls:
|
||||||
for k, v in attr_dict.items():
|
self._replace_attr(module, attr_replacement)
|
||||||
setattr_(layer, k, v, ignore=True)
|
self._replace_param(module, param_replacement)
|
||||||
self.shard_one_layer(layer, param_funcs)
|
self._replace_sub_module(module, sub_module_replacement)
|
||||||
for name, child in layer.named_children():
|
for name, child in module.named_children():
|
||||||
self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
|
self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement,
|
||||||
return layer
|
sub_module_replacement)
|
||||||
|
|
||||||
def shard_one_layer(
|
def _replace_attr(
|
||||||
|
self,
|
||||||
|
module: nn.Module,
|
||||||
|
attr_replacement: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Replace the attribute of the layer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer (:class:`torch.nn.Module`): The object of layer to shard
|
||||||
|
attr_replacement (Dict): The attribute dict to modify
|
||||||
|
"""
|
||||||
|
for k, v in attr_replacement.items():
|
||||||
|
setattr_(module, k, v, ignore=True)
|
||||||
|
|
||||||
|
def _replace_param(
|
||||||
|
self,
|
||||||
|
module: nn.Module,
|
||||||
|
param_replacement: List[Callable],
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Replace the parameter of the layer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer (:class:`torch.nn.Module`): The object of layer to shard
|
||||||
|
param_replacement (List[Callable]): The function list to get parameter shard information in policy
|
||||||
|
"""
|
||||||
|
# TODO: support parameter shard
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _replace_sub_module(
|
||||||
self,
|
self,
|
||||||
org_layer: nn.Module,
|
org_layer: nn.Module,
|
||||||
param_funcs: List[Callable],
|
sub_module_replacement: List[Callable],
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
|
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
|
||||||
@ -138,145 +170,14 @@ class ModelSharder(object):
|
|||||||
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
|
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
|
||||||
|
|
||||||
"""
|
"""
|
||||||
for func in param_funcs:
|
for description in sub_module_replacement:
|
||||||
policy_layers = func()
|
suffix = description.suffix
|
||||||
for policy_layer in policy_layers:
|
target_module = description.target_module
|
||||||
suffix = policy_layer.suffix
|
kwargs = description.kwargs
|
||||||
replace_layer_cls = policy_layer.replace_layer
|
|
||||||
ignore = policy_layer.ignore
|
|
||||||
reversed = policy_layer.reversed
|
|
||||||
n_cast = policy_layer.n_cast
|
|
||||||
|
|
||||||
assert replace_layer_cls is not None, 'replace_layer should not be None'
|
assert target_module is not None, 'target_module should not be None'
|
||||||
|
|
||||||
# create new object to replace the origin layer
|
# TODO: integrate with new layer
|
||||||
# Linear
|
# replace_layer = target_module.from_native_layer(org_layer, self.pg_manager)
|
||||||
suffix_layer = getattr_(org_layer, suffix, ignore=True)
|
replace_layer = None
|
||||||
assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}"
|
setattr_(org_layer, suffix, replace_layer)
|
||||||
if suffix_layer is None and ignore:
|
|
||||||
continue
|
|
||||||
if isinstance(policy_layer, (Col_Layer, Row_Layer, Embedding_Layer)):
|
|
||||||
weight = None
|
|
||||||
bias = None
|
|
||||||
weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None
|
|
||||||
bias_attr = suffix + '.' + policy_layer.bias if hasattr(policy_layer, 'bias') and policy_layer.bias is not None else None
|
|
||||||
|
|
||||||
if weight_attr is not None:
|
|
||||||
if hasattr_(org_layer, weight_attr):
|
|
||||||
weight = getattr_(org_layer, weight_attr)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}")
|
|
||||||
|
|
||||||
if bias_attr is not None:
|
|
||||||
if hasattr_(org_layer, bias_attr):
|
|
||||||
bias = getattr_(org_layer, bias_attr)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}")
|
|
||||||
|
|
||||||
# set the sliced weight and bias to the new nn_col layer
|
|
||||||
assert weight is not None or bias is not None
|
|
||||||
|
|
||||||
# slice weight and bias
|
|
||||||
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed)
|
|
||||||
|
|
||||||
if replace_layer_cls.__name__ == "Linear1D_Row":
|
|
||||||
replace_layer = replace_layer_cls(weight.shape[1],
|
|
||||||
weight.shape[0],
|
|
||||||
bias=False if bias is None else True)
|
|
||||||
elif replace_layer_cls.__name__ == "Linear1D_Col":
|
|
||||||
gather_output = policy_layer.gather_output and self.shard_config.gather_output
|
|
||||||
replace_layer = replace_layer_cls(weight.shape[0],
|
|
||||||
weight.shape[1],
|
|
||||||
bias=False if bias is None else True,
|
|
||||||
gather_output=gather_output)
|
|
||||||
elif replace_layer_cls.__name__ == "Embedding1D":
|
|
||||||
gather_output = policy_layer.gather_output
|
|
||||||
replace_layer = replace_layer_cls(weight.shape[0],
|
|
||||||
weight.shape[1],
|
|
||||||
gather_output=gather_output)
|
|
||||||
elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D":
|
|
||||||
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1],
|
|
||||||
getattr_(org_layer, f"{suffix}.padding_idx", ignore=True))
|
|
||||||
# setattr_(org_layer, suffix, replace_layer, ignore=ignore)
|
|
||||||
# self.set_param(replace_layer, weight, bias)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Replacing to {replace_layer_cls.__name__} is not implemented so far")
|
|
||||||
setattr_(org_layer, suffix, replace_layer, ignore=ignore)
|
|
||||||
self.set_param(replace_layer, weight, bias)
|
|
||||||
# dropout
|
|
||||||
elif isinstance(policy_layer, Dropout_Layer):
|
|
||||||
p_attr = suffix + '.' + policy_layer.p
|
|
||||||
p = getattr_(org_layer, p_attr, ignore=True)
|
|
||||||
replace_layer = replace_layer_cls(p)
|
|
||||||
setattr_(org_layer, suffix, replace_layer, ignore=ignore)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Replacing {getattr_(org_layer, suffix).__class__} is not implemented so far")
|
|
||||||
|
|
||||||
def set_param(self,
|
|
||||||
layer: Any,
|
|
||||||
weight: torch.Tensor = None,
|
|
||||||
bias: torch.Tensor = None,
|
|
||||||
layer_attr: str = "") -> None:
|
|
||||||
r"""
|
|
||||||
Reset the weight and bias of the layer object
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer (:class:`torch.nn.Module`): The layer object
|
|
||||||
layer_attr (str): The attribute name of the layer
|
|
||||||
weight (:class:`torch.Tensor`): The weight of the layer
|
|
||||||
bias (:class:`torch.Tensor`): The bias of the layer
|
|
||||||
"""
|
|
||||||
assert weight is not None or bias is not None
|
|
||||||
if weight is not None:
|
|
||||||
setattr_(layer, "weight" if layer_attr == "" else layer_attr + ".weight", nn.Parameter(weight.contiguous()))
|
|
||||||
self.set_layer_size(layer, layer_attr, weight.shape)
|
|
||||||
if bias is not None:
|
|
||||||
setattr_(layer, "bias" if layer_attr == "" else layer_attr + ".bias", nn.Parameter(bias.contiguous()))
|
|
||||||
|
|
||||||
def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None:
|
|
||||||
r"""
|
|
||||||
Set the layer attribute
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer (:class:`torch.nn.Module`): The layer object
|
|
||||||
layer_attr (str): The attribute name of the layer
|
|
||||||
size (:class:`torch.Size`): The size of the tensor
|
|
||||||
"""
|
|
||||||
# Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features
|
|
||||||
attrs = ["out_features", "in_features"]
|
|
||||||
for i, attr in enumerate(attrs):
|
|
||||||
if hasattr_(layer, f"{layer_attr}.{attr}"):
|
|
||||||
setattr_(layer, f"{layer_attr}.{attr}", size[i])
|
|
||||||
|
|
||||||
def bind_layer(self, model: nn.Module) -> None:
|
|
||||||
r"""
|
|
||||||
Bind the layer according to the binding policy
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (:class:`torch.nn.Module`): The shard model
|
|
||||||
"""
|
|
||||||
binding_map = self.policy.binding_policy()
|
|
||||||
if binding_map is None:
|
|
||||||
return
|
|
||||||
for k, v in binding_map.items():
|
|
||||||
param = getattr_(model, k)
|
|
||||||
param = nn.Parameter(param)
|
|
||||||
setattr_(model, k, param)
|
|
||||||
setattr_(model, v, param)
|
|
||||||
|
|
||||||
|
|
||||||
def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Policy = None):
|
|
||||||
r"""
|
|
||||||
The function is used to shard the PyTorch model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (`torch.nn.Model`): the origin huggingface model
|
|
||||||
shard_config (`ShardConfig`): the config for distribute information
|
|
||||||
policy (`Policy`): the custom policy for sharding
|
|
||||||
"""
|
|
||||||
# TODO: init shard_config automatically
|
|
||||||
sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy)
|
|
||||||
sharder.shard()
|
|
||||||
return model
|
|
||||||
|
77
colossalai/shardformer/shard/shardformer.py
Normal file
77
colossalai/shardformer/shard/shardformer.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from colossalai.cluster import DistCoordinator, ProcessGroupManager
|
||||||
|
|
||||||
|
from ..policies.basepolicy import Policy
|
||||||
|
from .shard_config import ShardConfig
|
||||||
|
from .sharder import ModelSharder
|
||||||
|
|
||||||
|
|
||||||
|
class ShardFormer:
|
||||||
|
"""
|
||||||
|
Parallelize model based on the given config and policy
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from colossalai.shardformer import ShardFormer, ShardConfig
|
||||||
|
from transformers import BertForMaskedLM
|
||||||
|
import colossalai
|
||||||
|
import torch
|
||||||
|
|
||||||
|
colossalai.launch_from_torch(config={})
|
||||||
|
|
||||||
|
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||||
|
shard_config = ShardConfig(
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
data_parallel_size=1,
|
||||||
|
pipeline_parallel_size=1,
|
||||||
|
tensor_parallel_mode='1d',
|
||||||
|
inference_only=True,
|
||||||
|
gather_output=True
|
||||||
|
)
|
||||||
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
|
shard_former.init_distributed()
|
||||||
|
model = shard_former.shard_model(org_model)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, shard_config: ShardConfig):
|
||||||
|
"""
|
||||||
|
Do two things:
|
||||||
|
1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp
|
||||||
|
2. serve as a store for
|
||||||
|
"""
|
||||||
|
self.coordinator = DistCoordinator()
|
||||||
|
self.shard_config = shard_config
|
||||||
|
self.pg_manager = None
|
||||||
|
|
||||||
|
def init_distributed(self) -> ProcessGroupManager:
|
||||||
|
"""
|
||||||
|
Initialize the distributed process group according to the
|
||||||
|
"""
|
||||||
|
pg_manager = ProcessGroupManager()
|
||||||
|
if (self.shard_config.tensor_parallel_mode == '1d'):
|
||||||
|
pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size))
|
||||||
|
self.pg_manager = pg_manager
|
||||||
|
return pg_manager
|
||||||
|
|
||||||
|
def shard_model(self, model: nn.Module, policy: Policy = None):
|
||||||
|
r"""
|
||||||
|
The function is used to shard the PyTorch model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (`torch.nn.Model`): the origin huggingface model
|
||||||
|
shard_config (`ShardConfig`): the config for distribute information
|
||||||
|
policy (`Policy`): the custom policy for sharding
|
||||||
|
"""
|
||||||
|
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy, pg_manager=self.pg_manager)
|
||||||
|
sharder.shard()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def shard_dataset(self, dataset: Dataset):
|
||||||
|
"""
|
||||||
|
Shard dataset for DP
|
||||||
|
"""
|
||||||
|
pass
|
@ -1,163 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer, Embedding_Layer
|
|
||||||
from .shard_config import ShardConfig
|
|
||||||
|
|
||||||
dim_mapping = {Col_Layer: 0, Row_Layer: 1, Embedding_Layer: 1}
|
|
||||||
|
|
||||||
|
|
||||||
class Slicer():
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
shardconfig: ShardConfig #TODO
|
|
||||||
) -> None:
|
|
||||||
self.shardconfig = shardconfig
|
|
||||||
|
|
||||||
def slice_weight_bias(
|
|
||||||
self,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
bias: torch.Tensor,
|
|
||||||
policy_layer_cls: Layer,
|
|
||||||
n_cast: int = None,
|
|
||||||
reversed: bool = False,
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Slice the weight and bias according to policy layer cls
|
|
||||||
``Layer`` -> do nothing
|
|
||||||
``Col_Layer`` -> slice the weight and bias along dim 1
|
|
||||||
``Row_Layer`` -> slice the weight along dim 0 and do not slice bias
|
|
||||||
|
|
||||||
Args:
|
|
||||||
weight (:class:`torch.nn.Module`): The weight of the layer
|
|
||||||
bias: (:class:`torch.nn.Module`): The bias of the layer
|
|
||||||
policy_layer_class (:class:`Policy`): The class represent how to slice the tensor
|
|
||||||
"""
|
|
||||||
if policy_layer_cls in [Layer, Dropout_Layer]:
|
|
||||||
return weight, bias
|
|
||||||
|
|
||||||
dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls])
|
|
||||||
# print(weight.shape, dim)
|
|
||||||
if policy_layer_cls == Col_Layer:
|
|
||||||
weight = self.slice_tensor(weight, dim, False, n_cast)
|
|
||||||
bias = self.slice_tensor(bias, 0, True, n_cast)
|
|
||||||
elif policy_layer_cls == Row_Layer:
|
|
||||||
weight = self.slice_tensor(weight, dim, False, n_cast)
|
|
||||||
elif policy_layer_cls == Embedding_Layer:
|
|
||||||
weight = self.slice_tensor(weight, dim, False, n_cast)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
|
|
||||||
if reversed:
|
|
||||||
weight = weight.transpose(0, 1).contiguous()
|
|
||||||
return weight, bias
|
|
||||||
|
|
||||||
def slice_tensor(
|
|
||||||
self,
|
|
||||||
tensor_in: torch.Tensor,
|
|
||||||
dim: int,
|
|
||||||
is_bias: bool,
|
|
||||||
n_cast: int = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
r"""
|
|
||||||
Slice tensor according to the config
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor_in (:class:`torch.Tensor`): The tensor to slice
|
|
||||||
dim (int): The dimension to slice
|
|
||||||
is_bias (bool): Whether the tensor is bias
|
|
||||||
"""
|
|
||||||
if tensor_in is None:
|
|
||||||
return None
|
|
||||||
if not is_bias:
|
|
||||||
return self.slice_2d(tensor_in, dim, n_cast)
|
|
||||||
else:
|
|
||||||
return self.slice_1d(tensor_in, n_cast)
|
|
||||||
|
|
||||||
def slice_2d(
|
|
||||||
self,
|
|
||||||
tensor: torch.Tensor,
|
|
||||||
dim: int,
|
|
||||||
n_cast: int = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
r"""
|
|
||||||
Slice the 2D tensor
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor (:class:`torch.Tensor`): The tensor to slice
|
|
||||||
dim (int): The dimension to slice
|
|
||||||
"""
|
|
||||||
assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor"
|
|
||||||
if dim == 0:
|
|
||||||
return self.slice_row(tensor, n_cast)
|
|
||||||
elif dim == 1:
|
|
||||||
return self.slice_col(tensor, n_cast)
|
|
||||||
|
|
||||||
def slice_1d(
|
|
||||||
self,
|
|
||||||
tensor: torch.Tensor,
|
|
||||||
n_cast: int = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
r"""
|
|
||||||
Slice the 1D tensor
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor (:class:`torch.Tensor`): The tensor to slice
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:class:`torch.Tensor`: The sliced tensor
|
|
||||||
"""
|
|
||||||
if n_cast is None:
|
|
||||||
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
|
||||||
else:
|
|
||||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
|
|
||||||
chunk_list = [
|
|
||||||
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
|
||||||
]
|
|
||||||
return torch.cat(chunk_list, dim=0).contiguous()
|
|
||||||
|
|
||||||
def slice_col(
|
|
||||||
self,
|
|
||||||
tensor: torch.Tensor,
|
|
||||||
n_cast: int = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
r"""
|
|
||||||
Slice the tensor in column
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor (:class:`torch.Tensor`): The tensor to slice
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:class:`torch.Tensor`: The sliced tensor
|
|
||||||
|
|
||||||
"""
|
|
||||||
if n_cast is None:
|
|
||||||
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
|
|
||||||
else:
|
|
||||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
|
|
||||||
chunk_list = [
|
|
||||||
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
|
||||||
]
|
|
||||||
return torch.cat(chunk_list, dim=1).contiguous()
|
|
||||||
|
|
||||||
def slice_row(
|
|
||||||
self,
|
|
||||||
tensor: torch.Tensor,
|
|
||||||
n_cast: int = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
r"""
|
|
||||||
Slice the tensor in column
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor (:class:`torch.Tensor`): The tensor to slice
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:class:`torch.Tensor`: The sliced tensor
|
|
||||||
"""
|
|
||||||
if n_cast is None:
|
|
||||||
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
|
||||||
else:
|
|
||||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
|
|
||||||
chunk_list = [
|
|
||||||
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
|
||||||
]
|
|
||||||
return torch.cat(chunk_list, dim=0).contiguous()
|
|
@ -0,0 +1 @@
|
|||||||
|
from .utils import getattr_, hasattr_, setattr_
|
Loading…
Reference in New Issue
Block a user