mistral: implement ls_params (#21867)

This commit is contained in:
ccurme 2024-05-20 14:49:48 -04:00 committed by GitHub
parent 68a90e2252
commit 9c76739425
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 22 additions and 15 deletions

View File

@ -29,6 +29,7 @@ from langchain_core.callbacks import (
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
LangSmithParams,
agenerate_from_stream,
generate_from_stream,
)
@ -343,6 +344,23 @@ class ChatMistralAI(BaseChatModel):
filtered = {k: v for k, v in defaults.items() if v is not None}
return filtered
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
ls_params = LangSmithParams(
ls_provider="mistral",
ls_model_name=self.model,
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
ls_params["ls_max_tokens"] = ls_max_tokens
if ls_stop := stop or params.get("stop", None):
ls_params["ls_stop"] = ls_stop
return ls_params
@property
def _client_params(self) -> Dict[str, Any]:
"""Get the parameters used for the client."""

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "annotated-types"
@ -392,7 +392,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.2.0rc1"
version = "0.2.0"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -1051,4 +1051,4 @@ zstd = ["zstandard (>=0.18.0)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "183134cac43c4c1d16e85d4ddf22ce4e14b6e4c78eab67134703da616f12c5cb"
content-hash = "4a5a57d01c791de831f03fb309541443dc8bb51f5068ccfb7bcb77490c2eb6c3"

View File

@ -12,7 +12,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = ">=0.1.46,<0.3"
langchain-core = ">=0.2.0,<0.3"
tokenizers = ">=0.15.1,<1"
httpx = ">=0.25.2,<1"
httpx-sse = ">=0.3.1,<1"

View File

@ -13,14 +13,3 @@ class TestMistralStandard(ChatModelUnitTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatMistralAI
@pytest.mark.xfail(reason="Not implemented.")
def test_standard_params(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_standard_params(
chat_model_class,
chat_model_params,
)