core[patch]: Add LLM output to message response_metadata (#19158)

This will more easily expose token usage information.

CC @baskaryan

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Jacob Lee 2024-03-18 13:58:32 -07:00 committed by GitHub
parent 6fa1438334
commit bd329e9aad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 48 additions and 6 deletions

View File

@ -615,6 +615,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation
)
if len(result.generations) == 1 and result.llm_output is not None:
result.generations[0].message.response_metadata = {
**result.llm_output,
**result.generations[0].message.response_metadata,
}
if check_cache and llm_cache:
llm_cache.update(prompt, llm_string, result.generations)
return result
@ -651,6 +656,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation
)
if len(result.generations) == 1 and result.llm_output is not None:
result.generations[0].message.response_metadata = {
**result.llm_output,
**result.generations[0].message.response_metadata,
}
if check_cache and llm_cache:
await llm_cache.aupdate(prompt, llm_string, result.generations)
return result

View File

@ -403,35 +403,67 @@ def test_invoke() -> None:
assert isinstance(result.content, str)
def test_logprobs() -> None:
def test_response_metadata() -> None:
llm = ChatOpenAI()
result = llm.invoke([HumanMessage(content="I'm PickleRick")], logprobs=True)
assert result.response_metadata
assert all(
k in result.response_metadata
for k in (
"token_usage",
"model_name",
"logprobs",
"system_fingerprint",
"finish_reason",
)
)
assert "content" in result.response_metadata["logprobs"]
async def test_async_logprobs() -> None:
async def test_async_response_metadata() -> None:
llm = ChatOpenAI()
result = await llm.ainvoke([HumanMessage(content="I'm PickleRick")], logprobs=True)
assert result.response_metadata
assert all(
k in result.response_metadata
for k in (
"token_usage",
"model_name",
"logprobs",
"system_fingerprint",
"finish_reason",
)
)
assert "content" in result.response_metadata["logprobs"]
def test_logprobs_streaming() -> None:
def test_response_metadata_streaming() -> None:
llm = ChatOpenAI()
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream("I'm Pickle Rick", logprobs=True):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
assert cast(BaseMessageChunk, full).response_metadata
assert all(
k in cast(BaseMessageChunk, full).response_metadata
for k in (
"logprobs",
"finish_reason",
)
)
assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"]
async def test_async_logprobs_streaming() -> None:
async def test_async_response_metadata_streaming() -> None:
llm = ChatOpenAI()
full: Optional[BaseMessageChunk] = None
async for chunk in llm.astream("I'm Pickle Rick", logprobs=True):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
assert cast(BaseMessageChunk, full).response_metadata
assert all(
k in cast(BaseMessageChunk, full).response_metadata
for k in (
"logprobs",
"finish_reason",
)
)
assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"]