qianfan generate/agenerate with usage_metadata (#25332)

This commit is contained in:
Chen Xiabin 2024-08-13 21:24:41 +08:00 committed by GitHub
parent ebbe609193
commit 24155aa1ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -512,6 +512,7 @@ class QianfanChatEndpoint(BaseChatModel):
if self.streaming: if self.streaming:
completion = "" completion = ""
chat_generation_info: Dict = {} chat_generation_info: Dict = {}
usage_metadata: Optional[UsageMetadata] = None
for chunk in self._stream(messages, stop, run_manager, **kwargs): for chunk in self._stream(messages, stop, run_manager, **kwargs):
chat_generation_info = ( chat_generation_info = (
chunk.generation_info chunk.generation_info
@ -519,7 +520,14 @@ class QianfanChatEndpoint(BaseChatModel):
else chat_generation_info else chat_generation_info
) )
completion += chunk.text completion += chunk.text
lc_msg = AIMessage(content=completion, additional_kwargs={}) if isinstance(chunk.message, AIMessageChunk):
usage_metadata = chunk.message.usage_metadata
lc_msg = AIMessage(
content=completion,
additional_kwargs={},
usage_metadata=usage_metadata,
)
gen = ChatGeneration( gen = ChatGeneration(
message=lc_msg, message=lc_msg,
generation_info=dict(finish_reason="stop"), generation_info=dict(finish_reason="stop"),
@ -527,7 +535,7 @@ class QianfanChatEndpoint(BaseChatModel):
return ChatResult( return ChatResult(
generations=[gen], generations=[gen],
llm_output={ llm_output={
"token_usage": chat_generation_info.get("usage", {}), "token_usage": usage_metadata or {},
"model_name": self.model, "model_name": self.model,
}, },
) )
@ -556,6 +564,7 @@ class QianfanChatEndpoint(BaseChatModel):
if self.streaming: if self.streaming:
completion = "" completion = ""
chat_generation_info: Dict = {} chat_generation_info: Dict = {}
usage_metadata: Optional[UsageMetadata] = None
async for chunk in self._astream(messages, stop, run_manager, **kwargs): async for chunk in self._astream(messages, stop, run_manager, **kwargs):
chat_generation_info = ( chat_generation_info = (
chunk.generation_info chunk.generation_info
@ -564,7 +573,14 @@ class QianfanChatEndpoint(BaseChatModel):
) )
completion += chunk.text completion += chunk.text
lc_msg = AIMessage(content=completion, additional_kwargs={}) if isinstance(chunk.message, AIMessageChunk):
usage_metadata = chunk.message.usage_metadata
lc_msg = AIMessage(
content=completion,
additional_kwargs={},
usage_metadata=usage_metadata,
)
gen = ChatGeneration( gen = ChatGeneration(
message=lc_msg, message=lc_msg,
generation_info=dict(finish_reason="stop"), generation_info=dict(finish_reason="stop"),
@ -572,7 +588,7 @@ class QianfanChatEndpoint(BaseChatModel):
return ChatResult( return ChatResult(
generations=[gen], generations=[gen],
llm_output={ llm_output={
"token_usage": chat_generation_info.get("usage", {}), "token_usage": usage_metadata or {},
"model_name": self.model, "model_name": self.model,
}, },
) )