mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
@@ -2,6 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from coati.models import convert_to_lora_module
|
||||
from coati.models.lora import LoraConfig, LoraEmbedding, LoraLinear
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
|
||||
@@ -38,7 +39,7 @@ def test_overfit():
|
||||
# Build and convert model
|
||||
model = SimpleNN(input_size, hidden_size, num_classes)
|
||||
weight_to_compare = model.fc1.weight.detach().clone()
|
||||
model = convert_to_lora_module(model, lora_rank=30)
|
||||
model = convert_to_lora_module(model, lora_config=LoraConfig(r=32))
|
||||
|
||||
# Loss and optimizer
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
@@ -50,7 +51,6 @@ def test_overfit():
|
||||
# Forward pass
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
print(loss)
|
||||
# Backward and optimize
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
@@ -65,5 +65,50 @@ def test_overfit():
|
||||
assert (weight_to_compare - model.fc1.weight).sum() < 0.01
|
||||
|
||||
|
||||
def test_lora_linear_accuracy():
|
||||
|
||||
weight = torch.randn(10, 5)
|
||||
linear = nn.Linear(5, 10)
|
||||
linear.weight.data = weight
|
||||
x = torch.randn(10, 5)
|
||||
out_linear = linear(x)
|
||||
|
||||
# lora linear Pissa
|
||||
linear.weight.data = weight
|
||||
lora_linear = LoraLinear(linear.weight, linear.bias, r=2, lora_initialization_method="PiSSA")
|
||||
out_lora = lora_linear(x)
|
||||
assert torch.allclose(out_linear, out_lora, atol=1e-5, rtol=1e-05)
|
||||
|
||||
# lora linear
|
||||
linear.weight.data = weight
|
||||
lora_linear = LoraLinear(linear.weight, linear.bias, r=2)
|
||||
out_lora = lora_linear(x)
|
||||
assert torch.allclose(out_linear, out_lora, atol=1e-5, rtol=1e-05)
|
||||
|
||||
|
||||
def test_lora_embedding_accuracy():
|
||||
weight = torch.randn(10, 5)
|
||||
embedding = nn.Embedding(10, 5)
|
||||
embedding.weight.data = weight
|
||||
x = torch.randint(0, 10, (10,))
|
||||
out_embedding = embedding(x)
|
||||
|
||||
# lora embedding Pissa
|
||||
embedding.weight.data = weight
|
||||
lora_embedding = LoraEmbedding(
|
||||
embedding.weight, r=2, lora_initialization_method="PiSSA", num_embeddings=10, embedding_dim=5
|
||||
)
|
||||
out_lora = lora_embedding(x)
|
||||
assert torch.allclose(out_embedding, out_lora, atol=1e-5, rtol=1e-05)
|
||||
|
||||
# lora embedding
|
||||
embedding.weight.data = weight
|
||||
lora_embedding = LoraEmbedding(embedding.weight, r=2, num_embeddings=10, embedding_dim=5)
|
||||
out_lora = lora_embedding(x)
|
||||
assert torch.allclose(out_embedding, out_lora, atol=1e-5, rtol=1e-05)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_overfit()
|
||||
test_lora_linear_accuracy()
|
||||
test_lora_embedding_accuracy()
|
||||
|
Reference in New Issue
Block a user