mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[lora] add lora APIs for booster, support lora for TorchDDP (#4981)
* add apis and peft requirement * add liscense and implement apis * add checkpointio apis * add torchddp fwd_bwd test * add support_lora methods * add checkpointio test and debug * delete unneeded codes * remove peft from LICENSE * add concrete methods for enable_lora * simplify enable_lora api * fix requirements
This commit is contained in:
committed by
Hongxin Liu
parent
c1594e4bad
commit
14b0d4c7e5
@@ -335,3 +335,20 @@ class CheckpointIO(ABC):
|
||||
"""
|
||||
state_dict = torch.load(checkpoint)
|
||||
lr_scheduler.load_state_dict(state_dict)
|
||||
|
||||
# ================================================================================
|
||||
# Abstract method for lora saving implementation.
|
||||
# ================================================================================
|
||||
|
||||
@abstractmethod
|
||||
def save_lora_as_pretrained(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
|
||||
|
||||
Args:
|
||||
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint directory. It must be a local path.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
|
||||
"""
|
||||
|
Reference in New Issue
Block a user