Compare commits

...

1 Commits

Author SHA1 Message Date
Chester Curme
df0dcde2d1 refactor 2024-05-26 13:57:01 -04:00

View File

@@ -483,38 +483,17 @@ class BaseChatOpenAI(BaseChatModel):
for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
if token_usage := chunk.get("usage"):
usage_metadata = UsageMetadata(
input_tokens=token_usage.get("prompt_tokens", 0),
output_tokens=token_usage.get("completion_tokens", 0),
total_tokens=token_usage.get("total_tokens", 0),
)
chunk = ChatGenerationChunk(
message=default_chunk_class(
content="", usage_metadata=usage_metadata
)
)
else:
continue
else:
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
chunk = _get_generation_chunk(
chunk, default_chunk_class=default_chunk_class
)
if chunk is None:
continue
default_chunk_class = chunk.message.__class__
if run_manager:
if chunk.generation_info:
logprobs = chunk.generation_info.get("logprobs")
else:
logprobs = None
run_manager.on_llm_new_token(
chunk.text, chunk=chunk, logprobs=logprobs
)
@@ -602,40 +581,19 @@ class BaseChatOpenAI(BaseChatModel):
async for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
if token_usage := chunk.get("usage"):
usage_metadata = UsageMetadata(
input_tokens=token_usage.get("prompt_tokens", 0),
output_tokens=token_usage.get("completion_tokens", 0),
total_tokens=token_usage.get("total_tokens", 0),
)
chunk = ChatGenerationChunk(
message=default_chunk_class(
content="", usage_metadata=usage_metadata
)
)
else:
continue
else:
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
chunk = _get_generation_chunk(
chunk, default_chunk_class=default_chunk_class
)
if chunk is None:
continue
default_chunk_class = chunk.message.__class__
if run_manager:
if chunk.generation_info:
logprobs = chunk.generation_info.get("logprobs")
else:
logprobs = None
await run_manager.on_llm_new_token(
token=chunk.text, chunk=chunk, logprobs=logprobs
chunk.text, chunk=chunk, logprobs=logprobs
)
yield chunk
@@ -1212,3 +1170,39 @@ def _lc_invalid_tool_call_to_openai_tool_call(
"arguments": invalid_tool_call["args"],
},
}
def _get_generation_chunk(
chunk: dict, default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
) -> Optional[ChatGenerationChunk]:
"""Get ChatGenerationChunk from client response."""
if len(chunk["choices"]) == 0:
if token_usage := chunk.get("usage"):
usage_metadata = UsageMetadata(
input_tokens=token_usage.get("prompt_tokens", 0),
output_tokens=token_usage.get("completion_tokens", 0),
total_tokens=token_usage.get("total_tokens", 0),
)
generation_chunk = ChatGenerationChunk(
message=default_chunk_class(content="", usage_metadata=usage_metadata)
)
else:
return None
else:
choice = chunk["choices"][0]
if choice["delta"] is None:
return None
message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
return generation_chunk