mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
mistralai[patch]: ruff fixes and rules (#31918)
* bump ruff deps * add more thorough ruff rules * fix said rules
This commit is contained in:
@@ -94,8 +94,7 @@ def _create_retry_decorator(
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""
|
||||
|
||||
"""Return a tenacity retry decorator, preconfigured to handle exceptions."""
|
||||
errors = [httpx.RequestError, httpx.StreamError]
|
||||
return create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
@@ -103,12 +102,12 @@ def _create_retry_decorator(
|
||||
|
||||
|
||||
def _is_valid_mistral_tool_call_id(tool_call_id: str) -> bool:
|
||||
"""Check if tool call ID is nine character string consisting of a-z, A-Z, 0-9"""
|
||||
"""Check if tool call ID is nine character string consisting of a-z, A-Z, 0-9."""
|
||||
return bool(TOOL_CALL_ID_PATTERN.match(tool_call_id))
|
||||
|
||||
|
||||
def _base62_encode(num: int) -> str:
|
||||
"""Encodes a number in base62 and ensures result is of a specified length."""
|
||||
"""Encode a number in base62 and ensures result is of a specified length."""
|
||||
base62 = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
if num == 0:
|
||||
return base62[0]
|
||||
@@ -122,17 +121,15 @@ def _base62_encode(num: int) -> str:
|
||||
|
||||
|
||||
def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str:
|
||||
"""Convert a tool call ID to a Mistral-compatible format"""
|
||||
"""Convert a tool call ID to a Mistral-compatible format."""
|
||||
if _is_valid_mistral_tool_call_id(tool_call_id):
|
||||
return tool_call_id
|
||||
else:
|
||||
hash_bytes = hashlib.sha256(tool_call_id.encode()).digest()
|
||||
hash_int = int.from_bytes(hash_bytes, byteorder="big")
|
||||
base62_str = _base62_encode(hash_int)
|
||||
if len(base62_str) >= 9:
|
||||
return base62_str[:9]
|
||||
else:
|
||||
return base62_str.rjust(9, "0")
|
||||
hash_bytes = hashlib.sha256(tool_call_id.encode()).digest()
|
||||
hash_int = int.from_bytes(hash_bytes, byteorder="big")
|
||||
base62_str = _base62_encode(hash_int)
|
||||
if len(base62_str) >= 9:
|
||||
return base62_str[:9]
|
||||
return base62_str.rjust(9, "0")
|
||||
|
||||
|
||||
def _convert_mistral_chat_message_to_message(
|
||||
@@ -140,7 +137,8 @@ def _convert_mistral_chat_message_to_message(
|
||||
) -> BaseMessage:
|
||||
role = _message["role"]
|
||||
if role != "assistant":
|
||||
raise ValueError(f"Expected role to be 'assistant', got {role}")
|
||||
msg = f"Expected role to be 'assistant', got {role}"
|
||||
raise ValueError(msg)
|
||||
content = cast(str, _message["content"])
|
||||
|
||||
additional_kwargs: dict = {}
|
||||
@@ -170,9 +168,12 @@ def _raise_on_error(response: httpx.Response) -> None:
|
||||
"""Raise an error if the response is an error."""
|
||||
if httpx.codes.is_error(response.status_code):
|
||||
error_message = response.read().decode("utf-8")
|
||||
raise httpx.HTTPStatusError(
|
||||
msg = (
|
||||
f"Error response {response.status_code} "
|
||||
f"while fetching {response.url}: {error_message}",
|
||||
f"while fetching {response.url}: {error_message}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
msg,
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
@@ -182,9 +183,12 @@ async def _araise_on_error(response: httpx.Response) -> None:
|
||||
"""Raise an error if the response is an error."""
|
||||
if httpx.codes.is_error(response.status_code):
|
||||
error_message = (await response.aread()).decode("utf-8")
|
||||
raise httpx.HTTPStatusError(
|
||||
msg = (
|
||||
f"Error response {response.status_code} "
|
||||
f"while fetching {response.url}: {error_message}",
|
||||
f"while fetching {response.url}: {error_message}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
msg,
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
@@ -220,10 +224,9 @@ async def acompletion_with_retry(
|
||||
llm.async_client, "POST", "/chat/completions", json=kwargs
|
||||
)
|
||||
return _aiter_sse(event_source)
|
||||
else:
|
||||
response = await llm.async_client.post(url="/chat/completions", json=kwargs)
|
||||
await _araise_on_error(response)
|
||||
return response.json()
|
||||
response = await llm.async_client.post(url="/chat/completions", json=kwargs)
|
||||
await _araise_on_error(response)
|
||||
return response.json()
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
@@ -237,7 +240,7 @@ def _convert_chunk_to_message_chunk(
|
||||
content = _delta.get("content") or ""
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
if role == "assistant" or default_class == AIMessageChunk:
|
||||
additional_kwargs: dict = {}
|
||||
response_metadata = {}
|
||||
if raw_tool_calls := _delta.get("tool_calls"):
|
||||
@@ -281,12 +284,11 @@ def _convert_chunk_to_message_chunk(
|
||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
if role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
if role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
|
||||
@@ -321,18 +323,24 @@ def _convert_message_to_mistral_chat_message(
|
||||
message: BaseMessage,
|
||||
) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
return dict(role=message.role, content=message.content)
|
||||
elif isinstance(message, HumanMessage):
|
||||
return dict(role="user", content=message.content)
|
||||
elif isinstance(message, AIMessage):
|
||||
return {"role": message.role, "content": message.content}
|
||||
if isinstance(message, HumanMessage):
|
||||
return {"role": "user", "content": message.content}
|
||||
if isinstance(message, AIMessage):
|
||||
message_dict: dict[str, Any] = {"role": "assistant"}
|
||||
tool_calls = []
|
||||
if message.tool_calls or message.invalid_tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_calls.append(_format_tool_call_for_mistral(tool_call))
|
||||
tool_calls.extend(
|
||||
[
|
||||
_format_tool_call_for_mistral(tool_call)
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
)
|
||||
for invalid_tool_call in message.invalid_tool_calls:
|
||||
tool_calls.append(
|
||||
tool_calls.extend(
|
||||
_format_invalid_tool_call_for_mistral(invalid_tool_call)
|
||||
for invalid_tool_call in message.invalid_tool_calls
|
||||
)
|
||||
elif "tool_calls" in message.additional_kwargs:
|
||||
for tc in message.additional_kwargs["tool_calls"]:
|
||||
@@ -359,9 +367,9 @@ def _convert_message_to_mistral_chat_message(
|
||||
if "prefix" in message.additional_kwargs:
|
||||
message_dict["prefix"] = message.additional_kwargs["prefix"]
|
||||
return message_dict
|
||||
elif isinstance(message, SystemMessage):
|
||||
return dict(role="system", content=message.content)
|
||||
elif isinstance(message, ToolMessage):
|
||||
if isinstance(message, SystemMessage):
|
||||
return {"role": "system", "content": message.content}
|
||||
if isinstance(message, ToolMessage):
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": message.content,
|
||||
@@ -370,8 +378,8 @@ def _convert_message_to_mistral_chat_message(
|
||||
message.tool_call_id
|
||||
),
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
msg = f"Got unknown type {message}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
class ChatMistralAI(BaseChatModel):
|
||||
@@ -380,10 +388,10 @@ class ChatMistralAI(BaseChatModel):
|
||||
# The type for client and async_client is ignored because the type is not
|
||||
# an Optional after the model is initialized and the model_validator
|
||||
# is run.
|
||||
client: httpx.Client = Field( # type: ignore # : meta private:
|
||||
client: httpx.Client = Field( # type: ignore[assignment] # : meta private:
|
||||
default=None, exclude=True
|
||||
)
|
||||
async_client: httpx.AsyncClient = Field( # type: ignore # : meta private:
|
||||
async_client: httpx.AsyncClient = Field( # type: ignore[assignment] # : meta private:
|
||||
default=None, exclude=True
|
||||
) #: :meta private:
|
||||
mistral_api_key: Optional[SecretStr] = Field(
|
||||
@@ -417,8 +425,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
return values
|
||||
return _build_model_kwargs(values, all_required_field_names)
|
||||
|
||||
@property
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
@@ -432,8 +439,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
"safe_prompt": self.safe_mode,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
filtered = {k: v for k, v in defaults.items() if v is not None}
|
||||
return filtered
|
||||
return {k: v for k, v in defaults.items() if v is not None}
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
@@ -481,13 +487,11 @@ class ChatMistralAI(BaseChatModel):
|
||||
yield event.json()
|
||||
|
||||
return iter_sse()
|
||||
else:
|
||||
response = self.client.post(url="/chat/completions", json=kwargs)
|
||||
_raise_on_error(response)
|
||||
return response.json()
|
||||
response = self.client.post(url="/chat/completions", json=kwargs)
|
||||
_raise_on_error(response)
|
||||
return response.json()
|
||||
|
||||
rtn = _completion_with_retry(**kwargs)
|
||||
return rtn
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
@@ -502,8 +506,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
overall_token_usage[k] += v
|
||||
else:
|
||||
overall_token_usage[k] = v
|
||||
combined = {"token_usage": overall_token_usage, "model_name": self.model}
|
||||
return combined
|
||||
return {"token_usage": overall_token_usage, "model_name": self.model}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
@@ -545,10 +548,12 @@ class ChatMistralAI(BaseChatModel):
|
||||
)
|
||||
|
||||
if self.temperature is not None and not 0 <= self.temperature <= 1:
|
||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||
msg = "temperature must be in the range [0.0, 1.0]"
|
||||
raise ValueError(msg)
|
||||
|
||||
if self.top_p is not None and not 0 <= self.top_p <= 1:
|
||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||
msg = "top_p must be in the range [0.0, 1.0]"
|
||||
raise ValueError(msg)
|
||||
|
||||
return self
|
||||
|
||||
@@ -557,7 +562,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream: Optional[bool] = None, # noqa: FBT001
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
@@ -669,7 +674,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream: Optional[bool] = None, # noqa: FBT001
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
@@ -689,7 +694,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||
tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None,
|
||||
tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None, # noqa: PYI051
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
@@ -707,15 +712,15 @@ class ChatMistralAI(BaseChatModel):
|
||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||
kwargs: Any additional parameters are passed directly to
|
||||
``self.bind(**kwargs)``.
|
||||
"""
|
||||
|
||||
"""
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
if tool_choice:
|
||||
tool_names = []
|
||||
for tool in formatted_tools:
|
||||
if "function" in tool and (name := tool["function"].get("name")):
|
||||
tool_names.append(name)
|
||||
elif name := tool.get("name"):
|
||||
if ("function" in tool and (name := tool["function"].get("name"))) or (
|
||||
name := tool.get("name")
|
||||
):
|
||||
tool_names.append(name)
|
||||
else:
|
||||
pass
|
||||
@@ -738,7 +743,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
r"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
schema:
|
||||
@@ -785,6 +790,12 @@ class ChatMistralAI(BaseChatModel):
|
||||
will be caught and returned as well. The final output is always a dict
|
||||
with keys "raw", "parsed", and "parsing_error".
|
||||
|
||||
kwargs: Any additional parameters are passed directly to
|
||||
``self.bind(**kwargs)``. This is useful for passing in
|
||||
parameters such as ``tool_choice`` or ``tools`` to control
|
||||
which tool the model should call, or to pass in parameters such as
|
||||
``stop`` to control when the model should stop generating output.
|
||||
|
||||
Returns:
|
||||
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
||||
|
||||
@@ -968,14 +979,16 @@ class ChatMistralAI(BaseChatModel):
|
||||
""" # noqa: E501
|
||||
_ = kwargs.pop("strict", None)
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
msg = f"Received unsupported arguments {kwargs}"
|
||||
raise ValueError(msg)
|
||||
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
||||
if method == "function_calling":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"schema must be specified when method is 'function_calling'. "
|
||||
"Received None."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
# TODO: Update to pass in tool name as tool_choice if/when Mistral supports
|
||||
# specifying a tool.
|
||||
llm = self.bind_tools(
|
||||
@@ -1014,10 +1027,11 @@ class ChatMistralAI(BaseChatModel):
|
||||
)
|
||||
elif method == "json_schema":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"schema must be specified when method is 'json_schema'. "
|
||||
"Received None."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
response_format = _convert_to_openai_response_format(schema, strict=True)
|
||||
llm = self.bind(
|
||||
response_format=response_format,
|
||||
@@ -1041,8 +1055,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
return llm | output_parser
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
@@ -1072,7 +1085,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
def _convert_to_openai_response_format(
|
||||
schema: Union[dict[str, Any], type], *, strict: Optional[bool] = None
|
||||
) -> dict:
|
||||
"""Same as in ChatOpenAI, but don't pass through Pydantic BaseModels."""
|
||||
"""Perform same op as in ChatOpenAI, but do not pass through Pydantic BaseModels."""
|
||||
if (
|
||||
isinstance(schema, dict)
|
||||
and "json_schema" in schema
|
||||
|
||||
@@ -17,20 +17,20 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
||||
from tokenizers import Tokenizer # type: ignore
|
||||
from tokenizers import Tokenizer # type: ignore[import]
|
||||
from typing_extensions import Self
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_TOKENS = 16_000
|
||||
"""A batching parameter for the Mistral API. This is NOT the maximum number of tokens
|
||||
accepted by the embedding model for each document/chunk, but rather the maximum number
|
||||
accepted by the embedding model for each document/chunk, but rather the maximum number
|
||||
of tokens that can be sent in a single request to the Mistral API (across multiple
|
||||
documents/chunks)"""
|
||||
|
||||
|
||||
class DummyTokenizer:
|
||||
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)"""
|
||||
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)."""
|
||||
|
||||
@staticmethod
|
||||
def encode_batch(texts: list[str]) -> list[list[str]]:
|
||||
@@ -126,9 +126,9 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
# The type for client and async_client is ignored because the type is not
|
||||
# an Optional after the model is initialized and the model_validator
|
||||
# is run.
|
||||
client: httpx.Client = Field(default=None) # type: ignore # : :meta private:
|
||||
client: httpx.Client = Field(default=None) # type: ignore[assignment] # :meta private:
|
||||
|
||||
async_client: httpx.AsyncClient = Field( # type: ignore # : meta private:
|
||||
async_client: httpx.AsyncClient = Field( # type: ignore[assignment] # :meta private:
|
||||
default=None
|
||||
)
|
||||
mistral_api_key: SecretStr = Field(
|
||||
@@ -153,7 +153,6 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate configuration."""
|
||||
|
||||
api_key_str = self.mistral_api_key.get_secret_value()
|
||||
# todo: handle retries
|
||||
if not self.client:
|
||||
@@ -187,14 +186,14 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
"Could not download mistral tokenizer from Huggingface for "
|
||||
"calculating batch sizes. Set a Huggingface token via the "
|
||||
"HF_TOKEN environment variable to download the real tokenizer. "
|
||||
"Falling back to a dummy tokenizer that uses `len()`."
|
||||
"Falling back to a dummy tokenizer that uses `len()`.",
|
||||
stacklevel=2,
|
||||
)
|
||||
self.tokenizer = DummyTokenizer()
|
||||
return self
|
||||
|
||||
def _get_batches(self, texts: list[str]) -> Iterable[list[str]]:
|
||||
"""Split a list of texts into batches of less than 16k tokens for Mistral
|
||||
API."""
|
||||
"""Split list of texts into batches of less than 16k tokens for Mistral API."""
|
||||
batch: list[str] = []
|
||||
batch_tokens = 0
|
||||
|
||||
@@ -224,6 +223,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
|
||||
"""
|
||||
try:
|
||||
batch_responses = []
|
||||
@@ -238,16 +238,17 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
def _embed_batch(batch: list[str]) -> Response:
|
||||
response = self.client.post(
|
||||
url="/embeddings",
|
||||
json=dict(
|
||||
model=self.model,
|
||||
input=batch,
|
||||
),
|
||||
json={
|
||||
"model": self.model,
|
||||
"input": batch,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
for batch in self._get_batches(texts):
|
||||
batch_responses.append(_embed_batch(batch))
|
||||
batch_responses = [
|
||||
_embed_batch(batch) for batch in self._get_batches(texts)
|
||||
]
|
||||
return [
|
||||
list(map(float, embedding_obj["embedding"]))
|
||||
for response in batch_responses
|
||||
@@ -265,16 +266,17 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
|
||||
"""
|
||||
try:
|
||||
batch_responses = await asyncio.gather(
|
||||
*[
|
||||
self.async_client.post(
|
||||
url="/embeddings",
|
||||
json=dict(
|
||||
model=self.model,
|
||||
input=batch,
|
||||
),
|
||||
json={
|
||||
"model": self.model,
|
||||
"input": batch,
|
||||
},
|
||||
)
|
||||
for batch in self._get_batches(texts)
|
||||
]
|
||||
@@ -296,6 +298,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
@@ -307,5 +310,6 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
|
||||
"""
|
||||
return (await self.aembed_documents([text]))[0]
|
||||
|
||||
Reference in New Issue
Block a user