mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[LowLevelZero] low level zero support lora (#5153)
* low level zero support lora low level zero support lora * add checkpoint test * add checkpoint test * fix * fix * fix * fix fix fix fix * fix * fix fix fix fix fix fix fix * fix * fix fix fix fix fix fix fix * fix * test ci * git # This is a combination of 3 commits. Update low_level_zero_plugin.py Update low_level_zero_plugin.py fix fix fix * fix naming fix naming fix naming fix
This commit is contained in:
committed by
Hongxin Liu
parent
14b0d4c7e5
commit
8954a0c2e2
@@ -1,5 +1,7 @@
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
@@ -7,6 +9,7 @@ from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils._pytree import tree_map
|
||||
@@ -42,6 +45,12 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
||||
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
|
||||
|
||||
|
||||
class OptimizerParamCheckState(enum.Enum):
|
||||
ORIGIN_PARAM_FINDED = 0
|
||||
ORIGIN_PARAM_NOT_FIND = -1
|
||||
LORA_PARM_EXISTED = -2
|
||||
|
||||
|
||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
def __init__(self, module: nn.Module, precision: str) -> None:
|
||||
super().__init__(module)
|
||||
@@ -209,6 +218,19 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
model.update_master_params()
|
||||
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
peft_model = model.unwrap()
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
|
||||
|
||||
|
||||
class LowLevelZeroPlugin(DPPluginBase):
|
||||
"""
|
||||
@@ -288,6 +310,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
cpu_offload=cpu_offload,
|
||||
master_weights=master_weights,
|
||||
)
|
||||
self.lora_enabled = False
|
||||
self.verbose = verbose
|
||||
|
||||
# set class name with stage, for better error message
|
||||
@@ -311,6 +334,72 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
def supported_devices(self) -> List[str]:
|
||||
return ["cuda", "npu"]
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return True
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
) -> nn.Module:
|
||||
from peft import PeftModel, get_peft_model
|
||||
|
||||
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
|
||||
self.lora_enabled = True
|
||||
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
|
||||
|
||||
if pretrained_dir is None:
|
||||
peft_model = get_peft_model(model, lora_config)
|
||||
else:
|
||||
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
|
||||
return peft_model
|
||||
|
||||
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
|
||||
origin_param_id = id(origin_param)
|
||||
for group_id, param_group in enumerate(optimizer.param_groups):
|
||||
for p in param_group["params"]:
|
||||
if id(p) == origin_param_id:
|
||||
return group_id
|
||||
return -1
|
||||
|
||||
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter):
|
||||
origin_param_id = id(origin_param)
|
||||
lora_param_id = id(lora_param)
|
||||
target_group_id = None
|
||||
for group_id, param_group in enumerate(optimizer.param_groups):
|
||||
for p in param_group["params"]:
|
||||
if id(p) == lora_param_id:
|
||||
# check if the lora parameter exists.
|
||||
return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED
|
||||
if id(p) == origin_param_id:
|
||||
target_group_id = group_id
|
||||
if target_group_id is not None:
|
||||
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED
|
||||
else:
|
||||
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND
|
||||
|
||||
def add_lora_params_to_optimizer(self, model, optimizer):
|
||||
"""add lora parameters to optimizer"""
|
||||
name2param = {}
|
||||
for name, param in model.named_parameters():
|
||||
name2param[name] = param
|
||||
|
||||
for name, param in name2param.items():
|
||||
if "lora_A" in name or "lora_B" in name:
|
||||
origin_key = name.replace("lora_A.", "")
|
||||
origin_key = origin_key.replace("lora_B.", "")
|
||||
origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer")
|
||||
origin_param = name2param[origin_key]
|
||||
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
|
||||
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
|
||||
warnings.warn(
|
||||
"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
|
||||
)
|
||||
elif (
|
||||
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
|
||||
and group_id is not None
|
||||
and group_id >= 0
|
||||
):
|
||||
optimizer.param_groups[group_id]["params"].append(param)
|
||||
|
||||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
@@ -319,6 +408,15 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
if self.lora_enabled:
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(
|
||||
model, PeftModel
|
||||
), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
|
||||
if optimizer is not None:
|
||||
self.add_lora_params_to_optimizer(model, optimizer)
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
model = LowLevelZeroModel(model, self.precision)
|
||||
|
||||
@@ -340,8 +438,3 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert isinstance(optimizer, LowLevelZeroOptimizer)
|
||||
return optimizer.no_sync()
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
Reference in New Issue
Block a user