mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
core[patch], openai[patch]: Chat openai stream logprobs (#16218)
This commit is contained in:
parent
6f7a414955
commit
84bf5787a7
@ -5,6 +5,7 @@ from typing import Any, Dict, List, Literal
|
|||||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||||
from langchain_core.outputs.generation import Generation
|
from langchain_core.outputs.generation import Generation
|
||||||
from langchain_core.pydantic_v1 import root_validator
|
from langchain_core.pydantic_v1 import root_validator
|
||||||
|
from langchain_core.utils._merge import merge_dicts
|
||||||
|
|
||||||
|
|
||||||
class ChatGeneration(Generation):
|
class ChatGeneration(Generation):
|
||||||
@ -53,14 +54,13 @@ class ChatGenerationChunk(ChatGeneration):
|
|||||||
|
|
||||||
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
|
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
|
||||||
if isinstance(other, ChatGenerationChunk):
|
if isinstance(other, ChatGenerationChunk):
|
||||||
generation_info = (
|
generation_info = merge_dicts(
|
||||||
{**(self.generation_info or {}), **(other.generation_info or {})}
|
self.generation_info or {},
|
||||||
if self.generation_info is not None or other.generation_info is not None
|
other.generation_info or {},
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
return ChatGenerationChunk(
|
return ChatGenerationChunk(
|
||||||
message=self.message + other.message,
|
message=self.message + other.message,
|
||||||
generation_info=generation_info,
|
generation_info=generation_info or None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from langchain_core.load import Serializable
|
from langchain_core.load import Serializable
|
||||||
|
from langchain_core.utils._merge import merge_dicts
|
||||||
|
|
||||||
|
|
||||||
class Generation(Serializable):
|
class Generation(Serializable):
|
||||||
@ -40,14 +41,13 @@ class GenerationChunk(Generation):
|
|||||||
|
|
||||||
def __add__(self, other: GenerationChunk) -> GenerationChunk:
|
def __add__(self, other: GenerationChunk) -> GenerationChunk:
|
||||||
if isinstance(other, GenerationChunk):
|
if isinstance(other, GenerationChunk):
|
||||||
generation_info = (
|
generation_info = merge_dicts(
|
||||||
{**(self.generation_info or {}), **(other.generation_info or {})}
|
self.generation_info or {},
|
||||||
if self.generation_info is not None or other.generation_info is not None
|
other.generation_info or {},
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
return GenerationChunk(
|
return GenerationChunk(
|
||||||
text=self.text + other.text,
|
text=self.text + other.text,
|
||||||
generation_info=generation_info,
|
generation_info=generation_info or None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
44
libs/core/langchain_core/utils/_merge.py
Normal file
44
libs/core/langchain_core/utils/_merge.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Merge two dicts, handling specific scenarios where a key exists in both
|
||||||
|
dictionaries but has a value of None in 'left'. In such cases, the method uses the
|
||||||
|
value from 'right' for that key in the merged dictionary.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
If left = {"function_call": {"arguments": None}} and
|
||||||
|
right = {"function_call": {"arguments": "{\n"}}
|
||||||
|
then, after merging, for the key "function_call",
|
||||||
|
the value from 'right' is used,
|
||||||
|
resulting in merged = {"function_call": {"arguments": "{\n"}}.
|
||||||
|
"""
|
||||||
|
merged = left.copy()
|
||||||
|
for k, v in right.items():
|
||||||
|
if k not in merged:
|
||||||
|
merged[k] = v
|
||||||
|
elif merged[k] is None and v:
|
||||||
|
merged[k] = v
|
||||||
|
elif v is None:
|
||||||
|
continue
|
||||||
|
elif merged[k] == v:
|
||||||
|
continue
|
||||||
|
elif type(merged[k]) != type(v):
|
||||||
|
raise TypeError(
|
||||||
|
f'additional_kwargs["{k}"] already exists in this message,'
|
||||||
|
" but with a different type."
|
||||||
|
)
|
||||||
|
elif isinstance(merged[k], str):
|
||||||
|
merged[k] += v
|
||||||
|
elif isinstance(merged[k], dict):
|
||||||
|
merged[k] = merge_dicts(merged[k], v)
|
||||||
|
elif isinstance(merged[k], list):
|
||||||
|
merged[k] = merged[k] + v
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Additional kwargs key {k} already exists in left dict and value has "
|
||||||
|
f"unsupported type {type(merged[k])}."
|
||||||
|
)
|
||||||
|
return merged
|
@ -404,15 +404,19 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
chunk = _convert_delta_to_message_chunk(
|
chunk = _convert_delta_to_message_chunk(
|
||||||
choice["delta"], default_chunk_class
|
choice["delta"], default_chunk_class
|
||||||
)
|
)
|
||||||
finish_reason = choice.get("finish_reason")
|
generation_info = {}
|
||||||
generation_info = (
|
if finish_reason := choice.get("finish_reason"):
|
||||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
generation_info["finish_reason"] = finish_reason
|
||||||
)
|
logprobs = choice.get("logprobs")
|
||||||
|
if logprobs:
|
||||||
|
generation_info["logprobs"] = logprobs
|
||||||
default_chunk_class = chunk.__class__
|
default_chunk_class = chunk.__class__
|
||||||
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
chunk = ChatGenerationChunk(
|
||||||
|
message=chunk, generation_info=generation_info or None
|
||||||
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
@ -492,15 +496,21 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
chunk = _convert_delta_to_message_chunk(
|
chunk = _convert_delta_to_message_chunk(
|
||||||
choice["delta"], default_chunk_class
|
choice["delta"], default_chunk_class
|
||||||
)
|
)
|
||||||
finish_reason = choice.get("finish_reason")
|
generation_info = {}
|
||||||
generation_info = (
|
if finish_reason := choice.get("finish_reason"):
|
||||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
generation_info["finish_reason"] = finish_reason
|
||||||
)
|
logprobs = choice.get("logprobs")
|
||||||
|
if logprobs:
|
||||||
|
generation_info["logprobs"] = logprobs
|
||||||
default_chunk_class = chunk.__class__
|
default_chunk_class = chunk.__class__
|
||||||
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
chunk = ChatGenerationChunk(
|
||||||
|
message=chunk, generation_info=generation_info or None
|
||||||
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
if run_manager:
|
if run_manager:
|
||||||
await run_manager.on_llm_new_token(token=chunk.text, chunk=chunk)
|
await run_manager.on_llm_new_token(
|
||||||
|
token=chunk.text, chunk=chunk, logprobs=logprobs
|
||||||
|
)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
|
@ -391,3 +391,37 @@ def test_invoke() -> None:
|
|||||||
|
|
||||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||||
assert isinstance(result.content, str)
|
assert isinstance(result.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_logprobs() -> None:
|
||||||
|
llm = ChatOpenAI()
|
||||||
|
result = llm.generate([[HumanMessage(content="I'm PickleRick")]], logprobs=True)
|
||||||
|
assert result.generations[0][0].generation_info
|
||||||
|
assert "content" in result.generations[0][0].generation_info["logprobs"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_logprobs() -> None:
|
||||||
|
llm = ChatOpenAI()
|
||||||
|
result = await llm.agenerate(
|
||||||
|
[[HumanMessage(content="I'm PickleRick")]], logprobs=True
|
||||||
|
)
|
||||||
|
assert result.generations[0][0].generation_info
|
||||||
|
assert "content" in result.generations[0][0].generation_info["logprobs"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_logprobs_streaming() -> None:
|
||||||
|
llm = ChatOpenAI()
|
||||||
|
result = llm.generate(
|
||||||
|
[[HumanMessage(content="I'm PickleRick")]], logprobs=True, stream=True
|
||||||
|
)
|
||||||
|
assert result.generations[0][0].generation_info
|
||||||
|
assert "content" in result.generations[0][0].generation_info["logprobs"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_logprobs_streaming() -> None:
|
||||||
|
llm = ChatOpenAI()
|
||||||
|
result = await llm.agenerate(
|
||||||
|
[[HumanMessage(content="I'm PickleRick")]], logprobs=True, stream=True
|
||||||
|
)
|
||||||
|
assert result.generations[0][0].generation_info
|
||||||
|
assert "content" in result.generations[0][0].generation_info["logprobs"]
|
||||||
|
Loading…
Reference in New Issue
Block a user