From 9c767394257cef6f43bdb18ca0ce63f155a513e4 Mon Sep 17 00:00:00 2001 From: ccurme Date: Mon, 20 May 2024 14:49:48 -0400 Subject: [PATCH] mistral: implement ls_params (#21867) --- .../langchain_mistralai/chat_models.py | 18 ++++++++++++++++++ libs/partners/mistralai/poetry.lock | 6 +++--- libs/partners/mistralai/pyproject.toml | 2 +- .../tests/unit_tests/test_standard.py | 11 ----------- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 6e22d0eefce..2a18876a0f7 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -29,6 +29,7 @@ from langchain_core.callbacks import ( from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, + LangSmithParams, agenerate_from_stream, generate_from_stream, ) @@ -343,6 +344,23 @@ class ChatMistralAI(BaseChatModel): filtered = {k: v for k, v in defaults.items() if v is not None} return filtered + def _get_ls_params( + self, stop: Optional[List[str]] = None, **kwargs: Any + ) -> LangSmithParams: + """Get standard params for tracing.""" + params = self._get_invocation_params(stop=stop, **kwargs) + ls_params = LangSmithParams( + ls_provider="mistral", + ls_model_name=self.model, + ls_model_type="chat", + ls_temperature=params.get("temperature", self.temperature), + ) + if ls_max_tokens := params.get("max_tokens", self.max_tokens): + ls_params["ls_max_tokens"] = ls_max_tokens + if ls_stop := stop or params.get("stop", None): + ls_params["ls_stop"] = ls_stop + return ls_params + @property def _client_params(self) -> Dict[str, Any]: """Get the parameters used for the client.""" diff --git a/libs/partners/mistralai/poetry.lock b/libs/partners/mistralai/poetry.lock index 5b986f4c235..72a47e32356 100644 --- a/libs/partners/mistralai/poetry.lock +++ b/libs/partners/mistralai/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "annotated-types" @@ -392,7 +392,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.2.0rc1" +version = "0.2.0" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -1051,4 +1051,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "183134cac43c4c1d16e85d4ddf22ce4e14b6e4c78eab67134703da616f12c5cb" +content-hash = "4a5a57d01c791de831f03fb309541443dc8bb51f5068ccfb7bcb77490c2eb6c3" diff --git a/libs/partners/mistralai/pyproject.toml b/libs/partners/mistralai/pyproject.toml index 6d5a08a30a8..b3767a43f45 100644 --- a/libs/partners/mistralai/pyproject.toml +++ b/libs/partners/mistralai/pyproject.toml @@ -12,7 +12,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -langchain-core = ">=0.1.46,<0.3" +langchain-core = ">=0.2.0,<0.3" tokenizers = ">=0.15.1,<1" httpx = ">=0.25.2,<1" httpx-sse = ">=0.3.1,<1" diff --git a/libs/partners/mistralai/tests/unit_tests/test_standard.py b/libs/partners/mistralai/tests/unit_tests/test_standard.py index 03cb5429fa4..46ef3ec3a44 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_standard.py +++ b/libs/partners/mistralai/tests/unit_tests/test_standard.py @@ -13,14 +13,3 @@ class TestMistralStandard(ChatModelUnitTests): @pytest.fixture def chat_model_class(self) -> Type[BaseChatModel]: return ChatMistralAI - - @pytest.mark.xfail(reason="Not implemented.") - def test_standard_params( - self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - ) -> None: - super().test_standard_params( - chat_model_class, - chat_model_params, - )