[lora] add lora APIs for booster, support lora for TorchDDP (#4981)

* add apis and peft requirement

* add liscense and implement apis

* add checkpointio apis

* add torchddp fwd_bwd test

* add support_lora methods

* add checkpointio test and debug

* delete unneeded codes

* remove peft from LICENSE

* add concrete methods for enable_lora

* simplify enable_lora api

* fix requirements
This commit is contained in:
Baizhou Zhang
2023-10-31 15:19:37 +08:00
committed by Hongxin Liu
parent c1594e4bad
commit 14b0d4c7e5
11 changed files with 265 additions and 7 deletions

View File

@@ -3,7 +3,7 @@ import logging
import os
import random
from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterator, List, Optional, Tuple
import numpy as np
import torch
@@ -444,6 +444,9 @@ class GeminiPlugin(DPPluginBase):
def support_no_sync(self) -> bool:
return False
def support_lora(self) -> bool:
return False
def control_precision(self) -> bool:
return True
@@ -573,3 +576,8 @@ class GeminiPlugin(DPPluginBase):
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
raise NotImplementedError