diff --git a/libs/partners/openai/langchain_openai/__init__.py b/libs/partners/openai/langchain_openai/__init__.py index c70be734874..81596180d59 100644 --- a/libs/partners/openai/langchain_openai/__init__.py +++ b/libs/partners/openai/langchain_openai/__init__.py @@ -1,7 +1,7 @@ """Module for OpenAI integrations.""" from langchain_openai._version import __version__ -from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI +from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI, ChatOpenAICodex from langchain_openai.chat_models._client_utils import StreamChunkTimeoutError from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings from langchain_openai.llms import AzureOpenAI, OpenAI @@ -12,6 +12,7 @@ __all__ = [ "AzureOpenAI", "AzureOpenAIEmbeddings", "ChatOpenAI", + "ChatOpenAICodex", "OpenAI", "OpenAIEmbeddings", "StreamChunkTimeoutError", diff --git a/libs/partners/openai/langchain_openai/chat_models/__init__.py b/libs/partners/openai/langchain_openai/chat_models/__init__.py index e43102ffabc..7e4f50348cc 100644 --- a/libs/partners/openai/langchain_openai/chat_models/__init__.py +++ b/libs/partners/openai/langchain_openai/chat_models/__init__.py @@ -2,5 +2,6 @@ from langchain_openai.chat_models.azure import AzureChatOpenAI from langchain_openai.chat_models.base import ChatOpenAI +from langchain_openai.chat_models.codex import ChatOpenAICodex -__all__ = ["AzureChatOpenAI", "ChatOpenAI"] +__all__ = ["AzureChatOpenAI", "ChatOpenAI", "ChatOpenAICodex"] diff --git a/libs/partners/openai/langchain_openai/chat_models/codex.py b/libs/partners/openai/langchain_openai/chat_models/codex.py new file mode 100644 index 00000000000..4644d94c2d6 --- /dev/null +++ b/libs/partners/openai/langchain_openai/chat_models/codex.py @@ -0,0 +1,497 @@ +"""`ChatOpenAICodex`: OAuth-backed chat model for ChatGPT subscription auth. + +Wraps `ChatOpenAI` to target the ChatGPT codex backend +(`https://chatgpt.com/backend-api/codex`) and supplies refresh-aware +`Authorization` and `ChatGPT-Account-Id` headers from a +`ChatGPTOAuthTokenProvider`. + +The standard `ChatOpenAI` (API-key) flow is untouched. +""" + +from __future__ import annotations + +import logging +import os +from typing import TYPE_CHECKING, Any + +from langchain_core.language_models.chat_models import LangSmithParams +from langchain_core.messages import BaseMessage, ChatMessage, SystemMessage +from pydantic import Field, model_validator + +from langchain_openai.chat_models.base import ChatOpenAI +from langchain_openai.chatgpt_oauth import ( + ChatGPTOAuthTokenProvider, + FileChatGPTOAuthTokenProvider, +) + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from langchain_core.callbacks import AsyncCallbackManagerForLLMRun + from langchain_core.language_models import LanguageModelInput + from langchain_core.outputs import ChatGenerationChunk, ChatResult + + +logger = logging.getLogger(__name__) + + +CHATGPT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex" +ORIGINATOR_HEADER = "originator" +ORIGINATOR_VALUE = "langchain" +"""Built-in default for the `originator` header value. + +Identifies requests as coming from `langchain-openai`. Override per-instance +via the `originator` field or globally via the `LANGCHAIN_CODEX_ORIGINATOR` +env var. +""" +ORIGINATOR_ENV_VAR = "LANGCHAIN_CODEX_ORIGINATOR" +ACCOUNT_ID_HEADER = "ChatGPT-Account-Id" +_INSTRUCTION_ROLES = frozenset({"system", "developer"}) + + +def _default_originator() -> str: + """Resolve the `originator` header default, honoring the env-var override.""" + return os.environ.get(ORIGINATOR_ENV_VAR) or ORIGINATOR_VALUE + + +def _maybe_has_system_messages(input_: Any) -> bool: + """Return `True` if `input_` *could* contain a system-role message. + + Cheap structural probe used to skip the full `_convert_input` pipeline + when there is no chance the lift logic will fire. False positives only + cost an extra conversion; false negatives would silently skip the lift, + so the probe is biased toward `True` for unknown shapes. + """ + if isinstance(input_, str): + return False + if isinstance(input_, BaseMessage): + return _is_instruction_message(input_) + if isinstance(input_, (list, tuple)): + for item in input_: + if isinstance(item, BaseMessage) and _is_instruction_message(item): + return True + if isinstance(item, dict) and item.get("role") in _INSTRUCTION_ROLES: + return True + if ( + isinstance(item, tuple) + and item + and isinstance(item[0], str) + and item[0] in _INSTRUCTION_ROLES + ): + return True + return False + # `PromptValue` or any future shape — be safe and run the slow path. + return True + + +def _is_instruction_message(message: BaseMessage) -> bool: + return isinstance(message, SystemMessage) or ( + isinstance(message, ChatMessage) and message.role in _INSTRUCTION_ROLES + ) + + +def _flatten_system_message_content(system_messages: list[BaseMessage]) -> str: + """Join system/developer message content into a single `instructions` string. + + Codex rejects system-role entries in the input list, so their content + is lifted into the top-level `instructions` field. Content that uses + list-of-content-blocks form is accepted only when every block is + `{"type": "text", ...}`; anything else cannot be flattened into the + string-typed `instructions` field. + + Raises: + ValueError: A system/developer message carries a non-text content block. + """ + parts: list[str] = [] + for index, message in enumerate(system_messages): + message_name = type(message).__name__ + content = message.content + if isinstance(content, str): + parts.append(content) + continue + if not isinstance(content, list): + msg = ( + f"`{message_name}` at index {index} has unsupported content " + f"type {type(content).__name__!r}; only `str` and " + "list-of-text-blocks are accepted by `ChatOpenAICodex`." + ) + raise ValueError(msg) + text_parts: list[str] = [] + for block_index, block in enumerate(content): + if not isinstance(block, dict) or block.get("type") != "text": + msg = ( + f"`{message_name}` at index {index} contains a " + f"non-text content block at position {block_index} " + "(Codex `instructions` is a string field — only " + '`{"type": "text", "text": "..."}` blocks can be ' + "lifted into it). Move the non-text content to a " + "`HumanMessage`, or pass plain instructions via the " + "constructor or `instructions=` kwarg." + ) + raise ValueError(msg) + text_value = block.get("text", "") + if not isinstance(text_value, str): + msg = ( + f"`{message_name}` at index {index} has a text block " + f"at position {block_index} whose `text` is not a " + "string." + ) + raise ValueError(msg) + text_parts.append(text_value) + parts.append("".join(text_parts)) + return "\n\n".join(parts) + + +DEFAULT_INSTRUCTIONS = "You are ChatGPT, a large language model trained by OpenAI." +"""Generic fallback for the Responses-API `instructions` field. + +The Codex backend rejects any request missing a top-level `instructions` +value (400 `Instructions are required`), so this constant keeps zero-config +construction working. **Most callers should override it** with their own +prompt — see `ChatOpenAICodex.instructions` for the resolution rules. +""" +_FORCED_VALUES: dict[str, Any] = { + "use_responses_api": True, + "store": False, + "streaming": True, +} +"""Values forced onto every `ChatOpenAICodex` instance. + +These are the wire-level constraints the Codex backend imposes: + +- `use_responses_api=True`: Codex is only reachable through the Responses + API surface. +- `store=False`: the backend rejects `store=true` + (`400 'Store must be set to false'`). +- `streaming=True`: the backend rejects non-streaming requests + (`400 'Stream must be set to true'`). Pinning this routes `invoke` + through `_stream` so a streaming request is always sent and chunks + are aggregated back into a single message for the caller. + +`output_version` is intentionally **not** forced — it is a client-side +`AIMessage` projection (see `ChatOpenAI.output_version`) that never +appears in the request payload, so callers can pick `"v0"`, `"v1"`, or +`"responses/v1"` freely. + +`base_url` (and its `openai_api_base` alias) is also pinned — to +`CHATGPT_CODEX_BASE_URL` — under the same raise-don't-rewrite contract. +It is enforced separately in the validator rather than listed here +because a caller-controlled endpoint combined with the OAuth bearer +token would be a token-exfiltration vector; see the validator for the +rationale. +""" + + +class ChatOpenAICodex(ChatOpenAI): + """`ChatOpenAI` variant authed by a ChatGPT OAuth subscription. + + Routes requests to `https://chatgpt.com/backend-api/codex` and forces + the wire-level fields the Codex backend requires + (`use_responses_api=True`, `store=False`, `streaming=True`). These + values are forced — passing a conflicting value to the constructor + raises. `output_version` (a client-side `AIMessage` projection) is + not forced; pick whichever projection you want. Authorization and + `ChatGPT-Account-Id` headers are taken from `token_provider` on every + request so a freshly-refreshed access token is always used. + + Example: + ```python + from langchain_openai import ChatOpenAICodex + from langchain_openai.chatgpt_oauth import login_chatgpt + + # One-time setup. The returned provider writes to the default store + # at `~/.langchain/chatgpt-auth.json`, which `ChatOpenAICodex` also + # reads from by default — so subsequent constructions need no + # explicit `token_provider`. + login_chatgpt() + model = ChatOpenAICodex( + model="gpt-5.5", + instructions="You are a senior Python reviewer. Be terse.", + ) + response = model.invoke("hello") + ``` + + !!! tip "Override `instructions`" + + The Codex backend requires a top-level `instructions` value on every + request. A generic default keeps zero-config use working, but most + callers should override it via the constructor (above) or per call + (`model.invoke(..., instructions=...)`). See the field's docstring + for the full resolution rules. + + !!! note + + Token storage is handled by `FileChatGPTOAuthTokenProvider`, which + defaults to `~/.langchain/chatgpt-auth.json` so it does not collide + with the Codex CLI / VS Code session at `~/.codex/auth.json`. + + !!! note "Always streams over the wire" + + The Codex backend only accepts streaming requests, so `streaming=True` + is forced. `invoke` still returns a single aggregated `AIMessage` — + chunks are collected internally — but the underlying HTTP request is + a stream either way. Expect every call to show up as a streamed + request in network logs and LangSmith traces. + """ + + token_provider: Any = Field(default=None, exclude=True) + """Refresh-aware ChatGPT OAuth token provider. + + Must implement the `ChatGPTOAuthTokenProvider` protocol. If `None`, a + `FileChatGPTOAuthTokenProvider` rooted at the default store path is + constructed. + """ + + originator: str | None = Field(default_factory=_default_originator) + """Value sent in the `originator` request header, or `None` to omit it. + + Identifies the client making the request. Defaults to `"langchain"` so + OpenAI telemetry attributes calls to this package. Downstream consumers + (e.g., a framework built on top of `ChatOpenAICodex`) can override this + to identify themselves instead, or set `None` to suppress the header. + + Resolution order (first match wins): + + 1. Per-call `extra_headers={"originator": "..."}` (always trumps the + field; pass an explicit value to override on a single call). + 2. Constructor / kwarg value (`ChatOpenAICodex(originator="my-app")`). + 3. The `LANGCHAIN_CODEX_ORIGINATOR` env var, if set and non-empty. + 4. `ORIGINATOR_VALUE` (`"langchain"`). + + Setting `originator=None` disables the header entirely; the constructor + default never resolves to `None`. + """ + + instructions: str = Field(default=DEFAULT_INSTRUCTIONS) + """System prompt sent in the Responses-API `instructions` field. + + `instructions` is a *top-level* field of the Responses API request — it + is not a chat message. The Codex backend rejects any request where this + field is missing or empty (400 `Instructions are required`) **and** + rejects any `SystemMessage` entry in the input list + (400 `System messages are not allowed`). To bridge those constraints + transparently, `ChatOpenAICodex` resolves `instructions` per call with + this precedence (highest wins): + + 1. Explicit `instructions=` kwarg on `invoke` / `stream`. + 2. Concatenated content of any `SystemMessage` entries in the input + list — joined with `"\\n\\n"` and stripped from the input before + sending. Set the explicit kwarg in (1) to override. + 3. This constructor field (defaults to a generic ChatGPT prompt). + + The Codex backend is stateless for this client (`store=False` is + forced), so `instructions` is sent on every request and can be changed + between calls — useful for switching persona / tooling mid-conversation: + + ```python + model = ChatOpenAICodex( + model="gpt-5.5", + instructions="You are a senior Python reviewer. Be terse.", + ) + model.invoke("review this diff…") + model.invoke( + "now translate the review to French", + instructions="You are a translator.", + ) + ``` + + `SystemMessage` content that uses list-of-content-blocks form is + accepted only if every block is `{"type": "text", ...}`; any other + block type raises `ValueError` since it cannot be flattened into the + string-typed `instructions` field. + """ + + @model_validator(mode="before") + @classmethod + def _apply_codex_defaults(cls, values: dict[str, Any]) -> dict[str, Any]: + """Apply Codex-specific defaults before the parent validator runs.""" + if not isinstance(values, dict): + return values + for key, forced in _FORCED_VALUES.items(): + supplied = values.get(key) + if supplied is not None and supplied != forced: + msg = ( + f"`ChatOpenAICodex` requires `{key}={forced!r}`; " + f"got `{key}={supplied!r}`. Use `ChatOpenAI` if you " + "need to customize this." + ) + raise ValueError(msg) + values[key] = forced + # Pin `base_url` (and its legacy `openai_api_base` alias) to the Codex + # endpoint. The OAuth bearer token is wired in as `api_key` below, so a + # caller-controlled `base_url` would otherwise exfiltrate the token to + # an attacker-chosen host. Reject any non-matching override rather than + # silently rewriting it, mirroring the `_FORCED_VALUES` contract. + for key in ("base_url", "openai_api_base"): + supplied = values.get(key) + if supplied is not None and supplied != CHATGPT_CODEX_BASE_URL: + msg = ( + f"`ChatOpenAICodex` requires `{key}={CHATGPT_CODEX_BASE_URL!r}`; " + f"got `{key}={supplied!r}`. Use `ChatOpenAI` if you need to " + "target a different endpoint." + ) + raise ValueError(msg) + values[key] = CHATGPT_CODEX_BASE_URL + + provider = values.get("token_provider") + if provider is None: + provider = FileChatGPTOAuthTokenProvider.from_default_store() + values["token_provider"] = provider + if not isinstance(provider, ChatGPTOAuthTokenProvider): + msg = ( + "`token_provider` must implement the " + "`ChatGPTOAuthTokenProvider` protocol." + ) + raise TypeError(msg) + + # The OAuth `token_provider` is the sole auth source: its access token + # is wired into the OpenAI SDK as `api_key` below. A caller-supplied + # `api_key` (or its `openai_api_key` alias) would silently win over the + # OAuth bearer, leaving the model in a conflicting state — so reject it + # (raise-don't-rewrite, mirroring the `base_url` handling above). An + # `OPENAI_API_KEY` env var is not consulted: the field's default + # factory never runs because `api_key` is always set here. + for key in ("api_key", "openai_api_key"): + if values.get(key) is not None: + msg = ( + f"`ChatOpenAICodex` manages authentication via " + f"`token_provider`; drop the explicit `{key}=`. Use " + "`ChatOpenAI` if you want API-key authentication." + ) + raise ValueError(msg) + values["api_key"] = _SyncTokenCallable(provider) + return values + + def _codex_headers_sync(self) -> dict[str, str]: + token = self.token_provider.get_token() + return self._build_headers(token.account_id) + + def _build_headers(self, account_id: str | None) -> dict[str, str]: + headers: dict[str, str] = {} + if account_id: + headers[ACCOUNT_ID_HEADER] = account_id + if self.originator is not None: + headers[ORIGINATOR_HEADER] = self.originator + return headers + + def _merge_codex_headers( + self, payload: dict[str, Any], headers: dict[str, str] + ) -> dict[str, Any]: + # Caller-supplied `extra_headers` win over our Codex defaults so + # users can override (e.g., to send a different `originator`). + if not headers: + return payload + merged = {**headers, **(payload.get("extra_headers") or {})} + payload["extra_headers"] = merged + return payload + + def _get_request_payload( + self, + input_: LanguageModelInput, + *, + stop: list[str] | None = None, + **kwargs: Any, + ) -> dict: + """Build the request payload and attach Codex auth headers. + + Lifts any `SystemMessage` content out of the input list into the + top-level `instructions` field, since Codex rejects `SystemMessage` + chat turns. See the `instructions` field docstring for the + precedence rules. + + Fast path: when the input can't carry a `SystemMessage`, skip the + local conversion and delegate `input_` straight to super — that + way `_convert_input` only runs once (inside super) instead of once + here and again there. + """ + payload_input: LanguageModelInput = input_ + if _maybe_has_system_messages(input_): + messages = self._convert_input(input_).to_messages() + system_messages = [m for m in messages if _is_instruction_message(m)] + if system_messages: + non_system = [m for m in messages if not _is_instruction_message(m)] + lifted = _flatten_system_message_content(system_messages) + explicit = kwargs.get("instructions") + if explicit is not None: + logger.warning( + "Both `instructions=` and a `SystemMessage` were " + "provided; the explicit `instructions=` kwarg wins " + "and the `SystemMessage` content is discarded for " + "this call. Discarded length: %d.", + len(lifted), + ) + else: + kwargs["instructions"] = lifted + payload_input = non_system + + payload = super()._get_request_payload(payload_input, stop=stop, **kwargs) + # The Codex backend rejects requests without `instructions` — populate + # the field's value if the caller didn't supply one. An explicit empty + # string from the caller is preserved (the backend will reject it, but + # silently overwriting it would hide a programming error). + if payload.get("instructions") is None: + payload["instructions"] = self.instructions + return self._merge_codex_headers(payload, self._codex_headers_sync()) + + async def _agenerate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + # Prime the cache via async refresh so the sync header build that + # happens inside `super()._agenerate` does not block the event loop. + await self.token_provider.aget_token() + return await super()._agenerate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + + async def _astream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + await self.token_provider.aget_token() + async for chunk in super()._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ): + yield chunk + + def _get_ls_params( + self, stop: list[str] | None = None, **kwargs: Any + ) -> LangSmithParams: + params = super()._get_ls_params(stop=stop, **kwargs) + params["ls_provider"] = "openai-codex" + return params + + @property + def _llm_type(self) -> str: + return "openai-codex-chat" + + @classmethod + def is_lc_serializable(cls) -> bool: + """`ChatOpenAICodex` is not serializable (holds a live token provider).""" + return False + + +class _SyncTokenCallable: + """Sync callable wrapper around a token provider for the OpenAI SDK. + + The OpenAI Python SDK accepts a callable returning a string for `api_key`. + Wrapping the provider lets the SDK fetch a freshly-refreshed access token + on every request without exposing the provider's other methods. + """ + + __slots__ = ("_provider",) + + def __init__(self, provider: ChatGPTOAuthTokenProvider) -> None: + self._provider = provider + + def __call__(self) -> str: + return self._provider.get_access_token() + + +__all__ = ["ChatOpenAICodex"] diff --git a/libs/partners/openai/langchain_openai/chatgpt_oauth.py b/libs/partners/openai/langchain_openai/chatgpt_oauth.py new file mode 100644 index 00000000000..b64d2376261 --- /dev/null +++ b/libs/partners/openai/langchain_openai/chatgpt_oauth.py @@ -0,0 +1,1058 @@ +"""ChatGPT OAuth helpers for `ChatOpenAICodex`. + +Implements OAuth 2.0 Authorization Code Flow with PKCE against the OpenAI +auth endpoints used by Codex/ChatGPT subscription auth, plus a small file-backed +token store and refresh logic. + +These helpers exist to keep login and token management *separate* from model +invocation. `ChatOpenAICodex` only consumes a `ChatGPTOAuthTokenProvider`. + +!!! warning + This is provider-specific subscription auth and is independent from the + standard OpenAI API-key flow used by `ChatOpenAI`. Refresh-token rotation + against `~/.codex/auth.json` can break Codex CLI / VS Code sessions, so + the default store lives at `~/.langchain/chatgpt-auth.json`. +""" + +from __future__ import annotations + +import asyncio +import base64 +import contextlib +import hashlib +import html +import http.server +import ipaddress +import json +import logging +import os +import secrets +import threading +import time +import urllib.parse +import webbrowser +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal, Protocol, runtime_checkable + +import httpx + +if TYPE_CHECKING: + from collections.abc import Iterator + +logger = logging.getLogger(__name__) + + +CHATGPT_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" +CHATGPT_AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize" +CHATGPT_TOKEN_URL = "https://auth.openai.com/oauth/token" # noqa: S105 +CHATGPT_DEVICE_CODE_URL = "https://auth.openai.com/api/accounts/deviceauth/usercode" +CHATGPT_DEVICE_TOKEN_URL = "https://auth.openai.com/api/accounts/deviceauth/token" # noqa: S105 +CHATGPT_DEVICE_REDIRECT_URI = "https://auth.openai.com/deviceauth/callback" +CHATGPT_AUTH_CLAIMS_NAMESPACE = "https://api.openai.com/auth" +DEFAULT_REDIRECT_HOST = "localhost" +DEFAULT_REDIRECT_PORT = 1455 +DEFAULT_REDIRECT_PATH = "/auth/callback" +DEFAULT_SCOPE = "openid profile email offline_access" +DEFAULT_REFRESH_SKEW = timedelta(minutes=5) +DEFAULT_STORE_PATH = Path.home() / ".langchain" / "chatgpt-auth.json" + + +@dataclass(frozen=True) +class ChatGPTToken: + """A ChatGPT OAuth token bundle. + + `expires_at` is timezone-aware. The JWT-derived optionals (`account_id`, + `plan_type`, `user_id`) are populated when decodable from the `id_token`; + `id_token` itself is the raw token, not derived from it. Secret-bearing + fields (`access_token`, `refresh_token`, `id_token`) are excluded from the + default `repr` so the token does not leak into logs or tracebacks. + + Instances are frozen: the constructor invariants below hold for the life of + the object, which matters because providers cache and share a single token + and replace it wholesale on refresh rather than mutating fields in place. + """ + + access_token: str = field(repr=False) + refresh_token: str = field(repr=False) + expires_at: datetime + account_id: str | None = None + plan_type: str | None = None + user_id: str | None = None + id_token: str | None = field(default=None, repr=False) + + def __post_init__(self) -> None: + """Validate non-empty secrets and timezone-aware `expires_at`.""" + if not self.access_token: + msg = "`access_token` must be a non-empty string." + raise ValueError(msg) + if not self.refresh_token: + msg = "`refresh_token` must be a non-empty string." + raise ValueError(msg) + if self.expires_at.tzinfo is None: + msg = "`expires_at` must be timezone-aware (UTC)." + raise ValueError(msg) + + def is_expired(self, *, skew: timedelta = DEFAULT_REFRESH_SKEW) -> bool: + """Return `True` if the token is past (or within `skew` of) expiry.""" + return datetime.now(timezone.utc) >= (self.expires_at - skew) + + +class ChatGPTOAuthRefreshError(RuntimeError): + """Raised when a refresh-token grant fails irrecoverably. + + Typically signals that the stored refresh token has been revoked or has + expired; the caller should re-run `login_chatgpt()` (or the device-code + equivalent) to obtain a new bundle. + """ + + +@runtime_checkable +class ChatGPTOAuthTokenProvider(Protocol): + """Refresh-aware token source consumed by `ChatOpenAICodex`.""" + + def get_token(self) -> ChatGPTToken: + """Return a current token, refreshing if necessary.""" + ... + + async def aget_token(self) -> ChatGPTToken: + """Async variant of `get_token`. + + Implementations must offer the same locking and refresh guarantees + as `get_token`: concurrent callers must not race on token storage. + """ + ... + + def get_access_token(self) -> str: + """Return only the access token string (sync callable for SDKs).""" + ... + + async def aget_access_token(self) -> str: + """Return only the access token string (async callable for SDKs).""" + ... + + +def _b64url_decode_segment(segment: str) -> bytes: + """Decode a single base64url JWT segment, handling missing padding.""" + padding = "=" * (-len(segment) % 4) + return base64.urlsafe_b64decode(segment + padding) + + +def decode_jwt_claims(token: str) -> dict[str, Any]: + """Decode a JWT's payload without signature verification. + + !!! danger + This is for *local claim extraction only*. Never use the returned + claims for security or authorization decisions. + + Args: + token: A JWT (`header.payload.signature`). + + Returns: + Decoded payload as a dict. Returns an empty dict if the token is + malformed. + """ + if not token or token.count(".") < 2: + return {} + try: + _, payload, _ = token.split(".", 2) + return json.loads(_b64url_decode_segment(payload)) + except (ValueError, json.JSONDecodeError, UnicodeDecodeError): + return {} + + +def _extract_chatgpt_claims(id_token: str | None) -> dict[str, str | None]: + """Pull the ChatGPT account/plan/user IDs out of an ID-token JWT.""" + out: dict[str, str | None] = { + "account_id": None, + "plan_type": None, + "user_id": None, + } + if not id_token: + return out + claims = decode_jwt_claims(id_token) + auth = claims.get(CHATGPT_AUTH_CLAIMS_NAMESPACE) or {} + if isinstance(auth, dict): + out["account_id"] = auth.get("chatgpt_account_id") + out["plan_type"] = auth.get("chatgpt_plan_type") + out["user_id"] = auth.get("chatgpt_user_id") + if out["account_id"] is None: + # A present-but-unparseable id_token (or one missing the namespaced + # auth claim) silently drops the `ChatGPT-Account-Id` header, which + # surfaces later as an opaque backend rejection. Leave a breadcrumb. + logger.debug( + "No `chatgpt_account_id` claim extracted from the ChatGPT " + "id_token; the `ChatGPT-Account-Id` header will be omitted." + ) + return out + + +def _expires_at_from_response(payload: dict[str, Any]) -> datetime: + raw = payload.get("expires_in") + try: + expires_in = int(raw) if raw is not None else 0 + except (TypeError, ValueError) as exc: + msg = f"OAuth token response had invalid `expires_in`: {raw!r}" + raise ChatGPTOAuthRefreshError(msg) from exc + if expires_in <= 0: + msg = ( + "OAuth token response had missing or non-positive `expires_in`; " + "refusing to store an immediately-expired token." + ) + raise ChatGPTOAuthRefreshError(msg) + return datetime.now(timezone.utc) + timedelta(seconds=expires_in) + + +def _token_from_response( + payload: dict[str, Any], + *, + fallback_refresh_token: str | None = None, +) -> ChatGPTToken: + """Build a `ChatGPTToken` from an OAuth token-endpoint response.""" + if not payload.get("access_token"): + msg = "OAuth token response did not include an `access_token`." + raise ChatGPTOAuthRefreshError(msg) + id_token = payload.get("id_token") + claims = _extract_chatgpt_claims(id_token) + refresh_token = payload.get("refresh_token") or fallback_refresh_token + if not refresh_token: + msg = ( + "OAuth token response did not include a `refresh_token` and no " + "prior refresh token was available; re-run `login_chatgpt()`." + ) + raise ChatGPTOAuthRefreshError(msg) + return ChatGPTToken( + access_token=payload["access_token"], + refresh_token=refresh_token, + expires_at=_expires_at_from_response(payload), + account_id=claims["account_id"], + plan_type=claims["plan_type"], + user_id=claims["user_id"], + id_token=id_token, + ) + + +def _serialize_token(token: ChatGPTToken) -> dict[str, Any]: + return { + "access_token": token.access_token, + "refresh_token": token.refresh_token, + "expires_at": token.expires_at.astimezone(timezone.utc).isoformat(), + "account_id": token.account_id, + "plan_type": token.plan_type, + "user_id": token.user_id, + "id_token": token.id_token, + } + + +def _deserialize_token(data: dict[str, Any]) -> ChatGPTToken: + expires_at_raw = data.get("expires_at") + if isinstance(expires_at_raw, str): + expires_at = datetime.fromisoformat(expires_at_raw) + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + elif isinstance(expires_at_raw, (int, float)): + expires_at = datetime.fromtimestamp(expires_at_raw, tz=timezone.utc) + else: + msg = "Stored token is missing `expires_at`." + raise ValueError(msg) + return ChatGPTToken( + access_token=data["access_token"], + refresh_token=data["refresh_token"], + expires_at=expires_at, + account_id=data.get("account_id"), + plan_type=data.get("plan_type"), + user_id=data.get("user_id"), + id_token=data.get("id_token"), + ) + + +def _chmod_warn(path: Path, mode: int) -> None: + """Best-effort `chmod` that logs (but does not raise) on failure. + + On filesystems without POSIX perms (Windows, some FUSE/SMB mounts) the + file may end up world-readable. Logging surfaces that to operators so + they don't silently trust the "private perms" claim of the caller. + """ + try: + os.chmod(path, mode) # noqa: PTH101 + except (OSError, NotImplementedError) as exc: + logger.warning( + "Failed to set permissions %o on %s: %s — token store may not " + "have private permissions on this filesystem.", + mode, + path, + exc, + ) + + +def _atomic_write_private_json(path: Path, data: dict[str, Any]) -> None: + """Write `data` as JSON to `path` with 0600 perms (where supported).""" + parent = path.parent + parent.mkdir(parents=True, exist_ok=True) + _chmod_warn(parent, 0o700) + tmp = path.with_suffix(path.suffix + ".tmp") + payload = json.dumps(data, indent=2, sort_keys=True) + flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC + fd = os.open(tmp, flags, 0o600) + try: + with os.fdopen(fd, "w", encoding="utf-8") as fh: + fh.write(payload) + except Exception: + with contextlib.suppress(OSError): + tmp.unlink() + raise + tmp.replace(path) + _chmod_warn(path, 0o600) + + +@contextlib.contextmanager +def _file_lock(path: Path) -> Iterator[None]: + """Best-effort cross-platform file lock around refresh + write. + + On POSIX this acquires an exclusive `fcntl.flock` on a sibling + `.lock` file. On Windows (or any platform where `fcntl` is + unavailable) the lock degrades to a no-op and a warning is logged so + callers know that cross-process safety is best-effort. + """ + lock_path = path.with_suffix(path.suffix + ".lock") + lock_path.parent.mkdir(parents=True, exist_ok=True) + fd = os.open(lock_path, os.O_CREAT | os.O_RDWR, 0o600) + locked = False + try: + try: + import fcntl + except ImportError: + logger.warning( + "fcntl is unavailable on this platform; ChatGPT token store " + "at %s is not protected against cross-process races.", + path, + ) + else: + try: + fcntl.flock(fd, fcntl.LOCK_EX) + locked = True + except OSError as exc: + logger.warning( + "fcntl.flock failed on %s: %s — token store is not " + "protected against cross-process races.", + lock_path, + exc, + ) + yield + finally: + if locked: + try: + import fcntl + + fcntl.flock(fd, fcntl.LOCK_UN) + except (ImportError, OSError) as exc: + logger.warning("Failed to release file lock on %s: %s", lock_path, exc) + os.close(fd) + + +def _redact(value: str | None) -> str: + if not value: + return "" + return f"" + + +def _parse_oauth_error(resp: httpx.Response) -> tuple[str | None, str]: + """Return `(error_code, body_excerpt)` from an OAuth error response.""" + try: + payload = resp.json() + except (ValueError, json.JSONDecodeError): + return None, resp.text[:500] + if isinstance(payload, dict): + error = payload.get("error") + description = payload.get("error_description") or "" + excerpt = f"{error}: {description}".strip(": ") or resp.text[:500] + return (error if isinstance(error, str) else None), excerpt + return None, resp.text[:500] + + +def _raise_for_oauth_response(url: str, resp: httpx.Response) -> None: + if resp.status_code < 400: + return + error_code, excerpt = _parse_oauth_error(resp) + if error_code == "invalid_grant": + msg = ( + "ChatGPT refresh token is no longer valid (`invalid_grant`). " + "Re-run `login_chatgpt()` to obtain a new token." + ) + raise ChatGPTOAuthRefreshError(msg) + msg = f"OAuth request to {url} failed with status {resp.status_code}: {excerpt}" + raise RuntimeError(msg) + + +def _post_form( + url: str, + data: dict[str, str], + *, + timeout: float = 30.0, +) -> dict[str, Any]: + """POST a form payload and return the parsed JSON body.""" + with httpx.Client(timeout=timeout) as client: + resp = client.post( + url, + data=data, + headers={"Accept": "application/json"}, + ) + _raise_for_oauth_response(url, resp) + return resp.json() + + +_DEVICE_POLL_PENDING_ERRORS = frozenset({"authorization_pending", "slow_down"}) + + +def _post_device_poll_form( + url: str, + data: dict[str, str], + *, + timeout: float = 30.0, +) -> dict[str, Any]: + """POST a device-code poll and return expected pending error payloads.""" + with httpx.Client(timeout=timeout) as client: + resp = client.post( + url, + data=data, + headers={"Accept": "application/json"}, + ) + if resp.status_code < 400: + return resp.json() + error_code, _ = _parse_oauth_error(resp) + if error_code in _DEVICE_POLL_PENDING_ERRORS: + return resp.json() + _raise_for_oauth_response(url, resp) + return resp.json() + + +async def _apost_form( + url: str, + data: dict[str, str], + *, + timeout: float = 30.0, +) -> dict[str, Any]: + """POST a form payload asynchronously and return the parsed JSON body.""" + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.post( + url, + data=data, + headers={"Accept": "application/json"}, + ) + _raise_for_oauth_response(url, resp) + return resp.json() + + +@dataclass +class FileChatGPTOAuthTokenProvider: + """File-backed `ChatGPTOAuthTokenProvider`. + + Stores tokens at `path` (defaults to `DEFAULT_STORE_PATH`) with private + permissions and refreshes them on read when they are within + `refresh_skew` of expiry. Refresh token rotation is preserved across + writes: if the OAuth response omits `refresh_token`, the existing one is + reused. + + !!! warning + The default path is intentionally distinct from `~/.codex/auth.json` + so that refresh-token rotation here does not invalidate Codex CLI / + VS Code sessions. + """ + + path: Path = field(default_factory=lambda: DEFAULT_STORE_PATH) + client_id: str = CHATGPT_CLIENT_ID + token_url: str = CHATGPT_TOKEN_URL + refresh_skew: timedelta = DEFAULT_REFRESH_SKEW + timeout: float = 30.0 + _cached: ChatGPTToken | None = field(default=None, init=False, repr=False) + _lock: threading.Lock = field( + default_factory=threading.Lock, init=False, repr=False + ) + + @classmethod + def from_default_store(cls) -> FileChatGPTOAuthTokenProvider: + """Construct a provider with all defaults (path, client ID, etc.). + + Equivalent to `FileChatGPTOAuthTokenProvider()`; the alias exists as + a discoverable entry point for callers reading the default-path + contract from the module docstring. + """ + return cls() + + def _read_from_disk(self) -> ChatGPTToken | None: + """Return the stored token, or `None` if no store exists. + + Raises `RuntimeError` (rather than returning `None`) if the file + exists but cannot be parsed — that way the user is not told to + "re-login" when the actual fix is to repair or remove a corrupt + store at `self.path`. + """ + if not self.path.exists(): + return None + try: + raw_text = self.path.read_text(encoding="utf-8") + except (OSError, UnicodeDecodeError) as exc: + msg = ( + f"Failed to read ChatGPT token store at {self.path}: {exc}. " + "Repair file permissions/encoding or delete the file and " + "re-run `login_chatgpt()`." + ) + raise RuntimeError(msg) from exc + try: + data = json.loads(raw_text) + except json.JSONDecodeError as exc: + msg = ( + f"ChatGPT token store at {self.path} is not valid JSON: " + f"{exc}. Delete the file and re-run `login_chatgpt()`." + ) + raise RuntimeError(msg) from exc + try: + return _deserialize_token(data) + except (KeyError, ValueError) as exc: + msg = ( + f"ChatGPT token store at {self.path} is missing required " + f"fields ({exc}). Delete the file and re-run " + "`login_chatgpt()`." + ) + raise RuntimeError(msg) from exc + + def _write_to_disk(self, token: ChatGPTToken) -> None: + _atomic_write_private_json(self.path, _serialize_token(token)) + + def save(self, token: ChatGPTToken) -> None: + """Persist `token` to disk and cache it in memory.""" + with self._lock, _file_lock(self.path): + self._write_to_disk(token) + self._cached = token + + def _build_refresh_payload(self, refresh_token: str) -> dict[str, str]: + return { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": self.client_id, + } + + def _apply_refresh_response( + self, response: dict[str, Any], previous_refresh: str + ) -> ChatGPTToken: + token = _token_from_response(response, fallback_refresh_token=previous_refresh) + self._write_to_disk(token) + self._cached = token + return token + + def _refresh_sync(self, existing: ChatGPTToken) -> ChatGPTToken: + logger.debug( + "Refreshing ChatGPT access token (refresh_token=%s).", + _redact(existing.refresh_token), + ) + response = _post_form( + self.token_url, + self._build_refresh_payload(existing.refresh_token), + timeout=self.timeout, + ) + return self._apply_refresh_response(response, existing.refresh_token) + + def _load_existing(self) -> ChatGPTToken: + existing = self._cached or self._read_from_disk() + if existing is None: + msg = ( + f"No ChatGPT OAuth token found at {self.path}. Run " + "`langchain_openai.chatgpt_oauth.login_chatgpt()` first." + ) + raise FileNotFoundError(msg) + return existing + + def _load_existing_before_refresh(self) -> ChatGPTToken: + existing = self._load_existing() + if not existing.is_expired(skew=self.refresh_skew): + return existing + disk_token = self._read_from_disk() + if disk_token is not None: + self._cached = disk_token + return disk_token + return existing + + def get_token(self) -> ChatGPTToken: + """Return a fresh token, refreshing on disk if needed. + + Raises: + FileNotFoundError: No token store exists at `self.path`; run + `login_chatgpt()` first. + ChatGPTOAuthRefreshError: The stored refresh token was rejected + (e.g. revoked or expired); re-run `login_chatgpt()`. + """ + with self._lock, _file_lock(self.path): + existing = self._load_existing_before_refresh() + if not existing.is_expired(skew=self.refresh_skew): + self._cached = existing + return existing + return self._refresh_sync(existing) + + async def aget_token(self) -> ChatGPTToken: + """Async variant of `get_token` with the same locking guarantees. + + The thread lock and cross-process file lock are acquired off the + event loop via `asyncio.to_thread` so concurrent async callers do + not race on `_cached` or on the on-disk token bundle. The HTTP + refresh runs synchronously inside that worker thread; this avoids + nesting event loops while still keeping the cross-process lock + held for the entire refresh + write window. + + Raises: + FileNotFoundError: No token store exists at `self.path`; run + `login_chatgpt()` first. + ChatGPTOAuthRefreshError: The stored refresh token was rejected + (e.g. revoked or expired); re-run `login_chatgpt()`. + """ + return await asyncio.to_thread(self._aget_token_locked_blocking) + + def _aget_token_locked_blocking(self) -> ChatGPTToken: + with self._lock, _file_lock(self.path): + existing = self._load_existing_before_refresh() + if not existing.is_expired(skew=self.refresh_skew): + self._cached = existing + return existing + return self._refresh_sync(existing) + + def get_access_token(self) -> str: + """Return only the access-token string.""" + return self.get_token().access_token + + async def aget_access_token(self) -> str: + """Return only the access-token string (async).""" + token = await self.aget_token() + return token.access_token + + +def _generate_pkce_pair() -> tuple[str, str]: + """Return a `(code_verifier, code_challenge)` pair using S256.""" + verifier = ( + base64.urlsafe_b64encode(secrets.token_bytes(64)).rstrip(b"=").decode("ascii") + ) + digest = hashlib.sha256(verifier.encode("ascii")).digest() + challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + return verifier, challenge + + +def _build_authorize_url( + *, + client_id: str, + redirect_uri: str, + state: str, + code_challenge: str, + scope: str = DEFAULT_SCOPE, + extra_params: dict[str, str] | None = None, +) -> str: + params = { + "client_id": client_id, + "response_type": "code", + "redirect_uri": redirect_uri, + "scope": scope, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "state": state, + } + if extra_params: + params.update(extra_params) + return f"{CHATGPT_AUTHORIZE_URL}?{urllib.parse.urlencode(params)}" + + +class _CallbackHandler(http.server.BaseHTTPRequestHandler): + server_result: dict[str, str] = {} + callback_path: str = DEFAULT_REDIRECT_PATH + + def do_GET(self) -> None: + parsed = urllib.parse.urlparse(self.path) + if parsed.path != self.callback_path: + # Surface path mismatches: otherwise a misconfigured + # `callback_path` looks identical to "still waiting" and only + # ends in a generic timeout. (Path only — never the query, which + # carries the auth code.) + logger.debug( + "Ignoring callback request for unexpected path %r (expected %r).", + parsed.path, + self.callback_path, + ) + self.send_response(404) + self.end_headers() + return + query = urllib.parse.parse_qs(parsed.query) + for key in ("code", "state", "error", "error_description"): + value = query.get(key) + if value: + self.server_result[key] = value[0] + self.send_response(200) + self.send_header("Content-Type", "text/html; charset=utf-8") + self.end_headers() + error = self.server_result.get("error") + if error: + error_description = self.server_result.get("error_description") + logger.error( + "ChatGPT OAuth callback returned error %r (%s)", + error, + error_description or "no description", + ) + if error_description: + description = f"{error_description} (error: {error})" + else: + description = ( + f"ChatGPT returned error '{error}'. Close this tab and " + "try `login_chatgpt()` again." + ) + body = _oauth_error_html(description) + else: + body = _oauth_success_html( + "ChatGPT sign-in complete. You can close this browser tab " + "and return to your terminal.", + ) + self.wfile.write(body.encode("utf-8")) + + def log_message(self, format: str, *args: Any) -> None: # noqa: A002 + # Don't leak callback URLs (which contain auth codes) into stderr. + return + + +def _oauth_success_html(message: str) -> str: + return _oauth_result_html( + title="ChatGPT sign-in complete", + heading="You're signed in", + message=message, + status="success", + ) + + +def _oauth_error_html(message: str) -> str: + return _oauth_result_html( + title="ChatGPT sign-in failed", + heading="Sign-in failed", + message=message, + status="error", + ) + + +def _oauth_result_html( + *, + title: str, + heading: str, + message: str, + status: Literal["success", "error"], +) -> str: + accent = "#137333" if status == "success" else "#b3261e" + background = "#eef7f0" if status == "success" else "#fceeee" + mark = "✓" if status == "success" else "!" + escaped_title = html.escape(title) + escaped_heading = html.escape(heading) + escaped_message = html.escape(message) + return ( + '' + '' + f"{escaped_title}" + "" + '
' + f'
' + f"{mark}
" + f"

{escaped_heading}

{escaped_message}

" + "
" + "" + ) + + +def _wait_for_callback( + *, + host: str, + port: int, + callback_path: str, + timeout: float, +) -> dict[str, str]: + class _BoundCallbackHandler(_CallbackHandler): + server_result: dict[str, str] = {} + + _BoundCallbackHandler.callback_path = callback_path + try: + server = http.server.HTTPServer((host, port), _BoundCallbackHandler) + except OSError as exc: + msg = ( + f"Could not bind ChatGPT OAuth callback server on " + f"http://{host}:{port}: {exc}. Free the port or pass `port=` " + "to `login_chatgpt()` with an unused port." + ) + raise RuntimeError(msg) from exc + server.timeout = 1.0 + deadline = time.monotonic() + timeout + try: + while time.monotonic() < deadline: + server.handle_request() + if _BoundCallbackHandler.server_result.get( + "code" + ) or _BoundCallbackHandler.server_result.get("error"): + return dict(_BoundCallbackHandler.server_result) + finally: + server.server_close() + msg = f"Timed out waiting for ChatGPT OAuth callback on http://{host}:{port}" + raise TimeoutError(msg) + + +def _validate_loopback_host(host: str) -> None: + """Reject non-loopback callback hosts. + + The callback server receives the OAuth authorization `code` in the request + URL. Binding it to a non-loopback interface (e.g. `0.0.0.0`) would expose + that code on the local network, so only loopback hosts are permitted — + RFC 8252 §8.3 expects a loopback redirect for native-app PKCE flows. + + Args: + host: The callback host passed to `login_chatgpt`. + + Raises: + ValueError: `host` is not `localhost` or a loopback IP address. + """ + if host == "localhost": + return + try: + is_loopback = ipaddress.ip_address(host).is_loopback + except ValueError: + # Not an IP literal and not `localhost`; can't prove it's loopback. + is_loopback = False + if not is_loopback: + msg = ( + f"`host={host!r}` is not a loopback address. The OAuth callback " + "server receives the authorization code in the request URL, so it " + "must bind to a loopback interface (`localhost`, `127.0.0.1`, or " + "`::1`) to avoid exposing the code on the network." + ) + raise ValueError(msg) + + +def login_chatgpt( + *, + store_path: Path | None = None, + client_id: str = CHATGPT_CLIENT_ID, + host: str = DEFAULT_REDIRECT_HOST, + port: int = DEFAULT_REDIRECT_PORT, + callback_path: str = DEFAULT_REDIRECT_PATH, + scope: str = DEFAULT_SCOPE, + open_browser: bool = True, + timeout: float = 300.0, +) -> FileChatGPTOAuthTokenProvider: + """Run the ChatGPT OAuth 2.0 Authorization Code Flow with PKCE. + + Starts a loopback callback server, optionally opens a browser to the + OpenAI authorize endpoint (when `open_browser=True`; the URL is always + printed as a fallback), exchanges the returned code for tokens, and + persists them via `FileChatGPTOAuthTokenProvider`. + + Args: + store_path: Where to persist the token. Defaults to + `DEFAULT_STORE_PATH`. + client_id: OAuth client ID (defaults to Codex/ChatGPT client). + host: Local callback host. Must be a loopback address. + port: Local callback port. + callback_path: Local callback path. + scope: OAuth scope string. + open_browser: Whether to launch the system browser. + timeout: Seconds to wait for the callback. + + Returns: + A `FileChatGPTOAuthTokenProvider` ready for use by + `ChatOpenAICodex`. + + Raises: + ValueError: `host` is not a loopback address. + RuntimeError: The callback server could not bind, the `state` did not + match (CSRF), the provider returned an OAuth error, or no + authorization code was returned. + TimeoutError: No callback was received within `timeout` seconds. + + See Also: + `login_chatgpt_device`: Headless fallback for environments without a + browser or the ability to bind a localhost callback port. + """ + _validate_loopback_host(host) + redirect_uri = f"http://{host}:{port}{callback_path}" + state = secrets.token_urlsafe(32) + verifier, challenge = _generate_pkce_pair() + authorize_url = _build_authorize_url( + client_id=client_id, + redirect_uri=redirect_uri, + state=state, + code_challenge=challenge, + scope=scope, + ) + + # Surface the URL prominently so users can complete sign-in manually if + # the browser launch fails or the environment is headless. + print( # noqa: T201 + f"\nChatGPT sign-in: open the following URL in a browser:\n {authorize_url}\n" + ) + logger.info("Opening ChatGPT sign-in flow at %s", CHATGPT_AUTHORIZE_URL) + if open_browser: + try: + webbrowser.open(authorize_url) + except webbrowser.Error as exc: + logger.warning( + "Could not launch a browser: %s. Copy the URL above instead.", + exc, + ) + + result = _wait_for_callback( + host=host, port=port, callback_path=callback_path, timeout=timeout + ) + # Validate `state` first: a CSRF mismatch is a security signal and must + # fail closed before any other branch (including server-reported errors) + # is considered. + if result.get("state") != state: + msg = "ChatGPT OAuth callback state mismatch." + raise RuntimeError(msg) + if "error" in result: + description = result.get("error_description", "") + msg = f"ChatGPT OAuth callback returned error: {result['error']} {description}" + raise RuntimeError(msg) + code = result.get("code") + if not code: + msg = "ChatGPT OAuth callback did not include an authorization code." + raise RuntimeError(msg) + + response = _post_form( + CHATGPT_TOKEN_URL, + { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "client_id": client_id, + "code_verifier": verifier, + }, + ) + token = _token_from_response(response) + provider = FileChatGPTOAuthTokenProvider( + path=store_path or DEFAULT_STORE_PATH, client_id=client_id + ) + provider.save(token) + return provider + + +def login_chatgpt_device( + *, + store_path: Path | None = None, + client_id: str = CHATGPT_CLIENT_ID, + poll_interval: float = 5.0, + timeout: float = 600.0, +) -> FileChatGPTOAuthTokenProvider: + """Run the ChatGPT device-code OAuth flow. + + This is the headless fallback for environments without a browser. The + function prints a verification URL and user code, polls for completion, + then exchanges the resulting code via the OAuth token endpoint using + `CHATGPT_DEVICE_REDIRECT_URI`. + + Args: + store_path: Where to persist the token. Defaults to + `DEFAULT_STORE_PATH`. + client_id: OAuth client ID (defaults to Codex/ChatGPT client). + poll_interval: Seconds between polls. + timeout: Total seconds to wait. + + Returns: + A configured `FileChatGPTOAuthTokenProvider`. + + Raises: + RuntimeError: The device-code response was missing required fields, or + device authorization failed with a terminal error. + TimeoutError: Authorization was not completed within `timeout` seconds. + + See Also: + `login_chatgpt`: Browser-based loopback flow preferred when a local + browser and free callback port are available. + """ + _verifier, challenge = _generate_pkce_pair() + start = _post_form( + CHATGPT_DEVICE_CODE_URL, + { + "client_id": client_id, + "scope": DEFAULT_SCOPE, + "code_challenge": challenge, + "code_challenge_method": "S256", + }, + ) + device_code = start.get("device_code") + user_code = start.get("user_code") + verification_uri = start.get("verification_uri") or start.get( + "verification_uri_complete" + ) + if not (device_code and user_code and verification_uri): + msg = "ChatGPT device-code response missing required fields." + raise RuntimeError(msg) + logger.info( + "Open %s in a browser and enter user code: %s", verification_uri, user_code + ) + + deadline = time.monotonic() + timeout + authorization_code: str | None = None + current_interval = poll_interval + while time.monotonic() < deadline: + poll = _post_device_poll_form( + CHATGPT_DEVICE_TOKEN_URL, + {"client_id": client_id, "device_code": device_code}, + ) + if poll.get("authorization_code"): + authorization_code = poll["authorization_code"] + break + error = poll.get("error") + if error == "slow_down": + # RFC 8628 §3.5: bump the poll interval by 5 seconds to comply + # with server-side rate limiting; otherwise we risk being banned. + current_interval += 5 + elif error and error != "authorization_pending": + msg = f"Device authorization failed: {error}" + raise RuntimeError(msg) + time.sleep(current_interval) + if not authorization_code: + msg = "Timed out waiting for ChatGPT device authorization." + raise TimeoutError(msg) + + response = _post_form( + CHATGPT_TOKEN_URL, + { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": CHATGPT_DEVICE_REDIRECT_URI, + "client_id": client_id, + "code_verifier": _verifier, + }, + ) + token = _token_from_response(response) + provider = FileChatGPTOAuthTokenProvider( + path=store_path or DEFAULT_STORE_PATH, client_id=client_id + ) + provider.save(token) + return provider + + +__all__ = [ + "CHATGPT_AUTHORIZE_URL", + "CHATGPT_CLIENT_ID", + "CHATGPT_TOKEN_URL", + "ChatGPTOAuthRefreshError", + "ChatGPTOAuthTokenProvider", + "ChatGPTToken", + "FileChatGPTOAuthTokenProvider", + "decode_jwt_claims", + "login_chatgpt", + "login_chatgpt_device", +] diff --git a/libs/partners/openai/scripts/RECORD_CODEX_CASSETTES.md b/libs/partners/openai/scripts/RECORD_CODEX_CASSETTES.md new file mode 100644 index 00000000000..3dba9ee5f9c --- /dev/null +++ b/libs/partners/openai/scripts/RECORD_CODEX_CASSETTES.md @@ -0,0 +1,105 @@ +# Recording `ChatOpenAICodex` VCR cassettes + +`ChatOpenAICodex` authenticates with a ChatGPT subscription OAuth bundle, so +its integration tests cannot run in PR CI without a live login. The workflow +below records cassettes once locally with a real token, scrubs OAuth secrets +before they hit disk, and commits the cassettes so CI replays them through +`_test_vcr.yml` (the `vcr-tests` job in `check_diffs.yml`). + +## Prerequisites + +1. A ChatGPT subscription (Plus / Pro / Team / Enterprise) — required for + `chatgpt.com/backend-api/codex`. +2. A token bundle on disk. Generate one with: + + ```bash + uv run --group test python -c \ + "from langchain_openai.chatgpt_oauth import login_chatgpt; login_chatgpt()" + ``` + + The default store is `~/.langchain/chatgpt-auth.json`. It is intentionally + distinct from `~/.codex/auth.json` so the Codex CLI / VS Code session is + not invalidated by refresh-token rotation here. +3. An integration test that instantiates `ChatOpenAICodex` and is marked + `@pytest.mark.vcr`. Tests written against the API-key `ChatOpenAI` will + not exercise the Codex backend — only Codex-specific tests should be + passed to the script. + +## Record + +From `libs/partners/openai/`: + +```bash +# Record every VCR-marked integration test (default). +scripts/record_codex_cassettes.sh + +# Record one file or one test. +scripts/record_codex_cassettes.sh tests/integration_tests/chat_models/test_codex.py +scripts/record_codex_cassettes.sh \ + tests/integration_tests/chat_models/test_codex.py::test_invoke + +# Forward extra pytest args. +PYTEST_EXTRA="-k streaming -x" scripts/record_codex_cassettes.sh +``` + +The script: + +1. Verifies the token store exists. +2. Force-refreshes the access token *outside* pytest so VCR never sees the + `auth.openai.com/oauth/token` roundtrip. A revoked refresh token surfaces + here, not after a long recording run. +3. Runs `pytest --record-mode=once -m vcr ` so missing cassettes are + created and existing ones replayed. +4. `zgrep`s every cassette for bearer tokens, JWTs, refresh-grant bodies, + leaked API keys, and ChatGPT account-id claims. Any match aborts with + a non-zero exit and a per-file report — do **not** commit those cassettes. +5. Prints a diff of cassette files that were touched. + +## What gets scrubbed automatically + +`tests/conftest.py` redacts: + +- Every request and response header (catches `Authorization: Bearer …`, + `ChatGPT-Account-Id`, cookies, organization IDs). +- Request URIs (no per-account URL parameters land in cassettes). +- OAuth secret fields in JSON request/response bodies: `access_token`, + `refresh_token`, `id_token`, `code`, `code_verifier`, `device_code`, + `client_secret`. +- The same fields in urlencoded form bodies (refresh-grant POSTs). +- Any JWT-shaped string anywhere in a body (`eyJ…`). + +The recording script's post-scan exists to catch anything the above misses. + +## Override the token store + +```bash +CHATGPT_AUTH_FILE=/tmp/codex-test-token.json \ + scripts/record_codex_cassettes.sh +``` + +Useful when recording against a dedicated test account so refresh rotation +doesn't churn your personal `~/.langchain/chatgpt-auth.json`. + +## Commit + +```bash +git -C libs/partners/openai status tests/cassettes/ +git -C libs/partners/openai diff --stat tests/cassettes/ +``` + +Spot-check at least one new cassette: `gunzip -c tests/cassettes/.yaml.gz | less`. +Verify every `Authorization` header reads `**REDACTED**` and no `eyJ…` strings +remain. + +## CI playback + +`.github/workflows/check_diffs.yml` routes openai changes through +`_test_vcr.yml`, which runs: + +```bash +make test_vcr # uv run pytest --record-mode=none -m vcr tests/integration_tests/ +``` + +`--record-mode=none` makes pytest fail rather than make a live network call, +so a missing or stale cassette is a hard failure — exactly the signal you +want. diff --git a/libs/partners/openai/scripts/record_codex_cassettes.sh b/libs/partners/openai/scripts/record_codex_cassettes.sh new file mode 100755 index 00000000000..09d12624eb8 --- /dev/null +++ b/libs/partners/openai/scripts/record_codex_cassettes.sh @@ -0,0 +1,181 @@ +#!/usr/bin/env bash +# Record VCR cassettes for `ChatOpenAICodex` integration tests against a +# real ChatGPT OAuth subscription account, then verify no OAuth secrets +# survived into the on-disk cassettes. +# +# Usage: +# scripts/record_codex_cassettes.sh # all integration_tests +# scripts/record_codex_cassettes.sh tests/integration_tests/chat_models/test_codex.py +# scripts/record_codex_cassettes.sh tests/integration_tests/chat_models/test_codex.py::test_invoke +# +# Env: +# CHATGPT_AUTH_FILE Override the token store path. Defaults to +# `$HOME/.langchain/chatgpt-auth.json`. +# PYTEST_EXTRA Extra args forwarded to pytest (e.g. `-x -k codex`). +# +# Exits non-zero if the token preflight fails, if pytest fails, or if the +# post-recording leak scan finds OAuth secrets in any cassette. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PKG_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)" +CASSETTE_DIR="${PKG_DIR}/tests/cassettes" +TOKEN_FILE="${CHATGPT_AUTH_FILE:-${HOME}/.langchain/chatgpt-auth.json}" + +if [[ ! -f "${TOKEN_FILE}" ]]; then + echo "error: ChatGPT OAuth token store not found at ${TOKEN_FILE}." >&2 + echo " Run \`python -c 'from langchain_openai.chatgpt_oauth import login_chatgpt; login_chatgpt()'\` first." >&2 + exit 2 +fi + +cd "${PKG_DIR}" + +# Reserve all tempfiles upfront and register a single cumulative cleanup +# trap so a failure between stages can't leak files (notably the leak +# report, which would contain the very tokens we just scrubbed). Mode 600 +# on the leak report seals it against other users sharing the host while +# the script runs. +PRE_SNAPSHOT="$(mktemp)" +POST_SNAPSHOT="$(mktemp)" +LEAK_REPORT="$(mktemp)" +chmod 600 "${LEAK_REPORT}" "${PRE_SNAPSHOT}" "${POST_SNAPSHOT}" +trap 'rm -f "${PRE_SNAPSHOT}" "${POST_SNAPSHOT}" "${LEAK_REPORT}"' EXIT + +# Portable cassette snapshot: GNU `find -printf` is unavailable on BSD/macOS +# and BSD `stat -f` is unavailable on Linux. A short python one-liner +# avoids the platform split and surfaces real errors instead of swallowing +# them behind `|| true`. +snapshot_cassettes() { + local dest="$1" + if [[ ! -d "${CASSETTE_DIR}" ]]; then + : > "${dest}" + return 0 + fi + python - "${CASSETTE_DIR}" >"${dest}" <<'PY' +import os +import sys +from pathlib import Path + +root = Path(sys.argv[1]) +entries = [] +for path in root.rglob("*.yaml.gz"): + try: + mtime = path.stat().st_mtime + except OSError as exc: + print(f"warning: failed to stat {path}: {exc}", file=sys.stderr) + continue + entries.append(f"{path} {mtime}") +entries.sort() +sys.stdout.write("\n".join(entries)) +if entries: + sys.stdout.write("\n") +PY +} + +# Preflight: force a refresh now so VCR doesn't record an `auth.openai.com` +# token roundtrip mid-test. A stale or revoked refresh token surfaces here +# rather than after a long test run. +echo "==> Refreshing ChatGPT OAuth token (${TOKEN_FILE})" +CHATGPT_AUTH_FILE="${TOKEN_FILE}" uv run --group test python - <<'PY' +import os +import sys +from pathlib import Path + +from langchain_openai.chatgpt_oauth import ( + ChatGPTOAuthRefreshError, + FileChatGPTOAuthTokenProvider, +) + +path = Path(os.environ["CHATGPT_AUTH_FILE"]) +provider = FileChatGPTOAuthTokenProvider(path=path) +try: + token = provider.get_token() +except ChatGPTOAuthRefreshError as exc: + print(f"refresh failed: {exc}", file=sys.stderr) + sys.exit(3) +print(f"ok — token valid until {token.expires_at.isoformat()}") +PY + +snapshot_cassettes "${PRE_SNAPSHOT}" + +# Default target: full integration suite. Override with positional args. +if [[ "$#" -eq 0 ]]; then + set -- tests/integration_tests/ +fi + +# `ChatOpenAICodex` constructs an `api_key` callable from its token provider, +# but `ChatOpenAI`'s pydantic init still requires *some* non-empty value. +# The placeholder is intentionally < 20 chars after `sk-` so the leak-scan +# pattern below (which targets real API keys at >= 20 chars) doesn't +# false-positive on it. +export OPENAI_API_KEY="${OPENAI_API_KEY:-sk-codex-placeholder}" + +echo "==> Recording cassettes for: $*" +# `--record-mode=once` writes any missing cassette and replays existing ones. +# `-m vcr` limits the run to VCR-marked tests so unrelated live tests don't +# fire unexpectedly. `PYTEST_EXTRA` is word-split via the shell (the +# `disable=SC2086` is intentional) so embedded spaces split into separate +# args — pass complex flags (e.g., `-k "name with spaces"`) as positional +# args to the script itself instead. +# shellcheck disable=SC2086 +uv run --group test --group test_integration pytest \ + --record-mode=once \ + -m vcr \ + -v --tb=short \ + ${PYTEST_EXTRA:-} \ + "$@" + +# Leak scan: zgrep the post-state cassettes for any pattern that would +# indicate an OAuth secret slipped past the conftest scrubbers. All +# patterns are passed to a single zgrep invocation per file so each +# cassette is decompressed exactly once. Patterns are ERE (zgrep -E): +# brace quantifiers use `{N,}` without backslashes. +echo "==> Scanning cassettes for OAuth secret leaks" +LEAK_ZGREP_ARGS=( + -e 'Bearer ey' # bearer token in a captured header value + -e 'eyJ[A-Za-z0-9_-]{20,}\.' # JWT-shaped payload (access/id/refresh tokens) + -e '"refresh_token"[[:space:]]*:[[:space:]]*"[^*"]' + -e '"access_token"[[:space:]]*:[[:space:]]*"[^*"]' + -e '"id_token"[[:space:]]*:[[:space:]]*"[^*"]' + -e 'refresh_token=[^&*]' # urlencoded refresh-grant body + -e 'sk-[A-Za-z0-9]{20,}' # leaked API key (>= 20 chars after sk-) + -e 'chatgpt_account_id' # account-id JWT claim from an id_token payload +) + +leak_found=0 +if [[ -d "${CASSETTE_DIR}" ]]; then + while IFS= read -r -d '' cassette; do + if zgrep -aHE "${LEAK_ZGREP_ARGS[@]}" "${cassette}" >> "${LEAK_REPORT}" 2>/dev/null; then + leak_found=1 + fi + done < <(find "${CASSETTE_DIR}" -type f -name '*.yaml.gz' -print0) +fi + +if [[ "${leak_found}" -ne 0 ]]; then + echo "error: OAuth secret leak detected in cassettes:" >&2 + cat "${LEAK_REPORT}" >&2 + echo >&2 + echo "Do NOT commit these cassettes. Re-run after extending the" >&2 + echo "scrubber in tests/conftest.py to cover the leaking field." >&2 + exit 4 +fi + +snapshot_cassettes "${POST_SNAPSHOT}" + +# Summarize what changed so the user knows which cassettes to inspect. +# Capture diff output to a variable so we can branch on `diff`'s exit +# code rather than on a piped one (`set -o pipefail` would otherwise +# flip the sense of the test). +echo "==> Cassette changes:" +if diff_out=$(diff -u "${PRE_SNAPSHOT}" "${POST_SNAPSHOT}"); then + echo "(no changes detected)" +else + # Strip the unified-diff `---`/`+++` headers; keep the `+`/`-` body lines. + printf '%s\n' "${diff_out}" | grep -E '^[+-][^+-]' || true +fi + +echo +echo "Recording complete. Inspect the diff before committing:" +echo " git -C ${PKG_DIR} status tests/cassettes/" +echo " git -C ${PKG_DIR} diff --stat tests/cassettes/" diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_abatch.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_abatch.yaml.gz new file mode 100644 index 00000000000..e88d21a91ad Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_abatch.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_agent_loop[model0].yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_agent_loop[model0].yaml.gz new file mode 100644 index 00000000000..f823b362224 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_agent_loop[model0].yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_agent_loop[model1].yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_agent_loop[model1].yaml.gz new file mode 100644 index 00000000000..ac9229c60f6 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_agent_loop[model1].yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_ainvoke.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_ainvoke.yaml.gz new file mode 100644 index 00000000000..ca0bf53ec9c Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_ainvoke.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_ainvoke_with_model_override.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_ainvoke_with_model_override.yaml.gz new file mode 100644 index 00000000000..66fc00db61e Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_ainvoke_with_model_override.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_anthropic_inputs.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_anthropic_inputs.yaml.gz new file mode 100644 index 00000000000..d510e93c6ce Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_anthropic_inputs.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_astream[model0].yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_astream[model0].yaml.gz new file mode 100644 index 00000000000..f7edb0ab364 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_astream[model0].yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_astream[model1].yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_astream[model1].yaml.gz new file mode 100644 index 00000000000..05ecd8f7584 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_astream[model1].yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_astream_events_v3.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_astream_events_v3.yaml.gz new file mode 100644 index 00000000000..70fc6a63a79 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_astream_events_v3.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_astream_with_model_override.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_astream_with_model_override.yaml.gz new file mode 100644 index 00000000000..bbbb6e435a3 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_astream_with_model_override.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_audio_inputs.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_audio_inputs.yaml.gz new file mode 100644 index 00000000000..94c03d6f57e Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_audio_inputs.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_batch.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_batch.yaml.gz new file mode 100644 index 00000000000..ec5100254c6 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_batch.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_bind_runnables_as_tools.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_bind_runnables_as_tools.yaml.gz new file mode 100644 index 00000000000..de9d7970d22 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_bind_runnables_as_tools.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_conversation.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_conversation.yaml.gz new file mode 100644 index 00000000000..3a66a5436c3 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_conversation.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_double_messages_conversation.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_double_messages_conversation.yaml.gz new file mode 100644 index 00000000000..32150de7252 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_double_messages_conversation.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_image_inputs.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_image_inputs.yaml.gz new file mode 100644 index 00000000000..a438023ecaf Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_image_inputs.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_image_tool_message.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_image_tool_message.yaml.gz new file mode 100644 index 00000000000..acf29a0e3b6 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_image_tool_message.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_invoke.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_invoke.yaml.gz new file mode 100644 index 00000000000..929eba37ce4 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_invoke.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_invoke_with_model_override.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_invoke_with_model_override.yaml.gz new file mode 100644 index 00000000000..c5f66a1f65a Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_invoke_with_model_override.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_json_mode.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_json_mode.yaml.gz new file mode 100644 index 00000000000..a37fa73e724 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_json_mode.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_message_with_name.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_message_with_name.yaml.gz new file mode 100644 index 00000000000..fae2f4b55f3 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_message_with_name.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_pdf_inputs.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_pdf_inputs.yaml.gz new file mode 100644 index 00000000000..0466bfbf487 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_pdf_inputs.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_pdf_tool_message.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_pdf_tool_message.yaml.gz new file mode 100644 index 00000000000..973dd866380 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_pdf_tool_message.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_stream[model0].yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_stream[model0].yaml.gz new file mode 100644 index 00000000000..af1f824fa6c Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_stream[model0].yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_stream[model1].yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_stream[model1].yaml.gz new file mode 100644 index 00000000000..8d208565710 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_stream[model1].yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_stream_events_v3.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_stream_events_v3.yaml.gz new file mode 100644 index 00000000000..fb7c736bbce Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_stream_events_v3.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_stream_time.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_stream_time.yaml.gz new file mode 100644 index 00000000000..6ccb893b783 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_stream_time.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_stream_with_model_override.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_stream_with_model_override.yaml.gz new file mode 100644 index 00000000000..00cac1d37c3 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_stream_with_model_override.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_few_shot_examples.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_few_shot_examples.yaml.gz new file mode 100644 index 00000000000..bfdf1ebb33f Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_few_shot_examples.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output[json_schema].yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output[json_schema].yaml.gz new file mode 100644 index 00000000000..90f18407274 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output[json_schema].yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output[pydantic].yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output[pydantic].yaml.gz new file mode 100644 index 00000000000..118aed53809 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output[pydantic].yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output[typeddict].yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output[typeddict].yaml.gz new file mode 100644 index 00000000000..aefbe293707 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output[typeddict].yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output_async[json_schema].yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output_async[json_schema].yaml.gz new file mode 100644 index 00000000000..14d87212cce Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output_async[json_schema].yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output_async[pydantic].yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output_async[pydantic].yaml.gz new file mode 100644 index 00000000000..7c246854fd9 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output_async[pydantic].yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output_async[typeddict].yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output_async[typeddict].yaml.gz new file mode 100644 index 00000000000..6442d269010 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output_async[typeddict].yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output_optional_param.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output_optional_param.yaml.gz new file mode 100644 index 00000000000..346764675a1 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output_optional_param.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output_pydantic_2_v1.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output_pydantic_2_v1.yaml.gz new file mode 100644 index 00000000000..365575b17ee Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_structured_output_pydantic_2_v1.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_calling[model0].yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_calling[model0].yaml.gz new file mode 100644 index 00000000000..cfca9dc2c9f Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_calling[model0].yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_calling[model1].yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_calling[model1].yaml.gz new file mode 100644 index 00000000000..0f6d724588f Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_calling[model1].yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_calling_async.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_calling_async.yaml.gz new file mode 100644 index 00000000000..9dc5ad5b107 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_calling_async.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_calling_with_no_arguments.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_calling_with_no_arguments.yaml.gz new file mode 100644 index 00000000000..2b8f557fc01 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_calling_with_no_arguments.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_choice.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_choice.yaml.gz new file mode 100644 index 00000000000..8ab82715870 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_choice.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_message_error_status.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_message_error_status.yaml.gz new file mode 100644 index 00000000000..b8ce08f3163 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_message_error_status.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_message_histories_list_content.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_message_histories_list_content.yaml.gz new file mode 100644 index 00000000000..bdc427add86 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_message_histories_list_content.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_message_histories_string_content.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_message_histories_string_content.yaml.gz new file mode 100644 index 00000000000..df26a6270b6 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_tool_message_histories_string_content.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_unicode_tool_call_integration.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_unicode_tool_call_integration.yaml.gz new file mode 100644 index 00000000000..c3d4dd77888 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_unicode_tool_call_integration.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_usage_metadata.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_usage_metadata.yaml.gz new file mode 100644 index 00000000000..725f5e4c3ae Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_usage_metadata.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_usage_metadata_streaming.yaml.gz b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_usage_metadata_streaming.yaml.gz new file mode 100644 index 00000000000..6f21ac23f04 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/TestChatOpenAICodexStandard.test_usage_metadata_streaming.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_codex_agent_loop.yaml.gz b/libs/partners/openai/tests/cassettes/test_codex_agent_loop.yaml.gz new file mode 100644 index 00000000000..577570e6b57 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_codex_agent_loop.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_codex_agent_loop_streaming.yaml.gz b/libs/partners/openai/tests/cassettes/test_codex_agent_loop_streaming.yaml.gz new file mode 100644 index 00000000000..b0a6718e6ee Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_codex_agent_loop_streaming.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_codex_custom_tool.yaml.gz b/libs/partners/openai/tests/cassettes/test_codex_custom_tool.yaml.gz new file mode 100644 index 00000000000..bdb0d792033 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_codex_custom_tool.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_codex_function_calling.yaml.gz b/libs/partners/openai/tests/cassettes/test_codex_function_calling.yaml.gz new file mode 100644 index 00000000000..b7683678796 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_codex_function_calling.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_codex_invoke.yaml.gz b/libs/partners/openai/tests/cassettes/test_codex_invoke.yaml.gz new file mode 100644 index 00000000000..952a16b3114 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_codex_invoke.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_codex_invoke_async.yaml.gz b/libs/partners/openai/tests/cassettes/test_codex_invoke_async.yaml.gz new file mode 100644 index 00000000000..a55120191b9 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_codex_invoke_async.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_codex_invoke_lifts_system_message_into_instructions.yaml.gz b/libs/partners/openai/tests/cassettes/test_codex_invoke_lifts_system_message_into_instructions.yaml.gz new file mode 100644 index 00000000000..f48573384de Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_codex_invoke_lifts_system_message_into_instructions.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_codex_invoke_with_instructions_override.yaml.gz b/libs/partners/openai/tests/cassettes/test_codex_invoke_with_instructions_override.yaml.gz new file mode 100644 index 00000000000..674ae35918d Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_codex_invoke_with_instructions_override.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_codex_multi_turn_no_tools.yaml.gz b/libs/partners/openai/tests/cassettes/test_codex_multi_turn_no_tools.yaml.gz new file mode 100644 index 00000000000..b2ba6a7e361 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_codex_multi_turn_no_tools.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_codex_reasoning.yaml.gz b/libs/partners/openai/tests/cassettes/test_codex_reasoning.yaml.gz new file mode 100644 index 00000000000..f1d2d537b6d Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_codex_reasoning.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_codex_reasoning_summary_streaming.yaml.gz b/libs/partners/openai/tests/cassettes/test_codex_reasoning_summary_streaming.yaml.gz new file mode 100644 index 00000000000..e78662b0d46 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_codex_reasoning_summary_streaming.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_codex_stream.yaml.gz b/libs/partners/openai/tests/cassettes/test_codex_stream.yaml.gz new file mode 100644 index 00000000000..e750960323d Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_codex_stream.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_codex_stream_async.yaml.gz b/libs/partners/openai/tests/cassettes/test_codex_stream_async.yaml.gz new file mode 100644 index 00000000000..16bf1e37326 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_codex_stream_async.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_codex_stream_events_v3.yaml.gz b/libs/partners/openai/tests/cassettes/test_codex_stream_events_v3.yaml.gz new file mode 100644 index 00000000000..8e7f6f4d352 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_codex_stream_events_v3.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_codex_stream_events_v3_async.yaml.gz b/libs/partners/openai/tests/cassettes/test_codex_stream_events_v3_async.yaml.gz new file mode 100644 index 00000000000..7b5b9499020 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_codex_stream_events_v3_async.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_codex_structured_output_pydantic.yaml.gz b/libs/partners/openai/tests/cassettes/test_codex_structured_output_pydantic.yaml.gz new file mode 100644 index 00000000000..17a3c23636c Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_codex_structured_output_pydantic.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_codex_structured_output_typed_dict.yaml.gz b/libs/partners/openai/tests/cassettes/test_codex_structured_output_typed_dict.yaml.gz new file mode 100644 index 00000000000..79914b77a3b Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_codex_structured_output_typed_dict.yaml.gz differ diff --git a/libs/partners/openai/tests/conftest.py b/libs/partners/openai/tests/conftest.py index e8159df1384..dc045c1c99d 100644 --- a/libs/partners/openai/tests/conftest.py +++ b/libs/partners/openai/tests/conftest.py @@ -1,29 +1,189 @@ +import inspect import json +import logging +import re from typing import Any import pytest from langchain_tests.conftest import CustomPersister, CustomSerializer, base_vcr_config from vcr import VCR # type: ignore[import-untyped] +logger = logging.getLogger(__name__) + _EXTRA_HEADERS = [ ("openai-organization", "PLACEHOLDER"), ("user-agent", "PLACEHOLDER"), ("x-openai-client-user-agent", "PLACEHOLDER"), + # ChatGPT OAuth subscription auth: the catch-all redactor below already + # wipes every header, but list these explicitly so anyone reading the + # config knows they are covered. + ("chatgpt-account-id", "PLACEHOLDER"), + ("cookie", "PLACEHOLDER"), + ("set-cookie", "PLACEHOLDER"), ] +# OAuth secret-bearing fields. Redacted in request and response bodies as +# defense-in-depth against an OAuth token-endpoint roundtrip getting captured +# mid-cassette. +_OAUTH_SECRET_FIELDS = frozenset( + { + "access_token", + "refresh_token", + "id_token", + "code", + "code_verifier", + "device_code", + "client_secret", + } +) +_JWT_PATTERN = re.compile(rb"eyJ[A-Za-z0-9_\-]+\.[A-Za-z0-9_\-]+\.[A-Za-z0-9_\-]+") + +# Binary content magic bytes. If any of these prefix the body we skip the +# scrub stack — JWTs (and every other OAuth secret we redact) are ASCII, +# so a confirmed-binary body can't carry one. Matters for performance: +# audio/PDF/image cassette responses can be many hundreds of KB and the +# UTF-8 + regex passes are O(N) on every record/replay. +_BINARY_MAGIC_PREFIXES: tuple[bytes, ...] = ( + b"\x89PNG", # PNG + b"\xff\xd8\xff", # JPEG + b"GIF8", # GIF + b"RIFF", # WAV / WEBP container + b"OggS", # Ogg + b"ID3", # MP3 with ID3 tag + b"\xff\xfb", # MP3 without tag + b"%PDF", # PDF + b"PK\x03\x04", # ZIP / DOCX / etc. +) + + +def _scrub_form_body(body: bytes) -> bytes: + """Redact OAuth secret fields in a urlencoded form body.""" + text = body.decode("utf-8", errors="replace") + parts = text.split("&") + redacted: list[str] = [] + for part in parts: + key, sep, _ = part.partition("=") + if sep and key in _OAUTH_SECRET_FIELDS: + redacted.append(f"{key}=**REDACTED**") + else: + redacted.append(part) + return "&".join(redacted).encode("utf-8") + + +def _walk_and_redact(node: Any) -> bool: + """Redact OAuth secret fields anywhere in a parsed JSON tree in place. + + Returns `True` if any field was redacted. Walks dicts and lists + recursively so nested OAuth payloads (e.g., `{"data": {"refresh_token": + "..."}}` or `[{"access_token": "..."}]`) are scrubbed too. + """ + redacted = False + if isinstance(node, dict): + for key, value in node.items(): + if key in _OAUTH_SECRET_FIELDS and isinstance(value, (str, int, float)): + node[key] = "**REDACTED**" + redacted = True + elif isinstance(value, (dict, list)): + redacted = _walk_and_redact(value) or redacted + elif isinstance(node, list): + for item in node: + if isinstance(item, (dict, list)): + redacted = _walk_and_redact(item) or redacted + return redacted + + +def _scrub_json_body(body: bytes) -> tuple[bytes, bool]: + """Redact OAuth secrets in a JSON body recursively. + + Returns `(scrubbed_bytes, was_json)`. `was_json` is `True` whenever the + body parsed successfully, so callers can skip the form-body fallback + (form-decoding a JSON body would split on `&` in string values and + mangle them). The returned bytes match the input byte-for-byte when no + field was redacted, so unrelated cassette diffs don't churn from JSON + whitespace differences. + """ + try: + payload = json.loads(body) + except (ValueError, json.JSONDecodeError): + return body, False + if not isinstance(payload, (dict, list)): + return body, True + if not _walk_and_redact(payload): + return body, True + return json.dumps(payload).encode("utf-8"), True + + +def _scrub_oauth_secrets(body: bytes | str | None) -> bytes | str | None: + """Best-effort scrubber for OAuth secrets in request/response bodies. + + Text bodies (JSON or urlencoded form) get a full structured scrub. + Bodies that begin with a known binary magic prefix (PNG, JPEG, PDF, + audio, etc.) pass through untouched — JWTs and the rest of + `_OAUTH_SECRET_FIELDS` are ASCII, so a confirmed-binary payload can't + carry one. Skipping these saves O(N) UTF-8 + regex work on cassette + responses that can run into the hundreds of KB. + + Partially-malformed UTF-8 (a corrupted text response) still gets the + JWT regex pass via `errors="replace"`, so a token-endpoint reply with + a single bad byte can't slip a JWT past us. + """ + if not body: + return body + if isinstance(body, bytes): + if body.startswith(_BINARY_MAGIC_PREFIXES): + return body + try: + text = body.decode("utf-8") + except UnicodeDecodeError: + # Not text and no known binary prefix — could be a partially- + # malformed text response. Run the JWT regex over a lossy decode + # so a token-endpoint reply with a bad byte can't sneak through. + # The lossy decode is *only* used for matching — the original + # bytes are returned if no JWT was found. + fallback_text = body.decode("utf-8", errors="replace").encode("utf-8") + scrubbed_fallback = _JWT_PATTERN.sub(b"**REDACTED-JWT**", fallback_text) + if scrubbed_fallback == fallback_text: + return body + logger.warning( + "Scrubbed a JWT-shaped token from a non-UTF-8 response body " + "(%d bytes). Originating endpoint likely returned text " + "with a bad byte; recording was preserved via lossy decode.", + len(body), + ) + return scrubbed_fallback + else: + text = body + payload_bytes = text.encode("utf-8") + scrubbed, was_json = _scrub_json_body(payload_bytes) + if not was_json: + scrubbed = _scrub_form_body(payload_bytes) + # Final pass: blanket-redact any JWT-shaped string that survived (e.g. + # JWT embedded in a free-form error message). + scrubbed = _JWT_PATTERN.sub(b"**REDACTED-JWT**", scrubbed) + return scrubbed.decode("utf-8") if isinstance(body, str) else scrubbed + def remove_request_headers(request: Any) -> Any: - """Remove sensitive headers from the request.""" + """Remove sensitive headers and OAuth secrets from the request.""" for k in request.headers: request.headers[k] = "**REDACTED**" request.uri = "**REDACTED**" + request.body = _scrub_oauth_secrets(request.body) return request def remove_response_headers(response: dict) -> dict: - """Remove sensitive headers from the response.""" + """Remove sensitive headers and OAuth secrets from the response.""" for k in response["headers"]: response["headers"][k] = "**REDACTED**" + # Pinning vcrpy's internal `body["string"]` shape: if vcrpy ever + # switches to a different key the scrub silently no-ops, so the + # script's post-recording leak scan is the load-bearing backstop. + body = response.get("body") + if isinstance(body, dict): + body_value = body.get("string") + if body_value is not None: + body["string"] = _scrub_oauth_secrets(body_value) return response @@ -44,6 +204,28 @@ def vcr_config() -> dict: return config +def _normalize_tool_descriptions(body: Any) -> Any: + """Strip common leading whitespace from `tools[*].description` strings. + + Python 3.13+ runs `inspect.cleandoc` on docstrings at compile time, but + cassettes recorded on Python 3.12 (or earlier) preserve the raw indented + form. Normalize the `description` field — populated from a tool's + docstring by `@tool` — so cassettes recorded on any Python version match + requests issued by any other version. Scoped to `tools[*].description` + to avoid mutating user-visible message content. + """ + if not isinstance(body, dict): + return body + tools = body.get("tools") + if isinstance(tools, list): + for entry in tools: + if isinstance(entry, dict): + description = entry.get("description") + if isinstance(description, str): + entry["description"] = inspect.cleandoc(description) + return body + + def _json_body_matcher(r1: Any, r2: Any) -> None: """Match request bodies as parsed JSON, ignoring key order.""" b1 = r1.body or b"" @@ -53,8 +235,8 @@ def _json_body_matcher(r1: Any, r2: Any) -> None: if isinstance(b2, bytes): b2 = b2.decode("utf-8") try: - j1 = json.loads(b1) - j2 = json.loads(b2) + j1 = _normalize_tool_descriptions(json.loads(b1)) + j2 = _normalize_tool_descriptions(json.loads(b2)) except (json.JSONDecodeError, ValueError): assert b1 == b2, f"body mismatch (non-JSON):\n{b1}\n!=\n{b2}" return diff --git a/libs/partners/openai/tests/integration_tests/chat_models/conftest.py b/libs/partners/openai/tests/integration_tests/chat_models/conftest.py new file mode 100644 index 00000000000..b4b31653920 --- /dev/null +++ b/libs/partners/openai/tests/integration_tests/chat_models/conftest.py @@ -0,0 +1,116 @@ +"""Shared fixtures for chat-model integration tests. + +The `ChatOpenAICodex` integration tests run under VCR cassette playback in +CI (`make test_vcr`), but its `FileChatGPTOAuthTokenProvider` still tries +to read `~/.langchain/chatgpt-auth.json` from disk on every request to +build the `Authorization` header. CI has no such file, so every Codex +test would fail with `FileNotFoundError` before VCR ever replays the +cassette. + +This fixture monkey-patches the on-disk token methods with an in-memory +fake token whenever a Codex test module is running, so cassette replay +works without a live OAuth login. The patch is scoped by module name and +leaves non-Codex tests in the directory untouched. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from typing import Any + +import pytest + +from langchain_openai import chatgpt_oauth + + +@pytest.fixture(autouse=True) +def _clear_openai_base_url_env( + request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch +) -> None: + """Strip ambient OpenAI endpoint env vars for VCR tests only. + + `ChatOpenAI` only default-enables `stream_usage` when no custom base URL is + configured (see the `OPENAI_BASE_URL` / `OPENAI_API_BASE` handling at + construction time). A developer whose shell exports one of these vars (e.g. + pointing at a gateway) would otherwise see `stream_usage` silently dropped, + so the request body omits `stream_options` and no longer matches the + recorded cassette's `json_body` — VCR then attempts a live call and fails + with `APIConnectionError`. + + Scoped to `@pytest.mark.vcr` tests: cassettes are recorded against the + canonical `api.openai.com` host, so cassette playback (and re-recording) + must ignore an ambient gateway endpoint. Live integration tests (no + cassette) are left untouched so a developer can still route them through a + gateway via these env vars. + """ + if request.node.get_closest_marker("vcr") is None: + return + for name in ("OPENAI_BASE_URL", "OPENAI_API_BASE"): + monkeypatch.delenv(name, raising=False) + + +def _vcr_record_mode(config: pytest.Config) -> str | None: + """Return pytest-recording's configured record mode, if available.""" + for option in ("record_mode", "--record-mode"): + try: + value: Any = config.getoption(option, default=None) + except ValueError: + continue + if value is not None: + return str(value) + return None + + +@pytest.fixture(autouse=True) +def _fake_codex_oauth_token( + request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch +) -> None: + """Stub `FileChatGPTOAuthTokenProvider` token reads for Codex VCR tests.""" + if "codex" not in request.module.__name__: + return + if _vcr_record_mode(request.config) != "none": + return + + fake_token = chatgpt_oauth.ChatGPTToken( + access_token="vcr-fake-access-token", # noqa: S106 + refresh_token="vcr-fake-refresh-token", # noqa: S106 + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + account_id="vcr-fake-account-id", + ) + + def _get_token( + self: chatgpt_oauth.FileChatGPTOAuthTokenProvider, + ) -> chatgpt_oauth.ChatGPTToken: + return fake_token + + async def _aget_token( + self: chatgpt_oauth.FileChatGPTOAuthTokenProvider, + ) -> chatgpt_oauth.ChatGPTToken: + return fake_token + + def _get_access_token( + self: chatgpt_oauth.FileChatGPTOAuthTokenProvider, + ) -> str: + return fake_token.access_token + + async def _aget_access_token( + self: chatgpt_oauth.FileChatGPTOAuthTokenProvider, + ) -> str: + return fake_token.access_token + + monkeypatch.setattr( + chatgpt_oauth.FileChatGPTOAuthTokenProvider, "get_token", _get_token + ) + monkeypatch.setattr( + chatgpt_oauth.FileChatGPTOAuthTokenProvider, "aget_token", _aget_token + ) + monkeypatch.setattr( + chatgpt_oauth.FileChatGPTOAuthTokenProvider, + "get_access_token", + _get_access_token, + ) + monkeypatch.setattr( + chatgpt_oauth.FileChatGPTOAuthTokenProvider, + "aget_access_token", + _aget_access_token, + ) diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_codex.py b/libs/partners/openai/tests/integration_tests/chat_models/test_codex.py new file mode 100644 index 00000000000..94a49c3f620 --- /dev/null +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_codex.py @@ -0,0 +1,369 @@ +"""Integration tests for `ChatOpenAICodex`. + +These tests exercise the ChatGPT subscription OAuth path against the +`https://chatgpt.com/backend-api/codex` endpoint. They are recorded with +VCR (cassettes live alongside, under `tests/cassettes/`) so CI replays +them in `--record-mode=none` without a live token. + +`ChatOpenAICodex` forces `use_responses_api=True`, `store=False`, and +`streaming=True` at the wire level (`output_version` is a client-side +projection and is *not* forced). The cassettes here are recorded with a +single `output_version` for stability; per-projection coverage already +lives in `test_responses_api.py`. + +Override the model with `CODEX_MODEL=` when recording against a +different account / plan tier. +""" + +from __future__ import annotations + +import json +import os +from typing import TYPE_CHECKING, cast + +import pytest +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + HumanMessage, + SystemMessage, + ToolMessage, +) +from langchain_core.tools import tool +from pydantic import BaseModel +from typing_extensions import TypedDict + +from langchain_openai import ChatOpenAICodex, custom_tool + +if TYPE_CHECKING: + from collections.abc import AsyncIterator, Awaitable, Iterator + + from langchain_core.language_models.chat_model_stream import ( + AsyncChatModelStream, + ChatModelStream, + ) + +pytestmark = pytest.mark.vcr + +MODEL_NAME = os.getenv("CODEX_MODEL", "gpt-5.5") +TERSE_INSTRUCTIONS = "You are terse. Answer in five words or fewer." + + +def _check_response(response: BaseMessage | None) -> None: + """Assert the response carries the minimum Responses-API shape. + + Looser than the `test_responses_api.py` equivalent — Codex responses + don't always populate `service_tier`, so we don't require it here. + """ + assert isinstance(response, AIMessage) + assert isinstance(response.content, list) + text_content = response.text + assert isinstance(text_content, str) + assert text_content + assert response.usage_metadata + assert response.usage_metadata["input_tokens"] > 0 + assert response.usage_metadata["output_tokens"] > 0 + assert response.usage_metadata["total_tokens"] > 0 + assert response.response_metadata["model_name"] + + +def _aggregate(stream: Iterator[BaseMessage]) -> AIMessageChunk: + """Drain a sync chunk stream and return the aggregated `AIMessageChunk`. + + Typed against `BaseMessage` (the broadest output of `Runnable.stream`) + so that `bound.stream(...)` — which mypy infers as `Iterator[AIMessage]` + after `bind_tools` — is accepted without a cast. The `isinstance` + guard inside the loop enforces the runtime contract. + """ + aggregated: BaseMessageChunk | None = None + for chunk in stream: + assert isinstance(chunk, AIMessageChunk) + aggregated = chunk if aggregated is None else aggregated + chunk + assert isinstance(aggregated, AIMessageChunk) + return aggregated + + +async def _aaggregate(stream: AsyncIterator[BaseMessage]) -> AIMessageChunk: + """Drain an async chunk stream and return the aggregated `AIMessageChunk`.""" + aggregated: BaseMessageChunk | None = None + async for chunk in stream: + assert isinstance(chunk, AIMessageChunk) + aggregated = chunk if aggregated is None else aggregated + chunk + assert isinstance(aggregated, AIMessageChunk) + return aggregated + + +# --------------------------------------------------------------------------- +# Basic invoke / stream / async surface +# --------------------------------------------------------------------------- + + +def test_codex_invoke() -> None: + llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS) + response = llm.invoke("Say hi in one word.") + _check_response(response) + + +def test_codex_invoke_lifts_system_message_into_instructions() -> None: + """`SystemMessage` content is lifted into top-level `instructions`. + + Codex rejects `SystemMessage` chat turns; `ChatOpenAICodex` works + around this by moving the `SystemMessage` content into the + `instructions` field and stripping it from the input list. The model + should respect the lifted instruction (here: respond with HELLO). + """ + llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS) + response = llm.invoke( + [ + SystemMessage("Respond with exactly one word: HELLO. No punctuation."), + HumanMessage("Greet me."), + ] + ) + _check_response(response) + assert "hello" in response.text.lower() + + +def test_codex_invoke_with_instructions_override() -> None: + """Per-call `instructions=` overrides the constructor value for one call.""" + llm = ChatOpenAICodex( + model=MODEL_NAME, instructions="You are an English assistant." + ) + response = llm.invoke( + "Greet me.", + instructions=( + "You are a French assistant. Respond only in French, in five words " + "or fewer." + ), + ) + _check_response(response) + + +async def test_codex_invoke_async() -> None: + llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS) + response = await llm.ainvoke("Say hi in one word.") + _check_response(response) + + +def test_codex_stream() -> None: + llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS) + _check_response(_aggregate(llm.stream("Count to three."))) + + +async def test_codex_stream_async() -> None: + llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS) + _check_response(await _aaggregate(llm.astream("Count to three."))) + + +def test_codex_stream_events_v3() -> None: + llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS) + stream = cast("ChatModelStream", llm.stream_events("Count to three.", version="v3")) + response = stream.output + _check_response(response) + + +async def test_codex_stream_events_v3_async() -> None: + llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS) + stream = await cast( + "Awaitable[AsyncChatModelStream]", + llm.astream_events("Count to three.", version="v3"), + ) + response = await stream + _check_response(response) + + +# --------------------------------------------------------------------------- +# Multi-turn conversation +# --------------------------------------------------------------------------- + + +def test_codex_multi_turn_no_tools() -> None: + """Pass full chat history (the backend is stateless for this client).""" + llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS) + first = llm.invoke("My name is Bobo.") + assert isinstance(first, AIMessage) + second = llm.invoke( + [ + HumanMessage("My name is Bobo."), + first, + HumanMessage("What is my name?"), + ] + ) + _check_response(second) + assert "bobo" in second.text.lower() + + +# --------------------------------------------------------------------------- +# Function calling / agent loop +# --------------------------------------------------------------------------- + + +def test_codex_function_calling() -> None: + @tool + def multiply(x: int, y: int) -> int: + """Return x * y.""" + return x * y + + llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS) + bound = llm.bind_tools([multiply]) + + ai_msg = cast(AIMessage, bound.invoke("What is 5 times 4?")) + assert len(ai_msg.tool_calls) == 1 + assert ai_msg.tool_calls[0]["name"] == "multiply" + assert set(ai_msg.tool_calls[0]["args"]) == {"x", "y"} + + aggregated = _aggregate(bound.stream("What is 5 times 4?")) + assert len(aggregated.tool_calls) == 1 + assert aggregated.tool_calls[0]["name"] == "multiply" + assert set(aggregated.tool_calls[0]["args"]) == {"x", "y"} + + +def test_codex_agent_loop() -> None: + """Tool call → tool message → final answer (three round trips).""" + + @tool + def get_weather(location: str) -> str: + """Get the weather for a location.""" + return "It's sunny." + + llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS) + bound = llm.bind_tools([get_weather]) + + user_msg = HumanMessage("What is the weather in San Francisco, CA?") + tool_call_msg = cast(AIMessage, bound.invoke([user_msg])) + assert tool_call_msg.tool_calls + tool_call = tool_call_msg.tool_calls[0] + tool_msg = get_weather.invoke(tool_call) + assert isinstance(tool_msg, ToolMessage) + + final = bound.invoke([user_msg, tool_call_msg, tool_msg]) + assert isinstance(final, AIMessage) + + +def test_codex_agent_loop_streaming() -> None: + @tool + def get_weather(location: str) -> str: + """Get the weather for a location.""" + return "It's sunny." + + llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS) + bound = llm.bind_tools([get_weather]) + + user_msg = HumanMessage("What is the weather in San Francisco, CA?") + tool_call_msg = _aggregate(bound.stream([user_msg])) + assert tool_call_msg.tool_calls + tool_msg = get_weather.invoke(tool_call_msg.tool_calls[0]) + assert isinstance(tool_msg, ToolMessage) + + final = _aggregate(bound.stream([user_msg, tool_call_msg, tool_msg])) + assert isinstance(final, AIMessage) + + +def test_codex_custom_tool() -> None: + @custom_tool + def execute_code(code: str) -> str: + """Execute Python code and return the result.""" + return "27" + + llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS).bind_tools( + [execute_code] + ) + + input_message = { + "role": "user", + "content": "Use the execute_code tool to evaluate 3**3.", + } + tool_call_msg = cast(AIMessage, llm.invoke([input_message])) + assert tool_call_msg.tool_calls + tool_msg = execute_code.invoke(tool_call_msg.tool_calls[0]) + response = llm.invoke([input_message, tool_call_msg, tool_msg]) + assert isinstance(response, AIMessage) + + +# --------------------------------------------------------------------------- +# Reasoning +# --------------------------------------------------------------------------- + + +def test_codex_reasoning() -> None: + """`reasoning={'effort': 'low'}` produces a reasoning block in content.""" + llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS) + response = llm.invoke("What is 2 + 2?", reasoning={"effort": "low"}) + assert isinstance(response, AIMessage) + block_types = [ + block["type"] for block in response.content if isinstance(block, dict) + ] + assert "reasoning" in block_types or "text" in block_types + + +def test_codex_reasoning_summary_streaming() -> None: + """`reasoning.summary='auto'` carries a populated summary list.""" + llm = ChatOpenAICodex( + model=MODEL_NAME, + instructions=TERSE_INSTRUCTIONS, + reasoning={"effort": "medium", "summary": "auto"}, + ) + aggregated = _aggregate( + llm.stream("What was the tallest building in the year 2000?") + ) + + reasoning_blocks = [ + block + for block in aggregated.content + if isinstance(block, dict) and block["type"] == "reasoning" + ] + # A non-trivial prompt should produce exactly one reasoning content block + # with at least one summary entry; if Codex stops streaming summaries + # this assertion regresses loudly instead of silently passing. + assert len(reasoning_blocks) >= 1 + summary = reasoning_blocks[0].get("summary") + assert isinstance(summary, list) + assert summary, "reasoning block emitted an empty `summary` list" + for summary_block in summary: + assert isinstance(summary_block, dict) + assert isinstance(summary_block.get("type"), str) + assert isinstance(summary_block.get("text"), str) + assert summary_block["text"] + + +# --------------------------------------------------------------------------- +# Structured output +# --------------------------------------------------------------------------- + + +class Foo(BaseModel): + """A trivial pydantic schema used to exercise `response_format`.""" + + # Docstring is intentional: Pydantic emits it as `description` in the + # JSON schema sent to Codex, which the cassettes pin on. + response: str + + +class FooDict(TypedDict): + """A trivial TypedDict schema used to exercise `response_format`.""" + + response: str + + +def test_codex_structured_output_pydantic() -> None: + llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS) + response = llm.invoke("Say hi.", response_format=Foo) + parsed = Foo(**json.loads(response.text)) + assert parsed == response.additional_kwargs["parsed"] + assert parsed.response + + +def test_codex_structured_output_typed_dict() -> None: + llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS) + response = llm.invoke("Say hi.", response_format=FooDict) + parsed = json.loads(response.text) + assert parsed == response.additional_kwargs["parsed"] + assert isinstance(parsed["response"], str) + assert parsed["response"] + + +# Header behavior (originator default/override, ChatGPT-Account-Id presence) +# is covered by the unit suite — VCR cassettes redact every recorded header +# value, so an integration test here couldn't distinguish a header-present +# round trip from a header-absent one anyway. diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_codex_standard.py b/libs/partners/openai/tests/integration_tests/chat_models/test_codex_standard.py new file mode 100644 index 00000000000..069612d0237 --- /dev/null +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_codex_standard.py @@ -0,0 +1,157 @@ +"""Standard LangChain interface tests for `ChatOpenAICodex`. + +Drives the full `ChatModelIntegrationTests` suite against the Codex +backend. The module-level `pytestmark = pytest.mark.vcr` makes every +inherited test record/replay through VCR so cassettes recorded once +locally replay without a live OAuth token in CI. + +Capability flags below reflect what the Codex ChatGPT-subscription +endpoint currently exposes (image inputs, PDF inputs, audio inputs, +JSON-mode `response_format`, Anthropic-format inputs all work). The +divergence from `ChatOpenAI`'s defaults is intentional — Codex is a +subset surface, so a few `ChatModelIntegrationTests` cases are xfailed +with documented reasons. +""" + +from __future__ import annotations + +import os +from typing import Any, Literal, cast + +import pytest +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage +from langchain_tests.integration_tests import ChatModelIntegrationTests + +from langchain_openai import ChatOpenAICodex + +pytestmark = pytest.mark.vcr + +MODEL_NAME = os.getenv("CODEX_MODEL", "gpt-5.5") +TERSE_INSTRUCTIONS = "You are terse. Answer in five words or fewer." + + +class TestChatOpenAICodexStandard(ChatModelIntegrationTests): + """Standard chat-model integration suite, configured for Codex. + + Capability properties below override the upstream defaults to match + what the Codex backend supports (or fails on) — flip any to `False` + if Codex regresses on the corresponding capability, then re-record + cassettes via the recording script. + """ + + @property + def chat_model_class(self) -> type[BaseChatModel]: + return ChatOpenAICodex + + @property + def chat_model_params(self) -> dict: + return {"model": MODEL_NAME, "instructions": TERSE_INSTRUCTIONS} + + @property + def model_override_value(self) -> str | None: + # The Codex ChatGPT-subscription backend exposes a single model + # to this client; reuse `MODEL_NAME` so the override path exercises + # the per-call `model=` plumbing without depending on a second + # account-eligible model. + return MODEL_NAME + + # -- Capability flags ------------------------------------------------- + # All currently confirmed working against the ChatGPT-subscription + # Codex backend at recording time. Flip a flag to `False` and + # re-record if Codex stops accepting a capability. + + @property + def supports_image_inputs(self) -> bool: + return True + + @property + def supports_image_urls(self) -> bool: + return True + + @property + def supports_image_tool_message(self) -> bool: + return True + + @property + def supports_pdf_inputs(self) -> bool: + return True + + @property + def supports_pdf_tool_message(self) -> bool: + return True + + @property + def supports_audio_inputs(self) -> bool: + return True + + @property + def supports_json_mode(self) -> bool: + return True + + @property + def supports_anthropic_inputs(self) -> bool: + return True + + @property + def enable_vcr_tests(self) -> bool: + return True + + @property + def supported_usage_metadata_details( + self, + ) -> dict[ + Literal["invoke", "stream"], + list[ + Literal[ + "audio_input", + "audio_output", + "reasoning_output", + "cache_read_input", + "cache_creation_input", + ] + ], + ]: + # The Codex backend reports `reasoning_output` for reasoning-enabled + # models; cache and audio metadata are not surfaced through the + # subscription endpoint. + return {"invoke": ["reasoning_output"], "stream": ["reasoning_output"]} + + # -- Helpers used by the shared suite --------------------------------- + + def invoke_with_reasoning_output(self, *, stream: bool = False) -> AIMessage: + llm = ChatOpenAICodex( + model=MODEL_NAME, + instructions=TERSE_INSTRUCTIONS, + reasoning={"effort": "medium", "summary": "auto"}, + ) + prompt = "What was the 3rd highest building in 2000?" + return _invoke(llm, prompt, stream) + + # -- Codex-specific xfails -------------------------------------------- + + @pytest.mark.xfail(reason="Codex backend does not honor `stop` sequences.") + def test_stop_sequence(self, model: BaseChatModel) -> None: + super().test_stop_sequence(model) + + @pytest.mark.xfail( + reason=( + "Few-shot helper generates a fresh tool-call UUID per invocation, " + "so the recorded request body never matches on replay. Tracked " + "separately; this test needs a deterministic call_id or a " + "matcher exception in conftest before it can be cassette-backed." + ) + ) + def test_structured_few_shot_examples( + self, model: BaseChatModel, my_adder_tool: Any + ) -> None: + super().test_structured_few_shot_examples(model, my_adder_tool) + + +def _invoke(llm: ChatOpenAICodex, prompt: str, stream: bool) -> AIMessage: + if stream: + full = None + for chunk in llm.stream(prompt): + full = full + chunk if full else chunk # type: ignore[operator] + return cast(AIMessage, full) + return cast(AIMessage, llm.invoke(prompt)) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/conftest.py b/libs/partners/openai/tests/unit_tests/chat_models/conftest.py new file mode 100644 index 00000000000..ddf19832030 --- /dev/null +++ b/libs/partners/openai/tests/unit_tests/chat_models/conftest.py @@ -0,0 +1,18 @@ +"""Shared fixtures for chat model unit tests.""" + +import pytest + + +@pytest.fixture(autouse=True) +def _clear_openai_base_url_env(monkeypatch: pytest.MonkeyPatch) -> None: + """Strip ambient OpenAI endpoint env vars so unit tests stay deterministic. + + `ChatOpenAI` only default-enables `stream_usage` when no custom base URL is + configured (see the `OPENAI_BASE_URL` / `OPENAI_API_BASE` handling at + construction time). A developer whose shell exports one of these vars (e.g. + pointing at a gateway) would otherwise see `stream_usage` silently dropped, + breaking serialization snapshots that assume the default. Removing them here + keeps results consistent with CI, where the vars are unset. + """ + for name in ("OPENAI_BASE_URL", "OPENAI_API_BASE"): + monkeypatch.delenv(name, raising=False) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_codex.py b/libs/partners/openai/tests/unit_tests/chat_models/test_codex.py new file mode 100644 index 00000000000..73bc8d9ee0d --- /dev/null +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_codex.py @@ -0,0 +1,536 @@ +"""Unit tests for `ChatOpenAICodex`.""" +# ruff: noqa: S106, S107 + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from typing import Any + +import pytest +from langchain_core.messages import ChatMessage, HumanMessage, SystemMessage + +from langchain_openai import ChatOpenAICodex +from langchain_openai.chat_models.base import ChatOpenAI +from langchain_openai.chat_models.codex import ( + ACCOUNT_ID_HEADER, + CHATGPT_CODEX_BASE_URL, + DEFAULT_INSTRUCTIONS, + ORIGINATOR_ENV_VAR, + ORIGINATOR_HEADER, + ORIGINATOR_VALUE, + _SyncTokenCallable, +) +from langchain_openai.chatgpt_oauth import ChatGPTToken + + +class FakeTokenProvider: + """Minimal `ChatGPTOAuthTokenProvider` for tests.""" + + def __init__( + self, + access_token: str = "at-1", + account_id: str | None = "acct-1", + ) -> None: + self.access_token = access_token + self.account_id = account_id + self.calls = 0 + self.async_calls = 0 + + def get_token(self) -> ChatGPTToken: + self.calls += 1 + return ChatGPTToken( + access_token=self.access_token, + refresh_token="rt", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + account_id=self.account_id, + ) + + async def aget_token(self) -> ChatGPTToken: + self.async_calls += 1 + return self.get_token() + + def get_access_token(self) -> str: + return self.get_token().access_token + + async def aget_access_token(self) -> str: + token = await self.aget_token() + return token.access_token + + +def _build_model(**overrides: Any) -> ChatOpenAICodex: + provider = overrides.pop("token_provider", None) or FakeTokenProvider() + return ChatOpenAICodex( + model=overrides.pop("model", "gpt-5.2-codex"), + token_provider=provider, + **overrides, + ) + + +def test_defaults_route_to_chatgpt_codex_backend() -> None: + model = _build_model() + assert model.openai_api_base == CHATGPT_CODEX_BASE_URL + assert model.use_responses_api is True + assert model.store is False + assert model.streaming is True + # `output_version` is a client-side projection and isn't forced — the + # base default applies unless the caller picks one explicitly. + + +def test_uses_callable_api_key_from_token_provider() -> None: + """The SDK-facing `api_key` must resolve to the provider's current token.""" + provider = FakeTokenProvider(access_token="abc") + model = _build_model(token_provider=provider) + secret = model.openai_api_key + assert secret is not None + if callable(secret): + assert secret() == "abc" + else: + # `ChatOpenAI.validate_environment` wraps callables in `SecretStr`. + assert secret.get_secret_value() == "abc" + assert provider.calls >= 1 + + +@pytest.mark.parametrize("field", ["base_url", "openai_api_base"]) +def test_base_url_override_is_rejected(field: str) -> None: + """Reject caller-supplied `base_url` / `openai_api_base`. + + The OAuth bearer token is wired in as `api_key`, so accepting an arbitrary + base URL would let an attacker (or a misconfigured serialized config) + exfiltrate the token to a host of their choice. + """ + with pytest.raises(ValueError, match=r"requires `(?:base_url|openai_api_base)="): + _build_model(**{field: "https://attacker.example.com/codex"}) + + +def test_explicit_base_url_matching_codex_endpoint_is_accepted() -> None: + """Passing the canonical Codex endpoint explicitly is still allowed.""" + model = _build_model(base_url=CHATGPT_CODEX_BASE_URL) + assert model.openai_api_base == CHATGPT_CODEX_BASE_URL + + +@pytest.mark.parametrize("field", ["api_key", "openai_api_key"]) +def test_explicit_api_key_is_rejected(field: str) -> None: + """Reject a caller-supplied `api_key` / `openai_api_key`. + + Auth is owned by `token_provider`; a caller-supplied key would silently + win over the OAuth bearer, so it must fail loudly rather than leave the + model in a conflicting state. + """ + with pytest.raises(ValueError, match=r"manages authentication via"): + _build_model(**{field: "sk-should-not-be-allowed"}) + + +def test_request_payload_injects_account_id_and_originator_headers() -> None: + provider = FakeTokenProvider(account_id="acct-42") + model = _build_model(token_provider=provider) + payload = model._get_request_payload([HumanMessage("hi")]) + headers = payload["extra_headers"] + assert headers[ACCOUNT_ID_HEADER] == "acct-42" + assert headers[ORIGINATOR_HEADER] == ORIGINATOR_VALUE + + +def test_request_payload_omits_account_id_when_unknown() -> None: + provider = FakeTokenProvider(account_id=None) + model = _build_model(token_provider=provider) + payload = model._get_request_payload([HumanMessage("hi")]) + headers = payload["extra_headers"] + assert ACCOUNT_ID_HEADER not in headers + assert headers[ORIGINATOR_HEADER] == ORIGINATOR_VALUE + + +def test_request_payload_can_disable_originator_header() -> None: + """`originator=None` omits the header entirely.""" + provider = FakeTokenProvider(account_id=None) + model = _build_model(token_provider=provider, originator=None) + payload = model._get_request_payload([HumanMessage("hi")]) + assert "extra_headers" not in payload + + +def test_constructor_originator_overrides_default() -> None: + """An explicit `originator=` constructor value replaces the package default.""" + model = _build_model(originator="my-app/1.2") + payload = model._get_request_payload([HumanMessage("hi")]) + assert payload["extra_headers"][ORIGINATOR_HEADER] == "my-app/1.2" + + +def test_env_var_overrides_default_originator( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """`LANGCHAIN_CODEX_ORIGINATOR` sets the default when no constructor value.""" + monkeypatch.setenv(ORIGINATOR_ENV_VAR, "env-app/2.0") + model = _build_model() + payload = model._get_request_payload([HumanMessage("hi")]) + assert payload["extra_headers"][ORIGINATOR_HEADER] == "env-app/2.0" + + +def test_constructor_originator_wins_over_env_var( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Explicit constructor value beats the env var (resolution order).""" + monkeypatch.setenv(ORIGINATOR_ENV_VAR, "env-app/2.0") + model = _build_model(originator="ctor-app/9.9") + payload = model._get_request_payload([HumanMessage("hi")]) + assert payload["extra_headers"][ORIGINATOR_HEADER] == "ctor-app/9.9" + + +def test_empty_env_var_falls_back_to_package_default( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """An empty `LANGCHAIN_CODEX_ORIGINATOR` is treated as unset.""" + monkeypatch.setenv(ORIGINATOR_ENV_VAR, "") + model = _build_model() + payload = model._get_request_payload([HumanMessage("hi")]) + assert payload["extra_headers"][ORIGINATOR_HEADER] == ORIGINATOR_VALUE + + +def test_caller_extra_headers_override_originator_field() -> None: + """A per-call `extra_headers` originator beats the model's field value.""" + model = _build_model(originator="ctor-app") + payload = model._get_request_payload( + [HumanMessage("hi")], extra_headers={ORIGINATOR_HEADER: "call-app"} + ) + assert payload["extra_headers"][ORIGINATOR_HEADER] == "call-app" + + +def test_request_payload_pulls_fresh_account_id_each_call() -> None: + provider = FakeTokenProvider() + model = _build_model(token_provider=provider) + before = provider.calls + model._get_request_payload([HumanMessage("hi")]) + model._get_request_payload([HumanMessage("hi again")]) + assert provider.calls >= before + 2 + + +def test_invalid_token_provider_rejected() -> None: + with pytest.raises(TypeError): + ChatOpenAICodex(model="gpt-5.2-codex", token_provider="not-a-provider") + + +def test_conflicting_use_responses_api_raises() -> None: + with pytest.raises(ValueError, match="use_responses_api"): + ChatOpenAICodex( + model="gpt-5.2-codex", + token_provider=FakeTokenProvider(), + use_responses_api=False, + ) + + +@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) +def test_explicit_output_version_is_respected(output_version: str) -> None: + """`output_version` is a client-side projection — any value is allowed. + + Also pins the wire-level invariant: regardless of the constructor + choice, `output_version` never appears in the outbound payload (it's + consumed entirely by the response projection layer). + """ + model = ChatOpenAICodex( + model="gpt-5.2-codex", + token_provider=FakeTokenProvider(), + output_version=output_version, + ) + assert model.output_version == output_version + payload = model._get_request_payload([HumanMessage("hi")]) + assert "output_version" not in payload + + +def test_conflicting_store_raises() -> None: + with pytest.raises(ValueError, match="store"): + ChatOpenAICodex( + model="gpt-5.2-codex", + token_provider=FakeTokenProvider(), + store=True, + ) + + +def test_conflicting_streaming_raises() -> None: + with pytest.raises(ValueError, match="streaming"): + ChatOpenAICodex( + model="gpt-5.2-codex", + token_provider=FakeTokenProvider(), + streaming=False, + ) + + +def test_request_payload_sends_store_false_and_stream_true() -> None: + """The Codex backend 400s with `store=True` or non-streaming requests.""" + model = _build_model() + payload = model._get_request_payload([HumanMessage("hi")]) + assert payload["store"] is False + assert payload["stream"] is True + + +def test_request_payload_sets_default_instructions() -> None: + """The Codex backend 400s without `instructions`; default must be injected.""" + model = _build_model() + payload = model._get_request_payload([HumanMessage("hi")]) + assert payload["instructions"] == DEFAULT_INSTRUCTIONS + + +def test_request_payload_respects_constructor_instructions() -> None: + model = _build_model(instructions="custom system prompt") + payload = model._get_request_payload([HumanMessage("hi")]) + assert payload["instructions"] == "custom system prompt" + + +def test_request_payload_respects_per_call_instructions_override() -> None: + """An `instructions` kwarg at invoke time wins over the model default.""" + model = _build_model(instructions="model-level") + payload = model._get_request_payload( + [HumanMessage("hi")], instructions="call-level" + ) + assert payload["instructions"] == "call-level" + + +def test_request_payload_preserves_explicit_empty_instructions() -> None: + """An explicit empty `instructions=""` must not be overwritten silently. + + The backend will reject it, but silently replacing it with the default + would hide the caller's bug. + """ + model = _build_model(instructions="model-level") + payload = model._get_request_payload([HumanMessage("hi")], instructions="") + assert payload["instructions"] == "" + + +def test_system_message_is_lifted_into_top_level_instructions() -> None: + """`SystemMessage` content overrides the constructor `instructions`. + + Codex rejects `SystemMessage` chat turns (400 "System messages are + not allowed"), so `ChatOpenAICodex` lifts their content into the + top-level `instructions` field and strips them from the input list. + """ + model = _build_model(instructions="model-level") + payload = model._get_request_payload( + [SystemMessage("from-system-message"), HumanMessage("hi")] + ) + assert payload["instructions"] == "from-system-message" + input_messages = payload["input"] + assert all(entry.get("role") != "system" for entry in input_messages) + assert any( + entry.get("role") == "user" and "hi" in str(entry.get("content")) + for entry in input_messages + ) + + +@pytest.mark.parametrize("role", ["system", "developer"]) +def test_chat_message_instruction_roles_are_lifted(role: str) -> None: + model = _build_model(instructions="model-level") + payload = model._get_request_payload( + [ + ChatMessage(content=f"from-{role}", role=role), + HumanMessage("hi"), + ] + ) + assert payload["instructions"] == f"from-{role}" + assert all(entry.get("role") != role for entry in payload["input"]) + assert [entry.get("role") for entry in payload["input"]] == ["user"] + + +def test_back_to_back_system_messages_join_in_input_order() -> None: + """Adjacent `SystemMessage` entries are concatenated with `"\\n\\n"`.""" + model = _build_model(instructions="model-level") + payload = model._get_request_payload( + [ + SystemMessage("first"), + SystemMessage("second"), + HumanMessage("hi"), + ] + ) + assert payload["instructions"] == "first\n\nsecond" + + +def test_interleaved_system_messages_are_still_lifted() -> None: + """`SystemMessage`s anywhere in the input are lifted into `instructions`. + + Codex is stateless per call and has no equivalent of an in-line system + turn, so positional intent (e.g., switching persona mid-conversation) + can't be preserved. All `SystemMessage`s are concatenated in input + order and stripped; the remaining (non-system) messages keep their + relative order. + """ + model = _build_model(instructions="model-level") + payload = model._get_request_payload( + [ + HumanMessage("hi"), + SystemMessage("midway-system"), + HumanMessage("bye"), + SystemMessage("tail-system"), + ] + ) + assert payload["instructions"] == "midway-system\n\ntail-system" + contents = [entry.get("content") for entry in payload["input"]] + assert all(entry.get("role") != "system" for entry in payload["input"]) + assert contents == ["hi", "bye"] + + +def test_explicit_instructions_kwarg_wins_over_system_message() -> None: + """A per-call `instructions=` kwarg always beats lifted `SystemMessage`.""" + model = _build_model(instructions="model-level") + payload = model._get_request_payload( + [SystemMessage("from-system-message"), HumanMessage("hi")], + instructions="from-kwarg", + ) + assert payload["instructions"] == "from-kwarg" + + +def test_explicit_instructions_kwarg_overrides_logs_warning( + caplog: pytest.LogCaptureFixture, +) -> None: + """The override path emits a warning so callers can audit the conflict.""" + model = _build_model(instructions="model-level") + with caplog.at_level("WARNING", logger="langchain_openai.chat_models.codex"): + model._get_request_payload( + [SystemMessage("discarded"), HumanMessage("hi")], + instructions="from-kwarg", + ) + assert any( + "explicit `instructions=` kwarg wins" in record.getMessage() + for record in caplog.records + ) + + +def test_system_message_with_text_blocks_is_lifted() -> None: + """List-of-text-blocks `SystemMessage` content flattens cleanly.""" + model = _build_model(instructions="model-level") + payload = model._get_request_payload( + [ + SystemMessage( + [ + {"type": "text", "text": "part one. "}, + {"type": "text", "text": "part two."}, + ] + ), + HumanMessage("hi"), + ] + ) + assert payload["instructions"] == "part one. part two." + + +def test_system_message_with_non_text_block_raises() -> None: + """Non-text content blocks can't be flattened into `instructions`.""" + model = _build_model(instructions="model-level") + with pytest.raises(ValueError, match="non-text content block"): + model._get_request_payload( + [ + SystemMessage( + [ + {"type": "text", "text": "ok"}, + {"type": "image_url", "image_url": "http://example/x.png"}, + ] + ), + HumanMessage("hi"), + ] + ) + + +def test_system_message_with_non_string_text_value_raises() -> None: + """A text block whose `text` isn't a string is a programming error.""" + from langchain_openai.chat_models.codex import _flatten_system_message_content + + # Constructed directly (bypassing the `SystemMessage` content type so + # we can hit the helper's defensive type check). + bad = SystemMessage.model_construct(content=[{"type": "text", "text": 42}]) + with pytest.raises(ValueError, match=r"text.*not a string"): + _flatten_system_message_content([bad]) + + +def test_system_message_with_unsupported_content_type_raises() -> None: + """Content that's neither str nor list (e.g., int) is rejected upfront.""" + from langchain_openai.chat_models.codex import _flatten_system_message_content + + bad = SystemMessage.model_construct(content=123) # type: ignore[arg-type] + with pytest.raises(ValueError, match="unsupported content type"): + _flatten_system_message_content([bad]) + + +def test_caller_headers_win_over_codex_defaults() -> None: + """Caller-supplied `extra_headers` must override the Codex injections.""" + provider = FakeTokenProvider(account_id="acct-1") + model = _build_model(token_provider=provider) + payload = model._get_request_payload( + [HumanMessage("hi")], + extra_headers={ORIGINATOR_HEADER: "custom-app"}, + ) + headers = payload["extra_headers"] + assert headers[ORIGINATOR_HEADER] == "custom-app" + # The auto-injected account ID still rides along when not overridden. + assert headers[ACCOUNT_ID_HEADER] == "acct-1" + + +async def test_agenerate_primes_async_token_cache( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """`_agenerate` must call `aget_token` so refresh doesn't block the loop.""" + provider = FakeTokenProvider() + model = _build_model(token_provider=provider) + + async def _fake_super_agenerate(*_a: Any, **_k: Any) -> Any: + return "sentinel" + + monkeypatch.setattr(ChatOpenAI, "_agenerate", _fake_super_agenerate) + before = provider.async_calls + result = await model._agenerate([HumanMessage("hi")]) + assert result == "sentinel" + assert provider.async_calls == before + 1 + + +async def test_astream_primes_async_token_cache_and_yields_headers_via_payload( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """`_astream` primes the async cache and the sync payload still attaches headers. + + The sync payload builder (`_get_request_payload`) is what actually + stamps `ChatGPT-Account-Id` + `originator` on outbound requests; the + async path's only job is to refresh the token off the event loop + before the sync builder reads from the (now warm) cache. Verifies + both halves of that contract. + """ + provider = FakeTokenProvider(account_id="acct-async") + model = _build_model(token_provider=provider) + + async def _fake_super_astream(*_a: Any, **_k: Any) -> Any: + yield "chunk" + + monkeypatch.setattr(ChatOpenAI, "_astream", _fake_super_astream) + before = provider.async_calls + received = [chunk async for chunk in model._astream([HumanMessage("hi")])] + assert received == ["chunk"] + assert provider.async_calls == before + 1 + + # Sync payload (which is what the SDK ultimately serializes) carries + # the codex headers even after the async refresh path primed them. + payload = model._get_request_payload([HumanMessage("hi")]) + headers = payload["extra_headers"] + assert headers[ACCOUNT_ID_HEADER] == "acct-async" + assert headers[ORIGINATOR_HEADER] == ORIGINATOR_VALUE + + +def test_callable_api_key_returns_provider_token() -> None: + """The `api_key` callable wired into the SDK must yield the access token.""" + provider = FakeTokenProvider(access_token="abc-123") + model = _build_model(token_provider=provider) + # ChatOpenAI converts callable api_keys into a `SecretStr` wrapping + # whatever the callable returns; resolving it should return the + # provider's current access token. + secret = model.openai_api_key + assert secret is not None + if callable(secret): + assert secret() == "abc-123" + else: + assert secret.get_secret_value() == "abc-123" + + +def test_ls_params_uses_codex_provider_tag() -> None: + model = _build_model() + params = model._get_ls_params() + assert params["ls_provider"] == "openai-codex" + + +def test_is_not_serializable_due_to_live_token_provider() -> None: + assert ChatOpenAICodex.is_lc_serializable() is False + + +def test_sync_token_callable_delegates() -> None: + provider = FakeTokenProvider(access_token="zzz") + callable_ = _SyncTokenCallable(provider) + assert callable_() == "zzz" diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_imports.py b/libs/partners/openai/tests/unit_tests/chat_models/test_imports.py index ef3ae2fb3e8..e622bfdbe02 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_imports.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_imports.py @@ -1,6 +1,6 @@ from langchain_openai.chat_models import __all__ -EXPECTED_ALL = ["ChatOpenAI", "AzureChatOpenAI"] +EXPECTED_ALL = ["ChatOpenAI", "AzureChatOpenAI", "ChatOpenAICodex"] def test_all_imports() -> None: diff --git a/libs/partners/openai/tests/unit_tests/test_chatgpt_oauth.py b/libs/partners/openai/tests/unit_tests/test_chatgpt_oauth.py new file mode 100644 index 00000000000..a7e671b7dd7 --- /dev/null +++ b/libs/partners/openai/tests/unit_tests/test_chatgpt_oauth.py @@ -0,0 +1,1031 @@ +"""Unit tests for `langchain_openai.chatgpt_oauth`.""" +# ruff: noqa: S105, S106 + +from __future__ import annotations + +import base64 +import dataclasses +import hashlib +import json +import os +from datetime import datetime, timedelta, timezone, tzinfo +from pathlib import Path +from typing import Any, Literal, overload + +import httpx +import pytest +from typing_extensions import Self + +from langchain_openai import chatgpt_oauth as oauth_module +from langchain_openai.chatgpt_oauth import ( + CHATGPT_AUTH_CLAIMS_NAMESPACE, + CHATGPT_TOKEN_URL, + ChatGPTOAuthRefreshError, + ChatGPTToken, + FileChatGPTOAuthTokenProvider, + _build_authorize_url, + _CallbackHandler, + _generate_pkce_pair, + _serialize_token, + _token_from_response, + _validate_loopback_host, + _wait_for_callback, + decode_jwt_claims, + login_chatgpt, + login_chatgpt_device, +) + + +def _make_jwt(payload: dict[str, Any]) -> str: + """Build an unsigned JWT for tests.""" + + def b64(data: bytes) -> str: + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + header = b64(json.dumps({"alg": "none", "typ": "JWT"}).encode()) + body = b64(json.dumps(payload).encode()) + sig = b64(b"sig") + return f"{header}.{body}.{sig}" + + +def test_decode_jwt_claims_extracts_namespaced_chatgpt_claims() -> None: + jwt = _make_jwt( + { + "sub": "user-1", + CHATGPT_AUTH_CLAIMS_NAMESPACE: { + "chatgpt_account_id": "acct-123", + "chatgpt_plan_type": "plus", + "chatgpt_user_id": "user-1", + }, + } + ) + claims = decode_jwt_claims(jwt) + assert claims["sub"] == "user-1" + auth = claims[CHATGPT_AUTH_CLAIMS_NAMESPACE] + assert auth["chatgpt_account_id"] == "acct-123" + assert auth["chatgpt_plan_type"] == "plus" + + +def test_decode_jwt_claims_handles_malformed_input() -> None: + assert decode_jwt_claims("") == {} + assert decode_jwt_claims("not-a-jwt") == {} + assert decode_jwt_claims("a.b") == {} + + +def test_token_from_response_extracts_claims_and_falls_back_to_existing_refresh() -> ( + None +): + id_token = _make_jwt( + { + CHATGPT_AUTH_CLAIMS_NAMESPACE: { + "chatgpt_account_id": "acct-9", + "chatgpt_plan_type": "pro", + "chatgpt_user_id": "user-9", + } + } + ) + response = { + "access_token": "new-at", + "expires_in": 3600, + "id_token": id_token, + # No refresh_token returned: must fall back to existing. + } + token = _token_from_response(response, fallback_refresh_token="old-rt") + assert token.access_token == "new-at" + assert token.refresh_token == "old-rt" + assert token.account_id == "acct-9" + assert token.plan_type == "pro" + assert token.user_id == "user-9" + assert token.id_token == id_token + # expires_at is in the future + assert token.expires_at > datetime.now(timezone.utc) + + +def test_token_is_expired_uses_skew() -> None: + now = datetime.now(timezone.utc) + token = ChatGPTToken( + access_token="x", + refresh_token="y", + expires_at=now + timedelta(minutes=1), + ) + assert token.is_expired(skew=timedelta(minutes=5)) is True + assert token.is_expired(skew=timedelta(seconds=0)) is False + + +def test_file_provider_persists_token_with_private_perms(tmp_path: Path) -> None: + store = tmp_path / "chatgpt-auth.json" + provider = FileChatGPTOAuthTokenProvider(path=store) + token = ChatGPTToken( + access_token="at", + refresh_token="rt", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + account_id="acct-1", + plan_type="plus", + ) + provider.save(token) + + assert store.exists() + if os.name != "nt": + mode = store.stat().st_mode & 0o777 + assert mode == 0o600 + + raw = json.loads(store.read_text()) + assert raw["access_token"] == "at" + assert raw["account_id"] == "acct-1" + + fresh = FileChatGPTOAuthTokenProvider(path=store) + reloaded = fresh.get_token() + assert reloaded.access_token == "at" + assert reloaded.account_id == "acct-1" + + +def test_file_provider_get_token_does_not_refresh_when_valid( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + store = tmp_path / "auth.json" + provider = FileChatGPTOAuthTokenProvider(path=store) + valid_token = ChatGPTToken( + access_token="at", + refresh_token="rt", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + ) + provider.save(valid_token) + + def _explode(*args: Any, **kwargs: Any) -> dict[str, Any]: + msg = "should not refresh" + raise AssertionError(msg) + + monkeypatch.setattr("langchain_openai.chatgpt_oauth._post_form", _explode) + out = provider.get_token() + assert out.access_token == "at" + + +def test_file_provider_get_token_refreshes_when_expired( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + store = tmp_path / "auth.json" + provider = FileChatGPTOAuthTokenProvider(path=store) + expired = ChatGPTToken( + access_token="old-at", + refresh_token="old-rt", + expires_at=datetime.now(timezone.utc) - timedelta(minutes=10), + ) + provider.save(expired) + + calls: list[dict[str, Any]] = [] + + new_id_token = _make_jwt( + {CHATGPT_AUTH_CLAIMS_NAMESPACE: {"chatgpt_account_id": "acct-after-refresh"}} + ) + + def _fake_post(url: str, data: dict[str, str], **_: Any) -> dict[str, Any]: + calls.append({"url": url, "data": data}) + return { + "access_token": "new-at", + "expires_in": 3600, + "id_token": new_id_token, + } + + monkeypatch.setattr("langchain_openai.chatgpt_oauth._post_form", _fake_post) + + refreshed = provider.get_token() + assert refreshed.access_token == "new-at" + assert refreshed.refresh_token == "old-rt" + assert refreshed.account_id == "acct-after-refresh" + assert len(calls) == 1 + assert calls[0]["data"] == { + "grant_type": "refresh_token", + "refresh_token": "old-rt", + "client_id": provider.client_id, + } + persisted = json.loads(store.read_text()) + assert persisted["access_token"] == "new-at" + assert persisted["refresh_token"] == "old-rt" + + +def test_file_provider_reloads_expired_cached_token_before_refresh( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + store = tmp_path / "auth.json" + provider = FileChatGPTOAuthTokenProvider(path=store) + expired = ChatGPTToken( + access_token="old-at", + refresh_token="old-rt", + expires_at=datetime.now(timezone.utc) - timedelta(minutes=10), + ) + provider.save(expired) + rotated = ChatGPTToken( + access_token="rotated-at", + refresh_token="rotated-rt", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + ) + FileChatGPTOAuthTokenProvider(path=store).save(rotated) + + def _explode(*args: Any, **kwargs: Any) -> dict[str, Any]: + msg = "should use disk token instead of refreshing stale cache" + raise AssertionError(msg) + + monkeypatch.setattr("langchain_openai.chatgpt_oauth._post_form", _explode) + + out = provider.get_token() + assert out.access_token == "rotated-at" + assert out.refresh_token == "rotated-rt" + + +def test_file_provider_raises_when_no_token_exists(tmp_path: Path) -> None: + provider = FileChatGPTOAuthTokenProvider(path=tmp_path / "missing.json") + with pytest.raises(FileNotFoundError): + provider.get_token() + + +def test_serialize_roundtrip_preserves_fields() -> None: + token = ChatGPTToken( + access_token="a", + refresh_token="b", + expires_at=datetime(2030, 1, 1, tzinfo=timezone.utc), + account_id="acct", + plan_type="plus", + user_id="u1", + id_token="id", + ) + serialized = _serialize_token(token) + assert serialized["access_token"] == "a" + assert serialized["expires_at"].endswith("+00:00") + parsed = json.loads(json.dumps(serialized)) + assert parsed["account_id"] == "acct" + + +def test_build_authorize_url_includes_pkce_and_state() -> None: + verifier, challenge = _generate_pkce_pair() + assert verifier != challenge + url = _build_authorize_url( + client_id="app_x", + redirect_uri="http://localhost:1455/auth/callback", + state="s1", + code_challenge=challenge, + ) + assert "client_id=app_x" in url + assert "code_challenge_method=S256" in url + assert "state=s1" in url + assert "scope=openid+profile+email+offline_access" in url + assert "redirect_uri=http%3A%2F%2Flocalhost%3A1455%2Fauth%2Fcallback" in url + + +def test_pkce_pair_challenge_is_s256_of_verifier() -> None: + """Regression guard: challenge must be base64url(SHA256(verifier)).""" + verifier, challenge = _generate_pkce_pair() + expected = ( + base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest()) + .rstrip(b"=") + .decode("ascii") + ) + assert challenge == expected + + +def test_chatgpt_token_repr_does_not_leak_secrets() -> None: + token = ChatGPTToken( + access_token="super-secret-at", + refresh_token="super-secret-rt", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + id_token="super-secret-id", + account_id="acct-1", + ) + text = repr(token) + assert "super-secret-at" not in text + assert "super-secret-rt" not in text + assert "super-secret-id" not in text + assert "acct-1" in text + + +def test_chatgpt_token_rejects_empty_or_naive_fields() -> None: + with pytest.raises(ValueError, match="access_token"): + ChatGPTToken( + access_token="", + refresh_token="rt", + expires_at=datetime.now(timezone.utc), + ) + with pytest.raises(ValueError, match="refresh_token"): + ChatGPTToken( + access_token="at", + refresh_token="", + expires_at=datetime.now(timezone.utc), + ) + with pytest.raises(ValueError, match="timezone-aware"): + ChatGPTToken( + access_token="at", + refresh_token="rt", + expires_at=datetime(2030, 1, 1), # noqa: DTZ001 + ) + + +def test_chatgpt_token_is_frozen() -> None: + """The token's construction-time invariants must hold for its lifetime. + + Providers cache and share a single instance, so post-construction mutation + (which would bypass `__post_init__`) must be impossible. + """ + token = ChatGPTToken( + access_token="at", + refresh_token="rt", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + ) + with pytest.raises(dataclasses.FrozenInstanceError): + token.access_token = "" # type: ignore[misc] + + +@pytest.mark.parametrize("host", ["localhost", "127.0.0.1", "127.0.0.2", "::1"]) +def test_validate_loopback_host_accepts_loopback(host: str) -> None: + # Loopback hosts pass validation (no exception raised). + _validate_loopback_host(host) + + +@pytest.mark.parametrize( + "host", + ["0.0.0.0", "10.0.0.5", "example.com", "192.168.1.10"], # noqa: S104 +) +def test_validate_loopback_host_rejects_non_loopback(host: str) -> None: + with pytest.raises(ValueError, match="loopback"): + _validate_loopback_host(host) + + +def test_login_chatgpt_rejects_non_loopback_host(tmp_path: Path) -> None: + """A non-loopback `host` must fail before the callback server binds.""" + with pytest.raises(ValueError, match="loopback"): + login_chatgpt( + store_path=tmp_path / "x.json", + host="0.0.0.0", # noqa: S104 + open_browser=False, + ) + + +def test_token_from_response_raises_on_missing_expires_in() -> None: + with pytest.raises(ChatGPTOAuthRefreshError, match="expires_in"): + _token_from_response( + {"access_token": "a", "refresh_token": "b"}, + fallback_refresh_token=None, + ) + + +def test_token_from_response_raises_on_missing_refresh_token() -> None: + with pytest.raises(ChatGPTOAuthRefreshError, match="refresh_token"): + _token_from_response( + {"access_token": "a", "expires_in": 3600}, + fallback_refresh_token=None, + ) + + +def test_token_from_response_raises_on_missing_access_token() -> None: + with pytest.raises(ChatGPTOAuthRefreshError, match="access_token"): + _token_from_response( + {"expires_in": 3600, "refresh_token": "rt"}, + fallback_refresh_token=None, + ) + + +def test_corrupt_token_store_raises_actionable_error(tmp_path: Path) -> None: + store = tmp_path / "auth.json" + store.write_text("{not valid json") + provider = FileChatGPTOAuthTokenProvider(path=store) + with pytest.raises(RuntimeError, match="not valid JSON"): + provider.get_token() + + +def test_missing_expires_at_in_store_raises_actionable_error(tmp_path: Path) -> None: + store = tmp_path / "auth.json" + store.write_text(json.dumps({"access_token": "at", "refresh_token": "rt"})) + provider = FileChatGPTOAuthTokenProvider(path=store) + with pytest.raises(RuntimeError, match="missing required"): + provider.get_token() + + +def test_invalid_grant_refresh_raises_typed_error( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + store = tmp_path / "auth.json" + provider = FileChatGPTOAuthTokenProvider(path=store) + provider.save( + ChatGPTToken( + access_token="old-at", + refresh_token="old-rt", + expires_at=datetime.now(timezone.utc) - timedelta(minutes=10), + ) + ) + + def _fake_post(*_: Any, **__: Any) -> dict[str, Any]: + msg = "ChatGPT refresh token is no longer valid (`invalid_grant`)." + raise ChatGPTOAuthRefreshError(msg) + + monkeypatch.setattr("langchain_openai.chatgpt_oauth._post_form", _fake_post) + with pytest.raises(ChatGPTOAuthRefreshError, match="invalid_grant"): + provider.get_token() + # The on-disk token must be preserved so a follow-up `login_chatgpt()` + # is the only thing needed. + persisted = json.loads(store.read_text()) + assert persisted["refresh_token"] == "old-rt" + assert persisted["access_token"] == "old-at" + + +def test_refresh_failure_preserves_stored_token( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + store = tmp_path / "auth.json" + provider = FileChatGPTOAuthTokenProvider(path=store) + provider.save( + ChatGPTToken( + access_token="keep-at", + refresh_token="keep-rt", + expires_at=datetime.now(timezone.utc) - timedelta(minutes=1), + ) + ) + + def _explode(*_: Any, **__: Any) -> dict[str, Any]: + msg = "transient network failure" + raise RuntimeError(msg) + + monkeypatch.setattr("langchain_openai.chatgpt_oauth._post_form", _explode) + with pytest.raises(RuntimeError, match="transient"): + provider.get_token() + persisted = json.loads(store.read_text()) + assert persisted["refresh_token"] == "keep-rt" + assert persisted["access_token"] == "keep-at" + + +def test_aget_token_refreshes_when_expired( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + store = tmp_path / "auth.json" + provider = FileChatGPTOAuthTokenProvider(path=store) + provider.save( + ChatGPTToken( + access_token="old-at", + refresh_token="old-rt", + expires_at=datetime.now(timezone.utc) - timedelta(minutes=1), + ) + ) + + def _fake_post(_url: str, _data: dict[str, str], **_: Any) -> dict[str, Any]: + return {"access_token": "new-at", "expires_in": 3600} + + monkeypatch.setattr("langchain_openai.chatgpt_oauth._post_form", _fake_post) + import asyncio + + refreshed = asyncio.run(provider.aget_token()) + assert refreshed.access_token == "new-at" + assert refreshed.refresh_token == "old-rt" + persisted = json.loads(store.read_text()) + assert persisted["access_token"] == "new-at" + + +def test_aget_access_token_returns_access_string( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + store = tmp_path / "auth.json" + provider = FileChatGPTOAuthTokenProvider(path=store) + provider.save( + ChatGPTToken( + access_token="at-x", + refresh_token="rt", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + ) + ) + + def _explode(*_: Any, **__: Any) -> dict[str, Any]: + msg = "should not refresh" + raise AssertionError(msg) + + monkeypatch.setattr("langchain_openai.chatgpt_oauth._post_form", _explode) + import asyncio + + assert asyncio.run(provider.aget_access_token()) == "at-x" + + +def test_token_is_expired_uses_skew_with_frozen_clock( + monkeypatch: pytest.MonkeyPatch, +) -> None: + frozen = datetime(2030, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + + class _FrozenDatetime(datetime): + @classmethod + def now(cls, tz: tzinfo | None = None) -> datetime: # type: ignore[override] + return frozen if tz is None else frozen.astimezone(tz) + + monkeypatch.setattr("langchain_openai.chatgpt_oauth.datetime", _FrozenDatetime) + token = ChatGPTToken( + access_token="x", + refresh_token="y", + expires_at=frozen + timedelta(minutes=1), + ) + assert token.is_expired(skew=timedelta(minutes=5)) is True + assert token.is_expired(skew=timedelta(seconds=0)) is False + + +def _make_response(status_code: int, body: dict[str, Any]) -> httpx.Response: + return httpx.Response( + status_code, + content=json.dumps(body).encode(), + headers={"Content-Type": "application/json"}, + ) + + +def test_raise_for_oauth_response_detects_invalid_grant() -> None: + from langchain_openai.chatgpt_oauth import _raise_for_oauth_response + + resp = _make_response( + 400, {"error": "invalid_grant", "error_description": "revoked"} + ) + with pytest.raises(ChatGPTOAuthRefreshError, match="invalid_grant"): + _raise_for_oauth_response(CHATGPT_TOKEN_URL, resp) + + +def test_raise_for_oauth_response_passes_through_other_errors() -> None: + from langchain_openai.chatgpt_oauth import _raise_for_oauth_response + + resp = _make_response(500, {"error": "server_error"}) + with pytest.raises(RuntimeError, match="500"): + _raise_for_oauth_response(CHATGPT_TOKEN_URL, resp) + + +def test_login_chatgpt_full_flow( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """End-to-end happy path using a stubbed callback + token endpoint.""" + posts: list[dict[str, Any]] = [] + pkce_holder: list[tuple[str, str]] = [] + real_pkce = oauth_module._generate_pkce_pair + + def _capturing_pkce() -> tuple[str, str]: + pair = real_pkce() + pkce_holder.append(pair) + return pair + + monkeypatch.setattr(oauth_module, "_generate_pkce_pair", _capturing_pkce) + # Pre-extract the state the SUT will generate by stubbing `secrets.token_urlsafe` + # so the test can craft a matching callback. + state_value = "state-xyz" + monkeypatch.setattr( + oauth_module.secrets, + "token_urlsafe", + lambda _n=32: state_value, + ) + + def _fake_wait_for_callback(**_: Any) -> dict[str, str]: + return {"code": "auth-code-1", "state": state_value} + + monkeypatch.setattr(oauth_module, "_wait_for_callback", _fake_wait_for_callback) + # Prevent any browser launch / URL print noise. + monkeypatch.setattr(oauth_module.webbrowser, "open", lambda _url: True) + + def _fake_post(url: str, data: dict[str, str], **_: Any) -> dict[str, Any]: + posts.append({"url": url, "data": data}) + return { + "access_token": "at-new", + "refresh_token": "rt-new", + "expires_in": 3600, + } + + monkeypatch.setattr(oauth_module, "_post_form", _fake_post) + + store = tmp_path / "auth.json" + provider = login_chatgpt(store_path=store, open_browser=False) + + assert posts[0]["url"] == CHATGPT_TOKEN_URL + sent = posts[0]["data"] + assert sent["grant_type"] == "authorization_code" + assert sent["code"] == "auth-code-1" + # The verifier the token endpoint sees must match the one paired with + # the challenge sent to the authorize endpoint. + assert sent["code_verifier"] == pkce_holder[0][0] + persisted = json.loads(store.read_text()) + assert persisted["access_token"] == "at-new" + assert persisted["refresh_token"] == "rt-new" + assert provider.path == store + + +def test_login_chatgpt_raises_on_state_mismatch( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(oauth_module.secrets, "token_urlsafe", lambda _n=32: "expected") + monkeypatch.setattr( + oauth_module, + "_wait_for_callback", + lambda **_: {"code": "c", "state": "ATTACKER"}, + ) + monkeypatch.setattr(oauth_module, "_post_form", lambda *_a, **_k: {}) + with pytest.raises(RuntimeError, match="state mismatch"): + login_chatgpt(store_path=tmp_path / "x.json", open_browser=False) + + +def test_login_chatgpt_state_check_runs_before_error_branch( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """If both state and error are present, state mismatch must win.""" + monkeypatch.setattr(oauth_module.secrets, "token_urlsafe", lambda _n=32: "expected") + monkeypatch.setattr( + oauth_module, + "_wait_for_callback", + lambda **_: { + "state": "ATTACKER", + "error": "access_denied", + "error_description": "user clicked deny", + }, + ) + monkeypatch.setattr(oauth_module, "_post_form", lambda *_a, **_k: {}) + with pytest.raises(RuntimeError, match="state mismatch"): + login_chatgpt(store_path=tmp_path / "x.json", open_browser=False) + + +def test_login_chatgpt_raises_when_code_missing( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(oauth_module.secrets, "token_urlsafe", lambda _n=32: "s") + monkeypatch.setattr( + oauth_module, + "_wait_for_callback", + lambda **_: {"state": "s"}, + ) + with pytest.raises(RuntimeError, match="authorization code"): + login_chatgpt(store_path=tmp_path / "x.json", open_browser=False) + + +def test_login_chatgpt_skips_browser_when_disabled( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + opened: list[str] = [] + + def _track_open(url: str) -> bool: + opened.append(url) + return True + + monkeypatch.setattr(oauth_module.webbrowser, "open", _track_open) + monkeypatch.setattr(oauth_module.secrets, "token_urlsafe", lambda _n=32: "s") + monkeypatch.setattr( + oauth_module, + "_wait_for_callback", + lambda **_: {"code": "c", "state": "s"}, + ) + monkeypatch.setattr( + oauth_module, + "_post_form", + lambda *_a, **_k: { + "access_token": "a", + "refresh_token": "r", + "expires_in": 3600, + }, + ) + login_chatgpt(store_path=tmp_path / "x.json", open_browser=False) + assert opened == [] + + +def test_login_chatgpt_device_honors_slow_down( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + posts: list[dict[str, Any]] = [] + polls: list[dict[str, Any]] = [] + sleeps: list[float] = [] + post_responses: list[dict[str, Any]] = [ + { + "device_code": "dev", + "user_code": "user", + "verification_uri": "https://example.com", + }, + { + "access_token": "at", + "refresh_token": "rt", + "expires_in": 3600, + }, + ] + poll_responses: list[dict[str, Any]] = [ + {"error": "authorization_pending"}, + {"error": "slow_down"}, + {"authorization_code": "auth-code"}, + ] + post_iter = iter(post_responses) + poll_iter = iter(poll_responses) + + def _fake_post(url: str, data: dict[str, str], **_: Any) -> dict[str, Any]: + posts.append({"url": url, "data": data}) + return next(post_iter) + + def _fake_poll(url: str, data: dict[str, str], **_: Any) -> dict[str, Any]: + polls.append({"url": url, "data": data}) + return next(poll_iter) + + def _track_sleep(seconds: float) -> None: + sleeps.append(seconds) + + monkeypatch.setattr(oauth_module, "_post_form", _fake_post) + monkeypatch.setattr(oauth_module, "_post_device_poll_form", _fake_poll) + monkeypatch.setattr(oauth_module.time, "sleep", _track_sleep) + + login_chatgpt_device(store_path=tmp_path / "x.json", poll_interval=2.0) + + assert len(polls) == 3 + # First sleep at base interval, then bumped by +5 after `slow_down`. + assert sleeps[0] == pytest.approx(2.0) + assert sleeps[1] == pytest.approx(7.0) + + +def test_login_chatgpt_device_raises_on_fatal_error( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr( + oauth_module, + "_post_form", + lambda *_a, **_k: { + "device_code": "d", + "user_code": "u", + "verification_uri": "https://example.com", + }, + ) + monkeypatch.setattr( + oauth_module, + "_post_device_poll_form", + lambda *_a, **_k: {"error": "access_denied"}, + ) + monkeypatch.setattr(oauth_module.time, "sleep", lambda _s: None) + with pytest.raises(RuntimeError, match="access_denied"): + login_chatgpt_device(store_path=tmp_path / "x.json") + + +def test_login_chatgpt_device_times_out( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr( + oauth_module, + "_post_form", + lambda *_a, **_k: { + "device_code": "d", + "user_code": "u", + "verification_uri": "https://example.com", + }, + ) + monkeypatch.setattr( + oauth_module, + "_post_device_poll_form", + lambda *_a, **_k: {"error": "authorization_pending"}, + ) + monkeypatch.setattr(oauth_module.time, "sleep", lambda _s: None) + # Force the monotonic clock to immediately blow past the deadline. + times = iter([0.0, 0.0, 9999.0]) + monkeypatch.setattr(oauth_module.time, "monotonic", lambda: next(times)) + with pytest.raises(TimeoutError): + login_chatgpt_device( + store_path=tmp_path / "x.json", poll_interval=0.0, timeout=1.0 + ) + + +def test_post_device_poll_form_returns_pending_400_payload( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class _FakeClient: + def __init__(self, **_: Any) -> None: + pass + + def __enter__(self) -> Self: + return self + + def __exit__(self, *_args: object) -> None: + pass + + def post(self, *_args: Any, **_kwargs: Any) -> httpx.Response: + return _make_response(400, {"error": "authorization_pending"}) + + monkeypatch.setattr(oauth_module.httpx, "Client", _FakeClient) + + payload = oauth_module._post_device_poll_form( + "https://example.com/poll", {"device_code": "d"} + ) + assert payload == {"error": "authorization_pending"} + + +def test_post_device_poll_form_raises_fatal_400( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class _FakeClient: + def __init__(self, **_: Any) -> None: + pass + + def __enter__(self) -> Self: + return self + + def __exit__(self, *_args: object) -> None: + pass + + def post(self, *_args: Any, **_kwargs: Any) -> httpx.Response: + return _make_response(400, {"error": "access_denied"}) + + monkeypatch.setattr(oauth_module.httpx, "Client", _FakeClient) + + with pytest.raises(RuntimeError, match="access_denied"): + oauth_module._post_device_poll_form( + "https://example.com/poll", {"device_code": "d"} + ) + + +def test_callback_handler_extracts_code_and_state() -> None: + result = _run_callback_handler( + path="/auth/callback?code=abc&state=xyz", + ) + assert result == {"code": "abc", "state": "xyz"} + + +def test_callback_handler_404s_unrelated_paths() -> None: + result = _run_callback_handler(path="/favicon.ico") + assert result is None + + +def test_callback_handler_extracts_error() -> None: + result = _run_callback_handler( + path="/auth/callback?error=access_denied&error_description=nope", + ) + assert result == {"error": "access_denied", "error_description": "nope"} + + +def test_callback_handler_success_renders_success_page() -> None: + result, body = _run_callback_handler( + path="/auth/callback?code=abc&state=xyz", + capture_body=True, + ) + assert result == {"code": "abc", "state": "xyz"} + # The apostrophe in "You're" is HTML-escaped by `html.escape`. + assert "You're signed in" in body + assert "ChatGPT sign-in complete" in body + assert "Sign-in failed" not in body + + +def test_callback_handler_error_renders_error_page() -> None: + result, body = _run_callback_handler( + path="/auth/callback?error=access_denied&error_description=user+declined", + capture_body=True, + ) + assert result == {"error": "access_denied", "error_description": "user declined"} + assert "Sign-in failed" in body + assert "user declined" in body + # Provider error code is surfaced for debuggability. + assert "access_denied" in body + assert "You're signed in" not in body + + +def test_callback_handler_error_without_description_surfaces_code() -> None: + """The provider's `error` code must reach the user when no description.""" + result, body = _run_callback_handler( + path="/auth/callback?error=invalid_scope", + capture_body=True, + ) + assert result == {"error": "invalid_scope"} + assert "Sign-in failed" in body + assert "invalid_scope" in body + + +def test_callback_handler_escapes_html_in_error_description() -> None: + """Reflected XSS regression: `error_description` must be HTML-escaped.""" + result, body = _run_callback_handler( + path=( + "/auth/callback?error=oops&error_description=" + "%3Cscript%3Ealert(1)%3C%2Fscript%3E" + ), + capture_body=True, + ) + assert result == { + "error": "oops", + "error_description": "", + } + assert "" not in body + assert "<script>alert(1)</script>" in body + + +def test_callback_handler_error_logs_server_side( + caplog: pytest.LogCaptureFixture, +) -> None: + """Operators need a server-side record of provider OAuth failures.""" + with caplog.at_level("ERROR", logger="langchain_openai.chatgpt_oauth"): + _run_callback_handler( + path="/auth/callback?error=access_denied&error_description=nope", + ) + assert any( + "access_denied" in rec.message and rec.levelname == "ERROR" + for rec in caplog.records + ) + + +def test_wait_for_callback_times_out(monkeypatch: pytest.MonkeyPatch) -> None: + # Stub out HTTPServer so no real socket is bound — the timeout path + # doesn't need a working server. + class _FakeServer: + timeout = 0.0 + + def __init__(self, *_a: Any, **_k: Any) -> None: + pass + + def handle_request(self) -> None: + return + + def server_close(self) -> None: + return + + monkeypatch.setattr(oauth_module.http.server, "HTTPServer", _FakeServer) + # Force the loop to never satisfy the deadline. + times = iter([0.0, 0.0, 9999.0]) + monkeypatch.setattr(oauth_module.time, "monotonic", lambda: next(times)) + with pytest.raises(TimeoutError): + _wait_for_callback( + host="127.0.0.1", + port=0, + callback_path="/auth/callback", + timeout=1.0, + ) + + +@overload +def _run_callback_handler( + *, path: str, capture_body: Literal[False] = False +) -> dict[str, str] | None: ... + + +@overload +def _run_callback_handler( + *, path: str, capture_body: Literal[True] +) -> tuple[dict[str, str] | None, str]: ... + + +def _run_callback_handler( + *, path: str, capture_body: bool = False +) -> dict[str, str] | None | tuple[dict[str, str] | None, str]: + """Drive `_CallbackHandler.do_GET` in-process without binding a socket. + + Bypasses `BaseHTTPRequestHandler.__init__` (which reads a real socket) + and overrides `send_response`/`send_header`/`end_headers` so the + response is captured in a `BytesIO`. Returns the populated + `server_result` if the callback was matched, or `None` if the handler + 404'd. When `capture_body=True`, returns a `(result, body)` tuple + where `body` is the decoded response body. + """ + import io + + class _BoundCallbackHandler(_CallbackHandler): + server_result: dict[str, str] = {} + + def __init__(self, request_path: str) -> None: + # Skip BaseHTTPRequestHandler.__init__: it expects a real socket. + self.path = request_path + self.command = "GET" + self.request_version = "HTTP/1.1" + self.client_address = ("127.0.0.1", 0) + self.status_code: int | None = None + self.body_buffer = io.BytesIO() + self.wfile = self.body_buffer + + def send_response(self, code: int, message: str | None = None) -> None: + self.status_code = code + + def send_header(self, keyword: str, value: str) -> None: + return + + def end_headers(self) -> None: + return + + _BoundCallbackHandler.callback_path = "/auth/callback" + handler = _BoundCallbackHandler(path) + handler.do_GET() + captured_body = handler.body_buffer.getvalue().decode("utf-8") + result: dict[str, str] | None = ( + None + if handler.status_code == 404 + else dict(_BoundCallbackHandler.server_result) + ) + if capture_body: + return result, captured_body + return result + + +def test_file_lock_logs_warning_when_fcntl_unavailable( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + """Simulate Windows by making `import fcntl` fail inside `_file_lock`.""" + import builtins + + real_import = builtins.__import__ + + def _no_fcntl(name: str, *args: Any, **kwargs: Any) -> Any: + if name == "fcntl": + msg = "simulated" + raise ImportError(msg) + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", _no_fcntl) + target = tmp_path / "auth.json" + with ( + caplog.at_level("WARNING", logger="langchain_openai.chatgpt_oauth"), + oauth_module._file_lock(target), + ): + pass + assert any("fcntl is unavailable" in rec.message for rec in caplog.records) diff --git a/libs/partners/openai/tests/unit_tests/test_imports.py b/libs/partners/openai/tests/unit_tests/test_imports.py index 59994381974..d21b6ba9dba 100644 --- a/libs/partners/openai/tests/unit_tests/test_imports.py +++ b/libs/partners/openai/tests/unit_tests/test_imports.py @@ -4,6 +4,7 @@ EXPECTED_ALL = [ "__version__", "OpenAI", "ChatOpenAI", + "ChatOpenAICodex", "OpenAIEmbeddings", "AzureOpenAI", "AzureChatOpenAI",