mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-19 09:30:15 +00:00
(a)stream
This commit is contained in:
parent
3951f90a9e
commit
03035b8d73
@ -356,6 +356,27 @@ def _convert_delta_to_message_chunk(
|
||||
return default_class(content=content, id=id_) # type: ignore
|
||||
|
||||
|
||||
def _convert_responses_chunk_to_generation_chunk(
|
||||
chunk: Any,
|
||||
) -> Optional[ChatGenerationChunk]:
|
||||
content = ""
|
||||
generation_info = None
|
||||
usage_metadata = None
|
||||
if chunk.type == "response.output_text.delta":
|
||||
content = chunk.delta
|
||||
elif chunk.type == "response.completed":
|
||||
token_usage = chunk.response.usage.model_dump() if chunk.response.usage else {}
|
||||
usage_metadata = _create_usage_metadata_responses(token_usage)
|
||||
generation_info = {"model_name": chunk.response.model}
|
||||
else:
|
||||
return None
|
||||
|
||||
return ChatGenerationChunk(
|
||||
message=AIMessageChunk(content=content, usage_metadata=usage_metadata),
|
||||
generation_info=generation_info,
|
||||
)
|
||||
|
||||
|
||||
def _update_token_usage(
|
||||
overall_token_usage: Union[int, dict], new_usage: Union[int, dict]
|
||||
) -> Union[int, dict]:
|
||||
@ -410,7 +431,7 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
|
||||
|
||||
|
||||
def _is_builtin_tool(tool: dict) -> bool:
|
||||
return set(tool.keys()) == {"type"}
|
||||
return "type" in tool and tool["type"] != "function"
|
||||
|
||||
|
||||
def _transform_payload_for_responses(payload: dict) -> dict:
|
||||
@ -741,6 +762,54 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
)
|
||||
return generation_chunk
|
||||
|
||||
def _stream_responses(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
kwargs["stream"] = True
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
responses_payload = _transform_payload_for_responses(payload)
|
||||
context_manager = self.root_client.responses.create(**responses_payload)
|
||||
|
||||
with context_manager as response:
|
||||
for chunk in response:
|
||||
if generation_chunk := _convert_responses_chunk_to_generation_chunk(
|
||||
chunk
|
||||
):
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
generation_chunk.text, chunk=generation_chunk
|
||||
)
|
||||
yield generation_chunk
|
||||
|
||||
async def _astream_responses(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
kwargs["stream"] = True
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
responses_payload = _transform_payload_for_responses(payload)
|
||||
context_manager = await self.root_async_client.responses.create(
|
||||
**responses_payload
|
||||
)
|
||||
|
||||
async with context_manager as response:
|
||||
async for chunk in response:
|
||||
if generation_chunk := _convert_responses_chunk_to_generation_chunk(
|
||||
chunk
|
||||
):
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
generation_chunk.text, chunk=generation_chunk
|
||||
)
|
||||
yield generation_chunk
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -2143,27 +2212,38 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
"""Set default stream_options."""
|
||||
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||
# Note: stream_options is not a valid parameter for Azure OpenAI.
|
||||
# To support users proxying Azure through ChatOpenAI, here we only specify
|
||||
# stream_options if include_usage is set to True.
|
||||
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new
|
||||
# for release notes.
|
||||
if stream_usage:
|
||||
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||
if "tools" in kwargs and any(
|
||||
_is_builtin_tool(tool) for tool in kwargs["tools"]
|
||||
):
|
||||
return super()._stream_responses(*args, **kwargs)
|
||||
else:
|
||||
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||
# Note: stream_options is not a valid parameter for Azure OpenAI.
|
||||
# To support users proxying Azure through ChatOpenAI, here we only specify
|
||||
# stream_options if include_usage is set to True.
|
||||
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new
|
||||
# for release notes.
|
||||
if stream_usage:
|
||||
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||
|
||||
return super()._stream(*args, **kwargs)
|
||||
return super()._stream(*args, **kwargs)
|
||||
|
||||
async def _astream(
|
||||
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
"""Set default stream_options."""
|
||||
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||
if stream_usage:
|
||||
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||
if "tools" in kwargs and any(
|
||||
_is_builtin_tool(tool) for tool in kwargs["tools"]
|
||||
):
|
||||
async for chunk in super()._astream_responses(*args, **kwargs):
|
||||
yield chunk
|
||||
else:
|
||||
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||
if stream_usage:
|
||||
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||
|
||||
async for chunk in super()._astream(*args, **kwargs):
|
||||
yield chunk
|
||||
async for chunk in super()._astream(*args, **kwargs):
|
||||
yield chunk
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
|
@ -1245,6 +1245,18 @@ def test_web_search() -> None:
|
||||
assert response.response_metadata["model_name"]
|
||||
assert response.response_metadata["status"]
|
||||
|
||||
# Test streaming
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in llm.stream(
|
||||
"What was a positive news story from today?",
|
||||
tools=[{"type": "web_search_preview"}],
|
||||
):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.content
|
||||
assert full.usage_metadata
|
||||
|
||||
|
||||
async def test_web_search_async() -> None:
|
||||
llm = ChatOpenAI(model="gpt-4o")
|
||||
@ -1260,3 +1272,14 @@ async def test_web_search_async() -> None:
|
||||
assert response.usage_metadata["total_tokens"] > 0
|
||||
assert response.response_metadata["model_name"]
|
||||
assert response.response_metadata["status"]
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
async for chunk in llm.astream(
|
||||
"What was a positive news story from today?",
|
||||
tools=[{"type": "web_search_preview"}],
|
||||
):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.content
|
||||
assert full.usage_metadata
|
||||
|
Loading…
Reference in New Issue
Block a user