mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-21 10:26:57 +00:00
(a)stream
This commit is contained in:
parent
3951f90a9e
commit
03035b8d73
@ -356,6 +356,27 @@ def _convert_delta_to_message_chunk(
|
|||||||
return default_class(content=content, id=id_) # type: ignore
|
return default_class(content=content, id=id_) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_responses_chunk_to_generation_chunk(
|
||||||
|
chunk: Any,
|
||||||
|
) -> Optional[ChatGenerationChunk]:
|
||||||
|
content = ""
|
||||||
|
generation_info = None
|
||||||
|
usage_metadata = None
|
||||||
|
if chunk.type == "response.output_text.delta":
|
||||||
|
content = chunk.delta
|
||||||
|
elif chunk.type == "response.completed":
|
||||||
|
token_usage = chunk.response.usage.model_dump() if chunk.response.usage else {}
|
||||||
|
usage_metadata = _create_usage_metadata_responses(token_usage)
|
||||||
|
generation_info = {"model_name": chunk.response.model}
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return ChatGenerationChunk(
|
||||||
|
message=AIMessageChunk(content=content, usage_metadata=usage_metadata),
|
||||||
|
generation_info=generation_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _update_token_usage(
|
def _update_token_usage(
|
||||||
overall_token_usage: Union[int, dict], new_usage: Union[int, dict]
|
overall_token_usage: Union[int, dict], new_usage: Union[int, dict]
|
||||||
) -> Union[int, dict]:
|
) -> Union[int, dict]:
|
||||||
@ -410,7 +431,7 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _is_builtin_tool(tool: dict) -> bool:
|
def _is_builtin_tool(tool: dict) -> bool:
|
||||||
return set(tool.keys()) == {"type"}
|
return "type" in tool and tool["type"] != "function"
|
||||||
|
|
||||||
|
|
||||||
def _transform_payload_for_responses(payload: dict) -> dict:
|
def _transform_payload_for_responses(payload: dict) -> dict:
|
||||||
@ -741,6 +762,54 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
return generation_chunk
|
return generation_chunk
|
||||||
|
|
||||||
|
def _stream_responses(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
kwargs["stream"] = True
|
||||||
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
|
responses_payload = _transform_payload_for_responses(payload)
|
||||||
|
context_manager = self.root_client.responses.create(**responses_payload)
|
||||||
|
|
||||||
|
with context_manager as response:
|
||||||
|
for chunk in response:
|
||||||
|
if generation_chunk := _convert_responses_chunk_to_generation_chunk(
|
||||||
|
chunk
|
||||||
|
):
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(
|
||||||
|
generation_chunk.text, chunk=generation_chunk
|
||||||
|
)
|
||||||
|
yield generation_chunk
|
||||||
|
|
||||||
|
async def _astream_responses(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
kwargs["stream"] = True
|
||||||
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
|
responses_payload = _transform_payload_for_responses(payload)
|
||||||
|
context_manager = await self.root_async_client.responses.create(
|
||||||
|
**responses_payload
|
||||||
|
)
|
||||||
|
|
||||||
|
async with context_manager as response:
|
||||||
|
async for chunk in response:
|
||||||
|
if generation_chunk := _convert_responses_chunk_to_generation_chunk(
|
||||||
|
chunk
|
||||||
|
):
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(
|
||||||
|
generation_chunk.text, chunk=generation_chunk
|
||||||
|
)
|
||||||
|
yield generation_chunk
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@ -2143,6 +2212,11 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
"""Set default stream_options."""
|
"""Set default stream_options."""
|
||||||
|
if "tools" in kwargs and any(
|
||||||
|
_is_builtin_tool(tool) for tool in kwargs["tools"]
|
||||||
|
):
|
||||||
|
return super()._stream_responses(*args, **kwargs)
|
||||||
|
else:
|
||||||
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||||
# Note: stream_options is not a valid parameter for Azure OpenAI.
|
# Note: stream_options is not a valid parameter for Azure OpenAI.
|
||||||
# To support users proxying Azure through ChatOpenAI, here we only specify
|
# To support users proxying Azure through ChatOpenAI, here we only specify
|
||||||
@ -2158,6 +2232,12 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
"""Set default stream_options."""
|
"""Set default stream_options."""
|
||||||
|
if "tools" in kwargs and any(
|
||||||
|
_is_builtin_tool(tool) for tool in kwargs["tools"]
|
||||||
|
):
|
||||||
|
async for chunk in super()._astream_responses(*args, **kwargs):
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||||
if stream_usage:
|
if stream_usage:
|
||||||
kwargs["stream_options"] = {"include_usage": stream_usage}
|
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||||
|
@ -1245,6 +1245,18 @@ def test_web_search() -> None:
|
|||||||
assert response.response_metadata["model_name"]
|
assert response.response_metadata["model_name"]
|
||||||
assert response.response_metadata["status"]
|
assert response.response_metadata["status"]
|
||||||
|
|
||||||
|
# Test streaming
|
||||||
|
full: Optional[BaseMessageChunk] = None
|
||||||
|
for chunk in llm.stream(
|
||||||
|
"What was a positive news story from today?",
|
||||||
|
tools=[{"type": "web_search_preview"}],
|
||||||
|
):
|
||||||
|
assert isinstance(chunk, AIMessageChunk)
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
assert full.content
|
||||||
|
assert full.usage_metadata
|
||||||
|
|
||||||
|
|
||||||
async def test_web_search_async() -> None:
|
async def test_web_search_async() -> None:
|
||||||
llm = ChatOpenAI(model="gpt-4o")
|
llm = ChatOpenAI(model="gpt-4o")
|
||||||
@ -1260,3 +1272,14 @@ async def test_web_search_async() -> None:
|
|||||||
assert response.usage_metadata["total_tokens"] > 0
|
assert response.usage_metadata["total_tokens"] > 0
|
||||||
assert response.response_metadata["model_name"]
|
assert response.response_metadata["model_name"]
|
||||||
assert response.response_metadata["status"]
|
assert response.response_metadata["status"]
|
||||||
|
|
||||||
|
full: Optional[BaseMessageChunk] = None
|
||||||
|
async for chunk in llm.astream(
|
||||||
|
"What was a positive news story from today?",
|
||||||
|
tools=[{"type": "web_search_preview"}],
|
||||||
|
):
|
||||||
|
assert isinstance(chunk, AIMessageChunk)
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
assert full.content
|
||||||
|
assert full.usage_metadata
|
||||||
|
Loading…
Reference in New Issue
Block a user