This commit is contained in:
Mason Daugherty 2025-07-08 21:50:58 -04:00
parent e7e8391fd0
commit ca10f0eee5
No known key found for this signature in database
16 changed files with 146 additions and 106 deletions

View File

@ -9,6 +9,7 @@ Logic is largely replicated from openai._base_client.
from __future__ import annotations
import asyncio
import contextlib
import os
from functools import lru_cache
from typing import Any, Optional
@ -23,10 +24,8 @@ class _SyncHttpxClientWrapper(openai.DefaultHttpxClient):
if self.is_closed:
return
try:
with contextlib.suppress(Exception):
self.close()
except Exception: # noqa: S110
pass
class _AsyncHttpxClientWrapper(openai.DefaultAsyncHttpxClient):
@ -36,11 +35,9 @@ class _AsyncHttpxClientWrapper(openai.DefaultAsyncHttpxClient):
if self.is_closed:
return
try:
# TODO(someday): support non asyncio runtimes here
# TODO(someday): support non asyncio runtimes here
with contextlib.suppress(Exception):
asyncio.get_running_loop().create_task(self.aclose())
except Exception: # noqa: S110
pass
def _build_sync_httpx_client(

View File

@ -70,7 +70,8 @@ _FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
def _convert_to_v03_ai_message(
message: AIMessage, has_reasoning: bool = False
message: AIMessage,
has_reasoning: bool = False, # noqa: FBT001, FBT002
) -> AIMessage:
"""Mutate an AIMessage to the old-style v0.3 format."""
if isinstance(message.content, list):

View File

@ -5,12 +5,11 @@ from __future__ import annotations
import logging
import os
from collections.abc import AsyncIterator, Awaitable, Iterator
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
from typing import Any, Callable, Optional, TypeVar, Union
import openai
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import LangSmithParams
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.utils import from_env, secret_from_env
@ -28,12 +27,6 @@ _DictOrPydanticClass = Union[dict[str, Any], type[_BM]]
_DictOrPydantic = Union[dict, _BM]
class _AllReturnType(TypedDict):
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and is_basemodel_subclass(obj)

View File

@ -369,7 +369,7 @@ def _convert_delta_to_message_chunk(
)
if role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role, id=id_)
return default_class(content=content, id=id_) # type: ignore
return default_class(content=content, id=id_) # type: ignore[call-arg]
def _update_token_usage(
@ -396,7 +396,7 @@ def _update_token_usage(
k: _update_token_usage(overall_token_usage.get(k, 0), v)
for k, v in new_usage.items()
}
warnings.warn(f"Unexpected type for token usage: {type(new_usage)}")
warnings.warn(f"Unexpected type for token usage: {type(new_usage)}", stacklevel=3)
return new_usage
@ -410,7 +410,7 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
"langchain-openai==0.3. To use `with_structured_output` with this model, "
'specify `method="function_calling"`.'
)
warnings.warn(message)
warnings.warn(message, stacklevel=3)
raise e
if "Invalid schema for response_format" in e.message:
message = (
@ -420,7 +420,7 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
"See supported schemas: "
"https://platform.openai.com/docs/guides/structured-outputs#supported-schemas"
)
warnings.warn(message)
warnings.warn(message, stacklevel=3)
raise e
raise
@ -434,12 +434,6 @@ _DictOrPydanticClass = Union[dict[str, Any], type[_BM], type]
_DictOrPydantic = Union[dict, _BM]
class _AllReturnType(TypedDict):
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
class BaseChatOpenAI(BaseChatModel):
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
@ -677,7 +671,7 @@ class BaseChatOpenAI(BaseChatModel):
@model_validator(mode="before")
@classmethod
def validate_temperature(cls, values: dict[str, Any]) -> Any:
"""Currently o1 models only allow temperature=1."""
"""o1 models only allow temperature=1."""
model = values.get("model_name") or values.get("model") or ""
if model.startswith("o1") and "temperature" not in values:
values["temperature"] = 1
@ -987,7 +981,9 @@ class BaseChatOpenAI(BaseChatModel):
yield generation_chunk
def _should_stream_usage(
self, stream_usage: Optional[bool] = None, **kwargs: Any
self,
stream_usage: Optional[bool] = None, # noqa: FBT001
**kwargs: Any,
) -> bool:
"""Determine whether to include usage metadata in streaming output.
@ -1026,7 +1022,8 @@ class BaseChatOpenAI(BaseChatModel):
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
"specified.",
stacklevel=2,
)
payload.pop("stream")
response_stream = self.root_client.beta.chat.completions.stream(**payload)
@ -1093,7 +1090,8 @@ class BaseChatOpenAI(BaseChatModel):
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
"specified.",
stacklevel=2,
)
payload.pop("stream")
try:
@ -1251,7 +1249,8 @@ class BaseChatOpenAI(BaseChatModel):
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
"specified.",
stacklevel=2,
)
payload.pop("stream")
response_stream = self.root_async_client.beta.chat.completions.stream(
@ -1322,7 +1321,8 @@ class BaseChatOpenAI(BaseChatModel):
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
"specified.",
stacklevel=2,
)
payload.pop("stream")
try:
@ -1433,7 +1433,7 @@ class BaseChatOpenAI(BaseChatModel):
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
# tiktoken NOT supported for Python 3.7 or below
if sys.version_info[1] <= 7:
if sys.version_info[1] <= 7: # noqa: YTT203
return super().get_token_ids(text)
_, encoding_model = self._get_encoding_model()
return encoding_model.encode(text)
@ -1464,9 +1464,10 @@ class BaseChatOpenAI(BaseChatModel):
# TODO: Count bound tools as part of input.
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools."
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
)
if sys.version_info[1] <= 7:
if sys.version_info[1] <= 7: # noqa: YTT203
return super().get_num_tokens_from_messages(messages)
model, encoding = self._get_encoding_model()
if model.startswith("gpt-3.5-turbo-0301"):
@ -1519,7 +1520,8 @@ class BaseChatOpenAI(BaseChatModel):
elif val["type"] == "file":
warnings.warn(
"Token counts for file inputs are not supported. "
"Ignoring file inputs."
"Ignoring file inputs.",
stacklevel=2,
)
else:
msg = f"Unrecognized content block type\n\n{val}"
@ -1545,7 +1547,7 @@ class BaseChatOpenAI(BaseChatModel):
self,
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
function_call: Optional[
Union[_FunctionCall, str, Literal["auto", "none"]]
Union[_FunctionCall, str, Literal["auto", "none"]] # noqa: PYI051
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
@ -1601,7 +1603,7 @@ class BaseChatOpenAI(BaseChatModel):
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
Union[dict, str, Literal["auto", "none", "required", "any"], bool] # noqa: PYI051
] = None,
strict: Optional[bool] = None,
parallel_tool_calls: Optional[bool] = None,
@ -1840,7 +1842,8 @@ class BaseChatOpenAI(BaseChatModel):
"Received a Pydantic BaseModel V1 schema. This is not supported by "
'method="json_schema". Please use method="function_calling" '
"or specify schema via JSON Schema or Pydantic V2 BaseModel. "
'Overriding to method="function_calling".'
'Overriding to method="function_calling".',
stacklevel=2,
)
method = "function_calling"
# Check for incompatible model
@ -1855,7 +1858,8 @@ class BaseChatOpenAI(BaseChatModel):
f"see supported models here: "
f"https://platform.openai.com/docs/guides/structured-outputs#supported-models. " # noqa: E501
"To fix this warning, set `method='function_calling'. "
"Overriding to method='function_calling'."
"Overriding to method='function_calling'.",
stacklevel=2,
)
method = "function_calling"
@ -3762,7 +3766,7 @@ def _convert_responses_chunk_to_generation_chunk(
current_sub_index: int, # index of content block in output item
schema: Optional[type[_BM]] = None,
metadata: Optional[dict] = None,
has_reasoning: bool = False,
has_reasoning: bool = False, # noqa: FBT001, FBT002
output_version: Literal["v0", "responses/v1"] = "v0",
) -> tuple[int, int, int, Optional[ChatGenerationChunk]]:
def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None:

View File

@ -20,7 +20,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): # type: ignore[override]
To access AzureOpenAI embedding models you'll need to create an Azure account,
get an API key, and install the `langchain-openai` integration package.
Youll need to have an Azure OpenAI instance deployed.
You'll need to have an Azure OpenAI instance deployed.
You can deploy a version on Azure Portal following this
[guide](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource?pivots=web-portal).

View File

@ -21,7 +21,7 @@ def _process_batched_chunked_embeddings(
tokens: list[Union[list[int], str]],
batched_embeddings: list[list[float]],
indices: list[int],
skip_empty: bool,
skip_empty: bool, # noqa: FBT001
) -> list[Optional[list[float]]]:
# for each text, this is the list of embeddings (list of list of floats)
# corresponding to the chunks of the text
@ -267,7 +267,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
Please confirm that {field_name} is what you intended.""",
stacklevel=2,
)
extra[field_name] = values.pop(field_name)

