core[patch], openai[patch]: Chat openai stream logprobs (#16218)

This commit is contained in:
Bagatur
2024-01-19 09:16:09 -08:00
committed by GitHub
parent 6f7a414955
commit 84bf5787a7
5 changed files with 110 additions and 22 deletions

View File

@@ -404,15 +404,19 @@ class ChatOpenAI(BaseChatModel):
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
def _generate(
self,
@@ -492,15 +496,21 @@ class ChatOpenAI(BaseChatModel):
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(token=chunk.text, chunk=chunk)
await run_manager.on_llm_new_token(
token=chunk.text, chunk=chunk, logprobs=logprobs
)
async def _agenerate(
self,