transfer python bindings code

This commit is contained in:
Richard Guo
2023-05-10 13:38:32 -04:00
parent 75591061fd
commit 62031c22d3
18 changed files with 1068 additions and 0 deletions

View File

@@ -0,0 +1,62 @@
import pytest
from gpt4all.gpt4all import GPT4All
def test_invalid_model_type():
model_type = "bad_type"
with pytest.raises(ValueError):
GPT4All.get_model_from_type(model_type)
def test_valid_model_type():
model_type = "gptj"
assert GPT4All.get_model_from_type(model_type).model_type == model_type
def test_invalid_model_name():
model_name = "bad_filename.bin"
with pytest.raises(ValueError):
GPT4All.get_model_from_name(model_name)
def test_valid_model_name():
model_name = "ggml-gpt4all-l13b-snoozy"
model_type = "llama"
assert GPT4All.get_model_from_name(model_name).model_type == model_type
model_name += ".bin"
assert GPT4All.get_model_from_name(model_name).model_type == model_type
def test_build_prompt():
messages = [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello there."
},
{
"role": "assistant",
"content": "Hi, how can I help you?"
},
{
"role": "user",
"content": "Reverse a list in Python."
}
]
expected_prompt = """You are a helpful assistant.\
\n### Instruction:
The prompt below is a question to answer, a task to complete, or a conversation
to respond to; decide which and write an appropriate response.\
### Prompt:\
Hello there.\
Response: Hi, how can I help you?\
Reverse a list in Python.\
### Response:"""
print(expected_prompt)
full_prompt = GPT4All._build_prompt(messages, default_prompt_footer=True, default_prompt_header=True)
print("\n\n\n")
print(full_prompt)
assert len(full_prompt) == len(expected_prompt)

View File

@@ -0,0 +1,44 @@
from io import StringIO
import sys
from gpt4all import pyllmodel
# TODO: Integration test for loadmodel and prompt.
# # Right now, too slow b/c it requries file download.
def test_create_gptj():
gptj = pyllmodel.GPTJModel()
assert gptj.model_type == "gptj"
def test_create_llama():
llama = pyllmodel.LlamaModel()
assert llama.model_type == "llama"
def prompt_unloaded_gptj():
gptj = pyllmodel.GPTJModel()
old_stdout = sys.stdout
collect_response = StringIO()
sys.stdout = collect_response
gptj.prompt("hello there")
response = collect_response.getvalue()
sys.stdout = old_stdout
response = response.strip()
assert response == "GPT-J ERROR: prompt won't work with an unloaded model!"
def prompt_unloaded_llama():
llama = pyllmodel.LlamaModel()
old_stdout = sys.stdout
collect_response = StringIO()
sys.stdout = collect_response
llama.prompt("hello there")
response = collect_response.getvalue()
sys.stdout = old_stdout
response = response.strip()
assert response == "LLAMA ERROR: prompt won't work with an unloaded model!"