View File

@ -282,6 +282,8 @@ class BaseOpenAI(BaseLLM):
Args:
prompts: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
run_manager: Optional callback manager to use for callbacks.
kwargs: Additional keyword arguments to pass to the model.
Returns:
The full LLM output.
@ -483,7 +485,7 @@ class BaseOpenAI(BaseLLM):
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
# tiktoken NOT supported for Python < 3.8
if sys.version_info[1] < 8:
if sys.version_info[1] < 8: # noqa: YTT203
return super().get_num_tokens(text)
model_name = self.tiktoken_model_name or self.model_name

View File

@ -64,11 +64,67 @@ ignore_missing_imports = true
target-version = "py39"
[tool.ruff.lint]
select = [
"A", # flake8-builtins
"B", # flake8-bugbear
"ASYNC", # flake8-async
"C4", # flake8-comprehensions
"COM", # flake8-commas
"D", # pydocstyle
"DOC", # pydoclint
"E", # pycodestyle error
"EM", # flake8-errmsg
"F", # pyflakes
"FA", # flake8-future-annotations
"FBT", # flake8-boolean-trap
"FLY", # flake8-flynt
"I", # isort
"ICN", # flake8-import-conventions
"INT", # flake8-gettext
"ISC", # isort-comprehensions
"PGH", # pygrep-hooks
"PIE", # flake8-pie
"PERF", # flake8-perf
"PYI", # flake8-pyi
"Q", # flake8-quotes
"RET", # flake8-return
"RSE", # flake8-rst-docstrings
"RUF", # ruff
"S", # flake8-bandit
"SLF", # flake8-self
"SLOT", # flake8-slots
"SIM", # flake8-simplify
"T10", # flake8-debugger
"T20", # flake8-print
"TID", # flake8-tidy-imports
"UP", # pyupgrade
"W", # pycodestyle warning
"YTT", # flake8-2020
]
ignore = [
"D100", # pydocstyle: Missing docstring in public module
"D101", # pydocstyle: Missing docstring in public class
"D102", # pydocstyle: Missing docstring in public method
"D103", # pydocstyle: Missing docstring in public function
"D104", # pydocstyle: Missing docstring in public package
"D105", # pydocstyle: Missing docstring in magic method
"D107", # pydocstyle: Missing docstring in __init__
"D203", # Messes with the formatter
"D407", # pydocstyle: Missing-dashed-underline-after-section
"COM812", # Messes with the formatter
"ISC001", # Messes with the formatter
"D213", # Messes with the formatter
"PERF203", # Rarely useful
"S112", # Rarely useful
"RUF012", # Doesn't play well with Pydantic
"SLF001", # Private member access
"UP007", # pyupgrade: non-pep604-annotation-union
"UP045", # pyupgrade: non-pep604-annotation-optional
]
unfixable = ["B028"] # People should intentionally tune the stacklevel
[tool.ruff.format]
docstring-code-format = true
skip-magic-trailing-comma = true
[tool.coverage.run]
omit = ["tests/*"]

