[LowLevelZero] low level zero support lora (#5153)

* low level zero support lora

low level zero support lora

* add checkpoint test

* add checkpoint test

* fix

* fix

* fix

* fix

fix

fix

fix

* fix

* fix

fix

fix

fix

fix

fix

fix

* fix

* fix

fix

fix

fix

fix

fix

fix

* fix

* test ci

* git # This is a combination of 3 commits.

Update low_level_zero_plugin.py

Update low_level_zero_plugin.py

fix

fix

fix

* fix naming

fix naming

fix naming

fix
This commit is contained in:
flybird11111
2023-12-21 17:01:01 +08:00
committed by Hongxin Liu
parent 14b0d4c7e5
commit 8954a0c2e2
8 changed files with 264 additions and 8 deletions

View File

@@ -1,5 +1,7 @@
import enum
import logging
import os
import warnings
from functools import partial
from pathlib import Path
from types import MethodType
@@ -7,6 +9,7 @@ from typing import Callable, Dict, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
@@ -42,6 +45,12 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
class OptimizerParamCheckState(enum.Enum):
ORIGIN_PARAM_FINDED = 0
ORIGIN_PARAM_NOT_FIND = -1
LORA_PARM_EXISTED = -2
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None:
super().__init__(module)
@@ -209,6 +218,19 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
class LowLevelZeroPlugin(DPPluginBase):
"""
@@ -288,6 +310,7 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload=cpu_offload,
master_weights=master_weights,
)
self.lora_enabled = False
self.verbose = verbose
# set class name with stage, for better error message
@@ -311,6 +334,72 @@ class LowLevelZeroPlugin(DPPluginBase):
def supported_devices(self) -> List[str]:
return ["cuda", "npu"]
def support_lora(self) -> bool:
return True
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
from peft import PeftModel, get_peft_model
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
if pretrained_dir is None:
peft_model = get_peft_model(model, lora_config)
else:
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
return peft_model
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
origin_param_id = id(origin_param)
for group_id, param_group in enumerate(optimizer.param_groups):
for p in param_group["params"]:
if id(p) == origin_param_id:
return group_id
return -1
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter):
origin_param_id = id(origin_param)
lora_param_id = id(lora_param)
target_group_id = None
for group_id, param_group in enumerate(optimizer.param_groups):
for p in param_group["params"]:
if id(p) == lora_param_id:
# check if the lora parameter exists.
return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED
if id(p) == origin_param_id:
target_group_id = group_id
if target_group_id is not None:
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED
else:
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND
def add_lora_params_to_optimizer(self, model, optimizer):
"""add lora parameters to optimizer"""
name2param = {}
for name, param in model.named_parameters():
name2param[name] = param
for name, param in name2param.items():
if "lora_A" in name or "lora_B" in name:
origin_key = name.replace("lora_A.", "")
origin_key = origin_key.replace("lora_B.", "")
origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer")
origin_param = name2param[origin_key]
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
warnings.warn(
"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
)
elif (
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
and group_id is not None
and group_id >= 0
):
optimizer.param_groups[group_id]["params"].append(param)
def configure(
self,
model: nn.Module,
@@ -319,6 +408,15 @@ class LowLevelZeroPlugin(DPPluginBase):
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if self.lora_enabled:
from peft import PeftModel
assert isinstance(
model, PeftModel
), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
if optimizer is not None:
self.add_lora_params_to_optimizer(model, optimizer)
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision)
@@ -340,8 +438,3 @@ class LowLevelZeroPlugin(DPPluginBase):
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(optimizer, LowLevelZeroOptimizer)
return optimizer.no_sync()
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
raise NotImplementedError