From 75c179c5e66fa30025e195b93e3f284160a366ea Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Sat, 23 Mar 2024 12:24:53 -0700 Subject: [PATCH] mistralai[patch]: streaming tool calls (#19469) --- .../langchain_mistralai/chat_models.py | 8 +-- libs/partners/mistralai/pyproject.toml | 2 +- .../integration_tests/test_chat_models.py | 61 +++++++++++++++++++ 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 2566f999cd3..a1a10c8edd3 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -128,13 +128,13 @@ def _convert_delta_to_message_chunk( _delta: Dict, default_class: Type[BaseMessageChunk] ) -> BaseMessageChunk: role = _delta.get("role") - content = _delta.get("content", "") + content = _delta.get("content") or "" if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) elif role == "assistant" or default_class == AIMessageChunk: additional_kwargs: Dict = {} if tool_calls := _delta.get("tool_calls"): - additional_kwargs["tool_calls"] = [tc.model_dump() for tc in tool_calls] + additional_kwargs["tool_calls"] = tool_calls return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) elif role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) @@ -355,8 +355,6 @@ class ChatMistralAI(BaseChatModel): if len(chunk["choices"]) == 0: continue delta = chunk["choices"][0]["delta"] - if not delta["content"]: - continue new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) # make future chunks same type as first chunk default_chunk_class = new_chunk.__class__ @@ -384,8 +382,6 @@ class ChatMistralAI(BaseChatModel): if len(chunk["choices"]) == 0: continue delta = chunk["choices"][0]["delta"] - if not delta["content"]: - continue new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) # make future chunks same type as first chunk default_chunk_class = new_chunk.__class__ diff --git a/libs/partners/mistralai/pyproject.toml b/libs/partners/mistralai/pyproject.toml index 33fe734b3dd..b076c1861e4 100644 --- a/libs/partners/mistralai/pyproject.toml +++ b/libs/partners/mistralai/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-mistralai" -version = "0.1.0rc1" +version = "0.1.0rc2" description = "An integration package connecting Mistral and LangChain" authors = [] readme = "README.md" diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index d4086643ebc..b292abf7003 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -1,5 +1,11 @@ """Test ChatMistral chat model.""" +import json +from typing import Any + +from langchain_core.messages import AIMessageChunk +from langchain_core.pydantic_v1 import BaseModel + from langchain_mistralai.chat_models import ChatMistralAI @@ -83,3 +89,58 @@ def test_structured_output() -> None: "What weighs more a pound of bricks or a pound of feathers" ) assert isinstance(result, dict) + + +def test_streaming_structured_output() -> None: + llm = ChatMistralAI(model="mistral-large", temperature=0) + + class Person(BaseModel): + name: str + age: int + + structured_llm = llm.with_structured_output(Person) + strm = structured_llm.stream("Erick, 27 years old") + chunk_num = 0 + for chunk in strm: + assert chunk_num == 0, "should only have one chunk with model" + assert isinstance(chunk, Person) + assert chunk.name == "Erick" + assert chunk.age == 27 + chunk_num += 1 + + +def test_streaming_tool_call() -> None: + llm = ChatMistralAI(model="mistral-large", temperature=0) + + class Person(BaseModel): + name: str + age: int + + tool_llm = llm.bind_tools([Person]) + + # where it calls the tool + strm = tool_llm.stream("Erick, 27 years old") + + additional_kwargs = None + for chunk in strm: + assert isinstance(chunk, AIMessageChunk) + assert chunk.content == "" + additional_kwargs = chunk.additional_kwargs + + assert additional_kwargs is not None + assert "tool_calls" in additional_kwargs + assert len(additional_kwargs["tool_calls"]) == 1 + assert additional_kwargs["tool_calls"][0]["function"]["name"] == "Person" + assert json.loads(additional_kwargs["tool_calls"][0]["function"]["arguments"]) == { + "name": "Erick", + "age": 27, + } + + # where it doesn't call the tool + strm = tool_llm.stream("What is 2+2?") + acc: Any = None + for chunk in strm: + assert isinstance(chunk, AIMessageChunk) + acc = chunk if acc is None else acc + chunk + assert acc.content != "" + assert "tool_calls" not in acc.additional_kwargs