diff --git a/tests/test_prompt_helper.py b/tests/test_prompt_helper.py index e2270209..6b1e6a8c 100644 --- a/tests/test_prompt_helper.py +++ b/tests/test_prompt_helper.py @@ -1,19 +1,31 @@ +import sys from pathlib import Path from tempfile import NamedTemporaryFile import pytest from llama_index.llms import ChatMessage, MessageRole -from private_gpt.components.llm.prompt.prompt_helper import ( - DefaultPromptStyle, - LlamaCppPromptStyle, - LlamaIndexPromptStyle, - TemplatePromptStyle, - VigognePromptStyle, - get_prompt_style, +try: + from private_gpt.components.llm.prompt.prompt_helper import ( + DefaultPromptStyle, + LlamaCppPromptStyle, + LlamaIndexPromptStyle, + TemplatePromptStyle, + VigognePromptStyle, + get_prompt_style, + ) +except ImportError: + DefaultPromptStyle = None + LlamaCppPromptStyle = None + LlamaIndexPromptStyle = None + TemplatePromptStyle = None + VigognePromptStyle = None + get_prompt_style = None + + +@pytest.mark.skipif( + "llama_cpp" not in sys.modules, reason="requires the llama-cpp-python library" ) - - @pytest.mark.parametrize( ("prompt_style", "expected_prompt_style"), [ @@ -28,6 +40,9 @@ def test_get_prompt_style_success(prompt_style, expected_prompt_style): assert type(get_prompt_style(prompt_style)) == expected_prompt_style +@pytest.mark.skipif( + "llama_cpp" not in sys.modules, reason="requires the llama-cpp-python library" +) def test_get_prompt_style_template_success(): jinja_template = "{% for message in messages %}<|{{message['role']}}|>: {{message['content'].strip() + '\\n'}}{% endfor %}<|assistant|>: " with NamedTemporaryFile("w") as tmp_file: @@ -57,6 +72,9 @@ def test_get_prompt_style_template_success(): assert prompt == expected_prompt +@pytest.mark.skipif( + "llama_cpp" not in sys.modules, reason="requires the llama-cpp-python library" +) def test_get_prompt_style_failure(): prompt_style = "unknown" with pytest.raises(ValueError) as exc_info: @@ -64,6 +82,9 @@ def test_get_prompt_style_failure(): assert str(exc_info.value) == f"Unknown prompt_style='{prompt_style}'" +@pytest.mark.skipif( + "llama_cpp" not in sys.modules, reason="requires the llama-cpp-python library" +) def test_tag_prompt_style_format(): prompt_style = VigognePromptStyle() messages = [ @@ -80,6 +101,9 @@ def test_tag_prompt_style_format(): assert prompt_style.messages_to_prompt(messages) == expected_prompt +@pytest.mark.skipif( + "llama_cpp" not in sys.modules, reason="requires the llama-cpp-python library" +) def test_tag_prompt_style_format_with_system_prompt(): system_prompt = "This is a system prompt from configuration." prompt_style = VigognePromptStyle(default_system_prompt=system_prompt) @@ -111,6 +135,9 @@ def test_tag_prompt_style_format_with_system_prompt(): assert prompt_style.messages_to_prompt(messages) == expected_prompt +@pytest.mark.skipif( + "llama_cpp" not in sys.modules, reason="requires the llama-cpp-python library" +) def test_llama2_prompt_style_format(): prompt_style = LlamaIndexPromptStyle() messages = [ @@ -129,6 +156,9 @@ def test_llama2_prompt_style_format(): assert prompt_style.messages_to_prompt(messages) == expected_prompt +@pytest.mark.skipif( + "llama_cpp" not in sys.modules, reason="requires the llama-cpp-python library" +) def test_llama2_prompt_style_with_system_prompt(): system_prompt = "This is a system prompt from configuration." prompt_style = LlamaIndexPromptStyle(default_system_prompt=system_prompt)