From d1b64d7eed450653af835af3d64a6f1763ae2741 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Thu, 4 May 2023 03:16:26 +0000 Subject: [PATCH] refactor: imports --- gpt4all/models/__init__.py | 10 +++++++--- gpt4all/models/gpt_jr/test_gpt_jr.py | 4 ++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/gpt4all/models/__init__.py b/gpt4all/models/__init__.py index 138b5b41..b0ce79f9 100644 --- a/gpt4all/models/__init__.py +++ b/gpt4all/models/__init__.py @@ -1,8 +1,12 @@ -from .configuration_gpt_jr import GPTJRConfig -from .modeling_gpt_jr import GPTJRForCausalLM +from .gpt_jr.configuration_gpt_jr import GPTJRConfig +from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM + +from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig __all__ = [ "GPTJRConfig", - "GPTJRForCausalLM" + "GPTJRForCausalLM", + "PythiaSeekConfig", + "PythiaSeekForCausalLM", ] \ No newline at end of file diff --git a/gpt4all/models/gpt_jr/test_gpt_jr.py b/gpt4all/models/gpt_jr/test_gpt_jr.py index 77a2c2fd..6bc4a727 100644 --- a/gpt4all/models/gpt_jr/test_gpt_jr.py +++ b/gpt4all/models/gpt_jr/test_gpt_jr.py @@ -35,7 +35,7 @@ encodings = mean_pooling(encoder(**tokenized), tokenized["attention_mask"]) # make 2 neighbors # (bs, knn, encoding_dim) -encoder_outputs = torch.stack([encodings, encodings]).squeeze().unsqueeze(0) +encoder_hidden_states = torch.stack([encodings, encodings]).squeeze().unsqueeze(0) inputs = "What did the fox do?" @@ -43,7 +43,7 @@ print("Encoded inputs") tokenized_input = tokenizer([inputs], padding="max_length", truncation=True, return_tensors="pt") print("Running model") -outputs = model(**tokenized_input, encoder_outputs=encoder_outputs) +outputs = model(**tokenized_input, encoder_hidden_states=encoder_hidden_states) print(outputs) print(outputs[0].shape)