This commit is contained in:
Mason Daugherty 2025-07-08 21:50:58 -04:00
parent e7e8391fd0
commit ca10f0eee5
No known key found for this signature in database
16 changed files with 146 additions and 106 deletions

View File

@ -9,6 +9,7 @@ Logic is largely replicated from openai._base_client.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib
import os import os
from functools import lru_cache from functools import lru_cache
from typing import Any, Optional from typing import Any, Optional
@ -23,10 +24,8 @@ class _SyncHttpxClientWrapper(openai.DefaultHttpxClient):
if self.is_closed: if self.is_closed:
return return
try: with contextlib.suppress(Exception):
self.close() self.close()
except Exception: # noqa: S110
pass
class _AsyncHttpxClientWrapper(openai.DefaultAsyncHttpxClient): class _AsyncHttpxClientWrapper(openai.DefaultAsyncHttpxClient):
@ -36,11 +35,9 @@ class _AsyncHttpxClientWrapper(openai.DefaultAsyncHttpxClient):
if self.is_closed: if self.is_closed:
return return
try: # TODO(someday): support non asyncio runtimes here
# TODO(someday): support non asyncio runtimes here with contextlib.suppress(Exception):
asyncio.get_running_loop().create_task(self.aclose()) asyncio.get_running_loop().create_task(self.aclose())
except Exception: # noqa: S110
pass
def _build_sync_httpx_client( def _build_sync_httpx_client(

View File

@ -70,7 +70,8 @@ _FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
def _convert_to_v03_ai_message( def _convert_to_v03_ai_message(
message: AIMessage, has_reasoning: bool = False message: AIMessage,
has_reasoning: bool = False, # noqa: FBT001, FBT002
) -> AIMessage: ) -> AIMessage:
"""Mutate an AIMessage to the old-style v0.3 format.""" """Mutate an AIMessage to the old-style v0.3 format."""
if isinstance(message.content, list): if isinstance(message.content, list):

View File

@ -5,12 +5,11 @@ from __future__ import annotations
import logging import logging
import os import os
from collections.abc import AsyncIterator, Awaitable, Iterator from collections.abc import AsyncIterator, Awaitable, Iterator
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union from typing import Any, Callable, Optional, TypeVar, Union
import openai import openai
from langchain_core.language_models import LanguageModelInput from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import LangSmithParams from langchain_core.language_models.chat_models import LangSmithParams
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable from langchain_core.runnables import Runnable
from langchain_core.utils import from_env, secret_from_env from langchain_core.utils import from_env, secret_from_env
@ -28,12 +27,6 @@ _DictOrPydanticClass = Union[dict[str, Any], type[_BM]]
_DictOrPydantic = Union[dict, _BM] _DictOrPydantic = Union[dict, _BM]
class _AllReturnType(TypedDict):
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
def _is_pydantic_class(obj: Any) -> bool: def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and is_basemodel_subclass(obj) return isinstance(obj, type) and is_basemodel_subclass(obj)

View File

@ -369,7 +369,7 @@ def _convert_delta_to_message_chunk(
) )
if role or default_class == ChatMessageChunk: if role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role, id=id_) return ChatMessageChunk(content=content, role=role, id=id_)
return default_class(content=content, id=id_) # type: ignore return default_class(content=content, id=id_) # type: ignore[call-arg]
def _update_token_usage( def _update_token_usage(
@ -396,7 +396,7 @@ def _update_token_usage(
k: _update_token_usage(overall_token_usage.get(k, 0), v) k: _update_token_usage(overall_token_usage.get(k, 0), v)
for k, v in new_usage.items() for k, v in new_usage.items()
} }
warnings.warn(f"Unexpected type for token usage: {type(new_usage)}") warnings.warn(f"Unexpected type for token usage: {type(new_usage)}", stacklevel=3)
return new_usage return new_usage
@ -410,7 +410,7 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
"langchain-openai==0.3. To use `with_structured_output` with this model, " "langchain-openai==0.3. To use `with_structured_output` with this model, "
'specify `method="function_calling"`.' 'specify `method="function_calling"`.'
) )
warnings.warn(message) warnings.warn(message, stacklevel=3)
raise e raise e
if "Invalid schema for response_format" in e.message: if "Invalid schema for response_format" in e.message:
message = ( message = (
@ -420,7 +420,7 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
"See supported schemas: " "See supported schemas: "
"https://platform.openai.com/docs/guides/structured-outputs#supported-schemas" "https://platform.openai.com/docs/guides/structured-outputs#supported-schemas"
) )
warnings.warn(message) warnings.warn(message, stacklevel=3)
raise e raise e
raise raise
@ -434,12 +434,6 @@ _DictOrPydanticClass = Union[dict[str, Any], type[_BM], type]
_DictOrPydantic = Union[dict, _BM] _DictOrPydantic = Union[dict, _BM]
class _AllReturnType(TypedDict):
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
class BaseChatOpenAI(BaseChatModel): class BaseChatOpenAI(BaseChatModel):
client: Any = Field(default=None, exclude=True) #: :meta private: client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private: async_client: Any = Field(default=None, exclude=True) #: :meta private:
@ -677,7 +671,7 @@ class BaseChatOpenAI(BaseChatModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_temperature(cls, values: dict[str, Any]) -> Any: def validate_temperature(cls, values: dict[str, Any]) -> Any:
"""Currently o1 models only allow temperature=1.""" """o1 models only allow temperature=1."""
model = values.get("model_name") or values.get("model") or "" model = values.get("model_name") or values.get("model") or ""
if model.startswith("o1") and "temperature" not in values: if model.startswith("o1") and "temperature" not in values:
values["temperature"] = 1 values["temperature"] = 1
@ -987,7 +981,9 @@ class BaseChatOpenAI(BaseChatModel):
yield generation_chunk yield generation_chunk
def _should_stream_usage( def _should_stream_usage(
self, stream_usage: Optional[bool] = None, **kwargs: Any self,
stream_usage: Optional[bool] = None, # noqa: FBT001
**kwargs: Any,
) -> bool: ) -> bool:
"""Determine whether to include usage metadata in streaming output. """Determine whether to include usage metadata in streaming output.
@ -1026,7 +1022,8 @@ class BaseChatOpenAI(BaseChatModel):
if self.include_response_headers: if self.include_response_headers:
warnings.warn( warnings.warn(
"Cannot currently include response headers when response_format is " "Cannot currently include response headers when response_format is "
"specified." "specified.",
stacklevel=2,
) )
payload.pop("stream") payload.pop("stream")
response_stream = self.root_client.beta.chat.completions.stream(**payload) response_stream = self.root_client.beta.chat.completions.stream(**payload)
@ -1093,7 +1090,8 @@ class BaseChatOpenAI(BaseChatModel):
if self.include_response_headers: if self.include_response_headers:
warnings.warn( warnings.warn(
"Cannot currently include response headers when response_format is " "Cannot currently include response headers when response_format is "
"specified." "specified.",
stacklevel=2,
) )
payload.pop("stream") payload.pop("stream")
try: try:
@ -1251,7 +1249,8 @@ class BaseChatOpenAI(BaseChatModel):
if self.include_response_headers: if self.include_response_headers:
warnings.warn( warnings.warn(
"Cannot currently include response headers when response_format is " "Cannot currently include response headers when response_format is "
"specified." "specified.",
stacklevel=2,
) )
payload.pop("stream") payload.pop("stream")
response_stream = self.root_async_client.beta.chat.completions.stream( response_stream = self.root_async_client.beta.chat.completions.stream(
@ -1322,7 +1321,8 @@ class BaseChatOpenAI(BaseChatModel):
if self.include_response_headers: if self.include_response_headers:
warnings.warn( warnings.warn(
"Cannot currently include response headers when response_format is " "Cannot currently include response headers when response_format is "
"specified." "specified.",
stacklevel=2,
) )
payload.pop("stream") payload.pop("stream")
try: try:
@ -1433,7 +1433,7 @@ class BaseChatOpenAI(BaseChatModel):
if self.custom_get_token_ids is not None: if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text) return self.custom_get_token_ids(text)
# tiktoken NOT supported for Python 3.7 or below # tiktoken NOT supported for Python 3.7 or below
if sys.version_info[1] <= 7: if sys.version_info[1] <= 7: # noqa: YTT203
return super().get_token_ids(text) return super().get_token_ids(text)
_, encoding_model = self._get_encoding_model() _, encoding_model = self._get_encoding_model()
return encoding_model.encode(text) return encoding_model.encode(text)
@ -1464,9 +1464,10 @@ class BaseChatOpenAI(BaseChatModel):
# TODO: Count bound tools as part of input. # TODO: Count bound tools as part of input.
if tools is not None: if tools is not None:
warnings.warn( warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools." "Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
) )
if sys.version_info[1] <= 7: if sys.version_info[1] <= 7: # noqa: YTT203
return super().get_num_tokens_from_messages(messages) return super().get_num_tokens_from_messages(messages)
model, encoding = self._get_encoding_model() model, encoding = self._get_encoding_model()
if model.startswith("gpt-3.5-turbo-0301"): if model.startswith("gpt-3.5-turbo-0301"):
@ -1519,7 +1520,8 @@ class BaseChatOpenAI(BaseChatModel):
elif val["type"] == "file": elif val["type"] == "file":
warnings.warn( warnings.warn(
"Token counts for file inputs are not supported. " "Token counts for file inputs are not supported. "
"Ignoring file inputs." "Ignoring file inputs.",
stacklevel=2,
) )
else: else:
msg = f"Unrecognized content block type\n\n{val}" msg = f"Unrecognized content block type\n\n{val}"
@ -1545,7 +1547,7 @@ class BaseChatOpenAI(BaseChatModel):
self, self,
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
function_call: Optional[ function_call: Optional[
Union[_FunctionCall, str, Literal["auto", "none"]] Union[_FunctionCall, str, Literal["auto", "none"]] # noqa: PYI051
] = None, ] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
@ -1601,7 +1603,7 @@ class BaseChatOpenAI(BaseChatModel):
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]], tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
*, *,
tool_choice: Optional[ tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool] Union[dict, str, Literal["auto", "none", "required", "any"], bool] # noqa: PYI051
] = None, ] = None,
strict: Optional[bool] = None, strict: Optional[bool] = None,
parallel_tool_calls: Optional[bool] = None, parallel_tool_calls: Optional[bool] = None,
@ -1840,7 +1842,8 @@ class BaseChatOpenAI(BaseChatModel):
"Received a Pydantic BaseModel V1 schema. This is not supported by " "Received a Pydantic BaseModel V1 schema. This is not supported by "
'method="json_schema". Please use method="function_calling" ' 'method="json_schema". Please use method="function_calling" '
"or specify schema via JSON Schema or Pydantic V2 BaseModel. " "or specify schema via JSON Schema or Pydantic V2 BaseModel. "
'Overriding to method="function_calling".' 'Overriding to method="function_calling".',
stacklevel=2,
) )
method = "function_calling" method = "function_calling"
# Check for incompatible model # Check for incompatible model
@ -1855,7 +1858,8 @@ class BaseChatOpenAI(BaseChatModel):
f"see supported models here: " f"see supported models here: "
f"https://platform.openai.com/docs/guides/structured-outputs#supported-models. " # noqa: E501 f"https://platform.openai.com/docs/guides/structured-outputs#supported-models. " # noqa: E501
"To fix this warning, set `method='function_calling'. " "To fix this warning, set `method='function_calling'. "
"Overriding to method='function_calling'." "Overriding to method='function_calling'.",
stacklevel=2,
) )
method = "function_calling" method = "function_calling"
@ -3762,7 +3766,7 @@ def _convert_responses_chunk_to_generation_chunk(
current_sub_index: int, # index of content block in output item current_sub_index: int, # index of content block in output item
schema: Optional[type[_BM]] = None, schema: Optional[type[_BM]] = None,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
has_reasoning: bool = False, has_reasoning: bool = False, # noqa: FBT001, FBT002
output_version: Literal["v0", "responses/v1"] = "v0", output_version: Literal["v0", "responses/v1"] = "v0",
) -> tuple[int, int, int, Optional[ChatGenerationChunk]]: ) -> tuple[int, int, int, Optional[ChatGenerationChunk]]:
def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None: def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None:

