diff --git a/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py b/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py index 556fecb6..983c728e 100644 --- a/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py @@ -32,6 +32,50 @@ def test_inference(): assert True +def do_long_input(model): + long_input = " ".join(["hello how are you"] * 40) + + with model.chat_session(): + # llmodel should limit us to 128 even if we ask for more + model.generate(long_input, n_batch=512) + print(model.current_chat_session) + + +def test_inference_long_orca_3b(): + model = GPT4All(model_name="orca-mini-3b.ggmlv3.q4_0.bin") + do_long_input(model) + + +def test_inference_long_falcon(): + model = GPT4All(model_name='ggml-model-gpt4all-falcon-q4_0.bin') + do_long_input(model) + + +def test_inference_long_llama_7b(): + model = GPT4All(model_name="orca-mini-7b.ggmlv3.q4_0.bin") + do_long_input(model) + + +def test_inference_long_llama_13b(): + model = GPT4All(model_name='ggml-nous-hermes-13b.ggmlv3.q4_0.bin') + do_long_input(model) + + +def test_inference_long_mpt(): + model = GPT4All(model_name='ggml-mpt-7b-chat.bin') + do_long_input(model) + + +def test_inference_long_replit(): + model = GPT4All(model_name='ggml-replit-code-v1-3b.bin') + do_long_input(model) + + +def test_inference_long_groovy(): + model = GPT4All(model_name='ggml-gpt4all-j-v1.3-groovy.bin') + do_long_input(model) + + def test_inference_hparams(): model = GPT4All(model_name='orca-mini-3b.ggmlv3.q4_0.bin')