[lora] add lora APIs for booster, support lora for TorchDDP (#4981)

* add apis and peft requirement

* add liscense and implement apis

* add checkpointio apis

* add torchddp fwd_bwd test

* add support_lora methods

* add checkpointio test and debug

* delete unneeded codes

* remove peft from LICENSE

* add concrete methods for enable_lora

* simplify enable_lora api

* fix requirements
This commit is contained in:
Baizhou Zhang
2023-10-31 15:19:37 +08:00
committed by Hongxin Liu
parent c1594e4bad
commit 14b0d4c7e5
11 changed files with 265 additions and 7 deletions

View File

@@ -335,3 +335,20 @@ class CheckpointIO(ABC):
"""
state_dict = torch.load(checkpoint)
lr_scheduler.load_state_dict(state_dict)
# ================================================================================
# Abstract method for lora saving implementation.
# ================================================================================
@abstractmethod
def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
) -> None:
"""
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
Args:
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
checkpoint (str): Path to the checkpoint directory. It must be a local path.
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
"""