[Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style
This commit is contained in:
YeAnbang
2024-07-31 14:10:17 +08:00
committed by GitHub
parent 09c5f72595
commit 30f4e31a33
13 changed files with 552 additions and 252 deletions

View File

@@ -2,6 +2,7 @@ import torch
import torch.nn as nn
import torch.optim as optim
from coati.models import convert_to_lora_module
from coati.models.lora import LoraConfig, LoraEmbedding, LoraLinear
from torch.utils.data import DataLoader, TensorDataset
@@ -38,7 +39,7 @@ def test_overfit():
# Build and convert model
model = SimpleNN(input_size, hidden_size, num_classes)
weight_to_compare = model.fc1.weight.detach().clone()
model = convert_to_lora_module(model, lora_rank=30)
model = convert_to_lora_module(model, lora_config=LoraConfig(r=32))
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
@@ -50,7 +51,6 @@ def test_overfit():
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
print(loss)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
@@ -65,5 +65,50 @@ def test_overfit():
assert (weight_to_compare - model.fc1.weight).sum() < 0.01
def test_lora_linear_accuracy():
weight = torch.randn(10, 5)
linear = nn.Linear(5, 10)
linear.weight.data = weight
x = torch.randn(10, 5)
out_linear = linear(x)
# lora linear Pissa
linear.weight.data = weight
lora_linear = LoraLinear(linear.weight, linear.bias, r=2, lora_initialization_method="PiSSA")
out_lora = lora_linear(x)
assert torch.allclose(out_linear, out_lora, atol=1e-5, rtol=1e-05)
# lora linear
linear.weight.data = weight
lora_linear = LoraLinear(linear.weight, linear.bias, r=2)
out_lora = lora_linear(x)
assert torch.allclose(out_linear, out_lora, atol=1e-5, rtol=1e-05)
def test_lora_embedding_accuracy():
weight = torch.randn(10, 5)
embedding = nn.Embedding(10, 5)
embedding.weight.data = weight
x = torch.randint(0, 10, (10,))
out_embedding = embedding(x)
# lora embedding Pissa
embedding.weight.data = weight
lora_embedding = LoraEmbedding(
embedding.weight, r=2, lora_initialization_method="PiSSA", num_embeddings=10, embedding_dim=5
)
out_lora = lora_embedding(x)
assert torch.allclose(out_embedding, out_lora, atol=1e-5, rtol=1e-05)
# lora embedding
embedding.weight.data = weight
lora_embedding = LoraEmbedding(embedding.weight, r=2, num_embeddings=10, embedding_dim=5)
out_lora = lora_embedding(x)
assert torch.allclose(out_embedding, out_lora, atol=1e-5, rtol=1e-05)
if __name__ == "__main__":
test_overfit()
test_lora_linear_accuracy()
test_lora_embedding_accuracy()