mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
refactor
This commit is contained in:
parent
5690575f13
commit
33e7d91f1a
@ -317,11 +317,21 @@ def _convert_delta_to_message_chunk(
|
||||
|
||||
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict]
|
||||
chunk: dict,
|
||||
default_chunk_class: Type,
|
||||
base_generation_info: Optional[Dict],
|
||||
chunk_object: Any,
|
||||
) -> Optional[ChatGenerationChunk]:
|
||||
if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
|
||||
return None
|
||||
token_usage = chunk.get("usage")
|
||||
if chunk.get("type") == "content.done" and hasattr(chunk_object, "parsed"):
|
||||
return ChatGenerationChunk(
|
||||
message=default_chunk_class(
|
||||
content="", additional_kwargs={"parsed": chunk_object.parsed}
|
||||
)
|
||||
)
|
||||
|
||||
token_usage = chunk.get("usage") or chunk.get("chunk", {}).get("usage")
|
||||
choices = (
|
||||
chunk.get("choices", [])
|
||||
# from beta.chat.completions.stream
|
||||
@ -725,12 +735,15 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
with context_manager as response:
|
||||
is_first_chunk = True
|
||||
for chunk in response:
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
if isinstance(chunk, dict):
|
||||
chunk_dict = chunk
|
||||
else:
|
||||
chunk_dict = chunk.model_dump()
|
||||
generation_chunk = _convert_chunk_to_generation_chunk(
|
||||
chunk,
|
||||
chunk_dict,
|
||||
default_chunk_class,
|
||||
base_generation_info if is_first_chunk else {},
|
||||
chunk,
|
||||
)
|
||||
if generation_chunk is None:
|
||||
continue
|
||||
@ -746,16 +759,6 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
yield generation_chunk
|
||||
except openai.BadRequestError as e:
|
||||
_handle_openai_bad_request(e)
|
||||
if hasattr(response, "get_final_completion") and "response_format" in payload:
|
||||
final_completion = response.get_final_completion()
|
||||
generation_chunk = self._get_generation_chunk_from_completion(
|
||||
final_completion
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
generation_chunk.text, chunk=generation_chunk
|
||||
)
|
||||
yield generation_chunk
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@ -893,12 +896,15 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
async with context_manager as response:
|
||||
is_first_chunk = True
|
||||
async for chunk in response:
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
if isinstance(chunk, dict):
|
||||
chunk_dict = chunk
|
||||
else:
|
||||
chunk_dict = chunk.model_dump()
|
||||
generation_chunk = _convert_chunk_to_generation_chunk(
|
||||
chunk,
|
||||
chunk_dict,
|
||||
default_chunk_class,
|
||||
base_generation_info if is_first_chunk else {},
|
||||
chunk,
|
||||
)
|
||||
if generation_chunk is None:
|
||||
continue
|
||||
@ -914,16 +920,6 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
yield generation_chunk
|
||||
except openai.BadRequestError as e:
|
||||
_handle_openai_bad_request(e)
|
||||
if hasattr(response, "get_final_completion") and "response_format" in payload:
|
||||
final_completion = await response.get_final_completion()
|
||||
generation_chunk = self._get_generation_chunk_from_completion(
|
||||
final_completion
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
generation_chunk.text, chunk=generation_chunk
|
||||
)
|
||||
yield generation_chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
@ -1475,28 +1471,6 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
filtered[k] = v
|
||||
return filtered
|
||||
|
||||
def _get_generation_chunk_from_completion(
|
||||
self, completion: openai.BaseModel
|
||||
) -> ChatGenerationChunk:
|
||||
"""Get chunk from completion (e.g., from final completion of a stream)."""
|
||||
chat_result = self._create_chat_result(completion)
|
||||
chat_message = chat_result.generations[0].message
|
||||
if isinstance(chat_message, AIMessage):
|
||||
usage_metadata = chat_message.usage_metadata
|
||||
# Skip tool_calls, already sent as chunks
|
||||
if "tool_calls" in chat_message.additional_kwargs:
|
||||
chat_message.additional_kwargs.pop("tool_calls")
|
||||
else:
|
||||
usage_metadata = None
|
||||
message = AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs=chat_message.additional_kwargs,
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
return ChatGenerationChunk(
|
||||
message=message, generation_info=chat_result.llm_output
|
||||
)
|
||||
|
||||
|
||||
class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
"""OpenAI chat model integration.
|
||||
|
Loading…
Reference in New Issue
Block a user