From 66576948e0ad71af28b9807e4634289b2996648e Mon Sep 17 00:00:00 2001 From: Alexander Dicke <119596967+AIexanderDicke@users.noreply.github.com> Date: Sat, 9 Mar 2024 02:14:23 +0100 Subject: [PATCH] experimental[minor]: adds mixtral wrapper (#17423) **Description:** Adds a chat wrapper for Mixtral models using the [prompt template](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1#instruction-format). --------- Co-authored-by: Bagatur --- .../chat_models/__init__.py | 13 ++++---- .../chat_models/llm_wrapper.py | 17 ++++++++++ .../chat_models/test_llm_wrapper_mixtral.py | 31 +++++++++++++++++++ 3 files changed, 55 insertions(+), 6 deletions(-) create mode 100644 libs/experimental/tests/unit_tests/chat_models/test_llm_wrapper_mixtral.py diff --git a/libs/experimental/langchain_experimental/chat_models/__init__.py b/libs/experimental/langchain_experimental/chat_models/__init__.py index af7574b4470..07be8ea7b73 100644 --- a/libs/experimental/langchain_experimental/chat_models/__init__.py +++ b/libs/experimental/langchain_experimental/chat_models/__init__.py @@ -17,10 +17,11 @@ an interface where "chat messages" are the inputs and outputs. AIMessage, BaseMessage, HumanMessage """ # noqa: E501 -from langchain_experimental.chat_models.llm_wrapper import Llama2Chat, Orca, Vicuna +from langchain_experimental.chat_models.llm_wrapper import ( + Llama2Chat, + Mixtral, + Orca, + Vicuna, +) -__all__ = [ - "Llama2Chat", - "Orca", - "Vicuna", -] +__all__ = ["Llama2Chat", "Orca", "Vicuna", "Mixtral"] diff --git a/libs/experimental/langchain_experimental/chat_models/llm_wrapper.py b/libs/experimental/langchain_experimental/chat_models/llm_wrapper.py index 06aee3d9c54..b9fcde9bedb 100644 --- a/libs/experimental/langchain_experimental/chat_models/llm_wrapper.py +++ b/libs/experimental/langchain_experimental/chat_models/llm_wrapper.py @@ -148,6 +148,23 @@ class Llama2Chat(ChatWrapper): usr_0_end: str = " [/INST]" +class Mixtral(ChatWrapper): + """See https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1#instruction-format""" # noqa: E501 + + @property + def _llm_type(self) -> str: + return "mixtral" + + sys_beg: str = "[INST] " + sys_end: str = "\n" + ai_n_beg: str = " " + ai_n_end: str = " " + usr_n_beg: str = " [INST] " + usr_n_end: str = " [/INST]" + usr_0_beg: str = "" + usr_0_end: str = " [/INST]" + + class Orca(ChatWrapper): """Wrapper for Orca-style models.""" diff --git a/libs/experimental/tests/unit_tests/chat_models/test_llm_wrapper_mixtral.py b/libs/experimental/tests/unit_tests/chat_models/test_llm_wrapper_mixtral.py new file mode 100644 index 00000000000..881f78f6a77 --- /dev/null +++ b/libs/experimental/tests/unit_tests/chat_models/test_llm_wrapper_mixtral.py @@ -0,0 +1,31 @@ +import pytest +from langchain.schema import AIMessage, HumanMessage, SystemMessage + +from langchain_experimental.chat_models import Mixtral +from tests.unit_tests.chat_models.test_llm_wrapper_llama2chat import FakeLLM + + +@pytest.fixture +def model() -> Mixtral: + return Mixtral(llm=FakeLLM()) + + +@pytest.fixture +def model_cfg_sys_msg() -> Mixtral: + return Mixtral(llm=FakeLLM(), system_message=SystemMessage(content="sys-msg")) + + +def test_prompt(model: Mixtral) -> None: + messages = [ + SystemMessage(content="sys-msg"), + HumanMessage(content="usr-msg-1"), + AIMessage(content="ai-msg-1"), + HumanMessage(content="usr-msg-2"), + ] + + actual = model.predict_messages(messages).content # type: ignore + expected = ( + "[INST] sys-msg\nusr-msg-1 [/INST] ai-msg-1 [INST] usr-msg-2 [/INST]" # noqa: E501 + ) + + assert actual == expected