mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[doc] add shardformer support matrix/update tensor parallel documents (#4728)
* add compatibility matrix for shardformer doc * update tp doc
This commit is contained in:
@@ -2,14 +2,12 @@
|
||||
|
||||
Author: Zhengda Bian, Yongbin Li
|
||||
|
||||
> ⚠️ The information on this page is outdated and will be deprecated. Please check [Shardformer](./shardformer.md) for more information.
|
||||
|
||||
**Prerequisite**
|
||||
- [Define Your Configuration](../basics/define_your_config.md)
|
||||
- [Configure Parallelization](../basics/configure_parallelization.md)
|
||||
|
||||
**Example Code**
|
||||
- [ColossalAI-Examples 1D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md)
|
||||
- [Tensor Parallelism with Shardformer](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/shardformer/examples)
|
||||
|
||||
**Related Paper**
|
||||
- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf)
|
||||
@@ -44,79 +42,7 @@ Given $P$ processors, we present the theoretical computation and memory cost, as
|
||||
|
||||
## Usage
|
||||
|
||||
To enable 1D tensor parallelism for our model, e.g. on 2 GPUs, we need to configure the parallelism setting as below.
|
||||
```python
|
||||
CONFIG = dict(parallel=dict(
|
||||
data=1,
|
||||
pipeline=1,
|
||||
tensor=dict(size=2, mode='1d'),
|
||||
))
|
||||
```
|
||||
Then Colossal-AI will automatically apply 1D parallelism to all the layers from `colossalai.nn`.
|
||||
|
||||
Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below.
|
||||
```python
|
||||
import colossalai
|
||||
import colossalai.nn as col_nn
|
||||
import torch
|
||||
from colossalai.utils import print_rank_0
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, dim: int = 256):
|
||||
super().__init__()
|
||||
intermediate_dim = dim * 4
|
||||
self.dense_1 = col_nn.Linear(dim, intermediate_dim)
|
||||
print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.transpose(0, 1).shape}')
|
||||
self.activation = torch.nn.GELU()
|
||||
self.dense_2 = col_nn.Linear(intermediate_dim, dim)
|
||||
print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.transpose(0, 1).shape}')
|
||||
self.dropout = col_nn.Dropout(0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dense_1(x)
|
||||
print_rank_0(f'Output of the first linear layer: {x.shape}')
|
||||
x = self.activation(x)
|
||||
x = self.dense_2(x)
|
||||
print_rank_0(f'Output of the second linear layer: {x.shape}')
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
```
|
||||
|
||||
Launch Colossal-AI on 2 GPUs and build the model.
|
||||
|
||||
```python
|
||||
parser = colossalai.get_default_parser()
|
||||
colossalai.launch(config=CONFIG,
|
||||
rank=args.rank,
|
||||
world_size=args.world_size,
|
||||
local_rank=args.local_rank,
|
||||
host=args.host,
|
||||
port=args.port)
|
||||
|
||||
m = MLP()
|
||||
```
|
||||
We will see the shapes of partitioned parameters(e.g. weights) in the MLP model.
|
||||
```shell
|
||||
Weight of the first linear layer: torch.Size([256, 512])
|
||||
Weight of the second linear layer: torch.Size([512, 256])
|
||||
```
|
||||
The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the column-parallel partitioning, it becomes `[256, 512]`.
|
||||
Similarly, the second row-parallel layer partitions the weight `[1024, 256]` into `[512, 256]`.
|
||||
|
||||
We can run the model with some random inputs.
|
||||
```python
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
x = torch.randn((16, 256), device=get_current_device())
|
||||
torch.distributed.broadcast(x, src=0) # synchronize input
|
||||
|
||||
x = m(x)
|
||||
```
|
||||
Then we can see the shapes of activation results.
|
||||
```shell
|
||||
Output of the first linear layer: torch.Size([16, 512])
|
||||
Output of the second linear layer: torch.Size([16, 256])
|
||||
```
|
||||
The output of the first linear layer is split into 2 partitions (each has the shape `[16, 512]`), while the second layer has identical outputs across the GPUs.
|
||||
1D tensor parallelism is implemented by `Shardformer` feature in the newest version of ColossalAI.
|
||||
For more details about ideas and usages of `Shardformer`, please refer to [Shardformer Doc](./shardformer.md).
|
||||
|
||||
<!-- doc-test-command: echo -->
|
||||
|
@@ -60,83 +60,9 @@ Given $P=q\times q$ processors, we present the theoretical computation and memor
|
||||
|
||||
## Usage
|
||||
|
||||
To enable 2D tensor parallelism for our model, e.g. on 4 GPUs, we need to configure the parallelism setting as below.
|
||||
```python
|
||||
CONFIG = dict(parallel=dict(
|
||||
data=1,
|
||||
pipeline=1,
|
||||
tensor=dict(size=4, mode='2d'),
|
||||
))
|
||||
```
|
||||
Then Colossal-AI will automatically apply 2D parallelism to all the layers from `colossalai.nn`.
|
||||
Currently the newest version of ColossalAI doesn't support 2D tensor parallelism, but this feature will be integrated into `Shardformer` in future releases.
|
||||
For more details about ideas and usages of `Shardformer`, please refer to [Shardformer Doc](./shardformer.md).
|
||||
|
||||
Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below.
|
||||
```python
|
||||
import colossalai
|
||||
import colossalai.nn as col_nn
|
||||
import torch
|
||||
from colossalai.utils import print_rank_0
|
||||
For users of older version of ColossalAI, please refer to [ColossalAI-Examples - 2D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md).
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, dim: int = 256):
|
||||
super().__init__()
|
||||
intermediate_dim = dim * 4
|
||||
self.dense_1 = col_nn.Linear(dim, intermediate_dim)
|
||||
print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}')
|
||||
self.activation = torch.nn.GELU()
|
||||
self.dense_2 = col_nn.Linear(intermediate_dim, dim)
|
||||
print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}')
|
||||
self.dropout = col_nn.Dropout(0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dense_1(x)
|
||||
print_rank_0(f'Output of the first linear layer: {x.shape}')
|
||||
x = self.activation(x)
|
||||
x = self.dense_2(x)
|
||||
print_rank_0(f'Output of the second linear layer: {x.shape}')
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
```
|
||||
Launch Colossal-AI on 4 GPUs and build the model
|
||||
```python
|
||||
parser = colossalai.get_default_parser()
|
||||
colossalai.launch(config=CONFIG,
|
||||
rank=args.rank,
|
||||
world_size=args.world_size,
|
||||
local_rank=args.local_rank,
|
||||
host=args.host,
|
||||
port=args.port)
|
||||
|
||||
m = MLP()
|
||||
```
|
||||
We will see the shapes of partitioned parameters(e.g. weights) in the MLP model.
|
||||
```shell
|
||||
Weight of the first linear layer: torch.Size([128, 512])
|
||||
Weight of the second linear layer: torch.Size([512, 128])
|
||||
```
|
||||
The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 2D parallelism, it becomes `[128, 512]` on each GPU.
|
||||
Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 128]`.
|
||||
|
||||
We can run the model with some random inputs.
|
||||
```python
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
x = torch.randn((16, 256), device=get_current_device())
|
||||
# partition input
|
||||
torch.distributed.broadcast(x, src=0)
|
||||
x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)]
|
||||
x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)]
|
||||
print_rank_0(f'Input: {x.shape}')
|
||||
|
||||
x = m(x)
|
||||
```
|
||||
Then we can see the shapes of activation results.
|
||||
```shell
|
||||
Input: torch.Size([8, 128])
|
||||
Output of the first linear layer: torch.Size([8, 512])
|
||||
Output of the second linear layer: torch.Size([8, 128])
|
||||
```
|
||||
The activation tensors in 2D parallelism are all split in both row and column.
|
||||
E.g. the output of the first linear layer has the shape `[8, 512]`, while the second layer has the output of `[8, 128]`.
|
||||
<!-- doc-test-command: echo -->
|
||||
|
@@ -58,86 +58,9 @@ Given $P=q \times q \times d$ processors, we present the theoretical computation
|
||||
|
||||
## Usage
|
||||
|
||||
To enable 2.5D tensor parallelism for our model, e.g. on 8 GPUs, we need to configure the parallelism setting as below.
|
||||
```python
|
||||
CONFIG = dict(parallel=dict(
|
||||
data=1,
|
||||
pipeline=1,
|
||||
tensor=dict(size=8, mode='2.5d', depth=2),
|
||||
))
|
||||
Currently the newest version of ColossalAI doesn't support 2.5D tensor parallelism, but this feature will be integrated into `Shardformer` in future releases.
|
||||
For more details about ideas and usages of `Shardformer`, please refer to [Shardformer Doc](./shardformer.md).
|
||||
|
||||
```
|
||||
Then Colossal-AI will automatically apply 2.5D parallelism to all the layers from `colossalai.nn`.
|
||||
For users of older version of ColossalAI, please refer to [ColossalAI-Examples - 2.5D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md).
|
||||
|
||||
Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below.
|
||||
```python
|
||||
import colossalai
|
||||
import colossalai.nn as col_nn
|
||||
import torch
|
||||
from colossalai.utils import print_rank_0
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, dim: int = 256):
|
||||
super().__init__()
|
||||
intermediate_dim = dim * 4
|
||||
self.dense_1 = col_nn.Linear(dim, intermediate_dim)
|
||||
print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}')
|
||||
self.activation = torch.nn.GELU()
|
||||
self.dense_2 = col_nn.Linear(intermediate_dim, dim)
|
||||
print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}')
|
||||
self.dropout = col_nn.Dropout(0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dense_1(x)
|
||||
print_rank_0(f'Output of the first linear layer: {x.shape}')
|
||||
x = self.activation(x)
|
||||
x = self.dense_2(x)
|
||||
print_rank_0(f'Output of the second linear layer: {x.shape}')
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
```
|
||||
Launch Colossal-AI on 8 GPUs and build the model
|
||||
```python
|
||||
parser = colossalai.get_default_parser()
|
||||
colossalai.launch(config=CONFIG,
|
||||
rank=args.rank,
|
||||
world_size=args.world_size,
|
||||
local_rank=args.local_rank,
|
||||
host=args.host,
|
||||
port=args.port)
|
||||
|
||||
m = MLP()
|
||||
```
|
||||
We will see the shapes of partitioned parameters(e.g. weights) in the MLP model.
|
||||
```shell
|
||||
Weight of the first linear layer: torch.Size([128, 512])
|
||||
Weight of the second linear layer: torch.Size([512, 128])
|
||||
```
|
||||
The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 2.5D parallelism, it becomes `[128, 512]` on each GPU.
|
||||
Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 128]`.
|
||||
|
||||
We can run the model with some random inputs.
|
||||
```python
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
x = torch.randn((16, 256), device=get_current_device())
|
||||
# partition input
|
||||
torch.distributed.broadcast(x, src=0)
|
||||
x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)]
|
||||
x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)]
|
||||
x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)]
|
||||
print_rank_0(f'Input: {x.shape}')
|
||||
|
||||
x = m(x)
|
||||
```
|
||||
Then we can see the shapes of activation results.
|
||||
```shell
|
||||
Input: torch.Size([4, 128])
|
||||
Output of the first linear layer: torch.Size([4, 512])
|
||||
Output of the second linear layer: torch.Size([4, 128])
|
||||
```
|
||||
The activation tensors in 2.5D parallelism are all split by $d \times q$ in the row and $q$ in the column.
|
||||
E.g. the output of the first linear layer has the shape `[4, 512]`), while the second layer has the output of `[4, 128]`.
|
||||
Note, 2.5D parallelism use the same partition method as 2D parallelism for weights, where the difference is the partition of input.
|
||||
<!-- doc-test-command: echo -->
|
||||
|
@@ -67,85 +67,9 @@ Given $P=q \times q \times q$ processors, we present the theoretical computation
|
||||
|
||||
## Usage
|
||||
|
||||
To enable 3D tensor parallelism for our model, e.g. on 8 GPUs, we need to configure the parallelism setting as below.
|
||||
```python
|
||||
CONFIG = dict(parallel=dict(
|
||||
data=1,
|
||||
pipeline=1,
|
||||
tensor=dict(size=8, mode='3d'),
|
||||
))
|
||||
```
|
||||
Then Colossal-AI will automatically apply 3D parallelism to all the layers from `colossalai.nn`.
|
||||
Currently the newest version of ColossalAI doesn't support 3D tensor parallelism, but this feature will be integrated into `Shardformer` in future releases.
|
||||
For more details about ideas and usages of `Shardformer`, please refer to [Shardformer Doc](./shardformer.md).
|
||||
|
||||
Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below.
|
||||
```python
|
||||
import colossalai
|
||||
import colossalai.nn as col_nn
|
||||
import torch
|
||||
from colossalai.utils import print_rank_0
|
||||
For users of older version of ColossalAI, please refer to [ColossalAI-Examples - 3D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md).
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, dim: int = 256):
|
||||
super().__init__()
|
||||
intermediate_dim = dim * 4
|
||||
self.dense_1 = col_nn.Linear(dim, intermediate_dim)
|
||||
print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}')
|
||||
self.activation = torch.nn.GELU()
|
||||
self.dense_2 = col_nn.Linear(intermediate_dim, dim)
|
||||
print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}')
|
||||
self.dropout = col_nn.Dropout(0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dense_1(x)
|
||||
print_rank_0(f'Output of the first linear layer: {x.shape}')
|
||||
x = self.activation(x)
|
||||
x = self.dense_2(x)
|
||||
print_rank_0(f'Output of the second linear layer: {x.shape}')
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
```
|
||||
Launch Colossal-AI on 8 GPUs and build the model
|
||||
```python
|
||||
parser = colossalai.get_default_parser()
|
||||
colossalai.launch(config=CONFIG,
|
||||
rank=args.rank,
|
||||
world_size=args.world_size,
|
||||
local_rank=args.local_rank,
|
||||
host=args.host,
|
||||
port=args.port)
|
||||
|
||||
m = MLP()
|
||||
```
|
||||
We will see the shapes of partitioned parameters(e.g. weights) in the MLP model.
|
||||
```shell
|
||||
Weight of the first linear layer: torch.Size([128, 256])
|
||||
Weight of the second linear layer: torch.Size([512, 64])
|
||||
```
|
||||
The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 3D parallelism, it becomes `[128, 256]` on each GPU.
|
||||
Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 64]`.
|
||||
|
||||
We can run the model with some random inputs.
|
||||
```python
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
x = torch.randn((16, 256), device=get_current_device())
|
||||
# partition input
|
||||
torch.distributed.broadcast(x, src=0)
|
||||
x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)]
|
||||
x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)]
|
||||
x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)]
|
||||
print_rank_0(f'Input: {x.shape}')
|
||||
|
||||
x = m(x)
|
||||
```
|
||||
Then we can see the shapes of activation results.
|
||||
```shell
|
||||
Input: torch.Size([4, 128])
|
||||
Output of the first linear layer: torch.Size([4, 512])
|
||||
Output of the second linear layer: torch.Size([4, 128])
|
||||
```
|
||||
The activation tensors in 3D parallelism are all split by $q^2$ in the row and $q$ in the column.
|
||||
E.g. the output of the first linear layer has the shape `[4, 512]`), while the second layer has the output of `[4, 128]`.
|
||||
Note, although the results of 3D parallelism have the same shape as that of 2.5D parallelism for weights here, the content of each partition is different.
|
||||
<!-- doc-test-command: echo -->
|
||||
|
@@ -29,33 +29,6 @@ This module aims to make parallelization hassle-free for users who are not from
|
||||
Within a few lines of codes, users can turn a model into a state ready for distributed training.
|
||||
Also, Shardformer contains various optimization tools for acceleration and memory saving during forward/backward pass.
|
||||
|
||||
|
||||
## How Shardformer Works
|
||||
|
||||
Generally, Shardformer works through the following four kinds of *replacements*:
|
||||
|
||||
1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module.
|
||||
The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters.
|
||||
Also, new `forward` methods will replace original ones so as to execute distributed computation, such as linear layers' split /gather operations executed under tensor parallelism.
|
||||
Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module.
|
||||
|
||||
2. Replacing attributes of original Huggingface Transformers layers with appropriate attributes for distributed training.
|
||||
For example, when training LlaMa-2 with tensor parallel size as 2, the attribute `num_heads` of `LlamaDecoderLayer` (the number of attention heads in each layer) should be replaced with `model.config.num_attention_heads // 2`.
|
||||
|
||||
3. Replacing the `forward` methods implemented by original Huggingface
|
||||
Transformers libraries with our customized `forward` methods.
|
||||
This replacement is essential for pipeline paralellism, where a customiozed function is needed to pass intermediate hidden states between different pipeline stages.
|
||||
Also, optimization methods such as flash attention or sequence parallel can be injected into the `forward` process through our customized `forward` method.
|
||||
|
||||
4. Replacing the whole copy of model parameters and optimizer states with incomplete ones controlled by current device (this is why it's called Shardformer).
|
||||
By executing `ModelSharder.shard` method, current device will only keep the part of model parameters it's supposed to take care of.
|
||||
To be specific, they should be the assigned parameter shards when using tensor parallelism, or the parameters belonging to current pipeline stage when using pipeline parallelism, or both of them.
|
||||
All other parameters are released so as to liberate memory usage.
|
||||
As a result, the optimizer will only compute the states corresponding to these part of parameters, causing the usage of memory to be further saved.
|
||||
|
||||
All of these replacements are implemented with manually written policies and forward functions.
|
||||
If you want to delve deeper into the design of Shardformer or customize your own Shardformer policies, please refer to our [Shardformer development document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md) and [pipeline parallelism design](https://github.com/hpcaitech/ColossalAI/discussions/4050) for more details.
|
||||
|
||||
## Usage
|
||||
|
||||
### Shardformer Configuration
|
||||
@@ -101,31 +74,187 @@ is an example on how to trigger `Shardformer` through calling Shardformer APIs.
|
||||
```
|
||||
when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes.
|
||||
|
||||
## How Shardformer Works
|
||||
|
||||
Generally, Shardformer works through the following four kinds of *replacements*:
|
||||
|
||||
1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module.
|
||||
The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters.
|
||||
Also, new `forward` methods will replace original ones so as to execute distributed computation, such as linear layers' split /gather operations executed under tensor parallelism.
|
||||
Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module.
|
||||
|
||||
2. Replacing attributes of original Huggingface Transformers layers with appropriate attributes for distributed training.
|
||||
For example, when training LlaMa-2 with tensor parallel size as 2, the attribute `num_heads` of `LlamaDecoderLayer` (the number of attention heads in each layer) should be replaced with `model.config.num_attention_heads // 2`.
|
||||
|
||||
3. Replacing the `forward` methods implemented by original Huggingface
|
||||
Transformers libraries with our customized `forward` methods.
|
||||
This replacement is essential for pipeline paralellism, where a customiozed function is needed to pass intermediate hidden states between different pipeline stages.
|
||||
Also, optimization methods such as flash attention or sequence parallel can be injected into the `forward` process through our customized `forward` method.
|
||||
|
||||
4. Replacing the whole copy of model parameters and optimizer states with incomplete ones controlled by current device (this is why it's called Shardformer).
|
||||
By executing `ModelSharder.shard` method, current device will only keep the part of model parameters it's supposed to take care of.
|
||||
To be specific, they should be the assigned parameter shards when using tensor parallelism, or the parameters belonging to current pipeline stage when using pipeline parallelism, or both of them.
|
||||
All other parameters are released so as to liberate memory usage.
|
||||
As a result, the optimizer will only compute the states corresponding to these part of parameters, causing the usage of memory to be further saved.
|
||||
|
||||
All of these replacements are implemented with manually written policies and forward functions.
|
||||
If you want to delve deeper into the design of Shardformer or customize your own Shardformer policies, please refer to our [Shardformer development document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md) and [pipeline parallelism design](https://github.com/hpcaitech/ColossalAI/discussions/4050) for more details.
|
||||
|
||||
## Supporting Information
|
||||
|
||||
List of Huggingface transformers model families currently supported by Shardformer:
|
||||
- LlaMa-1/LlaMa-2
|
||||
- GPT2
|
||||
- BERT
|
||||
- OPT
|
||||
- BLOOM
|
||||
- T5
|
||||
- ViT
|
||||
- ChatGLM-2 6B
|
||||
- Whisper
|
||||
Model/Feature Compatibility Matrix:
|
||||
|
||||
List of optimization tools currently supported by Shardformer:
|
||||
- Flash Attention 2
|
||||
- JIT Fused Operator
|
||||
- xFormers
|
||||
- Fused Layer Normalization
|
||||
- Sequence Parallel
|
||||
- Sequence Overlap
|
||||
<table>
|
||||
<tr>
|
||||
<th nowrap="nowrap">Model/Feature</th>
|
||||
<th nowrap="nowrap" title="Tensor Parallel">Tensor<br />Parallel</th>
|
||||
<th nowrap="nowrap" align="center" title="Pipeline Parallel">Pipeline<br />Parallel</th>
|
||||
<th nowrap="nowrap" align="center" title="Lazy Initialization">Lazy<br />Initialization</th>
|
||||
<th nowrap="nowrap" align="center" title="xFormers">xFormers</th>
|
||||
<th nowrap="nowrap" align="center" title="Flash Attention 2">Flash<br />Attention 2</th>
|
||||
<th nowrap="nowrap" align="center" title="JIT Fused Operators">JIT Fused<br />Operators</th>
|
||||
<th nowrap="nowrap" align="center" title="Fused LayerNorm">Fused<br />LayerNorm</th>
|
||||
<th nowrap="nowrap" align="center" title="Sequence Parallel">Sequence<br />Parallel</th>
|
||||
<th nowrap="nowrap" align="center" title="Sequence Overlap">Sequence<br />Overlap</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">Llama V1/V2</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">OPT</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">BLOOM</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">ChatGLM 2</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">BERT</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">GPT 2</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">T5</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">ViT</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">Whisper</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">SAM</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">Blip2</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="39"></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
List of model families we plan to support in the near future:
|
||||
- SAM
|
||||
- Blip2
|
||||
- RoBERTa
|
||||
- ALBERT
|
||||
- ERNIE
|
||||
@@ -135,9 +264,6 @@ List of model families we plan to support in the near future:
|
||||
- SwinTransformer V1/V2
|
||||
- qwen
|
||||
|
||||
These lists will grow longer as more models and optimization tools emerge in the future. If you have any suggestions on the models/optimization we should support, please feel free to mention it in [Issues](https://github.com/hpcaitech/ColossalAI/issues) section of our project.
|
||||
|
||||
For more details about compatibility between each optimization tool and each supported model, please refer to chapter Roadmap in our [develop document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md).
|
||||
|
||||
The support matrix will grow larger as more models and optimization tools emerge in the future. If you have any suggestions on the models/optimization we should support, please feel free to mention it in [Issues](https://github.com/hpcaitech/ColossalAI/issues) section of our project.
|
||||
|
||||
<!-- doc-test-command: echo -->
|
||||
|
Reference in New Issue
Block a user