fix: testing works

This commit is contained in:
Zach Nussbaum 2023-04-21 04:18:16 +00:00
parent df79fd64b0
commit aa814757fc

View File

@ -2,13 +2,16 @@ import torch
from gpt4all.models import GPTJRForCausalLM, GPTJRConfig from gpt4all.models import GPTJRForCausalLM, GPTJRConfig
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
config = GPTJRConfig(encoder_dim=384, n_layer=4)
print("loaded config")
print("loading model") print("loading model")
config = GPTJRConfig(encoder_ndim=384)
model = GPTJRForCausalLM(config) model = GPTJRForCausalLM(config)
print("loaded model") print("loaded model")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b") tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")
tokenizer.pad_token = tokenizer.eos_token
encoder_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') encoder_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
encoder = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') encoder = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
@ -23,19 +26,21 @@ text = "The quick brown fox jumps over the lazy dog."
print("Encoded knn") print("Encoded knn")
tokenized = encoder_tokenizer(text, return_tensors="pt") tokenized = encoder_tokenizer(text, return_tensors="pt")
# bs, seq_len, dim
encodings = mean_pooling(encoder(**tokenized), tokenized["attention_mask"]) encodings = mean_pooling(encoder(**tokenized), tokenized["attention_mask"])
# make 2 neighbors # make 2 neighbors
# (bs, knn, encoding_dim) # (bs, knn, encoding_dim)
encoder_outputs = torch.stack([encodings, encodings]).unsqueeze(0) encoder_outputs = torch.stack([encodings, encodings]).squeeze().unsqueeze(0)
inputs = "What did the fox do?" inputs = "What did the fox do?"
print("Encoded inputs") print("Encoded inputs")
tokenized_input = tokenizer(inputs, padding="max_length", truncation="true", return_tensors="pt") tokenized_input = tokenizer([inputs], padding="max_length", truncation=True, return_tensors="pt")
print("Running model") print("Running model")
outputs = model(**tokenized_input, encoder_outputs=encoder_outputs) outputs = model(**tokenized_input, encoder_outputs=encoder_outputs)
print(outputs.shape) print(outputs)
print(outputs[0].shape)