(a)stream

This commit is contained in:
Chester Curme 2025-03-11 16:05:40 -04:00
parent 3951f90a9e
commit 03035b8d73
2 changed files with 118 additions and 15 deletions

View File

@ -356,6 +356,27 @@ def _convert_delta_to_message_chunk(
return default_class(content=content, id=id_) # type: ignore
def _convert_responses_chunk_to_generation_chunk(
chunk: Any,
) -> Optional[ChatGenerationChunk]:
content = ""
generation_info = None
usage_metadata = None
if chunk.type == "response.output_text.delta":
content = chunk.delta
elif chunk.type == "response.completed":
token_usage = chunk.response.usage.model_dump() if chunk.response.usage else {}
usage_metadata = _create_usage_metadata_responses(token_usage)
generation_info = {"model_name": chunk.response.model}
else:
return None
return ChatGenerationChunk(
message=AIMessageChunk(content=content, usage_metadata=usage_metadata),
generation_info=generation_info,
)
def _update_token_usage(
overall_token_usage: Union[int, dict], new_usage: Union[int, dict]
) -> Union[int, dict]:
@ -410,7 +431,7 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
def _is_builtin_tool(tool: dict) -> bool:
return set(tool.keys()) == {"type"}
return "type" in tool and tool["type"] != "function"
def _transform_payload_for_responses(payload: dict) -> dict:
@ -741,6 +762,54 @@ class BaseChatOpenAI(BaseChatModel):
)
return generation_chunk
def _stream_responses(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
responses_payload = _transform_payload_for_responses(payload)
context_manager = self.root_client.responses.create(**responses_payload)
with context_manager as response:
for chunk in response:
if generation_chunk := _convert_responses_chunk_to_generation_chunk(
chunk
):
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
yield generation_chunk
async def _astream_responses(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
responses_payload = _transform_payload_for_responses(payload)
context_manager = await self.root_async_client.responses.create(
**responses_payload
)
async with context_manager as response:
async for chunk in response:
if generation_chunk := _convert_responses_chunk_to_generation_chunk(
chunk
):
if run_manager:
await run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
yield generation_chunk
def _stream(
self,
messages: List[BaseMessage],
@ -2143,27 +2212,38 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
) -> Iterator[ChatGenerationChunk]:
"""Set default stream_options."""
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
# Note: stream_options is not a valid parameter for Azure OpenAI.
# To support users proxying Azure through ChatOpenAI, here we only specify
# stream_options if include_usage is set to True.
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new
# for release notes.
if stream_usage:
kwargs["stream_options"] = {"include_usage": stream_usage}
if "tools" in kwargs and any(
_is_builtin_tool(tool) for tool in kwargs["tools"]
):
return super()._stream_responses(*args, **kwargs)
else:
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
# Note: stream_options is not a valid parameter for Azure OpenAI.
# To support users proxying Azure through ChatOpenAI, here we only specify
# stream_options if include_usage is set to True.
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new
# for release notes.
if stream_usage:
kwargs["stream_options"] = {"include_usage": stream_usage}
return super()._stream(*args, **kwargs)
return super()._stream(*args, **kwargs)
async def _astream(
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
) -> AsyncIterator[ChatGenerationChunk]:
"""Set default stream_options."""
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
if stream_usage:
kwargs["stream_options"] = {"include_usage": stream_usage}
if "tools" in kwargs and any(
_is_builtin_tool(tool) for tool in kwargs["tools"]
):
async for chunk in super()._astream_responses(*args, **kwargs):
yield chunk
else:
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
if stream_usage:
kwargs["stream_options"] = {"include_usage": stream_usage}
async for chunk in super()._astream(*args, **kwargs):
yield chunk
async for chunk in super()._astream(*args, **kwargs):
yield chunk
def with_structured_output(
self,

View File

@ -1245,6 +1245,18 @@ def test_web_search() -> None:
assert response.response_metadata["model_name"]
assert response.response_metadata["status"]
# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.content
assert full.usage_metadata
async def test_web_search_async() -> None:
llm = ChatOpenAI(model="gpt-4o")
@ -1260,3 +1272,14 @@ async def test_web_search_async() -> None:
assert response.usage_metadata["total_tokens"] > 0
assert response.response_metadata["model_name"]
assert response.response_metadata["status"]
full: Optional[BaseMessageChunk] = None
async for chunk in llm.astream(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.content
assert full.usage_metadata