mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-06 15:38:26 +00:00
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686
.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* Support overall loss, update KTO logging
* [Docs] clarify launch port
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Hotfix] README link (#5966)
* update ignore
* update readme
* run style
* update readme
* [Hotfix] Avoid fused RMSnorm import error without apex (#5985)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Chat] fix readme (#5989)
* fix readme
* fix readme, tokenization fully tested
* fix readme, tokenization fully tested
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix sync condition (#6000)
* [plugin] add cast inputs option for zero (#6003)
* [pre-commit.ci] pre-commit autoupdate (#5995)
updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991)
* [Feature] Zigzag Ring attention (#5905)
* halfway
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* unified cross entropy func for all shardformer models
* remove redundant lines
* add basic ring attn; debug cross entropy
* fwd bwd logic complete
* fwd bwd logic complete; add experimental triton rescale
* precision tests passed
* precision tests passed
* fix typos and remove misc files
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* add sp_mode to benchmark; fix varlen interface
* update softmax_lse shape by new interface
* change tester name
* remove buffer clone; support packed seq layout
* add varlen tests
* fix typo
* all tests passed
* add dkv_group; fix mask
* remove debug statements
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] update compatibility (#6008)
* [misc] update compatibility
* [misc] update requirements
* [devops] disable requirements cache
* [test] fix torch ddp test
* [test] fix rerun on address in use
* [test] fix lazy init
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix the merge
* overlap kv comm with output rescale (#6017)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* fix the merge
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix
* fix
* fix the merge
* fix
* [misc] Use dist logger in plugins (#6011)
* use dist logger in plugins
* remove trash
* print on rank 0
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* fix
* fix
* fix
* fix
* fix the merge
* fix
* fix
* fix
* fix
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
410 lines
20 KiB
Python
410 lines
20 KiB
Python
from contextlib import contextmanager
|
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|
from torch.utils.data import DataLoader
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
SUPPORT_PEFT = False
|
|
try:
|
|
import peft
|
|
|
|
SUPPORT_PEFT = True
|
|
except ImportError:
|
|
pass
|
|
|
|
import colossalai.interface.pretrained as pretrained_utils
|
|
from colossalai.checkpoint_io import GeneralCheckpointIO
|
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
|
from colossalai.quantization import BnbQuantizationConfig
|
|
|
|
from .accelerator import Accelerator
|
|
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
|
from .plugin import Plugin
|
|
from .plugin.pp_plugin_base import PipelinePluginBase
|
|
|
|
__all__ = ["Booster"]
|
|
|
|
|
|
class Booster:
|
|
"""
|
|
Booster is a high-level API for training neural networks. It provides a unified interface for
|
|
training with different precision, accelerator, and plugin.
|
|
|
|
|
|
```python
|
|
# Following is pseudocode
|
|
|
|
colossalai.launch(...)
|
|
plugin = GeminiPlugin(...)
|
|
booster = Booster(precision='fp16', plugin=plugin)
|
|
|
|
model = GPT2()
|
|
optimizer = HybridAdam(model.parameters())
|
|
dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
|
lr_scheduler = LinearWarmupScheduler()
|
|
criterion = GPTLMLoss()
|
|
|
|
model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler)
|
|
|
|
for epoch in range(max_epochs):
|
|
for input_ids, attention_mask in dataloader:
|
|
outputs = model(input_ids.cuda(), attention_mask.cuda())
|
|
loss = criterion(outputs.logits, input_ids)
|
|
booster.backward(loss, optimizer)
|
|
optimizer.step()
|
|
lr_scheduler.step()
|
|
optimizer.zero_grad()
|
|
```
|
|
|
|
Args:
|
|
device (str or torch.device): The device to run the training. Default: None.
|
|
If plugin is not used or plugin doesn't control the device,
|
|
this argument will be set as training device ('cuda' will be used if argument is None).
|
|
mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None.
|
|
If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'.
|
|
'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex.
|
|
plugin (Plugin): The plugin to run the training. Default: None.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
device: Optional[str] = None,
|
|
mixed_precision: Optional[Union[MixedPrecision, str]] = None,
|
|
plugin: Optional[Plugin] = None,
|
|
) -> None:
|
|
if plugin is not None:
|
|
assert isinstance(
|
|
plugin, Plugin
|
|
), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}."
|
|
self.plugin = plugin
|
|
self.logger = get_dist_logger()
|
|
|
|
# set accelerator
|
|
if self.plugin and self.plugin.control_device():
|
|
self.accelerator = None
|
|
if device is not None:
|
|
self.logger.warning(
|
|
"The plugin will control the accelerator," "so the device argument will be ignored.", ranks=[0]
|
|
)
|
|
else:
|
|
device = device or "cuda"
|
|
self.accelerator = Accelerator(device)
|
|
|
|
# set precision
|
|
if self.plugin and self.plugin.control_precision():
|
|
if mixed_precision is not None:
|
|
self.logger.warning(
|
|
"The plugin will control the precision," "so the mixed_precision argument will be ignored.",
|
|
ranks=[0],
|
|
)
|
|
self.mixed_precision = None
|
|
elif mixed_precision is None:
|
|
self.mixed_precision = None
|
|
else:
|
|
# validate and set precision
|
|
if isinstance(mixed_precision, str):
|
|
# the user will take the default arguments for amp training
|
|
self.mixed_precision = mixed_precision_factory(mixed_precision)
|
|
elif isinstance(mixed_precision, MixedPrecision):
|
|
# the user can customize the arguments by passing the precision object
|
|
self.mixed_precision = mixed_precision
|
|
else:
|
|
raise ValueError(
|
|
f"Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}."
|
|
)
|
|
|
|
if self.plugin is not None and self.plugin.control_checkpoint_io():
|
|
self.checkpoint_io = self.plugin.get_checkpoint_io()
|
|
else:
|
|
self.checkpoint_io = GeneralCheckpointIO()
|
|
|
|
def boost(
|
|
self,
|
|
model: nn.Module,
|
|
optimizer: Optional[Optimizer] = None,
|
|
criterion: Optional[Callable] = None,
|
|
dataloader: Optional[DataLoader] = None,
|
|
lr_scheduler: Optional[LRScheduler] = None,
|
|
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
|
|
"""
|
|
Wrap and inject features to the passed in model, optimizer, criterion, lr_scheduler, and dataloader.
|
|
|
|
Args:
|
|
model (nn.Module): Convert model into a wrapped model for distributive training.
|
|
The model might be decorated or partitioned by plugin's strategy after execution of this method.
|
|
optimizer (Optimizer, optional): Convert optimizer into a wrapped optimizer for distributive training.
|
|
The optimizer's param groups or states might be decorated or partitioned by plugin's strategy after execution of this method. Defaults to None.
|
|
criterion (Callable, optional): The function that calculates loss. Defaults to None.
|
|
dataloader (DataLoader, optional): The prepared dataloader for training. Defaults to None.
|
|
lr_scheduler (LRScheduler, optional): The learning scheduler for training. Defaults to None.
|
|
|
|
Returns:
|
|
List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: The list of boosted input arguments.
|
|
"""
|
|
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
|
|
# TODO(FrankLeeeee): consider multi-dataloader case
|
|
pretrained_path = pretrained_utils.get_pretrained_path(model)
|
|
# transform model for mixed precision
|
|
if self.plugin:
|
|
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
|
|
model, optimizer, criterion, dataloader, lr_scheduler
|
|
)
|
|
|
|
if self.plugin and not self.plugin.control_device():
|
|
# transform model for accelerator
|
|
model = self.accelerator.configure_model(model)
|
|
|
|
if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
|
|
# transform model for mixed precision
|
|
# when mixed_precision is specified and the plugin is not given or does not control the precision
|
|
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
|
|
|
|
if pretrained_path:
|
|
self.load_model(model, pretrained_path)
|
|
# clear pretrained path attr
|
|
orig_model = model.unwrap() if isinstance(model, ModelWrapper) else model
|
|
pretrained_utils.set_pretrained_path(orig_model, None)
|
|
|
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
|
|
|
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
|
|
"""Execution of backward during training step.
|
|
|
|
Args:
|
|
loss (torch.Tensor): The loss for backpropagation.
|
|
optimizer (Optimizer): The optimizer to be updated.
|
|
"""
|
|
# TODO(frank lee): implement this method with plugin
|
|
optimizer.backward(loss)
|
|
|
|
def execute_pipeline(
|
|
self,
|
|
data_iter: Iterator,
|
|
model: nn.Module,
|
|
criterion: Callable[[Any, Any], torch.Tensor],
|
|
optimizer: Optional[Optimizer] = None,
|
|
return_loss: bool = True,
|
|
return_outputs: bool = False,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Execute forward & backward when utilizing pipeline parallel.
|
|
Return loss or Huggingface style model outputs if needed.
|
|
|
|
Warning: This function is tailored for the scenario of pipeline parallel.
|
|
As a result, please don't do the forward/backward pass in the conventional way (model(input)/loss.backward())
|
|
when doing pipeline parallel training with booster, which will cause unexpected errors.
|
|
|
|
Args:
|
|
data_iter(Iterator): The iterator for getting the next batch of data. Usually there are two ways to obtain this argument:
|
|
1. wrap the dataloader to iterator through: iter(dataloader)
|
|
2. get the next batch from dataloader, and wrap this batch to iterator: iter([batch])
|
|
model (nn.Module): The model to execute forward/backward, it should be a model wrapped by a plugin that supports pipeline.
|
|
criterion: (Callable[[Any, Any], torch.Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
|
'lambda y, x: loss_fn(y)' can turn a normal loss function into a valid two-argument criterion here.
|
|
optimizer (Optimizer, optional): The optimizer for execution of backward. Can be None when only doing forward (i.e. evaluation). Defaults to None.
|
|
return_loss (bool, optional): Whether to return loss in the dict returned by this method. Defaults to True.
|
|
return_output (bool, optional): Whether to return Huggingface style model outputs in the dict returned by this method. Defaults to False.
|
|
|
|
Returns:
|
|
Dict[str, Any]: Output dict in the form of {'loss': ..., 'outputs': ...}.
|
|
ret_dict['loss'] is the loss of forward if return_loss is set to True, else None.
|
|
ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.
|
|
"""
|
|
assert isinstance(
|
|
self.plugin, PipelinePluginBase
|
|
), f"The plugin {self.plugin.__class__.__name__} does not support pipeline."
|
|
return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)
|
|
|
|
def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
|
|
"""Context manager to disable gradient synchronization across DP process groups.
|
|
Support torch DDP and Low Level ZeRO-1 for now.
|
|
|
|
Args:
|
|
model (nn.Module): The model to be disabled gradient synchronization, for DDP
|
|
optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1
|
|
|
|
Returns:
|
|
contextmanager: Context to disable gradient synchronization.
|
|
"""
|
|
assert (
|
|
self.plugin is not None
|
|
), f"no_sync is only enabled when a plugin is provided and the plugin supports no_sync."
|
|
assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
|
|
return self.plugin.no_sync(model, optimizer)
|
|
|
|
def enable_lora(
|
|
self,
|
|
model: nn.Module,
|
|
pretrained_dir: Optional[str] = None,
|
|
lora_config: "peft.LoraConfig" = None,
|
|
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
|
|
quantize=False,
|
|
) -> nn.Module:
|
|
"""
|
|
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
|
|
Lora in ColossalAI is implemented using Huggingface peft library, so the arguments for Lora configuration are same as those of peft.
|
|
|
|
Args:
|
|
model (nn.Module): The model to be appended with LoRA modules.
|
|
pretrained_dir(str, optional): The path to the pretrained directory, can be a local directory
|
|
or model_id of a PEFT configuration hosted inside a model repo on the Hugging Face Hub.
|
|
When set to None, create new lora configs and weights for the model using the passed in lora_config. Defaults to None.
|
|
lora_config: (peft.LoraConfig, optional): Passed in LoraConfig for peft. Defaults to None.
|
|
"""
|
|
if not SUPPORT_PEFT:
|
|
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
|
|
|
|
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
|
|
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
|
|
if pretrained_dir is None:
|
|
assert (
|
|
lora_config is not None
|
|
), "Please provide configuration for Lora when pretrained directory path isn't passed in."
|
|
assert isinstance(
|
|
lora_config, peft.LoraConfig
|
|
), "The passed in configuration should be an instance of peft.LoraConfig."
|
|
if lora_config is None:
|
|
assert (
|
|
pretrained_dir is not None
|
|
), "Please provide pretrained directory path if not passing in lora configuration."
|
|
if quantize is True:
|
|
if bnb_quantization_config is not None:
|
|
self.logger.warning(
|
|
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk.",
|
|
ranks=[0],
|
|
)
|
|
else:
|
|
bnb_quantization_config = BnbQuantizationConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
bnb_4bit_use_double_quant=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
)
|
|
|
|
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
|
|
|
|
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
|
"""Load model from checkpoint.
|
|
|
|
Args:
|
|
model (nn.Module or ModelWrapper): A model boosted by Booster.
|
|
checkpoint (str): Path to the checkpoint. It must be a local path.
|
|
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
|
strict (bool, optional): whether to strictly enforce that the keys
|
|
in :attr:`state_dict` match the keys returned by this module's
|
|
:meth:`~torch.nn.Module.state_dict` function. Defaults to True.
|
|
"""
|
|
self.checkpoint_io.load_model(model, checkpoint, strict)
|
|
|
|
def save_model(
|
|
self,
|
|
model: Union[nn.Module, ModelWrapper],
|
|
checkpoint: str,
|
|
shard: bool = False,
|
|
gather_dtensor: bool = True,
|
|
prefix: Optional[str] = None,
|
|
size_per_shard: int = 1024,
|
|
use_safetensors: bool = False,
|
|
) -> None:
|
|
"""Save model to checkpoint.
|
|
|
|
Args:
|
|
model (nn.Module or ModelWrapper): A model boosted by Booster.
|
|
checkpoint (str): Path to the checkpoint. It must be a local path.
|
|
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
|
shard (bool, optional): Whether to save checkpoint a sharded way.
|
|
If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False.
|
|
gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True.
|
|
prefix (str, optional): A prefix added to parameter and buffer
|
|
names to compose the keys in state_dict. Defaults to None.
|
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
|
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
|
|
"""
|
|
self.checkpoint_io.save_model(
|
|
model,
|
|
checkpoint=checkpoint,
|
|
shard=shard,
|
|
gather_dtensor=gather_dtensor,
|
|
prefix=prefix,
|
|
size_per_shard=size_per_shard,
|
|
use_safetensors=use_safetensors,
|
|
)
|
|
|
|
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
|
|
"""Load optimizer from checkpoint.
|
|
|
|
Args:
|
|
optimizer (Optimizer): An optimizer boosted by Booster.
|
|
checkpoint (str): Path to the checkpoint. It must be a local path.
|
|
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
|
prefix (str, optional): A prefix added to parameter and buffer
|
|
names to compose the keys in state_dict. Defaults to None.
|
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
|
"""
|
|
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
|
|
|
|
def save_optimizer(
|
|
self,
|
|
optimizer: Optimizer,
|
|
checkpoint: str,
|
|
shard: bool = False,
|
|
gather_dtensor: bool = True,
|
|
prefix: Optional[str] = None,
|
|
size_per_shard: int = 1024,
|
|
) -> None:
|
|
"""
|
|
Save optimizer to checkpoint.
|
|
|
|
Args:
|
|
optimizer (Optimizer): An optimizer boosted by Booster.
|
|
checkpoint (str): Path to the checkpoint. It must be a local path.
|
|
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
|
shard (bool, optional): Whether to save checkpoint a sharded way.
|
|
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
|
prefix (str, optional): A prefix added to parameter and buffer
|
|
names to compose the keys in state_dict. Defaults to None.
|
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
|
"""
|
|
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
|
|
|
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
|
|
"""Save lr scheduler to checkpoint.
|
|
|
|
Args:
|
|
lr_scheduler (LRScheduler): A lr scheduler boosted by Booster.
|
|
checkpoint (str): Path to the checkpoint. It must be a local file path.
|
|
"""
|
|
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)
|
|
|
|
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
|
|
"""Load lr scheduler from checkpoint.
|
|
|
|
Args:
|
|
lr_scheduler (LRScheduler): A lr scheduler boosted by Booster.
|
|
checkpoint (str): Path to the checkpoint. It must be a local file path.
|
|
"""
|
|
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
|
|
|
|
def save_lora_as_pretrained(
|
|
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
|
) -> None:
|
|
"""
|
|
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
|
|
|
|
Args:
|
|
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
|
|
checkpoint (str): Path to the checkpoint directory. It must be a local path.
|
|
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
|
|
"""
|
|
if not SUPPORT_PEFT:
|
|
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
|
|
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
|
|
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
|
|
self.checkpoint_io.save_lora_as_pretrained(model, checkpoint, use_safetensors)
|