View File

@ -68,7 +68,7 @@ def test_chat_openai_model() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True])
def test_chat_openai_system_message(use_responses_api: bool) -> None:
def test_chat_openai_system_message(use_responses_api: bool) -> None: # noqa: FBT001
"""Test ChatOpenAI wrapper with system message."""
chat = ChatOpenAI(use_responses_api=use_responses_api, max_tokens=MAX_TOKEN_COUNT) # type: ignore[call-arg]
system_message = SystemMessage(content="You are to chat with the user.")
@ -110,7 +110,7 @@ def test_chat_openai_multiple_completions() -> None:
@pytest.mark.scheduled
@pytest.mark.parametrize("use_responses_api", [False, True])
def test_chat_openai_streaming(use_responses_api: bool) -> None:
def test_chat_openai_streaming(use_responses_api: bool) -> None: # noqa: FBT001
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
@ -206,7 +206,7 @@ async def test_async_chat_openai_bind_functions() -> None:
@pytest.mark.scheduled
@pytest.mark.parametrize("use_responses_api", [False, True])
async def test_openai_abatch_tags(use_responses_api: bool) -> None:
async def test_openai_abatch_tags(use_responses_api: bool) -> None: # noqa: FBT001
"""Test batch tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=MAX_TOKEN_COUNT, use_responses_api=use_responses_api) # type: ignore[call-arg]
@ -273,7 +273,7 @@ def test_stream() -> None:
async def test_astream() -> None:
"""Test streaming tokens from OpenAI."""
async def _test_stream(stream: AsyncIterator, expect_usage: bool) -> None:
async def _test_stream(stream: AsyncIterator, expect_usage: bool) -> None: # noqa: FBT001
full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
chunks_with_response_metadata = 0
@ -447,7 +447,7 @@ def test_tool_use() -> None:
gathered = message
first = False
else:
gathered = gathered + message # type: ignore
gathered = gathered + message # type: ignore[assignment]
assert isinstance(gathered, AIMessageChunk)
assert isinstance(gathered.tool_call_chunks, list)
assert len(gathered.tool_call_chunks) == 1
@ -463,7 +463,7 @@ def test_tool_use() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True])
def test_manual_tool_call_msg(use_responses_api: bool) -> None:
def test_manual_tool_call_msg(use_responses_api: bool) -> None: # noqa: FBT001
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(
model="gpt-3.5-turbo-0125", temperature=0, use_responses_api=use_responses_api
@ -504,12 +504,12 @@ def test_manual_tool_call_msg(use_responses_api: bool) -> None:
),
ToolMessage("sally_green_hair", tool_call_id="foo"),
]
with pytest.raises(Exception):
with pytest.raises(Exception): # noqa: B017
llm_with_tool.invoke(msgs)
@pytest.mark.parametrize("use_responses_api", [False, True])
def test_bind_tools_tool_choice(use_responses_api: bool) -> None:
def test_bind_tools_tool_choice(use_responses_api: bool) -> None: # noqa: FBT001
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(
model="gpt-3.5-turbo-0125", temperature=0, use_responses_api=use_responses_api
@ -574,7 +574,7 @@ def test_openai_proxy() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True])
def test_openai_response_headers(use_responses_api: bool) -> None:
def test_openai_response_headers(use_responses_api: bool) -> None: # noqa: FBT001
"""Test ChatOpenAI response headers."""
chat_openai = ChatOpenAI(
include_response_headers=True, use_responses_api=use_responses_api
@ -598,7 +598,7 @@ def test_openai_response_headers(use_responses_api: bool) -> None:
@pytest.mark.parametrize("use_responses_api", [False, True])
async def test_openai_response_headers_async(use_responses_api: bool) -> None:
async def test_openai_response_headers_async(use_responses_api: bool) -> None: # noqa: FBT001
"""Test ChatOpenAI response headers."""
chat_openai = ChatOpenAI(
include_response_headers=True, use_responses_api=use_responses_api
@ -686,7 +686,7 @@ def test_image_token_counting_png() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True])
def test_tool_calling_strict(use_responses_api: bool) -> None:
def test_tool_calling_strict(use_responses_api: bool) -> None: # noqa: FBT001
"""Test tool calling with strict=True.
Responses API appears to have fewer constraints on schema when strict=True.
@ -719,7 +719,7 @@ def test_tool_calling_strict(use_responses_api: bool) -> None:
# Test stream
full: Optional[BaseMessageChunk] = None
for chunk in model_with_tools.stream(query):
full = chunk if full is None else full + chunk # type: ignore
full = chunk if full is None else full + chunk # type: ignore[assignment]
assert isinstance(full, AIMessage)
_validate_tool_call_message(full)
@ -736,7 +736,7 @@ def test_tool_calling_strict(use_responses_api: bool) -> None:
def test_structured_output_strict(
model: str,
method: Literal["function_calling", "json_schema"],
use_responses_api: bool,
use_responses_api: bool, # noqa: FBT001
) -> None:
"""Test to verify structured output with strict=True."""
from pydantic import BaseModel as BaseModelProper
@ -775,7 +775,9 @@ def test_structured_output_strict(
@pytest.mark.parametrize("use_responses_api", [False, True])
@pytest.mark.parametrize(("model", "method"), [("gpt-4o-2024-08-06", "json_schema")])
def test_nested_structured_output_strict(
model: str, method: Literal["json_schema"], use_responses_api: bool
model: str,
method: Literal["json_schema"],
use_responses_api: bool, # noqa: FBT001
) -> None:
"""Test to verify structured output with strict=True for nested object."""
from typing import TypedDict
@ -817,7 +819,8 @@ def test_nested_structured_output_strict(
],
)
def test_json_schema_openai_format(
strict: bool, method: Literal["json_schema", "function_calling"]
strict: bool, # noqa: FBT001
method: Literal["json_schema", "function_calling"],
) -> None:
"""Test we can pass in OpenAI schema format specifying strict."""
llm = ChatOpenAI(model="gpt-4o-mini")
@ -960,7 +963,7 @@ def test_prediction_tokens() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True])
def test_stream_o_series(use_responses_api: bool) -> None:
def test_stream_o_series(use_responses_api: bool) -> None: # noqa: FBT001
list(
ChatOpenAI(model="o3-mini", use_responses_api=use_responses_api).stream(
"how are you"
@ -969,7 +972,7 @@ def test_stream_o_series(use_responses_api: bool) -> None:
@pytest.mark.parametrize("use_responses_api", [False, True])
async def test_astream_o_series(use_responses_api: bool) -> None:
async def test_astream_o_series(use_responses_api: bool) -> None: # noqa: FBT001
async for _ in ChatOpenAI(
model="o3-mini", use_responses_api=use_responses_api
).astream("how are you"):
@ -1016,7 +1019,7 @@ async def test_astream_response_format() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True])
@pytest.mark.parametrize("use_max_completion_tokens", [True, False])
def test_o1(use_max_completion_tokens: bool, use_responses_api: bool) -> None:
def test_o1(use_max_completion_tokens: bool, use_responses_api: bool) -> None: # noqa: FBT001
if use_max_completion_tokens:
kwargs: dict = {"max_completion_tokens": MAX_TOKEN_COUNT}
else:

