fix(openai): avoid sync token reads in Codex streaming (#38128)

Codex streaming now builds request headers from the async token path
instead of refreshing asynchronously and later reading the token
synchronously during payload construction. That keeps
`_ChatOpenAICodex._astream` off the sync token path while preserving the
`ChatGPT-Account-Id` and `originator` headers needed by Codex requests.
This commit is contained in:
Mason Daugherty
2026-06-13 01:26:48 -04:00
committed by GitHub
parent 454e19588c
commit 11429a9e1c
2 changed files with 32 additions and 21 deletions

View File

@@ -438,6 +438,7 @@ class _ChatOpenAICodex(ChatOpenAI):
way `_convert_input` only runs once (inside super) instead of once
here and again there.
"""
codex_headers = kwargs.pop("_codex_headers", None)
payload_input: LanguageModelInput = input_
if _maybe_has_system_messages(input_):
messages = self._convert_input(input_).to_messages()
@@ -465,7 +466,10 @@ class _ChatOpenAICodex(ChatOpenAI):
# silently overwriting it would hide a programming error).
if payload.get("instructions") is None:
payload["instructions"] = self.instructions
return self._merge_codex_headers(payload, self._codex_headers_sync())
headers = (
codex_headers if codex_headers is not None else self._codex_headers_sync()
)
return self._merge_codex_headers(payload, headers)
async def _agenerate(
self,
@@ -488,7 +492,8 @@ class _ChatOpenAICodex(ChatOpenAI):
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
await self.token_provider.aget_token()
token = await self.token_provider.aget_token()
kwargs["_codex_headers"] = self._build_headers(token.account_id)
async for chunk in super()._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
):

View File

@@ -486,35 +486,41 @@ async def test_agenerate_primes_async_token_cache(
assert provider.async_calls == before + 1
async def test_astream_primes_async_token_cache_and_yields_headers_via_payload(
async def test_astream_builds_codex_headers_without_sync_token_read(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""`_astream` primes the async cache and the sync payload still attaches headers.
"""`_astream` must not call sync `get_token` on the event loop."""
The sync payload builder (`_get_request_payload`) is what actually
stamps `ChatGPT-Account-Id` + `originator` on outbound requests; the
async path's only job is to refresh the token off the event loop
before the sync builder reads from the (now warm) cache. Verifies
both halves of that contract.
"""
provider = FakeTokenProvider(account_id="acct-async")
class AsyncOnlyTokenProvider(FakeTokenProvider):
async def aget_token(self) -> _ChatGPTToken:
self.async_calls += 1
return _ChatGPTToken(
access_token=self.access_token,
refresh_token="rt",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
account_id=self.account_id,
)
provider = AsyncOnlyTokenProvider(account_id="acct-async")
model = _build_model(token_provider=provider)
kwargs_seen: dict[str, Any] = {}
async def _fake_super_astream(*_a: Any, **_k: Any) -> Any:
async def _fake_super_astream(*_a: Any, **kwargs: Any) -> Any:
kwargs_seen.update(kwargs)
yield "chunk"
monkeypatch.setattr(ChatOpenAI, "_astream", _fake_super_astream)
before = provider.async_calls
before_sync = provider.calls
before_async = provider.async_calls
received = [chunk async for chunk in model._astream([HumanMessage("hi")])]
assert received == ["chunk"]
assert provider.async_calls == before + 1
# Sync payload (which is what the SDK ultimately serializes) carries
# the codex headers even after the async refresh path primed them.
payload = model._get_request_payload([HumanMessage("hi")])
headers = payload["extra_headers"]
assert headers[ACCOUNT_ID_HEADER] == "acct-async"
assert headers[ORIGINATOR_HEADER] == ORIGINATOR_VALUE
assert received == ["chunk"]
assert provider.async_calls == before_async + 1
assert provider.calls == before_sync
assert kwargs_seen["_codex_headers"] == {
ACCOUNT_ID_HEADER: "acct-async",
ORIGINATOR_HEADER: ORIGINATOR_VALUE,
}
def test_callable_api_key_returns_provider_token() -> None: