mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-10-12 11:38:18 +00:00
transfer python bindings code
This commit is contained in:
0
gpt4all-bindings/python/tests/__init__.py
Normal file
0
gpt4all-bindings/python/tests/__init__.py
Normal file
62
gpt4all-bindings/python/tests/test_gpt4all.py
Normal file
62
gpt4all-bindings/python/tests/test_gpt4all.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
|
||||
from gpt4all.gpt4all import GPT4All
|
||||
|
||||
def test_invalid_model_type():
|
||||
model_type = "bad_type"
|
||||
with pytest.raises(ValueError):
|
||||
GPT4All.get_model_from_type(model_type)
|
||||
|
||||
def test_valid_model_type():
|
||||
model_type = "gptj"
|
||||
assert GPT4All.get_model_from_type(model_type).model_type == model_type
|
||||
|
||||
def test_invalid_model_name():
|
||||
model_name = "bad_filename.bin"
|
||||
with pytest.raises(ValueError):
|
||||
GPT4All.get_model_from_name(model_name)
|
||||
|
||||
def test_valid_model_name():
|
||||
model_name = "ggml-gpt4all-l13b-snoozy"
|
||||
model_type = "llama"
|
||||
assert GPT4All.get_model_from_name(model_name).model_type == model_type
|
||||
model_name += ".bin"
|
||||
assert GPT4All.get_model_from_name(model_name).model_type == model_type
|
||||
|
||||
def test_build_prompt():
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello there."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hi, how can I help you?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Reverse a list in Python."
|
||||
}
|
||||
]
|
||||
|
||||
expected_prompt = """You are a helpful assistant.\
|
||||
\n### Instruction:
|
||||
The prompt below is a question to answer, a task to complete, or a conversation
|
||||
to respond to; decide which and write an appropriate response.\
|
||||
### Prompt:\
|
||||
Hello there.\
|
||||
Response: Hi, how can I help you?\
|
||||
Reverse a list in Python.\
|
||||
### Response:"""
|
||||
|
||||
print(expected_prompt)
|
||||
|
||||
full_prompt = GPT4All._build_prompt(messages, default_prompt_footer=True, default_prompt_header=True)
|
||||
|
||||
print("\n\n\n")
|
||||
print(full_prompt)
|
||||
assert len(full_prompt) == len(expected_prompt)
|
44
gpt4all-bindings/python/tests/test_pyllmodel.py
Normal file
44
gpt4all-bindings/python/tests/test_pyllmodel.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from io import StringIO
|
||||
import sys
|
||||
|
||||
from gpt4all import pyllmodel
|
||||
|
||||
# TODO: Integration test for loadmodel and prompt.
|
||||
# # Right now, too slow b/c it requries file download.
|
||||
|
||||
def test_create_gptj():
|
||||
gptj = pyllmodel.GPTJModel()
|
||||
assert gptj.model_type == "gptj"
|
||||
|
||||
def test_create_llama():
|
||||
llama = pyllmodel.LlamaModel()
|
||||
assert llama.model_type == "llama"
|
||||
|
||||
def prompt_unloaded_gptj():
|
||||
gptj = pyllmodel.GPTJModel()
|
||||
old_stdout = sys.stdout
|
||||
collect_response = StringIO()
|
||||
sys.stdout = collect_response
|
||||
|
||||
gptj.prompt("hello there")
|
||||
|
||||
response = collect_response.getvalue()
|
||||
sys.stdout = old_stdout
|
||||
|
||||
response = response.strip()
|
||||
assert response == "GPT-J ERROR: prompt won't work with an unloaded model!"
|
||||
|
||||
def prompt_unloaded_llama():
|
||||
llama = pyllmodel.LlamaModel()
|
||||
old_stdout = sys.stdout
|
||||
collect_response = StringIO()
|
||||
sys.stdout = collect_response
|
||||
|
||||
llama.prompt("hello there")
|
||||
|
||||
response = collect_response.getvalue()
|
||||
sys.stdout = old_stdout
|
||||
|
||||
response = response.strip()
|
||||
assert response == "LLAMA ERROR: prompt won't work with an unloaded model!"
|
||||
|
Reference in New Issue
Block a user