mirror of
https://github.com/hwchase17/langchain.git
synced 2026-07-01 14:47:02 +00:00
fix(openai): avoid sync token reads in Codex streaming (#38128)
Codex streaming now builds request headers from the async token path instead of refreshing asynchronously and later reading the token synchronously during payload construction. That keeps `_ChatOpenAICodex._astream` off the sync token path while preserving the `ChatGPT-Account-Id` and `originator` headers needed by Codex requests.
This commit is contained in:
@@ -438,6 +438,7 @@ class _ChatOpenAICodex(ChatOpenAI):
|
||||
way `_convert_input` only runs once (inside super) instead of once
|
||||
here and again there.
|
||||
"""
|
||||
codex_headers = kwargs.pop("_codex_headers", None)
|
||||
payload_input: LanguageModelInput = input_
|
||||
if _maybe_has_system_messages(input_):
|
||||
messages = self._convert_input(input_).to_messages()
|
||||
@@ -465,7 +466,10 @@ class _ChatOpenAICodex(ChatOpenAI):
|
||||
# silently overwriting it would hide a programming error).
|
||||
if payload.get("instructions") is None:
|
||||
payload["instructions"] = self.instructions
|
||||
return self._merge_codex_headers(payload, self._codex_headers_sync())
|
||||
headers = (
|
||||
codex_headers if codex_headers is not None else self._codex_headers_sync()
|
||||
)
|
||||
return self._merge_codex_headers(payload, headers)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
@@ -488,7 +492,8 @@ class _ChatOpenAICodex(ChatOpenAI):
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
await self.token_provider.aget_token()
|
||||
token = await self.token_provider.aget_token()
|
||||
kwargs["_codex_headers"] = self._build_headers(token.account_id)
|
||||
async for chunk in super()._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
|
||||
@@ -486,35 +486,41 @@ async def test_agenerate_primes_async_token_cache(
|
||||
assert provider.async_calls == before + 1
|
||||
|
||||
|
||||
async def test_astream_primes_async_token_cache_and_yields_headers_via_payload(
|
||||
async def test_astream_builds_codex_headers_without_sync_token_read(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""`_astream` primes the async cache and the sync payload still attaches headers.
|
||||
"""`_astream` must not call sync `get_token` on the event loop."""
|
||||
|
||||
The sync payload builder (`_get_request_payload`) is what actually
|
||||
stamps `ChatGPT-Account-Id` + `originator` on outbound requests; the
|
||||
async path's only job is to refresh the token off the event loop
|
||||
before the sync builder reads from the (now warm) cache. Verifies
|
||||
both halves of that contract.
|
||||
"""
|
||||
provider = FakeTokenProvider(account_id="acct-async")
|
||||
class AsyncOnlyTokenProvider(FakeTokenProvider):
|
||||
async def aget_token(self) -> _ChatGPTToken:
|
||||
self.async_calls += 1
|
||||
return _ChatGPTToken(
|
||||
access_token=self.access_token,
|
||||
refresh_token="rt",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
account_id=self.account_id,
|
||||
)
|
||||
|
||||
provider = AsyncOnlyTokenProvider(account_id="acct-async")
|
||||
model = _build_model(token_provider=provider)
|
||||
kwargs_seen: dict[str, Any] = {}
|
||||
|
||||
async def _fake_super_astream(*_a: Any, **_k: Any) -> Any:
|
||||
async def _fake_super_astream(*_a: Any, **kwargs: Any) -> Any:
|
||||
kwargs_seen.update(kwargs)
|
||||
yield "chunk"
|
||||
|
||||
monkeypatch.setattr(ChatOpenAI, "_astream", _fake_super_astream)
|
||||
before = provider.async_calls
|
||||
before_sync = provider.calls
|
||||
before_async = provider.async_calls
|
||||
received = [chunk async for chunk in model._astream([HumanMessage("hi")])]
|
||||
assert received == ["chunk"]
|
||||
assert provider.async_calls == before + 1
|
||||
|
||||
# Sync payload (which is what the SDK ultimately serializes) carries
|
||||
# the codex headers even after the async refresh path primed them.
|
||||
payload = model._get_request_payload([HumanMessage("hi")])
|
||||
headers = payload["extra_headers"]
|
||||
assert headers[ACCOUNT_ID_HEADER] == "acct-async"
|
||||
assert headers[ORIGINATOR_HEADER] == ORIGINATOR_VALUE
|
||||
assert received == ["chunk"]
|
||||
assert provider.async_calls == before_async + 1
|
||||
assert provider.calls == before_sync
|
||||
assert kwargs_seen["_codex_headers"] == {
|
||||
ACCOUNT_ID_HEADER: "acct-async",
|
||||
ORIGINATOR_HEADER: ORIGINATOR_VALUE,
|
||||
}
|
||||
|
||||
|
||||
def test_callable_api_key_returns_provider_token() -> None:
|
||||
|
||||
Reference in New Issue
Block a user