mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 23:41:46 +00:00
mistralai[patch]: streaming tool calls (#19469)
This commit is contained in:
@@ -128,13 +128,13 @@ def _convert_delta_to_message_chunk(
|
||||
_delta: Dict, default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = _delta.get("role")
|
||||
content = _delta.get("content", "")
|
||||
content = _delta.get("content") or ""
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
additional_kwargs: Dict = {}
|
||||
if tool_calls := _delta.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = [tc.model_dump() for tc in tool_calls]
|
||||
additional_kwargs["tool_calls"] = tool_calls
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
@@ -355,8 +355,6 @@ class ChatMistralAI(BaseChatModel):
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
if not delta["content"]:
|
||||
continue
|
||||
new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
# make future chunks same type as first chunk
|
||||
default_chunk_class = new_chunk.__class__
|
||||
@@ -384,8 +382,6 @@ class ChatMistralAI(BaseChatModel):
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
if not delta["content"]:
|
||||
continue
|
||||
new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
# make future chunks same type as first chunk
|
||||
default_chunk_class = new_chunk.__class__
|
||||
|
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-mistralai"
|
||||
version = "0.1.0rc1"
|
||||
version = "0.1.0rc2"
|
||||
description = "An integration package connecting Mistral and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
|
@@ -1,5 +1,11 @@
|
||||
"""Test ChatMistral chat model."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
from langchain_mistralai.chat_models import ChatMistralAI
|
||||
|
||||
|
||||
@@ -83,3 +89,58 @@ def test_structured_output() -> None:
|
||||
"What weighs more a pound of bricks or a pound of feathers"
|
||||
)
|
||||
assert isinstance(result, dict)
|
||||
|
||||
|
||||
def test_streaming_structured_output() -> None:
|
||||
llm = ChatMistralAI(model="mistral-large", temperature=0)
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
structured_llm = llm.with_structured_output(Person)
|
||||
strm = structured_llm.stream("Erick, 27 years old")
|
||||
chunk_num = 0
|
||||
for chunk in strm:
|
||||
assert chunk_num == 0, "should only have one chunk with model"
|
||||
assert isinstance(chunk, Person)
|
||||
assert chunk.name == "Erick"
|
||||
assert chunk.age == 27
|
||||
chunk_num += 1
|
||||
|
||||
|
||||
def test_streaming_tool_call() -> None:
|
||||
llm = ChatMistralAI(model="mistral-large", temperature=0)
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
tool_llm = llm.bind_tools([Person])
|
||||
|
||||
# where it calls the tool
|
||||
strm = tool_llm.stream("Erick, 27 years old")
|
||||
|
||||
additional_kwargs = None
|
||||
for chunk in strm:
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
assert chunk.content == ""
|
||||
additional_kwargs = chunk.additional_kwargs
|
||||
|
||||
assert additional_kwargs is not None
|
||||
assert "tool_calls" in additional_kwargs
|
||||
assert len(additional_kwargs["tool_calls"]) == 1
|
||||
assert additional_kwargs["tool_calls"][0]["function"]["name"] == "Person"
|
||||
assert json.loads(additional_kwargs["tool_calls"][0]["function"]["arguments"]) == {
|
||||
"name": "Erick",
|
||||
"age": 27,
|
||||
}
|
||||
|
||||
# where it doesn't call the tool
|
||||
strm = tool_llm.stream("What is 2+2?")
|
||||
acc: Any = None
|
||||
for chunk in strm:
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
acc = chunk if acc is None else acc + chunk
|
||||
assert acc.content != ""
|
||||
assert "tool_calls" not in acc.additional_kwargs
|
||||
|
Reference in New Issue
Block a user