refactor: imports

This commit is contained in:
Zach Nussbaum 2023-05-04 03:16:26 +00:00
parent 0402d3e28a
commit d1b64d7eed
2 changed files with 9 additions and 5 deletions

View File

@ -1,8 +1,12 @@
from .configuration_gpt_jr import GPTJRConfig
from .modeling_gpt_jr import GPTJRForCausalLM
from .gpt_jr.configuration_gpt_jr import GPTJRConfig
from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM
from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig
__all__ = [
"GPTJRConfig",
"GPTJRForCausalLM"
"GPTJRForCausalLM",
"PythiaSeekConfig",
"PythiaSeekForCausalLM",
]

View File

@ -35,7 +35,7 @@ encodings = mean_pooling(encoder(**tokenized), tokenized["attention_mask"])
# make 2 neighbors
# (bs, knn, encoding_dim)
encoder_outputs = torch.stack([encodings, encodings]).squeeze().unsqueeze(0)
encoder_hidden_states = torch.stack([encodings, encodings]).squeeze().unsqueeze(0)
inputs = "What did the fox do?"
@ -43,7 +43,7 @@ print("Encoded inputs")
tokenized_input = tokenizer([inputs], padding="max_length", truncation=True, return_tensors="pt")
print("Running model")
outputs = model(**tokenized_input, encoder_outputs=encoder_outputs)
outputs = model(**tokenized_input, encoder_hidden_states=encoder_hidden_states)
print(outputs)
print(outputs[0].shape)