mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-07 20:41:24 +00:00
refactor: imports
This commit is contained in:
parent
0402d3e28a
commit
d1b64d7eed
@ -1,8 +1,12 @@
|
||||
from .configuration_gpt_jr import GPTJRConfig
|
||||
from .modeling_gpt_jr import GPTJRForCausalLM
|
||||
from .gpt_jr.configuration_gpt_jr import GPTJRConfig
|
||||
from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM
|
||||
|
||||
from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GPTJRConfig",
|
||||
"GPTJRForCausalLM"
|
||||
"GPTJRForCausalLM",
|
||||
"PythiaSeekConfig",
|
||||
"PythiaSeekForCausalLM",
|
||||
]
|
@ -35,7 +35,7 @@ encodings = mean_pooling(encoder(**tokenized), tokenized["attention_mask"])
|
||||
|
||||
# make 2 neighbors
|
||||
# (bs, knn, encoding_dim)
|
||||
encoder_outputs = torch.stack([encodings, encodings]).squeeze().unsqueeze(0)
|
||||
encoder_hidden_states = torch.stack([encodings, encodings]).squeeze().unsqueeze(0)
|
||||
|
||||
inputs = "What did the fox do?"
|
||||
|
||||
@ -43,7 +43,7 @@ print("Encoded inputs")
|
||||
tokenized_input = tokenizer([inputs], padding="max_length", truncation=True, return_tensors="pt")
|
||||
|
||||
print("Running model")
|
||||
outputs = model(**tokenized_input, encoder_outputs=encoder_outputs)
|
||||
outputs = model(**tokenized_input, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
print(outputs)
|
||||
print(outputs[0].shape)
|
||||
|
Loading…
Reference in New Issue
Block a user