1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-05-03 22:18:23 +00:00
ColossalAI/colossalai/booster/plugin/torch_fsdp_plugin.py
Wang Binluo eea37da6fa
[fp8] Merge feature/fp8_comm to main branch of Colossalai ()
* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support ()

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file ()

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos ()

* [Feature] remove modeling file, use auto config. ()

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code ()

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. ()

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support ()

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for  ()

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama ()

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM ()

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap ()

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% ()

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp ()

* [compatibility] support torch 2.2 ()

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 ()

* [misc] support torch2.3 ()

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version ()

* [plugin] support all-gather overlap for hybrid parallel ()

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples ()

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat ()

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value ()

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module ()

* [FIX BUG] convert env param to int in ()

* [Hotfix] Fix ZeRO typo 

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends ()

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask ()

* [shardformer] hotfix attn mask ()

* [Feat] Distrifusion Acceleration Support for Diffusion Inference ()

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params ()

* [release] update version ()

* [Chat] Fix lora ()

* fix merging

* remove filepath

* fix style

* Update README.md ()

* [hotfix] Remove unused plan section ()

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support ()

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval ()

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin ()

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* Support overall loss, update KTO logging

* [Docs] clarify launch port

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Hotfix] README link ()

* update ignore

* update readme

* run style

* update readme

* [Hotfix] Avoid fused RMSnorm import error without apex ()

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Chat] fix readme ()

* fix readme

* fix readme, tokenization fully tested

* fix readme, tokenization fully tested

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix sync condition ()

* [plugin] add cast inputs option for zero ()

* [pre-commit.ci] pre-commit autoupdate ()

updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] Bypass the huggingface bug to solve the mask mismatch problem ()

* [Feature] Zigzag Ring attention ()

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] update compatibility ()

* [misc] update compatibility

* [misc] update requirements

* [devops] disable requirements cache

* [test] fix torch ddp test

* [test] fix rerun on address in use

* [test] fix lazy init

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix the merge

* overlap kv comm with output rescale ()

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix the merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix

* fix

* fix the merge

* fix

* [misc] Use dist logger in plugins ()

* use dist logger in plugins

* remove trash

* print on rank 0

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix

* fix

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
2024-08-22 09:21:34 +08:00

388 lines
15 KiB
Python

import os
from pathlib import Path
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
from packaging import version
from torch.distributed import ProcessGroup
if version.parse(torch.__version__) >= version.parse("1.12.0"):
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
CPUOffload,
FullStateDictConfig,
MixedPrecision,
ShardingStrategy,
)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from .dp_plugin_base import DPPluginBase
__all__ = ["TorchFSDPPlugin"]
class TorchFSDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
model = model.unwrap()
checkpoint = utils.load_state_dict(checkpoint)
model.load_state_dict(checkpoint)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
checkpoint = utils.load_state_dict(checkpoint)
fsdp_model = optimizer.unwrap_model()
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
optimizer.load_state_dict(sharded_osd)
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
model = model.unwrap()
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
full_model_state = model.state_dict()
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
def save_sharded_model(
self,
model: ModelWrapper,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
):
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
if os.path.isfile(checkpoint_path):
self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
with FSDP.state_dict_type(
model.unwrap(), StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
):
state_dict = model.unwrap().state_dict()
state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard)
weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint_path)
# In general cases, is_master is set to True to get the right behavior.
total_size = utils.save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=self.coordinator.is_master(),
use_safetensors=use_safetensors,
)
# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
utils.save_config_file(model.unwrap(), checkpoint_path)
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_model(
self,
model: nn.Module,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
"""
Load model to checkpoint but only on master process.
"""
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True
if use_safetensors and not utils.is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# read checkpoint index file
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
fsdp_state_dict = {}
for shard_file in checkpoint_files:
fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors))
with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT):
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
):
"""
Save optimizer to checkpoint but only on master process.
"""
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
with FSDP.state_dict_type(
optimizer.unwrap_model().unwrap(),
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
fsdp_optim_state = FSDP.full_optim_state_dict(
optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True
)
if self.coordinator.is_master():
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
utils.save_param_groups(fsdp_optim_state, group_file_path)
sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard)
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = utils.save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=self.coordinator.is_master(),
use_safetensors=False,
)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
self.logger.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
"""
Load optimizer to checkpoint but only on master process.
"""
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(
f"Invalid index file path {index_file_path} for an optimizer. "
"Looking param group file under current directory."
)
saved_param_groups = torch.load(param_group_path)
# Load param
fsdp_optim_state = {}
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
fsdp_optim_state.update(state_dict_shard)
fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)
with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT):
fsdp_state = FSDP.optim_state_dict_to_load(
model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict
)
optimizer.load_state_dict(fsdp_state)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
class TorchFSDPModel(ModelWrapper):
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = FSDP(module, *args, **kwargs)
def unwrap(self):
return self.module
class FSDPOptimizerWrapper(OptimizerWrapper):
def __init__(self, optimizer: Optimizer, model: nn.Module):
self.model = model
super().__init__(optimizer)
def unwrap_model(self) -> nn.Module:
return self.model
class TorchFSDPPlugin(DPPluginBase):
"""
Plugin for PyTorch FSDP.
```python
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchFSDPPlugin
model, train_dataset, optimizer, criterion = ...
plugin = TorchFSDPPlugin()
train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
booster = Booster(plugin=plugin)
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```
Args:
See https://pytorch.org/docs/stable/fsdp.html for details.
"""
if version.parse(torch.__version__) >= version.parse("1.12.0"):
def __init__(
self,
process_group: Optional[ProcessGroup] = None,
sharding_strategy: Optional[ShardingStrategy] = None,
cpu_offload: Optional[CPUOffload] = None,
auto_wrap_policy: Optional[Callable] = None,
backward_prefetch: Optional[BackwardPrefetch] = None,
mixed_precision: Optional[MixedPrecision] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
sync_module_states: bool = False,
fp8_communication: bool = False,
):
super().__init__()
self.fsdp_kwargs = dict(
process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
)
self.fp8_communication = fp8_communication
self.logger = get_dist_logger()
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
def support_no_sync(self) -> bool:
return False
def support_lora(self) -> bool:
return False
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
def control_precision(self) -> bool:
return True
def supported_precisions(self) -> List[str]:
return ["fp16", "bf16"]
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
return ["cuda"]
def configure(
self,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
if self.fp8_communication:
from colossalai.quantization.utils import patch_fsdp_params_comm_hook
patch_fsdp_params_comm_hook()
from colossalai.quantization.fp8 import fp8_compress_fsdp_params_comm_hook
fsdp_model.module.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook)
from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook
fsdp_model.module.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook)
if optimizer is not None:
if len(optimizer.param_groups) > 1:
self.logger.warning(
"TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used."
)
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
if not isinstance(optimizer, FSDPOptimizerWrapper):
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
return fsdp_model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool:
return True
def get_checkpoint_io(self) -> CheckpointIO:
return TorchFSDPCheckpointIO()
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
raise NotImplementedError