mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[lora] add lora APIs for booster, support lora for TorchDDP (#4981)
* add apis and peft requirement * add liscense and implement apis * add checkpointio apis * add torchddp fwd_bwd test * add support_lora methods * add checkpointio test and debug * delete unneeded codes * remove peft from LICENSE * add concrete methods for enable_lora * simplify enable_lora api * fix requirements
This commit is contained in:
committed by
Hongxin Liu
parent
c1594e4bad
commit
14b0d4c7e5
@@ -4,7 +4,7 @@ import warnings
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -1156,6 +1156,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
def support_no_sync(self) -> bool:
|
||||
return True
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return False
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -1356,3 +1359,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
self.zero_stage != 2
|
||||
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
|
||||
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
|
||||
def enable_lora(
|
||||
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
) -> Module:
|
||||
raise NotImplementedError
|
||||
|
Reference in New Issue
Block a user