mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[lora] add lora APIs for booster, support lora for TorchDDP (#4981)
* add apis and peft requirement * add liscense and implement apis * add checkpointio apis * add torchddp fwd_bwd test * add support_lora methods * add checkpointio test and debug * delete unneeded codes * remove peft from LICENSE * add concrete methods for enable_lora * simplify enable_lora api * fix requirements
This commit is contained in:
committed by
Hongxin Liu
parent
c1594e4bad
commit
14b0d4c7e5
@@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -318,6 +318,9 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
def support_no_sync(self) -> bool:
|
||||
return False
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return False
|
||||
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
|
||||
|
||||
@@ -361,3 +364,8 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return TorchFSDPCheckpointIO()
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
Reference in New Issue
Block a user