mpt bindings

This commit is contained in:
Richard Guo
2023-05-11 15:26:20 -04:00
committed by Richard Guo
parent d56aada08c
commit 36a6e824f0
4 changed files with 46 additions and 2 deletions

View File

@@ -14,6 +14,24 @@ def test_create_llama():
llama = pyllmodel.LlamaModel()
assert llama.model_type == "llama"
def test_create_mpt():
mpt = pyllmodel.MPTModel()
assert mpt.model_type == "mpt"
def prompt_unloaded_mpt():
mpt = pyllmodel.MPTModel()
old_stdout = sys.stdout
collect_response = StringIO()
sys.stdout = collect_response
mpt.prompt("hello there")
response = collect_response.getvalue()
sys.stdout = old_stdout
response = response.strip()
assert response == "MPT ERROR: prompt won't work with an unloaded model!"
def prompt_unloaded_gptj():
gptj = pyllmodel.GPTJModel()
old_stdout = sys.stdout