Make test optional for prompt_helper if no llama_cpp is installed

This commit is contained in:
Louis 2023-12-03 19:03:34 +01:00
parent af1463637b
commit 29816d8a3a

View File

@ -1,19 +1,31 @@
import sys
from pathlib import Path
from tempfile import NamedTemporaryFile
import pytest
from llama_index.llms import ChatMessage, MessageRole
from private_gpt.components.llm.prompt.prompt_helper import (
DefaultPromptStyle,
LlamaCppPromptStyle,
LlamaIndexPromptStyle,
TemplatePromptStyle,
VigognePromptStyle,
get_prompt_style,
try:
from private_gpt.components.llm.prompt.prompt_helper import (
DefaultPromptStyle,
LlamaCppPromptStyle,
LlamaIndexPromptStyle,
TemplatePromptStyle,
VigognePromptStyle,
get_prompt_style,
)
except ImportError:
DefaultPromptStyle = None
LlamaCppPromptStyle = None
LlamaIndexPromptStyle = None
TemplatePromptStyle = None
VigognePromptStyle = None
get_prompt_style = None
@pytest.mark.skipif(
"llama_cpp" not in sys.modules, reason="requires the llama-cpp-python library"
)
@pytest.mark.parametrize(
("prompt_style", "expected_prompt_style"),
[
@ -28,6 +40,9 @@ def test_get_prompt_style_success(prompt_style, expected_prompt_style):
assert type(get_prompt_style(prompt_style)) == expected_prompt_style
@pytest.mark.skipif(
"llama_cpp" not in sys.modules, reason="requires the llama-cpp-python library"
)
def test_get_prompt_style_template_success():
jinja_template = "{% for message in messages %}<|{{message['role']}}|>: {{message['content'].strip() + '\\n'}}{% endfor %}<|assistant|>: "
with NamedTemporaryFile("w") as tmp_file:
@ -57,6 +72,9 @@ def test_get_prompt_style_template_success():
assert prompt == expected_prompt
@pytest.mark.skipif(
"llama_cpp" not in sys.modules, reason="requires the llama-cpp-python library"
)
def test_get_prompt_style_failure():
prompt_style = "unknown"
with pytest.raises(ValueError) as exc_info:
@ -64,6 +82,9 @@ def test_get_prompt_style_failure():
assert str(exc_info.value) == f"Unknown prompt_style='{prompt_style}'"
@pytest.mark.skipif(
"llama_cpp" not in sys.modules, reason="requires the llama-cpp-python library"
)
def test_tag_prompt_style_format():
prompt_style = VigognePromptStyle()
messages = [
@ -80,6 +101,9 @@ def test_tag_prompt_style_format():
assert prompt_style.messages_to_prompt(messages) == expected_prompt
@pytest.mark.skipif(
"llama_cpp" not in sys.modules, reason="requires the llama-cpp-python library"
)
def test_tag_prompt_style_format_with_system_prompt():
system_prompt = "This is a system prompt from configuration."
prompt_style = VigognePromptStyle(default_system_prompt=system_prompt)
@ -111,6 +135,9 @@ def test_tag_prompt_style_format_with_system_prompt():
assert prompt_style.messages_to_prompt(messages) == expected_prompt
@pytest.mark.skipif(
"llama_cpp" not in sys.modules, reason="requires the llama-cpp-python library"
)
def test_llama2_prompt_style_format():
prompt_style = LlamaIndexPromptStyle()
messages = [
@ -129,6 +156,9 @@ def test_llama2_prompt_style_format():
assert prompt_style.messages_to_prompt(messages) == expected_prompt
@pytest.mark.skipif(
"llama_cpp" not in sys.modules, reason="requires the llama-cpp-python library"
)
def test_llama2_prompt_style_with_system_prompt():
system_prompt = "This is a system prompt from configuration."
prompt_style = LlamaIndexPromptStyle(default_system_prompt=system_prompt)