mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 10:23:30 +00:00
experimental[minor]: adds mixtral wrapper (#17423)
**Description:** Adds a chat wrapper for Mixtral models using the [prompt template](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1#instruction-format). --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
4f4300723b
commit
66576948e0
@ -17,10 +17,11 @@ an interface where "chat messages" are the inputs and outputs.
|
|||||||
AIMessage, BaseMessage, HumanMessage
|
AIMessage, BaseMessage, HumanMessage
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
from langchain_experimental.chat_models.llm_wrapper import Llama2Chat, Orca, Vicuna
|
from langchain_experimental.chat_models.llm_wrapper import (
|
||||||
|
Llama2Chat,
|
||||||
|
Mixtral,
|
||||||
|
Orca,
|
||||||
|
Vicuna,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["Llama2Chat", "Orca", "Vicuna", "Mixtral"]
|
||||||
"Llama2Chat",
|
|
||||||
"Orca",
|
|
||||||
"Vicuna",
|
|
||||||
]
|
|
||||||
|
@ -148,6 +148,23 @@ class Llama2Chat(ChatWrapper):
|
|||||||
usr_0_end: str = " [/INST]"
|
usr_0_end: str = " [/INST]"
|
||||||
|
|
||||||
|
|
||||||
|
class Mixtral(ChatWrapper):
|
||||||
|
"""See https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1#instruction-format""" # noqa: E501
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "mixtral"
|
||||||
|
|
||||||
|
sys_beg: str = "<s>[INST] "
|
||||||
|
sys_end: str = "\n"
|
||||||
|
ai_n_beg: str = " "
|
||||||
|
ai_n_end: str = " </s>"
|
||||||
|
usr_n_beg: str = " [INST] "
|
||||||
|
usr_n_end: str = " [/INST]"
|
||||||
|
usr_0_beg: str = ""
|
||||||
|
usr_0_end: str = " [/INST]"
|
||||||
|
|
||||||
|
|
||||||
class Orca(ChatWrapper):
|
class Orca(ChatWrapper):
|
||||||
"""Wrapper for Orca-style models."""
|
"""Wrapper for Orca-style models."""
|
||||||
|
|
||||||
|
@ -0,0 +1,31 @@
|
|||||||
|
import pytest
|
||||||
|
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
from langchain_experimental.chat_models import Mixtral
|
||||||
|
from tests.unit_tests.chat_models.test_llm_wrapper_llama2chat import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model() -> Mixtral:
|
||||||
|
return Mixtral(llm=FakeLLM())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_cfg_sys_msg() -> Mixtral:
|
||||||
|
return Mixtral(llm=FakeLLM(), system_message=SystemMessage(content="sys-msg"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt(model: Mixtral) -> None:
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content="sys-msg"),
|
||||||
|
HumanMessage(content="usr-msg-1"),
|
||||||
|
AIMessage(content="ai-msg-1"),
|
||||||
|
HumanMessage(content="usr-msg-2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
actual = model.predict_messages(messages).content # type: ignore
|
||||||
|
expected = (
|
||||||
|
"<s>[INST] sys-msg\nusr-msg-1 [/INST] ai-msg-1 </s> [INST] usr-msg-2 [/INST]" # noqa: E501
|
||||||
|
)
|
||||||
|
|
||||||
|
assert actual == expected
|
Loading…
Reference in New Issue
Block a user