[lora] add lora APIs for booster, support lora for TorchDDP (#4981)

* add apis and peft requirement

* add liscense and implement apis

* add checkpointio apis

* add torchddp fwd_bwd test

* add support_lora methods

* add checkpointio test and debug

* delete unneeded codes

* remove peft from LICENSE

* add concrete methods for enable_lora

* simplify enable_lora api

* fix requirements
This commit is contained in:
Baizhou Zhang
2023-10-31 15:19:37 +08:00
committed by Hongxin Liu
parent c1594e4bad
commit 14b0d4c7e5
11 changed files with 265 additions and 7 deletions

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterator, List, Optional, Tuple
import torch.nn as nn
from torch.optim import Optimizer
@@ -33,6 +33,10 @@ class Plugin(ABC):
def support_no_sync(self) -> bool:
pass
@abstractmethod
def support_lora(self) -> bool:
pass
@abstractmethod
def configure(
self,
@@ -63,6 +67,12 @@ class Plugin(ABC):
Context manager to disable gradient synchronization.
"""
@abstractmethod
def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
"""
Add LoRA modules to the model passed in. Should only be called in booster.enable_lora().
"""
@abstractmethod
def prepare_dataloader(
self,