mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 08:03:39 +00:00
mistral: implement ls_params (#21867)
This commit is contained in:
parent
68a90e2252
commit
9c76739425
@ -29,6 +29,7 @@ from langchain_core.callbacks import (
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
LangSmithParams,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
@ -343,6 +344,23 @@ class ChatMistralAI(BaseChatModel):
|
||||
filtered = {k: v for k, v in defaults.items() if v is not None}
|
||||
return filtered
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
ls_params = LangSmithParams(
|
||||
ls_provider="mistral",
|
||||
ls_model_name=self.model,
|
||||
ls_model_type="chat",
|
||||
ls_temperature=params.get("temperature", self.temperature),
|
||||
)
|
||||
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
|
||||
ls_params["ls_max_tokens"] = ls_max_tokens
|
||||
if ls_stop := stop or params.get("stop", None):
|
||||
ls_params["ls_stop"] = ls_stop
|
||||
return ls_params
|
||||
|
||||
@property
|
||||
def _client_params(self) -> Dict[str, Any]:
|
||||
"""Get the parameters used for the client."""
|
||||
|
6
libs/partners/mistralai/poetry.lock
generated
6
libs/partners/mistralai/poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
@ -392,7 +392,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.2.0rc1"
|
||||
version = "0.2.0"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -1051,4 +1051,4 @@ zstd = ["zstandard (>=0.18.0)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "183134cac43c4c1d16e85d4ddf22ce4e14b6e4c78eab67134703da616f12c5cb"
|
||||
content-hash = "4a5a57d01c791de831f03fb309541443dc8bb51f5068ccfb7bcb77490c2eb6c3"
|
||||
|
@ -12,7 +12,7 @@ license = "MIT"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = ">=0.1.46,<0.3"
|
||||
langchain-core = ">=0.2.0,<0.3"
|
||||
tokenizers = ">=0.15.1,<1"
|
||||
httpx = ">=0.25.2,<1"
|
||||
httpx-sse = ">=0.3.1,<1"
|
||||
|
@ -13,14 +13,3 @@ class TestMistralStandard(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatMistralAI
|
||||
|
||||
@pytest.mark.xfail(reason="Not implemented.")
|
||||
def test_standard_params(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
super().test_standard_params(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user