[chat]: add lora merge weights config (#4766)

* feat: modify lora merge weights fn

* feat: add lora merge weights config
This commit is contained in:
Wenhao Chen
2023-09-21 16:23:59 +08:00
committed by GitHub
parent 493a5efeab
commit 901ab1eedd
4 changed files with 60 additions and 30 deletions

View File

@@ -1,4 +1,6 @@
import dataclasses
import math
import warnings
from typing import Optional
import loralib as lora
@@ -7,6 +9,14 @@ import torch.nn as nn
import torch.nn.functional as F
@dataclasses.dataclass
class LoRAManager:
merge_weights: bool = False
LORA_MANAGER = LoRAManager()
class LoraLinear(lora.LoRALayer, nn.Module):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
@@ -17,13 +27,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
fan_in_fan_out: bool = False,
):
nn.Module.__init__(self)
lora.LoRALayer.__init__(
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
)
lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
self.weight = weight
self.bias = bias
@@ -53,31 +61,31 @@ class LoraLinear(lora.LoRALayer, nn.Module):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Module.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
# FIXME(csric): temporary fix
self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
self.reset_parameters()
else:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
self.training = mode
if LORA_MANAGER.merge_weights:
if mode and self.merged:
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
raise NotImplementedError("LoRA unmerge is not tested.")
# Make sure that the weights are not merged
if self.r > 0:
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
# FIXME(csric): temporary fix
self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
self.reset_parameters()
else:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
elif not mode and not self.merged:
warnings.warn("Invoke module.eval() would merge LoRA weights.")
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
delattr(self, "lora_A")
delattr(self, "lora_B")
self.merged = True
def eval(self):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Module.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
delattr(self, "lora_A")
delattr(self, "lora_B")
self.merged = True
return self
def forward(self, x: torch.Tensor):
def T(w):
@@ -96,7 +104,7 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert (
lora_rank <= linear.in_features
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
return lora_linear