mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-02 21:23:32 +00:00
mistral[patch]: add usage_metadata to (a)invoke and (a)stream (#22781)
This commit is contained in:
parent
20e3662acf
commit
936aedd10c
@ -186,9 +186,10 @@ async def acompletion_with_retry(
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_delta: Dict, default_class: Type[BaseMessageChunk]
|
||||
def _convert_chunk_to_message_chunk(
|
||||
chunk: Dict, default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
_delta = chunk["choices"][0]["delta"]
|
||||
role = _delta.get("role")
|
||||
content = _delta.get("content") or ""
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
@ -216,10 +217,19 @@ def _convert_delta_to_message_chunk(
|
||||
pass
|
||||
else:
|
||||
tool_call_chunks = []
|
||||
if token_usage := chunk.get("usage"):
|
||||
usage_metadata = {
|
||||
"input_tokens": token_usage.get("prompt_tokens", 0),
|
||||
"output_tokens": token_usage.get("completion_tokens", 0),
|
||||
"total_tokens": token_usage.get("total_tokens", 0),
|
||||
}
|
||||
else:
|
||||
usage_metadata = None
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
@ -484,14 +494,21 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
def _create_chat_result(self, response: Dict) -> ChatResult:
|
||||
generations = []
|
||||
token_usage = response.get("usage", {})
|
||||
for res in response["choices"]:
|
||||
finish_reason = res.get("finish_reason")
|
||||
message = _convert_mistral_chat_message_to_message(res["message"])
|
||||
if token_usage and isinstance(message, AIMessage):
|
||||
message.usage_metadata = {
|
||||
"input_tokens": token_usage.get("prompt_tokens", 0),
|
||||
"output_tokens": token_usage.get("completion_tokens", 0),
|
||||
"total_tokens": token_usage.get("total_tokens", 0),
|
||||
}
|
||||
gen = ChatGeneration(
|
||||
message=_convert_mistral_chat_message_to_message(res["message"]),
|
||||
message=message,
|
||||
generation_info={"finish_reason": finish_reason},
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = response.get("usage", {})
|
||||
|
||||
llm_output = {"token_usage": token_usage, "model": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
@ -525,8 +542,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
):
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||
# make future chunks same type as first chunk
|
||||
default_chunk_class = new_chunk.__class__
|
||||
gen_chunk = ChatGenerationChunk(message=new_chunk)
|
||||
@ -552,8 +568,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
):
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||
# make future chunks same type as first chunk
|
||||
default_chunk_class = new_chunk.__class__
|
||||
gen_chunk = ChatGenerationChunk(message=new_chunk)
|
||||
|
15
libs/partners/mistralai/poetry.lock
generated
15
libs/partners/mistralai/poetry.lock
generated
@ -392,7 +392,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.2.0"
|
||||
version = "0.2.5"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -401,15 +401,12 @@ develop = true
|
||||
|
||||
[package.dependencies]
|
||||
jsonpatch = "^1.33"
|
||||
langsmith = "^0.1.0"
|
||||
langsmith = "^0.1.75"
|
||||
packaging = "^23.2"
|
||||
pydantic = ">=1,<3"
|
||||
PyYAML = ">=5.3"
|
||||
tenacity = "^8.1.0"
|
||||
|
||||
[package.extras]
|
||||
extended-testing = ["jinja2 (>=3,<4)"]
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
url = "../../core"
|
||||
@ -433,13 +430,13 @@ url = "../../standard-tests"
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.1.58"
|
||||
version = "0.1.76"
|
||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8.1"
|
||||
files = [
|
||||
{file = "langsmith-0.1.58-py3-none-any.whl", hash = "sha256:1148cc836ec99d1b2f37cd2fa3014fcac213bb6bad798a2b21bb9111c18c9768"},
|
||||
{file = "langsmith-0.1.58.tar.gz", hash = "sha256:a5060933c1fb3006b498ec849677993329d7e6138bdc2ec044068ab806e09c39"},
|
||||
{file = "langsmith-0.1.76-py3-none-any.whl", hash = "sha256:4b8cb14f2233d9673ce9e6e3d545359946d9690a2c1457ab01e7459ec97b964e"},
|
||||
{file = "langsmith-0.1.76.tar.gz", hash = "sha256:5829f997495c0f9a39f91fe0a57e0cb702e8642e6948945f5bb9f46337db7732"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -1051,4 +1048,4 @@ zstd = ["zstandard (>=0.18.0)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "4a5a57d01c791de831f03fb309541443dc8bb51f5068ccfb7bcb77490c2eb6c3"
|
||||
content-hash = "af4576b4e41d3e01716cff9476d6130dd0c5ef7b98bfd02fefd1f5b730574b6e"
|
||||
|
@ -12,7 +12,7 @@ license = "MIT"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = ">=0.2.0,<0.3"
|
||||
langchain-core = ">=0.2.2,<0.3"
|
||||
tokenizers = ">=0.15.1,<1"
|
||||
httpx = ">=0.25.2,<1"
|
||||
httpx-sse = ">=0.3.1,<1"
|
||||
|
@ -1,11 +1,12 @@
|
||||
"""Test ChatMistral chat model."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
@ -25,8 +26,28 @@ async def test_astream() -> None:
|
||||
"""Test streaming tokens from ChatMistralAI."""
|
||||
llm = ChatMistralAI()
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
chunks_with_token_counts = 0
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token, AIMessageChunk)
|
||||
assert isinstance(token.content, str)
|
||||
full = token if full is None else full + token
|
||||
if token.usage_metadata is not None:
|
||||
chunks_with_token_counts += 1
|
||||
if chunks_with_token_counts != 1:
|
||||
raise AssertionError(
|
||||
"Expected exactly one chunk with token counts. "
|
||||
"AIMessageChunk aggregation adds counts. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.usage_metadata is not None
|
||||
assert full.usage_metadata["input_tokens"] > 0
|
||||
assert full.usage_metadata["output_tokens"] > 0
|
||||
assert (
|
||||
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
|
||||
== full.usage_metadata["total_tokens"]
|
||||
)
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
|
@ -20,14 +20,3 @@ class TestMistralStandard(ChatModelIntegrationTests):
|
||||
"model": "mistral-large-latest",
|
||||
"temperature": 0,
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(reason="Not implemented.")
|
||||
def test_usage_metadata(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
super().test_usage_metadata(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user