[lora] add lora APIs for booster, support lora for TorchDDP (#4981)

* add apis and peft requirement

* add liscense and implement apis

* add checkpointio apis

* add torchddp fwd_bwd test

* add support_lora methods

* add checkpointio test and debug

* delete unneeded codes

* remove peft from LICENSE

* add concrete methods for enable_lora

* simplify enable_lora api

* fix requirements
This commit is contained in:
Baizhou Zhang
2023-10-31 15:19:37 +08:00
committed by Hongxin Liu
parent c1594e4bad
commit 14b0d4c7e5
11 changed files with 265 additions and 7 deletions

View File

@@ -2,7 +2,7 @@ import logging
import os
import warnings
from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
@@ -318,6 +318,9 @@ class TorchFSDPPlugin(DPPluginBase):
def support_no_sync(self) -> bool:
return False
def support_lora(self) -> bool:
return False
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
@@ -361,3 +364,8 @@ class TorchFSDPPlugin(DPPluginBase):
def get_checkpoint_io(self) -> CheckpointIO:
return TorchFSDPCheckpointIO()
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
raise NotImplementedError