diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 4b60d269d71..cdbc5d37bb8 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -139,7 +139,8 @@ def _convert_mistral_chat_message_to_message( _message: dict, ) -> BaseMessage: role = _message["role"] - assert role == "assistant", f"Expected role to be 'assistant', got {role}" + if role != "assistant": + raise ValueError(f"Expected role to be 'assistant', got {role}") content = cast(str, _message["content"]) additional_kwargs: dict = {} @@ -398,7 +399,8 @@ class ChatMistralAI(BaseChatModel): max_tokens: Optional[int] = None top_p: float = 1 """Decode using nucleus sampling: consider the smallest set of tokens whose - probability sum is at least top_p. Must be in the closed interval [0.0, 1.0].""" + probability sum is at least ``top_p``. Must be in the closed interval + ``[0.0, 1.0]``.""" random_seed: Optional[int] = None safe_mode: Optional[bool] = None streaming: bool = False diff --git a/libs/partners/mistralai/langchain_mistralai/embeddings.py b/libs/partners/mistralai/langchain_mistralai/embeddings.py index 72f692a0aed..575ceb0c637 100644 --- a/libs/partners/mistralai/langchain_mistralai/embeddings.py +++ b/libs/partners/mistralai/langchain_mistralai/embeddings.py @@ -54,21 +54,23 @@ class MistralAIEmbeddings(BaseModel, Embeddings): Name of MistralAI model to use. Key init args — client params: - api_key: Optional[SecretStr] - The API key for the MistralAI API. If not provided, it will be read from the - environment variable `MISTRAL_API_KEY`. - max_retries: int - The number of times to retry a request if it fails. - timeout: int - The number of seconds to wait for a response before timing out. - wait_time: int - The number of seconds to wait before retrying a request in case of 429 error. - max_concurrent_requests: int - The maximum number of concurrent requests to make to the Mistral API. + api_key: Optional[SecretStr] + The API key for the MistralAI API. If not provided, it will be read from the + environment variable ``MISTRAL_API_KEY``. + max_retries: int + The number of times to retry a request if it fails. + timeout: int + The number of seconds to wait for a response before timing out. + wait_time: int + The number of seconds to wait before retrying a request in case of 429 + error. + max_concurrent_requests: int + The maximum number of concurrent requests to make to the Mistral API. See full list of supported init args and their descriptions in the params section. Instantiate: + .. code-block:: python from __module_name__ import MistralAIEmbeddings @@ -80,6 +82,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings): ) Embed single text: + .. code-block:: python input_text = "The meaning of life is 42" @@ -91,9 +94,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings): [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915] Embed multiple text: + .. code-block:: python - input_texts = ["Document 1...", "Document 2..."] + input_texts = ["Document 1...", "Document 2..."] vectors = embed.embed_documents(input_texts) print(len(vectors)) # The first 3 coordinates for the first vector @@ -105,10 +109,11 @@ class MistralAIEmbeddings(BaseModel, Embeddings): [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915] Async: + .. code-block:: python vector = await embed.aembed_query(input_text) - print(vector[:3]) + print(vector[:3]) # multiple: # await embed.aembed_documents(input_texts) @@ -188,8 +193,8 @@ class MistralAIEmbeddings(BaseModel, Embeddings): return self def _get_batches(self, texts: list[str]) -> Iterable[list[str]]: - """Split a list of texts into batches of less than 16k tokens - for Mistral API.""" + """Split a list of texts into batches of less than 16k tokens for Mistral + API.""" batch: list[str] = [] batch_tokens = 0 diff --git a/libs/partners/mistralai/pyproject.toml b/libs/partners/mistralai/pyproject.toml index 4ef62d4ce69..4d4aef83662 100644 --- a/libs/partners/mistralai/pyproject.toml +++ b/libs/partners/mistralai/pyproject.toml @@ -48,7 +48,7 @@ disallow_untyped_defs = "True" target-version = "py39" [tool.ruff.lint] -select = ["E", "F", "I", "T201", "UP"] +select = ["E", "F", "I", "T201", "UP", "S"] ignore = [ "UP007", ] [tool.coverage.run] @@ -61,3 +61,9 @@ markers = [ "compile: mark placeholder test used to compile integration tests without running them", ] asyncio_mode = "auto" + +[tool.ruff.lint.extend-per-file-ignores] +"tests/**/*.py" = [ + "S101", # Tests need assertions + "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes +]