This commit is contained in:
Chester Curme 2025-02-07 10:55:58 -05:00
parent 5690575f13
commit 33e7d91f1a

View File

@ -317,11 +317,21 @@ def _convert_delta_to_message_chunk(
def _convert_chunk_to_generation_chunk( def _convert_chunk_to_generation_chunk(
chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict] chunk: dict,
default_chunk_class: Type,
base_generation_info: Optional[Dict],
chunk_object: Any,
) -> Optional[ChatGenerationChunk]: ) -> Optional[ChatGenerationChunk]:
if chunk.get("type") == "content.delta": # from beta.chat.completions.stream if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
return None return None
token_usage = chunk.get("usage") if chunk.get("type") == "content.done" and hasattr(chunk_object, "parsed"):
return ChatGenerationChunk(
message=default_chunk_class(
content="", additional_kwargs={"parsed": chunk_object.parsed}
)
)
token_usage = chunk.get("usage") or chunk.get("chunk", {}).get("usage")
choices = ( choices = (
chunk.get("choices", []) chunk.get("choices", [])
# from beta.chat.completions.stream # from beta.chat.completions.stream
@ -725,12 +735,15 @@ class BaseChatOpenAI(BaseChatModel):
with context_manager as response: with context_manager as response:
is_first_chunk = True is_first_chunk = True
for chunk in response: for chunk in response:
if not isinstance(chunk, dict): if isinstance(chunk, dict):
chunk = chunk.model_dump() chunk_dict = chunk
else:
chunk_dict = chunk.model_dump()
generation_chunk = _convert_chunk_to_generation_chunk( generation_chunk = _convert_chunk_to_generation_chunk(
chunk, chunk_dict,
default_chunk_class, default_chunk_class,
base_generation_info if is_first_chunk else {}, base_generation_info if is_first_chunk else {},
chunk,
) )
if generation_chunk is None: if generation_chunk is None:
continue continue
@ -746,16 +759,6 @@ class BaseChatOpenAI(BaseChatModel):
yield generation_chunk yield generation_chunk
except openai.BadRequestError as e: except openai.BadRequestError as e:
_handle_openai_bad_request(e) _handle_openai_bad_request(e)
if hasattr(response, "get_final_completion") and "response_format" in payload:
final_completion = response.get_final_completion()
generation_chunk = self._get_generation_chunk_from_completion(
final_completion
)
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
yield generation_chunk
def _generate( def _generate(
self, self,
@ -893,12 +896,15 @@ class BaseChatOpenAI(BaseChatModel):
async with context_manager as response: async with context_manager as response:
is_first_chunk = True is_first_chunk = True
async for chunk in response: async for chunk in response:
if not isinstance(chunk, dict): if isinstance(chunk, dict):
chunk = chunk.model_dump() chunk_dict = chunk
else:
chunk_dict = chunk.model_dump()
generation_chunk = _convert_chunk_to_generation_chunk( generation_chunk = _convert_chunk_to_generation_chunk(
chunk, chunk_dict,
default_chunk_class, default_chunk_class,
base_generation_info if is_first_chunk else {}, base_generation_info if is_first_chunk else {},
chunk,
) )
if generation_chunk is None: if generation_chunk is None:
continue continue
@ -914,16 +920,6 @@ class BaseChatOpenAI(BaseChatModel):
yield generation_chunk yield generation_chunk
except openai.BadRequestError as e: except openai.BadRequestError as e:
_handle_openai_bad_request(e) _handle_openai_bad_request(e)
if hasattr(response, "get_final_completion") and "response_format" in payload:
final_completion = await response.get_final_completion()
generation_chunk = self._get_generation_chunk_from_completion(
final_completion
)
if run_manager:
await run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
yield generation_chunk
async def _agenerate( async def _agenerate(
self, self,
@ -1475,28 +1471,6 @@ class BaseChatOpenAI(BaseChatModel):
filtered[k] = v filtered[k] = v
return filtered return filtered
def _get_generation_chunk_from_completion(
self, completion: openai.BaseModel
) -> ChatGenerationChunk:
"""Get chunk from completion (e.g., from final completion of a stream)."""
chat_result = self._create_chat_result(completion)
chat_message = chat_result.generations[0].message
if isinstance(chat_message, AIMessage):
usage_metadata = chat_message.usage_metadata
# Skip tool_calls, already sent as chunks
if "tool_calls" in chat_message.additional_kwargs:
chat_message.additional_kwargs.pop("tool_calls")
else:
usage_metadata = None
message = AIMessageChunk(
content="",
additional_kwargs=chat_message.additional_kwargs,
usage_metadata=usage_metadata,
)
return ChatGenerationChunk(
message=message, generation_info=chat_result.llm_output
)
class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
"""OpenAI chat model integration. """OpenAI chat model integration.