mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 13:36:15 +00:00
[Partner]: Add metadata to stream response (#22716)
Adds `response_metadata` to stream responses from OpenAI. This is returned with `invoke` normally, but wasn't implemented for `stream`. --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
@@ -478,7 +478,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
with self.client.create(messages=message_dicts, **params) as response:
|
||||
for chunk in response:
|
||||
if not isinstance(chunk, dict):
|
||||
@@ -490,7 +490,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
output_tokens=token_usage.get("completion_tokens", 0),
|
||||
total_tokens=token_usage.get("total_tokens", 0),
|
||||
)
|
||||
chunk = ChatGenerationChunk(
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=default_chunk_class(
|
||||
content="", usage_metadata=usage_metadata
|
||||
)
|
||||
@@ -501,24 +501,29 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
choice = chunk["choices"][0]
|
||||
if choice["delta"] is None:
|
||||
continue
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
message_chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if model_name := chunk.get("model"):
|
||||
generation_info["model_name"] = model_name
|
||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
default_chunk_class = message_chunk.__class__
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text, chunk=chunk, logprobs=logprobs
|
||||
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
|
||||
)
|
||||
yield chunk
|
||||
yield generation_chunk
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@@ -596,7 +601,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
response = await self.async_client.create(messages=message_dicts, **params)
|
||||
async with response:
|
||||
async for chunk in response:
|
||||
@@ -609,7 +614,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
output_tokens=token_usage.get("completion_tokens", 0),
|
||||
total_tokens=token_usage.get("total_tokens", 0),
|
||||
)
|
||||
chunk = ChatGenerationChunk(
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=default_chunk_class(
|
||||
content="", usage_metadata=usage_metadata
|
||||
)
|
||||
@@ -620,24 +625,31 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
choice = chunk["choices"][0]
|
||||
if choice["delta"] is None:
|
||||
continue
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
message_chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if model_name := chunk.get("model"):
|
||||
generation_info["model_name"] = model_name
|
||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
default_chunk_class = message_chunk.__class__
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
token=chunk.text, chunk=chunk, logprobs=logprobs
|
||||
token=generation_chunk.text,
|
||||
chunk=generation_chunk,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
yield chunk
|
||||
yield generation_chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user