mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 20:54:55 +00:00
[doc] polish shardformer doc (#4779)
* fix example format in docstring * polish shardformer doc
This commit is contained in:
parent
26cd6d850c
commit
a2db75546d
@ -229,16 +229,17 @@ class GeminiPlugin(DPPluginBase):
|
|||||||
"""
|
"""
|
||||||
Plugin for Gemini.
|
Plugin for Gemini.
|
||||||
|
|
||||||
Example:
|
```python
|
||||||
>>> from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
>>> from colossalai.booster.plugin import GeminiPlugin
|
from colossalai.booster.plugin import GeminiPlugin
|
||||||
>>>
|
|
||||||
>>> model, train_dataset, optimizer, criterion = ...
|
|
||||||
>>> plugin = GeminiPlugin()
|
|
||||||
|
|
||||||
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
model, train_dataset, optimizer, criterion = ...
|
||||||
>>> booster = Booster(plugin=plugin)
|
plugin = GeminiPlugin()
|
||||||
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
|
||||||
|
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunk_config_dict (dict, optional): chunk configuration dictionary.
|
chunk_config_dict (dict, optional): chunk configuration dictionary.
|
||||||
|
@ -266,16 +266,17 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
|
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
|
||||||
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
|
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
|
||||||
|
|
||||||
Example:
|
```python
|
||||||
>>> from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
>>> from colossalai.booster.plugin import HybridParallelPlugin
|
from colossalai.booster.plugin import HybridParallelPlugin
|
||||||
|
|
||||||
>>> model, train_dataset, optimizer, criterion = ...
|
model, train_dataset, optimizer, criterion = ...
|
||||||
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
|
plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
|
||||||
|
|
||||||
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
||||||
>>> booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
|
model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
|
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
|
||||||
|
@ -213,16 +213,17 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
"""
|
"""
|
||||||
Plugin for low level zero.
|
Plugin for low level zero.
|
||||||
|
|
||||||
Example:
|
```python
|
||||||
>>> from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
>>> from colossalai.booster.plugin import LowLevelZeroPlugin
|
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||||
>>>
|
|
||||||
>>> model, train_dataset, optimizer, criterion = ...
|
|
||||||
>>> plugin = LowLevelZeroPlugin()
|
|
||||||
|
|
||||||
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
model, train_dataset, optimizer, criterion = ...
|
||||||
>>> booster = Booster(plugin=plugin)
|
plugin = LowLevelZeroPlugin()
|
||||||
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
|
||||||
|
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
strage (int, optional): ZeRO stage. Defaults to 1.
|
strage (int, optional): ZeRO stage. Defaults to 1.
|
||||||
|
@ -130,16 +130,17 @@ class TorchDDPPlugin(DPPluginBase):
|
|||||||
"""
|
"""
|
||||||
Plugin for PyTorch DDP.
|
Plugin for PyTorch DDP.
|
||||||
|
|
||||||
Example:
|
```python
|
||||||
>>> from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
>>> from colossalai.booster.plugin import TorchDDPPlugin
|
from colossalai.booster.plugin import TorchDDPPlugin
|
||||||
>>>
|
|
||||||
>>> model, train_dataset, optimizer, criterion = ...
|
|
||||||
>>> plugin = TorchDDPPlugin()
|
|
||||||
|
|
||||||
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
model, train_dataset, optimizer, criterion = ...
|
||||||
>>> booster = Booster(plugin=plugin)
|
plugin = TorchDDPPlugin()
|
||||||
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
|
||||||
|
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True.
|
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True.
|
||||||
|
@ -143,16 +143,17 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||||||
"""
|
"""
|
||||||
Plugin for PyTorch FSDP.
|
Plugin for PyTorch FSDP.
|
||||||
|
|
||||||
Example:
|
```python
|
||||||
>>> from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
>>> from colossalai.booster.plugin import TorchFSDPPlugin
|
from colossalai.booster.plugin import TorchFSDPPlugin
|
||||||
>>>
|
|
||||||
>>> model, train_dataset, optimizer, criterion = ...
|
|
||||||
>>> plugin = TorchFSDPPlugin()
|
|
||||||
|
|
||||||
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
|
model, train_dataset, optimizer, criterion = ...
|
||||||
>>> booster = Booster(plugin=plugin)
|
plugin = TorchFSDPPlugin()
|
||||||
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
|
||||||
|
train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
See https://pytorch.org/docs/stable/fsdp.html for details.
|
See https://pytorch.org/docs/stable/fsdp.html for details.
|
||||||
|
@ -20,14 +20,16 @@ class DistCoordinator(metaclass=SingletonMeta):
|
|||||||
- master: the process with rank 0
|
- master: the process with rank 0
|
||||||
- node master: the process with local rank 0 on the current node
|
- node master: the process with local rank 0 on the current node
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> from colossalai.cluster.dist_coordinator import DistCoordinator
|
```python
|
||||||
>>> coordinator = DistCoordinator()
|
from colossalai.cluster.dist_coordinator import DistCoordinator
|
||||||
>>>
|
coordinator = DistCoordinator()
|
||||||
>>> if coordinator.is_master():
|
|
||||||
>>> do_something()
|
if coordinator.is_master():
|
||||||
>>>
|
do_something()
|
||||||
>>> coordinator.print_on_master('hello world')
|
|
||||||
|
coordinator.print_on_master('hello world')
|
||||||
|
```
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
rank (int): the rank of the current process
|
rank (int): the rank of the current process
|
||||||
@ -131,11 +133,13 @@ class DistCoordinator(metaclass=SingletonMeta):
|
|||||||
other processes in the same process group. This is often useful when downloading is required
|
other processes in the same process group. This is often useful when downloading is required
|
||||||
as we only want to download in one process to prevent file corruption.
|
as we only want to download in one process to prevent file corruption.
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> from colossalai.cluster import DistCoordinator
|
```python
|
||||||
>>> dist_coordinator = DistCoordinator()
|
from colossalai.cluster import DistCoordinator
|
||||||
>>> with dist_coordinator.priority_execution():
|
dist_coordinator = DistCoordinator()
|
||||||
>>> dataset = CIFAR10(root='./data', download=True)
|
with dist_coordinator.priority_execution():
|
||||||
|
dataset = CIFAR10(root='./data', download=True)
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
|
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
|
||||||
@ -174,13 +178,14 @@ class DistCoordinator(metaclass=SingletonMeta):
|
|||||||
"""
|
"""
|
||||||
A function wrapper that only executes the wrapped function on the master process (rank 0).
|
A function wrapper that only executes the wrapped function on the master process (rank 0).
|
||||||
|
|
||||||
Example:
|
```python
|
||||||
>>> from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
>>> dist_coordinator = DistCoordinator()
|
dist_coordinator = DistCoordinator()
|
||||||
>>>
|
|
||||||
>>> @dist_coordinator.on_master_only()
|
@dist_coordinator.on_master_only()
|
||||||
>>> def print_on_master(msg):
|
def print_on_master(msg):
|
||||||
>>> print(msg)
|
print(msg)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
is_master = self.is_master(process_group)
|
is_master = self.is_master(process_group)
|
||||||
|
|
||||||
|
@ -214,9 +214,56 @@ In addition, xFormers's `cutlass_op` can serve as a backup for flash attention.
|
|||||||
Enabling `Shardformer` through `Booster` initialized with `HybridParallelPlugin` is the recommended way to awaken the power of Shardformer.
|
Enabling `Shardformer` through `Booster` initialized with `HybridParallelPlugin` is the recommended way to awaken the power of Shardformer.
|
||||||
The main reason is that pipeline parallelism cannot successfully work without the calling of `execute_pipeline` method of `Booster`. Besides, `HybridParallelPlugin` provides the capacity to combine the features of `Shardformer` with other useful features, such as mixed precision training or Zero.
|
The main reason is that pipeline parallelism cannot successfully work without the calling of `execute_pipeline` method of `Booster`. Besides, `HybridParallelPlugin` provides the capacity to combine the features of `Shardformer` with other useful features, such as mixed precision training or Zero.
|
||||||
|
|
||||||
More details about this usage can be found in chapter [Booster API](../basics/booster_api.md) and [Booster Plugins](../basics/booster_plugins.md).
|
[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Move to the root directory of this example, and execute
|
||||||
|
```bash
|
||||||
|
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin "hybrid_parallel" --model_type "bert"
|
||||||
|
```
|
||||||
|
Then you can start finetuning a bert model wrapped by `Shardformer`. The process of wrapping is operated by `HybridParallelPlugin`.
|
||||||
|
|
||||||
[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Please be aware that there's a difference in the way of doing forward and backward between the situation of using pipeline and not using pipeline.
|
Let's delve into the code of `finetune.py`:
|
||||||
|
|
||||||
|
In the `main` function, the plugin is created through the following codes:
|
||||||
|
```python
|
||||||
|
...
|
||||||
|
elif args.plugin == "hybrid_parallel":
|
||||||
|
# modify the param accordingly for finetuning test cases
|
||||||
|
plugin = HybridParallelPlugin(
|
||||||
|
tp_size=1,
|
||||||
|
pp_size=2,
|
||||||
|
num_microbatches=None,
|
||||||
|
microbatch_size=1,
|
||||||
|
enable_all_optimization=True,
|
||||||
|
zero_stage=1,
|
||||||
|
precision="fp16",
|
||||||
|
initial_scale=1,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
Here you can change the configuration of plugin by setting `tp_size`, `pp_size` or `zero_stage` to other values. More details about plugin configuration can be found in [Booster Plugins Doc](../basics/booster_plugins.md).
|
||||||
|
|
||||||
|
If pipeline parallel is not enabled, just do the training in the same way of other booster plugins(first boost with Booster, then do forward and backward through normal way).
|
||||||
|
However, if pipeline parallel is enabled, there are several usages different from other normal cases:
|
||||||
|
|
||||||
|
1. Before doing forward or backward, the criterion function (loss function) is processed to meet the argument demand of running pipeline:
|
||||||
|
```python
|
||||||
|
def _criterion(outputs, inputs):
|
||||||
|
outputs = output_transform_fn(outputs)
|
||||||
|
loss = criterion(outputs)
|
||||||
|
return loss
|
||||||
|
```
|
||||||
|
|
||||||
|
2. In `train_epoch` function, dataloader is converted into `Iterator` class before running pipeline:
|
||||||
|
```python
|
||||||
|
train_dataloader_iter = iter(train_dataloader)
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Do forward and backward passing through calling `Booster.execute_pipeline` method:
|
||||||
|
```python
|
||||||
|
outputs = booster.execute_pipeline(
|
||||||
|
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
|
||||||
|
)
|
||||||
|
```
|
||||||
|
Backward passing has been completed by this method, so there is no need to call `loss.backward()` after executing this method.
|
||||||
|
More details about `Booster.execute_pipeline` can be found in [Booster API Doc](../basics/booster_api.md).
|
||||||
|
|
||||||
|
|
||||||
#### 2. Enabling Shardformer Through Shardformer APIs (Not Recommended)
|
#### 2. Enabling Shardformer Through Shardformer APIs (Not Recommended)
|
||||||
@ -224,7 +271,26 @@ More details about this usage can be found in chapter [Booster API](../basics/bo
|
|||||||
You can also use Shardformer through manually calling Shardformer APIs. However, this usage is not recommended since pipeline parallelism can't run without `Booster`.
|
You can also use Shardformer through manually calling Shardformer APIs. However, this usage is not recommended since pipeline parallelism can't run without `Booster`.
|
||||||
|
|
||||||
[Here](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py)
|
[Here](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py)
|
||||||
is an example on how to trigger `Shardformer` through calling Shardformer APIs.
|
is an example on how to trigger `Shardformer` through calling Shardformer APIs. In the `train` function of example code, the model is wrapped by `Shardformer` through the following few codes:
|
||||||
|
```python
|
||||||
|
...
|
||||||
|
if dist.get_world_size() > 1:
|
||||||
|
tp_group = dist.new_group(backend="nccl")
|
||||||
|
|
||||||
|
# First create configuration for Shardformer
|
||||||
|
shard_config = ShardConfig(
|
||||||
|
tensor_parallel_process_group=tp_group,
|
||||||
|
enable_tensor_parallelism=True,
|
||||||
|
enable_all_optimization=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then create ShardFormer object with created config
|
||||||
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
|
|
||||||
|
# Finally shard the model using ShardFormer.optimize method
|
||||||
|
model, _ = shard_former.optimize(model)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
### Precautions
|
### Precautions
|
||||||
|
|
||||||
@ -241,6 +307,8 @@ is an example on how to trigger `Shardformer` through calling Shardformer APIs.
|
|||||||
|
|
||||||
## How Shardformer Works
|
## How Shardformer Works
|
||||||
|
|
||||||
|
### Main Idea
|
||||||
|
|
||||||
Generally, Shardformer works through the following four kinds of *replacements*:
|
Generally, Shardformer works through the following four kinds of *replacements*:
|
||||||
|
|
||||||
1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module.
|
1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module.
|
||||||
|
@ -207,8 +207,56 @@ Shardformer的配置由类`ShardConfig`的参数控制:
|
|||||||
|
|
||||||
通过用`HybridParallelPlugin`初始化的`Booster`来启动`Shardformer`是最推荐的用法。其主要原因是如果不调用`Booster`的`execute_pipeline`方法,流水线并行就无法正常工作。此外,`HybridParallelPlugin`提供了将`Shardformer`的功能与其他功能(例如混合精度训练或Zero)相结合的能力。
|
通过用`HybridParallelPlugin`初始化的`Booster`来启动`Shardformer`是最推荐的用法。其主要原因是如果不调用`Booster`的`execute_pipeline`方法,流水线并行就无法正常工作。此外,`HybridParallelPlugin`提供了将`Shardformer`的功能与其他功能(例如混合精度训练或Zero)相结合的能力。
|
||||||
|
|
||||||
更多关于这一用法的细节可以参考 [Booster API 文档](../basics/booster_api.md)以及[Booster 插件文档](../basics/booster_plugins.md)。[这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。
|
[这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。
|
||||||
|
移动到示例的根目录下,执行命令:
|
||||||
|
```bash
|
||||||
|
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin "hybrid_parallel" --model_type "bert"
|
||||||
|
```
|
||||||
|
你便可以微调一个被`Shardformer`封装过的Bert模型,而封装的操作是由`HybridParallelPlugin`完成的。
|
||||||
|
|
||||||
|
接下来一起深入挖掘一下`finetune.py`里的代码:
|
||||||
|
|
||||||
|
在`main`函数中,混合并行的插件通过以下的代码创建
|
||||||
|
```python
|
||||||
|
...
|
||||||
|
elif args.plugin == "hybrid_parallel":
|
||||||
|
# modify the param accordingly for finetuning test cases
|
||||||
|
plugin = HybridParallelPlugin(
|
||||||
|
tp_size=1,
|
||||||
|
pp_size=2,
|
||||||
|
num_microbatches=None,
|
||||||
|
microbatch_size=1,
|
||||||
|
enable_all_optimization=True,
|
||||||
|
zero_stage=1,
|
||||||
|
precision="fp16",
|
||||||
|
initial_scale=1,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
在这里你可以通过设置不同的`tp_size`, `pp_size` 或 `zero_stage`来改变插件的配置。更多关于插件配置的信息可以在[Booster 插件文档](../basics/booster_plugins.md)中被找到。
|
||||||
|
|
||||||
|
当流水并行不被启用的时候,训练的流程和其他的插件是一样的 (先用Booster封装模型和优化器,再用正常的方式做前向和后向传递)。然而,当流水线并行被启用的时候,有几处不同于寻常情况的用法:
|
||||||
|
|
||||||
|
1. 在进行前向和后向之前,criterion函数(loss函数)需要被处理以满足流水线并行的传参要求:
|
||||||
|
```python
|
||||||
|
def _criterion(outputs, inputs):
|
||||||
|
outputs = output_transform_fn(outputs)
|
||||||
|
loss = criterion(outputs)
|
||||||
|
return loss
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 在 `train_epoch` 函数中, dataloader 在进行流水线的前向后向操作之前需要被转换为 `Iterator` 类:
|
||||||
|
```python
|
||||||
|
train_dataloader_iter = iter(train_dataloader)
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 通过调用`Booster.execute_pipeline` 方法来执行前向和后向传递:
|
||||||
|
```python
|
||||||
|
outputs = booster.execute_pipeline(
|
||||||
|
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
|
||||||
|
)
|
||||||
|
```
|
||||||
|
该方法会自动执行后向传递,所以在执行该方法后不需要再调用 `loss.backward()`方法。
|
||||||
|
更多关于 `Booster.execute_pipeline` 的信息可以参考 [Booster API 文档](../basics/booster_api.md)。
|
||||||
|
|
||||||
#### 2. 通过Shardformer API启动Shardformer (不推荐)
|
#### 2. 通过Shardformer API启动Shardformer (不推荐)
|
||||||
|
|
||||||
@ -216,7 +264,26 @@ Shardformer的配置由类`ShardConfig`的参数控制:
|
|||||||
|
|
||||||
[这里](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py)
|
[这里](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py)
|
||||||
是一个通过调用Shardformer的API启动`Shardformer`的示例。
|
是一个通过调用Shardformer的API启动`Shardformer`的示例。
|
||||||
|
在示例代码的`train`函数中,模型被以下的几行代码进行封装:
|
||||||
|
```python
|
||||||
|
...
|
||||||
|
if dist.get_world_size() > 1:
|
||||||
|
tp_group = dist.new_group(backend="nccl")
|
||||||
|
|
||||||
|
# First create configuration for Shardformer
|
||||||
|
shard_config = ShardConfig(
|
||||||
|
tensor_parallel_process_group=tp_group,
|
||||||
|
enable_tensor_parallelism=True,
|
||||||
|
enable_all_optimization=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then create ShardFormer object with created config
|
||||||
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
|
|
||||||
|
# Finally shard the model using ShardFormer.optimize method
|
||||||
|
model, _ = shard_former.optimize(model)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
### 注意事项
|
### 注意事项
|
||||||
|
|
||||||
@ -234,6 +301,8 @@ Shardformer的配置由类`ShardConfig`的参数控制:
|
|||||||
|
|
||||||
## Shardformer的工作原理
|
## Shardformer的工作原理
|
||||||
|
|
||||||
|
### 设计思想
|
||||||
|
|
||||||
通常来说,Shardformer通过以下四种“替换”进行工作:
|
通常来说,Shardformer通过以下四种“替换”进行工作:
|
||||||
|
|
||||||
1. 用我们设计的分布式模块替换原始的PyTorch模块(例如`nn.Linear`、`nn.Embedding`)。
|
1. 用我们设计的分布式模块替换原始的PyTorch模块(例如`nn.Linear`、`nn.Embedding`)。
|
||||||
|
Loading…
Reference in New Issue
Block a user