View File

@ -20,7 +20,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): # type: ignore[override]
To access AzureOpenAI embedding models you'll need to create an Azure account, To access AzureOpenAI embedding models you'll need to create an Azure account,
get an API key, and install the `langchain-openai` integration package. get an API key, and install the `langchain-openai` integration package.
Youll need to have an Azure OpenAI instance deployed. You'll need to have an Azure OpenAI instance deployed.
You can deploy a version on Azure Portal following this You can deploy a version on Azure Portal following this
[guide](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource?pivots=web-portal). [guide](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource?pivots=web-portal).

View File

@ -21,7 +21,7 @@ def _process_batched_chunked_embeddings(
tokens: list[Union[list[int], str]], tokens: list[Union[list[int], str]],
batched_embeddings: list[list[float]], batched_embeddings: list[list[float]],
indices: list[int], indices: list[int],
skip_empty: bool, skip_empty: bool, # noqa: FBT001
) -> list[Optional[list[float]]]: ) -> list[Optional[list[float]]]:
# for each text, this is the list of embeddings (list of list of floats) # for each text, this is the list of embeddings (list of list of floats)
# corresponding to the chunks of the text # corresponding to the chunks of the text
@ -267,7 +267,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
warnings.warn( warnings.warn(
f"""WARNING! {field_name} is not default parameter. f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs. {field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended.""" Please confirm that {field_name} is what you intended.""",
stacklevel=2,
) )
extra[field_name] = values.pop(field_name) extra[field_name] = values.pop(field_name)

View File

@ -282,6 +282,8 @@ class BaseOpenAI(BaseLLM):
Args: Args:
prompts: The prompts to pass into the model. prompts: The prompts to pass into the model.
stop: Optional list of stop words to use when generating. stop: Optional list of stop words to use when generating.
run_manager: Optional callback manager to use for callbacks.
kwargs: Additional keyword arguments to pass to the model.
Returns: Returns:
The full LLM output. The full LLM output.
@ -483,7 +485,7 @@ class BaseOpenAI(BaseLLM):
if self.custom_get_token_ids is not None: if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text) return self.custom_get_token_ids(text)
# tiktoken NOT supported for Python < 3.8 # tiktoken NOT supported for Python < 3.8
if sys.version_info[1] < 8: if sys.version_info[1] < 8: # noqa: YTT203
return super().get_num_tokens(text) return super().get_num_tokens(text)
model_name = self.tiktoken_model_name or self.model_name model_name = self.tiktoken_model_name or self.model_name

View File

@ -64,11 +64,67 @@ ignore_missing_imports = true
target-version = "py39" target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = [
"A", # flake8-builtins
"B", # flake8-bugbear
"ASYNC", # flake8-async
"C4", # flake8-comprehensions
"COM", # flake8-commas
"D", # pydocstyle
"DOC", # pydoclint
"E", # pycodestyle error
"EM", # flake8-errmsg
"F", # pyflakes
"FA", # flake8-future-annotations
"FBT", # flake8-boolean-trap
"FLY", # flake8-flynt
"I", # isort
"ICN", # flake8-import-conventions
"INT", # flake8-gettext
"ISC", # isort-comprehensions
"PGH", # pygrep-hooks
"PIE", # flake8-pie
"PERF", # flake8-perf
"PYI", # flake8-pyi
"Q", # flake8-quotes
"RET", # flake8-return
"RSE", # flake8-rst-docstrings
"RUF", # ruff
"S", # flake8-bandit
"SLF", # flake8-self
"SLOT", # flake8-slots
"SIM", # flake8-simplify
"T10", # flake8-debugger
"T20", # flake8-print
"TID", # flake8-tidy-imports
"UP", # pyupgrade
"W", # pycodestyle warning
"YTT", # flake8-2020
]
ignore = [
"D100", # pydocstyle: Missing docstring in public module
"D101", # pydocstyle: Missing docstring in public class
"D102", # pydocstyle: Missing docstring in public method
"D103", # pydocstyle: Missing docstring in public function
"D104", # pydocstyle: Missing docstring in public package
"D105", # pydocstyle: Missing docstring in magic method
"D107", # pydocstyle: Missing docstring in __init__
"D203", # Messes with the formatter
"D407", # pydocstyle: Missing-dashed-underline-after-section
"COM812", # Messes with the formatter
"ISC001", # Messes with the formatter
"D213", # Messes with the formatter
"PERF203", # Rarely useful
"S112", # Rarely useful
"RUF012", # Doesn't play well with Pydantic
"SLF001", # Private member access
"UP007", # pyupgrade: non-pep604-annotation-union
"UP045", # pyupgrade: non-pep604-annotation-optional
]
unfixable = ["B028"] # People should intentionally tune the stacklevel
[tool.ruff.format] [tool.ruff.format]
docstring-code-format = true docstring-code-format = true
skip-magic-trailing-comma = true
[tool.coverage.run] [tool.coverage.run]
omit = ["tests/*"] omit = ["tests/*"]

View File

@ -68,7 +68,7 @@ def test_chat_openai_model() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True]) @pytest.mark.parametrize("use_responses_api", [False, True])
def test_chat_openai_system_message(use_responses_api: bool) -> None: def test_chat_openai_system_message(use_responses_api: bool) -> None: # noqa: FBT001
"""Test ChatOpenAI wrapper with system message.""" """Test ChatOpenAI wrapper with system message."""
chat = ChatOpenAI(use_responses_api=use_responses_api, max_tokens=MAX_TOKEN_COUNT) # type: ignore[call-arg] chat = ChatOpenAI(use_responses_api=use_responses_api, max_tokens=MAX_TOKEN_COUNT) # type: ignore[call-arg]
system_message = SystemMessage(content="You are to chat with the user.") system_message = SystemMessage(content="You are to chat with the user.")
@ -110,7 +110,7 @@ def test_chat_openai_multiple_completions() -> None:
@pytest.mark.scheduled @pytest.mark.scheduled
@pytest.mark.parametrize("use_responses_api", [False, True]) @pytest.mark.parametrize("use_responses_api", [False, True])
def test_chat_openai_streaming(use_responses_api: bool) -> None: def test_chat_openai_streaming(use_responses_api: bool) -> None: # noqa: FBT001
"""Test that streaming correctly invokes on_llm_new_token callback.""" """Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler() callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler]) callback_manager = CallbackManager([callback_handler])
@ -206,7 +206,7 @@ async def test_async_chat_openai_bind_functions() -> None:
@pytest.mark.scheduled @pytest.mark.scheduled
@pytest.mark.parametrize("use_responses_api", [False, True]) @pytest.mark.parametrize("use_responses_api", [False, True])
async def test_openai_abatch_tags(use_responses_api: bool) -> None: async def test_openai_abatch_tags(use_responses_api: bool) -> None: # noqa: FBT001
"""Test batch tokens from ChatOpenAI.""" """Test batch tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=MAX_TOKEN_COUNT, use_responses_api=use_responses_api) # type: ignore[call-arg] llm = ChatOpenAI(max_tokens=MAX_TOKEN_COUNT, use_responses_api=use_responses_api) # type: ignore[call-arg]
@ -273,7 +273,7 @@ def test_stream() -> None:
async def test_astream() -> None: async def test_astream() -> None:
"""Test streaming tokens from OpenAI.""" """Test streaming tokens from OpenAI."""
async def _test_stream(stream: AsyncIterator, expect_usage: bool) -> None: async def _test_stream(stream: AsyncIterator, expect_usage: bool) -> None: # noqa: FBT001
full: Optional[BaseMessageChunk] = None full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0 chunks_with_token_counts = 0
chunks_with_response_metadata = 0 chunks_with_response_metadata = 0
@ -447,7 +447,7 @@ def test_tool_use() -> None:
gathered = message gathered = message
first = False first = False
else: else:
gathered = gathered + message # type: ignore gathered = gathered + message # type: ignore[assignment]
assert isinstance(gathered, AIMessageChunk) assert isinstance(gathered, AIMessageChunk)
assert isinstance(gathered.tool_call_chunks, list) assert isinstance(gathered.tool_call_chunks, list)
assert len(gathered.tool_call_chunks) == 1 assert len(gathered.tool_call_chunks) == 1
@ -463,7 +463,7 @@ def test_tool_use() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True]) @pytest.mark.parametrize("use_responses_api", [False, True])
def test_manual_tool_call_msg(use_responses_api: bool) -> None: def test_manual_tool_call_msg(use_responses_api: bool) -> None: # noqa: FBT001
"""Test passing in manually construct tool call message.""" """Test passing in manually construct tool call message."""
llm = ChatOpenAI( llm = ChatOpenAI(
model="gpt-3.5-turbo-0125", temperature=0, use_responses_api=use_responses_api model="gpt-3.5-turbo-0125", temperature=0, use_responses_api=use_responses_api
@ -504,12 +504,12 @@ def test_manual_tool_call_msg(use_responses_api: bool) -> None:
), ),
ToolMessage("sally_green_hair", tool_call_id="foo"), ToolMessage("sally_green_hair", tool_call_id="foo"),
] ]
with pytest.raises(Exception): with pytest.raises(Exception): # noqa: B017
llm_with_tool.invoke(msgs) llm_with_tool.invoke(msgs)
@pytest.mark.parametrize("use_responses_api", [False, True]) @pytest.mark.parametrize("use_responses_api", [False, True])
def test_bind_tools_tool_choice(use_responses_api: bool) -> None: def test_bind_tools_tool_choice(use_responses_api: bool) -> None: # noqa: FBT001
"""Test passing in manually construct tool call message.""" """Test passing in manually construct tool call message."""
llm = ChatOpenAI( llm = ChatOpenAI(
model="gpt-3.5-turbo-0125", temperature=0, use_responses_api=use_responses_api model="gpt-3.5-turbo-0125", temperature=0, use_responses_api=use_responses_api
@ -574,7 +574,7 @@ def test_openai_proxy() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True]) @pytest.mark.parametrize("use_responses_api", [False, True])
def test_openai_response_headers(use_responses_api: bool) -> None: def test_openai_response_headers(use_responses_api: bool) -> None: # noqa: FBT001
"""Test ChatOpenAI response headers.""" """Test ChatOpenAI response headers."""
chat_openai = ChatOpenAI( chat_openai = ChatOpenAI(
include_response_headers=True, use_responses_api=use_responses_api include_response_headers=True, use_responses_api=use_responses_api
@ -598,7 +598,7 @@ def test_openai_response_headers(use_responses_api: bool) -> None:
@pytest.mark.parametrize("use_responses_api", [False, True]) @pytest.mark.parametrize("use_responses_api", [False, True])
async def test_openai_response_headers_async(use_responses_api: bool) -> None: async def test_openai_response_headers_async(use_responses_api: bool) -> None: # noqa: FBT001
"""Test ChatOpenAI response headers.""" """Test ChatOpenAI response headers."""
chat_openai = ChatOpenAI( chat_openai = ChatOpenAI(
include_response_headers=True, use_responses_api=use_responses_api include_response_headers=True, use_responses_api=use_responses_api
@ -686,7 +686,7 @@ def test_image_token_counting_png() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True]) @pytest.mark.parametrize("use_responses_api", [False, True])
def test_tool_calling_strict(use_responses_api: bool) -> None: def test_tool_calling_strict(use_responses_api: bool) -> None: # noqa: FBT001
"""Test tool calling with strict=True. """Test tool calling with strict=True.
Responses API appears to have fewer constraints on schema when strict=True. Responses API appears to have fewer constraints on schema when strict=True.
@ -719,7 +719,7 @@ def test_tool_calling_strict(use_responses_api: bool) -> None:
# Test stream # Test stream
full: Optional[BaseMessageChunk] = None full: Optional[BaseMessageChunk] = None
for chunk in model_with_tools.stream(query): for chunk in model_with_tools.stream(query):
full = chunk if full is None else full + chunk # type: ignore full = chunk if full is None else full + chunk # type: ignore[assignment]
assert isinstance(full, AIMessage) assert isinstance(full, AIMessage)
_validate_tool_call_message(full) _validate_tool_call_message(full)
@ -736,7 +736,7 @@ def test_tool_calling_strict(use_responses_api: bool) -> None:
def test_structured_output_strict( def test_structured_output_strict(
model: str, model: str,
method: Literal["function_calling", "json_schema"], method: Literal["function_calling", "json_schema"],
use_responses_api: bool, use_responses_api: bool, # noqa: FBT001
) -> None: ) -> None:
"""Test to verify structured output with strict=True.""" """Test to verify structured output with strict=True."""
from pydantic import BaseModel as BaseModelProper from pydantic import BaseModel as BaseModelProper
@ -775,7 +775,9 @@ def test_structured_output_strict(
@pytest.mark.parametrize("use_responses_api", [False, True]) @pytest.mark.parametrize("use_responses_api", [False, True])
@pytest.mark.parametrize(("model", "method"), [("gpt-4o-2024-08-06", "json_schema")]) @pytest.mark.parametrize(("model", "method"), [("gpt-4o-2024-08-06", "json_schema")])
def test_nested_structured_output_strict( def test_nested_structured_output_strict(
model: str, method: Literal["json_schema"], use_responses_api: bool model: str,
method: Literal["json_schema"],
use_responses_api: bool, # noqa: FBT001
) -> None: ) -> None:
"""Test to verify structured output with strict=True for nested object.""" """Test to verify structured output with strict=True for nested object."""
from typing import TypedDict from typing import TypedDict
@ -817,7 +819,8 @@ def test_nested_structured_output_strict(
], ],
) )
def test_json_schema_openai_format( def test_json_schema_openai_format(
strict: bool, method: Literal["json_schema", "function_calling"] strict: bool, # noqa: FBT001
method: Literal["json_schema", "function_calling"],
) -> None: ) -> None:
"""Test we can pass in OpenAI schema format specifying strict.""" """Test we can pass in OpenAI schema format specifying strict."""
llm = ChatOpenAI(model="gpt-4o-mini") llm = ChatOpenAI(model="gpt-4o-mini")
@ -960,7 +963,7 @@ def test_prediction_tokens() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True]) @pytest.mark.parametrize("use_responses_api", [False, True])
def test_stream_o_series(use_responses_api: bool) -> None: def test_stream_o_series(use_responses_api: bool) -> None: # noqa: FBT001
list( list(
ChatOpenAI(model="o3-mini", use_responses_api=use_responses_api).stream( ChatOpenAI(model="o3-mini", use_responses_api=use_responses_api).stream(
"how are you" "how are you"
@ -969,7 +972,7 @@ def test_stream_o_series(use_responses_api: bool) -> None:
@pytest.mark.parametrize("use_responses_api", [False, True]) @pytest.mark.parametrize("use_responses_api", [False, True])
async def test_astream_o_series(use_responses_api: bool) -> None: async def test_astream_o_series(use_responses_api: bool) -> None: # noqa: FBT001
async for _ in ChatOpenAI( async for _ in ChatOpenAI(
model="o3-mini", use_responses_api=use_responses_api model="o3-mini", use_responses_api=use_responses_api
).astream("how are you"): ).astream("how are you"):
@ -1016,7 +1019,7 @@ async def test_astream_response_format() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True]) @pytest.mark.parametrize("use_responses_api", [False, True])
@pytest.mark.parametrize("use_max_completion_tokens", [True, False]) @pytest.mark.parametrize("use_max_completion_tokens", [True, False])
def test_o1(use_max_completion_tokens: bool, use_responses_api: bool) -> None: def test_o1(use_max_completion_tokens: bool, use_responses_api: bool) -> None: # noqa: FBT001
if use_max_completion_tokens: if use_max_completion_tokens:
kwargs: dict = {"max_completion_tokens": MAX_TOKEN_COUNT} kwargs: dict = {"max_completion_tokens": MAX_TOKEN_COUNT}
else: else:

View File

@ -123,7 +123,7 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
_ = model.invoke([message]) _ = model.invoke([message])
def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage: def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage: # noqa: FBT001
if stream: if stream:
full = None full = None
for chunk in llm.stream(input_): for chunk in llm.stream(input_):

View File

@ -334,7 +334,7 @@ def test_stateful_api() -> None:
"what's my name", previous_response_id=response.response_metadata["id"] "what's my name", previous_response_id=response.response_metadata["id"]
) )
assert isinstance(second_response.content, list) assert isinstance(second_response.content, list)
assert "bobo" in second_response.content[0]["text"].lower() # type: ignore assert "bobo" in second_response.content[0]["text"].lower() # type: ignore[index]
def test_route_from_model_kwargs() -> None: def test_route_from_model_kwargs() -> None:

View File

@ -49,7 +49,7 @@ class TestOpenAIResponses(TestOpenAIStandard):
return _invoke(llm, input_, stream) return _invoke(llm, input_, stream)
def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage: def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage: # noqa: FBT001
if stream: if stream:
full = None full = None
for chunk in llm.stream(input_): for chunk in llm.stream(input_):

View File

@ -16,7 +16,6 @@ DEPLOYMENT_NAME = os.environ.get(
"AZURE_OPENAI_DEPLOYMENT_NAME", "AZURE_OPENAI_DEPLOYMENT_NAME",
os.environ.get("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME", ""), os.environ.get("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME", ""),
) )
print
def _get_embeddings(**kwargs: Any) -> AzureOpenAIEmbeddings: def _get_embeddings(**kwargs: Any) -> AzureOpenAIEmbeddings:
@ -111,7 +110,7 @@ def test_azure_openai_embedding_with_empty_string() -> None:
api_key=OPENAI_API_KEY, api_key=OPENAI_API_KEY,
azure_endpoint=OPENAI_API_BASE, azure_endpoint=OPENAI_API_BASE,
azure_deployment=DEPLOYMENT_NAME, azure_deployment=DEPLOYMENT_NAME,
) # type: ignore )
.embeddings.create(input="", model="text-embedding-ada-002") .embeddings.create(input="", model="text-embedding-ada-002")
.data[0] .data[0]
.embedding .embedding

View File

@ -80,8 +80,8 @@ def test_structured_output_old_model() -> None:
).with_structured_output(Output) ).with_structured_output(Output)
# assert tool calling was used instead of json_schema # assert tool calling was used instead of json_schema
assert "tools" in llm.steps[0].kwargs # type: ignore assert "tools" in llm.steps[0].kwargs # type: ignore[attr-defined]
assert "response_format" not in llm.steps[0].kwargs # type: ignore assert "response_format" not in llm.steps[0].kwargs # type: ignore[attr-defined]
def test_max_completion_tokens_in_payload() -> None: def test_max_completion_tokens_in_payload() -> None:

View File

@ -326,7 +326,7 @@ class MockSyncContextManager:
GLM4_STREAM_META = """{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4eba\u5de5\u667a\u80fd"}}]} GLM4_STREAM_META = """{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4eba\u5de5\u667a\u80fd"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u52a9\u624b"}}]} {"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u52a9\u624b"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":""}}]} {"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":","}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4f60\u53ef\u4ee5"}}]} {"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4f60\u53ef\u4ee5"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u53eb\u6211"}}]} {"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u53eb\u6211"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"AI"}}]} {"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"AI"}}]}
@ -339,12 +339,7 @@ GLM4_STREAM_META = """{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","cr
@pytest.fixture @pytest.fixture
def mock_glm4_completion() -> list: def mock_glm4_completion() -> list:
list_chunk_data = GLM4_STREAM_META.split("\n") list_chunk_data = GLM4_STREAM_META.split("\n")
result_list = [] return [json.loads(msg) for msg in list_chunk_data if msg != "[DONE]"]
for msg in list_chunk_data:
if msg != "[DONE]":
result_list.append(json.loads(msg))
return result_list
async def test_glm4_astream(mock_glm4_completion: list) -> None: async def test_glm4_astream(mock_glm4_completion: list) -> None:
@ -360,7 +355,7 @@ async def test_glm4_astream(mock_glm4_completion: list) -> None:
usage_metadata: Optional[UsageMetadata] = None usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "async_client", mock_client): with patch.object(llm, "async_client", mock_client):
async for chunk in llm.astream("你的名字叫什么只回答名字"): async for chunk in llm.astream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk) assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None: if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata usage_metadata = chunk.usage_metadata
@ -385,7 +380,7 @@ def test_glm4_stream(mock_glm4_completion: list) -> None:
usage_metadata: Optional[UsageMetadata] = None usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "client", mock_client): with patch.object(llm, "client", mock_client):
for chunk in llm.stream("你的名字叫什么只回答名字"): for chunk in llm.stream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk) assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None: if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata usage_metadata = chunk.usage_metadata
@ -402,7 +397,7 @@ DEEPSEEK_STREAM_DATA = """{"id":"d3610c24e6b42518a7883ea57c3ea2c3","choices":[{"
{"choices":[{"delta":{"content":"Deep","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} {"choices":[{"delta":{"content":"Deep","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"Seek","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} {"choices":[{"delta":{"content":"Seek","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":" Chat","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} {"choices":[{"delta":{"content":" Chat","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} {"choices":[{"delta":{"content":",","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"一个","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} {"choices":[{"delta":{"content":"一个","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} {"choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"深度","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null} {"choices":[{"delta":{"content":"深度","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
@ -420,12 +415,7 @@ DEEPSEEK_STREAM_DATA = """{"id":"d3610c24e6b42518a7883ea57c3ea2c3","choices":[{"
@pytest.fixture @pytest.fixture
def mock_deepseek_completion() -> list[dict]: def mock_deepseek_completion() -> list[dict]:
list_chunk_data = DEEPSEEK_STREAM_DATA.split("\n") list_chunk_data = DEEPSEEK_STREAM_DATA.split("\n")
result_list = [] return [json.loads(msg) for msg in list_chunk_data if msg != "[DONE]"]
for msg in list_chunk_data:
if msg != "[DONE]":
result_list.append(json.loads(msg))
return result_list
async def test_deepseek_astream(mock_deepseek_completion: list) -> None: async def test_deepseek_astream(mock_deepseek_completion: list) -> None:
@ -440,7 +430,7 @@ async def test_deepseek_astream(mock_deepseek_completion: list) -> None:
usage_chunk = mock_deepseek_completion[-1] usage_chunk = mock_deepseek_completion[-1]
usage_metadata: Optional[UsageMetadata] = None usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "async_client", mock_client): with patch.object(llm, "async_client", mock_client):
async for chunk in llm.astream("你的名字叫什么只回答名字"): async for chunk in llm.astream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk) assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None: if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata usage_metadata = chunk.usage_metadata
@ -464,7 +454,7 @@ def test_deepseek_stream(mock_deepseek_completion: list) -> None:
usage_chunk = mock_deepseek_completion[-1] usage_chunk = mock_deepseek_completion[-1]
usage_metadata: Optional[UsageMetadata] = None usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "client", mock_client): with patch.object(llm, "client", mock_client):
for chunk in llm.stream("你的名字叫什么只回答名字"): for chunk in llm.stream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk) assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None: if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata usage_metadata = chunk.usage_metadata
@ -488,12 +478,7 @@ OPENAI_STREAM_DATA = """{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":
@pytest.fixture @pytest.fixture
def mock_openai_completion() -> list[dict]: def mock_openai_completion() -> list[dict]:
list_chunk_data = OPENAI_STREAM_DATA.split("\n") list_chunk_data = OPENAI_STREAM_DATA.split("\n")
result_list = [] return [json.loads(msg) for msg in list_chunk_data if msg != "[DONE]"]
for msg in list_chunk_data:
if msg != "[DONE]":
result_list.append(json.loads(msg))
return result_list
async def test_openai_astream(mock_openai_completion: list) -> None: async def test_openai_astream(mock_openai_completion: list) -> None:
@ -508,7 +493,7 @@ async def test_openai_astream(mock_openai_completion: list) -> None:
usage_chunk = mock_openai_completion[-1] usage_chunk = mock_openai_completion[-1]
usage_metadata: Optional[UsageMetadata] = None usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "async_client", mock_client): with patch.object(llm, "async_client", mock_client):
async for chunk in llm.astream("你的名字叫什么只回答名字"): async for chunk in llm.astream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk) assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None: if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata usage_metadata = chunk.usage_metadata
@ -532,7 +517,7 @@ def test_openai_stream(mock_openai_completion: list) -> None:
usage_chunk = mock_openai_completion[-1] usage_chunk = mock_openai_completion[-1]
usage_metadata: Optional[UsageMetadata] = None usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "client", mock_client): with patch.object(llm, "client", mock_client):
for chunk in llm.stream("你的名字叫什么只回答名字"): for chunk in llm.stream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk) assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None: if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata usage_metadata = chunk.usage_metadata
@ -805,7 +790,7 @@ class MakeASandwich(BaseModel):
], ],
) )
@pytest.mark.parametrize("strict", [True, False, None]) @pytest.mark.parametrize("strict", [True, False, None])
def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> None: def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> None: # noqa: FBT001
"""Test passing in manually construct tool call message.""" """Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm.bind_tools( llm.bind_tools(
@ -822,8 +807,8 @@ def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> Non
def test_with_structured_output( def test_with_structured_output(
schema: Union[type, dict[str, Any], None], schema: Union[type, dict[str, Any], None],
method: Literal["function_calling", "json_mode", "json_schema"], method: Literal["function_calling", "json_mode", "json_schema"],
include_raw: bool, include_raw: bool, # noqa: FBT001
strict: Optional[bool], strict: Optional[bool], # noqa: FBT001
) -> None: ) -> None:
"""Test passing in manually construct tool call message.""" """Test passing in manually construct tool call message."""
if method == "json_mode": if method == "json_mode":
@ -1044,7 +1029,8 @@ def test__convert_to_openai_response_format() -> None:
@pytest.mark.parametrize("method", ["function_calling", "json_schema"]) @pytest.mark.parametrize("method", ["function_calling", "json_schema"])
@pytest.mark.parametrize("strict", [True, None]) @pytest.mark.parametrize("strict", [True, None])
def test_structured_output_strict( def test_structured_output_strict(
method: Literal["function_calling", "json_schema"], strict: Optional[bool] method: Literal["function_calling", "json_schema"],
strict: Optional[bool], # noqa: FBT001
) -> None: ) -> None:
"""Test to verify structured output with strict=True.""" """Test to verify structured output with strict=True."""
llm = ChatOpenAI(model="gpt-4o-2024-08-06") llm = ChatOpenAI(model="gpt-4o-2024-08-06")
@ -1167,8 +1153,8 @@ def test_structured_output_old_model() -> None:
with pytest.warns(match="Cannot use method='json_schema'"): with pytest.warns(match="Cannot use method='json_schema'"):
llm = ChatOpenAI(model="gpt-4").with_structured_output(Output) llm = ChatOpenAI(model="gpt-4").with_structured_output(Output)
# assert tool calling was used instead of json_schema # assert tool calling was used instead of json_schema
assert "tools" in llm.steps[0].kwargs # type: ignore assert "tools" in llm.steps[0].kwargs # type: ignore[attr-defined]
assert "response_format" not in llm.steps[0].kwargs # type: ignore assert "response_format" not in llm.steps[0].kwargs # type: ignore[attr-defined]
def test_structured_outputs_parser() -> None: def test_structured_outputs_parser() -> None:

View File

@ -32,7 +32,6 @@ def test_embed_documents_with_custom_chunk_size() -> None:
result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size) result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
_, tokens, __ = embeddings._tokenize(texts, custom_chunk_size) _, tokens, __ = embeddings._tokenize(texts, custom_chunk_size)
mock_create.call_args
mock_create.assert_any_call(input=tokens[0:3], **embeddings._invocation_params) mock_create.assert_any_call(input=tokens[0:3], **embeddings._invocation_params)
mock_create.assert_any_call(input=tokens[3:4], **embeddings._invocation_params) mock_create.assert_any_call(input=tokens[3:4], **embeddings._invocation_params)
@ -52,7 +51,6 @@ def test_embed_documents_with_custom_chunk_size_no_check_ctx_length() -> None:
result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size) result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
mock_create.call_args
mock_create.assert_any_call(input=texts[0:3], **embeddings._invocation_params) mock_create.assert_any_call(input=texts[0:3], **embeddings._invocation_params)
mock_create.assert_any_call(input=texts[3:4], **embeddings._invocation_params) mock_create.assert_any_call(input=texts[3:4], **embeddings._invocation_params)