mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-15 17:33:53 +00:00
groq[patch]: add usage_metadata to (a)invoke and (a)stream (#22834)
This commit is contained in:
parent
e01e5d5a91
commit
b626c3ca23
@ -307,7 +307,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
)
|
)
|
||||||
chat_result = self._create_chat_result(response)
|
chat_result = self._create_chat_result(response)
|
||||||
generation = chat_result.generations[0]
|
generation = chat_result.generations[0]
|
||||||
message = generation.message
|
message = cast(AIMessage, generation.message)
|
||||||
tool_call_chunks = [
|
tool_call_chunks = [
|
||||||
{
|
{
|
||||||
"name": rtc["function"].get("name"),
|
"name": rtc["function"].get("name"),
|
||||||
@ -322,6 +322,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
content=message.content,
|
content=message.content,
|
||||||
additional_kwargs=message.additional_kwargs,
|
additional_kwargs=message.additional_kwargs,
|
||||||
tool_call_chunks=tool_call_chunks,
|
tool_call_chunks=tool_call_chunks,
|
||||||
|
usage_metadata=message.usage_metadata,
|
||||||
),
|
),
|
||||||
generation_info=generation.generation_info,
|
generation_info=generation.generation_info,
|
||||||
)
|
)
|
||||||
@ -337,30 +338,30 @@ class ChatGroq(BaseChatModel):
|
|||||||
|
|
||||||
params = {**params, **kwargs, "stream": True}
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
default_chunk_class = AIMessageChunk
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||||
for chunk in self.client.create(messages=message_dicts, **params):
|
for chunk in self.client.create(messages=message_dicts, **params):
|
||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.dict()
|
chunk = chunk.dict()
|
||||||
if len(chunk["choices"]) == 0:
|
if len(chunk["choices"]) == 0:
|
||||||
continue
|
continue
|
||||||
choice = chunk["choices"][0]
|
choice = chunk["choices"][0]
|
||||||
chunk = _convert_delta_to_message_chunk(
|
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||||
choice["delta"], default_chunk_class
|
|
||||||
)
|
|
||||||
generation_info = {}
|
generation_info = {}
|
||||||
if finish_reason := choice.get("finish_reason"):
|
if finish_reason := choice.get("finish_reason"):
|
||||||
generation_info["finish_reason"] = finish_reason
|
generation_info["finish_reason"] = finish_reason
|
||||||
logprobs = choice.get("logprobs")
|
logprobs = choice.get("logprobs")
|
||||||
if logprobs:
|
if logprobs:
|
||||||
generation_info["logprobs"] = logprobs
|
generation_info["logprobs"] = logprobs
|
||||||
default_chunk_class = chunk.__class__
|
default_chunk_class = message_chunk.__class__
|
||||||
chunk = ChatGenerationChunk(
|
generation_chunk = ChatGenerationChunk(
|
||||||
message=chunk, generation_info=generation_info or None
|
message=message_chunk, generation_info=generation_info or None
|
||||||
)
|
)
|
||||||
|
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
|
run_manager.on_llm_new_token(
|
||||||
yield chunk
|
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
|
||||||
|
)
|
||||||
|
yield generation_chunk
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
self,
|
self,
|
||||||
@ -378,7 +379,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
)
|
)
|
||||||
chat_result = self._create_chat_result(response)
|
chat_result = self._create_chat_result(response)
|
||||||
generation = chat_result.generations[0]
|
generation = chat_result.generations[0]
|
||||||
message = generation.message
|
message = cast(AIMessage, generation.message)
|
||||||
tool_call_chunks = [
|
tool_call_chunks = [
|
||||||
{
|
{
|
||||||
"name": rtc["function"].get("name"),
|
"name": rtc["function"].get("name"),
|
||||||
@ -393,6 +394,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
content=message.content,
|
content=message.content,
|
||||||
additional_kwargs=message.additional_kwargs,
|
additional_kwargs=message.additional_kwargs,
|
||||||
tool_call_chunks=tool_call_chunks,
|
tool_call_chunks=tool_call_chunks,
|
||||||
|
usage_metadata=message.usage_metadata,
|
||||||
),
|
),
|
||||||
generation_info=generation.generation_info,
|
generation_info=generation.generation_info,
|
||||||
)
|
)
|
||||||
@ -408,7 +410,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
|
|
||||||
params = {**params, **kwargs, "stream": True}
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
default_chunk_class = AIMessageChunk
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||||
async for chunk in await self.async_client.create(
|
async for chunk in await self.async_client.create(
|
||||||
messages=message_dicts, **params
|
messages=message_dicts, **params
|
||||||
):
|
):
|
||||||
@ -417,25 +419,25 @@ class ChatGroq(BaseChatModel):
|
|||||||
if len(chunk["choices"]) == 0:
|
if len(chunk["choices"]) == 0:
|
||||||
continue
|
continue
|
||||||
choice = chunk["choices"][0]
|
choice = chunk["choices"][0]
|
||||||
chunk = _convert_delta_to_message_chunk(
|
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||||
choice["delta"], default_chunk_class
|
|
||||||
)
|
|
||||||
generation_info = {}
|
generation_info = {}
|
||||||
if finish_reason := choice.get("finish_reason"):
|
if finish_reason := choice.get("finish_reason"):
|
||||||
generation_info["finish_reason"] = finish_reason
|
generation_info["finish_reason"] = finish_reason
|
||||||
logprobs = choice.get("logprobs")
|
logprobs = choice.get("logprobs")
|
||||||
if logprobs:
|
if logprobs:
|
||||||
generation_info["logprobs"] = logprobs
|
generation_info["logprobs"] = logprobs
|
||||||
default_chunk_class = chunk.__class__
|
default_chunk_class = message_chunk.__class__
|
||||||
chunk = ChatGenerationChunk(
|
generation_chunk = ChatGenerationChunk(
|
||||||
message=chunk, generation_info=generation_info or None
|
message=message_chunk, generation_info=generation_info or None
|
||||||
)
|
)
|
||||||
|
|
||||||
if run_manager:
|
if run_manager:
|
||||||
await run_manager.on_llm_new_token(
|
await run_manager.on_llm_new_token(
|
||||||
token=chunk.text, chunk=chunk, logprobs=logprobs
|
token=generation_chunk.text,
|
||||||
|
chunk=generation_chunk,
|
||||||
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
yield chunk
|
yield generation_chunk
|
||||||
|
|
||||||
#
|
#
|
||||||
# Internal methods
|
# Internal methods
|
||||||
@ -459,8 +461,19 @@ class ChatGroq(BaseChatModel):
|
|||||||
generations = []
|
generations = []
|
||||||
if not isinstance(response, dict):
|
if not isinstance(response, dict):
|
||||||
response = response.dict()
|
response = response.dict()
|
||||||
|
token_usage = response.get("usage", {})
|
||||||
for res in response["choices"]:
|
for res in response["choices"]:
|
||||||
message = _convert_dict_to_message(res["message"])
|
message = _convert_dict_to_message(res["message"])
|
||||||
|
if token_usage and isinstance(message, AIMessage):
|
||||||
|
input_tokens = token_usage.get("prompt_tokens", 0)
|
||||||
|
output_tokens = token_usage.get("completion_tokens", 0)
|
||||||
|
message.usage_metadata = {
|
||||||
|
"input_tokens": input_tokens,
|
||||||
|
"output_tokens": output_tokens,
|
||||||
|
"total_tokens": token_usage.get(
|
||||||
|
"total_tokens", input_tokens + output_tokens
|
||||||
|
),
|
||||||
|
}
|
||||||
generation_info = dict(finish_reason=res.get("finish_reason"))
|
generation_info = dict(finish_reason=res.get("finish_reason"))
|
||||||
if "logprobs" in res:
|
if "logprobs" in res:
|
||||||
generation_info["logprobs"] = res["logprobs"]
|
generation_info["logprobs"] = res["logprobs"]
|
||||||
@ -469,7 +482,6 @@ class ChatGroq(BaseChatModel):
|
|||||||
generation_info=generation_info,
|
generation_info=generation_info,
|
||||||
)
|
)
|
||||||
generations.append(gen)
|
generations.append(gen)
|
||||||
token_usage = response.get("usage", {})
|
|
||||||
llm_output = {
|
llm_output = {
|
||||||
"token_usage": token_usage,
|
"token_usage": token_usage,
|
||||||
"model_name": self.model_name,
|
"model_name": self.model_name,
|
||||||
@ -892,9 +904,11 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
return message_dict
|
return message_dict
|
||||||
|
|
||||||
|
|
||||||
def _convert_delta_to_message_chunk(
|
def _convert_chunk_to_message_chunk(
|
||||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
chunk: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||||
) -> BaseMessageChunk:
|
) -> BaseMessageChunk:
|
||||||
|
choice = chunk["choices"][0]
|
||||||
|
_dict = choice["delta"]
|
||||||
role = cast(str, _dict.get("role"))
|
role = cast(str, _dict.get("role"))
|
||||||
content = cast(str, _dict.get("content") or "")
|
content = cast(str, _dict.get("content") or "")
|
||||||
additional_kwargs: Dict = {}
|
additional_kwargs: Dict = {}
|
||||||
@ -909,7 +923,21 @@ def _convert_delta_to_message_chunk(
|
|||||||
if role == "user" or default_class == HumanMessageChunk:
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
return HumanMessageChunk(content=content)
|
return HumanMessageChunk(content=content)
|
||||||
elif role == "assistant" or default_class == AIMessageChunk:
|
elif role == "assistant" or default_class == AIMessageChunk:
|
||||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
if usage := (chunk.get("x_groq") or {}).get("usage"):
|
||||||
|
input_tokens = usage.get("prompt_tokens", 0)
|
||||||
|
output_tokens = usage.get("completion_tokens", 0)
|
||||||
|
usage_metadata = {
|
||||||
|
"input_tokens": input_tokens,
|
||||||
|
"output_tokens": output_tokens,
|
||||||
|
"total_tokens": usage.get("total_tokens", input_tokens + output_tokens),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
usage_metadata = None
|
||||||
|
return AIMessageChunk(
|
||||||
|
content=content,
|
||||||
|
additional_kwargs=additional_kwargs,
|
||||||
|
usage_metadata=usage_metadata,
|
||||||
|
)
|
||||||
elif role == "system" or default_class == SystemMessageChunk:
|
elif role == "system" or default_class == SystemMessageChunk:
|
||||||
return SystemMessageChunk(content=content)
|
return SystemMessageChunk(content=content)
|
||||||
elif role == "function" or default_class == FunctionMessageChunk:
|
elif role == "function" or default_class == FunctionMessageChunk:
|
||||||
|
15
libs/partners/groq/poetry.lock
generated
15
libs/partners/groq/poetry.lock
generated
@ -323,7 +323,7 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.2.4"
|
version = "0.2.5"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
@ -332,15 +332,12 @@ develop = true
|
|||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
jsonpatch = "^1.33"
|
jsonpatch = "^1.33"
|
||||||
langsmith = "^0.1.66"
|
langsmith = "^0.1.75"
|
||||||
packaging = "^23.2"
|
packaging = "^23.2"
|
||||||
pydantic = ">=1,<3"
|
pydantic = ">=1,<3"
|
||||||
PyYAML = ">=5.3"
|
PyYAML = ">=5.3"
|
||||||
tenacity = "^8.1.0"
|
tenacity = "^8.1.0"
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
extended-testing = ["jinja2 (>=3,<4)"]
|
|
||||||
|
|
||||||
[package.source]
|
[package.source]
|
||||||
type = "directory"
|
type = "directory"
|
||||||
url = "../../core"
|
url = "../../core"
|
||||||
@ -364,13 +361,13 @@ url = "../../standard-tests"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langsmith"
|
name = "langsmith"
|
||||||
version = "0.1.73"
|
version = "0.1.76"
|
||||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "<4.0,>=3.8.1"
|
python-versions = "<4.0,>=3.8.1"
|
||||||
files = [
|
files = [
|
||||||
{file = "langsmith-0.1.73-py3-none-any.whl", hash = "sha256:38bfcce2cfcf0b2da2e9628b903c9e768e1ce59d450e8a584514c1638c595e93"},
|
{file = "langsmith-0.1.76-py3-none-any.whl", hash = "sha256:4b8cb14f2233d9673ce9e6e3d545359946d9690a2c1457ab01e7459ec97b964e"},
|
||||||
{file = "langsmith-0.1.73.tar.gz", hash = "sha256:0055471cb1fddb76ec65499716764ad0b0314affbdf33ff1f72ad5e2d6a3b224"},
|
{file = "langsmith-0.1.76.tar.gz", hash = "sha256:5829f997495c0f9a39f91fe0a57e0cb702e8642e6948945f5bb9f46337db7732"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -918,4 +915,4 @@ watchmedo = ["PyYAML (>=3.10)"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "672ecb755a4d938d114d4ffa96455758ecc05943c06e49e9bad3dfe65ee3c810"
|
content-hash = "3cbd3deff4e93bc6337655edfbb328e3e2d5c3dff337ce911c4327f39bc231f9"
|
||||||
|
@ -12,7 +12,7 @@ license = "MIT"
|
|||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8.1,<4.0"
|
python = ">=3.8.1,<4.0"
|
||||||
langchain-core = ">=0.2.0,<0.3"
|
langchain-core = ">=0.2.2,<0.3"
|
||||||
groq = ">=0.4.1,<1"
|
groq = ">=0.4.1,<1"
|
||||||
|
|
||||||
[tool.poetry.group.test]
|
[tool.poetry.group.test]
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Test ChatGroq chat model."""
|
"""Test ChatGroq chat model."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
@ -93,9 +93,28 @@ async def test_astream() -> None:
|
|||||||
"""Test streaming tokens from Groq."""
|
"""Test streaming tokens from Groq."""
|
||||||
chat = ChatGroq(max_tokens=10)
|
chat = ChatGroq(max_tokens=10)
|
||||||
|
|
||||||
|
full: Optional[BaseMessageChunk] = None
|
||||||
|
chunks_with_token_counts = 0
|
||||||
async for token in chat.astream("Welcome to the Groqetship!"):
|
async for token in chat.astream("Welcome to the Groqetship!"):
|
||||||
assert isinstance(token, BaseMessageChunk)
|
assert isinstance(token, AIMessageChunk)
|
||||||
assert isinstance(token.content, str)
|
assert isinstance(token.content, str)
|
||||||
|
full = token if full is None else full + token
|
||||||
|
if token.usage_metadata is not None:
|
||||||
|
chunks_with_token_counts += 1
|
||||||
|
if chunks_with_token_counts != 1:
|
||||||
|
raise AssertionError(
|
||||||
|
"Expected exactly one chunk with token counts. "
|
||||||
|
"AIMessageChunk aggregation adds counts. Check that "
|
||||||
|
"this is behaving properly."
|
||||||
|
)
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
assert full.usage_metadata is not None
|
||||||
|
assert full.usage_metadata["input_tokens"] > 0
|
||||||
|
assert full.usage_metadata["output_tokens"] > 0
|
||||||
|
assert (
|
||||||
|
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
|
||||||
|
== full.usage_metadata["total_tokens"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -9,22 +9,11 @@ from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
|||||||
from langchain_groq import ChatGroq
|
from langchain_groq import ChatGroq
|
||||||
|
|
||||||
|
|
||||||
class TestMistralStandard(ChatModelIntegrationTests):
|
class TestGroqStandard(ChatModelIntegrationTests):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||||
return ChatGroq
|
return ChatGroq
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="Not implemented.")
|
|
||||||
def test_usage_metadata(
|
|
||||||
self,
|
|
||||||
chat_model_class: Type[BaseChatModel],
|
|
||||||
chat_model_params: dict,
|
|
||||||
) -> None:
|
|
||||||
super().test_usage_metadata(
|
|
||||||
chat_model_class,
|
|
||||||
chat_model_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="Not yet implemented.")
|
@pytest.mark.xfail(reason="Not yet implemented.")
|
||||||
def test_tool_message_histories_list_content(
|
def test_tool_message_histories_list_content(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user