mistralai[minor]: 0.1.0rc0, remove mistral sdk (#19420)

This commit is contained in:
Erick Friis 2024-03-21 18:24:58 -07:00 committed by GitHub
parent e980c14d6a
commit 53ac1ebbbc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 237 additions and 448 deletions

View File

@ -1,10 +1,10 @@
from __future__ import annotations from __future__ import annotations
import importlib.util
import logging import logging
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
Any, Any,
AsyncContextManager,
AsyncIterator, AsyncIterator,
Callable, Callable,
Dict, Dict,
@ -18,6 +18,8 @@ from typing import (
cast, cast,
) )
import httpx
from httpx_sse import EventSource, aconnect_sse, connect_sse
from langchain_core._api import beta from langchain_core._api import beta
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -54,19 +56,6 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import convert_to_openai_tool
from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient
from mistralai.constants import ENDPOINT as DEFAULT_MISTRAL_ENDPOINT
from mistralai.exceptions import (
MistralAPIException,
MistralConnectionException,
MistralException,
)
from mistralai.models.chat_completion import (
ChatCompletionResponse as MistralChatCompletionResponse,
)
from mistralai.models.chat_completion import ChatMessage as MistralChatMessage
from mistralai.models.chat_completion import DeltaMessage as MistralDeltaMessage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -79,36 +68,34 @@ def _create_retry_decorator(
) -> Callable[[Any], Any]: ) -> Callable[[Any], Any]:
"""Returns a tenacity retry decorator, preconfigured to handle exceptions""" """Returns a tenacity retry decorator, preconfigured to handle exceptions"""
errors = [ errors = [httpx.RequestError, httpx.StreamError]
MistralException,
MistralAPIException,
MistralConnectionException,
]
return create_base_retry_decorator( return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
) )
def _convert_mistral_chat_message_to_message( def _convert_mistral_chat_message_to_message(
_message: MistralChatMessage, _message: Dict,
) -> BaseMessage: ) -> BaseMessage:
role = _message.role role = _message["role"]
content = cast(Union[str, List], _message.content) assert role == "assistant", f"Expected role to be 'assistant', got {role}"
if role == "user": content = cast(str, _message["content"])
return HumanMessage(content=content)
elif role == "assistant": additional_kwargs: Dict = {}
additional_kwargs: Dict = {} if tool_calls := _message.get("tool_calls"):
if hasattr(_message, "tool_calls") and getattr(_message, "tool_calls"): additional_kwargs["tool_calls"] = [tc.model_dump() for tc in tool_calls]
additional_kwargs["tool_calls"] = [ return AIMessage(content=content, additional_kwargs=additional_kwargs)
tc.model_dump() for tc in getattr(_message, "tool_calls")
]
return AIMessage(content=content, additional_kwargs=additional_kwargs) async def _aiter_sse(
elif role == "system": event_source_mgr: AsyncContextManager[EventSource],
return SystemMessage(content=content) ) -> AsyncIterator[Dict]:
elif role == "tool": """Iterate over the server-sent events."""
return ToolMessage(content=content, name=_message.name) # type: ignore[attr-defined] async with event_source_mgr as event_source:
else: async for event in event_source.aiter_sse():
return ChatMessage(content=content, role=role) if event.data == "[DONE]":
return
yield event.json()
async def acompletion_with_retry( async def acompletion_with_retry(
@ -121,28 +108,33 @@ async def acompletion_with_retry(
@retry_decorator @retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any: async def _completion_with_retry(**kwargs: Any) -> Any:
stream = kwargs.pop("stream", False) if "stream" not in kwargs:
kwargs["stream"] = False
stream = kwargs["stream"]
if stream: if stream:
return llm.async_client.chat_stream(**kwargs) event_source = aconnect_sse(
llm.async_client, "POST", "/chat/completions", json=kwargs
)
return _aiter_sse(event_source)
else: else:
return await llm.async_client.chat(**kwargs) response = await llm.async_client.post(url="/chat/completions", json=kwargs)
return response.json()
return await _completion_with_retry(**kwargs) return await _completion_with_retry(**kwargs)
def _convert_delta_to_message_chunk( def _convert_delta_to_message_chunk(
_delta: MistralDeltaMessage, default_class: Type[BaseMessageChunk] _delta: Dict, default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk: ) -> BaseMessageChunk:
role = getattr(_delta, "role") role = _delta.get("role")
content = getattr(_delta, "content", "") content = _delta.get("content", "")
if role == "user" or default_class == HumanMessageChunk: if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content) return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk: elif role == "assistant" or default_class == AIMessageChunk:
additional_kwargs: Dict = {} additional_kwargs: Dict = {}
if hasattr(_delta, "tool_calls") and getattr(_delta, "tool_calls"): if tool_calls := _delta.get("tool_calls"):
additional_kwargs["tool_calls"] = [ additional_kwargs["tool_calls"] = [tc.model_dump() for tc in tool_calls]
tc.model_dump() for tc in getattr(_delta, "tool_calls")
]
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system" or default_class == SystemMessageChunk: elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content) return SystemMessageChunk(content=content)
@ -154,44 +146,48 @@ def _convert_delta_to_message_chunk(
def _convert_message_to_mistral_chat_message( def _convert_message_to_mistral_chat_message(
message: BaseMessage, message: BaseMessage,
) -> MistralChatMessage: ) -> Dict:
if isinstance(message, ChatMessage): if isinstance(message, ChatMessage):
mistral_message = MistralChatMessage(role=message.role, content=message.content) return dict(role=message.role, content=message.content)
elif isinstance(message, HumanMessage): elif isinstance(message, HumanMessage):
mistral_message = MistralChatMessage(role="user", content=message.content) return dict(role="user", content=message.content)
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
if "tool_calls" in message.additional_kwargs: if "tool_calls" in message.additional_kwargs:
from mistralai.models.chat_completion import ( # type: ignore[attr-defined]
ToolCall as MistralToolCall,
)
tool_calls = [ tool_calls = [
MistralToolCall.model_validate(tc) {
"function": {
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"],
}
}
for tc in message.additional_kwargs["tool_calls"] for tc in message.additional_kwargs["tool_calls"]
] ]
else: else:
tool_calls = None tool_calls = None
mistral_message = MistralChatMessage( return {
role="assistant", content=message.content, tool_calls=tool_calls "role": "assistant",
) "content": message.content,
"tool_calls": tool_calls,
}
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
mistral_message = MistralChatMessage(role="system", content=message.content) return dict(role="system", content=message.content)
elif isinstance(message, ToolMessage): elif isinstance(message, ToolMessage):
mistral_message = MistralChatMessage( return {
role="tool", content=message.content, name=message.name "role": "tool",
) "content": message.content,
"name": message.name,
}
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")
return mistral_message
class ChatMistralAI(BaseChatModel): class ChatMistralAI(BaseChatModel):
"""A chat model that uses the MistralAI API.""" """A chat model that uses the MistralAI API."""
client: MistralClient = Field(default=None) #: :meta private: client: httpx.Client = Field(default=None) #: :meta private:
async_client: MistralAsyncClient = Field(default=None) #: :meta private: async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
mistral_api_key: Optional[SecretStr] = None mistral_api_key: Optional[SecretStr] = None
endpoint: str = DEFAULT_MISTRAL_ENDPOINT endpoint: str = "https://api.mistral.ai/v1"
max_retries: int = 5 max_retries: int = 5
timeout: int = 120 timeout: int = 120
max_concurrent_requests: int = 64 max_concurrent_requests: int = 64
@ -204,6 +200,7 @@ class ChatMistralAI(BaseChatModel):
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0].""" probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
random_seed: Optional[int] = None random_seed: Optional[int] = None
safe_mode: bool = False safe_mode: bool = False
streaming: bool = False
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
@ -214,7 +211,7 @@ class ChatMistralAI(BaseChatModel):
"max_tokens": self.max_tokens, "max_tokens": self.max_tokens,
"top_p": self.top_p, "top_p": self.top_p,
"random_seed": self.random_seed, "random_seed": self.random_seed,
"safe_mode": self.safe_mode, "safe_prompt": self.safe_mode,
} }
filtered = {k: v for k, v in defaults.items() if v is not None} filtered = {k: v for k, v in defaults.items() if v is not None}
return filtered return filtered
@ -228,45 +225,60 @@ class ChatMistralAI(BaseChatModel):
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any: ) -> Any:
"""Use tenacity to retry the completion call.""" """Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(self, run_manager=run_manager) # retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator # @retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any: def _completion_with_retry(**kwargs: Any) -> Any:
stream = kwargs.pop("stream", False) if "stream" not in kwargs:
kwargs["stream"] = False
stream = kwargs["stream"]
if stream: if stream:
return self.client.chat_stream(**kwargs)
else:
return self.client.chat(**kwargs)
return _completion_with_retry(**kwargs) def iter_sse() -> Iterator[Dict]:
with connect_sse(
self.client, "POST", "/chat/completions", json=kwargs
) as event_source:
for event in event_source.iter_sse():
if event.data == "[DONE]":
return
yield event.json()
return iter_sse()
else:
return self.client.post(url="/chat/completions", json=kwargs).json()
rtn = _completion_with_retry(**kwargs)
return rtn
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, and top_p.""" """Validate api key, python package exists, temperature, and top_p."""
mistralai_spec = importlib.util.find_spec("mistralai")
if mistralai_spec is None:
raise MistralException(
"Could not find mistralai python package. "
"Please install it with `pip install mistralai`"
)
values["mistral_api_key"] = convert_to_secret_str( values["mistral_api_key"] = convert_to_secret_str(
get_from_dict_or_env( get_from_dict_or_env(
values, "mistral_api_key", "MISTRAL_API_KEY", default="" values, "mistral_api_key", "MISTRAL_API_KEY", default=""
) )
) )
values["client"] = MistralClient( api_key_str = values["mistral_api_key"].get_secret_value()
api_key=values["mistral_api_key"].get_secret_value(), # todo: handle retries
endpoint=values["endpoint"], values["client"] = httpx.Client(
max_retries=values["max_retries"], base_url=values["endpoint"],
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"], timeout=values["timeout"],
) )
values["async_client"] = MistralAsyncClient( # todo: handle retries and max_concurrency
api_key=values["mistral_api_key"].get_secret_value(), values["async_client"] = httpx.AsyncClient(
endpoint=values["endpoint"], base_url=values["endpoint"],
max_retries=values["max_retries"], headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"], timeout=values["timeout"],
max_concurrent_requests=values["max_concurrent_requests"],
) )
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
@ -285,7 +297,7 @@ class ChatMistralAI(BaseChatModel):
stream: Optional[bool] = None, stream: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
should_stream = stream if stream is not None else False should_stream = stream if stream is not None else self.streaming
if should_stream: if should_stream:
stream_iter = self._stream( stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs messages, stop=stop, run_manager=run_manager, **kwargs
@ -299,27 +311,23 @@ class ChatMistralAI(BaseChatModel):
) )
return self._create_chat_result(response) return self._create_chat_result(response)
def _create_chat_result( def _create_chat_result(self, response: Dict) -> ChatResult:
self, response: MistralChatCompletionResponse
) -> ChatResult:
generations = [] generations = []
for res in response.choices: for res in response["choices"]:
finish_reason = getattr(res, "finish_reason") finish_reason = res.get("finish_reason")
if finish_reason:
finish_reason = finish_reason.value
gen = ChatGeneration( gen = ChatGeneration(
message=_convert_mistral_chat_message_to_message(res.message), message=_convert_mistral_chat_message_to_message(res["message"]),
generation_info={"finish_reason": finish_reason}, generation_info={"finish_reason": finish_reason},
) )
generations.append(gen) generations.append(gen)
token_usage = getattr(response, "usage") token_usage = response.get("usage", {})
token_usage = vars(token_usage) if token_usage else {}
llm_output = {"token_usage": token_usage, "model": self.model} llm_output = {"token_usage": token_usage, "model": self.model}
return ChatResult(generations=generations, llm_output=llm_output) return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts( def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]] self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[MistralChatMessage], Dict[str, Any]]: ) -> Tuple[List[Dict], Dict[str, Any]]:
params = self._client_params params = self._client_params
if stop is not None or "stop" in params: if stop is not None or "stop" in params:
if "stop" in params: if "stop" in params:
@ -340,20 +348,24 @@ class ChatMistralAI(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
for chunk in self.completion_with_retry( for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params messages=message_dicts, run_manager=run_manager, **params
): ):
if len(chunk.choices) == 0: if len(chunk["choices"]) == 0:
continue continue
delta = chunk.choices[0].delta delta = chunk["choices"][0]["delta"]
if not delta.content: if not delta["content"]:
continue continue
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__ # make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
gen_chunk = ChatGenerationChunk(message=new_chunk)
if run_manager: if run_manager:
run_manager.on_llm_new_token(token=chunk.content, chunk=chunk) run_manager.on_llm_new_token(
yield ChatGenerationChunk(message=chunk) token=cast(str, new_chunk.content), chunk=gen_chunk
)
yield gen_chunk
async def _astream( async def _astream(
self, self,
@ -365,20 +377,24 @@ class ChatMistralAI(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
async for chunk in await acompletion_with_retry( async for chunk in await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params self, messages=message_dicts, run_manager=run_manager, **params
): ):
if len(chunk.choices) == 0: if len(chunk["choices"]) == 0:
continue continue
delta = chunk.choices[0].delta delta = chunk["choices"][0]["delta"]
if not delta.content: if not delta["content"]:
continue continue
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__ # make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
gen_chunk = ChatGenerationChunk(message=new_chunk)
if run_manager: if run_manager:
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk) await run_manager.on_llm_new_token(
yield ChatGenerationChunk(message=chunk) token=cast(str, new_chunk.content), chunk=gen_chunk
)
yield gen_chunk
async def _agenerate( async def _agenerate(
self, self,

View File

@ -2,6 +2,7 @@ import asyncio
import logging import logging
from typing import Dict, Iterable, List, Optional from typing import Dict, Iterable, List, Optional
import httpx
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import ( from langchain_core.pydantic_v1 import (
BaseModel, BaseModel,
@ -11,12 +12,6 @@ from langchain_core.pydantic_v1 import (
root_validator, root_validator,
) )
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient
from mistralai.constants import (
ENDPOINT as DEFAULT_MISTRAL_ENDPOINT,
)
from mistralai.exceptions import MistralException
from tokenizers import Tokenizer # type: ignore from tokenizers import Tokenizer # type: ignore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -40,10 +35,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
) )
""" """
client: MistralClient = Field(default=None) #: :meta private: client: httpx.Client = Field(default=None) #: :meta private:
async_client: MistralAsyncClient = Field(default=None) #: :meta private: async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
mistral_api_key: Optional[SecretStr] = None mistral_api_key: Optional[SecretStr] = None
endpoint: str = DEFAULT_MISTRAL_ENDPOINT endpoint: str = "https://api.mistral.ai/v1/"
max_retries: int = 5 max_retries: int = 5
timeout: int = 120 timeout: int = 120
max_concurrent_requests: int = 64 max_concurrent_requests: int = 64
@ -64,18 +59,26 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
values, "mistral_api_key", "MISTRAL_API_KEY", default="" values, "mistral_api_key", "MISTRAL_API_KEY", default=""
) )
) )
values["client"] = MistralClient( api_key_str = values["mistral_api_key"].get_secret_value()
api_key=values["mistral_api_key"].get_secret_value(), # todo: handle retries
endpoint=values["endpoint"], values["client"] = httpx.Client(
max_retries=values["max_retries"], base_url=values["endpoint"],
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"], timeout=values["timeout"],
) )
values["async_client"] = MistralAsyncClient( # todo: handle retries and max_concurrency
api_key=values["mistral_api_key"].get_secret_value(), values["async_client"] = httpx.AsyncClient(
endpoint=values["endpoint"], base_url=values["endpoint"],
max_retries=values["max_retries"], headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"], timeout=values["timeout"],
max_concurrent_requests=values["max_concurrent_requests"],
) )
if values["tokenizer"] is None: if values["tokenizer"] is None:
values["tokenizer"] = Tokenizer.from_pretrained( values["tokenizer"] = Tokenizer.from_pretrained(
@ -115,18 +118,21 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
""" """
try: try:
batch_responses = ( batch_responses = (
self.client.embeddings( self.client.post(
model=self.model, url="/embeddings",
input=batch, json=dict(
model=self.model,
input=batch,
),
) )
for batch in self._get_batches(texts) for batch in self._get_batches(texts)
) )
return [ return [
list(map(float, embedding_obj.embedding)) list(map(float, embedding_obj["embedding"]))
for response in batch_responses for response in batch_responses
for embedding_obj in response.data for embedding_obj in response.json()["data"]
] ]
except MistralException as e: except Exception as e:
logger.error(f"An error occurred with MistralAI: {e}") logger.error(f"An error occurred with MistralAI: {e}")
raise raise
@ -142,19 +148,22 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
try: try:
batch_responses = await asyncio.gather( batch_responses = await asyncio.gather(
*[ *[
self.async_client.embeddings( self.async_client.post(
model=self.model, url="/embeddings",
input=batch, json=dict(
model=self.model,
input=batch,
),
) )
for batch in self._get_batches(texts) for batch in self._get_batches(texts)
] ]
) )
return [ return [
list(map(float, embedding_obj.embedding)) list(map(float, embedding_obj["embedding"]))
for response in batch_responses for response in batch_responses
for embedding_obj in response.data for embedding_obj in response.json()["data"]
] ]
except MistralException as e: except Exception as e:
logger.error(f"An error occurred with MistralAI: {e}") logger.error(f"An error occurred with MistralAI: {e}")
raise raise

View File

@ -206,13 +206,13 @@ typing = ["typing-extensions (>=4.8)"]
[[package]] [[package]]
name = "fsspec" name = "fsspec"
version = "2024.2.0" version = "2024.3.1"
description = "File-system specification" description = "File-system specification"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "fsspec-2024.2.0-py3-none-any.whl", hash = "sha256:817f969556fa5916bc682e02ca2045f96ff7f586d45110fcb76022063ad2c7d8"}, {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"},
{file = "fsspec-2024.2.0.tar.gz", hash = "sha256:b6ad1a679f760dda52b1168c859d01b7b80648ea6f7f7c7f5a8a91dc3f3ecb84"}, {file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"},
] ]
[package.extras] [package.extras]
@ -273,13 +273,13 @@ trio = ["trio (>=0.22.0,<0.25.0)"]
[[package]] [[package]]
name = "httpx" name = "httpx"
version = "0.25.2" version = "0.27.0"
description = "The next generation HTTP client." description = "The next generation HTTP client."
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "httpx-0.25.2-py3-none-any.whl", hash = "sha256:a05d3d052d9b2dfce0e3896636467f8a5342fb2b902c819428e1ac65413ca118"}, {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"},
{file = "httpx-0.25.2.tar.gz", hash = "sha256:8b8fcaa0c8ea7b05edd69a094e63a2094c4efcb48129fb757361bc423c0ad9e8"}, {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"},
] ]
[package.dependencies] [package.dependencies]
@ -295,15 +295,26 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
http2 = ["h2 (>=3,<5)"] http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"] socks = ["socksio (==1.*)"]
[[package]]
name = "httpx-sse"
version = "0.4.0"
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
optional = false
python-versions = ">=3.8"
files = [
{file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"},
{file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"},
]
[[package]] [[package]]
name = "huggingface-hub" name = "huggingface-hub"
version = "0.20.3" version = "0.21.4"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false optional = false
python-versions = ">=3.8.0" python-versions = ">=3.8.0"
files = [ files = [
{file = "huggingface_hub-0.20.3-py3-none-any.whl", hash = "sha256:d988ae4f00d3e307b0c80c6a05ca6dbb7edba8bba3079f74cda7d9c2e562a7b6"}, {file = "huggingface_hub-0.21.4-py3-none-any.whl", hash = "sha256:df37c2c37fc6c82163cdd8a67ede261687d80d1e262526d6c0ce73b6b3630a7b"},
{file = "huggingface_hub-0.20.3.tar.gz", hash = "sha256:94e7f8e074475fbc67d6a71957b678e1b4a74ff1b64a644fd6cbb83da962d05d"}, {file = "huggingface_hub-0.21.4.tar.gz", hash = "sha256:e1f4968c93726565a80edf6dc309763c7b546d0cfe79aa221206034d50155531"},
] ]
[package.dependencies] [package.dependencies]
@ -320,11 +331,12 @@ all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi",
cli = ["InquirerPy (==0.3.4)"] cli = ["InquirerPy (==0.3.4)"]
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
hf-transfer = ["hf-transfer (>=0.1.4)"]
inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"] inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"]
quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"] quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"]
tensorflow = ["graphviz", "pydot", "tensorflow"] tensorflow = ["graphviz", "pydot", "tensorflow"]
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
torch = ["torch"] torch = ["safetensors", "torch"]
typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
[[package]] [[package]]
@ -376,7 +388,7 @@ files = [
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "0.1.27" version = "0.1.33"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -402,13 +414,13 @@ url = "../../core"
[[package]] [[package]]
name = "langsmith" name = "langsmith"
version = "0.1.8" version = "0.1.31"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = "<4.0,>=3.8.1"
files = [ files = [
{file = "langsmith-0.1.8-py3-none-any.whl", hash = "sha256:f4320fd80ec9d311a648e7d4c44e0814e6e5454772c5026f40db0307bc07e287"}, {file = "langsmith-0.1.31-py3-none-any.whl", hash = "sha256:5211a9dc00831db307eb843485a97096484b697b5d2cd1efaac34228e97ca087"},
{file = "langsmith-0.1.8.tar.gz", hash = "sha256:ab5f1cdfb7d418109ea506d41928fb8708547db2f6c7f7da7cfe997f3c55767b"}, {file = "langsmith-0.1.31.tar.gz", hash = "sha256:efd54ccd44be7fda911bfdc0ead340473df2fdd07345c7252901834d0c4aa37e"},
] ]
[package.dependencies] [package.dependencies]
@ -416,40 +428,6 @@ orjson = ">=3.9.14,<4.0.0"
pydantic = ">=1,<3" pydantic = ">=1,<3"
requests = ">=2,<3" requests = ">=2,<3"
[[package]]
name = "mistralai"
version = "0.0.12"
description = ""
optional = false
python-versions = ">=3.8,<4.0"
files = [
{file = "mistralai-0.0.12-py3-none-any.whl", hash = "sha256:d489d1f0a31bf0edbe15c6d12f68b943148d2a725a088be0d8a5d4c888f8436c"},
{file = "mistralai-0.0.12.tar.gz", hash = "sha256:fe652836146a15bdce7691a95803a32c53c641c5400093447ffa93bf2ed296b2"},
]
[package.dependencies]
httpx = ">=0.25.2,<0.26.0"
orjson = ">=3.9.10,<4.0.0"
pydantic = ">=2.5.2,<3.0.0"
[[package]]
name = "mistralai"
version = "0.1.2"
description = ""
optional = false
python-versions = ">=3.9,<4.0"
files = [
{file = "mistralai-0.1.2-py3-none-any.whl", hash = "sha256:5e74e5ef0c0f15058892d73b00c659e06e9882c00838a1ad9862d93c77336847"},
{file = "mistralai-0.1.2.tar.gz", hash = "sha256:eb915fd15075f71bdbfce9cb476bb647322b1ce1e93b19ab0047728067466397"},
]
[package.dependencies]
httpx = ">=0.25.2,<0.26.0"
orjson = ">=3.9.10,<4.0.0"
pandas = ">=2.2.0,<3.0.0"
pyarrow = ">=15.0.0,<16.0.0"
pydantic = ">=2.5.2,<3.0.0"
[[package]] [[package]]
name = "mypy" name = "mypy"
version = "0.991" version = "0.991"
@ -511,51 +489,6 @@ files = [
{file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"},
] ]
[[package]]
name = "numpy"
version = "1.26.4"
description = "Fundamental package for array computing in Python"
optional = false
python-versions = ">=3.9"
files = [
{file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
{file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"},
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"},
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"},
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"},
{file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"},
{file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"},
{file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"},
{file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"},
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"},
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"},
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"},
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"},
{file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"},
{file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"},
{file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"},
{file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"},
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"},
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"},
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"},
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"},
{file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"},
{file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"},
{file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"},
{file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"},
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"},
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"},
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"},
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"},
{file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"},
{file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"},
{file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
]
[[package]] [[package]]
name = "orjson" name = "orjson"
version = "3.9.15" version = "3.9.15"
@ -626,79 +559,6 @@ files = [
{file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"},
] ]
[[package]]
name = "pandas"
version = "2.2.1"
description = "Powerful data structures for data analysis, time series, and statistics"
optional = false
python-versions = ">=3.9"
files = [
{file = "pandas-2.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8df8612be9cd1c7797c93e1c5df861b2ddda0b48b08f2c3eaa0702cf88fb5f88"},
{file = "pandas-2.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0f573ab277252ed9aaf38240f3b54cfc90fff8e5cab70411ee1d03f5d51f3944"},
{file = "pandas-2.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f02a3a6c83df4026e55b63c1f06476c9aa3ed6af3d89b4f04ea656ccdaaaa359"},
{file = "pandas-2.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c38ce92cb22a4bea4e3929429aa1067a454dcc9c335799af93ba9be21b6beb51"},
{file = "pandas-2.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c2ce852e1cf2509a69e98358e8458775f89599566ac3775e70419b98615f4b06"},
{file = "pandas-2.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:53680dc9b2519cbf609c62db3ed7c0b499077c7fefda564e330286e619ff0dd9"},
{file = "pandas-2.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:94e714a1cca63e4f5939cdce5f29ba8d415d85166be3441165edd427dc9f6bc0"},
{file = "pandas-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f821213d48f4ab353d20ebc24e4faf94ba40d76680642fb7ce2ea31a3ad94f9b"},
{file = "pandas-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c70e00c2d894cb230e5c15e4b1e1e6b2b478e09cf27cc593a11ef955b9ecc81a"},
{file = "pandas-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e97fbb5387c69209f134893abc788a6486dbf2f9e511070ca05eed4b930b1b02"},
{file = "pandas-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:101d0eb9c5361aa0146f500773395a03839a5e6ecde4d4b6ced88b7e5a1a6403"},
{file = "pandas-2.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:7d2ed41c319c9fb4fd454fe25372028dfa417aacb9790f68171b2e3f06eae8cd"},
{file = "pandas-2.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:af5d3c00557d657c8773ef9ee702c61dd13b9d7426794c9dfeb1dc4a0bf0ebc7"},
{file = "pandas-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:06cf591dbaefb6da9de8472535b185cba556d0ce2e6ed28e21d919704fef1a9e"},
{file = "pandas-2.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:88ecb5c01bb9ca927ebc4098136038519aa5d66b44671861ffab754cae75102c"},
{file = "pandas-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:04f6ec3baec203c13e3f8b139fb0f9f86cd8c0b94603ae3ae8ce9a422e9f5bee"},
{file = "pandas-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a935a90a76c44fe170d01e90a3594beef9e9a6220021acfb26053d01426f7dc2"},
{file = "pandas-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c391f594aae2fd9f679d419e9a4d5ba4bce5bb13f6a989195656e7dc4b95c8f0"},
{file = "pandas-2.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9d1265545f579edf3f8f0cb6f89f234f5e44ba725a34d86535b1a1d38decbccc"},
{file = "pandas-2.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:11940e9e3056576ac3244baef2fedade891977bcc1cb7e5cc8f8cc7d603edc89"},
{file = "pandas-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:4acf681325ee1c7f950d058b05a820441075b0dd9a2adf5c4835b9bc056bf4fb"},
{file = "pandas-2.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9bd8a40f47080825af4317d0340c656744f2bfdb6819f818e6ba3cd24c0e1397"},
{file = "pandas-2.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:df0c37ebd19e11d089ceba66eba59a168242fc6b7155cba4ffffa6eccdfb8f16"},
{file = "pandas-2.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:739cc70eaf17d57608639e74d63387b0d8594ce02f69e7a0b046f117974b3019"},
{file = "pandas-2.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9d3558d263073ed95e46f4650becff0c5e1ffe0fc3a015de3c79283dfbdb3df"},
{file = "pandas-2.2.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4aa1d8707812a658debf03824016bf5ea0d516afdea29b7dc14cf687bc4d4ec6"},
{file = "pandas-2.2.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:76f27a809cda87e07f192f001d11adc2b930e93a2b0c4a236fde5429527423be"},
{file = "pandas-2.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:1ba21b1d5c0e43416218db63037dbe1a01fc101dc6e6024bcad08123e48004ab"},
{file = "pandas-2.2.1.tar.gz", hash = "sha256:0ab90f87093c13f3e8fa45b48ba9f39181046e8f3317d3aadb2fffbb1b978572"},
]
[package.dependencies]
numpy = [
{version = ">=1.22.4,<2", markers = "python_version < \"3.11\""},
{version = ">=1.23.2,<2", markers = "python_version == \"3.11\""},
{version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""},
]
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
tzdata = ">=2022.7"
[package.extras]
all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"]
aws = ["s3fs (>=2022.11.0)"]
clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"]
compression = ["zstandard (>=0.19.0)"]
computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"]
consortium-standard = ["dataframe-api-compat (>=0.1.7)"]
excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"]
feather = ["pyarrow (>=10.0.1)"]
fss = ["fsspec (>=2022.11.0)"]
gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"]
hdf5 = ["tables (>=3.8.0)"]
html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"]
mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"]
output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"]
parquet = ["pyarrow (>=10.0.1)"]
performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"]
plot = ["matplotlib (>=3.6.3)"]
postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"]
pyarrow = ["pyarrow (>=10.0.1)"]
spss = ["pyreadstat (>=1.2.0)"]
sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"]
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"]
xml = ["lxml (>=4.9.2)"]
[[package]] [[package]]
name = "pluggy" name = "pluggy"
version = "1.4.0" version = "1.4.0"
@ -714,63 +574,15 @@ files = [
dev = ["pre-commit", "tox"] dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"] testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "pyarrow"
version = "15.0.0"
description = "Python library for Apache Arrow"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyarrow-15.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:0a524532fd6dd482edaa563b686d754c70417c2f72742a8c990b322d4c03a15d"},
{file = "pyarrow-15.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:60a6bdb314affa9c2e0d5dddf3d9cbb9ef4a8dddaa68669975287d47ece67642"},
{file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:66958fd1771a4d4b754cd385835e66a3ef6b12611e001d4e5edfcef5f30391e2"},
{file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f500956a49aadd907eaa21d4fff75f73954605eaa41f61cb94fb008cf2e00c6"},
{file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:6f87d9c4f09e049c2cade559643424da84c43a35068f2a1c4653dc5b1408a929"},
{file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:85239b9f93278e130d86c0e6bb455dcb66fc3fd891398b9d45ace8799a871a1e"},
{file = "pyarrow-15.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5b8d43e31ca16aa6e12402fcb1e14352d0d809de70edd185c7650fe80e0769e3"},
{file = "pyarrow-15.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:fa7cd198280dbd0c988df525e50e35b5d16873e2cdae2aaaa6363cdb64e3eec5"},
{file = "pyarrow-15.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8780b1a29d3c8b21ba6b191305a2a607de2e30dab399776ff0aa09131e266340"},
{file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe0ec198ccc680f6c92723fadcb97b74f07c45ff3fdec9dd765deb04955ccf19"},
{file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:036a7209c235588c2f07477fe75c07e6caced9b7b61bb897c8d4e52c4b5f9555"},
{file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2bd8a0e5296797faf9a3294e9fa2dc67aa7f10ae2207920dbebb785c77e9dbe5"},
{file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e8ebed6053dbe76883a822d4e8da36860f479d55a762bd9e70d8494aed87113e"},
{file = "pyarrow-15.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:17d53a9d1b2b5bd7d5e4cd84d018e2a45bc9baaa68f7e6e3ebed45649900ba99"},
{file = "pyarrow-15.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9950a9c9df24090d3d558b43b97753b8f5867fb8e521f29876aa021c52fda351"},
{file = "pyarrow-15.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:003d680b5e422d0204e7287bb3fa775b332b3fce2996aa69e9adea23f5c8f970"},
{file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f75fce89dad10c95f4bf590b765e3ae98bcc5ba9f6ce75adb828a334e26a3d40"},
{file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ca9cb0039923bec49b4fe23803807e4ef39576a2bec59c32b11296464623dc2"},
{file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ed5a78ed29d171d0acc26a305a4b7f83c122d54ff5270810ac23c75813585e4"},
{file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:6eda9e117f0402dfcd3cd6ec9bfee89ac5071c48fc83a84f3075b60efa96747f"},
{file = "pyarrow-15.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a3a6180c0e8f2727e6f1b1c87c72d3254cac909e609f35f22532e4115461177"},
{file = "pyarrow-15.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:19a8918045993349b207de72d4576af0191beef03ea655d8bdb13762f0cd6eac"},
{file = "pyarrow-15.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d0ec076b32bacb6666e8813a22e6e5a7ef1314c8069d4ff345efa6246bc38593"},
{file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5db1769e5d0a77eb92344c7382d6543bea1164cca3704f84aa44e26c67e320fb"},
{file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2617e3bf9df2a00020dd1c1c6dce5cc343d979efe10bc401c0632b0eef6ef5b"},
{file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:d31c1d45060180131caf10f0f698e3a782db333a422038bf7fe01dace18b3a31"},
{file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:c8c287d1d479de8269398b34282e206844abb3208224dbdd7166d580804674b7"},
{file = "pyarrow-15.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:07eb7f07dc9ecbb8dace0f58f009d3a29ee58682fcdc91337dfeb51ea618a75b"},
{file = "pyarrow-15.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:47af7036f64fce990bb8a5948c04722e4e3ea3e13b1007ef52dfe0aa8f23cf7f"},
{file = "pyarrow-15.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93768ccfff85cf044c418bfeeafce9a8bb0cee091bd8fd19011aff91e58de540"},
{file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6ee87fd6892700960d90abb7b17a72a5abb3b64ee0fe8db6c782bcc2d0dc0b4"},
{file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:001fca027738c5f6be0b7a3159cc7ba16a5c52486db18160909a0831b063c4e4"},
{file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:d1c48648f64aec09accf44140dccb92f4f94394b8d79976c426a5b79b11d4fa7"},
{file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:972a0141be402bb18e3201448c8ae62958c9c7923dfaa3b3d4530c835ac81aed"},
{file = "pyarrow-15.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:f01fc5cf49081426429127aa2d427d9d98e1cb94a32cb961d583a70b7c4504e6"},
{file = "pyarrow-15.0.0.tar.gz", hash = "sha256:876858f549d540898f927eba4ef77cd549ad8d24baa3207cf1b72e5788b50e83"},
]
[package.dependencies]
numpy = ">=1.16.6,<2"
[[package]] [[package]]
name = "pydantic" name = "pydantic"
version = "2.6.2" version = "2.6.4"
description = "Data validation using Python type hints" description = "Data validation using Python type hints"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "pydantic-2.6.2-py3-none-any.whl", hash = "sha256:37a5432e54b12fecaa1049c5195f3d860a10e01bdfd24f1840ef14bd0d3aeab3"}, {file = "pydantic-2.6.4-py3-none-any.whl", hash = "sha256:cc46fce86607580867bdc3361ad462bab9c222ef042d3da86f2fb333e1d916c5"},
{file = "pydantic-2.6.2.tar.gz", hash = "sha256:a09be1c3d28f3abe37f8a78af58284b236a92ce520105ddc91a6d29ea1176ba7"}, {file = "pydantic-2.6.4.tar.gz", hash = "sha256:b1704e0847db01817624a6b86766967f552dd9dbf3afba4004409f908dcc84e6"},
] ]
[package.dependencies] [package.dependencies]
@ -912,31 +724,6 @@ pytest = ">=7.0.0"
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
[[package]]
name = "python-dateutil"
version = "2.8.2"
description = "Extensions to the standard Python datetime module"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
files = [
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
]
[package.dependencies]
six = ">=1.5"
[[package]]
name = "pytz"
version = "2024.1"
description = "World timezone definitions, modern and historical"
optional = false
python-versions = "*"
files = [
{file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"},
{file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"},
]
[[package]] [[package]]
name = "pyyaml" name = "pyyaml"
version = "6.0.1" version = "6.0.1"
@ -1044,17 +831,6 @@ files = [
{file = "ruff-0.1.15.tar.gz", hash = "sha256:f6dfa8c1b21c913c326919056c390966648b680966febcb796cc9d1aaab8564e"}, {file = "ruff-0.1.15.tar.gz", hash = "sha256:f6dfa8c1b21c913c326919056c390966648b680966febcb796cc9d1aaab8564e"},
] ]
[[package]]
name = "six"
version = "1.16.0"
description = "Python 2 and 3 compatibility utilities"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
files = [
{file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
]
[[package]] [[package]]
name = "sniffio" name = "sniffio"
version = "1.3.1" version = "1.3.1"
@ -1249,17 +1025,6 @@ files = [
{file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"},
] ]
[[package]]
name = "tzdata"
version = "2024.1"
description = "Provider of IANA time zone data"
optional = false
python-versions = ">=2"
files = [
{file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"},
{file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"},
]
[[package]] [[package]]
name = "urllib3" name = "urllib3"
version = "2.2.1" version = "2.2.1"
@ -1280,4 +1045,4 @@ zstd = ["zstandard (>=0.18.0)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "ccb95664a734631dde949975506ab160f65cdd222b28bf4f702fb4b11644f418" content-hash = "3d4fde33e55ded42474f7f42fbe34ce877f1deccaffdeed17d4ea26c47d07842"

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "langchain-mistralai" name = "langchain-mistralai"
version = "0.0.5" version = "0.1.0rc0"
description = "An integration package connecting Mistral and LangChain" description = "An integration package connecting Mistral and LangChain"
authors = [] authors = []
readme = "README.md" readme = "README.md"
@ -13,8 +13,9 @@ license = "MIT"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<4.0" python = ">=3.8.1,<4.0"
langchain-core = "^0.1.27" langchain-core = "^0.1.27"
mistralai = [{version = "^0.1", python = "^3.9"}, {version = ">=0.0.11,<0.2", python="3.8"}]
tokenizers = "^0.15.1" tokenizers = "^0.15.1"
httpx = ">=0.25.2,<1"
httpx-sse = ">=0.3.1,<1"
[tool.poetry.group.test] [tool.poetry.group.test]
optional = true optional = true
@ -24,17 +25,17 @@ pytest = "^7.3.0"
pytest-asyncio = "^0.21.1" pytest-asyncio = "^0.21.1"
langchain-core = { path = "../../core", develop = true } langchain-core = { path = "../../core", develop = true }
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.test_integration.dependencies]
[tool.poetry.group.codespell] [tool.poetry.group.codespell]
optional = true optional = true
[tool.poetry.group.codespell.dependencies] [tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0" codespell = "^2.2.0"
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.test_integration.dependencies]
[tool.poetry.group.lint] [tool.poetry.group.lint]
optional = true optional = true

View File

@ -1,6 +1,7 @@
"""Test MistralAI Chat API wrapper.""" """Test MistralAI Chat API wrapper."""
import os import os
from typing import Any, AsyncGenerator, Generator from typing import Any, AsyncGenerator, Dict, Generator
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -13,16 +14,6 @@ from langchain_core.messages import (
SystemMessage, SystemMessage,
) )
# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker.
from mistralai.models.chat_completion import ( # type: ignore[import]
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
DeltaMessage,
)
from mistralai.models.chat_completion import (
ChatMessage as MistralChatMessage,
)
from langchain_mistralai.chat_models import ( # type: ignore[import] from langchain_mistralai.chat_models import ( # type: ignore[import]
ChatMistralAI, ChatMistralAI,
_convert_message_to_mistral_chat_message, _convert_message_to_mistral_chat_message,
@ -31,13 +22,11 @@ from langchain_mistralai.chat_models import ( # type: ignore[import]
os.environ["MISTRAL_API_KEY"] = "foo" os.environ["MISTRAL_API_KEY"] = "foo"
@pytest.mark.requires("mistralai")
def test_mistralai_model_param() -> None: def test_mistralai_model_param() -> None:
llm = ChatMistralAI(model="foo") llm = ChatMistralAI(model="foo")
assert llm.model == "foo" assert llm.model == "foo"
@pytest.mark.requires("mistralai")
def test_mistralai_initialization() -> None: def test_mistralai_initialization() -> None:
"""Test ChatMistralAI initialization.""" """Test ChatMistralAI initialization."""
# Verify that ChatMistralAI can be initialized using a secret key provided # Verify that ChatMistralAI can be initialized using a secret key provided
@ -50,37 +39,37 @@ def test_mistralai_initialization() -> None:
[ [
( (
SystemMessage(content="Hello"), SystemMessage(content="Hello"),
MistralChatMessage(role="system", content="Hello"), dict(role="system", content="Hello"),
), ),
( (
HumanMessage(content="Hello"), HumanMessage(content="Hello"),
MistralChatMessage(role="user", content="Hello"), dict(role="user", content="Hello"),
), ),
( (
AIMessage(content="Hello"), AIMessage(content="Hello"),
MistralChatMessage(role="assistant", content="Hello"), dict(role="assistant", content="Hello", tool_calls=None),
), ),
( (
ChatMessage(role="assistant", content="Hello"), ChatMessage(role="assistant", content="Hello"),
MistralChatMessage(role="assistant", content="Hello"), dict(role="assistant", content="Hello"),
), ),
], ],
) )
def test_convert_message_to_mistral_chat_message( def test_convert_message_to_mistral_chat_message(
message: BaseMessage, expected: MistralChatMessage message: BaseMessage, expected: Dict
) -> None: ) -> None:
result = _convert_message_to_mistral_chat_message(message) result = _convert_message_to_mistral_chat_message(message)
assert result == expected assert result == expected
def _make_completion_response_from_token(token: str) -> ChatCompletionStreamResponse: def _make_completion_response_from_token(token: str) -> Dict:
return ChatCompletionStreamResponse( return dict(
id="abc123", id="abc123",
model="fake_model", model="fake_model",
choices=[ choices=[
ChatCompletionResponseStreamChoice( dict(
index=0, index=0,
delta=DeltaMessage(content=token), delta=dict(content=token),
finish_reason=None, finish_reason=None,
) )
], ],
@ -88,13 +77,19 @@ def _make_completion_response_from_token(token: str) -> ChatCompletionStreamResp
def mock_chat_stream(*args: Any, **kwargs: Any) -> Generator: def mock_chat_stream(*args: Any, **kwargs: Any) -> Generator:
for token in ["Hello", " how", " can", " I", " help", "?"]: def it() -> Generator:
yield _make_completion_response_from_token(token) for token in ["Hello", " how", " can", " I", " help", "?"]:
yield _make_completion_response_from_token(token)
return it()
async def mock_chat_astream(*args: Any, **kwargs: Any) -> AsyncGenerator: async def mock_chat_astream(*args: Any, **kwargs: Any) -> AsyncGenerator:
for token in ["Hello", " how", " can", " I", " help", "?"]: async def it() -> AsyncGenerator:
yield _make_completion_response_from_token(token) for token in ["Hello", " how", " can", " I", " help", "?"]:
yield _make_completion_response_from_token(token)
return it()
class MyCustomHandler(BaseCallbackHandler): class MyCustomHandler(BaseCallbackHandler):
@ -104,7 +99,10 @@ class MyCustomHandler(BaseCallbackHandler):
self.last_token = token self.last_token = token
@patch("mistralai.client.MistralClient.chat_stream", new=mock_chat_stream) @patch(
"langchain_mistralai.chat_models.ChatMistralAI.completion_with_retry",
new=mock_chat_stream,
)
def test_stream_with_callback() -> None: def test_stream_with_callback() -> None:
callback = MyCustomHandler() callback = MyCustomHandler()
chat = ChatMistralAI(callbacks=[callback]) chat = ChatMistralAI(callbacks=[callback])
@ -112,7 +110,7 @@ def test_stream_with_callback() -> None:
assert callback.last_token == token.content assert callback.last_token == token.content
@patch("mistralai.async_client.MistralAsyncClient.chat_stream", new=mock_chat_astream) @patch("langchain_mistralai.chat_models.acompletion_with_retry", new=mock_chat_astream)
async def test_astream_with_callback() -> None: async def test_astream_with_callback() -> None:
callback = MyCustomHandler() callback = MyCustomHandler()
chat = ChatMistralAI(callbacks=[callback]) chat = ChatMistralAI(callbacks=[callback])