View File

@ -123,7 +123,7 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
_ = model.invoke([message])
def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage:
def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage: # noqa: FBT001
if stream:
full = None
for chunk in llm.stream(input_):

View File

@ -334,7 +334,7 @@ def test_stateful_api() -> None:
"what's my name", previous_response_id=response.response_metadata["id"]
)
assert isinstance(second_response.content, list)
assert "bobo" in second_response.content[0]["text"].lower() # type: ignore
assert "bobo" in second_response.content[0]["text"].lower() # type: ignore[index]
def test_route_from_model_kwargs() -> None:

View File

@ -49,7 +49,7 @@ class TestOpenAIResponses(TestOpenAIStandard):
return _invoke(llm, input_, stream)
def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage:
def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage: # noqa: FBT001
if stream:
full = None
for chunk in llm.stream(input_):

View File

@ -16,7 +16,6 @@ DEPLOYMENT_NAME = os.environ.get(
"AZURE_OPENAI_DEPLOYMENT_NAME",
os.environ.get("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME", ""),
)
print
def _get_embeddings(**kwargs: Any) -> AzureOpenAIEmbeddings:
@ -111,7 +110,7 @@ def test_azure_openai_embedding_with_empty_string() -> None:
api_key=OPENAI_API_KEY,
azure_endpoint=OPENAI_API_BASE,
azure_deployment=DEPLOYMENT_NAME,
) # type: ignore
)
.embeddings.create(input="", model="text-embedding-ada-002")
.data[0]
.embedding

