mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-25 23:13:06 +00:00
fix: testing works
This commit is contained in:
parent
df79fd64b0
commit
aa814757fc
@ -2,13 +2,16 @@ import torch
|
||||
from gpt4all.models import GPTJRForCausalLM, GPTJRConfig
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
config = GPTJRConfig(encoder_dim=384, n_layer=4)
|
||||
print("loaded config")
|
||||
|
||||
print("loading model")
|
||||
config = GPTJRConfig(encoder_ndim=384)
|
||||
model = GPTJRForCausalLM(config)
|
||||
print("loaded model")
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
encoder_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
||||
encoder = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
||||
@ -23,19 +26,21 @@ text = "The quick brown fox jumps over the lazy dog."
|
||||
print("Encoded knn")
|
||||
tokenized = encoder_tokenizer(text, return_tensors="pt")
|
||||
|
||||
# bs, seq_len, dim
|
||||
encodings = mean_pooling(encoder(**tokenized), tokenized["attention_mask"])
|
||||
|
||||
# make 2 neighbors
|
||||
# (bs, knn, encoding_dim)
|
||||
encoder_outputs = torch.stack([encodings, encodings]).unsqueeze(0)
|
||||
encoder_outputs = torch.stack([encodings, encodings]).squeeze().unsqueeze(0)
|
||||
|
||||
inputs = "What did the fox do?"
|
||||
|
||||
print("Encoded inputs")
|
||||
tokenized_input = tokenizer(inputs, padding="max_length", truncation="true", return_tensors="pt")
|
||||
tokenized_input = tokenizer([inputs], padding="max_length", truncation=True, return_tensors="pt")
|
||||
|
||||
print("Running model")
|
||||
outputs = model(**tokenized_input, encoder_outputs=encoder_outputs)
|
||||
|
||||
print(outputs.shape)
|
||||
print(outputs)
|
||||
print(outputs[0].shape)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user