mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 13:07:58 +00:00
EXPERIMENTAL Generic LLM wrapper to support chat model interface with configurable chat prompt format (#8295)
## Update 2023-09-08 This PR now supports further models in addition to Lllama-2 chat models. See [this comment](#issuecomment-1668988543) for further details. The title of this PR has been updated accordingly. ## Original PR description This PR adds a generic `Llama2Chat` model, a wrapper for LLMs able to serve Llama-2 chat models (like `LlamaCPP`, `HuggingFaceTextGenInference`, ...). It implements `BaseChatModel`, converts a list of chat messages into the [required Llama-2 chat prompt format](https://huggingface.co/blog/llama2#how-to-prompt-llama-2) and forwards the formatted prompt as `str` to the wrapped `LLM`. Usage example: ```python # uses a locally hosted Llama2 chat model llm = HuggingFaceTextGenInference( inference_server_url="http://127.0.0.1:8080/", max_new_tokens=512, top_k=50, temperature=0.1, repetition_penalty=1.03, ) # Wrap llm to support Llama2 chat prompt format. # Resulting model is a chat model model = Llama2Chat(llm=llm) messages = [ SystemMessage(content="You are a helpful assistant."), MessagesPlaceholder(variable_name="chat_history"), HumanMessagePromptTemplate.from_template("{text}"), ] prompt = ChatPromptTemplate.from_messages(messages) memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) chain = LLMChain(llm=model, prompt=prompt, memory=memory) # use chat model in a conversation # ... ``` Also part of this PR are tests and a demo notebook. - Tag maintainer: @hwchase17 - Twitter handle: `@mrt1nz` --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -0,0 +1,26 @@
|
||||
"""**Chat Models** are a variation on language models.
|
||||
|
||||
While Chat Models use language models under the hood, the interface they expose
|
||||
is a bit different. Rather than expose a "text in, text out" API, they expose
|
||||
an interface where "chat messages" are the inputs and outputs.
|
||||
|
||||
**Class hierarchy:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
BaseLanguageModel --> BaseChatModel --> <name> # Examples: ChatOpenAI, ChatGooglePalm
|
||||
|
||||
**Main helpers:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
AIMessage, BaseMessage, HumanMessage
|
||||
""" # noqa: E501
|
||||
|
||||
from langchain_experimental.chat_models.llm_wrapper import Llama2Chat, Orca, Vicuna
|
||||
|
||||
__all__ = [
|
||||
"Llama2Chat",
|
||||
"Orca",
|
||||
"Vicuna",
|
||||
]
|
@@ -0,0 +1,163 @@
|
||||
"""Generic Wrapper for chat LLMs, with sample implementations
|
||||
for Llama-2-chat, Llama-2-instruct and Vicuna models.
|
||||
"""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
HumanMessage,
|
||||
LLMResult,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
||||
|
||||
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" # noqa: E501
|
||||
|
||||
|
||||
class ChatWrapper(BaseChatModel):
|
||||
llm: LLM
|
||||
sys_beg: str
|
||||
sys_end: str
|
||||
ai_n_beg: str
|
||||
ai_n_end: str
|
||||
usr_n_beg: str
|
||||
usr_n_end: str
|
||||
usr_0_beg: Optional[str] = None
|
||||
usr_0_end: Optional[str] = None
|
||||
|
||||
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
llm_input = self._to_chat_prompt(messages)
|
||||
llm_result = self.llm._generate(
|
||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return self._to_chat_result(llm_result)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
llm_input = self._to_chat_prompt(messages)
|
||||
llm_result = await self.llm._agenerate(
|
||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return self._to_chat_result(llm_result)
|
||||
|
||||
def _to_chat_prompt(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
) -> str:
|
||||
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
|
||||
if not messages:
|
||||
raise ValueError("at least one HumanMessage must be provided")
|
||||
|
||||
if not isinstance(messages[0], SystemMessage):
|
||||
messages = [self.system_message] + messages
|
||||
|
||||
if not isinstance(messages[1], HumanMessage):
|
||||
raise ValueError(
|
||||
"messages list must start with a SystemMessage or UserMessage"
|
||||
)
|
||||
|
||||
if not isinstance(messages[-1], HumanMessage):
|
||||
raise ValueError("last message must be a HumanMessage")
|
||||
|
||||
prompt_parts = []
|
||||
|
||||
if self.usr_0_beg is None:
|
||||
self.usr_0_beg = self.usr_n_beg
|
||||
|
||||
if self.usr_0_end is None:
|
||||
self.usr_0_end = self.usr_n_end
|
||||
|
||||
prompt_parts.append(self.sys_beg + messages[0].content + self.sys_end)
|
||||
prompt_parts.append(self.usr_0_beg + messages[1].content + self.usr_0_end)
|
||||
|
||||
for ai_message, human_message in zip(messages[2::2], messages[3::2]):
|
||||
if not isinstance(ai_message, AIMessage) or not isinstance(
|
||||
human_message, HumanMessage
|
||||
):
|
||||
raise ValueError(
|
||||
"messages must be alternating human- and ai-messages, "
|
||||
"optionally prepended by a system message"
|
||||
)
|
||||
|
||||
prompt_parts.append(self.ai_n_beg + ai_message.content + self.ai_n_end)
|
||||
prompt_parts.append(self.usr_n_beg + human_message.content + self.usr_n_end)
|
||||
|
||||
return "".join(prompt_parts)
|
||||
|
||||
@staticmethod
|
||||
def _to_chat_result(llm_result: LLMResult) -> ChatResult:
|
||||
chat_generations = []
|
||||
|
||||
for g in llm_result.generations[0]:
|
||||
chat_generation = ChatGeneration(
|
||||
message=AIMessage(content=g.text), generation_info=g.generation_info
|
||||
)
|
||||
chat_generations.append(chat_generation)
|
||||
|
||||
return ChatResult(
|
||||
generations=chat_generations, llm_output=llm_result.llm_output
|
||||
)
|
||||
|
||||
|
||||
class Llama2Chat(ChatWrapper):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "llama-2-chat"
|
||||
|
||||
sys_beg: str = "<s>[INST] <<SYS>>\n"
|
||||
sys_end: str = "\n<</SYS>>\n\n"
|
||||
ai_n_beg: str = " "
|
||||
ai_n_end: str = " </s>"
|
||||
usr_n_beg: str = "<s>[INST] "
|
||||
usr_n_end: str = " [/INST]"
|
||||
usr_0_beg: str = ""
|
||||
usr_0_end: str = " [/INST]"
|
||||
|
||||
|
||||
class Orca(ChatWrapper):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "orca-style"
|
||||
|
||||
sys_beg: str = "### System:\n"
|
||||
sys_end: str = "\n\n"
|
||||
ai_n_beg: str = "### Assistant:\n"
|
||||
ai_n_end: str = "\n\n"
|
||||
usr_n_beg: str = "### User:\n"
|
||||
usr_n_end: str = "\n\n"
|
||||
|
||||
|
||||
class Vicuna(ChatWrapper):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "vicuna-style"
|
||||
|
||||
sys_beg: str = ""
|
||||
sys_end: str = " "
|
||||
ai_n_beg: str = "ASSISTANT: "
|
||||
ai_n_end: str = " </s>"
|
||||
usr_n_beg: str = "USER: "
|
||||
usr_n_end: str = " "
|
24
libs/experimental/poetry.lock
generated
24
libs/experimental/poetry.lock
generated
@@ -2806,6 +2806,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
|
||||
[package.extras]
|
||||
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
version = "0.20.3"
|
||||
description = "Pytest support for asyncio"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pytest-asyncio-0.20.3.tar.gz", hash = "sha256:83cbf01169ce3e8eb71c6c278ccb0574d1a7a3bb8eaaf5e50e0ad342afb33b36"},
|
||||
{file = "pytest_asyncio-0.20.3-py3-none-any.whl", hash = "sha256:f129998b209d04fcc65c96fc85c11e5316738358909a8399e93be553d7656442"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=6.1.0"
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
|
||||
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.8.2"
|
||||
@@ -3871,9 +3889,7 @@ python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:638c2c0b6b4661a4fd264f6fb804eccd392745c5887f9317feb64bb7cb03b3ea"},
|
||||
{file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e3b5036aa326dc2df50cba3c958e29b291a80f604b1afa4c8ce73e78e1c9f01d"},
|
||||
{file = "SQLAlchemy-2.0.23-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:787af80107fb691934a01889ca8f82a44adedbf5ef3d6ad7d0f0b9ac557e0c34"},
|
||||
{file = "SQLAlchemy-2.0.23-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c14eba45983d2f48f7546bb32b47937ee2cafae353646295f0e99f35b14286ab"},
|
||||
{file = "SQLAlchemy-2.0.23-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0666031df46b9badba9bed00092a1ffa3aa063a5e68fa244acd9f08070e936d3"},
|
||||
{file = "SQLAlchemy-2.0.23-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:89a01238fcb9a8af118eaad3ffcc5dedaacbd429dc6fdc43fe430d3a941ff965"},
|
||||
{file = "SQLAlchemy-2.0.23-cp310-cp310-win32.whl", hash = "sha256:cabafc7837b6cec61c0e1e5c6d14ef250b675fa9c3060ed8a7e38653bd732ff8"},
|
||||
{file = "SQLAlchemy-2.0.23-cp310-cp310-win_amd64.whl", hash = "sha256:87a3d6b53c39cd173990de2f5f4b83431d534a74f0e2f88bd16eabb5667e65c6"},
|
||||
@@ -3910,9 +3926,7 @@ files = [
|
||||
{file = "SQLAlchemy-2.0.23-cp38-cp38-win_amd64.whl", hash = "sha256:964971b52daab357d2c0875825e36584d58f536e920f2968df8d581054eada4b"},
|
||||
{file = "SQLAlchemy-2.0.23-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:616fe7bcff0a05098f64b4478b78ec2dfa03225c23734d83d6c169eb41a93e55"},
|
||||
{file = "SQLAlchemy-2.0.23-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0e680527245895aba86afbd5bef6c316831c02aa988d1aad83c47ffe92655e74"},
|
||||
{file = "SQLAlchemy-2.0.23-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9585b646ffb048c0250acc7dad92536591ffe35dba624bb8fd9b471e25212a35"},
|
||||
{file = "SQLAlchemy-2.0.23-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4895a63e2c271ffc7a81ea424b94060f7b3b03b4ea0cd58ab5bb676ed02f4221"},
|
||||
{file = "SQLAlchemy-2.0.23-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:cc1d21576f958c42d9aec68eba5c1a7d715e5fc07825a629015fe8e3b0657fb0"},
|
||||
{file = "SQLAlchemy-2.0.23-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:967c0b71156f793e6662dd839da54f884631755275ed71f1539c95bbada9aaab"},
|
||||
{file = "SQLAlchemy-2.0.23-cp39-cp39-win32.whl", hash = "sha256:0a8c6aa506893e25a04233bc721c6b6cf844bafd7250535abb56cb6cc1368884"},
|
||||
{file = "SQLAlchemy-2.0.23-cp39-cp39-win_amd64.whl", hash = "sha256:f3420d00d2cb42432c1d0e44540ae83185ccbbc67a6054dcc8ab5387add6620b"},
|
||||
@@ -4871,4 +4885,4 @@ extended-testing = ["faker", "presidio-analyzer", "presidio-anonymizer", "senten
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "b834d2b8bcfb0c10549937841a9c6838ca8fde99d23e6c6deb8a6e3f4f4e43af"
|
||||
content-hash = "ba9be2e62d1507b2f370b4388604d8e3e5afb3d495691f12d15d0128f162539d"
|
||||
|
@@ -34,6 +34,7 @@ setuptools = "^67.6.1"
|
||||
# dependencies used for running tests (e.g., pytest, freezegun, response).
|
||||
# Any dependencies that do not meet that criteria will be removed.
|
||||
pytest = "^7.3.0"
|
||||
pytest-asyncio = "^0.20.3"
|
||||
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
|
@@ -0,0 +1,157 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import pytest
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from langchain_experimental.chat_models import Llama2Chat
|
||||
from langchain_experimental.chat_models.llm_wrapper import DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
|
||||
class FakeLLM(LLM):
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
return prompt
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
return prompt
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-llm"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model() -> Llama2Chat:
|
||||
return Llama2Chat(llm=FakeLLM())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_cfg_sys_msg() -> Llama2Chat:
|
||||
return Llama2Chat(llm=FakeLLM(), system_message=SystemMessage(content="sys-msg"))
|
||||
|
||||
|
||||
def test_default_system_message(model: Llama2Chat) -> None:
|
||||
messages = [HumanMessage(content="usr-msg-1")]
|
||||
|
||||
actual = model.predict_messages(messages).content # type: ignore
|
||||
expected = (
|
||||
f"<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>>\n\nusr-msg-1 [/INST]"
|
||||
)
|
||||
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_configured_system_message(
|
||||
model_cfg_sys_msg: Llama2Chat,
|
||||
) -> None:
|
||||
messages = [HumanMessage(content="usr-msg-1")]
|
||||
|
||||
actual = model_cfg_sys_msg.predict_messages(messages).content # type: ignore
|
||||
expected = "<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]"
|
||||
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configured_system_message_async(
|
||||
model_cfg_sys_msg: Llama2Chat,
|
||||
) -> None:
|
||||
messages = [HumanMessage(content="usr-msg-1")]
|
||||
|
||||
actual = await model_cfg_sys_msg.apredict_messages(messages) # type: ignore
|
||||
expected = "<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]"
|
||||
|
||||
assert actual.content == expected
|
||||
|
||||
|
||||
def test_provided_system_message(
|
||||
model_cfg_sys_msg: Llama2Chat,
|
||||
) -> None:
|
||||
messages = [
|
||||
SystemMessage(content="custom-sys-msg"),
|
||||
HumanMessage(content="usr-msg-1"),
|
||||
]
|
||||
|
||||
actual = model_cfg_sys_msg.predict_messages(messages).content
|
||||
expected = "<s>[INST] <<SYS>>\ncustom-sys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]"
|
||||
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_human_ai_dialogue(model_cfg_sys_msg: Llama2Chat) -> None:
|
||||
messages = [
|
||||
HumanMessage(content="usr-msg-1"),
|
||||
AIMessage(content="ai-msg-1"),
|
||||
HumanMessage(content="usr-msg-2"),
|
||||
AIMessage(content="ai-msg-2"),
|
||||
HumanMessage(content="usr-msg-3"),
|
||||
]
|
||||
|
||||
actual = model_cfg_sys_msg.predict_messages(messages).content
|
||||
expected = (
|
||||
"<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST] ai-msg-1 </s>"
|
||||
"<s>[INST] usr-msg-2 [/INST] ai-msg-2 </s><s>[INST] usr-msg-3 [/INST]"
|
||||
)
|
||||
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_no_message(model: Llama2Chat) -> None:
|
||||
with pytest.raises(ValueError) as info:
|
||||
model.predict_messages([])
|
||||
|
||||
assert info.value.args[0] == "at least one HumanMessage must be provided"
|
||||
|
||||
|
||||
def test_ai_message_first(model: Llama2Chat) -> None:
|
||||
with pytest.raises(ValueError) as info:
|
||||
model.predict_messages([AIMessage(content="ai-msg-1")])
|
||||
|
||||
assert (
|
||||
info.value.args[0]
|
||||
== "messages list must start with a SystemMessage or UserMessage"
|
||||
)
|
||||
|
||||
|
||||
def test_human_ai_messages_not_alternating(model: Llama2Chat) -> None:
|
||||
messages = [
|
||||
HumanMessage(content="usr-msg-1"),
|
||||
HumanMessage(content="usr-msg-2"),
|
||||
HumanMessage(content="ai-msg-1"),
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
model.predict_messages(messages) # type: ignore
|
||||
|
||||
assert info.value.args[0] == (
|
||||
"messages must be alternating human- and ai-messages, "
|
||||
"optionally prepended by a system message"
|
||||
)
|
||||
|
||||
|
||||
def test_last_message_not_human_message(model: Llama2Chat) -> None:
|
||||
messages = [
|
||||
HumanMessage(content="usr-msg-1"),
|
||||
AIMessage(content="ai-msg-1"),
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
model.predict_messages(messages)
|
||||
|
||||
assert info.value.args[0] == "last message must be a HumanMessage"
|
@@ -0,0 +1,29 @@
|
||||
import pytest
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from langchain_experimental.chat_models import Orca
|
||||
from tests.unit_tests.chat_models.test_llm_wrapper_llama2chat import FakeLLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model() -> Orca:
|
||||
return Orca(llm=FakeLLM())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_cfg_sys_msg() -> Orca:
|
||||
return Orca(llm=FakeLLM(), system_message=SystemMessage(content="sys-msg"))
|
||||
|
||||
|
||||
def test_prompt(model: Orca) -> None:
|
||||
messages = [
|
||||
SystemMessage(content="sys-msg"),
|
||||
HumanMessage(content="usr-msg-1"),
|
||||
AIMessage(content="ai-msg-1"),
|
||||
HumanMessage(content="usr-msg-2"),
|
||||
]
|
||||
|
||||
actual = model.predict_messages(messages).content # type: ignore
|
||||
expected = "### System:\nsys-msg\n\n### User:\nusr-msg-1\n\n### Assistant:\nai-msg-1\n\n### User:\nusr-msg-2\n\n" # noqa: E501
|
||||
|
||||
assert actual == expected
|
@@ -0,0 +1,29 @@
|
||||
import pytest
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from langchain_experimental.chat_models import Vicuna
|
||||
from tests.unit_tests.chat_models.test_llm_wrapper_llama2chat import FakeLLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model() -> Vicuna:
|
||||
return Vicuna(llm=FakeLLM())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_cfg_sys_msg() -> Vicuna:
|
||||
return Vicuna(llm=FakeLLM(), system_message=SystemMessage(content="sys-msg"))
|
||||
|
||||
|
||||
def test_prompt(model: Vicuna) -> None:
|
||||
messages = [
|
||||
SystemMessage(content="sys-msg"),
|
||||
HumanMessage(content="usr-msg-1"),
|
||||
AIMessage(content="ai-msg-1"),
|
||||
HumanMessage(content="usr-msg-2"),
|
||||
]
|
||||
|
||||
actual = model.predict_messages(messages).content # type: ignore
|
||||
expected = "sys-msg USER: usr-msg-1 ASSISTANT: ai-msg-1 </s>USER: usr-msg-2 "
|
||||
|
||||
assert actual == expected
|
Reference in New Issue
Block a user