mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[lora] add lora APIs for booster, support lora for TorchDDP (#4981)
* add apis and peft requirement * add liscense and implement apis * add checkpointio apis * add torchddp fwd_bwd test * add support_lora methods * add checkpointio test and debug * delete unneeded codes * remove peft from LICENSE * add concrete methods for enable_lora * simplify enable_lora api * fix requirements
This commit is contained in:
@@ -8,6 +8,14 @@ from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
SUPPORT_PEFT = False
|
||||
try:
|
||||
import peft
|
||||
|
||||
SUPPORT_PEFT = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import colossalai.interface.pretrained as pretrained_utils
|
||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
@@ -221,6 +229,38 @@ class Booster:
|
||||
assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
|
||||
return self.plugin.no_sync(model, optimizer)
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
|
||||
Lora in ColossalAI is implemented using Huggingface peft library, so the arguments for Lora configuration are same as those of peft.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be appended with LoRA modules.
|
||||
pretrained_dir(str, optional): The path to the pretrained directory, can be a local directory
|
||||
or model_id of a PEFT configuration hosted inside a model repo on the Hugging Face Hub.
|
||||
When set to None, create new lora configs and weights for the model using the passed in lora_config. Defaults to None.
|
||||
lora_config: (peft.LoraConfig, optional): Passed in LoraConfig for peft. Defaults to None.
|
||||
"""
|
||||
if not SUPPORT_PEFT:
|
||||
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
|
||||
|
||||
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
|
||||
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
|
||||
if pretrained_dir is None:
|
||||
assert (
|
||||
lora_config is not None
|
||||
), "Please provide configuration for Lora when pretrained directory path isn't passed in."
|
||||
assert isinstance(
|
||||
lora_config, peft.LoraConfig
|
||||
), "The passed in configuration should be an instance of peft.LoraConfig."
|
||||
if lora_config is None:
|
||||
assert (
|
||||
pretrained_dir is not None
|
||||
), "Please provide pretrained directory path if not passing in lora configuration."
|
||||
return self.plugin.enable_lora(model, pretrained_dir, lora_config)
|
||||
|
||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
||||
"""Load model from checkpoint.
|
||||
|
||||
@@ -323,3 +363,20 @@ class Booster:
|
||||
checkpoint (str): Path to the checkpoint. It must be a local file path.
|
||||
"""
|
||||
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def save_lora_as_pretrained(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
|
||||
|
||||
Args:
|
||||
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint directory. It must be a local path.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
|
||||
"""
|
||||
if not SUPPORT_PEFT:
|
||||
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
|
||||
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
|
||||
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
|
||||
self.checkpoint_io.save_lora_as_pretrained(model, checkpoint, use_safetensors)
|
||||
|
@@ -2,7 +2,7 @@ import gc
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -360,6 +360,9 @@ class GeminiPlugin(DPPluginBase):
|
||||
def support_no_sync(self) -> bool:
|
||||
return False
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return False
|
||||
|
||||
def control_precision(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -408,3 +411,8 @@ class GeminiPlugin(DPPluginBase):
|
||||
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
@@ -3,7 +3,7 @@ import random
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -753,6 +753,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
def support_no_sync(self) -> bool:
|
||||
return False
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return False
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -891,3 +894,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
|
||||
def no_sync(self, model: Module) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def enable_lora(
|
||||
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
) -> Module:
|
||||
raise NotImplementedError
|
||||
|
@@ -3,7 +3,7 @@ import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -295,6 +295,9 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
def support_no_sync(self) -> bool:
|
||||
return self.stage == 1
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return False
|
||||
|
||||
def control_precision(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -336,3 +339,8 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert isinstance(optimizer, LowLevelZeroOptimizer)
|
||||
return optimizer.no_sync()
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
@@ -33,6 +33,10 @@ class Plugin(ABC):
|
||||
def support_no_sync(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def support_lora(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def configure(
|
||||
self,
|
||||
@@ -63,6 +67,12 @@ class Plugin(ABC):
|
||||
Context manager to disable gradient synchronization.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
|
||||
"""
|
||||
Add LoRA modules to the model passed in. Should only be called in booster.enable_lora().
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def prepare_dataloader(
|
||||
self,
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@@ -116,6 +116,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
|
||||
|
||||
def save_lora_as_pretrained(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Save the lora adapters and adapter configuration file to checkpoint directory.
|
||||
"""
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
if self.coordinator.is_master():
|
||||
peft_model = model.unwrap()
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
|
||||
|
||||
|
||||
class TorchDDPModel(ModelWrapper):
|
||||
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
|
||||
@@ -173,6 +189,9 @@ class TorchDDPPlugin(DPPluginBase):
|
||||
def support_no_sync(self) -> bool:
|
||||
return True
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return True
|
||||
|
||||
def control_precision(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -216,3 +235,14 @@ class TorchDDPPlugin(DPPluginBase):
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin."
|
||||
return model.module.no_sync()
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
) -> nn.Module:
|
||||
from peft import PeftModel, get_peft_model
|
||||
|
||||
assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model."
|
||||
if pretrained_dir is None:
|
||||
return get_peft_model(model, lora_config)
|
||||
else:
|
||||
return PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -190,7 +190,10 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
|
||||
def support_no_sync(self) -> bool:
|
||||
False
|
||||
return False
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return False
|
||||
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
|
||||
@@ -235,3 +238,8 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return TorchFSDPCheckpointIO()
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
@@ -327,3 +327,20 @@ class CheckpointIO(ABC):
|
||||
"""
|
||||
state_dict = torch.load(checkpoint)
|
||||
lr_scheduler.load_state_dict(state_dict)
|
||||
|
||||
# ================================================================================
|
||||
# Abstract method for lora saving implementation.
|
||||
# ================================================================================
|
||||
|
||||
@abstractmethod
|
||||
def save_lora_as_pretrained(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
|
||||
|
||||
Args:
|
||||
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint directory. It must be a local path.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
|
||||
"""
|
||||
|
@@ -228,3 +228,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)
|
||||
)
|
||||
)
|
||||
|
||||
def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None:
|
||||
raise NotImplementedError
|
||||
|
Reference in New Issue
Block a user