[lora] add lora APIs for booster, support lora for TorchDDP (#4981)

* add apis and peft requirement

* add liscense and implement apis

* add checkpointio apis

* add torchddp fwd_bwd test

* add support_lora methods

* add checkpointio test and debug

* delete unneeded codes

* remove peft from LICENSE

* add concrete methods for enable_lora

* simplify enable_lora api

* fix requirements
This commit is contained in:
Baizhou Zhang
2023-10-31 15:19:37 +08:00
committed by Hongxin Liu
parent c1594e4bad
commit 14b0d4c7e5
11 changed files with 265 additions and 7 deletions

View File

@@ -4,7 +4,7 @@ import warnings
from contextlib import contextmanager
from functools import partial
from types import MethodType
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union
import numpy as np
import torch
@@ -1156,6 +1156,9 @@ class HybridParallelPlugin(PipelinePluginBase):
def support_no_sync(self) -> bool:
return True
def support_lora(self) -> bool:
return False
def control_checkpoint_io(self) -> bool:
return True
@@ -1356,3 +1359,8 @@ class HybridParallelPlugin(PipelinePluginBase):
self.zero_stage != 2
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
def enable_lora(
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> Module:
raise NotImplementedError