mirror of
https://github.com/hwchase17/langchain.git
synced 2026-07-01 14:47:02 +00:00
feat(openai): add ChatGPT OAuth-backed ChatOpenAICodex chat model (#37569)
[Docs](https://github.com/langchain-ai/docs/pull/4115) Adds a new `ChatOpenAICodex` chat model and a small `chatgpt_oauth` module so users can authenticate with their ChatGPT subscription (OAuth 2.0 Authorization Code Flow with PKCE) and route Responses-API requests to the ChatGPT Codex backend at `https://chatgpt.com/backend-api/codex`. Login and token persistence live behind a refresh-aware `ChatGPTOAuthTokenProvider` protocol so they stay decoupled from model invocation. The existing API-key `ChatOpenAI` behavior is untouched. By default the file-backed provider writes to `~/.langchain/chatgpt-auth.json` to avoid stomping on Codex CLI / VS Code sessions at `~/.codex/auth.json`. No new required dependencies are introduced (uses stdlib + `httpx`). ```python from langchain_openai import ChatOpenAICodex from langchain_openai.chatgpt_oauth import login_chatgpt login_chatgpt() model = ChatOpenAICodex(model="gpt-5.5") response = model.invoke("hello") ``` _Opened collaboratively by Mason Daugherty and open-swe._ --------- Co-authored-by: open-swe[bot] <open-swe@users.noreply.github.com> Co-authored-by: Mason Daugherty <61371264+mdrxy@users.noreply.github.com> Co-authored-by: Mason Daugherty <github@mdrxy.com> Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Module for OpenAI integrations."""
|
||||
|
||||
from langchain_openai._version import __version__
|
||||
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
|
||||
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI, ChatOpenAICodex
|
||||
from langchain_openai.chat_models._client_utils import StreamChunkTimeoutError
|
||||
from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
from langchain_openai.llms import AzureOpenAI, OpenAI
|
||||
@@ -12,6 +12,7 @@ __all__ = [
|
||||
"AzureOpenAI",
|
||||
"AzureOpenAIEmbeddings",
|
||||
"ChatOpenAI",
|
||||
"ChatOpenAICodex",
|
||||
"OpenAI",
|
||||
"OpenAIEmbeddings",
|
||||
"StreamChunkTimeoutError",
|
||||
|
||||
@@ -2,5 +2,6 @@
|
||||
|
||||
from langchain_openai.chat_models.azure import AzureChatOpenAI
|
||||
from langchain_openai.chat_models.base import ChatOpenAI
|
||||
from langchain_openai.chat_models.codex import ChatOpenAICodex
|
||||
|
||||
__all__ = ["AzureChatOpenAI", "ChatOpenAI"]
|
||||
__all__ = ["AzureChatOpenAI", "ChatOpenAI", "ChatOpenAICodex"]
|
||||
|
||||
497
libs/partners/openai/langchain_openai/chat_models/codex.py
Normal file
497
libs/partners/openai/langchain_openai/chat_models/codex.py
Normal file
@@ -0,0 +1,497 @@
|
||||
"""`ChatOpenAICodex`: OAuth-backed chat model for ChatGPT subscription auth.
|
||||
|
||||
Wraps `ChatOpenAI` to target the ChatGPT codex backend
|
||||
(`https://chatgpt.com/backend-api/codex`) and supplies refresh-aware
|
||||
`Authorization` and `ChatGPT-Account-Id` headers from a
|
||||
`ChatGPTOAuthTokenProvider`.
|
||||
|
||||
The standard `ChatOpenAI` (API-key) flow is untouched.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.language_models.chat_models import LangSmithParams
|
||||
from langchain_core.messages import BaseMessage, ChatMessage, SystemMessage
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from langchain_openai.chat_models.base import ChatOpenAI
|
||||
from langchain_openai.chatgpt_oauth import (
|
||||
ChatGPTOAuthTokenProvider,
|
||||
FileChatGPTOAuthTokenProvider,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackManagerForLLMRun
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CHATGPT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
ORIGINATOR_HEADER = "originator"
|
||||
ORIGINATOR_VALUE = "langchain"
|
||||
"""Built-in default for the `originator` header value.
|
||||
|
||||
Identifies requests as coming from `langchain-openai`. Override per-instance
|
||||
via the `originator` field or globally via the `LANGCHAIN_CODEX_ORIGINATOR`
|
||||
env var.
|
||||
"""
|
||||
ORIGINATOR_ENV_VAR = "LANGCHAIN_CODEX_ORIGINATOR"
|
||||
ACCOUNT_ID_HEADER = "ChatGPT-Account-Id"
|
||||
_INSTRUCTION_ROLES = frozenset({"system", "developer"})
|
||||
|
||||
|
||||
def _default_originator() -> str:
|
||||
"""Resolve the `originator` header default, honoring the env-var override."""
|
||||
return os.environ.get(ORIGINATOR_ENV_VAR) or ORIGINATOR_VALUE
|
||||
|
||||
|
||||
def _maybe_has_system_messages(input_: Any) -> bool:
|
||||
"""Return `True` if `input_` *could* contain a system-role message.
|
||||
|
||||
Cheap structural probe used to skip the full `_convert_input` pipeline
|
||||
when there is no chance the lift logic will fire. False positives only
|
||||
cost an extra conversion; false negatives would silently skip the lift,
|
||||
so the probe is biased toward `True` for unknown shapes.
|
||||
"""
|
||||
if isinstance(input_, str):
|
||||
return False
|
||||
if isinstance(input_, BaseMessage):
|
||||
return _is_instruction_message(input_)
|
||||
if isinstance(input_, (list, tuple)):
|
||||
for item in input_:
|
||||
if isinstance(item, BaseMessage) and _is_instruction_message(item):
|
||||
return True
|
||||
if isinstance(item, dict) and item.get("role") in _INSTRUCTION_ROLES:
|
||||
return True
|
||||
if (
|
||||
isinstance(item, tuple)
|
||||
and item
|
||||
and isinstance(item[0], str)
|
||||
and item[0] in _INSTRUCTION_ROLES
|
||||
):
|
||||
return True
|
||||
return False
|
||||
# `PromptValue` or any future shape — be safe and run the slow path.
|
||||
return True
|
||||
|
||||
|
||||
def _is_instruction_message(message: BaseMessage) -> bool:
|
||||
return isinstance(message, SystemMessage) or (
|
||||
isinstance(message, ChatMessage) and message.role in _INSTRUCTION_ROLES
|
||||
)
|
||||
|
||||
|
||||
def _flatten_system_message_content(system_messages: list[BaseMessage]) -> str:
|
||||
"""Join system/developer message content into a single `instructions` string.
|
||||
|
||||
Codex rejects system-role entries in the input list, so their content
|
||||
is lifted into the top-level `instructions` field. Content that uses
|
||||
list-of-content-blocks form is accepted only when every block is
|
||||
`{"type": "text", ...}`; anything else cannot be flattened into the
|
||||
string-typed `instructions` field.
|
||||
|
||||
Raises:
|
||||
ValueError: A system/developer message carries a non-text content block.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for index, message in enumerate(system_messages):
|
||||
message_name = type(message).__name__
|
||||
content = message.content
|
||||
if isinstance(content, str):
|
||||
parts.append(content)
|
||||
continue
|
||||
if not isinstance(content, list):
|
||||
msg = (
|
||||
f"`{message_name}` at index {index} has unsupported content "
|
||||
f"type {type(content).__name__!r}; only `str` and "
|
||||
"list-of-text-blocks are accepted by `ChatOpenAICodex`."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
text_parts: list[str] = []
|
||||
for block_index, block in enumerate(content):
|
||||
if not isinstance(block, dict) or block.get("type") != "text":
|
||||
msg = (
|
||||
f"`{message_name}` at index {index} contains a "
|
||||
f"non-text content block at position {block_index} "
|
||||
"(Codex `instructions` is a string field — only "
|
||||
'`{"type": "text", "text": "..."}` blocks can be '
|
||||
"lifted into it). Move the non-text content to a "
|
||||
"`HumanMessage`, or pass plain instructions via the "
|
||||
"constructor or `instructions=` kwarg."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
text_value = block.get("text", "")
|
||||
if not isinstance(text_value, str):
|
||||
msg = (
|
||||
f"`{message_name}` at index {index} has a text block "
|
||||
f"at position {block_index} whose `text` is not a "
|
||||
"string."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
text_parts.append(text_value)
|
||||
parts.append("".join(text_parts))
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
DEFAULT_INSTRUCTIONS = "You are ChatGPT, a large language model trained by OpenAI."
|
||||
"""Generic fallback for the Responses-API `instructions` field.
|
||||
|
||||
The Codex backend rejects any request missing a top-level `instructions`
|
||||
value (400 `Instructions are required`), so this constant keeps zero-config
|
||||
construction working. **Most callers should override it** with their own
|
||||
prompt — see `ChatOpenAICodex.instructions` for the resolution rules.
|
||||
"""
|
||||
_FORCED_VALUES: dict[str, Any] = {
|
||||
"use_responses_api": True,
|
||||
"store": False,
|
||||
"streaming": True,
|
||||
}
|
||||
"""Values forced onto every `ChatOpenAICodex` instance.
|
||||
|
||||
These are the wire-level constraints the Codex backend imposes:
|
||||
|
||||
- `use_responses_api=True`: Codex is only reachable through the Responses
|
||||
API surface.
|
||||
- `store=False`: the backend rejects `store=true`
|
||||
(`400 'Store must be set to false'`).
|
||||
- `streaming=True`: the backend rejects non-streaming requests
|
||||
(`400 'Stream must be set to true'`). Pinning this routes `invoke`
|
||||
through `_stream` so a streaming request is always sent and chunks
|
||||
are aggregated back into a single message for the caller.
|
||||
|
||||
`output_version` is intentionally **not** forced — it is a client-side
|
||||
`AIMessage` projection (see `ChatOpenAI.output_version`) that never
|
||||
appears in the request payload, so callers can pick `"v0"`, `"v1"`, or
|
||||
`"responses/v1"` freely.
|
||||
|
||||
`base_url` (and its `openai_api_base` alias) is also pinned — to
|
||||
`CHATGPT_CODEX_BASE_URL` — under the same raise-don't-rewrite contract.
|
||||
It is enforced separately in the validator rather than listed here
|
||||
because a caller-controlled endpoint combined with the OAuth bearer
|
||||
token would be a token-exfiltration vector; see the validator for the
|
||||
rationale.
|
||||
"""
|
||||
|
||||
|
||||
class ChatOpenAICodex(ChatOpenAI):
|
||||
"""`ChatOpenAI` variant authed by a ChatGPT OAuth subscription.
|
||||
|
||||
Routes requests to `https://chatgpt.com/backend-api/codex` and forces
|
||||
the wire-level fields the Codex backend requires
|
||||
(`use_responses_api=True`, `store=False`, `streaming=True`). These
|
||||
values are forced — passing a conflicting value to the constructor
|
||||
raises. `output_version` (a client-side `AIMessage` projection) is
|
||||
not forced; pick whichever projection you want. Authorization and
|
||||
`ChatGPT-Account-Id` headers are taken from `token_provider` on every
|
||||
request so a freshly-refreshed access token is always used.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain_openai import ChatOpenAICodex
|
||||
from langchain_openai.chatgpt_oauth import login_chatgpt
|
||||
|
||||
# One-time setup. The returned provider writes to the default store
|
||||
# at `~/.langchain/chatgpt-auth.json`, which `ChatOpenAICodex` also
|
||||
# reads from by default — so subsequent constructions need no
|
||||
# explicit `token_provider`.
|
||||
login_chatgpt()
|
||||
model = ChatOpenAICodex(
|
||||
model="gpt-5.5",
|
||||
instructions="You are a senior Python reviewer. Be terse.",
|
||||
)
|
||||
response = model.invoke("hello")
|
||||
```
|
||||
|
||||
!!! tip "Override `instructions`"
|
||||
|
||||
The Codex backend requires a top-level `instructions` value on every
|
||||
request. A generic default keeps zero-config use working, but most
|
||||
callers should override it via the constructor (above) or per call
|
||||
(`model.invoke(..., instructions=...)`). See the field's docstring
|
||||
for the full resolution rules.
|
||||
|
||||
!!! note
|
||||
|
||||
Token storage is handled by `FileChatGPTOAuthTokenProvider`, which
|
||||
defaults to `~/.langchain/chatgpt-auth.json` so it does not collide
|
||||
with the Codex CLI / VS Code session at `~/.codex/auth.json`.
|
||||
|
||||
!!! note "Always streams over the wire"
|
||||
|
||||
The Codex backend only accepts streaming requests, so `streaming=True`
|
||||
is forced. `invoke` still returns a single aggregated `AIMessage` —
|
||||
chunks are collected internally — but the underlying HTTP request is
|
||||
a stream either way. Expect every call to show up as a streamed
|
||||
request in network logs and LangSmith traces.
|
||||
"""
|
||||
|
||||
token_provider: Any = Field(default=None, exclude=True)
|
||||
"""Refresh-aware ChatGPT OAuth token provider.
|
||||
|
||||
Must implement the `ChatGPTOAuthTokenProvider` protocol. If `None`, a
|
||||
`FileChatGPTOAuthTokenProvider` rooted at the default store path is
|
||||
constructed.
|
||||
"""
|
||||
|
||||
originator: str | None = Field(default_factory=_default_originator)
|
||||
"""Value sent in the `originator` request header, or `None` to omit it.
|
||||
|
||||
Identifies the client making the request. Defaults to `"langchain"` so
|
||||
OpenAI telemetry attributes calls to this package. Downstream consumers
|
||||
(e.g., a framework built on top of `ChatOpenAICodex`) can override this
|
||||
to identify themselves instead, or set `None` to suppress the header.
|
||||
|
||||
Resolution order (first match wins):
|
||||
|
||||
1. Per-call `extra_headers={"originator": "..."}` (always trumps the
|
||||
field; pass an explicit value to override on a single call).
|
||||
2. Constructor / kwarg value (`ChatOpenAICodex(originator="my-app")`).
|
||||
3. The `LANGCHAIN_CODEX_ORIGINATOR` env var, if set and non-empty.
|
||||
4. `ORIGINATOR_VALUE` (`"langchain"`).
|
||||
|
||||
Setting `originator=None` disables the header entirely; the constructor
|
||||
default never resolves to `None`.
|
||||
"""
|
||||
|
||||
instructions: str = Field(default=DEFAULT_INSTRUCTIONS)
|
||||
"""System prompt sent in the Responses-API `instructions` field.
|
||||
|
||||
`instructions` is a *top-level* field of the Responses API request — it
|
||||
is not a chat message. The Codex backend rejects any request where this
|
||||
field is missing or empty (400 `Instructions are required`) **and**
|
||||
rejects any `SystemMessage` entry in the input list
|
||||
(400 `System messages are not allowed`). To bridge those constraints
|
||||
transparently, `ChatOpenAICodex` resolves `instructions` per call with
|
||||
this precedence (highest wins):
|
||||
|
||||
1. Explicit `instructions=` kwarg on `invoke` / `stream`.
|
||||
2. Concatenated content of any `SystemMessage` entries in the input
|
||||
list — joined with `"\\n\\n"` and stripped from the input before
|
||||
sending. Set the explicit kwarg in (1) to override.
|
||||
3. This constructor field (defaults to a generic ChatGPT prompt).
|
||||
|
||||
The Codex backend is stateless for this client (`store=False` is
|
||||
forced), so `instructions` is sent on every request and can be changed
|
||||
between calls — useful for switching persona / tooling mid-conversation:
|
||||
|
||||
```python
|
||||
model = ChatOpenAICodex(
|
||||
model="gpt-5.5",
|
||||
instructions="You are a senior Python reviewer. Be terse.",
|
||||
)
|
||||
model.invoke("review this diff…")
|
||||
model.invoke(
|
||||
"now translate the review to French",
|
||||
instructions="You are a translator.",
|
||||
)
|
||||
```
|
||||
|
||||
`SystemMessage` content that uses list-of-content-blocks form is
|
||||
accepted only if every block is `{"type": "text", ...}`; any other
|
||||
block type raises `ValueError` since it cannot be flattened into the
|
||||
string-typed `instructions` field.
|
||||
"""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _apply_codex_defaults(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Apply Codex-specific defaults before the parent validator runs."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
for key, forced in _FORCED_VALUES.items():
|
||||
supplied = values.get(key)
|
||||
if supplied is not None and supplied != forced:
|
||||
msg = (
|
||||
f"`ChatOpenAICodex` requires `{key}={forced!r}`; "
|
||||
f"got `{key}={supplied!r}`. Use `ChatOpenAI` if you "
|
||||
"need to customize this."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
values[key] = forced
|
||||
# Pin `base_url` (and its legacy `openai_api_base` alias) to the Codex
|
||||
# endpoint. The OAuth bearer token is wired in as `api_key` below, so a
|
||||
# caller-controlled `base_url` would otherwise exfiltrate the token to
|
||||
# an attacker-chosen host. Reject any non-matching override rather than
|
||||
# silently rewriting it, mirroring the `_FORCED_VALUES` contract.
|
||||
for key in ("base_url", "openai_api_base"):
|
||||
supplied = values.get(key)
|
||||
if supplied is not None and supplied != CHATGPT_CODEX_BASE_URL:
|
||||
msg = (
|
||||
f"`ChatOpenAICodex` requires `{key}={CHATGPT_CODEX_BASE_URL!r}`; "
|
||||
f"got `{key}={supplied!r}`. Use `ChatOpenAI` if you need to "
|
||||
"target a different endpoint."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
values[key] = CHATGPT_CODEX_BASE_URL
|
||||
|
||||
provider = values.get("token_provider")
|
||||
if provider is None:
|
||||
provider = FileChatGPTOAuthTokenProvider.from_default_store()
|
||||
values["token_provider"] = provider
|
||||
if not isinstance(provider, ChatGPTOAuthTokenProvider):
|
||||
msg = (
|
||||
"`token_provider` must implement the "
|
||||
"`ChatGPTOAuthTokenProvider` protocol."
|
||||
)
|
||||
raise TypeError(msg)
|
||||
|
||||
# The OAuth `token_provider` is the sole auth source: its access token
|
||||
# is wired into the OpenAI SDK as `api_key` below. A caller-supplied
|
||||
# `api_key` (or its `openai_api_key` alias) would silently win over the
|
||||
# OAuth bearer, leaving the model in a conflicting state — so reject it
|
||||
# (raise-don't-rewrite, mirroring the `base_url` handling above). An
|
||||
# `OPENAI_API_KEY` env var is not consulted: the field's default
|
||||
# factory never runs because `api_key` is always set here.
|
||||
for key in ("api_key", "openai_api_key"):
|
||||
if values.get(key) is not None:
|
||||
msg = (
|
||||
f"`ChatOpenAICodex` manages authentication via "
|
||||
f"`token_provider`; drop the explicit `{key}=`. Use "
|
||||
"`ChatOpenAI` if you want API-key authentication."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
values["api_key"] = _SyncTokenCallable(provider)
|
||||
return values
|
||||
|
||||
def _codex_headers_sync(self) -> dict[str, str]:
|
||||
token = self.token_provider.get_token()
|
||||
return self._build_headers(token.account_id)
|
||||
|
||||
def _build_headers(self, account_id: str | None) -> dict[str, str]:
|
||||
headers: dict[str, str] = {}
|
||||
if account_id:
|
||||
headers[ACCOUNT_ID_HEADER] = account_id
|
||||
if self.originator is not None:
|
||||
headers[ORIGINATOR_HEADER] = self.originator
|
||||
return headers
|
||||
|
||||
def _merge_codex_headers(
|
||||
self, payload: dict[str, Any], headers: dict[str, str]
|
||||
) -> dict[str, Any]:
|
||||
# Caller-supplied `extra_headers` win over our Codex defaults so
|
||||
# users can override (e.g., to send a different `originator`).
|
||||
if not headers:
|
||||
return payload
|
||||
merged = {**headers, **(payload.get("extra_headers") or {})}
|
||||
payload["extra_headers"] = merged
|
||||
return payload
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: LanguageModelInput,
|
||||
*,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Build the request payload and attach Codex auth headers.
|
||||
|
||||
Lifts any `SystemMessage` content out of the input list into the
|
||||
top-level `instructions` field, since Codex rejects `SystemMessage`
|
||||
chat turns. See the `instructions` field docstring for the
|
||||
precedence rules.
|
||||
|
||||
Fast path: when the input can't carry a `SystemMessage`, skip the
|
||||
local conversion and delegate `input_` straight to super — that
|
||||
way `_convert_input` only runs once (inside super) instead of once
|
||||
here and again there.
|
||||
"""
|
||||
payload_input: LanguageModelInput = input_
|
||||
if _maybe_has_system_messages(input_):
|
||||
messages = self._convert_input(input_).to_messages()
|
||||
system_messages = [m for m in messages if _is_instruction_message(m)]
|
||||
if system_messages:
|
||||
non_system = [m for m in messages if not _is_instruction_message(m)]
|
||||
lifted = _flatten_system_message_content(system_messages)
|
||||
explicit = kwargs.get("instructions")
|
||||
if explicit is not None:
|
||||
logger.warning(
|
||||
"Both `instructions=` and a `SystemMessage` were "
|
||||
"provided; the explicit `instructions=` kwarg wins "
|
||||
"and the `SystemMessage` content is discarded for "
|
||||
"this call. Discarded length: %d.",
|
||||
len(lifted),
|
||||
)
|
||||
else:
|
||||
kwargs["instructions"] = lifted
|
||||
payload_input = non_system
|
||||
|
||||
payload = super()._get_request_payload(payload_input, stop=stop, **kwargs)
|
||||
# The Codex backend rejects requests without `instructions` — populate
|
||||
# the field's value if the caller didn't supply one. An explicit empty
|
||||
# string from the caller is preserved (the backend will reject it, but
|
||||
# silently overwriting it would hide a programming error).
|
||||
if payload.get("instructions") is None:
|
||||
payload["instructions"] = self.instructions
|
||||
return self._merge_codex_headers(payload, self._codex_headers_sync())
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
# Prime the cache via async refresh so the sync header build that
|
||||
# happens inside `super()._agenerate` does not block the event loop.
|
||||
await self.token_provider.aget_token()
|
||||
return await super()._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
await self.token_provider.aget_token()
|
||||
async for chunk in super()._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: list[str] | None = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
params["ls_provider"] = "openai-codex"
|
||||
return params
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "openai-codex-chat"
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""`ChatOpenAICodex` is not serializable (holds a live token provider)."""
|
||||
return False
|
||||
|
||||
|
||||
class _SyncTokenCallable:
|
||||
"""Sync callable wrapper around a token provider for the OpenAI SDK.
|
||||
|
||||
The OpenAI Python SDK accepts a callable returning a string for `api_key`.
|
||||
Wrapping the provider lets the SDK fetch a freshly-refreshed access token
|
||||
on every request without exposing the provider's other methods.
|
||||
"""
|
||||
|
||||
__slots__ = ("_provider",)
|
||||
|
||||
def __init__(self, provider: ChatGPTOAuthTokenProvider) -> None:
|
||||
self._provider = provider
|
||||
|
||||
def __call__(self) -> str:
|
||||
return self._provider.get_access_token()
|
||||
|
||||
|
||||
__all__ = ["ChatOpenAICodex"]
|
||||
1058
libs/partners/openai/langchain_openai/chatgpt_oauth.py
Normal file
1058
libs/partners/openai/langchain_openai/chatgpt_oauth.py
Normal file
File diff suppressed because it is too large
Load Diff
105
libs/partners/openai/scripts/RECORD_CODEX_CASSETTES.md
Normal file
105
libs/partners/openai/scripts/RECORD_CODEX_CASSETTES.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# Recording `ChatOpenAICodex` VCR cassettes
|
||||
|
||||
`ChatOpenAICodex` authenticates with a ChatGPT subscription OAuth bundle, so
|
||||
its integration tests cannot run in PR CI without a live login. The workflow
|
||||
below records cassettes once locally with a real token, scrubs OAuth secrets
|
||||
before they hit disk, and commits the cassettes so CI replays them through
|
||||
`_test_vcr.yml` (the `vcr-tests` job in `check_diffs.yml`).
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. A ChatGPT subscription (Plus / Pro / Team / Enterprise) — required for
|
||||
`chatgpt.com/backend-api/codex`.
|
||||
2. A token bundle on disk. Generate one with:
|
||||
|
||||
```bash
|
||||
uv run --group test python -c \
|
||||
"from langchain_openai.chatgpt_oauth import login_chatgpt; login_chatgpt()"
|
||||
```
|
||||
|
||||
The default store is `~/.langchain/chatgpt-auth.json`. It is intentionally
|
||||
distinct from `~/.codex/auth.json` so the Codex CLI / VS Code session is
|
||||
not invalidated by refresh-token rotation here.
|
||||
3. An integration test that instantiates `ChatOpenAICodex` and is marked
|
||||
`@pytest.mark.vcr`. Tests written against the API-key `ChatOpenAI` will
|
||||
not exercise the Codex backend — only Codex-specific tests should be
|
||||
passed to the script.
|
||||
|
||||
## Record
|
||||
|
||||
From `libs/partners/openai/`:
|
||||
|
||||
```bash
|
||||
# Record every VCR-marked integration test (default).
|
||||
scripts/record_codex_cassettes.sh
|
||||
|
||||
# Record one file or one test.
|
||||
scripts/record_codex_cassettes.sh tests/integration_tests/chat_models/test_codex.py
|
||||
scripts/record_codex_cassettes.sh \
|
||||
tests/integration_tests/chat_models/test_codex.py::test_invoke
|
||||
|
||||
# Forward extra pytest args.
|
||||
PYTEST_EXTRA="-k streaming -x" scripts/record_codex_cassettes.sh
|
||||
```
|
||||
|
||||
The script:
|
||||
|
||||
1. Verifies the token store exists.
|
||||
2. Force-refreshes the access token *outside* pytest so VCR never sees the
|
||||
`auth.openai.com/oauth/token` roundtrip. A revoked refresh token surfaces
|
||||
here, not after a long recording run.
|
||||
3. Runs `pytest --record-mode=once -m vcr <target>` so missing cassettes are
|
||||
created and existing ones replayed.
|
||||
4. `zgrep`s every cassette for bearer tokens, JWTs, refresh-grant bodies,
|
||||
leaked API keys, and ChatGPT account-id claims. Any match aborts with
|
||||
a non-zero exit and a per-file report — do **not** commit those cassettes.
|
||||
5. Prints a diff of cassette files that were touched.
|
||||
|
||||
## What gets scrubbed automatically
|
||||
|
||||
`tests/conftest.py` redacts:
|
||||
|
||||
- Every request and response header (catches `Authorization: Bearer …`,
|
||||
`ChatGPT-Account-Id`, cookies, organization IDs).
|
||||
- Request URIs (no per-account URL parameters land in cassettes).
|
||||
- OAuth secret fields in JSON request/response bodies: `access_token`,
|
||||
`refresh_token`, `id_token`, `code`, `code_verifier`, `device_code`,
|
||||
`client_secret`.
|
||||
- The same fields in urlencoded form bodies (refresh-grant POSTs).
|
||||
- Any JWT-shaped string anywhere in a body (`eyJ…`).
|
||||
|
||||
The recording script's post-scan exists to catch anything the above misses.
|
||||
|
||||
## Override the token store
|
||||
|
||||
```bash
|
||||
CHATGPT_AUTH_FILE=/tmp/codex-test-token.json \
|
||||
scripts/record_codex_cassettes.sh
|
||||
```
|
||||
|
||||
Useful when recording against a dedicated test account so refresh rotation
|
||||
doesn't churn your personal `~/.langchain/chatgpt-auth.json`.
|
||||
|
||||
## Commit
|
||||
|
||||
```bash
|
||||
git -C libs/partners/openai status tests/cassettes/
|
||||
git -C libs/partners/openai diff --stat tests/cassettes/
|
||||
```
|
||||
|
||||
Spot-check at least one new cassette: `gunzip -c tests/cassettes/<name>.yaml.gz | less`.
|
||||
Verify every `Authorization` header reads `**REDACTED**` and no `eyJ…` strings
|
||||
remain.
|
||||
|
||||
## CI playback
|
||||
|
||||
`.github/workflows/check_diffs.yml` routes openai changes through
|
||||
`_test_vcr.yml`, which runs:
|
||||
|
||||
```bash
|
||||
make test_vcr # uv run pytest --record-mode=none -m vcr tests/integration_tests/
|
||||
```
|
||||
|
||||
`--record-mode=none` makes pytest fail rather than make a live network call,
|
||||
so a missing or stale cassette is a hard failure — exactly the signal you
|
||||
want.
|
||||
181
libs/partners/openai/scripts/record_codex_cassettes.sh
Executable file
181
libs/partners/openai/scripts/record_codex_cassettes.sh
Executable file
@@ -0,0 +1,181 @@
|
||||
#!/usr/bin/env bash
|
||||
# Record VCR cassettes for `ChatOpenAICodex` integration tests against a
|
||||
# real ChatGPT OAuth subscription account, then verify no OAuth secrets
|
||||
# survived into the on-disk cassettes.
|
||||
#
|
||||
# Usage:
|
||||
# scripts/record_codex_cassettes.sh # all integration_tests
|
||||
# scripts/record_codex_cassettes.sh tests/integration_tests/chat_models/test_codex.py
|
||||
# scripts/record_codex_cassettes.sh tests/integration_tests/chat_models/test_codex.py::test_invoke
|
||||
#
|
||||
# Env:
|
||||
# CHATGPT_AUTH_FILE Override the token store path. Defaults to
|
||||
# `$HOME/.langchain/chatgpt-auth.json`.
|
||||
# PYTEST_EXTRA Extra args forwarded to pytest (e.g. `-x -k codex`).
|
||||
#
|
||||
# Exits non-zero if the token preflight fails, if pytest fails, or if the
|
||||
# post-recording leak scan finds OAuth secrets in any cassette.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PKG_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
||||
CASSETTE_DIR="${PKG_DIR}/tests/cassettes"
|
||||
TOKEN_FILE="${CHATGPT_AUTH_FILE:-${HOME}/.langchain/chatgpt-auth.json}"
|
||||
|
||||
if [[ ! -f "${TOKEN_FILE}" ]]; then
|
||||
echo "error: ChatGPT OAuth token store not found at ${TOKEN_FILE}." >&2
|
||||
echo " Run \`python -c 'from langchain_openai.chatgpt_oauth import login_chatgpt; login_chatgpt()'\` first." >&2
|
||||
exit 2
|
||||
fi
|
||||
|
||||
cd "${PKG_DIR}"
|
||||
|
||||
# Reserve all tempfiles upfront and register a single cumulative cleanup
|
||||
# trap so a failure between stages can't leak files (notably the leak
|
||||
# report, which would contain the very tokens we just scrubbed). Mode 600
|
||||
# on the leak report seals it against other users sharing the host while
|
||||
# the script runs.
|
||||
PRE_SNAPSHOT="$(mktemp)"
|
||||
POST_SNAPSHOT="$(mktemp)"
|
||||
LEAK_REPORT="$(mktemp)"
|
||||
chmod 600 "${LEAK_REPORT}" "${PRE_SNAPSHOT}" "${POST_SNAPSHOT}"
|
||||
trap 'rm -f "${PRE_SNAPSHOT}" "${POST_SNAPSHOT}" "${LEAK_REPORT}"' EXIT
|
||||
|
||||
# Portable cassette snapshot: GNU `find -printf` is unavailable on BSD/macOS
|
||||
# and BSD `stat -f` is unavailable on Linux. A short python one-liner
|
||||
# avoids the platform split and surfaces real errors instead of swallowing
|
||||
# them behind `|| true`.
|
||||
snapshot_cassettes() {
|
||||
local dest="$1"
|
||||
if [[ ! -d "${CASSETTE_DIR}" ]]; then
|
||||
: > "${dest}"
|
||||
return 0
|
||||
fi
|
||||
python - "${CASSETTE_DIR}" >"${dest}" <<'PY'
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root = Path(sys.argv[1])
|
||||
entries = []
|
||||
for path in root.rglob("*.yaml.gz"):
|
||||
try:
|
||||
mtime = path.stat().st_mtime
|
||||
except OSError as exc:
|
||||
print(f"warning: failed to stat {path}: {exc}", file=sys.stderr)
|
||||
continue
|
||||
entries.append(f"{path} {mtime}")
|
||||
entries.sort()
|
||||
sys.stdout.write("\n".join(entries))
|
||||
if entries:
|
||||
sys.stdout.write("\n")
|
||||
PY
|
||||
}
|
||||
|
||||
# Preflight: force a refresh now so VCR doesn't record an `auth.openai.com`
|
||||
# token roundtrip mid-test. A stale or revoked refresh token surfaces here
|
||||
# rather than after a long test run.
|
||||
echo "==> Refreshing ChatGPT OAuth token (${TOKEN_FILE})"
|
||||
CHATGPT_AUTH_FILE="${TOKEN_FILE}" uv run --group test python - <<'PY'
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from langchain_openai.chatgpt_oauth import (
|
||||
ChatGPTOAuthRefreshError,
|
||||
FileChatGPTOAuthTokenProvider,
|
||||
)
|
||||
|
||||
path = Path(os.environ["CHATGPT_AUTH_FILE"])
|
||||
provider = FileChatGPTOAuthTokenProvider(path=path)
|
||||
try:
|
||||
token = provider.get_token()
|
||||
except ChatGPTOAuthRefreshError as exc:
|
||||
print(f"refresh failed: {exc}", file=sys.stderr)
|
||||
sys.exit(3)
|
||||
print(f"ok — token valid until {token.expires_at.isoformat()}")
|
||||
PY
|
||||
|
||||
snapshot_cassettes "${PRE_SNAPSHOT}"
|
||||
|
||||
# Default target: full integration suite. Override with positional args.
|
||||
if [[ "$#" -eq 0 ]]; then
|
||||
set -- tests/integration_tests/
|
||||
fi
|
||||
|
||||
# `ChatOpenAICodex` constructs an `api_key` callable from its token provider,
|
||||
# but `ChatOpenAI`'s pydantic init still requires *some* non-empty value.
|
||||
# The placeholder is intentionally < 20 chars after `sk-` so the leak-scan
|
||||
# pattern below (which targets real API keys at >= 20 chars) doesn't
|
||||
# false-positive on it.
|
||||
export OPENAI_API_KEY="${OPENAI_API_KEY:-sk-codex-placeholder}"
|
||||
|
||||
echo "==> Recording cassettes for: $*"
|
||||
# `--record-mode=once` writes any missing cassette and replays existing ones.
|
||||
# `-m vcr` limits the run to VCR-marked tests so unrelated live tests don't
|
||||
# fire unexpectedly. `PYTEST_EXTRA` is word-split via the shell (the
|
||||
# `disable=SC2086` is intentional) so embedded spaces split into separate
|
||||
# args — pass complex flags (e.g., `-k "name with spaces"`) as positional
|
||||
# args to the script itself instead.
|
||||
# shellcheck disable=SC2086
|
||||
uv run --group test --group test_integration pytest \
|
||||
--record-mode=once \
|
||||
-m vcr \
|
||||
-v --tb=short \
|
||||
${PYTEST_EXTRA:-} \
|
||||
"$@"
|
||||
|
||||
# Leak scan: zgrep the post-state cassettes for any pattern that would
|
||||
# indicate an OAuth secret slipped past the conftest scrubbers. All
|
||||
# patterns are passed to a single zgrep invocation per file so each
|
||||
# cassette is decompressed exactly once. Patterns are ERE (zgrep -E):
|
||||
# brace quantifiers use `{N,}` without backslashes.
|
||||
echo "==> Scanning cassettes for OAuth secret leaks"
|
||||
LEAK_ZGREP_ARGS=(
|
||||
-e 'Bearer ey' # bearer token in a captured header value
|
||||
-e 'eyJ[A-Za-z0-9_-]{20,}\.' # JWT-shaped payload (access/id/refresh tokens)
|
||||
-e '"refresh_token"[[:space:]]*:[[:space:]]*"[^*"]'
|
||||
-e '"access_token"[[:space:]]*:[[:space:]]*"[^*"]'
|
||||
-e '"id_token"[[:space:]]*:[[:space:]]*"[^*"]'
|
||||
-e 'refresh_token=[^&*]' # urlencoded refresh-grant body
|
||||
-e 'sk-[A-Za-z0-9]{20,}' # leaked API key (>= 20 chars after sk-)
|
||||
-e 'chatgpt_account_id' # account-id JWT claim from an id_token payload
|
||||
)
|
||||
|
||||
leak_found=0
|
||||
if [[ -d "${CASSETTE_DIR}" ]]; then
|
||||
while IFS= read -r -d '' cassette; do
|
||||
if zgrep -aHE "${LEAK_ZGREP_ARGS[@]}" "${cassette}" >> "${LEAK_REPORT}" 2>/dev/null; then
|
||||
leak_found=1
|
||||
fi
|
||||
done < <(find "${CASSETTE_DIR}" -type f -name '*.yaml.gz' -print0)
|
||||
fi
|
||||
|
||||
if [[ "${leak_found}" -ne 0 ]]; then
|
||||
echo "error: OAuth secret leak detected in cassettes:" >&2
|
||||
cat "${LEAK_REPORT}" >&2
|
||||
echo >&2
|
||||
echo "Do NOT commit these cassettes. Re-run after extending the" >&2
|
||||
echo "scrubber in tests/conftest.py to cover the leaking field." >&2
|
||||
exit 4
|
||||
fi
|
||||
|
||||
snapshot_cassettes "${POST_SNAPSHOT}"
|
||||
|
||||
# Summarize what changed so the user knows which cassettes to inspect.
|
||||
# Capture diff output to a variable so we can branch on `diff`'s exit
|
||||
# code rather than on a piped one (`set -o pipefail` would otherwise
|
||||
# flip the sense of the test).
|
||||
echo "==> Cassette changes:"
|
||||
if diff_out=$(diff -u "${PRE_SNAPSHOT}" "${POST_SNAPSHOT}"); then
|
||||
echo "(no changes detected)"
|
||||
else
|
||||
# Strip the unified-diff `---`/`+++` headers; keep the `+`/`-` body lines.
|
||||
printf '%s\n' "${diff_out}" | grep -E '^[+-][^+-]' || true
|
||||
fi
|
||||
|
||||
echo
|
||||
echo "Recording complete. Inspect the diff before committing:"
|
||||
echo " git -C ${PKG_DIR} status tests/cassettes/"
|
||||
echo " git -C ${PKG_DIR} diff --stat tests/cassettes/"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
libs/partners/openai/tests/cassettes/test_codex_invoke.yaml.gz
Normal file
BIN
libs/partners/openai/tests/cassettes/test_codex_invoke.yaml.gz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
libs/partners/openai/tests/cassettes/test_codex_stream.yaml.gz
Normal file
BIN
libs/partners/openai/tests/cassettes/test_codex_stream.yaml.gz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,29 +1,189 @@
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_tests.conftest import CustomPersister, CustomSerializer, base_vcr_config
|
||||
from vcr import VCR # type: ignore[import-untyped]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_EXTRA_HEADERS = [
|
||||
("openai-organization", "PLACEHOLDER"),
|
||||
("user-agent", "PLACEHOLDER"),
|
||||
("x-openai-client-user-agent", "PLACEHOLDER"),
|
||||
# ChatGPT OAuth subscription auth: the catch-all redactor below already
|
||||
# wipes every header, but list these explicitly so anyone reading the
|
||||
# config knows they are covered.
|
||||
("chatgpt-account-id", "PLACEHOLDER"),
|
||||
("cookie", "PLACEHOLDER"),
|
||||
("set-cookie", "PLACEHOLDER"),
|
||||
]
|
||||
|
||||
# OAuth secret-bearing fields. Redacted in request and response bodies as
|
||||
# defense-in-depth against an OAuth token-endpoint roundtrip getting captured
|
||||
# mid-cassette.
|
||||
_OAUTH_SECRET_FIELDS = frozenset(
|
||||
{
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"id_token",
|
||||
"code",
|
||||
"code_verifier",
|
||||
"device_code",
|
||||
"client_secret",
|
||||
}
|
||||
)
|
||||
_JWT_PATTERN = re.compile(rb"eyJ[A-Za-z0-9_\-]+\.[A-Za-z0-9_\-]+\.[A-Za-z0-9_\-]+")
|
||||
|
||||
# Binary content magic bytes. If any of these prefix the body we skip the
|
||||
# scrub stack — JWTs (and every other OAuth secret we redact) are ASCII,
|
||||
# so a confirmed-binary body can't carry one. Matters for performance:
|
||||
# audio/PDF/image cassette responses can be many hundreds of KB and the
|
||||
# UTF-8 + regex passes are O(N) on every record/replay.
|
||||
_BINARY_MAGIC_PREFIXES: tuple[bytes, ...] = (
|
||||
b"\x89PNG", # PNG
|
||||
b"\xff\xd8\xff", # JPEG
|
||||
b"GIF8", # GIF
|
||||
b"RIFF", # WAV / WEBP container
|
||||
b"OggS", # Ogg
|
||||
b"ID3", # MP3 with ID3 tag
|
||||
b"\xff\xfb", # MP3 without tag
|
||||
b"%PDF", # PDF
|
||||
b"PK\x03\x04", # ZIP / DOCX / etc.
|
||||
)
|
||||
|
||||
|
||||
def _scrub_form_body(body: bytes) -> bytes:
|
||||
"""Redact OAuth secret fields in a urlencoded form body."""
|
||||
text = body.decode("utf-8", errors="replace")
|
||||
parts = text.split("&")
|
||||
redacted: list[str] = []
|
||||
for part in parts:
|
||||
key, sep, _ = part.partition("=")
|
||||
if sep and key in _OAUTH_SECRET_FIELDS:
|
||||
redacted.append(f"{key}=**REDACTED**")
|
||||
else:
|
||||
redacted.append(part)
|
||||
return "&".join(redacted).encode("utf-8")
|
||||
|
||||
|
||||
def _walk_and_redact(node: Any) -> bool:
|
||||
"""Redact OAuth secret fields anywhere in a parsed JSON tree in place.
|
||||
|
||||
Returns `True` if any field was redacted. Walks dicts and lists
|
||||
recursively so nested OAuth payloads (e.g., `{"data": {"refresh_token":
|
||||
"..."}}` or `[{"access_token": "..."}]`) are scrubbed too.
|
||||
"""
|
||||
redacted = False
|
||||
if isinstance(node, dict):
|
||||
for key, value in node.items():
|
||||
if key in _OAUTH_SECRET_FIELDS and isinstance(value, (str, int, float)):
|
||||
node[key] = "**REDACTED**"
|
||||
redacted = True
|
||||
elif isinstance(value, (dict, list)):
|
||||
redacted = _walk_and_redact(value) or redacted
|
||||
elif isinstance(node, list):
|
||||
for item in node:
|
||||
if isinstance(item, (dict, list)):
|
||||
redacted = _walk_and_redact(item) or redacted
|
||||
return redacted
|
||||
|
||||
|
||||
def _scrub_json_body(body: bytes) -> tuple[bytes, bool]:
|
||||
"""Redact OAuth secrets in a JSON body recursively.
|
||||
|
||||
Returns `(scrubbed_bytes, was_json)`. `was_json` is `True` whenever the
|
||||
body parsed successfully, so callers can skip the form-body fallback
|
||||
(form-decoding a JSON body would split on `&` in string values and
|
||||
mangle them). The returned bytes match the input byte-for-byte when no
|
||||
field was redacted, so unrelated cassette diffs don't churn from JSON
|
||||
whitespace differences.
|
||||
"""
|
||||
try:
|
||||
payload = json.loads(body)
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
return body, False
|
||||
if not isinstance(payload, (dict, list)):
|
||||
return body, True
|
||||
if not _walk_and_redact(payload):
|
||||
return body, True
|
||||
return json.dumps(payload).encode("utf-8"), True
|
||||
|
||||
|
||||
def _scrub_oauth_secrets(body: bytes | str | None) -> bytes | str | None:
|
||||
"""Best-effort scrubber for OAuth secrets in request/response bodies.
|
||||
|
||||
Text bodies (JSON or urlencoded form) get a full structured scrub.
|
||||
Bodies that begin with a known binary magic prefix (PNG, JPEG, PDF,
|
||||
audio, etc.) pass through untouched — JWTs and the rest of
|
||||
`_OAUTH_SECRET_FIELDS` are ASCII, so a confirmed-binary payload can't
|
||||
carry one. Skipping these saves O(N) UTF-8 + regex work on cassette
|
||||
responses that can run into the hundreds of KB.
|
||||
|
||||
Partially-malformed UTF-8 (a corrupted text response) still gets the
|
||||
JWT regex pass via `errors="replace"`, so a token-endpoint reply with
|
||||
a single bad byte can't slip a JWT past us.
|
||||
"""
|
||||
if not body:
|
||||
return body
|
||||
if isinstance(body, bytes):
|
||||
if body.startswith(_BINARY_MAGIC_PREFIXES):
|
||||
return body
|
||||
try:
|
||||
text = body.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# Not text and no known binary prefix — could be a partially-
|
||||
# malformed text response. Run the JWT regex over a lossy decode
|
||||
# so a token-endpoint reply with a bad byte can't sneak through.
|
||||
# The lossy decode is *only* used for matching — the original
|
||||
# bytes are returned if no JWT was found.
|
||||
fallback_text = body.decode("utf-8", errors="replace").encode("utf-8")
|
||||
scrubbed_fallback = _JWT_PATTERN.sub(b"**REDACTED-JWT**", fallback_text)
|
||||
if scrubbed_fallback == fallback_text:
|
||||
return body
|
||||
logger.warning(
|
||||
"Scrubbed a JWT-shaped token from a non-UTF-8 response body "
|
||||
"(%d bytes). Originating endpoint likely returned text "
|
||||
"with a bad byte; recording was preserved via lossy decode.",
|
||||
len(body),
|
||||
)
|
||||
return scrubbed_fallback
|
||||
else:
|
||||
text = body
|
||||
payload_bytes = text.encode("utf-8")
|
||||
scrubbed, was_json = _scrub_json_body(payload_bytes)
|
||||
if not was_json:
|
||||
scrubbed = _scrub_form_body(payload_bytes)
|
||||
# Final pass: blanket-redact any JWT-shaped string that survived (e.g.
|
||||
# JWT embedded in a free-form error message).
|
||||
scrubbed = _JWT_PATTERN.sub(b"**REDACTED-JWT**", scrubbed)
|
||||
return scrubbed.decode("utf-8") if isinstance(body, str) else scrubbed
|
||||
|
||||
|
||||
def remove_request_headers(request: Any) -> Any:
|
||||
"""Remove sensitive headers from the request."""
|
||||
"""Remove sensitive headers and OAuth secrets from the request."""
|
||||
for k in request.headers:
|
||||
request.headers[k] = "**REDACTED**"
|
||||
request.uri = "**REDACTED**"
|
||||
request.body = _scrub_oauth_secrets(request.body)
|
||||
return request
|
||||
|
||||
|
||||
def remove_response_headers(response: dict) -> dict:
|
||||
"""Remove sensitive headers from the response."""
|
||||
"""Remove sensitive headers and OAuth secrets from the response."""
|
||||
for k in response["headers"]:
|
||||
response["headers"][k] = "**REDACTED**"
|
||||
# Pinning vcrpy's internal `body["string"]` shape: if vcrpy ever
|
||||
# switches to a different key the scrub silently no-ops, so the
|
||||
# script's post-recording leak scan is the load-bearing backstop.
|
||||
body = response.get("body")
|
||||
if isinstance(body, dict):
|
||||
body_value = body.get("string")
|
||||
if body_value is not None:
|
||||
body["string"] = _scrub_oauth_secrets(body_value)
|
||||
return response
|
||||
|
||||
|
||||
@@ -44,6 +204,28 @@ def vcr_config() -> dict:
|
||||
return config
|
||||
|
||||
|
||||
def _normalize_tool_descriptions(body: Any) -> Any:
|
||||
"""Strip common leading whitespace from `tools[*].description` strings.
|
||||
|
||||
Python 3.13+ runs `inspect.cleandoc` on docstrings at compile time, but
|
||||
cassettes recorded on Python 3.12 (or earlier) preserve the raw indented
|
||||
form. Normalize the `description` field — populated from a tool's
|
||||
docstring by `@tool` — so cassettes recorded on any Python version match
|
||||
requests issued by any other version. Scoped to `tools[*].description`
|
||||
to avoid mutating user-visible message content.
|
||||
"""
|
||||
if not isinstance(body, dict):
|
||||
return body
|
||||
tools = body.get("tools")
|
||||
if isinstance(tools, list):
|
||||
for entry in tools:
|
||||
if isinstance(entry, dict):
|
||||
description = entry.get("description")
|
||||
if isinstance(description, str):
|
||||
entry["description"] = inspect.cleandoc(description)
|
||||
return body
|
||||
|
||||
|
||||
def _json_body_matcher(r1: Any, r2: Any) -> None:
|
||||
"""Match request bodies as parsed JSON, ignoring key order."""
|
||||
b1 = r1.body or b""
|
||||
@@ -53,8 +235,8 @@ def _json_body_matcher(r1: Any, r2: Any) -> None:
|
||||
if isinstance(b2, bytes):
|
||||
b2 = b2.decode("utf-8")
|
||||
try:
|
||||
j1 = json.loads(b1)
|
||||
j2 = json.loads(b2)
|
||||
j1 = _normalize_tool_descriptions(json.loads(b1))
|
||||
j2 = _normalize_tool_descriptions(json.loads(b2))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
assert b1 == b2, f"body mismatch (non-JSON):\n{b1}\n!=\n{b2}"
|
||||
return
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
"""Shared fixtures for chat-model integration tests.
|
||||
|
||||
The `ChatOpenAICodex` integration tests run under VCR cassette playback in
|
||||
CI (`make test_vcr`), but its `FileChatGPTOAuthTokenProvider` still tries
|
||||
to read `~/.langchain/chatgpt-auth.json` from disk on every request to
|
||||
build the `Authorization` header. CI has no such file, so every Codex
|
||||
test would fail with `FileNotFoundError` before VCR ever replays the
|
||||
cassette.
|
||||
|
||||
This fixture monkey-patches the on-disk token methods with an in-memory
|
||||
fake token whenever a Codex test module is running, so cassette replay
|
||||
works without a live OAuth login. The patch is scoped by module name and
|
||||
leaves non-Codex tests in the directory untouched.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_openai import chatgpt_oauth
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_openai_base_url_env(
|
||||
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Strip ambient OpenAI endpoint env vars for VCR tests only.
|
||||
|
||||
`ChatOpenAI` only default-enables `stream_usage` when no custom base URL is
|
||||
configured (see the `OPENAI_BASE_URL` / `OPENAI_API_BASE` handling at
|
||||
construction time). A developer whose shell exports one of these vars (e.g.
|
||||
pointing at a gateway) would otherwise see `stream_usage` silently dropped,
|
||||
so the request body omits `stream_options` and no longer matches the
|
||||
recorded cassette's `json_body` — VCR then attempts a live call and fails
|
||||
with `APIConnectionError`.
|
||||
|
||||
Scoped to `@pytest.mark.vcr` tests: cassettes are recorded against the
|
||||
canonical `api.openai.com` host, so cassette playback (and re-recording)
|
||||
must ignore an ambient gateway endpoint. Live integration tests (no
|
||||
cassette) are left untouched so a developer can still route them through a
|
||||
gateway via these env vars.
|
||||
"""
|
||||
if request.node.get_closest_marker("vcr") is None:
|
||||
return
|
||||
for name in ("OPENAI_BASE_URL", "OPENAI_API_BASE"):
|
||||
monkeypatch.delenv(name, raising=False)
|
||||
|
||||
|
||||
def _vcr_record_mode(config: pytest.Config) -> str | None:
|
||||
"""Return pytest-recording's configured record mode, if available."""
|
||||
for option in ("record_mode", "--record-mode"):
|
||||
try:
|
||||
value: Any = config.getoption(option, default=None)
|
||||
except ValueError:
|
||||
continue
|
||||
if value is not None:
|
||||
return str(value)
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _fake_codex_oauth_token(
|
||||
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Stub `FileChatGPTOAuthTokenProvider` token reads for Codex VCR tests."""
|
||||
if "codex" not in request.module.__name__:
|
||||
return
|
||||
if _vcr_record_mode(request.config) != "none":
|
||||
return
|
||||
|
||||
fake_token = chatgpt_oauth.ChatGPTToken(
|
||||
access_token="vcr-fake-access-token", # noqa: S106
|
||||
refresh_token="vcr-fake-refresh-token", # noqa: S106
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
account_id="vcr-fake-account-id",
|
||||
)
|
||||
|
||||
def _get_token(
|
||||
self: chatgpt_oauth.FileChatGPTOAuthTokenProvider,
|
||||
) -> chatgpt_oauth.ChatGPTToken:
|
||||
return fake_token
|
||||
|
||||
async def _aget_token(
|
||||
self: chatgpt_oauth.FileChatGPTOAuthTokenProvider,
|
||||
) -> chatgpt_oauth.ChatGPTToken:
|
||||
return fake_token
|
||||
|
||||
def _get_access_token(
|
||||
self: chatgpt_oauth.FileChatGPTOAuthTokenProvider,
|
||||
) -> str:
|
||||
return fake_token.access_token
|
||||
|
||||
async def _aget_access_token(
|
||||
self: chatgpt_oauth.FileChatGPTOAuthTokenProvider,
|
||||
) -> str:
|
||||
return fake_token.access_token
|
||||
|
||||
monkeypatch.setattr(
|
||||
chatgpt_oauth.FileChatGPTOAuthTokenProvider, "get_token", _get_token
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
chatgpt_oauth.FileChatGPTOAuthTokenProvider, "aget_token", _aget_token
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
chatgpt_oauth.FileChatGPTOAuthTokenProvider,
|
||||
"get_access_token",
|
||||
_get_access_token,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
chatgpt_oauth.FileChatGPTOAuthTokenProvider,
|
||||
"aget_access_token",
|
||||
_aget_access_token,
|
||||
)
|
||||
@@ -0,0 +1,369 @@
|
||||
"""Integration tests for `ChatOpenAICodex`.
|
||||
|
||||
These tests exercise the ChatGPT subscription OAuth path against the
|
||||
`https://chatgpt.com/backend-api/codex` endpoint. They are recorded with
|
||||
VCR (cassettes live alongside, under `tests/cassettes/`) so CI replays
|
||||
them in `--record-mode=none` without a live token.
|
||||
|
||||
`ChatOpenAICodex` forces `use_responses_api=True`, `store=False`, and
|
||||
`streaming=True` at the wire level (`output_version` is a client-side
|
||||
projection and is *not* forced). The cassettes here are recorded with a
|
||||
single `output_version` for stability; per-projection coverage already
|
||||
lives in `test_responses_api.py`.
|
||||
|
||||
Override the model with `CODEX_MODEL=<id>` when recording against a
|
||||
different account / plan tier.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.tools import tool
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_openai import ChatOpenAICodex, custom_tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Awaitable, Iterator
|
||||
|
||||
from langchain_core.language_models.chat_model_stream import (
|
||||
AsyncChatModelStream,
|
||||
ChatModelStream,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.vcr
|
||||
|
||||
MODEL_NAME = os.getenv("CODEX_MODEL", "gpt-5.5")
|
||||
TERSE_INSTRUCTIONS = "You are terse. Answer in five words or fewer."
|
||||
|
||||
|
||||
def _check_response(response: BaseMessage | None) -> None:
|
||||
"""Assert the response carries the minimum Responses-API shape.
|
||||
|
||||
Looser than the `test_responses_api.py` equivalent — Codex responses
|
||||
don't always populate `service_tier`, so we don't require it here.
|
||||
"""
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, list)
|
||||
text_content = response.text
|
||||
assert isinstance(text_content, str)
|
||||
assert text_content
|
||||
assert response.usage_metadata
|
||||
assert response.usage_metadata["input_tokens"] > 0
|
||||
assert response.usage_metadata["output_tokens"] > 0
|
||||
assert response.usage_metadata["total_tokens"] > 0
|
||||
assert response.response_metadata["model_name"]
|
||||
|
||||
|
||||
def _aggregate(stream: Iterator[BaseMessage]) -> AIMessageChunk:
|
||||
"""Drain a sync chunk stream and return the aggregated `AIMessageChunk`.
|
||||
|
||||
Typed against `BaseMessage` (the broadest output of `Runnable.stream`)
|
||||
so that `bound.stream(...)` — which mypy infers as `Iterator[AIMessage]`
|
||||
after `bind_tools` — is accepted without a cast. The `isinstance`
|
||||
guard inside the loop enforces the runtime contract.
|
||||
"""
|
||||
aggregated: BaseMessageChunk | None = None
|
||||
for chunk in stream:
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
aggregated = chunk if aggregated is None else aggregated + chunk
|
||||
assert isinstance(aggregated, AIMessageChunk)
|
||||
return aggregated
|
||||
|
||||
|
||||
async def _aaggregate(stream: AsyncIterator[BaseMessage]) -> AIMessageChunk:
|
||||
"""Drain an async chunk stream and return the aggregated `AIMessageChunk`."""
|
||||
aggregated: BaseMessageChunk | None = None
|
||||
async for chunk in stream:
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
aggregated = chunk if aggregated is None else aggregated + chunk
|
||||
assert isinstance(aggregated, AIMessageChunk)
|
||||
return aggregated
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Basic invoke / stream / async surface
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_codex_invoke() -> None:
|
||||
llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS)
|
||||
response = llm.invoke("Say hi in one word.")
|
||||
_check_response(response)
|
||||
|
||||
|
||||
def test_codex_invoke_lifts_system_message_into_instructions() -> None:
|
||||
"""`SystemMessage` content is lifted into top-level `instructions`.
|
||||
|
||||
Codex rejects `SystemMessage` chat turns; `ChatOpenAICodex` works
|
||||
around this by moving the `SystemMessage` content into the
|
||||
`instructions` field and stripping it from the input list. The model
|
||||
should respect the lifted instruction (here: respond with HELLO).
|
||||
"""
|
||||
llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS)
|
||||
response = llm.invoke(
|
||||
[
|
||||
SystemMessage("Respond with exactly one word: HELLO. No punctuation."),
|
||||
HumanMessage("Greet me."),
|
||||
]
|
||||
)
|
||||
_check_response(response)
|
||||
assert "hello" in response.text.lower()
|
||||
|
||||
|
||||
def test_codex_invoke_with_instructions_override() -> None:
|
||||
"""Per-call `instructions=` overrides the constructor value for one call."""
|
||||
llm = ChatOpenAICodex(
|
||||
model=MODEL_NAME, instructions="You are an English assistant."
|
||||
)
|
||||
response = llm.invoke(
|
||||
"Greet me.",
|
||||
instructions=(
|
||||
"You are a French assistant. Respond only in French, in five words "
|
||||
"or fewer."
|
||||
),
|
||||
)
|
||||
_check_response(response)
|
||||
|
||||
|
||||
async def test_codex_invoke_async() -> None:
|
||||
llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS)
|
||||
response = await llm.ainvoke("Say hi in one word.")
|
||||
_check_response(response)
|
||||
|
||||
|
||||
def test_codex_stream() -> None:
|
||||
llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS)
|
||||
_check_response(_aggregate(llm.stream("Count to three.")))
|
||||
|
||||
|
||||
async def test_codex_stream_async() -> None:
|
||||
llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS)
|
||||
_check_response(await _aaggregate(llm.astream("Count to three.")))
|
||||
|
||||
|
||||
def test_codex_stream_events_v3() -> None:
|
||||
llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS)
|
||||
stream = cast("ChatModelStream", llm.stream_events("Count to three.", version="v3"))
|
||||
response = stream.output
|
||||
_check_response(response)
|
||||
|
||||
|
||||
async def test_codex_stream_events_v3_async() -> None:
|
||||
llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS)
|
||||
stream = await cast(
|
||||
"Awaitable[AsyncChatModelStream]",
|
||||
llm.astream_events("Count to three.", version="v3"),
|
||||
)
|
||||
response = await stream
|
||||
_check_response(response)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-turn conversation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_codex_multi_turn_no_tools() -> None:
|
||||
"""Pass full chat history (the backend is stateless for this client)."""
|
||||
llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS)
|
||||
first = llm.invoke("My name is Bobo.")
|
||||
assert isinstance(first, AIMessage)
|
||||
second = llm.invoke(
|
||||
[
|
||||
HumanMessage("My name is Bobo."),
|
||||
first,
|
||||
HumanMessage("What is my name?"),
|
||||
]
|
||||
)
|
||||
_check_response(second)
|
||||
assert "bobo" in second.text.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Function calling / agent loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_codex_function_calling() -> None:
|
||||
@tool
|
||||
def multiply(x: int, y: int) -> int:
|
||||
"""Return x * y."""
|
||||
return x * y
|
||||
|
||||
llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS)
|
||||
bound = llm.bind_tools([multiply])
|
||||
|
||||
ai_msg = cast(AIMessage, bound.invoke("What is 5 times 4?"))
|
||||
assert len(ai_msg.tool_calls) == 1
|
||||
assert ai_msg.tool_calls[0]["name"] == "multiply"
|
||||
assert set(ai_msg.tool_calls[0]["args"]) == {"x", "y"}
|
||||
|
||||
aggregated = _aggregate(bound.stream("What is 5 times 4?"))
|
||||
assert len(aggregated.tool_calls) == 1
|
||||
assert aggregated.tool_calls[0]["name"] == "multiply"
|
||||
assert set(aggregated.tool_calls[0]["args"]) == {"x", "y"}
|
||||
|
||||
|
||||
def test_codex_agent_loop() -> None:
|
||||
"""Tool call → tool message → final answer (three round trips)."""
|
||||
|
||||
@tool
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get the weather for a location."""
|
||||
return "It's sunny."
|
||||
|
||||
llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS)
|
||||
bound = llm.bind_tools([get_weather])
|
||||
|
||||
user_msg = HumanMessage("What is the weather in San Francisco, CA?")
|
||||
tool_call_msg = cast(AIMessage, bound.invoke([user_msg]))
|
||||
assert tool_call_msg.tool_calls
|
||||
tool_call = tool_call_msg.tool_calls[0]
|
||||
tool_msg = get_weather.invoke(tool_call)
|
||||
assert isinstance(tool_msg, ToolMessage)
|
||||
|
||||
final = bound.invoke([user_msg, tool_call_msg, tool_msg])
|
||||
assert isinstance(final, AIMessage)
|
||||
|
||||
|
||||
def test_codex_agent_loop_streaming() -> None:
|
||||
@tool
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get the weather for a location."""
|
||||
return "It's sunny."
|
||||
|
||||
llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS)
|
||||
bound = llm.bind_tools([get_weather])
|
||||
|
||||
user_msg = HumanMessage("What is the weather in San Francisco, CA?")
|
||||
tool_call_msg = _aggregate(bound.stream([user_msg]))
|
||||
assert tool_call_msg.tool_calls
|
||||
tool_msg = get_weather.invoke(tool_call_msg.tool_calls[0])
|
||||
assert isinstance(tool_msg, ToolMessage)
|
||||
|
||||
final = _aggregate(bound.stream([user_msg, tool_call_msg, tool_msg]))
|
||||
assert isinstance(final, AIMessage)
|
||||
|
||||
|
||||
def test_codex_custom_tool() -> None:
|
||||
@custom_tool
|
||||
def execute_code(code: str) -> str:
|
||||
"""Execute Python code and return the result."""
|
||||
return "27"
|
||||
|
||||
llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS).bind_tools(
|
||||
[execute_code]
|
||||
)
|
||||
|
||||
input_message = {
|
||||
"role": "user",
|
||||
"content": "Use the execute_code tool to evaluate 3**3.",
|
||||
}
|
||||
tool_call_msg = cast(AIMessage, llm.invoke([input_message]))
|
||||
assert tool_call_msg.tool_calls
|
||||
tool_msg = execute_code.invoke(tool_call_msg.tool_calls[0])
|
||||
response = llm.invoke([input_message, tool_call_msg, tool_msg])
|
||||
assert isinstance(response, AIMessage)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reasoning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_codex_reasoning() -> None:
|
||||
"""`reasoning={'effort': 'low'}` produces a reasoning block in content."""
|
||||
llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS)
|
||||
response = llm.invoke("What is 2 + 2?", reasoning={"effort": "low"})
|
||||
assert isinstance(response, AIMessage)
|
||||
block_types = [
|
||||
block["type"] for block in response.content if isinstance(block, dict)
|
||||
]
|
||||
assert "reasoning" in block_types or "text" in block_types
|
||||
|
||||
|
||||
def test_codex_reasoning_summary_streaming() -> None:
|
||||
"""`reasoning.summary='auto'` carries a populated summary list."""
|
||||
llm = ChatOpenAICodex(
|
||||
model=MODEL_NAME,
|
||||
instructions=TERSE_INSTRUCTIONS,
|
||||
reasoning={"effort": "medium", "summary": "auto"},
|
||||
)
|
||||
aggregated = _aggregate(
|
||||
llm.stream("What was the tallest building in the year 2000?")
|
||||
)
|
||||
|
||||
reasoning_blocks = [
|
||||
block
|
||||
for block in aggregated.content
|
||||
if isinstance(block, dict) and block["type"] == "reasoning"
|
||||
]
|
||||
# A non-trivial prompt should produce exactly one reasoning content block
|
||||
# with at least one summary entry; if Codex stops streaming summaries
|
||||
# this assertion regresses loudly instead of silently passing.
|
||||
assert len(reasoning_blocks) >= 1
|
||||
summary = reasoning_blocks[0].get("summary")
|
||||
assert isinstance(summary, list)
|
||||
assert summary, "reasoning block emitted an empty `summary` list"
|
||||
for summary_block in summary:
|
||||
assert isinstance(summary_block, dict)
|
||||
assert isinstance(summary_block.get("type"), str)
|
||||
assert isinstance(summary_block.get("text"), str)
|
||||
assert summary_block["text"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Structured output
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Foo(BaseModel):
|
||||
"""A trivial pydantic schema used to exercise `response_format`."""
|
||||
|
||||
# Docstring is intentional: Pydantic emits it as `description` in the
|
||||
# JSON schema sent to Codex, which the cassettes pin on.
|
||||
response: str
|
||||
|
||||
|
||||
class FooDict(TypedDict):
|
||||
"""A trivial TypedDict schema used to exercise `response_format`."""
|
||||
|
||||
response: str
|
||||
|
||||
|
||||
def test_codex_structured_output_pydantic() -> None:
|
||||
llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS)
|
||||
response = llm.invoke("Say hi.", response_format=Foo)
|
||||
parsed = Foo(**json.loads(response.text))
|
||||
assert parsed == response.additional_kwargs["parsed"]
|
||||
assert parsed.response
|
||||
|
||||
|
||||
def test_codex_structured_output_typed_dict() -> None:
|
||||
llm = ChatOpenAICodex(model=MODEL_NAME, instructions=TERSE_INSTRUCTIONS)
|
||||
response = llm.invoke("Say hi.", response_format=FooDict)
|
||||
parsed = json.loads(response.text)
|
||||
assert parsed == response.additional_kwargs["parsed"]
|
||||
assert isinstance(parsed["response"], str)
|
||||
assert parsed["response"]
|
||||
|
||||
|
||||
# Header behavior (originator default/override, ChatGPT-Account-Id presence)
|
||||
# is covered by the unit suite — VCR cassettes redact every recorded header
|
||||
# value, so an integration test here couldn't distinguish a header-present
|
||||
# round trip from a header-absent one anyway.
|
||||
@@ -0,0 +1,157 @@
|
||||
"""Standard LangChain interface tests for `ChatOpenAICodex`.
|
||||
|
||||
Drives the full `ChatModelIntegrationTests` suite against the Codex
|
||||
backend. The module-level `pytestmark = pytest.mark.vcr` makes every
|
||||
inherited test record/replay through VCR so cassettes recorded once
|
||||
locally replay without a live OAuth token in CI.
|
||||
|
||||
Capability flags below reflect what the Codex ChatGPT-subscription
|
||||
endpoint currently exposes (image inputs, PDF inputs, audio inputs,
|
||||
JSON-mode `response_format`, Anthropic-format inputs all work). The
|
||||
divergence from `ChatOpenAI`'s defaults is intentional — Codex is a
|
||||
subset surface, so a few `ChatModelIntegrationTests` cases are xfailed
|
||||
with documented reasons.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
from langchain_openai import ChatOpenAICodex
|
||||
|
||||
pytestmark = pytest.mark.vcr
|
||||
|
||||
MODEL_NAME = os.getenv("CODEX_MODEL", "gpt-5.5")
|
||||
TERSE_INSTRUCTIONS = "You are terse. Answer in five words or fewer."
|
||||
|
||||
|
||||
class TestChatOpenAICodexStandard(ChatModelIntegrationTests):
|
||||
"""Standard chat-model integration suite, configured for Codex.
|
||||
|
||||
Capability properties below override the upstream defaults to match
|
||||
what the Codex backend supports (or fails on) — flip any to `False`
|
||||
if Codex regresses on the corresponding capability, then re-record
|
||||
cassettes via the recording script.
|
||||
"""
|
||||
|
||||
@property
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return ChatOpenAICodex
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"model": MODEL_NAME, "instructions": TERSE_INSTRUCTIONS}
|
||||
|
||||
@property
|
||||
def model_override_value(self) -> str | None:
|
||||
# The Codex ChatGPT-subscription backend exposes a single model
|
||||
# to this client; reuse `MODEL_NAME` so the override path exercises
|
||||
# the per-call `model=` plumbing without depending on a second
|
||||
# account-eligible model.
|
||||
return MODEL_NAME
|
||||
|
||||
# -- Capability flags -------------------------------------------------
|
||||
# All currently confirmed working against the ChatGPT-subscription
|
||||
# Codex backend at recording time. Flip a flag to `False` and
|
||||
# re-record if Codex stops accepting a capability.
|
||||
|
||||
@property
|
||||
def supports_image_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_image_urls(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_image_tool_message(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_pdf_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_pdf_tool_message(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_audio_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_json_mode(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_anthropic_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def enable_vcr_tests(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def supported_usage_metadata_details(
|
||||
self,
|
||||
) -> dict[
|
||||
Literal["invoke", "stream"],
|
||||
list[
|
||||
Literal[
|
||||
"audio_input",
|
||||
"audio_output",
|
||||
"reasoning_output",
|
||||
"cache_read_input",
|
||||
"cache_creation_input",
|
||||
]
|
||||
],
|
||||
]:
|
||||
# The Codex backend reports `reasoning_output` for reasoning-enabled
|
||||
# models; cache and audio metadata are not surfaced through the
|
||||
# subscription endpoint.
|
||||
return {"invoke": ["reasoning_output"], "stream": ["reasoning_output"]}
|
||||
|
||||
# -- Helpers used by the shared suite ---------------------------------
|
||||
|
||||
def invoke_with_reasoning_output(self, *, stream: bool = False) -> AIMessage:
|
||||
llm = ChatOpenAICodex(
|
||||
model=MODEL_NAME,
|
||||
instructions=TERSE_INSTRUCTIONS,
|
||||
reasoning={"effort": "medium", "summary": "auto"},
|
||||
)
|
||||
prompt = "What was the 3rd highest building in 2000?"
|
||||
return _invoke(llm, prompt, stream)
|
||||
|
||||
# -- Codex-specific xfails --------------------------------------------
|
||||
|
||||
@pytest.mark.xfail(reason="Codex backend does not honor `stop` sequences.")
|
||||
def test_stop_sequence(self, model: BaseChatModel) -> None:
|
||||
super().test_stop_sequence(model)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"Few-shot helper generates a fresh tool-call UUID per invocation, "
|
||||
"so the recorded request body never matches on replay. Tracked "
|
||||
"separately; this test needs a deterministic call_id or a "
|
||||
"matcher exception in conftest before it can be cassette-backed."
|
||||
)
|
||||
)
|
||||
def test_structured_few_shot_examples(
|
||||
self, model: BaseChatModel, my_adder_tool: Any
|
||||
) -> None:
|
||||
super().test_structured_few_shot_examples(model, my_adder_tool)
|
||||
|
||||
|
||||
def _invoke(llm: ChatOpenAICodex, prompt: str, stream: bool) -> AIMessage:
|
||||
if stream:
|
||||
full = None
|
||||
for chunk in llm.stream(prompt):
|
||||
full = full + chunk if full else chunk # type: ignore[operator]
|
||||
return cast(AIMessage, full)
|
||||
return cast(AIMessage, llm.invoke(prompt))
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Shared fixtures for chat model unit tests."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_openai_base_url_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Strip ambient OpenAI endpoint env vars so unit tests stay deterministic.
|
||||
|
||||
`ChatOpenAI` only default-enables `stream_usage` when no custom base URL is
|
||||
configured (see the `OPENAI_BASE_URL` / `OPENAI_API_BASE` handling at
|
||||
construction time). A developer whose shell exports one of these vars (e.g.
|
||||
pointing at a gateway) would otherwise see `stream_usage` silently dropped,
|
||||
breaking serialization snapshots that assume the default. Removing them here
|
||||
keeps results consistent with CI, where the vars are unset.
|
||||
"""
|
||||
for name in ("OPENAI_BASE_URL", "OPENAI_API_BASE"):
|
||||
monkeypatch.delenv(name, raising=False)
|
||||
536
libs/partners/openai/tests/unit_tests/chat_models/test_codex.py
Normal file
536
libs/partners/openai/tests/unit_tests/chat_models/test_codex.py
Normal file
@@ -0,0 +1,536 @@
|
||||
"""Unit tests for `ChatOpenAICodex`."""
|
||||
# ruff: noqa: S106, S107
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import ChatMessage, HumanMessage, SystemMessage
|
||||
|
||||
from langchain_openai import ChatOpenAICodex
|
||||
from langchain_openai.chat_models.base import ChatOpenAI
|
||||
from langchain_openai.chat_models.codex import (
|
||||
ACCOUNT_ID_HEADER,
|
||||
CHATGPT_CODEX_BASE_URL,
|
||||
DEFAULT_INSTRUCTIONS,
|
||||
ORIGINATOR_ENV_VAR,
|
||||
ORIGINATOR_HEADER,
|
||||
ORIGINATOR_VALUE,
|
||||
_SyncTokenCallable,
|
||||
)
|
||||
from langchain_openai.chatgpt_oauth import ChatGPTToken
|
||||
|
||||
|
||||
class FakeTokenProvider:
|
||||
"""Minimal `ChatGPTOAuthTokenProvider` for tests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
access_token: str = "at-1",
|
||||
account_id: str | None = "acct-1",
|
||||
) -> None:
|
||||
self.access_token = access_token
|
||||
self.account_id = account_id
|
||||
self.calls = 0
|
||||
self.async_calls = 0
|
||||
|
||||
def get_token(self) -> ChatGPTToken:
|
||||
self.calls += 1
|
||||
return ChatGPTToken(
|
||||
access_token=self.access_token,
|
||||
refresh_token="rt",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
account_id=self.account_id,
|
||||
)
|
||||
|
||||
async def aget_token(self) -> ChatGPTToken:
|
||||
self.async_calls += 1
|
||||
return self.get_token()
|
||||
|
||||
def get_access_token(self) -> str:
|
||||
return self.get_token().access_token
|
||||
|
||||
async def aget_access_token(self) -> str:
|
||||
token = await self.aget_token()
|
||||
return token.access_token
|
||||
|
||||
|
||||
def _build_model(**overrides: Any) -> ChatOpenAICodex:
|
||||
provider = overrides.pop("token_provider", None) or FakeTokenProvider()
|
||||
return ChatOpenAICodex(
|
||||
model=overrides.pop("model", "gpt-5.2-codex"),
|
||||
token_provider=provider,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
def test_defaults_route_to_chatgpt_codex_backend() -> None:
|
||||
model = _build_model()
|
||||
assert model.openai_api_base == CHATGPT_CODEX_BASE_URL
|
||||
assert model.use_responses_api is True
|
||||
assert model.store is False
|
||||
assert model.streaming is True
|
||||
# `output_version` is a client-side projection and isn't forced — the
|
||||
# base default applies unless the caller picks one explicitly.
|
||||
|
||||
|
||||
def test_uses_callable_api_key_from_token_provider() -> None:
|
||||
"""The SDK-facing `api_key` must resolve to the provider's current token."""
|
||||
provider = FakeTokenProvider(access_token="abc")
|
||||
model = _build_model(token_provider=provider)
|
||||
secret = model.openai_api_key
|
||||
assert secret is not None
|
||||
if callable(secret):
|
||||
assert secret() == "abc"
|
||||
else:
|
||||
# `ChatOpenAI.validate_environment` wraps callables in `SecretStr`.
|
||||
assert secret.get_secret_value() == "abc"
|
||||
assert provider.calls >= 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("field", ["base_url", "openai_api_base"])
|
||||
def test_base_url_override_is_rejected(field: str) -> None:
|
||||
"""Reject caller-supplied `base_url` / `openai_api_base`.
|
||||
|
||||
The OAuth bearer token is wired in as `api_key`, so accepting an arbitrary
|
||||
base URL would let an attacker (or a misconfigured serialized config)
|
||||
exfiltrate the token to a host of their choice.
|
||||
"""
|
||||
with pytest.raises(ValueError, match=r"requires `(?:base_url|openai_api_base)="):
|
||||
_build_model(**{field: "https://attacker.example.com/codex"})
|
||||
|
||||
|
||||
def test_explicit_base_url_matching_codex_endpoint_is_accepted() -> None:
|
||||
"""Passing the canonical Codex endpoint explicitly is still allowed."""
|
||||
model = _build_model(base_url=CHATGPT_CODEX_BASE_URL)
|
||||
assert model.openai_api_base == CHATGPT_CODEX_BASE_URL
|
||||
|
||||
|
||||
@pytest.mark.parametrize("field", ["api_key", "openai_api_key"])
|
||||
def test_explicit_api_key_is_rejected(field: str) -> None:
|
||||
"""Reject a caller-supplied `api_key` / `openai_api_key`.
|
||||
|
||||
Auth is owned by `token_provider`; a caller-supplied key would silently
|
||||
win over the OAuth bearer, so it must fail loudly rather than leave the
|
||||
model in a conflicting state.
|
||||
"""
|
||||
with pytest.raises(ValueError, match=r"manages authentication via"):
|
||||
_build_model(**{field: "sk-should-not-be-allowed"})
|
||||
|
||||
|
||||
def test_request_payload_injects_account_id_and_originator_headers() -> None:
|
||||
provider = FakeTokenProvider(account_id="acct-42")
|
||||
model = _build_model(token_provider=provider)
|
||||
payload = model._get_request_payload([HumanMessage("hi")])
|
||||
headers = payload["extra_headers"]
|
||||
assert headers[ACCOUNT_ID_HEADER] == "acct-42"
|
||||
assert headers[ORIGINATOR_HEADER] == ORIGINATOR_VALUE
|
||||
|
||||
|
||||
def test_request_payload_omits_account_id_when_unknown() -> None:
|
||||
provider = FakeTokenProvider(account_id=None)
|
||||
model = _build_model(token_provider=provider)
|
||||
payload = model._get_request_payload([HumanMessage("hi")])
|
||||
headers = payload["extra_headers"]
|
||||
assert ACCOUNT_ID_HEADER not in headers
|
||||
assert headers[ORIGINATOR_HEADER] == ORIGINATOR_VALUE
|
||||
|
||||
|
||||
def test_request_payload_can_disable_originator_header() -> None:
|
||||
"""`originator=None` omits the header entirely."""
|
||||
provider = FakeTokenProvider(account_id=None)
|
||||
model = _build_model(token_provider=provider, originator=None)
|
||||
payload = model._get_request_payload([HumanMessage("hi")])
|
||||
assert "extra_headers" not in payload
|
||||
|
||||
|
||||
def test_constructor_originator_overrides_default() -> None:
|
||||
"""An explicit `originator=` constructor value replaces the package default."""
|
||||
model = _build_model(originator="my-app/1.2")
|
||||
payload = model._get_request_payload([HumanMessage("hi")])
|
||||
assert payload["extra_headers"][ORIGINATOR_HEADER] == "my-app/1.2"
|
||||
|
||||
|
||||
def test_env_var_overrides_default_originator(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""`LANGCHAIN_CODEX_ORIGINATOR` sets the default when no constructor value."""
|
||||
monkeypatch.setenv(ORIGINATOR_ENV_VAR, "env-app/2.0")
|
||||
model = _build_model()
|
||||
payload = model._get_request_payload([HumanMessage("hi")])
|
||||
assert payload["extra_headers"][ORIGINATOR_HEADER] == "env-app/2.0"
|
||||
|
||||
|
||||
def test_constructor_originator_wins_over_env_var(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Explicit constructor value beats the env var (resolution order)."""
|
||||
monkeypatch.setenv(ORIGINATOR_ENV_VAR, "env-app/2.0")
|
||||
model = _build_model(originator="ctor-app/9.9")
|
||||
payload = model._get_request_payload([HumanMessage("hi")])
|
||||
assert payload["extra_headers"][ORIGINATOR_HEADER] == "ctor-app/9.9"
|
||||
|
||||
|
||||
def test_empty_env_var_falls_back_to_package_default(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""An empty `LANGCHAIN_CODEX_ORIGINATOR` is treated as unset."""
|
||||
monkeypatch.setenv(ORIGINATOR_ENV_VAR, "")
|
||||
model = _build_model()
|
||||
payload = model._get_request_payload([HumanMessage("hi")])
|
||||
assert payload["extra_headers"][ORIGINATOR_HEADER] == ORIGINATOR_VALUE
|
||||
|
||||
|
||||
def test_caller_extra_headers_override_originator_field() -> None:
|
||||
"""A per-call `extra_headers` originator beats the model's field value."""
|
||||
model = _build_model(originator="ctor-app")
|
||||
payload = model._get_request_payload(
|
||||
[HumanMessage("hi")], extra_headers={ORIGINATOR_HEADER: "call-app"}
|
||||
)
|
||||
assert payload["extra_headers"][ORIGINATOR_HEADER] == "call-app"
|
||||
|
||||
|
||||
def test_request_payload_pulls_fresh_account_id_each_call() -> None:
|
||||
provider = FakeTokenProvider()
|
||||
model = _build_model(token_provider=provider)
|
||||
before = provider.calls
|
||||
model._get_request_payload([HumanMessage("hi")])
|
||||
model._get_request_payload([HumanMessage("hi again")])
|
||||
assert provider.calls >= before + 2
|
||||
|
||||
|
||||
def test_invalid_token_provider_rejected() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
ChatOpenAICodex(model="gpt-5.2-codex", token_provider="not-a-provider")
|
||||
|
||||
|
||||
def test_conflicting_use_responses_api_raises() -> None:
|
||||
with pytest.raises(ValueError, match="use_responses_api"):
|
||||
ChatOpenAICodex(
|
||||
model="gpt-5.2-codex",
|
||||
token_provider=FakeTokenProvider(),
|
||||
use_responses_api=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
||||
def test_explicit_output_version_is_respected(output_version: str) -> None:
|
||||
"""`output_version` is a client-side projection — any value is allowed.
|
||||
|
||||
Also pins the wire-level invariant: regardless of the constructor
|
||||
choice, `output_version` never appears in the outbound payload (it's
|
||||
consumed entirely by the response projection layer).
|
||||
"""
|
||||
model = ChatOpenAICodex(
|
||||
model="gpt-5.2-codex",
|
||||
token_provider=FakeTokenProvider(),
|
||||
output_version=output_version,
|
||||
)
|
||||
assert model.output_version == output_version
|
||||
payload = model._get_request_payload([HumanMessage("hi")])
|
||||
assert "output_version" not in payload
|
||||
|
||||
|
||||
def test_conflicting_store_raises() -> None:
|
||||
with pytest.raises(ValueError, match="store"):
|
||||
ChatOpenAICodex(
|
||||
model="gpt-5.2-codex",
|
||||
token_provider=FakeTokenProvider(),
|
||||
store=True,
|
||||
)
|
||||
|
||||
|
||||
def test_conflicting_streaming_raises() -> None:
|
||||
with pytest.raises(ValueError, match="streaming"):
|
||||
ChatOpenAICodex(
|
||||
model="gpt-5.2-codex",
|
||||
token_provider=FakeTokenProvider(),
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
|
||||
def test_request_payload_sends_store_false_and_stream_true() -> None:
|
||||
"""The Codex backend 400s with `store=True` or non-streaming requests."""
|
||||
model = _build_model()
|
||||
payload = model._get_request_payload([HumanMessage("hi")])
|
||||
assert payload["store"] is False
|
||||
assert payload["stream"] is True
|
||||
|
||||
|
||||
def test_request_payload_sets_default_instructions() -> None:
|
||||
"""The Codex backend 400s without `instructions`; default must be injected."""
|
||||
model = _build_model()
|
||||
payload = model._get_request_payload([HumanMessage("hi")])
|
||||
assert payload["instructions"] == DEFAULT_INSTRUCTIONS
|
||||
|
||||
|
||||
def test_request_payload_respects_constructor_instructions() -> None:
|
||||
model = _build_model(instructions="custom system prompt")
|
||||
payload = model._get_request_payload([HumanMessage("hi")])
|
||||
assert payload["instructions"] == "custom system prompt"
|
||||
|
||||
|
||||
def test_request_payload_respects_per_call_instructions_override() -> None:
|
||||
"""An `instructions` kwarg at invoke time wins over the model default."""
|
||||
model = _build_model(instructions="model-level")
|
||||
payload = model._get_request_payload(
|
||||
[HumanMessage("hi")], instructions="call-level"
|
||||
)
|
||||
assert payload["instructions"] == "call-level"
|
||||
|
||||
|
||||
def test_request_payload_preserves_explicit_empty_instructions() -> None:
|
||||
"""An explicit empty `instructions=""` must not be overwritten silently.
|
||||
|
||||
The backend will reject it, but silently replacing it with the default
|
||||
would hide the caller's bug.
|
||||
"""
|
||||
model = _build_model(instructions="model-level")
|
||||
payload = model._get_request_payload([HumanMessage("hi")], instructions="")
|
||||
assert payload["instructions"] == ""
|
||||
|
||||
|
||||
def test_system_message_is_lifted_into_top_level_instructions() -> None:
|
||||
"""`SystemMessage` content overrides the constructor `instructions`.
|
||||
|
||||
Codex rejects `SystemMessage` chat turns (400 "System messages are
|
||||
not allowed"), so `ChatOpenAICodex` lifts their content into the
|
||||
top-level `instructions` field and strips them from the input list.
|
||||
"""
|
||||
model = _build_model(instructions="model-level")
|
||||
payload = model._get_request_payload(
|
||||
[SystemMessage("from-system-message"), HumanMessage("hi")]
|
||||
)
|
||||
assert payload["instructions"] == "from-system-message"
|
||||
input_messages = payload["input"]
|
||||
assert all(entry.get("role") != "system" for entry in input_messages)
|
||||
assert any(
|
||||
entry.get("role") == "user" and "hi" in str(entry.get("content"))
|
||||
for entry in input_messages
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("role", ["system", "developer"])
|
||||
def test_chat_message_instruction_roles_are_lifted(role: str) -> None:
|
||||
model = _build_model(instructions="model-level")
|
||||
payload = model._get_request_payload(
|
||||
[
|
||||
ChatMessage(content=f"from-{role}", role=role),
|
||||
HumanMessage("hi"),
|
||||
]
|
||||
)
|
||||
assert payload["instructions"] == f"from-{role}"
|
||||
assert all(entry.get("role") != role for entry in payload["input"])
|
||||
assert [entry.get("role") for entry in payload["input"]] == ["user"]
|
||||
|
||||
|
||||
def test_back_to_back_system_messages_join_in_input_order() -> None:
|
||||
"""Adjacent `SystemMessage` entries are concatenated with `"\\n\\n"`."""
|
||||
model = _build_model(instructions="model-level")
|
||||
payload = model._get_request_payload(
|
||||
[
|
||||
SystemMessage("first"),
|
||||
SystemMessage("second"),
|
||||
HumanMessage("hi"),
|
||||
]
|
||||
)
|
||||
assert payload["instructions"] == "first\n\nsecond"
|
||||
|
||||
|
||||
def test_interleaved_system_messages_are_still_lifted() -> None:
|
||||
"""`SystemMessage`s anywhere in the input are lifted into `instructions`.
|
||||
|
||||
Codex is stateless per call and has no equivalent of an in-line system
|
||||
turn, so positional intent (e.g., switching persona mid-conversation)
|
||||
can't be preserved. All `SystemMessage`s are concatenated in input
|
||||
order and stripped; the remaining (non-system) messages keep their
|
||||
relative order.
|
||||
"""
|
||||
model = _build_model(instructions="model-level")
|
||||
payload = model._get_request_payload(
|
||||
[
|
||||
HumanMessage("hi"),
|
||||
SystemMessage("midway-system"),
|
||||
HumanMessage("bye"),
|
||||
SystemMessage("tail-system"),
|
||||
]
|
||||
)
|
||||
assert payload["instructions"] == "midway-system\n\ntail-system"
|
||||
contents = [entry.get("content") for entry in payload["input"]]
|
||||
assert all(entry.get("role") != "system" for entry in payload["input"])
|
||||
assert contents == ["hi", "bye"]
|
||||
|
||||
|
||||
def test_explicit_instructions_kwarg_wins_over_system_message() -> None:
|
||||
"""A per-call `instructions=` kwarg always beats lifted `SystemMessage`."""
|
||||
model = _build_model(instructions="model-level")
|
||||
payload = model._get_request_payload(
|
||||
[SystemMessage("from-system-message"), HumanMessage("hi")],
|
||||
instructions="from-kwarg",
|
||||
)
|
||||
assert payload["instructions"] == "from-kwarg"
|
||||
|
||||
|
||||
def test_explicit_instructions_kwarg_overrides_logs_warning(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""The override path emits a warning so callers can audit the conflict."""
|
||||
model = _build_model(instructions="model-level")
|
||||
with caplog.at_level("WARNING", logger="langchain_openai.chat_models.codex"):
|
||||
model._get_request_payload(
|
||||
[SystemMessage("discarded"), HumanMessage("hi")],
|
||||
instructions="from-kwarg",
|
||||
)
|
||||
assert any(
|
||||
"explicit `instructions=` kwarg wins" in record.getMessage()
|
||||
for record in caplog.records
|
||||
)
|
||||
|
||||
|
||||
def test_system_message_with_text_blocks_is_lifted() -> None:
|
||||
"""List-of-text-blocks `SystemMessage` content flattens cleanly."""
|
||||
model = _build_model(instructions="model-level")
|
||||
payload = model._get_request_payload(
|
||||
[
|
||||
SystemMessage(
|
||||
[
|
||||
{"type": "text", "text": "part one. "},
|
||||
{"type": "text", "text": "part two."},
|
||||
]
|
||||
),
|
||||
HumanMessage("hi"),
|
||||
]
|
||||
)
|
||||
assert payload["instructions"] == "part one. part two."
|
||||
|
||||
|
||||
def test_system_message_with_non_text_block_raises() -> None:
|
||||
"""Non-text content blocks can't be flattened into `instructions`."""
|
||||
model = _build_model(instructions="model-level")
|
||||
with pytest.raises(ValueError, match="non-text content block"):
|
||||
model._get_request_payload(
|
||||
[
|
||||
SystemMessage(
|
||||
[
|
||||
{"type": "text", "text": "ok"},
|
||||
{"type": "image_url", "image_url": "http://example/x.png"},
|
||||
]
|
||||
),
|
||||
HumanMessage("hi"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_system_message_with_non_string_text_value_raises() -> None:
|
||||
"""A text block whose `text` isn't a string is a programming error."""
|
||||
from langchain_openai.chat_models.codex import _flatten_system_message_content
|
||||
|
||||
# Constructed directly (bypassing the `SystemMessage` content type so
|
||||
# we can hit the helper's defensive type check).
|
||||
bad = SystemMessage.model_construct(content=[{"type": "text", "text": 42}])
|
||||
with pytest.raises(ValueError, match=r"text.*not a string"):
|
||||
_flatten_system_message_content([bad])
|
||||
|
||||
|
||||
def test_system_message_with_unsupported_content_type_raises() -> None:
|
||||
"""Content that's neither str nor list (e.g., int) is rejected upfront."""
|
||||
from langchain_openai.chat_models.codex import _flatten_system_message_content
|
||||
|
||||
bad = SystemMessage.model_construct(content=123) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match="unsupported content type"):
|
||||
_flatten_system_message_content([bad])
|
||||
|
||||
|
||||
def test_caller_headers_win_over_codex_defaults() -> None:
|
||||
"""Caller-supplied `extra_headers` must override the Codex injections."""
|
||||
provider = FakeTokenProvider(account_id="acct-1")
|
||||
model = _build_model(token_provider=provider)
|
||||
payload = model._get_request_payload(
|
||||
[HumanMessage("hi")],
|
||||
extra_headers={ORIGINATOR_HEADER: "custom-app"},
|
||||
)
|
||||
headers = payload["extra_headers"]
|
||||
assert headers[ORIGINATOR_HEADER] == "custom-app"
|
||||
# The auto-injected account ID still rides along when not overridden.
|
||||
assert headers[ACCOUNT_ID_HEADER] == "acct-1"
|
||||
|
||||
|
||||
async def test_agenerate_primes_async_token_cache(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""`_agenerate` must call `aget_token` so refresh doesn't block the loop."""
|
||||
provider = FakeTokenProvider()
|
||||
model = _build_model(token_provider=provider)
|
||||
|
||||
async def _fake_super_agenerate(*_a: Any, **_k: Any) -> Any:
|
||||
return "sentinel"
|
||||
|
||||
monkeypatch.setattr(ChatOpenAI, "_agenerate", _fake_super_agenerate)
|
||||
before = provider.async_calls
|
||||
result = await model._agenerate([HumanMessage("hi")])
|
||||
assert result == "sentinel"
|
||||
assert provider.async_calls == before + 1
|
||||
|
||||
|
||||
async def test_astream_primes_async_token_cache_and_yields_headers_via_payload(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""`_astream` primes the async cache and the sync payload still attaches headers.
|
||||
|
||||
The sync payload builder (`_get_request_payload`) is what actually
|
||||
stamps `ChatGPT-Account-Id` + `originator` on outbound requests; the
|
||||
async path's only job is to refresh the token off the event loop
|
||||
before the sync builder reads from the (now warm) cache. Verifies
|
||||
both halves of that contract.
|
||||
"""
|
||||
provider = FakeTokenProvider(account_id="acct-async")
|
||||
model = _build_model(token_provider=provider)
|
||||
|
||||
async def _fake_super_astream(*_a: Any, **_k: Any) -> Any:
|
||||
yield "chunk"
|
||||
|
||||
monkeypatch.setattr(ChatOpenAI, "_astream", _fake_super_astream)
|
||||
before = provider.async_calls
|
||||
received = [chunk async for chunk in model._astream([HumanMessage("hi")])]
|
||||
assert received == ["chunk"]
|
||||
assert provider.async_calls == before + 1
|
||||
|
||||
# Sync payload (which is what the SDK ultimately serializes) carries
|
||||
# the codex headers even after the async refresh path primed them.
|
||||
payload = model._get_request_payload([HumanMessage("hi")])
|
||||
headers = payload["extra_headers"]
|
||||
assert headers[ACCOUNT_ID_HEADER] == "acct-async"
|
||||
assert headers[ORIGINATOR_HEADER] == ORIGINATOR_VALUE
|
||||
|
||||
|
||||
def test_callable_api_key_returns_provider_token() -> None:
|
||||
"""The `api_key` callable wired into the SDK must yield the access token."""
|
||||
provider = FakeTokenProvider(access_token="abc-123")
|
||||
model = _build_model(token_provider=provider)
|
||||
# ChatOpenAI converts callable api_keys into a `SecretStr` wrapping
|
||||
# whatever the callable returns; resolving it should return the
|
||||
# provider's current access token.
|
||||
secret = model.openai_api_key
|
||||
assert secret is not None
|
||||
if callable(secret):
|
||||
assert secret() == "abc-123"
|
||||
else:
|
||||
assert secret.get_secret_value() == "abc-123"
|
||||
|
||||
|
||||
def test_ls_params_uses_codex_provider_tag() -> None:
|
||||
model = _build_model()
|
||||
params = model._get_ls_params()
|
||||
assert params["ls_provider"] == "openai-codex"
|
||||
|
||||
|
||||
def test_is_not_serializable_due_to_live_token_provider() -> None:
|
||||
assert ChatOpenAICodex.is_lc_serializable() is False
|
||||
|
||||
|
||||
def test_sync_token_callable_delegates() -> None:
|
||||
provider = FakeTokenProvider(access_token="zzz")
|
||||
callable_ = _SyncTokenCallable(provider)
|
||||
assert callable_() == "zzz"
|
||||
@@ -1,6 +1,6 @@
|
||||
from langchain_openai.chat_models import __all__
|
||||
|
||||
EXPECTED_ALL = ["ChatOpenAI", "AzureChatOpenAI"]
|
||||
EXPECTED_ALL = ["ChatOpenAI", "AzureChatOpenAI", "ChatOpenAICodex"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
|
||||
1031
libs/partners/openai/tests/unit_tests/test_chatgpt_oauth.py
Normal file
1031
libs/partners/openai/tests/unit_tests/test_chatgpt_oauth.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,7 @@ EXPECTED_ALL = [
|
||||
"__version__",
|
||||
"OpenAI",
|
||||
"ChatOpenAI",
|
||||
"ChatOpenAICodex",
|
||||
"OpenAIEmbeddings",
|
||||
"AzureOpenAI",
|
||||
"AzureChatOpenAI",
|
||||
|
||||
Reference in New Issue
Block a user