core, openai[patch]: prefer provider-assigned IDs when aggregating message chunks (#31080)

When aggregating AIMessageChunks in a stream, core prefers the leftmost
non-null ID. This is problematic because:
- Core assigns IDs when they are null to `f"run-{run_manager.run_id}"`
- The desired meaningful ID might not be available until midway through
the stream, as is the case for the OpenAI Responses API.

For the OpenAI Responses API, we assign message IDs to the top-level
`AIMessage.id`. This works in `.(a)invoke`, but during `.(a)stream` the
IDs get overwritten by the defaults assigned in langchain-core. These
IDs
[must](https://community.openai.com/t/how-to-solve-badrequesterror-400-item-rs-of-type-reasoning-was-provided-without-its-required-following-item-error-in-responses-api/1151686/9)
be available on the AIMessage object to support passing reasoning items
back to the API (e.g., if not using OpenAI's `previous_response_id`
feature). We could add them elsewhere, but seeing as we've already made
the decision to store them in `.id` during `.(a)invoke`, addressing the
issue in core lets us fix the problem with no interface changes.
This commit is contained in:
ccurme 2025-05-02 11:18:18 -04:00 committed by GitHub
parent 72f905a436
commit 26ad239669
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 42 additions and 14 deletions

View File

@ -58,6 +58,7 @@ from langchain_core.messages import (
is_data_content_block,
message_chunk_to_message,
)
from langchain_core.messages.ai import _LC_ID_PREFIX
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
@ -493,7 +494,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
input_messages = _normalize_messages(messages)
for chunk in self._stream(input_messages, stop=stop, **kwargs):
if chunk.message.id is None:
chunk.message.id = f"run-{run_manager.run_id}"
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
@ -583,7 +584,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
**kwargs,
):
if chunk.message.id is None:
chunk.message.id = f"run-{run_manager.run_id}"
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
await run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
@ -1001,7 +1002,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
if run_manager:
if chunk.message.id is None:
chunk.message.id = f"run-{run_manager.run_id}"
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
)
@ -1017,7 +1018,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
# Add response metadata to each generation
for idx, generation in enumerate(result.generations):
if run_manager and generation.message.id is None:
generation.message.id = f"run-{run_manager.run_id}-{idx}"
generation.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}-{idx}"
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation
)
@ -1073,7 +1074,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
if run_manager:
if chunk.message.id is None:
chunk.message.id = f"run-{run_manager.run_id}"
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
await run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
)
@ -1089,7 +1090,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
# Add response metadata to each generation
for idx, generation in enumerate(result.generations):
if run_manager and generation.message.id is None:
generation.message.id = f"run-{run_manager.run_id}-{idx}"
generation.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}-{idx}"
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation
)

View File

@ -36,6 +36,9 @@ from langchain_core.utils.usage import _dict_int_op
logger = logging.getLogger(__name__)
_LC_ID_PREFIX = "run-"
class InputTokenDetails(TypedDict, total=False):
"""Breakdown of input token counts.
@ -418,10 +421,19 @@ def add_ai_message_chunks(
usage_metadata = None
id = None
for id_ in [left.id] + [o.id for o in others]:
if id_:
candidates = [left.id] + [o.id for o in others]
# first pass: pick the first nonrun-* id
for id_ in candidates:
if id_ and not id_.startswith(_LC_ID_PREFIX):
id = id_
break
else:
# second pass: no provider-assigned id found, just take the first nonnull
for id_ in candidates:
if id_:
id = id_
break
return left.__class__(
example=left.example,
content=content,

View File

@ -178,6 +178,22 @@ def test_message_chunks() -> None:
assert AIMessageChunk(content="") + left == left
assert right + AIMessageChunk(content="") == right
# Test ID order of precedence
null_id = AIMessageChunk(content="", id=None)
default_id = AIMessageChunk(
content="", id="run-abc123"
) # LangChain-assigned run ID
meaningful_id = AIMessageChunk(content="", id="msg_def456") # provider-assigned ID
assert (null_id + default_id).id == "run-abc123"
assert (default_id + null_id).id == "run-abc123"
assert (null_id + meaningful_id).id == "msg_def456"
assert (meaningful_id + null_id).id == "msg_def456"
assert (default_id + meaningful_id).id == "msg_def456"
assert (meaningful_id + default_id).id == "msg_def456"
def test_chat_message_chunks() -> None:
assert ChatMessageChunk(role="User", content="I am", id="ai4") + ChatMessageChunk(

View File

@ -3127,6 +3127,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
reasoning_items = []
if reasoning := lc_msg.additional_kwargs.get("reasoning"):
reasoning_items.append(_pop_summary_index_from_reasoning(reasoning))
input_.extend(reasoning_items)
# Function calls
function_calls = []
if tool_calls := msg.pop("tool_calls", None):
@ -3185,13 +3186,11 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
pass
msg["content"] = new_blocks
if msg["content"]:
if lc_msg.id and lc_msg.id.startswith("msg_"):
msg["id"] = lc_msg.id
input_.append(msg)
input_.extend(function_calls)
if computer_calls:
# Hack: we only add reasoning items if computer calls are present. See:
# https://community.openai.com/t/how-to-solve-badrequesterror-400-item-rs-of-type-reasoning-was-provided-without-its-required-following-item-error-in-responses-api/1151686/5
input_.extend(reasoning_items)
input_.extend(computer_calls)
input_.extend(computer_calls)
elif msg["role"] == "user":
if isinstance(msg["content"], list):
new_blocks = []

View File

@ -271,7 +271,7 @@ def test_function_calling_and_structured_output() -> None:
"""return x * y"""
return x * y
llm = ChatOpenAI(model=MODEL_NAME)
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
bound_llm = llm.bind_tools([multiply], response_format=Foo, strict=True)
# Test structured output
response = llm.invoke("how are ya", response_format=Foo)