mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 01:24:04 +00:00
[lora] add lora APIs for booster, support lora for TorchDDP (#4981)
* add apis and peft requirement * add liscense and implement apis * add checkpointio apis * add torchddp fwd_bwd test * add support_lora methods * add checkpointio test and debug * delete unneeded codes * remove peft from LICENSE * add concrete methods for enable_lora * simplify enable_lora api * fix requirements
This commit is contained in:
committed by
Hongxin Liu
parent
c1594e4bad
commit
14b0d4c7e5
@@ -8,6 +8,14 @@ from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
SUPPORT_PEFT = False
|
||||
try:
|
||||
import peft
|
||||
|
||||
SUPPORT_PEFT = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import colossalai.interface.pretrained as pretrained_utils
|
||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
@@ -221,6 +229,38 @@ class Booster:
|
||||
assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
|
||||
return self.plugin.no_sync(model, optimizer)
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
|
||||
Lora in ColossalAI is implemented using Huggingface peft library, so the arguments for Lora configuration are same as those of peft.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be appended with LoRA modules.
|
||||
pretrained_dir(str, optional): The path to the pretrained directory, can be a local directory
|
||||
or model_id of a PEFT configuration hosted inside a model repo on the Hugging Face Hub.
|
||||
When set to None, create new lora configs and weights for the model using the passed in lora_config. Defaults to None.
|
||||
lora_config: (peft.LoraConfig, optional): Passed in LoraConfig for peft. Defaults to None.
|
||||
"""
|
||||
if not SUPPORT_PEFT:
|
||||
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
|
||||
|
||||
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
|
||||
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
|
||||
if pretrained_dir is None:
|
||||
assert (
|
||||
lora_config is not None
|
||||
), "Please provide configuration for Lora when pretrained directory path isn't passed in."
|
||||
assert isinstance(
|
||||
lora_config, peft.LoraConfig
|
||||
), "The passed in configuration should be an instance of peft.LoraConfig."
|
||||
if lora_config is None:
|
||||
assert (
|
||||
pretrained_dir is not None
|
||||
), "Please provide pretrained directory path if not passing in lora configuration."
|
||||
return self.plugin.enable_lora(model, pretrained_dir, lora_config)
|
||||
|
||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
||||
"""Load model from checkpoint.
|
||||
|
||||
@@ -323,3 +363,20 @@ class Booster:
|
||||
checkpoint (str): Path to the checkpoint. It must be a local file path.
|
||||
"""
|
||||
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def save_lora_as_pretrained(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
|
||||
|
||||
Args:
|
||||
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint directory. It must be a local path.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
|
||||
"""
|
||||
if not SUPPORT_PEFT:
|
||||
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
|
||||
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
|
||||
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
|
||||
self.checkpoint_io.save_lora_as_pretrained(model, checkpoint, use_safetensors)
|
||||
|
Reference in New Issue
Block a user