mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-10-01 01:38:46 +00:00
transfer python bindings code
This commit is contained in:
44
gpt4all-bindings/python/tests/test_pyllmodel.py
Normal file
44
gpt4all-bindings/python/tests/test_pyllmodel.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from io import StringIO
|
||||
import sys
|
||||
|
||||
from gpt4all import pyllmodel
|
||||
|
||||
# TODO: Integration test for loadmodel and prompt.
|
||||
# # Right now, too slow b/c it requries file download.
|
||||
|
||||
def test_create_gptj():
|
||||
gptj = pyllmodel.GPTJModel()
|
||||
assert gptj.model_type == "gptj"
|
||||
|
||||
def test_create_llama():
|
||||
llama = pyllmodel.LlamaModel()
|
||||
assert llama.model_type == "llama"
|
||||
|
||||
def prompt_unloaded_gptj():
|
||||
gptj = pyllmodel.GPTJModel()
|
||||
old_stdout = sys.stdout
|
||||
collect_response = StringIO()
|
||||
sys.stdout = collect_response
|
||||
|
||||
gptj.prompt("hello there")
|
||||
|
||||
response = collect_response.getvalue()
|
||||
sys.stdout = old_stdout
|
||||
|
||||
response = response.strip()
|
||||
assert response == "GPT-J ERROR: prompt won't work with an unloaded model!"
|
||||
|
||||
def prompt_unloaded_llama():
|
||||
llama = pyllmodel.LlamaModel()
|
||||
old_stdout = sys.stdout
|
||||
collect_response = StringIO()
|
||||
sys.stdout = collect_response
|
||||
|
||||
llama.prompt("hello there")
|
||||
|
||||
response = collect_response.getvalue()
|
||||
sys.stdout = old_stdout
|
||||
|
||||
response = response.strip()
|
||||
assert response == "LLAMA ERROR: prompt won't work with an unloaded model!"
|
||||
|
Reference in New Issue
Block a user