core[patch], openai[patch]: Chat openai stream logprobs (#16218)

This commit is contained in:
Bagatur
2024-01-19 09:16:09 -08:00
committed by GitHub
parent 6f7a414955
commit 84bf5787a7
5 changed files with 110 additions and 22 deletions

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Literal
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils._merge import merge_dicts
class ChatGeneration(Generation):
@@ -53,14 +54,13 @@ class ChatGenerationChunk(ChatGeneration):
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk):
generation_info = (
{**(self.generation_info or {}), **(other.generation_info or {})}
if self.generation_info is not None or other.generation_info is not None
else None
generation_info = merge_dicts(
self.generation_info or {},
other.generation_info or {},
)
return ChatGenerationChunk(
message=self.message + other.message,
generation_info=generation_info,
generation_info=generation_info or None,
)
else:
raise TypeError(

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Literal, Optional
from langchain_core.load import Serializable
from langchain_core.utils._merge import merge_dicts
class Generation(Serializable):
@@ -40,14 +41,13 @@ class GenerationChunk(Generation):
def __add__(self, other: GenerationChunk) -> GenerationChunk:
if isinstance(other, GenerationChunk):
generation_info = (
{**(self.generation_info or {}), **(other.generation_info or {})}
if self.generation_info is not None or other.generation_info is not None
else None
generation_info = merge_dicts(
self.generation_info or {},
other.generation_info or {},
)
return GenerationChunk(
text=self.text + other.text,
generation_info=generation_info,
generation_info=generation_info or None,
)
else:
raise TypeError(