mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 19:12:42 +00:00
openai[patch]: support streaming with json_schema response format (#29044)
- Stream JSON string content. Final chunk includes parsed representation (following OpenAI [docs](https://platform.openai.com/docs/guides/structured-outputs#streaming)). - Mildly (?) breaking change: if you were using streaming with `response_format` before, usage metadata will disappear unless you set `stream_usage=True`. ## Response format Before:  After:  ## with_structured_output For pydantic output, behavior of `with_structured_output` is unchanged (except for warning disappearing), because we pluck the parsed representation straight from OpenAI, and OpenAI doesn't return it until the stream is completed. Open to alternatives (e.g., parsing from content or intermediate dict chunks generated by OpenAI). Before:  After: 
This commit is contained in:
@@ -318,8 +318,14 @@ def _convert_delta_to_message_chunk(
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict]
|
||||
) -> Optional[ChatGenerationChunk]:
|
||||
if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
|
||||
return None
|
||||
token_usage = chunk.get("usage")
|
||||
choices = chunk.get("choices", [])
|
||||
choices = (
|
||||
chunk.get("choices", [])
|
||||
# from beta.chat.completions.stream
|
||||
or chunk.get("chunk", {}).get("choices", [])
|
||||
)
|
||||
|
||||
usage_metadata: Optional[UsageMetadata] = (
|
||||
_create_usage_metadata(token_usage) if token_usage else None
|
||||
@@ -660,13 +666,24 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
base_generation_info = {}
|
||||
|
||||
if self.include_response_headers:
|
||||
raw_response = self.client.with_raw_response.create(**payload)
|
||||
response = raw_response.parse()
|
||||
base_generation_info = {"headers": dict(raw_response.headers)}
|
||||
if "response_format" in payload:
|
||||
if self.include_response_headers:
|
||||
warnings.warn(
|
||||
"Cannot currently include response headers when response_format is "
|
||||
"specified."
|
||||
)
|
||||
payload.pop("stream")
|
||||
response_stream = self.root_client.beta.chat.completions.stream(**payload)
|
||||
context_manager = response_stream
|
||||
else:
|
||||
response = self.client.create(**payload)
|
||||
with response:
|
||||
if self.include_response_headers:
|
||||
raw_response = self.client.with_raw_response.create(**payload)
|
||||
response = raw_response.parse()
|
||||
base_generation_info = {"headers": dict(raw_response.headers)}
|
||||
else:
|
||||
response = self.client.create(**payload)
|
||||
context_manager = response
|
||||
with context_manager as response:
|
||||
is_first_chunk = True
|
||||
for chunk in response:
|
||||
if not isinstance(chunk, dict):
|
||||
@@ -686,6 +703,16 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
)
|
||||
is_first_chunk = False
|
||||
yield generation_chunk
|
||||
if hasattr(response, "get_final_completion") and "response_format" in payload:
|
||||
final_completion = response.get_final_completion()
|
||||
generation_chunk = self._get_generation_chunk_from_completion(
|
||||
final_completion
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
generation_chunk.text, chunk=generation_chunk
|
||||
)
|
||||
yield generation_chunk
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@@ -794,13 +821,29 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
base_generation_info = {}
|
||||
if self.include_response_headers:
|
||||
raw_response = await self.async_client.with_raw_response.create(**payload)
|
||||
response = raw_response.parse()
|
||||
base_generation_info = {"headers": dict(raw_response.headers)}
|
||||
|
||||
if "response_format" in payload:
|
||||
if self.include_response_headers:
|
||||
warnings.warn(
|
||||
"Cannot currently include response headers when response_format is "
|
||||
"specified."
|
||||
)
|
||||
payload.pop("stream")
|
||||
response_stream = self.root_async_client.beta.chat.completions.stream(
|
||||
**payload
|
||||
)
|
||||
context_manager = response_stream
|
||||
else:
|
||||
response = await self.async_client.create(**payload)
|
||||
async with response:
|
||||
if self.include_response_headers:
|
||||
raw_response = await self.async_client.with_raw_response.create(
|
||||
**payload
|
||||
)
|
||||
response = raw_response.parse()
|
||||
base_generation_info = {"headers": dict(raw_response.headers)}
|
||||
else:
|
||||
response = await self.async_client.create(**payload)
|
||||
context_manager = response
|
||||
async with context_manager as response:
|
||||
is_first_chunk = True
|
||||
async for chunk in response:
|
||||
if not isinstance(chunk, dict):
|
||||
@@ -820,6 +863,16 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
)
|
||||
is_first_chunk = False
|
||||
yield generation_chunk
|
||||
if hasattr(response, "get_final_completion") and "response_format" in payload:
|
||||
final_completion = await response.get_final_completion()
|
||||
generation_chunk = self._get_generation_chunk_from_completion(
|
||||
final_completion
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
generation_chunk.text, chunk=generation_chunk
|
||||
)
|
||||
yield generation_chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
@@ -1010,25 +1063,6 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
num_tokens += 3
|
||||
return num_tokens
|
||||
|
||||
def _should_stream(
|
||||
self,
|
||||
*,
|
||||
async_api: bool,
|
||||
run_manager: Optional[
|
||||
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
|
||||
] = None,
|
||||
response_format: Optional[Union[dict, type]] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
if isinstance(response_format, type) and is_basemodel_subclass(response_format):
|
||||
# TODO: Add support for streaming with Pydantic response_format.
|
||||
warnings.warn("Streaming with Pydantic response_format not yet supported.")
|
||||
return False
|
||||
|
||||
return super()._should_stream(
|
||||
async_api=async_api, run_manager=run_manager, **kwargs
|
||||
)
|
||||
|
||||
@deprecated(
|
||||
since="0.2.1",
|
||||
alternative="langchain_openai.chat_models.base.ChatOpenAI.bind_tools",
|
||||
@@ -1531,6 +1565,25 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
filtered[k] = v
|
||||
return filtered
|
||||
|
||||
def _get_generation_chunk_from_completion(
|
||||
self, completion: openai.BaseModel
|
||||
) -> ChatGenerationChunk:
|
||||
"""Get chunk from completion (e.g., from final completion of a stream)."""
|
||||
chat_result = self._create_chat_result(completion)
|
||||
chat_message = chat_result.generations[0].message
|
||||
if isinstance(chat_message, AIMessage):
|
||||
usage_metadata = chat_message.usage_metadata
|
||||
else:
|
||||
usage_metadata = None
|
||||
message = AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs=chat_message.additional_kwargs,
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
return ChatGenerationChunk(
|
||||
message=message, generation_info=chat_result.llm_output
|
||||
)
|
||||
|
||||
|
||||
class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
"""OpenAI chat model integration.
|
||||
|
Reference in New Issue
Block a user