diff --git a/test_gpt_jr.py b/test_gpt_jr.py index 8f0dde8d..e8c406dc 100644 --- a/test_gpt_jr.py +++ b/test_gpt_jr.py @@ -2,13 +2,16 @@ import torch from gpt4all.models import GPTJRForCausalLM, GPTJRConfig from transformers import AutoTokenizer, AutoModel +config = GPTJRConfig(encoder_dim=384, n_layer=4) +print("loaded config") + print("loading model") -config = GPTJRConfig(encoder_ndim=384) model = GPTJRForCausalLM(config) print("loaded model") tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b") +tokenizer.pad_token = tokenizer.eos_token encoder_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') encoder = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') @@ -23,19 +26,21 @@ text = "The quick brown fox jumps over the lazy dog." print("Encoded knn") tokenized = encoder_tokenizer(text, return_tensors="pt") +# bs, seq_len, dim encodings = mean_pooling(encoder(**tokenized), tokenized["attention_mask"]) # make 2 neighbors # (bs, knn, encoding_dim) -encoder_outputs = torch.stack([encodings, encodings]).unsqueeze(0) +encoder_outputs = torch.stack([encodings, encodings]).squeeze().unsqueeze(0) inputs = "What did the fox do?" print("Encoded inputs") -tokenized_input = tokenizer(inputs, padding="max_length", truncation="true", return_tensors="pt") +tokenized_input = tokenizer([inputs], padding="max_length", truncation=True, return_tensors="pt") print("Running model") outputs = model(**tokenized_input, encoder_outputs=encoder_outputs) -print(outputs.shape) +print(outputs) +print(outputs[0].shape)