[shardformer] updated doc (#4016)

This commit is contained in:
Frank Lee 2023-06-16 16:15:10 +08:00
parent df018fc305
commit e253a07007

View File

@ -6,9 +6,15 @@
- [📚 Table of Contents](#-table-of-contents) - [📚 Table of Contents](#-table-of-contents)
- [🔗 Introduction](#-introduction) - [🔗 Introduction](#-introduction)
- [🔨 Usage](#-usage) - [🔨 Usage](#-usage)
- [🔮 Simple example](#-simple-example) - [Quick Start](#quick-start)
- [💡 Policy](#-policy) - [Write your own policy](#write-your-own-policy)
- [😊 Module](#-module) - [🗺 Roadmap](#-roadmap)
- [💡 API Design](#-api-design)
- [Distributed Modules](#distributed-modules)
- [Shard Config](#shard-config)
- [Policy](#policy)
- [Model Sharder](#model-sharder)
- [User-facing API](#user-facing-api)
## 🔗 Introduction ## 🔗 Introduction
@ -17,299 +23,303 @@
## 🔨 Usage ## 🔨 Usage
### Quick Start
The sample API usage is given below: The sample API usage is given below:
``` python ``` python
from colossalai.shardformer import ShardConfig, shard_model from colossalai.shardformer import ShardConfig, Shard
from transformers import BertForMaskedLM from transformers import BertForMaskedLM
# launch colossalai
colossalai.launch_from_torch()
# create model
config = BertConfig.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
# create huggingface model as normal # create huggingface model as normal
model = BertForMaskedLM.from_pretrained("bert-base-uncased") shard_config = ShardConfig(tensor_parallel_size=2,
data_parallel_size=1,
gather_output=True)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model).to('cuda')
# make the huggingface model paralleled to ShardModel # do everything like normal
# auto policy:
shardconfig = ShardConfig(
rank=rank,
world_size=world_size,
gather_output=True,
)
sharded_model = shard_model(model, config=shardconfig)
# custom policy:
from xxx import <POLICYCLASS>
sharded_model = shard_model(model, <POLICYCLASS>)
# do angthing as normal
... ...
``` ```
## 🔮 Simple example ### Write your own policy
If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in [API Design](#-api-design).
```python
from colossalai.shardformer import Policy
class MyPolicy(Policy):
# implement your own policy
...
# init model and shard former
...
# use customized policy to shard model
my_policy = MyPolicy()
shard_former.shard_model(model, my_policy)
``` shell
# inference
colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode inference
# train
colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode train
``` ```
## 🗺 Roadmap
## 💡 Policy We will follow this roadmap to develop Shardformer:
If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model. Please refer to any policy that we have pre-established, like [bert policy](./policies/bert.py) or [gpt2 policy](./policies/gpt2.py). - [x] API Design
- [x] API Implementation
- [x] Unit Testing
- [ ] Policy Implementation
- [ ] Hugging Face
- [ ] NLP
- [x] BERT
- [ ] T5
- [ ] LlaMa
- [ ] GPT2
- [ ] BLOOM
- [ ] RoBERTa
- [ ] ALBERT
- [ ] ERNIE
- [ ] GPT Neo
- [ ] GPT-J
- [ ] CV
- [ ] CV
- [ ] ViT
- [ ] BEiT
- [ ] SwinTransformer
- [ ] SwinTransformer V2
- [ ] Audio
- [ ] To be added
- [ ] Multi-modal
- [ ] To be added
You should do: ## 💡 API Design
1. Inherit Policy class We will discuss the major components of `ShardFormer` below to help you better understand how things work.
2. Overwrite `argument_policy` method This section serves as the design doc for Shardformer and the function signature might differ from the actual implementation.
- In this method, you need to list which layers class you wanna modify and the attributes and parameters in those layers. Shardformer will replace all the layer belonging to the class you specified. Please refer to the code for more details.
- `attr_dict` is dict contains all the attributes need to be modified in this layer.
- `param_funcs` is a list contains some functions which will return the path of the weight and bias from the layer. <p align="center">
3. Overwrite `inject_policy` method (Optional) <img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/shardformer_flowchart.png" width="600" />
- Shardformer will inject the model according to this method. If you need to modify the forward or backward progress (like distributed corssentropy loss in Bert) you need to overwrite this method. <br/>
4. Overwrite or add the param functions <b>This diagram is deprecated, need to update it</b>
- These functions use a suffix to record the path of weight or bias for the layer. </p>
- The return is a list contains some `Col_Layer`, `Row_Layer` or `Dropout_Layer` objects, which means slice along col and row respectively or as dropout layer, refer to CLASS `Layer` for more details.
5. Overwrite `binding_policy` (Optional)
- Overwrite to specify Shardformer will bind some weight between layers, like embedding and unembedding layers.
- This function will return a dict, the key and value are the suffix of weight need to be binded. ### Distributed Modules
`ShardFormer` replaces the original PyTorch module with a distributed module.
The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters and defines a new `forward` function to execute distributed computation.
Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module.
```python
class ParallelModule(torch.nn.Module):
@abstractmethod
def from_native_module(module: torch.nn.Module, process_group: Union[ProcessGroup, Tuple[ProcessGroup]]) -> ParallelModule
"""
Convert a native module to a parallelized
Examples:
```python
# replace module
my_linear = Linear1D_Col.from_native_module(my_linear, process_group)
```
"""
```
### Shard Config
`ShardConfig` is a simple data class to tell `ShardFormer` how sharding will be performed.
```python
@dataclass
class ShardConfig:
data_parallel_size: int
tensor_parallel_size: int
...
# Some possible future config fields
pipeline_parallel_size: int # Support pipeline parallelism
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
inference_only: bool # only inject inference-suitable sharding policy
gather_output: bool # gather the model output
use_flash_attention: bool # whether to use flash attention to speed up attention
```
### Policy
The `Policy` class describes how to handle the model sharding.
It is merely a description, the actual sharding will be performed by `ModelSharder`.
We abstract the policy into four stages:
1. Preprocessing: call `Policy.preprocess` to do some prior work before sharding, for example, resizing the embedding
2. Providing a new class: call `Policy.new_model_class` to get a new class for the model, this class replaces attributes and the forward function
3. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted.
4. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
More details can be found in shardformer/policies/basepolicy.py
``` python ``` python
from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument @dataclass
class ModulePolicyDescription:
class CustomPolicy(Policy): """
@staticmethod Describe how the attributes and parameters will be transformed in a policy
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]:
r"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
Args: Args:
model_config (:class:`tansformer.Config`): The config of transformer model attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
shard_config (:class:`ShardConfig`): The config for sharding model param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive two arguments: module, process_group. One example is
def example_replace_weight(module: torch.nn.Module, process_group):
weight = module.weight
new_weight = shard_rowwise(weight, process_group)
module.weight = torch.nn.Parameter(new_weight)
sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies the module to be replaced and the target module used to replacement
"""
attribute_replacement: Dict[str, Any]
param_replacement: List[Callable]
sub_module_replacement: List[SubModuleReplacementDescription]
Return: @dataclass
Dict for the modify policy, class SubModuleReplacementDescription:
:: """
{ Describe how a submodule will be replaced
origin layer class1 (nn.Module): Argument(
attr_dict = { Args:
argument1: value1, suffix (str): used to get the submodule object
argument2: value2, target_module (ParallelModule): specifies the module class used to replace to submodule
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
"""
suffix: str
target_module: ParallelModule
kwargs: Dict[str, Any] = None
class Policy(ABC):
def __init__(self)
self.model = None
def set_model(self, model: nn.Module) -> None:
"""
Set model as an attribute of the Policy object so that we can access the model's attributes.
"""
self.model = model
@abstractmethod
def preprocess(self) -> nn.Module:
"""
Perform some preprocessing on the model, such as resizing the embedding size
"""
... ...
},
param_funcs = [ @abstractmethod
staticmethod1, def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
staticmethod2, """
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
"""
... ...
]
), @abstractmethod
origin layer class2 (nn.Module): Argument( def new_model_class(self) -> Union[Type[nn.Module], None]:
attr_dict = { """
argument1: value1, replace the class of the model to substitute the forward and attributes
argument2: value2, """
... ...
},
param_funcs = [ @abstractmethods
staticmethod1, def postprocess(self) -> nn.Module:
staticmethod2, """
Perform some postprocessing on the model, such as binding the embedding with the weight of the classifier head
"""
... ...
]
),
...
}
"""
raise NotImplementedError
@staticmethod
def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]:
r"""
Return the dict for the inject model
Return:
The injected model, key is the original model and value is the new shardmodel
::
(OrignModel, CustomModel)
in `CustomModel`, we can overwrite the forward and backward process
"""
return None
@staticmethod
def binding_policy() -> Union[Dict[str, str], None]:
r"""
Return the dict for the binding model, None means no need to bind
Return:
This method should return the binding relationship for some layers share the weight or bias,
the key and value is the suffix of the weight or bias of the model
::
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
"""
return None
@staticmethod
def attn_in() -> Union[List, None]:
r"""
Attention qkv layer
In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be
``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters
in ``Layer`` object can refer to the ``Layer`` class.
Returns:
List[Layer]: List of layer object, each layer is the new
"""
return None
@staticmethod
def attn_out() -> Union[List, None]:
r"""
Attention output projection layer
Returns:
List[Layer]: List of layer object
"""
return None
@staticmethod
def mlp_in() -> Union[List, None]:
r"""
h -> 4h mlp layer
Returns:
List[Layer]: List of layer object
"""
return None
@staticmethod
def mlp_out() -> Union[List, None]:
r"""
4h -> h mlp layer
Returns:
List[Layer]: List of layer object
"""
return None
@staticmethod
def embedding() -> Union[List, None]:
r"""
Partially slice the embedding layer
Return:
List[Layer]: List of layer object
"""
return None
@staticmethod
def unembedding() -> Union[List, None]:
r"""
Partially slice the embedding layer, None means there is no unembedding layer
Return:
List[Layer]: List of layer object
"""
return None
``` ```
## 😊 Module ### Model Sharder
1. Flowchart `ModelSharder` is the class in charge of sharding the model based on the given policy.
<p align="center"> ```python
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/shardformer_flowchart.png" width="600" /> class ModelSharder:
</p>
2. Important Modules def __init__(self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None)
#TODO: input is a cls or a obj
- CLASS `shard_model`: def shard(self) -> None:
"""
Shard model with parallelelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
"""
...
This is the user api to use shardformer, just create a model from transformers and define a custom policy or use shardformer autopolicy to make a shard model. def replace_model_class(self) -> None:
"""
Replace the model's methods and attributes with our own defined class.
- CLASS `Layer`: E.g. we can replace the forward function of the original BertForMaskedLM object
with the forward function we define in BertForMaskedLM_ class.
"""
...
Parameters: def replace_module(self) -> None:
- suffix: (str): the suffix of the layer to indicate the attribute of the layer. """
- replace_layer (:class:`colosalai.nn`): The layer to replace the original layer Replace the layer according to the policy. Call Policy.module_policy() to get the module. Call _replace_module recursively.
- ignore (bool): Whether to ignore this layer if it is not in the model """
- reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in], but in GPT2 `Conv1D` layer is [in, out] which is reversed. ...
- n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices, but in multi-head attention, we need to chunk the weight with the number of $ devices * n\_head $, and each device should have a part of Q, K and V weight. ```
This class is a base class used to specify the replacement policy and the suffix the layer for a particular layer. ### User-facing API
CLASS `Col_Layer(Layer)`: We only expose a limited number of APIs to the user to keep their user experience simple and clean.
- weight (str): The weight suffix of the layer
- bias (str): The bias suffix of the layer
- gather_output (bool): Whether the output of this layer can be gathered, like the last layer can be gathered, but most of the time, the intermediate layers of the model do not need to be gathered.
This class inherited from `Layer`, representing the layer will be sliced along colum and indicate the attributes of weight and bias. Setting `bias` to `None` means ignoring bias, regardless of whether or not it originally exists. ```python
class ShardFormer:
"""
Parallelize model based on the given config and policy
CLASS `Row_Layer(Layer)`: Example:
- weight (str): The weight suffix of the layer
- bias (str): The bias suffix of the layer
This class inherited from `Layer`, representing the layer will be sliced along row. Just like `Col_Layer` but in tensor parrallel, there is no need to gather the output of layer sliced by row. shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
model = shard_former.shard_model(model, policy=policy)
dataloader = shard_former.shard_dataset(dataset)
- CLASS `Policy`: """
In Shardformer, this class holds significant importance as it defines the model partitioning methods, required parameter modifications, and model injection techniques all within a single Policy class. def __init__(self, shard_config: ShardConfig):
- `Policy.attn_in()/attn_out()/mlp_in()/mlp_out()/embedding()/unembedding()`...... """
Do two things:
1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp
2. serve as a store for shard config
"""
self.shard_config = shard_config
self.pg_manager = None
These functions define the partitioning methods of the parameters at different locations in the model. Each function returns a list of objects of Layer class that specify the replacement approach for these parameters. Shardformer also supports user-defined functions for modifying their models, in addition to the listed functions. def init_distributed(self) -> colossalai.cluster.ProcessGroupManager:
"""
Initialize the distributed process group according to the
"""
pg_manager = ...
self.pg_manager = pg_manager
return pg_manager
- `Policy.argument_policy()` def shard_model(self, model: torch.nn.Modulepolicy: Policy) -> torch.nn.Module:
"""
Shard model for TP and PP
"""
...
In this function, the user should use multiple dict to define which class of layers will require replacement. This includes the attributes and parameters that need to be modified or replaced. Attributes are stored in the form of a "suffix-string: value" dict, while parameters are stored via multiple static methods that return the replacement approach. def shard_dataset(self, dataset: Dataset) -> Dataloader:
"""
- `Policy.inject_policy()` Shard dataset for DP
"""
This function will return the injected model to replace the original model. The new model should be a nn.Module class which includes modified forward or backward functions or anything else. ...
```
- `Policy.binding_policy()`
This function will return the weight sharing information in the model in some dict. The key and value are both the suffixes of the shared parameters.
- CLASS `ModelSharder(model, policy)`:
This class helps shard the model, the parameter is the created transformers model and the custom policy. If custom policy is None, shardformer will automatically get already defined policy for the model.
- `ModelShard.inject_model()`
This function is used to inject the model to modify the forward and backward progress.
- `ModelShard.replace_layer()`
This function is used to replace the original layers with colossalai layer to make them paralleled and can do distributed communication.
- `ModelShard.bind_layer()`
This function is used to help different layers share weight or bias.
- CLASS `Slicer`:
This class is used to slice tensor according to policy.
3. DistCrossEntropy Loss
- Overview
In order to reduce the communication size, caculate the crossentropy before all gather, refer to [Megatron-LM](https://github.com/NVIDIA/Megatron-LM), reduce the communication size from [batch_size * seq_length * vocab_size] to [batch_size * seq_length]. The origin loss function is:
$$ loss = -\log(\frac{\exp(x[class])}{\sum_i\exp(x[i])})$$
alse can be represented as:
$$ loss = \log(\sum_i\exp(x[i])) - x[class]$$
- Step
- First get the maximum logits across all the devices, make all the logist minus the maximun value to scale the value less than zero to avoid the value of exp being too large
- Get a mask to mask the logits not in the local device
- Caculate the loss according to the second formula