From 11429a9e1c37efccd1da92246d6f5b878a2af74a Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Sat, 13 Jun 2026 01:26:48 -0400 Subject: [PATCH] fix(openai): avoid sync token reads in Codex streaming (#38128) Codex streaming now builds request headers from the async token path instead of refreshing asynchronously and later reading the token synchronously during payload construction. That keeps `_ChatOpenAICodex._astream` off the sync token path while preserving the `ChatGPT-Account-Id` and `originator` headers needed by Codex requests. --- .../langchain_openai/chat_models/codex.py | 9 +++- .../unit_tests/chat_models/test_codex.py | 44 +++++++++++-------- 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/codex.py b/libs/partners/openai/langchain_openai/chat_models/codex.py index b35fd1102f3..564149c34d8 100644 --- a/libs/partners/openai/langchain_openai/chat_models/codex.py +++ b/libs/partners/openai/langchain_openai/chat_models/codex.py @@ -438,6 +438,7 @@ class _ChatOpenAICodex(ChatOpenAI): way `_convert_input` only runs once (inside super) instead of once here and again there. """ + codex_headers = kwargs.pop("_codex_headers", None) payload_input: LanguageModelInput = input_ if _maybe_has_system_messages(input_): messages = self._convert_input(input_).to_messages() @@ -465,7 +466,10 @@ class _ChatOpenAICodex(ChatOpenAI): # silently overwriting it would hide a programming error). if payload.get("instructions") is None: payload["instructions"] = self.instructions - return self._merge_codex_headers(payload, self._codex_headers_sync()) + headers = ( + codex_headers if codex_headers is not None else self._codex_headers_sync() + ) + return self._merge_codex_headers(payload, headers) async def _agenerate( self, @@ -488,7 +492,8 @@ class _ChatOpenAICodex(ChatOpenAI): run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: - await self.token_provider.aget_token() + token = await self.token_provider.aget_token() + kwargs["_codex_headers"] = self._build_headers(token.account_id) async for chunk in super()._astream( messages, stop=stop, run_manager=run_manager, **kwargs ): diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_codex.py b/libs/partners/openai/tests/unit_tests/chat_models/test_codex.py index c98cb808c9e..f2161a64b03 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_codex.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_codex.py @@ -486,35 +486,41 @@ async def test_agenerate_primes_async_token_cache( assert provider.async_calls == before + 1 -async def test_astream_primes_async_token_cache_and_yields_headers_via_payload( +async def test_astream_builds_codex_headers_without_sync_token_read( monkeypatch: pytest.MonkeyPatch, ) -> None: - """`_astream` primes the async cache and the sync payload still attaches headers. + """`_astream` must not call sync `get_token` on the event loop.""" - The sync payload builder (`_get_request_payload`) is what actually - stamps `ChatGPT-Account-Id` + `originator` on outbound requests; the - async path's only job is to refresh the token off the event loop - before the sync builder reads from the (now warm) cache. Verifies - both halves of that contract. - """ - provider = FakeTokenProvider(account_id="acct-async") + class AsyncOnlyTokenProvider(FakeTokenProvider): + async def aget_token(self) -> _ChatGPTToken: + self.async_calls += 1 + return _ChatGPTToken( + access_token=self.access_token, + refresh_token="rt", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + account_id=self.account_id, + ) + + provider = AsyncOnlyTokenProvider(account_id="acct-async") model = _build_model(token_provider=provider) + kwargs_seen: dict[str, Any] = {} - async def _fake_super_astream(*_a: Any, **_k: Any) -> Any: + async def _fake_super_astream(*_a: Any, **kwargs: Any) -> Any: + kwargs_seen.update(kwargs) yield "chunk" monkeypatch.setattr(ChatOpenAI, "_astream", _fake_super_astream) - before = provider.async_calls + before_sync = provider.calls + before_async = provider.async_calls received = [chunk async for chunk in model._astream([HumanMessage("hi")])] - assert received == ["chunk"] - assert provider.async_calls == before + 1 - # Sync payload (which is what the SDK ultimately serializes) carries - # the codex headers even after the async refresh path primed them. - payload = model._get_request_payload([HumanMessage("hi")]) - headers = payload["extra_headers"] - assert headers[ACCOUNT_ID_HEADER] == "acct-async" - assert headers[ORIGINATOR_HEADER] == ORIGINATOR_VALUE + assert received == ["chunk"] + assert provider.async_calls == before_async + 1 + assert provider.calls == before_sync + assert kwargs_seen["_codex_headers"] == { + ACCOUNT_ID_HEADER: "acct-async", + ORIGINATOR_HEADER: ORIGINATOR_VALUE, + } def test_callable_api_key_returns_provider_token() -> None: