[Partner]: Add metadata to stream response (#22716)

Adds `response_metadata` to stream responses from OpenAI. This is
returned with `invoke` normally, but wasn't implemented for `stream`.

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Hakan Özdemir
2024-06-17 16:46:50 +03:00
committed by GitHub
parent 42a379c75c
commit c437b1aab7
3 changed files with 56 additions and 23 deletions

View File

@@ -478,7 +478,7 @@ class BaseChatOpenAI(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
with self.client.create(messages=message_dicts, **params) as response:
for chunk in response:
if not isinstance(chunk, dict):
@@ -490,7 +490,7 @@ class BaseChatOpenAI(BaseChatModel):
output_tokens=token_usage.get("completion_tokens", 0),
total_tokens=token_usage.get("total_tokens", 0),
)
chunk = ChatGenerationChunk(
generation_chunk = ChatGenerationChunk(
message=default_chunk_class(
content="", usage_metadata=usage_metadata
)
@@ -501,24 +501,29 @@ class BaseChatOpenAI(BaseChatModel):
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
chunk = _convert_delta_to_message_chunk(
message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
default_chunk_class = message_chunk.__class__
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(
chunk.text, chunk=chunk, logprobs=logprobs
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
)
yield chunk
yield generation_chunk
def _generate(
self,
@@ -596,7 +601,7 @@ class BaseChatOpenAI(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
response = await self.async_client.create(messages=message_dicts, **params)
async with response:
async for chunk in response:
@@ -609,7 +614,7 @@ class BaseChatOpenAI(BaseChatModel):
output_tokens=token_usage.get("completion_tokens", 0),
total_tokens=token_usage.get("total_tokens", 0),
)
chunk = ChatGenerationChunk(
generation_chunk = ChatGenerationChunk(
message=default_chunk_class(
content="", usage_metadata=usage_metadata
)
@@ -620,24 +625,31 @@ class BaseChatOpenAI(BaseChatModel):
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
chunk = _convert_delta_to_message_chunk(
message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
default_chunk_class = message_chunk.__class__
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
if run_manager:
await run_manager.on_llm_new_token(
token=chunk.text, chunk=chunk, logprobs=logprobs
token=generation_chunk.text,
chunk=generation_chunk,
logprobs=logprobs,
)
yield chunk
yield generation_chunk
async def _agenerate(
self,