mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[lora] add lora APIs for booster, support lora for TorchDDP (#4981)
* add apis and peft requirement * add liscense and implement apis * add checkpointio apis * add torchddp fwd_bwd test * add support_lora methods * add checkpointio test and debug * delete unneeded codes * remove peft from LICENSE * add concrete methods for enable_lora * simplify enable_lora api * fix requirements
This commit is contained in:
committed by
Hongxin Liu
parent
c1594e4bad
commit
14b0d4c7e5
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
@@ -33,6 +33,10 @@ class Plugin(ABC):
|
||||
def support_no_sync(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def support_lora(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def configure(
|
||||
self,
|
||||
@@ -63,6 +67,12 @@ class Plugin(ABC):
|
||||
Context manager to disable gradient synchronization.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
|
||||
"""
|
||||
Add LoRA modules to the model passed in. Should only be called in booster.enable_lora().
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def prepare_dataloader(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user