mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 13:07:58 +00:00
experimental[minor]: adds mixtral wrapper (#17423)
**Description:** Adds a chat wrapper for Mixtral models using the [prompt template](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1#instruction-format). --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
import pytest
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from langchain_experimental.chat_models import Mixtral
|
||||
from tests.unit_tests.chat_models.test_llm_wrapper_llama2chat import FakeLLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model() -> Mixtral:
|
||||
return Mixtral(llm=FakeLLM())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_cfg_sys_msg() -> Mixtral:
|
||||
return Mixtral(llm=FakeLLM(), system_message=SystemMessage(content="sys-msg"))
|
||||
|
||||
|
||||
def test_prompt(model: Mixtral) -> None:
|
||||
messages = [
|
||||
SystemMessage(content="sys-msg"),
|
||||
HumanMessage(content="usr-msg-1"),
|
||||
AIMessage(content="ai-msg-1"),
|
||||
HumanMessage(content="usr-msg-2"),
|
||||
]
|
||||
|
||||
actual = model.predict_messages(messages).content # type: ignore
|
||||
expected = (
|
||||
"<s>[INST] sys-msg\nusr-msg-1 [/INST] ai-msg-1 </s> [INST] usr-msg-2 [/INST]" # noqa: E501
|
||||
)
|
||||
|
||||
assert actual == expected
|
Reference in New Issue
Block a user