mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-31 02:11:09 +00:00
core, openai[patch]: prefer provider-assigned IDs when aggregating message chunks (#31080)
When aggregating AIMessageChunks in a stream, core prefers the leftmost non-null ID. This is problematic because: - Core assigns IDs when they are null to `f"run-{run_manager.run_id}"` - The desired meaningful ID might not be available until midway through the stream, as is the case for the OpenAI Responses API. For the OpenAI Responses API, we assign message IDs to the top-level `AIMessage.id`. This works in `.(a)invoke`, but during `.(a)stream` the IDs get overwritten by the defaults assigned in langchain-core. These IDs [must](https://community.openai.com/t/how-to-solve-badrequesterror-400-item-rs-of-type-reasoning-was-provided-without-its-required-following-item-error-in-responses-api/1151686/9) be available on the AIMessage object to support passing reasoning items back to the API (e.g., if not using OpenAI's `previous_response_id` feature). We could add them elsewhere, but seeing as we've already made the decision to store them in `.id` during `.(a)invoke`, addressing the issue in core lets us fix the problem with no interface changes.
This commit is contained in:
parent
72f905a436
commit
26ad239669
@ -58,6 +58,7 @@ from langchain_core.messages import (
|
||||
is_data_content_block,
|
||||
message_chunk_to_message,
|
||||
)
|
||||
from langchain_core.messages.ai import _LC_ID_PREFIX
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
@ -493,7 +494,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
input_messages = _normalize_messages(messages)
|
||||
for chunk in self._stream(input_messages, stop=stop, **kwargs):
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = f"run-{run_manager.run_id}"
|
||||
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
run_manager.on_llm_new_token(
|
||||
cast("str", chunk.message.content), chunk=chunk
|
||||
@ -583,7 +584,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
**kwargs,
|
||||
):
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = f"run-{run_manager.run_id}"
|
||||
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
await run_manager.on_llm_new_token(
|
||||
cast("str", chunk.message.content), chunk=chunk
|
||||
@ -1001,7 +1002,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
if run_manager:
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = f"run-{run_manager.run_id}"
|
||||
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
|
||||
run_manager.on_llm_new_token(
|
||||
cast("str", chunk.message.content), chunk=chunk
|
||||
)
|
||||
@ -1017,7 +1018,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
# Add response metadata to each generation
|
||||
for idx, generation in enumerate(result.generations):
|
||||
if run_manager and generation.message.id is None:
|
||||
generation.message.id = f"run-{run_manager.run_id}-{idx}"
|
||||
generation.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}-{idx}"
|
||||
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
generation
|
||||
)
|
||||
@ -1073,7 +1074,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
if run_manager:
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = f"run-{run_manager.run_id}"
|
||||
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
|
||||
await run_manager.on_llm_new_token(
|
||||
cast("str", chunk.message.content), chunk=chunk
|
||||
)
|
||||
@ -1089,7 +1090,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
# Add response metadata to each generation
|
||||
for idx, generation in enumerate(result.generations):
|
||||
if run_manager and generation.message.id is None:
|
||||
generation.message.id = f"run-{run_manager.run_id}-{idx}"
|
||||
generation.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}-{idx}"
|
||||
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
generation
|
||||
)
|
||||
|
@ -36,6 +36,9 @@ from langchain_core.utils.usage import _dict_int_op
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_LC_ID_PREFIX = "run-"
|
||||
|
||||
|
||||
class InputTokenDetails(TypedDict, total=False):
|
||||
"""Breakdown of input token counts.
|
||||
|
||||
@ -418,10 +421,19 @@ def add_ai_message_chunks(
|
||||
usage_metadata = None
|
||||
|
||||
id = None
|
||||
for id_ in [left.id] + [o.id for o in others]:
|
||||
if id_:
|
||||
candidates = [left.id] + [o.id for o in others]
|
||||
# first pass: pick the first non‐run-* id
|
||||
for id_ in candidates:
|
||||
if id_ and not id_.startswith(_LC_ID_PREFIX):
|
||||
id = id_
|
||||
break
|
||||
else:
|
||||
# second pass: no provider-assigned id found, just take the first non‐null
|
||||
for id_ in candidates:
|
||||
if id_:
|
||||
id = id_
|
||||
break
|
||||
|
||||
return left.__class__(
|
||||
example=left.example,
|
||||
content=content,
|
||||
|
@ -178,6 +178,22 @@ def test_message_chunks() -> None:
|
||||
assert AIMessageChunk(content="") + left == left
|
||||
assert right + AIMessageChunk(content="") == right
|
||||
|
||||
# Test ID order of precedence
|
||||
null_id = AIMessageChunk(content="", id=None)
|
||||
default_id = AIMessageChunk(
|
||||
content="", id="run-abc123"
|
||||
) # LangChain-assigned run ID
|
||||
meaningful_id = AIMessageChunk(content="", id="msg_def456") # provider-assigned ID
|
||||
|
||||
assert (null_id + default_id).id == "run-abc123"
|
||||
assert (default_id + null_id).id == "run-abc123"
|
||||
|
||||
assert (null_id + meaningful_id).id == "msg_def456"
|
||||
assert (meaningful_id + null_id).id == "msg_def456"
|
||||
|
||||
assert (default_id + meaningful_id).id == "msg_def456"
|
||||
assert (meaningful_id + default_id).id == "msg_def456"
|
||||
|
||||
|
||||
def test_chat_message_chunks() -> None:
|
||||
assert ChatMessageChunk(role="User", content="I am", id="ai4") + ChatMessageChunk(
|
||||
|
@ -3127,6 +3127,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
||||
reasoning_items = []
|
||||
if reasoning := lc_msg.additional_kwargs.get("reasoning"):
|
||||
reasoning_items.append(_pop_summary_index_from_reasoning(reasoning))
|
||||
input_.extend(reasoning_items)
|
||||
# Function calls
|
||||
function_calls = []
|
||||
if tool_calls := msg.pop("tool_calls", None):
|
||||
@ -3185,13 +3186,11 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
||||
pass
|
||||
msg["content"] = new_blocks
|
||||
if msg["content"]:
|
||||
if lc_msg.id and lc_msg.id.startswith("msg_"):
|
||||
msg["id"] = lc_msg.id
|
||||
input_.append(msg)
|
||||
input_.extend(function_calls)
|
||||
if computer_calls:
|
||||
# Hack: we only add reasoning items if computer calls are present. See:
|
||||
# https://community.openai.com/t/how-to-solve-badrequesterror-400-item-rs-of-type-reasoning-was-provided-without-its-required-following-item-error-in-responses-api/1151686/5
|
||||
input_.extend(reasoning_items)
|
||||
input_.extend(computer_calls)
|
||||
input_.extend(computer_calls)
|
||||
elif msg["role"] == "user":
|
||||
if isinstance(msg["content"], list):
|
||||
new_blocks = []
|
||||
|
@ -271,7 +271,7 @@ def test_function_calling_and_structured_output() -> None:
|
||||
"""return x * y"""
|
||||
return x * y
|
||||
|
||||
llm = ChatOpenAI(model=MODEL_NAME)
|
||||
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
|
||||
bound_llm = llm.bind_tools([multiply], response_format=Foo, strict=True)
|
||||
# Test structured output
|
||||
response = llm.invoke("how are ya", response_format=Foo)
|
||||
|
Loading…
Reference in New Issue
Block a user