mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
openai
This commit is contained in:
@@ -40,6 +40,7 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.exceptions import ContextOverflowError
|
||||
from langchain_core.language_models import (
|
||||
LanguageModelInput,
|
||||
ModelProfileRegistry,
|
||||
@@ -450,6 +451,11 @@ def _update_token_usage(
|
||||
|
||||
|
||||
def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
|
||||
if (
|
||||
"context_length_exceeded" in str(e)
|
||||
or "Input tokens exceed the configured limit" in e.message
|
||||
):
|
||||
raise ContextOverflowError(e.message) from e
|
||||
if (
|
||||
"'response_format' of type 'json_schema' is not supported with this model"
|
||||
) in e.message:
|
||||
@@ -474,6 +480,13 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
|
||||
raise
|
||||
|
||||
|
||||
def _handle_openai_api_error(e: openai.APIError) -> None:
|
||||
error_message = str(e)
|
||||
if "exceeds the context window" in error_message:
|
||||
raise ContextOverflowError(error_message) from e
|
||||
raise
|
||||
|
||||
|
||||
def _model_prefers_responses_api(model_name: str | None) -> bool:
|
||||
if not model_name:
|
||||
return False
|
||||
@@ -1146,49 +1159,54 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
self._ensure_sync_client_available()
|
||||
kwargs["stream"] = True
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
if self.include_response_headers:
|
||||
raw_context_manager = self.root_client.with_raw_response.responses.create(
|
||||
**payload
|
||||
)
|
||||
context_manager = raw_context_manager.parse()
|
||||
headers = {"headers": dict(raw_context_manager.headers)}
|
||||
else:
|
||||
context_manager = self.root_client.responses.create(**payload)
|
||||
headers = {}
|
||||
original_schema_obj = kwargs.get("response_format")
|
||||
|
||||
with context_manager as response:
|
||||
is_first_chunk = True
|
||||
current_index = -1
|
||||
current_output_index = -1
|
||||
current_sub_index = -1
|
||||
has_reasoning = False
|
||||
for chunk in response:
|
||||
metadata = headers if is_first_chunk else {}
|
||||
(
|
||||
current_index,
|
||||
current_output_index,
|
||||
current_sub_index,
|
||||
generation_chunk,
|
||||
) = _convert_responses_chunk_to_generation_chunk(
|
||||
chunk,
|
||||
current_index,
|
||||
current_output_index,
|
||||
current_sub_index,
|
||||
schema=original_schema_obj,
|
||||
metadata=metadata,
|
||||
has_reasoning=has_reasoning,
|
||||
output_version=self.output_version,
|
||||
try:
|
||||
if self.include_response_headers:
|
||||
raw_context_manager = (
|
||||
self.root_client.with_raw_response.responses.create(**payload)
|
||||
)
|
||||
if generation_chunk:
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
generation_chunk.text, chunk=generation_chunk
|
||||
)
|
||||
is_first_chunk = False
|
||||
if "reasoning" in generation_chunk.message.additional_kwargs:
|
||||
has_reasoning = True
|
||||
yield generation_chunk
|
||||
context_manager = raw_context_manager.parse()
|
||||
headers = {"headers": dict(raw_context_manager.headers)}
|
||||
else:
|
||||
context_manager = self.root_client.responses.create(**payload)
|
||||
headers = {}
|
||||
original_schema_obj = kwargs.get("response_format")
|
||||
|
||||
with context_manager as response:
|
||||
is_first_chunk = True
|
||||
current_index = -1
|
||||
current_output_index = -1
|
||||
current_sub_index = -1
|
||||
has_reasoning = False
|
||||
for chunk in response:
|
||||
metadata = headers if is_first_chunk else {}
|
||||
(
|
||||
current_index,
|
||||
current_output_index,
|
||||
current_sub_index,
|
||||
generation_chunk,
|
||||
) = _convert_responses_chunk_to_generation_chunk(
|
||||
chunk,
|
||||
current_index,
|
||||
current_output_index,
|
||||
current_sub_index,
|
||||
schema=original_schema_obj,
|
||||
metadata=metadata,
|
||||
has_reasoning=has_reasoning,
|
||||
output_version=self.output_version,
|
||||
)
|
||||
if generation_chunk:
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
generation_chunk.text, chunk=generation_chunk
|
||||
)
|
||||
is_first_chunk = False
|
||||
if "reasoning" in generation_chunk.message.additional_kwargs:
|
||||
has_reasoning = True
|
||||
yield generation_chunk
|
||||
except openai.BadRequestError as e:
|
||||
_handle_openai_bad_request(e)
|
||||
except openai.APIError as e:
|
||||
_handle_openai_api_error(e)
|
||||
|
||||
async def _astream_responses(
|
||||
self,
|
||||
@@ -1199,51 +1217,58 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
kwargs["stream"] = True
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
if self.include_response_headers:
|
||||
raw_context_manager = (
|
||||
await self.root_async_client.with_raw_response.responses.create(
|
||||
try:
|
||||
if self.include_response_headers:
|
||||
raw_context_manager = (
|
||||
await self.root_async_client.with_raw_response.responses.create(
|
||||
**payload
|
||||
)
|
||||
)
|
||||
context_manager = raw_context_manager.parse()
|
||||
headers = {"headers": dict(raw_context_manager.headers)}
|
||||
else:
|
||||
context_manager = await self.root_async_client.responses.create(
|
||||
**payload
|
||||
)
|
||||
)
|
||||
context_manager = raw_context_manager.parse()
|
||||
headers = {"headers": dict(raw_context_manager.headers)}
|
||||
else:
|
||||
context_manager = await self.root_async_client.responses.create(**payload)
|
||||
headers = {}
|
||||
original_schema_obj = kwargs.get("response_format")
|
||||
headers = {}
|
||||
original_schema_obj = kwargs.get("response_format")
|
||||
|
||||
async with context_manager as response:
|
||||
is_first_chunk = True
|
||||
current_index = -1
|
||||
current_output_index = -1
|
||||
current_sub_index = -1
|
||||
has_reasoning = False
|
||||
async for chunk in response:
|
||||
metadata = headers if is_first_chunk else {}
|
||||
(
|
||||
current_index,
|
||||
current_output_index,
|
||||
current_sub_index,
|
||||
generation_chunk,
|
||||
) = _convert_responses_chunk_to_generation_chunk(
|
||||
chunk,
|
||||
current_index,
|
||||
current_output_index,
|
||||
current_sub_index,
|
||||
schema=original_schema_obj,
|
||||
metadata=metadata,
|
||||
has_reasoning=has_reasoning,
|
||||
output_version=self.output_version,
|
||||
)
|
||||
if generation_chunk:
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
generation_chunk.text, chunk=generation_chunk
|
||||
)
|
||||
is_first_chunk = False
|
||||
if "reasoning" in generation_chunk.message.additional_kwargs:
|
||||
has_reasoning = True
|
||||
yield generation_chunk
|
||||
async with context_manager as response:
|
||||
is_first_chunk = True
|
||||
current_index = -1
|
||||
current_output_index = -1
|
||||
current_sub_index = -1
|
||||
has_reasoning = False
|
||||
async for chunk in response:
|
||||
metadata = headers if is_first_chunk else {}
|
||||
(
|
||||
current_index,
|
||||
current_output_index,
|
||||
current_sub_index,
|
||||
generation_chunk,
|
||||
) = _convert_responses_chunk_to_generation_chunk(
|
||||
chunk,
|
||||
current_index,
|
||||
current_output_index,
|
||||
current_sub_index,
|
||||
schema=original_schema_obj,
|
||||
metadata=metadata,
|
||||
has_reasoning=has_reasoning,
|
||||
output_version=self.output_version,
|
||||
)
|
||||
if generation_chunk:
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
generation_chunk.text, chunk=generation_chunk
|
||||
)
|
||||
is_first_chunk = False
|
||||
if "reasoning" in generation_chunk.message.additional_kwargs:
|
||||
has_reasoning = True
|
||||
yield generation_chunk
|
||||
except openai.BadRequestError as e:
|
||||
_handle_openai_bad_request(e)
|
||||
except openai.APIError as e:
|
||||
_handle_openai_api_error(e)
|
||||
|
||||
def _should_stream_usage(
|
||||
self, stream_usage: bool | None = None, **kwargs: Any
|
||||
@@ -1282,24 +1307,26 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
base_generation_info = {}
|
||||
|
||||
if "response_format" in payload:
|
||||
if self.include_response_headers:
|
||||
warnings.warn(
|
||||
"Cannot currently include response headers when response_format is "
|
||||
"specified."
|
||||
)
|
||||
payload.pop("stream")
|
||||
response_stream = self.root_client.beta.chat.completions.stream(**payload)
|
||||
context_manager = response_stream
|
||||
else:
|
||||
if self.include_response_headers:
|
||||
raw_response = self.client.with_raw_response.create(**payload)
|
||||
response = raw_response.parse()
|
||||
base_generation_info = {"headers": dict(raw_response.headers)}
|
||||
else:
|
||||
response = self.client.create(**payload)
|
||||
context_manager = response
|
||||
try:
|
||||
if "response_format" in payload:
|
||||
if self.include_response_headers:
|
||||
warnings.warn(
|
||||
"Cannot currently include response headers when "
|
||||
"response_format is specified."
|
||||
)
|
||||
payload.pop("stream")
|
||||
response_stream = self.root_client.beta.chat.completions.stream(
|
||||
**payload
|
||||
)
|
||||
context_manager = response_stream
|
||||
else:
|
||||
if self.include_response_headers:
|
||||
raw_response = self.client.with_raw_response.create(**payload)
|
||||
response = raw_response.parse()
|
||||
base_generation_info = {"headers": dict(raw_response.headers)}
|
||||
else:
|
||||
response = self.client.create(**payload)
|
||||
context_manager = response
|
||||
with context_manager as response:
|
||||
is_first_chunk = True
|
||||
for chunk in response:
|
||||
@@ -1324,6 +1351,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
yield generation_chunk
|
||||
except openai.BadRequestError as e:
|
||||
_handle_openai_bad_request(e)
|
||||
except openai.APIError as e:
|
||||
_handle_openai_api_error(e)
|
||||
if hasattr(response, "get_final_completion") and "response_format" in payload:
|
||||
final_completion = response.get_final_completion()
|
||||
generation_chunk = self._get_generation_chunk_from_completion(
|
||||
@@ -1349,15 +1378,10 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
try:
|
||||
if "response_format" in payload:
|
||||
payload.pop("stream")
|
||||
try:
|
||||
raw_response = (
|
||||
self.root_client.chat.completions.with_raw_response.parse(
|
||||
**payload
|
||||
)
|
||||
)
|
||||
response = raw_response.parse()
|
||||
except openai.BadRequestError as e:
|
||||
_handle_openai_bad_request(e)
|
||||
raw_response = (
|
||||
self.root_client.chat.completions.with_raw_response.parse(**payload)
|
||||
)
|
||||
response = raw_response.parse()
|
||||
elif self._use_responses_api(payload):
|
||||
original_schema_obj = kwargs.get("response_format")
|
||||
if original_schema_obj and _is_pydantic_class(original_schema_obj):
|
||||
@@ -1380,6 +1404,10 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
else:
|
||||
raw_response = self.client.with_raw_response.create(**payload)
|
||||
response = raw_response.parse()
|
||||
except openai.BadRequestError as e:
|
||||
_handle_openai_bad_request(e)
|
||||
except openai.APIError as e:
|
||||
_handle_openai_api_error(e)
|
||||
except Exception as e:
|
||||
if raw_response is not None and hasattr(raw_response, "http_response"):
|
||||
e.response = raw_response.http_response # type: ignore[attr-defined]
|
||||
@@ -1523,28 +1551,28 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
base_generation_info = {}
|
||||
|
||||
if "response_format" in payload:
|
||||
if self.include_response_headers:
|
||||
warnings.warn(
|
||||
"Cannot currently include response headers when response_format is "
|
||||
"specified."
|
||||
)
|
||||
payload.pop("stream")
|
||||
response_stream = self.root_async_client.beta.chat.completions.stream(
|
||||
**payload
|
||||
)
|
||||
context_manager = response_stream
|
||||
else:
|
||||
if self.include_response_headers:
|
||||
raw_response = await self.async_client.with_raw_response.create(
|
||||
try:
|
||||
if "response_format" in payload:
|
||||
if self.include_response_headers:
|
||||
warnings.warn(
|
||||
"Cannot currently include response headers when "
|
||||
"response_format is specified."
|
||||
)
|
||||
payload.pop("stream")
|
||||
response_stream = self.root_async_client.beta.chat.completions.stream(
|
||||
**payload
|
||||
)
|
||||
response = raw_response.parse()
|
||||
base_generation_info = {"headers": dict(raw_response.headers)}
|
||||
context_manager = response_stream
|
||||
else:
|
||||
response = await self.async_client.create(**payload)
|
||||
context_manager = response
|
||||
try:
|
||||
if self.include_response_headers:
|
||||
raw_response = await self.async_client.with_raw_response.create(
|
||||
**payload
|
||||
)
|
||||
response = raw_response.parse()
|
||||
base_generation_info = {"headers": dict(raw_response.headers)}
|
||||
else:
|
||||
response = await self.async_client.create(**payload)
|
||||
context_manager = response
|
||||
async with context_manager as response:
|
||||
is_first_chunk = True
|
||||
async for chunk in response:
|
||||
@@ -1569,6 +1597,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
yield generation_chunk
|
||||
except openai.BadRequestError as e:
|
||||
_handle_openai_bad_request(e)
|
||||
except openai.APIError as e:
|
||||
_handle_openai_api_error(e)
|
||||
if hasattr(response, "get_final_completion") and "response_format" in payload:
|
||||
final_completion = await response.get_final_completion()
|
||||
generation_chunk = self._get_generation_chunk_from_completion(
|
||||
@@ -1593,13 +1623,10 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
try:
|
||||
if "response_format" in payload:
|
||||
payload.pop("stream")
|
||||
try:
|
||||
raw_response = await self.root_async_client.chat.completions.with_raw_response.parse( # noqa: E501
|
||||
**payload
|
||||
)
|
||||
response = raw_response.parse()
|
||||
except openai.BadRequestError as e:
|
||||
_handle_openai_bad_request(e)
|
||||
raw_response = await self.root_async_client.chat.completions.with_raw_response.parse( # noqa: E501
|
||||
**payload
|
||||
)
|
||||
response = raw_response.parse()
|
||||
elif self._use_responses_api(payload):
|
||||
original_schema_obj = kwargs.get("response_format")
|
||||
if original_schema_obj and _is_pydantic_class(original_schema_obj):
|
||||
@@ -1628,6 +1655,10 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
**payload
|
||||
)
|
||||
response = raw_response.parse()
|
||||
except openai.BadRequestError as e:
|
||||
_handle_openai_bad_request(e)
|
||||
except openai.APIError as e:
|
||||
_handle_openai_api_error(e)
|
||||
except Exception as e:
|
||||
if raw_response is not None and hasattr(raw_response, "http_response"):
|
||||
e.response = raw_response.http_response # type: ignore[attr-defined]
|
||||
|
||||
@@ -10,7 +10,9 @@ from typing import Any, Literal, cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
import pytest
|
||||
from langchain_core.exceptions import ContextOverflowError
|
||||
from langchain_core.load import dumps, loads
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -3231,3 +3233,146 @@ def test_openai_structured_output_refusal_handling_responses_api() -> None:
|
||||
pass
|
||||
except ValueError as e:
|
||||
pytest.fail(f"This is a wrong behavior. Error details: {e}")
|
||||
|
||||
|
||||
# Test fixtures for context overflow error tests
|
||||
_CONTEXT_OVERFLOW_ERROR_BODY = {
|
||||
"error": {
|
||||
"message": (
|
||||
"Input tokens exceed the configured limit of 272000 tokens. Your messages "
|
||||
"resulted in 300007 tokens. Please reduce the length of the messages."
|
||||
),
|
||||
"type": "invalid_request_error",
|
||||
"param": "messages",
|
||||
"code": "context_length_exceeded",
|
||||
}
|
||||
}
|
||||
_CONTEXT_OVERFLOW_BAD_REQUEST_ERROR = openai.BadRequestError(
|
||||
message=_CONTEXT_OVERFLOW_ERROR_BODY["error"]["message"],
|
||||
response=MagicMock(status_code=400),
|
||||
body=_CONTEXT_OVERFLOW_ERROR_BODY,
|
||||
)
|
||||
_CONTEXT_OVERFLOW_API_ERROR = openai.APIError(
|
||||
message=(
|
||||
"Your input exceeds the context window of this model. Please adjust your input "
|
||||
"and try again."
|
||||
),
|
||||
request=MagicMock(),
|
||||
body=None,
|
||||
)
|
||||
|
||||
|
||||
def test_context_overflow_error_invoke_sync() -> None:
|
||||
"""Test context overflow error on invoke (sync, chat completions API)."""
|
||||
llm = ChatOpenAI()
|
||||
|
||||
with ( # noqa: PT012
|
||||
patch.object(llm.client, "with_raw_response") as mock_client,
|
||||
pytest.raises(ContextOverflowError) as exc_info,
|
||||
):
|
||||
mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
|
||||
llm.invoke([HumanMessage(content="test")])
|
||||
|
||||
assert "Input tokens exceed the configured limit" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_context_overflow_error_invoke_sync_responses_api() -> None:
|
||||
"""Test context overflow error on invoke (sync, responses API)."""
|
||||
llm = ChatOpenAI(use_responses_api=True)
|
||||
|
||||
with ( # noqa: PT012
|
||||
patch.object(llm.root_client.responses, "with_raw_response") as mock_client,
|
||||
pytest.raises(ContextOverflowError) as exc_info,
|
||||
):
|
||||
mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
|
||||
llm.invoke([HumanMessage(content="test")])
|
||||
|
||||
assert "Input tokens exceed the configured limit" in str(exc_info.value)
|
||||
|
||||
|
||||
async def test_context_overflow_error_invoke_async() -> None:
|
||||
"""Test context overflow error on invoke (async, chat completions API)."""
|
||||
llm = ChatOpenAI()
|
||||
|
||||
with ( # noqa: PT012
|
||||
patch.object(llm.async_client, "with_raw_response") as mock_client,
|
||||
pytest.raises(ContextOverflowError) as exc_info,
|
||||
):
|
||||
mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
|
||||
await llm.ainvoke([HumanMessage(content="test")])
|
||||
|
||||
assert "Input tokens exceed the configured limit" in str(exc_info.value)
|
||||
|
||||
|
||||
async def test_context_overflow_error_invoke_async_responses_api() -> None:
|
||||
"""Test context overflow error on invoke (async, responses API)."""
|
||||
llm = ChatOpenAI(use_responses_api=True)
|
||||
|
||||
with ( # noqa: PT012
|
||||
patch.object(
|
||||
llm.root_async_client.responses, "with_raw_response"
|
||||
) as mock_client,
|
||||
pytest.raises(ContextOverflowError) as exc_info,
|
||||
):
|
||||
mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
|
||||
await llm.ainvoke([HumanMessage(content="test")])
|
||||
|
||||
assert "Input tokens exceed the configured limit" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_context_overflow_error_stream_sync() -> None:
|
||||
"""Test context overflow error on stream (sync, chat completions API)."""
|
||||
llm = ChatOpenAI()
|
||||
|
||||
with ( # noqa: PT012
|
||||
patch.object(llm.client, "create") as mock_create,
|
||||
pytest.raises(ContextOverflowError) as exc_info,
|
||||
):
|
||||
mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
|
||||
list(llm.stream([HumanMessage(content="test")]))
|
||||
|
||||
assert "Input tokens exceed the configured limit" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_context_overflow_error_stream_sync_responses_api() -> None:
|
||||
"""Test context overflow error on stream (sync, responses API)."""
|
||||
llm = ChatOpenAI(use_responses_api=True)
|
||||
|
||||
with ( # noqa: PT012
|
||||
patch.object(llm.root_client.responses, "create") as mock_create,
|
||||
pytest.raises(ContextOverflowError) as exc_info,
|
||||
):
|
||||
mock_create.side_effect = _CONTEXT_OVERFLOW_API_ERROR
|
||||
list(llm.stream([HumanMessage(content="test")]))
|
||||
|
||||
assert "exceeds the context window" in str(exc_info.value)
|
||||
|
||||
|
||||
async def test_context_overflow_error_stream_async() -> None:
|
||||
"""Test context overflow error on stream (async, chat completions API)."""
|
||||
llm = ChatOpenAI()
|
||||
|
||||
with ( # noqa: PT012
|
||||
patch.object(llm.async_client, "create") as mock_create,
|
||||
pytest.raises(ContextOverflowError) as exc_info,
|
||||
):
|
||||
mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
|
||||
async for _ in llm.astream([HumanMessage(content="test")]):
|
||||
pass
|
||||
|
||||
assert "Input tokens exceed the configured limit" in str(exc_info.value)
|
||||
|
||||
|
||||
async def test_context_overflow_error_stream_async_responses_api() -> None:
|
||||
"""Test context overflow error on stream (async, responses API)."""
|
||||
llm = ChatOpenAI(use_responses_api=True)
|
||||
|
||||
with ( # noqa: PT012
|
||||
patch.object(llm.root_async_client.responses, "create") as mock_create,
|
||||
pytest.raises(ContextOverflowError) as exc_info,
|
||||
):
|
||||
mock_create.side_effect = _CONTEXT_OVERFLOW_API_ERROR
|
||||
async for _ in llm.astream([HumanMessage(content="test")]):
|
||||
pass
|
||||
|
||||
assert "exceeds the context window" in str(exc_info.value)
|
||||
|
||||
Reference in New Issue
Block a user