View File

@ -80,8 +80,8 @@ def test_structured_output_old_model() -> None:
).with_structured_output(Output)
# assert tool calling was used instead of json_schema
assert "tools" in llm.steps[0].kwargs # type: ignore
assert "response_format" not in llm.steps[0].kwargs # type: ignore
assert "tools" in llm.steps[0].kwargs # type: ignore[attr-defined]
assert "response_format" not in llm.steps[0].kwargs # type: ignore[attr-defined]
def test_max_completion_tokens_in_payload() -> None:

View File

@ -326,7 +326,7 @@ class MockSyncContextManager:
GLM4_STREAM_META = """{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4eba\u5de5\u667a\u80fd"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u52a9\u624b"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":""}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":","}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4f60\u53ef\u4ee5"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u53eb\u6211"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"AI"}}]}
@ -339,12 +339,7 @@ GLM4_STREAM_META = """{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","cr
@pytest.fixture
def mock_glm4_completion() -> list:
list_chunk_data = GLM4_STREAM_META.split("\n")
result_list = []
for msg in list_chunk_data:
if msg != "[DONE]":
result_list.append(json.loads(msg))
return result_list
return [json.loads(msg) for msg in list_chunk_data if msg != "[DONE]"]
async def test_glm4_astream(mock_glm4_completion: list) -> None:
@ -360,7 +355,7 @@ async def test_glm4_astream(mock_glm4_completion: list) -> None:
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "async_client", mock_client):
async for chunk in llm.astream("你的名字叫什么只回答名字"):
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
@ -385,7 +380,7 @@ def test_glm4_stream(mock_glm4_completion: list) -> None:
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "client", mock_client):
for chunk in llm.stream("你的名字叫什么只回答名字"):
for chunk in llm.stream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
@ -402,7 +397,7 @@ DEEPSEEK_STREAM_DATA = """{"id":"d3610c24e6b42518a7883ea57c3ea2c3","choices":[{"
{"choices":[{"delta":{"content":"Deep","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"Seek","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":" Chat","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":",","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"一个","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"深度","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
@ -420,12 +415,7 @@ DEEPSEEK_STREAM_DATA = """{"id":"d3610c24e6b42518a7883ea57c3ea2c3","choices":[{"
@pytest.fixture
def mock_deepseek_completion() -> list[dict]:
list_chunk_data = DEEPSEEK_STREAM_DATA.split("\n")
result_list = []
for msg in list_chunk_data:
if msg != "[DONE]":
result_list.append(json.loads(msg))
return result_list
return [json.loads(msg) for msg in list_chunk_data if msg != "[DONE]"]
async def test_deepseek_astream(mock_deepseek_completion: list) -> None:
@ -440,7 +430,7 @@ async def test_deepseek_astream(mock_deepseek_completion: list) -> None:
usage_chunk = mock_deepseek_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "async_client", mock_client):
async for chunk in llm.astream("你的名字叫什么只回答名字"):
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
@ -464,7 +454,7 @@ def test_deepseek_stream(mock_deepseek_completion: list) -> None:
usage_chunk = mock_deepseek_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "client", mock_client):
for chunk in llm.stream("你的名字叫什么只回答名字"):
for chunk in llm.stream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
@ -488,12 +478,7 @@ OPENAI_STREAM_DATA = """{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":
@pytest.fixture
def mock_openai_completion() -> list[dict]:
list_chunk_data = OPENAI_STREAM_DATA.split("\n")
result_list = []
for msg in list_chunk_data:
if msg != "[DONE]":
result_list.append(json.loads(msg))
return result_list
return [json.loads(msg) for msg in list_chunk_data if msg != "[DONE]"]
async def test_openai_astream(mock_openai_completion: list) -> None:
@ -508,7 +493,7 @@ async def test_openai_astream(mock_openai_completion: list) -> None:
usage_chunk = mock_openai_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "async_client", mock_client):
async for chunk in llm.astream("你的名字叫什么只回答名字"):
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
@ -532,7 +517,7 @@ def test_openai_stream(mock_openai_completion: list) -> None:
usage_chunk = mock_openai_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "client", mock_client):
for chunk in llm.stream("你的名字叫什么只回答名字"):
for chunk in llm.stream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
@ -805,7 +790,7 @@ class MakeASandwich(BaseModel):
],
)
@pytest.mark.parametrize("strict", [True, False, None])
def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> None:
def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> None: # noqa: FBT001
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm.bind_tools(
@ -822,8 +807,8 @@ def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> Non
def test_with_structured_output(
schema: Union[type, dict[str, Any], None],
method: Literal["function_calling", "json_mode", "json_schema"],
include_raw: bool,
strict: Optional[bool],
include_raw: bool, # noqa: FBT001
strict: Optional[bool], # noqa: FBT001
) -> None:
"""Test passing in manually construct tool call message."""
if method == "json_mode":
@ -1044,7 +1029,8 @@ def test__convert_to_openai_response_format() -> None:
@pytest.mark.parametrize("method", ["function_calling", "json_schema"])
@pytest.mark.parametrize("strict", [True, None])
def test_structured_output_strict(
method: Literal["function_calling", "json_schema"], strict: Optional[bool]
method: Literal["function_calling", "json_schema"],
strict: Optional[bool], # noqa: FBT001
) -> None:
"""Test to verify structured output with strict=True."""
llm = ChatOpenAI(model="gpt-4o-2024-08-06")
@ -1167,8 +1153,8 @@ def test_structured_output_old_model() -> None:
with pytest.warns(match="Cannot use method='json_schema'"):
llm = ChatOpenAI(model="gpt-4").with_structured_output(Output)
# assert tool calling was used instead of json_schema
assert "tools" in llm.steps[0].kwargs # type: ignore
assert "response_format" not in llm.steps[0].kwargs # type: ignore
assert "tools" in llm.steps[0].kwargs # type: ignore[attr-defined]
assert "response_format" not in llm.steps[0].kwargs # type: ignore[attr-defined]
def test_structured_outputs_parser() -> None:

View File

@ -32,7 +32,6 @@ def test_embed_documents_with_custom_chunk_size() -> None:
result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
_, tokens, __ = embeddings._tokenize(texts, custom_chunk_size)
mock_create.call_args
mock_create.assert_any_call(input=tokens[0:3], **embeddings._invocation_params)
mock_create.assert_any_call(input=tokens[3:4], **embeddings._invocation_params)
@ -52,7 +51,6 @@ def test_embed_documents_with_custom_chunk_size_no_check_ctx_length() -> None:
result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
mock_create.call_args
mock_create.assert_any_call(input=texts[0:3], **embeddings._invocation_params)
mock_create.assert_any_call(input=texts[3:4], **embeddings._invocation_params)