Compare commits

...

2 Commits

Author SHA1 Message Date
Mason Daugherty
ca10f0eee5 fixes 2025-07-08 21:50:58 -04:00
Mason Daugherty
e7e8391fd0 init fixes 2025-07-08 21:27:25 -04:00
35 changed files with 2011 additions and 1899 deletions

View File

@@ -3,10 +3,10 @@ from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_openai.llms import AzureOpenAI, OpenAI
__all__ = [
"OpenAI",
"ChatOpenAI",
"OpenAIEmbeddings",
"AzureOpenAI",
"AzureChatOpenAI",
"AzureOpenAI",
"AzureOpenAIEmbeddings",
"ChatOpenAI",
"OpenAI",
"OpenAIEmbeddings",
]

View File

@@ -1,4 +1,4 @@
from langchain_openai.chat_models.azure import AzureChatOpenAI
from langchain_openai.chat_models.base import ChatOpenAI
__all__ = ["ChatOpenAI", "AzureChatOpenAI"]
__all__ = ["AzureChatOpenAI", "ChatOpenAI"]

View File

@@ -6,7 +6,10 @@ for each instance of ChatOpenAI.
Logic is largely replicated from openai._base_client.
"""
from __future__ import annotations
import asyncio
import contextlib
import os
from functools import lru_cache
from typing import Any, Optional
@@ -15,30 +18,26 @@ import openai
class _SyncHttpxClientWrapper(openai.DefaultHttpxClient):
"""Borrowed from openai._base_client"""
"""Borrowed from openai._base_client."""
def __del__(self) -> None:
if self.is_closed:
return
try:
with contextlib.suppress(Exception):
self.close()
except Exception: # noqa: S110
pass
class _AsyncHttpxClientWrapper(openai.DefaultAsyncHttpxClient):
"""Borrowed from openai._base_client"""
"""Borrowed from openai._base_client."""
def __del__(self) -> None:
if self.is_closed:
return
try:
# TODO(someday): support non asyncio runtimes here
# TODO(someday): support non asyncio runtimes here
with contextlib.suppress(Exception):
asyncio.get_running_loop().create_task(self.aclose())
except Exception: # noqa: S110
pass
def _build_sync_httpx_client(

View File

@@ -1,5 +1,4 @@
"""
This module converts between AIMessage output formats for the Responses API.
"""Converts between AIMessage output formats for the Responses API.
ChatOpenAI v0.3 stores reasoning and tool outputs in AIMessage.additional_kwargs:
@@ -16,7 +15,11 @@ ChatOpenAI v0.3 stores reasoning and tool outputs in AIMessage.additional_kwargs
"summary": [{"type": "summary_text", "text": "Reasoning summary"}],
},
"tool_outputs": [
{"type": "web_search_call", "id": "websearch_123", "status": "completed"}
{
"type": "web_search_call",
"id": "websearch_123",
"status": "completed",
}
],
"refusal": "I cannot assist with that.",
},
@@ -54,7 +57,9 @@ content blocks, rather than on the AIMessage.id, which now stores the response I
For backwards compatibility, this module provides functions to convert between the
old and new formats. The functions are used internally by ChatOpenAI.
""" # noqa: E501
"""
from __future__ import annotations
import json
from typing import Union
@@ -65,7 +70,8 @@ _FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
def _convert_to_v03_ai_message(
message: AIMessage, has_reasoning: bool = False
message: AIMessage,
has_reasoning: bool = False, # noqa: FBT001, FBT002
) -> AIMessage:
"""Mutate an AIMessage to the old-style v0.3 format."""
if isinstance(message.content, list):

View File

@@ -5,12 +5,11 @@ from __future__ import annotations
import logging
import os
from collections.abc import AsyncIterator, Awaitable, Iterator
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
from typing import Any, Callable, Optional, TypeVar, Union
import openai
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import LangSmithParams
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.utils import from_env, secret_from_env
@@ -28,18 +27,12 @@ _DictOrPydanticClass = Union[dict[str, Any], type[_BM]]
_DictOrPydantic = Union[dict, _BM]
class _AllReturnType(TypedDict):
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and is_basemodel_subclass(obj)
class AzureChatOpenAI(BaseChatOpenAI):
"""Azure OpenAI chat model integration.
r"""Azure OpenAI chat model integration.
Setup:
Head to the https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart?tabs=command-line%2Cpython-new&pivots=programming-language-python
@@ -138,7 +131,11 @@ class AzureChatOpenAI(BaseChatOpenAI):
AIMessage(
content="J'adore programmer.",
usage_metadata={"input_tokens": 28, "output_tokens": 6, "total_tokens": 34},
usage_metadata={
"input_tokens": 28,
"output_tokens": 6,
"total_tokens": 34,
},
response_metadata={
"token_usage": {
"completion_tokens": 6,
@@ -184,8 +181,12 @@ class AzureChatOpenAI(BaseChatOpenAI):
AIMessageChunk(content="ad", id="run-a6f294d3-0700-4f6a-abc2-c6ef1178c37f")
AIMessageChunk(content="ore", id="run-a6f294d3-0700-4f6a-abc2-c6ef1178c37f")
AIMessageChunk(content=" la", id="run-a6f294d3-0700-4f6a-abc2-c6ef1178c37f")
AIMessageChunk(content=" programm", id="run-a6f294d3-0700-4f6a-abc2-c6ef1178c37f")
AIMessageChunk(content="ation", id="run-a6f294d3-0700-4f6a-abc2-c6ef1178c37f")
AIMessageChunk(
content=" programm", id="run-a6f294d3-0700-4f6a-abc2-c6ef1178c37f"
)
AIMessageChunk(
content="ation", id="run-a6f294d3-0700-4f6a-abc2-c6ef1178c37f"
)
AIMessageChunk(content=".", id="run-a6f294d3-0700-4f6a-abc2-c6ef1178c37f")
AIMessageChunk(
content="",
@@ -294,7 +295,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
setup: str = Field(description="The setup of the joke")
punchline: str = Field(description="The punchline to the joke")
rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
rating: Optional[int] = Field(
description="How funny the joke is, from 1 to 10"
)
structured_llm = llm.with_structured_output(Joke)
@@ -465,8 +468,8 @@ class AzureChatOpenAI(BaseChatOpenAI):
Example: `https://example-resource.azure.openai.com/`
"""
deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment")
"""A model deployment.
"""A model deployment.
If given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints.
"""
@@ -497,27 +500,27 @@ class AzureChatOpenAI(BaseChatOpenAI):
"""
azure_ad_token_provider: Union[Callable[[], str], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every sync request. For async requests,
will be invoked if `azure_ad_async_token_provider` is not provided.
"""
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every async request.
"""
model_version: str = ""
"""The version of the model (e.g. "0125" for gpt-3.5-0125).
Azure OpenAI doesn't return model version with the response by default so it must
Azure OpenAI doesn't return model version with the response by default so it must
be manually specified if you want to use this information downstream, e.g. when
calculating costs.
When you specify the version, it will be appended to the model name in the
response. Setting correct version will help you to calculate the cost properly.
Model version is not validated, so make sure you set it correctly to get the
When you specify the version, it will be appended to the model name in the
response. Setting correct version will help you to calculate the cost properly.
Model version is not validated, so make sure you set it correctly to get the
correct cost.
"""
@@ -527,36 +530,36 @@ class AzureChatOpenAI(BaseChatOpenAI):
"""Legacy, for openai<1.0.0 support."""
validate_base_url: bool = True
"""If legacy arg openai_api_base is passed in, try to infer if it is a base_url or
"""If legacy arg openai_api_base is passed in, try to infer if it is a base_url or
azure_endpoint and update client params accordingly.
"""
model_name: Optional[str] = Field(default=None, alias="model") # type: ignore[assignment]
"""Name of the deployed OpenAI model, e.g. "gpt-4o", "gpt-35-turbo", etc.
"""Name of the deployed OpenAI model, e.g. "gpt-4o", "gpt-35-turbo", etc.
Distinct from the Azure deployment name, which is set by the Azure user.
Used for tracing and token counting. Does NOT affect completion.
"""
disabled_params: Optional[dict[str, Any]] = Field(default=None)
"""Parameters of the OpenAI client or chat.completions endpoint that should be
"""Parameters of the OpenAI client or chat.completions endpoint that should be
disabled for the given model.
Should be specified as ``{"param": None | ['val1', 'val2']}`` where the key is the
Should be specified as ``{"param": None | ['val1', 'val2']}`` where the key is the
parameter and the value is either None, meaning that parameter should never be
used, or it's a list of disabled values for the parameter.
For example, older models may not support the 'parallel_tool_calls' parameter at
all, in which case ``disabled_params={"parallel_tool_calls: None}`` can ben passed
For example, older models may not support the 'parallel_tool_calls' parameter at
all, in which case ``disabled_params={"parallel_tool_calls: None}`` can ben passed
in.
If a parameter is disabled then it will not be used by default in any methods, e.g.
in
in
:meth:`~langchain_openai.chat_models.azure.AzureChatOpenAI.with_structured_output`.
However this does not prevent a user from directly passed in the parameter during
invocation.
By default, unless ``model_name="gpt-4o"`` is specified, then
invocation.
By default, unless ``model_name="gpt-4o"`` is specified, then
'parallel_tools_calls' will be disabled.
"""
@@ -580,9 +583,11 @@ class AzureChatOpenAI(BaseChatOpenAI):
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n is not None and self.n < 1:
raise ValueError("n must be at least 1.")
elif self.n is not None and self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")
msg = "n must be at least 1."
raise ValueError(msg)
if self.n is not None and self.n > 1 and self.streaming:
msg = "n must be 1 when streaming."
raise ValueError(msg)
if self.disabled_params is None:
# As of 09-17-2024 'parallel_tool_calls' param is only supported for gpt-4o.
@@ -602,13 +607,14 @@ class AzureChatOpenAI(BaseChatOpenAI):
openai_api_base = self.openai_api_base
if openai_api_base and self.validate_base_url:
if "/openai" not in openai_api_base:
raise ValueError(
msg = (
"As of openai>=1.0.0, Azure endpoints should be specified via "
"the `azure_endpoint` param not `openai_api_base` "
"(or alias `base_url`)."
)
raise ValueError(msg)
if self.deployment_name:
raise ValueError(
msg = (
"As of openai>=1.0.0, if `azure_deployment` (or alias "
"`deployment_name`) is specified then "
"`base_url` (or alias `openai_api_base`) should not be. "
@@ -620,6 +626,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
"Or you can equivalently specify:\n\n"
'base_url="https://xxx.openai.azure.com/openai/deployments/my-deployment"'
)
raise ValueError(msg)
client_params: dict = {
"api_version": self.openai_api_version,
"azure_endpoint": self.azure_endpoint,
@@ -665,10 +672,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
@property
def _identifying_params(self) -> dict[str, Any]:
"""Get the identifying parameters."""
return {
**{"azure_deployment": self.deployment_name},
**super()._identifying_params,
}
return {"azure_deployment": self.deployment_name, **super()._identifying_params}
@property
def _llm_type(self) -> str:
@@ -709,10 +713,11 @@ class AzureChatOpenAI(BaseChatOpenAI):
response = response.model_dump()
for res in response["choices"]:
if res.get("finish_reason", None) == "content_filter":
raise ValueError(
msg = (
"Azure has not provided the response due to a content filter "
"being triggered"
)
raise ValueError(msg)
if "model" in response:
model = response["model"]
@@ -740,8 +745,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
"""Route to Chat Completions or Responses API."""
if self._use_responses_api({**kwargs, **self.model_kwargs}):
return super()._stream_responses(*args, **kwargs)
else:
return super()._stream(*args, **kwargs)
return super()._stream(*args, **kwargs)
async def _astream(
self, *args: Any, **kwargs: Any
@@ -763,7 +767,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
strict: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
"""Model wrapper that returns outputs formatted to match the given schema.
r"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema:
@@ -930,7 +934,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
)
llm = AzureChatOpenAI(azure_deployment="...", model="gpt-4o", temperature=0)
llm = AzureChatOpenAI(
azure_deployment="...", model="gpt-4o", temperature=0
)
structured_llm = llm.with_structured_output(AnswerWithJustification)
structured_llm.invoke(
@@ -961,7 +967,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
)
llm = AzureChatOpenAI(azure_deployment="...", model="gpt-4o", temperature=0)
llm = AzureChatOpenAI(
azure_deployment="...", model="gpt-4o", temperature=0
)
structured_llm = llm.with_structured_output(
AnswerWithJustification, method="function_calling"
)
@@ -990,7 +998,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
justification: str
llm = AzureChatOpenAI(azure_deployment="...", model="gpt-4o", temperature=0)
llm = AzureChatOpenAI(
azure_deployment="...", model="gpt-4o", temperature=0
)
structured_llm = llm.with_structured_output(
AnswerWithJustification, include_raw=True
)
@@ -1022,7 +1032,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
]
llm = AzureChatOpenAI(azure_deployment="...", model="gpt-4o", temperature=0)
llm = AzureChatOpenAI(
azure_deployment="...", model="gpt-4o", temperature=0
)
structured_llm = llm.with_structured_output(AnswerWithJustification)
structured_llm.invoke(
@@ -1119,6 +1131,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
# },
# 'parsing_error': None
# }
""" # noqa: E501
return super().with_structured_output(
schema, method=method, include_raw=include_raw, strict=strict, **kwargs

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import base64
import contextlib
import json
import logging
import os
@@ -138,13 +139,14 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
Returns:
The LangChain message.
"""
role = _dict.get("role")
name = _dict.get("name")
id_ = _dict.get("id")
if role == "user":
return HumanMessage(content=_dict.get("content", ""), id=id_, name=name)
elif role == "assistant":
if role == "assistant":
# Fix for azure
# Also OpenAI returns None for tool invocations
content = _dict.get("content", "") or ""
@@ -172,22 +174,19 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
elif role in ("system", "developer"):
if role == "developer":
additional_kwargs = {"__openai_role__": role}
else:
additional_kwargs = {}
if role in ("system", "developer"):
additional_kwargs = {"__openai_role__": role} if role == "developer" else {}
return SystemMessage(
content=_dict.get("content", ""),
name=name,
id=id_,
additional_kwargs=additional_kwargs,
)
elif role == "function":
if role == "function":
return FunctionMessage(
content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_
)
elif role == "tool":
if role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
@@ -198,8 +197,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
name=name,
id=id_,
)
else:
return ChatMessage(content=_dict.get("content", ""), role=role, id=id_) # type: ignore[arg-type]
return ChatMessage(content=_dict.get("content", ""), role=role, id=id_) # type: ignore[arg-type]
def _format_message_content(content: Any) -> Any:
@@ -214,7 +212,7 @@ def _format_message_content(content: Any) -> Any:
and block["type"] in ("tool_use", "thinking", "reasoning_content")
):
continue
elif isinstance(block, dict) and is_data_content_block(block):
if isinstance(block, dict) and is_data_content_block(block):
formatted_content.append(convert_to_openai_data_block(block))
# Anthropic image blocks
elif (
@@ -255,6 +253,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
Returns:
The dictionary.
"""
message_dict: dict[str, Any] = {"content": _format_message_content(message.content)}
if (name := message.name or message.additional_kwargs.get("name")) is not None:
@@ -314,7 +313,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
supported_props = {"content", "role", "tool_call_id"}
message_dict = {k: v for k, v in message_dict.items() if k in supported_props}
else:
raise TypeError(f"Got unknown type {message}")
msg = f"Got unknown type {message}"
raise TypeError(msg)
return message_dict
@@ -333,7 +333,7 @@ def _convert_delta_to_message_chunk(
tool_call_chunks = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
with contextlib.suppress(KeyError):
tool_call_chunks = [
tool_call_chunk(
name=rtc["function"].get("name"),
@@ -343,19 +343,17 @@ def _convert_delta_to_message_chunk(
)
for rtc in raw_tool_calls
]
except KeyError:
pass
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content, id=id_)
elif role == "assistant" or default_class == AIMessageChunk:
if role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
id=id_,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
)
elif role in ("system", "developer") or default_class == SystemMessageChunk:
if role in ("system", "developer") or default_class == SystemMessageChunk:
if role == "developer":
additional_kwargs = {"__openai_role__": "developer"}
else:
@@ -363,16 +361,15 @@ def _convert_delta_to_message_chunk(
return SystemMessageChunk(
content=content, id=id_, additional_kwargs=additional_kwargs
)
elif role == "function" or default_class == FunctionMessageChunk:
if role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
elif role == "tool" or default_class == ToolMessageChunk:
if role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(
content=content, tool_call_id=_dict["tool_call_id"], id=id_
)
elif role or default_class == ChatMessageChunk:
if role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role, id=id_)
else:
return default_class(content=content, id=id_) # type: ignore
return default_class(content=content, id=id_) # type: ignore[call-arg]
def _update_token_usage(
@@ -382,24 +379,25 @@ def _update_token_usage(
# `reasoning_tokens` is nested inside `completion_tokens_details`
if isinstance(new_usage, int):
if not isinstance(overall_token_usage, int):
raise ValueError(
msg = (
f"Got different types for token usage: "
f"{type(new_usage)} and {type(overall_token_usage)}"
)
raise ValueError(msg)
return new_usage + overall_token_usage
elif isinstance(new_usage, dict):
if isinstance(new_usage, dict):
if not isinstance(overall_token_usage, dict):
raise ValueError(
msg = (
f"Got different types for token usage: "
f"{type(new_usage)} and {type(overall_token_usage)}"
)
raise ValueError(msg)
return {
k: _update_token_usage(overall_token_usage.get(k, 0), v)
for k, v in new_usage.items()
}
else:
warnings.warn(f"Unexpected type for token usage: {type(new_usage)}")
return new_usage
warnings.warn(f"Unexpected type for token usage: {type(new_usage)}", stacklevel=3)
return new_usage
def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
@@ -412,20 +410,19 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
"langchain-openai==0.3. To use `with_structured_output` with this model, "
'specify `method="function_calling"`.'
)
warnings.warn(message)
warnings.warn(message, stacklevel=3)
raise e
elif "Invalid schema for response_format" in e.message:
if "Invalid schema for response_format" in e.message:
message = (
"Invalid schema for OpenAI's structured output feature, which is the "
"default method for `with_structured_output` as of langchain-openai==0.3. "
'Specify `method="function_calling"` instead or update your schema. '
"See supported schemas: "
"https://platform.openai.com/docs/guides/structured-outputs#supported-schemas" # noqa: E501
"https://platform.openai.com/docs/guides/structured-outputs#supported-schemas"
)
warnings.warn(message)
warnings.warn(message, stacklevel=3)
raise e
else:
raise
raise
class _FunctionCall(TypedDict):
@@ -437,12 +434,6 @@ _DictOrPydanticClass = Union[dict[str, Any], type[_BM], type]
_DictOrPydantic = Union[dict, _BM]
class _AllReturnType(TypedDict):
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
class BaseChatOpenAI(BaseChatModel):
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
@@ -458,7 +449,7 @@ class BaseChatOpenAI(BaseChatModel):
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
)
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
"""Base URL path for API requests, leave blank if not using a proxy or service
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator."""
openai_organization: Optional[str] = Field(default=None, alias="organization")
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
@@ -469,7 +460,7 @@ class BaseChatOpenAI(BaseChatModel):
request_timeout: Union[float, tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
None."""
stream_usage: bool = False
"""Whether to include usage metadata in streaming output. If True, an additional
@@ -489,7 +480,7 @@ class BaseChatOpenAI(BaseChatModel):
"""Whether to return logprobs."""
top_logprobs: Optional[int] = None
"""Number of most likely tokens to return at each token position, each with
an associated log probability. `logprobs` must be set to true
an associated log probability. `logprobs` must be set to true
if this parameter is used."""
logit_bias: Optional[dict[int, int]] = None
"""Modify the likelihood of specified tokens appearing in the completion."""
@@ -507,7 +498,7 @@ class BaseChatOpenAI(BaseChatModel):
Reasoning models only, like OpenAI o1, o3, and o4-mini.
Currently supported values are low, medium, and high. Reducing reasoning effort
Currently supported values are low, medium, and high. Reducing reasoning effort
can result in faster responses and fewer tokens used on reasoning in a response.
.. versionadded:: 0.2.14
@@ -528,25 +519,25 @@ class BaseChatOpenAI(BaseChatModel):
.. versionadded:: 0.3.24
"""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain
them to be under a certain limit. By default, when set to None, this will
be the same as the embedding model name. However, there are some cases
where you may want to use this Embedding class with a model name not
supported by tiktoken. This can include when using Azure embeddings or
when using one of the many model providers that expose an OpenAI-like
API but with different models. In those cases, in order to avoid erroring
"""The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain
them to be under a certain limit. By default, when set to None, this will
be the same as the embedding model name. However, there are some cases
where you may want to use this Embedding class with a model name not
supported by tiktoken. This can include when using Azure embeddings or
when using one of the many model providers that expose an OpenAI-like
API but with different models. In those cases, in order to avoid erroring
when tiktoken is called, you can specify a model name to use here."""
default_headers: Union[Mapping[str, str], None] = None
default_query: Union[Mapping[str, object], None] = None
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = Field(default=None, exclude=True)
"""Optional httpx.Client. Only used for sync invocations. Must specify
"""Optional httpx.Client. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations.
"""
http_async_client: Union[Any, None] = Field(default=None, exclude=True)
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences."""
@@ -556,21 +547,21 @@ class BaseChatOpenAI(BaseChatModel):
include_response_headers: bool = False
"""Whether to include response headers in the output message response_metadata."""
disabled_params: Optional[dict[str, Any]] = Field(default=None)
"""Parameters of the OpenAI client or chat.completions endpoint that should be
"""Parameters of the OpenAI client or chat.completions endpoint that should be
disabled for the given model.
Should be specified as ``{"param": None | ['val1', 'val2']}`` where the key is the
Should be specified as ``{"param": None | ['val1', 'val2']}`` where the key is the
parameter and the value is either None, meaning that parameter should never be
used, or it's a list of disabled values for the parameter.
For example, older models may not support the 'parallel_tool_calls' parameter at
all, in which case ``disabled_params={"parallel_tool_calls": None}`` can be passed
For example, older models may not support the 'parallel_tool_calls' parameter at
all, in which case ``disabled_params={"parallel_tool_calls": None}`` can be passed
in.
If a parameter is disabled then it will not be used by default in any methods, e.g.
in :meth:`~langchain_openai.chat_models.base.ChatOpenAI.with_structured_output`.
However this does not prevent a user from directly passed in the parameter during
invocation.
invocation.
"""
include: Optional[list[str]] = None
@@ -675,13 +666,12 @@ class BaseChatOpenAI(BaseChatModel):
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names)
return values
return _build_model_kwargs(values, all_required_field_names)
@model_validator(mode="before")
@classmethod
def validate_temperature(cls, values: dict[str, Any]) -> Any:
"""Currently o1 models only allow temperature=1."""
"""o1 models only allow temperature=1."""
model = values.get("model_name") or values.get("model") or ""
if model.startswith("o1") and "temperature" not in values:
values["temperature"] = 1
@@ -691,9 +681,11 @@ class BaseChatOpenAI(BaseChatModel):
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n is not None and self.n < 1:
raise ValueError("n must be at least 1.")
elif self.n is not None and self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")
msg = "n must be at least 1."
raise ValueError(msg)
if self.n is not None and self.n > 1 and self.streaming:
msg = "n must be 1 when streaming."
raise ValueError(msg)
# Check OPENAI_ORGANIZATION for backwards compatibility.
self.openai_organization = (
@@ -719,20 +711,22 @@ class BaseChatOpenAI(BaseChatModel):
openai_proxy = self.openai_proxy
http_client = self.http_client
http_async_client = self.http_async_client
raise ValueError(
msg = (
"Cannot specify 'openai_proxy' if one of "
"'http_client'/'http_async_client' is already specified. Received:\n"
f"{openai_proxy=}\n{http_client=}\n{http_async_client=}"
)
raise ValueError(msg)
if not self.client:
if self.openai_proxy and not self.http_client:
try:
import httpx
except ImportError as e:
raise ImportError(
msg = (
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
)
raise ImportError(msg) from e
self.http_client = httpx.Client(
proxy=self.openai_proxy, verify=global_ssl_context
)
@@ -747,10 +741,11 @@ class BaseChatOpenAI(BaseChatModel):
try:
import httpx
except ImportError as e:
raise ImportError(
msg = (
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
)
raise ImportError(msg) from e
self.http_async_client = httpx.AsyncClient(
proxy=self.openai_proxy, verify=global_ssl_context
)
@@ -791,15 +786,13 @@ class BaseChatOpenAI(BaseChatModel):
"store": self.store,
}
params = {
return {
"model": self.model_name,
"stream": self.streaming,
**{k: v for k, v in exclude_if_none.items() if v is not None},
**self.model_kwargs,
}
return params
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
system_fingerprint = None
@@ -845,11 +838,10 @@ class BaseChatOpenAI(BaseChatModel):
)
if len(choices) == 0:
# logprobs is implicitly None
generation_chunk = ChatGenerationChunk(
return ChatGenerationChunk(
message=default_chunk_class(content="", usage_metadata=usage_metadata),
generation_info=base_generation_info,
)
return generation_chunk
choice = choices[0]
if choice["delta"] is None:
@@ -876,10 +868,9 @@ class BaseChatOpenAI(BaseChatModel):
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
message_chunk.usage_metadata = usage_metadata
generation_chunk = ChatGenerationChunk(
return ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
return generation_chunk
def _stream_responses(
self,
@@ -990,7 +981,9 @@ class BaseChatOpenAI(BaseChatModel):
yield generation_chunk
def _should_stream_usage(
self, stream_usage: Optional[bool] = None, **kwargs: Any
self,
stream_usage: Optional[bool] = None, # noqa: FBT001
**kwargs: Any,
) -> bool:
"""Determine whether to include usage metadata in streaming output.
@@ -1029,7 +1022,8 @@ class BaseChatOpenAI(BaseChatModel):
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
"specified.",
stacklevel=2,
)
payload.pop("stream")
response_stream = self.root_client.beta.chat.completions.stream(**payload)
@@ -1096,7 +1090,8 @@ class BaseChatOpenAI(BaseChatModel):
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
"specified.",
stacklevel=2,
)
payload.pop("stream")
try:
@@ -1133,18 +1128,15 @@ class BaseChatOpenAI(BaseChatModel):
def _use_responses_api(self, payload: dict) -> bool:
if isinstance(self.use_responses_api, bool):
return self.use_responses_api
elif self.output_version == "responses/v1":
if (
self.output_version == "responses/v1"
or self.include is not None
or self.reasoning is not None
or self.truncation is not None
or self.use_previous_response_id
):
return True
elif self.include is not None:
return True
elif self.reasoning is not None:
return True
elif self.truncation is not None:
return True
elif self.use_previous_response_id:
return True
else:
return _use_responses_api(payload)
return _use_responses_api(payload)
def _get_request_payload(
self,
@@ -1192,12 +1184,12 @@ class BaseChatOpenAI(BaseChatModel):
try:
choices = response_dict["choices"]
except KeyError as e:
raise KeyError(
f"Response missing `choices` key: {response_dict.keys()}"
) from e
msg = f"Response missing `choices` key: {response_dict.keys()}"
raise KeyError(msg) from e
if choices is None:
raise TypeError("Received response with null value for `choices`.")
msg = "Received response with null value for `choices`."
raise TypeError(msg)
token_usage = response_dict.get("usage")
@@ -1257,7 +1249,8 @@ class BaseChatOpenAI(BaseChatModel):
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
"specified.",
stacklevel=2,
)
payload.pop("stream")
response_stream = self.root_async_client.beta.chat.completions.stream(
@@ -1328,7 +1321,8 @@ class BaseChatOpenAI(BaseChatModel):
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
"specified.",
stacklevel=2,
)
payload.pop("stream")
try:
@@ -1439,7 +1433,7 @@ class BaseChatOpenAI(BaseChatModel):
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
# tiktoken NOT supported for Python 3.7 or below
if sys.version_info[1] <= 7:
if sys.version_info[1] <= 7: # noqa: YTT203
return super().get_token_ids(text)
_, encoding_model = self._get_encoding_model()
return encoding_model.encode(text)
@@ -1459,20 +1453,21 @@ class BaseChatOpenAI(BaseChatModel):
as a URL. If these aren't installed image inputs will be ignored in token
counting.
OpenAI reference: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
OpenAI reference: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
"""
# TODO: Count bound tools as part of input.
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools."
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
)
if sys.version_info[1] <= 7:
if sys.version_info[1] <= 7: # noqa: YTT203
return super().get_num_tokens_from_messages(messages)
model, encoding = self._get_encoding_model()
if model.startswith("gpt-3.5-turbo-0301"):
@@ -1480,16 +1475,17 @@ class BaseChatOpenAI(BaseChatModel):
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
elif model.startswith(("gpt-3.5-turbo", "gpt-4")):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
msg = (
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}. See "
"https://platform.openai.com/docs/guides/text-generation/managing-tokens" # noqa: E501
"https://platform.openai.com/docs/guides/text-generation/managing-tokens"
" for information on how messages are converted to tokens."
)
raise NotImplementedError(msg)
num_tokens = 0
messages_dict = [_convert_message_to_dict(m) for m in messages]
for message in messages_dict:
@@ -1524,13 +1520,12 @@ class BaseChatOpenAI(BaseChatModel):
elif val["type"] == "file":
warnings.warn(
"Token counts for file inputs are not supported. "
"Ignoring file inputs."
"Ignoring file inputs.",
stacklevel=2,
)
pass
else:
raise ValueError(
f"Unrecognized content block type\n\n{val}"
)
msg = f"Unrecognized content block type\n\n{val}"
raise ValueError(msg)
elif not value:
continue
else:
@@ -1552,7 +1547,7 @@ class BaseChatOpenAI(BaseChatModel):
self,
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
function_call: Optional[
Union[_FunctionCall, str, Literal["auto", "none"]]
Union[_FunctionCall, str, Literal["auto", "none"]] # noqa: PYI051
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
@@ -1575,8 +1570,8 @@ class BaseChatOpenAI(BaseChatModel):
(if any).
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
"""
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
if function_call is not None:
function_call = (
@@ -1586,18 +1581,20 @@ class BaseChatOpenAI(BaseChatModel):
else function_call
)
if isinstance(function_call, dict) and len(formatted_functions) != 1:
raise ValueError(
msg = (
"When specifying `function_call`, you must provide exactly one "
"function."
)
raise ValueError(msg)
if (
isinstance(function_call, dict)
and formatted_functions[0]["name"] != function_call["name"]
):
raise ValueError(
msg = (
f"Function call {function_call} was specified, but the only "
f"provided function was {formatted_functions[0]['name']}."
)
raise ValueError(msg)
kwargs = {**kwargs, "function_call": function_call}
return super().bind(functions=formatted_functions, **kwargs)
@@ -1606,7 +1603,7 @@ class BaseChatOpenAI(BaseChatModel):
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
Union[dict, str, Literal["auto", "none", "required", "any"], bool] # noqa: PYI051
] = None,
strict: Optional[bool] = None,
parallel_tool_calls: Optional[bool] = None,
@@ -1645,7 +1642,6 @@ class BaseChatOpenAI(BaseChatModel):
Support for ``strict`` argument added.
""" # noqa: E501
if parallel_tool_calls is not None:
kwargs["parallel_tool_calls"] = parallel_tool_calls
formatted_tools = [
@@ -1680,10 +1676,11 @@ class BaseChatOpenAI(BaseChatModel):
elif isinstance(tool_choice, dict):
pass
else:
raise ValueError(
msg = (
f"Unrecognized tool_choice type. Expected str, bool or dict. "
f"Received: {tool_choice}"
)
raise ValueError(msg)
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)
@@ -1829,11 +1826,11 @@ class BaseChatOpenAI(BaseChatModel):
.. versionchanged:: 0.3.21
Pass ``kwargs`` through to the model.
""" # noqa: E501
if strict is not None and method == "json_mode":
raise ValueError(
"Argument `strict` is not supported with `method`='json_mode'"
)
msg = "Argument `strict` is not supported with `method`='json_mode'"
raise ValueError(msg)
is_pydantic_schema = _is_pydantic_class(schema)
if method == "json_schema":
@@ -1845,7 +1842,8 @@ class BaseChatOpenAI(BaseChatModel):
"Received a Pydantic BaseModel V1 schema. This is not supported by "
'method="json_schema". Please use method="function_calling" '
"or specify schema via JSON Schema or Pydantic V2 BaseModel. "
'Overriding to method="function_calling".'
'Overriding to method="function_calling".',
stacklevel=2,
)
method = "function_calling"
# Check for incompatible model
@@ -1860,28 +1858,28 @@ class BaseChatOpenAI(BaseChatModel):
f"see supported models here: "
f"https://platform.openai.com/docs/guides/structured-outputs#supported-models. " # noqa: E501
"To fix this warning, set `method='function_calling'. "
"Overriding to method='function_calling'."
"Overriding to method='function_calling'.",
stacklevel=2,
)
method = "function_calling"
if method == "function_calling":
if schema is None:
raise ValueError(
msg = (
"schema must be specified when method is not 'json_mode'. "
"Received None."
)
raise ValueError(msg)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
bind_kwargs = self._filter_disabled_params(
**{
**dict(
tool_choice=tool_name,
parallel_tool_calls=False,
strict=strict,
ls_structured_output_format={
"kwargs": {"method": method, "strict": strict},
"schema": schema,
},
),
"tool_choice": tool_name,
"parallel_tool_calls": False,
"strict": strict,
"ls_structured_output_format": {
"kwargs": {"method": method, "strict": strict},
"schema": schema,
},
**kwargs,
}
)
@@ -1899,13 +1897,11 @@ class BaseChatOpenAI(BaseChatModel):
elif method == "json_mode":
llm = self.bind(
**{
**dict(
response_format={"type": "json_object"},
ls_structured_output_format={
"kwargs": {"method": method},
"schema": schema,
},
),
"response_format": {"type": "json_object"},
"ls_structured_output_format": {
"kwargs": {"method": method},
"schema": schema,
},
**kwargs,
}
)
@@ -1916,10 +1912,11 @@ class BaseChatOpenAI(BaseChatModel):
)
elif method == "json_schema":
if schema is None:
raise ValueError(
msg = (
"schema must be specified when method is not 'json_mode'. "
"Received None."
)
raise ValueError(msg)
response_format = _convert_to_openai_response_format(schema, strict=strict)
bind_kwargs = {
**dict(
@@ -1943,10 +1940,11 @@ class BaseChatOpenAI(BaseChatModel):
else:
output_parser = JsonOutputParser()
else:
raise ValueError(
msg = (
f"Unrecognized method argument. Expected one of 'function_calling' or "
f"'json_mode'. Received: '{method}'"
)
raise ValueError(msg)
if include_raw:
parser_assign = RunnablePassthrough.assign(
@@ -1957,8 +1955,7 @@ class BaseChatOpenAI(BaseChatModel):
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
return llm | output_parser
def _filter_disabled_params(self, **kwargs: Any) -> dict[str, Any]:
if not self.disabled_params:
@@ -1971,8 +1968,7 @@ class BaseChatOpenAI(BaseChatModel):
):
continue
# Keep param
else:
filtered[k] = v
filtered[k] = v
return filtered
def _get_generation_chunk_from_completion(
@@ -1999,7 +1995,7 @@ class BaseChatOpenAI(BaseChatModel):
class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
"""OpenAI chat model integration.
r"""OpenAI chat model integration.
.. dropdown:: Setup
:open:
@@ -2126,7 +2122,9 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
AIMessageChunk(content="", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
AIMessageChunk(content="J", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
AIMessageChunk(content="'adore", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
AIMessageChunk(
content="'adore", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0"
)
AIMessageChunk(content=" la", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
AIMessageChunk(
content=" programmation", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0"
@@ -2182,7 +2180,11 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
"logprobs": None,
},
id="run-012cffe2-5d3d-424d-83b5-51c6d4a593d1-0",
usage_metadata={"input_tokens": 31, "output_tokens": 5, "total_tokens": 36},
usage_metadata={
"input_tokens": 31,
"output_tokens": 5,
"total_tokens": 36,
},
)
.. dropdown:: Tool calling
@@ -2299,7 +2301,9 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
tool = {"type": "web_search_preview"}
llm_with_tools = llm.bind_tools([tool])
response = llm_with_tools.invoke("What was a positive news story from today?")
response = llm_with_tools.invoke(
"What was a positive news story from today?"
)
response.content
.. code-block:: python
@@ -2346,7 +2350,8 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
.. code-block:: python
second_response = llm.invoke(
"What is my name?", previous_response_id=response.response_metadata["id"]
"What is my name?",
previous_response_id=response.response_metadata["id"],
)
second_response.text()
@@ -2424,7 +2429,9 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
setup: str = Field(description="The setup of the joke")
punchline: str = Field(description="The punchline to the joke")
rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
rating: Optional[int] = Field(
description="How funny the joke is, from 1 to 10"
)
structured_llm = llm.with_structured_output(Joke)
@@ -2684,8 +2691,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
"""Route to Chat Completions or Responses API."""
if self._use_responses_api({**kwargs, **self.model_kwargs}):
return super()._stream_responses(*args, **kwargs)
else:
return super()._stream(*args, **kwargs)
return super()._stream(*args, **kwargs)
async def _astream(
self, *args: Any, **kwargs: Any
@@ -2707,7 +2713,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
strict: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
"""Model wrapper that returns outputs formatted to match the given schema.
r"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema:
@@ -3057,6 +3063,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
# },
# 'parsing_error': None
# }
""" # noqa: E501
return super().with_structured_output(
schema, method=method, include_raw=include_raw, strict=strict, **kwargs
@@ -3113,13 +3120,12 @@ def _url_to_size(image_source: str) -> Optional[tuple[int, int]]:
response.raise_for_status()
width, height = Image.open(BytesIO(response.content)).size
return width, height
elif _is_b64(image_source):
if _is_b64(image_source):
_, encoded = image_source.split(",", 1)
data = base64.b64decode(encoded)
width, height = Image.open(BytesIO(data)).size
return width, height
else:
return None
return None
def _count_image_tokens(width: int, height: int) -> int:
@@ -3208,17 +3214,16 @@ def _oai_structured_outputs_parser(
if parsed := ai_msg.additional_kwargs.get("parsed"):
if isinstance(parsed, dict):
return schema(**parsed)
else:
return parsed
elif ai_msg.additional_kwargs.get("refusal"):
return parsed
if ai_msg.additional_kwargs.get("refusal"):
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
elif ai_msg.tool_calls:
if ai_msg.tool_calls:
return None
else:
raise ValueError(
"Structured Output response does not have a 'parsed' field nor a 'refusal' "
f"field. Received message:\n\n{ai_msg}"
)
msg = (
"Structured Output response does not have a 'parsed' field nor a 'refusal' "
f"field. Received message:\n\n{ai_msg}"
)
raise ValueError(msg)
class OpenAIRefusalError(Exception):
@@ -3315,14 +3320,16 @@ def _use_responses_api(payload: dict) -> bool:
def _get_last_messages(
messages: Sequence[BaseMessage],
) -> tuple[Sequence[BaseMessage], Optional[str]]:
"""
Return
"""Return the last messages in a conversation.
Returns:
1. Every message after the most-recent AIMessage that has a non-empty
``response_metadata["id"]`` (may be an empty list),
2. That id.
If the most-recent AIMessage does not have an id (or there is no
AIMessage at all) the entire conversation is returned together with ``None``.
"""
for i in range(len(messages) - 1, -1, -1):
msg = messages[i]
@@ -3330,8 +3337,7 @@ def _get_last_messages(
response_id = msg.response_metadata.get("id")
if response_id:
return messages[i + 1 :], response_id
else:
return messages, None
return messages, None
return messages, None
@@ -3358,13 +3364,14 @@ def _construct_responses_api_payload(
if tool["type"] == "image_generation":
# Handle partial images (not yet supported)
if "partial_images" in tool:
raise NotImplementedError(
msg = (
"Partial image generation is not yet supported "
"via the LangChain ChatOpenAI client. Please "
"drop the 'partial_images' key from the image_generation "
"tool."
)
elif payload.get("stream") and "partial_images" not in tool:
raise NotImplementedError(msg)
if payload.get("stream") and "partial_images" not in tool:
# OpenAI requires this parameter be set; we ignore it during
# streaming.
tool["partial_images"] = 1
@@ -3390,10 +3397,11 @@ def _construct_responses_api_payload(
if schema := payload.pop("response_format", None):
if payload.get("text"):
text = payload["text"]
raise ValueError(
msg = (
"Can specify at most one of 'response_format' or 'text', received both:"
f"\n{schema=}\n{text=}"
)
raise ValueError(msg)
# For pydantic + non-streaming case, we use responses.parse.
# Otherwise, we use responses.create.
@@ -3447,7 +3455,7 @@ def _make_computer_call_output_from_message(message: ToolMessage) -> dict:
def _pop_index_and_sub_index(block: dict) -> dict:
"""When streaming, langchain-core uses the ``index`` key to aggregate
text blocks. OpenAI API does not support this key, so we need to remove it.
"""
""" # noqa: D205
new_block = {k: v for k, v in block.items() if k != "index"}
if "summary" in new_block and isinstance(new_block["summary"], list):
new_summary = []
@@ -3758,7 +3766,7 @@ def _convert_responses_chunk_to_generation_chunk(
current_sub_index: int, # index of content block in output item
schema: Optional[type[_BM]] = None,
metadata: Optional[dict] = None,
has_reasoning: bool = False,
has_reasoning: bool = False, # noqa: FBT001, FBT002
output_version: Literal["v0", "responses/v1"] = "v0",
) -> tuple[int, int, int, Optional[ChatGenerationChunk]]:
def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None:
@@ -3808,12 +3816,9 @@ def _convert_responses_chunk_to_generation_chunk(
content = []
tool_call_chunks: list = []
additional_kwargs: dict = {}
if metadata:
response_metadata = metadata
else:
response_metadata = {}
response_metadata = metadata or {}
usage_metadata = None
id = None
id_ = None
if chunk.type == "response.output_text.delta":
_advance(chunk.output_index, chunk.content_index)
content.append({"type": "text", "text": chunk.delta, "index": current_index})
@@ -3828,7 +3833,7 @@ def _convert_responses_chunk_to_generation_chunk(
elif chunk.type == "response.output_text.done":
content.append({"id": chunk.item_id, "index": current_index})
elif chunk.type == "response.created":
id = chunk.response.id
id_ = chunk.response.id
response_metadata["id"] = chunk.response.id # Backwards compatibility
elif chunk.type == "response.completed":
msg = cast(
@@ -3849,7 +3854,7 @@ def _convert_responses_chunk_to_generation_chunk(
}
elif chunk.type == "response.output_item.added" and chunk.item.type == "message":
if output_version == "v0":
id = chunk.item.id
id_ = chunk.item.id
else:
pass
elif (
@@ -3944,7 +3949,7 @@ def _convert_responses_chunk_to_generation_chunk(
usage_metadata=usage_metadata,
response_metadata=response_metadata,
additional_kwargs=additional_kwargs,
id=id,
id=id_,
)
if output_version == "v0":
message = cast(

View File

@@ -1,4 +1,4 @@
from langchain_openai.embeddings.azure import AzureOpenAIEmbeddings
from langchain_openai.embeddings.base import OpenAIEmbeddings
__all__ = ["OpenAIEmbeddings", "AzureOpenAIEmbeddings"]
__all__ = ["AzureOpenAIEmbeddings", "OpenAIEmbeddings"]

View File

@@ -20,7 +20,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): # type: ignore[override]
To access AzureOpenAI embedding models you'll need to create an Azure account,
get an API key, and install the `langchain-openai` integration package.
Youll need to have an Azure OpenAI instance deployed.
You'll need to have an Azure OpenAI instance deployed.
You can deploy a version on Azure Portal following this
[guide](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource?pivots=web-portal).
@@ -131,7 +131,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): # type: ignore[override]
alias="api_version",
)
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided.
Set to "2023-05-15" by default if env variable `OPENAI_API_VERSION` is not set.
"""
azure_ad_token: Optional[SecretStr] = Field(
@@ -171,19 +171,21 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): # type: ignore[override]
if openai_api_base and self.validate_base_url:
if "/openai" not in openai_api_base:
self.openai_api_base = cast(str, self.openai_api_base) + "/openai"
raise ValueError(
msg = (
"As of openai>=1.0.0, Azure endpoints should be specified via "
"the `azure_endpoint` param not `openai_api_base` "
"(or alias `base_url`). "
)
raise ValueError(msg)
if self.deployment:
raise ValueError(
msg = (
"As of openai>=1.0.0, if `deployment` (or alias "
"`azure_deployment`) is specified then "
"`openai_api_base` (or alias `base_url`) should not be. "
"Instead use `deployment` (or alias `azure_deployment`) "
"and `azure_endpoint`."
)
raise ValueError(msg)
client_params: dict = {
"api_version": self.openai_api_version,
"azure_endpoint": self.azure_endpoint,

View File

@@ -21,7 +21,7 @@ def _process_batched_chunked_embeddings(
tokens: list[Union[list[int], str]],
batched_embeddings: list[list[float]],
indices: list[int],
skip_empty: bool,
skip_empty: bool, # noqa: FBT001
) -> list[Optional[list[float]]]:
# for each text, this is the list of embeddings (list of list of floats)
# corresponding to the chunks of the text
@@ -50,29 +50,25 @@ def _process_batched_chunked_embeddings(
embeddings.append(None)
continue
elif len(_result) == 1:
if len(_result) == 1:
# if only one embedding was produced, use it
embeddings.append(_result[0])
continue
else:
# else we need to weighted average
# should be same as
# average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
total_weight = sum(num_tokens_in_batch[i])
average = [
sum(
val * weight
for val, weight in zip(embedding, num_tokens_in_batch[i])
)
/ total_weight
for embedding in zip(*_result)
]
# else we need to weighted average
# should be same as
# average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
total_weight = sum(num_tokens_in_batch[i])
average = [
sum(val * weight for val, weight in zip(embedding, num_tokens_in_batch[i]))
/ total_weight
for embedding in zip(*_result)
]
# should be same as
# embeddings.append((average / np.linalg.norm(average)).tolist())
magnitude = sum(val**2 for val in average) ** 0.5
embeddings.append([val / magnitude for val in average])
# should be same as
# embeddings.append((average / np.linalg.norm(average)).tolist())
magnitude = sum(val**2 for val in average) ** 0.5
embeddings.append([val / magnitude for val in average])
return embeddings
@@ -265,21 +261,24 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
msg = f"Found {field_name} supplied twice."
raise ValueError(msg)
if field_name not in all_required_field_names:
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
Please confirm that {field_name} is what you intended.""",
stacklevel=2,
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
msg = (
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
raise ValueError(msg)
values["model_kwargs"] = extra
return values
@@ -288,10 +287,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.openai_api_type in ("azure", "azure_ad", "azuread"):
raise ValueError(
"If you are using Azure, "
"please use the `AzureOpenAIEmbeddings` class."
msg = (
"If you are using Azure, please use the `AzureOpenAIEmbeddings` class."
)
raise ValueError(msg)
client_params: dict = {
"api_key": (
self.openai_api_key.get_secret_value() if self.openai_api_key else None
@@ -308,20 +307,22 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
openai_proxy = self.openai_proxy
http_client = self.http_client
http_async_client = self.http_async_client
raise ValueError(
msg = (
"Cannot specify 'openai_proxy' if one of "
"'http_client'/'http_async_client' is already specified. Received:\n"
f"{openai_proxy=}\n{http_client=}\n{http_async_client=}"
)
raise ValueError(msg)
if not self.client:
if self.openai_proxy and not self.http_client:
try:
import httpx
except ImportError as e:
raise ImportError(
msg = (
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
)
raise ImportError(msg) from e
self.http_client = httpx.Client(proxy=self.openai_proxy)
sync_specific = {"http_client": self.http_client}
self.client = openai.OpenAI(**client_params, **sync_specific).embeddings # type: ignore[arg-type]
@@ -330,10 +331,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
try:
import httpx
except ImportError as e:
raise ImportError(
msg = (
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
)
raise ImportError(msg) from e
self.http_async_client = httpx.AsyncClient(proxy=self.openai_proxy)
async_specific = {"http_client": self.http_async_client}
self.async_client = openai.AsyncOpenAI(
@@ -352,8 +354,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
def _tokenize(
self, texts: list[str], chunk_size: int
) -> tuple[Iterable[int], list[Union[list[int], str]], list[int]]:
"""
Take the input `texts` and `chunk_size` and return 3 iterables as a tuple:
"""Take the input `texts` and `chunk_size` and return 3 iterables as a tuple.
We have `batches`, where batches are sets of individual texts
we want responses from the openai api. The length of a single batch is
@@ -380,12 +381,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
if not self.tiktoken_enabled:
try:
from transformers import AutoTokenizer
except ImportError:
raise ValueError(
except ImportError as e:
msg = (
"Could not import transformers python package. "
"This is needed for OpenAIEmbeddings to work without "
"`tiktoken`. Please install it with `pip install transformers`. "
)
raise ValueError(msg) from e
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=model_name
@@ -455,8 +457,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
chunk_size: Optional[int] = None,
**kwargs: Any,
) -> list[list[float]]:
"""
Generate length-safe embeddings for a list of texts.
"""Generate length-safe embeddings for a list of texts.
This method handles tokenization and embedding generation, respecting the
set embedding context length and chunk size. It supports both tiktoken
@@ -466,9 +467,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
texts (List[str]): A list of texts to embed.
engine (str): The engine or model to use for embeddings.
chunk_size (Optional[int]): The size of chunks for processing embeddings.
kwargs (Any): Additional keyword arguments to pass to the embedding API.
Returns:
List[List[float]]: A list of embeddings for each input text.
"""
_chunk_size = chunk_size or self.chunk_size
client_kwargs = {**self._invocation_params, **kwargs}
@@ -508,8 +511,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
chunk_size: Optional[int] = None,
**kwargs: Any,
) -> list[list[float]]:
"""
Asynchronously generate length-safe embeddings for a list of texts.
"""Asynchronously generate length-safe embeddings for a list of texts.
This method handles tokenization and asynchronous embedding generation,
respecting the set embedding context length and chunk size. It supports both
@@ -519,11 +521,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
texts (List[str]): A list of texts to embed.
engine (str): The engine or model to use for embeddings.
chunk_size (Optional[int]): The size of chunks for processing embeddings.
kwargs (Any): Additional keyword arguments to pass to the embedding API.
Returns:
List[List[float]]: A list of embeddings for each input text.
"""
"""
_chunk_size = chunk_size or self.chunk_size
client_kwargs = {**self._invocation_params, **kwargs}
_iter, tokens, indices = await run_in_executor(
@@ -570,6 +573,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
chunk_size_ = chunk_size or self.chunk_size
client_kwargs = {**self._invocation_params, **kwargs}
@@ -604,6 +608,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
chunk_size_ = chunk_size or self.chunk_size
client_kwargs = {**self._invocation_params, **kwargs}
@@ -634,6 +639,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Returns:
Embedding for the text.
"""
return self.embed_documents([text], **kwargs)[0]
@@ -646,6 +652,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Returns:
Embedding for the text.
"""
embeddings = await self.aembed_documents([text], **kwargs)
return embeddings[0]

View File

@@ -1,4 +1,4 @@
from langchain_openai.llms.azure import AzureOpenAI
from langchain_openai.llms.base import OpenAI
__all__ = ["OpenAI", "AzureOpenAI"]
__all__ = ["AzureOpenAI", "OpenAI"]

View File

@@ -30,6 +30,7 @@ class AzureOpenAI(BaseOpenAI):
from langchain_openai import AzureOpenAI
openai = AzureOpenAI(model_name="gpt-3.5-turbo-instruct")
"""
azure_endpoint: Optional[str] = Field(
@@ -42,7 +43,7 @@ class AzureOpenAI(BaseOpenAI):
Example: `https://example-resource.azure.openai.com/`
"""
deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment")
"""A model deployment.
"""A model deployment.
If given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints.
@@ -87,7 +88,7 @@ class AzureOpenAI(BaseOpenAI):
)
"""Legacy, for openai<1.0.0 support."""
validate_base_url: bool = True
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
infer if it is a base_url or azure_endpoint and update accordingly.
"""
@@ -112,11 +113,14 @@ class AzureOpenAI(BaseOpenAI):
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n < 1:
raise ValueError("n must be at least 1.")
msg = "n must be at least 1."
raise ValueError(msg)
if self.streaming and self.n > 1:
raise ValueError("Cannot stream results when n > 1.")
msg = "Cannot stream results when n > 1."
raise ValueError(msg)
if self.streaming and self.best_of > 1:
raise ValueError("Cannot stream results when best_of > 1.")
msg = "Cannot stream results when best_of > 1."
raise ValueError(msg)
# For backwards compatibility. Before openai v1, no distinction was made
# between azure_endpoint and base_url (openai_api_base).
openai_api_base = self.openai_api_base
@@ -125,19 +129,21 @@ class AzureOpenAI(BaseOpenAI):
self.openai_api_base = (
cast(str, self.openai_api_base).rstrip("/") + "/openai"
)
raise ValueError(
msg = (
"As of openai>=1.0.0, Azure endpoints should be specified via "
"the `azure_endpoint` param not `openai_api_base` "
"(or alias `base_url`)."
)
raise ValueError(msg)
if self.deployment_name:
raise ValueError(
msg = (
"As of openai>=1.0.0, if `deployment_name` (or alias "
"`azure_deployment`) is specified then "
"`openai_api_base` (or alias `base_url`) should not be. "
"Instead use `deployment_name` (or alias `azure_deployment`) "
"and `azure_endpoint`."
)
raise ValueError(msg)
self.deployment_name = None
client_params: dict = {
"api_version": self.openai_api_version,
@@ -183,10 +189,7 @@ class AzureOpenAI(BaseOpenAI):
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {
**{"deployment_name": self.deployment_name},
**super()._identifying_params,
}
return {"deployment_name": self.deployment_name, **super()._identifying_params}
@property
def _invocation_params(self) -> dict[str, Any]:

View File

@@ -41,10 +41,10 @@ def _stream_response_to_generation_chunk(
return GenerationChunk(text="")
return GenerationChunk(
text=stream_response["choices"][0]["text"] or "",
generation_info=dict(
finish_reason=stream_response["choices"][0].get("finish_reason", None),
logprobs=stream_response["choices"][0].get("logprobs", None),
),
generation_info={
"finish_reason": stream_response["choices"][0].get("finish_reason", None),
"logprobs": stream_response["choices"][0].get("logprobs", None),
},
)
@@ -80,7 +80,7 @@ class BaseOpenAI(BaseLLM):
openai_api_base: Optional[str] = Field(
alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None)
)
"""Base URL path for API requests, leave blank if not using a proxy or service
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator."""
openai_organization: Optional[str] = Field(
alias="organization",
@@ -98,7 +98,7 @@ class BaseOpenAI(BaseLLM):
request_timeout: Union[float, tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
None."""
logit_bias: Optional[dict[str, float]] = None
"""Adjust the probability of specific tokens being generated."""
@@ -116,25 +116,25 @@ class BaseOpenAI(BaseLLM):
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
"""Set of special tokens that are not allowed。"""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain
them to be under a certain limit. By default, when set to None, this will
be the same as the embedding model name. However, there are some cases
where you may want to use this Embedding class with a model name not
supported by tiktoken. This can include when using Azure embeddings or
when using one of the many model providers that expose an OpenAI-like
API but with different models. In those cases, in order to avoid erroring
"""The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain
them to be under a certain limit. By default, when set to None, this will
be the same as the embedding model name. However, there are some cases
where you may want to use this Embedding class with a model name not
supported by tiktoken. This can include when using Azure embeddings or
when using one of the many model providers that expose an OpenAI-like
API but with different models. In those cases, in order to avoid erroring
when tiktoken is called, you can specify a model name to use here."""
default_headers: Union[Mapping[str, str], None] = None
default_query: Union[Mapping[str, object], None] = None
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = None
"""Optional httpx.Client. Only used for sync invocations. Must specify
"""Optional httpx.Client. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations.
"""
http_async_client: Union[Any, None] = None
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
extra_body: Optional[Mapping[str, Any]] = None
"""Optional additional JSON properties to include in the request parameters when
@@ -147,18 +147,20 @@ class BaseOpenAI(BaseLLM):
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names)
return values
return _build_model_kwargs(values, all_required_field_names)
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n < 1:
raise ValueError("n must be at least 1.")
msg = "n must be at least 1."
raise ValueError(msg)
if self.streaming and self.n > 1:
raise ValueError("Cannot stream results when n > 1.")
msg = "Cannot stream results when n > 1."
raise ValueError(msg)
if self.streaming and self.best_of > 1:
raise ValueError("Cannot stream results when best_of > 1.")
msg = "Cannot stream results when best_of > 1."
raise ValueError(msg)
client_params: dict = {
"api_key": (
@@ -280,6 +282,8 @@ class BaseOpenAI(BaseLLM):
Args:
prompts: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
run_manager: Optional callback manager to use for callbacks.
kwargs: Additional keyword arguments to pass to the model.
Returns:
The full LLM output.
@@ -288,6 +292,7 @@ class BaseOpenAI(BaseLLM):
.. code-block:: python
response = openai.generate(["Tell me a joke."])
"""
# TODO: write a unit test for this
params = self._invocation_params
@@ -302,7 +307,8 @@ class BaseOpenAI(BaseLLM):
for _prompts in sub_prompts:
if self.streaming:
if len(_prompts) > 1:
raise ValueError("Cannot stream results with multiple prompts.")
msg = "Cannot stream results with multiple prompts."
raise ValueError(msg)
generation: Optional[GenerationChunk] = None
for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs):
@@ -311,7 +317,8 @@ class BaseOpenAI(BaseLLM):
else:
generation += chunk
if generation is None:
raise ValueError("Generation is empty after streaming.")
msg = "Generation is empty after streaming."
raise ValueError(msg)
choices.append(
{
"text": generation.text,
@@ -369,7 +376,8 @@ class BaseOpenAI(BaseLLM):
for _prompts in sub_prompts:
if self.streaming:
if len(_prompts) > 1:
raise ValueError("Cannot stream results with multiple prompts.")
msg = "Cannot stream results with multiple prompts."
raise ValueError(msg)
generation: Optional[GenerationChunk] = None
async for chunk in self._astream(
@@ -380,7 +388,8 @@ class BaseOpenAI(BaseLLM):
else:
generation += chunk
if generation is None:
raise ValueError("Generation is empty after streaming.")
msg = "Generation is empty after streaming."
raise ValueError(msg)
choices.append(
{
"text": generation.text,
@@ -417,15 +426,13 @@ class BaseOpenAI(BaseLLM):
params["stop"] = stop
if params["max_tokens"] == -1:
if len(prompts) != 1:
raise ValueError(
"max_tokens set to -1 not supported for multiple inputs."
)
msg = "max_tokens set to -1 not supported for multiple inputs."
raise ValueError(msg)
params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
sub_prompts = [
return [
prompts[i : i + self.batch_size]
for i in range(0, len(prompts), self.batch_size)
]
return sub_prompts
def create_llm_result(
self,
@@ -445,10 +452,10 @@ class BaseOpenAI(BaseLLM):
[
Generation(
text=choice["text"],
generation_info=dict(
finish_reason=choice.get("finish_reason"),
logprobs=choice.get("logprobs"),
),
generation_info={
"finish_reason": choice.get("finish_reason"),
"logprobs": choice.get("logprobs"),
},
)
for choice in sub_choices
]
@@ -466,7 +473,7 @@ class BaseOpenAI(BaseLLM):
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params}
return {"model_name": self.model_name, **self._default_params}
@property
def _llm_type(self) -> str:
@@ -478,7 +485,7 @@ class BaseOpenAI(BaseLLM):
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
# tiktoken NOT supported for Python < 3.8
if sys.version_info[1] < 8:
if sys.version_info[1] < 8: # noqa: YTT203
return super().get_num_tokens(text)
model_name = self.tiktoken_model_name or self.model_name
@@ -507,6 +514,7 @@ class BaseOpenAI(BaseLLM):
.. code-block:: python
max_tokens = openai.modelname_to_contextsize("gpt-3.5-turbo-instruct")
"""
model_token_mapping = {
"gpt-4o-mini": 128_000,
@@ -543,7 +551,7 @@ class BaseOpenAI(BaseLLM):
if "ft-" in modelname:
modelname = modelname.split(":")[0]
context_size = model_token_mapping.get(modelname, None)
context_size = model_token_mapping.get(modelname)
if context_size is None:
raise ValueError(
@@ -571,6 +579,7 @@ class BaseOpenAI(BaseLLM):
.. code-block:: python
max_tokens = openai.max_tokens_for_prompt("Tell me a joke.")
"""
num_tokens = self.get_num_tokens(prompt)
return self.max_context_size - num_tokens
@@ -689,7 +698,7 @@ class OpenAI(BaseOpenAI):
@property
def _invocation_params(self) -> dict[str, Any]:
return {**{"model": self.model_name}, **super()._invocation_params}
return {"model": self.model_name, **super()._invocation_params}
@property
def lc_secrets(self) -> dict[str, str]:

View File

@@ -4,4 +4,4 @@ from langchain_core.output_parsers.openai_tools import (
PydanticToolsParser,
)
__all__ = ["PydanticToolsParser", "JsonOutputToolsParser", "JsonOutputKeyToolsParser"]
__all__ = ["JsonOutputKeyToolsParser", "JsonOutputToolsParser", "PydanticToolsParser"]

View File

@@ -64,12 +64,67 @@ ignore_missing_imports = true
target-version = "py39"
[tool.ruff.lint]
select = ["E", "F", "I", "T201", "UP", "S"]
ignore = [ "UP007", ]
select = [
"A", # flake8-builtins
"B", # flake8-bugbear
"ASYNC", # flake8-async
"C4", # flake8-comprehensions
"COM", # flake8-commas
"D", # pydocstyle
"DOC", # pydoclint
"E", # pycodestyle error
"EM", # flake8-errmsg
"F", # pyflakes
"FA", # flake8-future-annotations
"FBT", # flake8-boolean-trap
"FLY", # flake8-flynt
"I", # isort
"ICN", # flake8-import-conventions
"INT", # flake8-gettext
"ISC", # isort-comprehensions
"PGH", # pygrep-hooks
"PIE", # flake8-pie
"PERF", # flake8-perf
"PYI", # flake8-pyi
"Q", # flake8-quotes
"RET", # flake8-return
"RSE", # flake8-rst-docstrings
"RUF", # ruff
"S", # flake8-bandit
"SLF", # flake8-self
"SLOT", # flake8-slots
"SIM", # flake8-simplify
"T10", # flake8-debugger
"T20", # flake8-print
"TID", # flake8-tidy-imports
"UP", # pyupgrade
"W", # pycodestyle warning
"YTT", # flake8-2020
]
ignore = [
"D100", # pydocstyle: Missing docstring in public module
"D101", # pydocstyle: Missing docstring in public class
"D102", # pydocstyle: Missing docstring in public method
"D103", # pydocstyle: Missing docstring in public function
"D104", # pydocstyle: Missing docstring in public package
"D105", # pydocstyle: Missing docstring in magic method
"D107", # pydocstyle: Missing docstring in __init__
"D203", # Messes with the formatter
"D407", # pydocstyle: Missing-dashed-underline-after-section
"COM812", # Messes with the formatter
"ISC001", # Messes with the formatter
"D213", # Messes with the formatter
"PERF203", # Rarely useful
"S112", # Rarely useful
"RUF012", # Doesn't play well with Pydantic
"SLF001", # Private member access
"UP007", # pyupgrade: non-pep604-annotation-union
"UP045", # pyupgrade: non-pep604-annotation-optional
]
unfixable = ["B028"] # People should intentionally tune the stacklevel
[tool.ruff.format]
docstring-code-format = true
skip-magic-trailing-comma = true
[tool.coverage.run]
omit = ["tests/*"]

View File

@@ -1,5 +1,7 @@
"""Test AzureChatOpenAI wrapper."""
from __future__ import annotations
import json
import os
from typing import Any, Optional
@@ -173,7 +175,6 @@ def test_openai_streaming(llm: AzureChatOpenAI) -> None:
@pytest.mark.scheduled
async def test_openai_astream(llm: AzureChatOpenAI) -> None:
"""Test streaming tokens from OpenAI."""
full: Optional[BaseMessageChunk] = None
async for chunk in llm.astream("I'm Pickle Rick"):
assert isinstance(chunk.content, str)
@@ -185,7 +186,6 @@ async def test_openai_astream(llm: AzureChatOpenAI) -> None:
@pytest.mark.scheduled
async def test_openai_abatch(llm: AzureChatOpenAI) -> None:
"""Test streaming tokens from AzureChatOpenAI."""
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
@@ -194,7 +194,6 @@ async def test_openai_abatch(llm: AzureChatOpenAI) -> None:
@pytest.mark.scheduled
async def test_openai_abatch_tags(llm: AzureChatOpenAI) -> None:
"""Test batch tokens from AzureChatOpenAI."""
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
@@ -205,7 +204,6 @@ async def test_openai_abatch_tags(llm: AzureChatOpenAI) -> None:
@pytest.mark.scheduled
def test_openai_batch(llm: AzureChatOpenAI) -> None:
"""Test batch tokens from AzureChatOpenAI."""
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
@@ -214,7 +212,6 @@ def test_openai_batch(llm: AzureChatOpenAI) -> None:
@pytest.mark.scheduled
async def test_openai_ainvoke(llm: AzureChatOpenAI) -> None:
"""Test invoke tokens from AzureChatOpenAI."""
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
assert result.response_metadata.get("model_name") is not None
@@ -223,8 +220,7 @@ async def test_openai_ainvoke(llm: AzureChatOpenAI) -> None:
@pytest.mark.scheduled
def test_openai_invoke(llm: AzureChatOpenAI) -> None:
"""Test invoke tokens from AzureChatOpenAI."""
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
assert result.response_metadata.get("model_name") is not None

View File

@@ -1,4 +1,4 @@
"""Standard LangChain interface tests"""
"""Standard LangChain interface tests."""
import os

View File

@@ -1,5 +1,7 @@
"""Test ChatOpenAI chat model."""
from __future__ import annotations
import base64
import json
from collections.abc import AsyncIterator
@@ -66,7 +68,7 @@ def test_chat_openai_model() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True])
def test_chat_openai_system_message(use_responses_api: bool) -> None:
def test_chat_openai_system_message(use_responses_api: bool) -> None: # noqa: FBT001
"""Test ChatOpenAI wrapper with system message."""
chat = ChatOpenAI(use_responses_api=use_responses_api, max_tokens=MAX_TOKEN_COUNT) # type: ignore[call-arg]
system_message = SystemMessage(content="You are to chat with the user.")
@@ -108,7 +110,7 @@ def test_chat_openai_multiple_completions() -> None:
@pytest.mark.scheduled
@pytest.mark.parametrize("use_responses_api", [False, True])
def test_chat_openai_streaming(use_responses_api: bool) -> None:
def test_chat_openai_streaming(use_responses_api: bool) -> None: # noqa: FBT001
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
@@ -204,7 +206,7 @@ async def test_async_chat_openai_bind_functions() -> None:
@pytest.mark.scheduled
@pytest.mark.parametrize("use_responses_api", [False, True])
async def test_openai_abatch_tags(use_responses_api: bool) -> None:
async def test_openai_abatch_tags(use_responses_api: bool) -> None: # noqa: FBT001
"""Test batch tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=MAX_TOKEN_COUNT, use_responses_api=use_responses_api) # type: ignore[call-arg]
@@ -223,7 +225,7 @@ def test_openai_invoke() -> None:
service_tier="flex", # Also test service_tier
)
result = llm.invoke("Hello", config=dict(tags=["foo"]))
result = llm.invoke("Hello", config={"tags": ["foo"]})
assert isinstance(result.content, str)
# assert no response headers if include_response_headers is not set
@@ -255,11 +257,12 @@ def test_stream() -> None:
if chunk.response_metadata:
chunks_with_response_metadata += 1
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
raise AssertionError(
msg = (
"Expected exactly one chunk with metadata. "
"AIMessageChunk aggregation can add these metadata. Check that "
"this is behaving properly."
)
raise AssertionError(msg)
assert isinstance(aggregate, AIMessageChunk)
assert aggregate.usage_metadata is not None
assert aggregate.usage_metadata["input_tokens"] > 0
@@ -270,7 +273,7 @@ def test_stream() -> None:
async def test_astream() -> None:
"""Test streaming tokens from OpenAI."""
async def _test_stream(stream: AsyncIterator, expect_usage: bool) -> None:
async def _test_stream(stream: AsyncIterator, expect_usage: bool) -> None: # noqa: FBT001
full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
chunks_with_response_metadata = 0
@@ -284,20 +287,22 @@ async def test_astream() -> None:
chunks_with_response_metadata += 1
assert isinstance(full, AIMessageChunk)
if chunks_with_response_metadata != 1:
raise AssertionError(
msg = (
"Expected exactly one chunk with metadata. "
"AIMessageChunk aggregation can add these metadata. Check that "
"this is behaving properly."
)
raise AssertionError(msg)
assert full.response_metadata.get("finish_reason") is not None
assert full.response_metadata.get("model_name") is not None
if expect_usage:
if chunks_with_token_counts != 1:
raise AssertionError(
msg = (
"Expected exactly one chunk with token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
raise AssertionError(msg)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
@@ -401,14 +406,14 @@ async def test_async_response_metadata_streaming() -> None:
class GenerateUsername(BaseModel):
"Get a username based on someone's name and hair color."
"""Get a username based on someone's name and hair color."""
name: str
hair_color: str
class MakeASandwich(BaseModel):
"Make a sandwich given a list of ingredients."
"""Make a sandwich given a list of ingredients."""
bread_type: str
cheese_type: str
@@ -442,7 +447,7 @@ def test_tool_use() -> None:
gathered = message
first = False
else:
gathered = gathered + message # type: ignore
gathered = gathered + message # type: ignore[assignment]
assert isinstance(gathered, AIMessageChunk)
assert isinstance(gathered.tool_call_chunks, list)
assert len(gathered.tool_call_chunks) == 1
@@ -458,7 +463,7 @@ def test_tool_use() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True])
def test_manual_tool_call_msg(use_responses_api: bool) -> None:
def test_manual_tool_call_msg(use_responses_api: bool) -> None: # noqa: FBT001
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(
model="gpt-3.5-turbo-0125", temperature=0, use_responses_api=use_responses_api
@@ -499,12 +504,12 @@ def test_manual_tool_call_msg(use_responses_api: bool) -> None:
),
ToolMessage("sally_green_hair", tool_call_id="foo"),
]
with pytest.raises(Exception):
with pytest.raises(Exception): # noqa: B017
llm_with_tool.invoke(msgs)
@pytest.mark.parametrize("use_responses_api", [False, True])
def test_bind_tools_tool_choice(use_responses_api: bool) -> None:
def test_bind_tools_tool_choice(use_responses_api: bool) -> None: # noqa: FBT001
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(
model="gpt-3.5-turbo-0125", temperature=0, use_responses_api=use_responses_api
@@ -536,7 +541,7 @@ def test_disable_parallel_tool_calling() -> None:
@pytest.mark.parametrize("model", ["gpt-4o-mini", "o1", "gpt-4"])
def test_openai_structured_output(model: str) -> None:
class MyModel(BaseModel):
"""A Person"""
"""A Person."""
name: str
age: int
@@ -553,7 +558,7 @@ def test_openai_proxy() -> None:
chat_openai = ChatOpenAI(openai_proxy="http://localhost:8080")
mounts = chat_openai.client._client._client._mounts
assert len(mounts) == 1
for key, value in mounts.items():
for value in mounts.values():
proxy = value._pool._proxy_url.origin
assert proxy.scheme == b"http"
assert proxy.host == b"localhost"
@@ -561,7 +566,7 @@ def test_openai_proxy() -> None:
async_client_mounts = chat_openai.async_client._client._client._mounts
assert len(async_client_mounts) == 1
for key, value in async_client_mounts.items():
for value in async_client_mounts.values():
proxy = value._pool._proxy_url.origin
assert proxy.scheme == b"http"
assert proxy.host == b"localhost"
@@ -569,7 +574,7 @@ def test_openai_proxy() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True])
def test_openai_response_headers(use_responses_api: bool) -> None:
def test_openai_response_headers(use_responses_api: bool) -> None: # noqa: FBT001
"""Test ChatOpenAI response headers."""
chat_openai = ChatOpenAI(
include_response_headers=True, use_responses_api=use_responses_api
@@ -593,7 +598,7 @@ def test_openai_response_headers(use_responses_api: bool) -> None:
@pytest.mark.parametrize("use_responses_api", [False, True])
async def test_openai_response_headers_async(use_responses_api: bool) -> None:
async def test_openai_response_headers_async(use_responses_api: bool) -> None: # noqa: FBT001
"""Test ChatOpenAI response headers."""
chat_openai = ChatOpenAI(
include_response_headers=True, use_responses_api=use_responses_api
@@ -681,7 +686,7 @@ def test_image_token_counting_png() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True])
def test_tool_calling_strict(use_responses_api: bool) -> None:
def test_tool_calling_strict(use_responses_api: bool) -> None: # noqa: FBT001
"""Test tool calling with strict=True.
Responses API appears to have fewer constraints on schema when strict=True.
@@ -714,7 +719,7 @@ def test_tool_calling_strict(use_responses_api: bool) -> None:
# Test stream
full: Optional[BaseMessageChunk] = None
for chunk in model_with_tools.stream(query):
full = chunk if full is None else full + chunk # type: ignore
full = chunk if full is None else full + chunk # type: ignore[assignment]
assert isinstance(full, AIMessage)
_validate_tool_call_message(full)
@@ -731,10 +736,9 @@ def test_tool_calling_strict(use_responses_api: bool) -> None:
def test_structured_output_strict(
model: str,
method: Literal["function_calling", "json_schema"],
use_responses_api: bool,
use_responses_api: bool, # noqa: FBT001
) -> None:
"""Test to verify structured output with strict=True."""
from pydantic import BaseModel as BaseModelProper
from pydantic import Field as FieldProper
@@ -771,10 +775,11 @@ def test_structured_output_strict(
@pytest.mark.parametrize("use_responses_api", [False, True])
@pytest.mark.parametrize(("model", "method"), [("gpt-4o-2024-08-06", "json_schema")])
def test_nested_structured_output_strict(
model: str, method: Literal["json_schema"], use_responses_api: bool
model: str,
method: Literal["json_schema"],
use_responses_api: bool, # noqa: FBT001
) -> None:
"""Test to verify structured output with strict=True for nested object."""
from typing import TypedDict
llm = ChatOpenAI(model=model, temperature=0, use_responses_api=use_responses_api)
@@ -814,7 +819,8 @@ def test_nested_structured_output_strict(
],
)
def test_json_schema_openai_format(
strict: bool, method: Literal["json_schema", "function_calling"]
strict: bool, # noqa: FBT001
method: Literal["json_schema", "function_calling"],
) -> None:
"""Test we can pass in OpenAI schema format specifying strict."""
llm = ChatOpenAI(model="gpt-4o-mini")
@@ -957,7 +963,7 @@ def test_prediction_tokens() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True])
def test_stream_o_series(use_responses_api: bool) -> None:
def test_stream_o_series(use_responses_api: bool) -> None: # noqa: FBT001
list(
ChatOpenAI(model="o3-mini", use_responses_api=use_responses_api).stream(
"how are you"
@@ -966,7 +972,7 @@ def test_stream_o_series(use_responses_api: bool) -> None:
@pytest.mark.parametrize("use_responses_api", [False, True])
async def test_astream_o_series(use_responses_api: bool) -> None:
async def test_astream_o_series(use_responses_api: bool) -> None: # noqa: FBT001
async for _ in ChatOpenAI(
model="o3-mini", use_responses_api=use_responses_api
).astream("how are you"):
@@ -1013,7 +1019,7 @@ async def test_astream_response_format() -> None:
@pytest.mark.parametrize("use_responses_api", [False, True])
@pytest.mark.parametrize("use_max_completion_tokens", [True, False])
def test_o1(use_max_completion_tokens: bool, use_responses_api: bool) -> None:
def test_o1(use_max_completion_tokens: bool, use_responses_api: bool) -> None: # noqa: FBT001
if use_max_completion_tokens:
kwargs: dict = {"max_completion_tokens": MAX_TOKEN_COUNT}
else:

View File

@@ -1,4 +1,4 @@
"""Standard LangChain interface tests"""
"""Standard LangChain interface tests."""
import base64
from pathlib import Path
@@ -66,7 +66,7 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
readme = f.read()
input_ = f"""What's langchain? Here's the langchain README:
{readme}
"""
llm = ChatOpenAI(model="gpt-4o-mini", stream_usage=True)
@@ -123,14 +123,13 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
_ = model.invoke([message])
def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage:
def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage: # noqa: FBT001
if stream:
full = None
for chunk in llm.stream(input_):
full = full + chunk if full else chunk # type: ignore[operator]
return cast(AIMessage, full)
else:
return cast(AIMessage, llm.invoke(input_))
return cast(AIMessage, llm.invoke(input_))
@pytest.mark.skip() # Test either finishes in 5 seconds or 5 minutes.

View File

@@ -1,5 +1,7 @@
"""Test Responses API usage."""
from __future__ import annotations
import json
import os
from typing import Annotated, Any, Literal, Optional, cast
@@ -144,7 +146,7 @@ async def test_web_search_async() -> None:
@pytest.mark.flaky(retries=3, delay=1)
def test_function_calling() -> None:
def multiply(x: int, y: int) -> int:
"""return x * y"""
"""Return x * y."""
return x * y
llm = ChatOpenAI(model=MODEL_NAME)
@@ -277,7 +279,7 @@ async def test_parsed_dict_schema_async(schema: Any) -> None:
def test_function_calling_and_structured_output() -> None:
def multiply(x: int, y: int) -> int:
"""return x * y"""
"""Return x * y."""
return x * y
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
@@ -332,7 +334,7 @@ def test_stateful_api() -> None:
"what's my name", previous_response_id=response.response_metadata["id"]
)
assert isinstance(second_response.content, list)
assert "bobo" in second_response.content[0]["text"].lower() # type: ignore
assert "bobo" in second_response.content[0]["text"].lower() # type: ignore[index]
def test_route_from_model_kwargs() -> None:

View File

@@ -1,4 +1,4 @@
"""Standard LangChain interface tests for Responses API"""
"""Standard LangChain interface tests for Responses API."""
from pathlib import Path
from typing import cast
@@ -31,7 +31,7 @@ class TestOpenAIResponses(TestOpenAIStandard):
readme = f.read()
input_ = f"""What's langchain? Here's the langchain README:
{readme}
"""
llm = ChatOpenAI(model="gpt-4.1-mini", output_version="responses/v1")
@@ -49,11 +49,10 @@ class TestOpenAIResponses(TestOpenAIStandard):
return _invoke(llm, input_, stream)
def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage:
def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage: # noqa: FBT001
if stream:
full = None
for chunk in llm.stream(input_):
full = full + chunk if full else chunk # type: ignore[operator]
return cast(AIMessage, full)
else:
return cast(AIMessage, llm.invoke(input_))
return cast(AIMessage, llm.invoke(input_))

View File

@@ -16,7 +16,6 @@ DEPLOYMENT_NAME = os.environ.get(
"AZURE_OPENAI_DEPLOYMENT_NAME",
os.environ.get("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME", ""),
)
print
def _get_embeddings(**kwargs: Any) -> AzureOpenAIEmbeddings:
@@ -63,7 +62,7 @@ def test_azure_openai_embedding_documents_chunk_size() -> None:
# Max 2048 chunks per batch on Azure OpenAI embeddings
assert embedding.chunk_size == 2048
assert len(output) == 20
assert all([len(out) == 1536 for out in output])
assert all(len(out) == 1536 for out in output)
@pytest.mark.scheduled
@@ -100,7 +99,6 @@ async def test_azure_openai_embedding_async_query() -> None:
@pytest.mark.scheduled
def test_azure_openai_embedding_with_empty_string() -> None:
"""Test openai embeddings with empty string."""
document = ["", "abc"]
embedding = _get_embeddings()
output = embedding.embed_documents(document)
@@ -112,7 +110,7 @@ def test_azure_openai_embedding_with_empty_string() -> None:
api_key=OPENAI_API_KEY,
azure_endpoint=OPENAI_API_BASE,
azure_deployment=DEPLOYMENT_NAME,
) # type: ignore
)
.embeddings.create(input="", model="text-embedding-ada-002")
.data[0]
.embedding

View File

@@ -1,4 +1,4 @@
"""Standard LangChain interface tests"""
"""Standard LangChain interface tests."""
from langchain_core.embeddings import Embeddings
from langchain_tests.integration_tests.embeddings import EmbeddingsIntegrationTests

View File

@@ -98,7 +98,7 @@ async def test_openai_ainvoke(llm: AzureOpenAI) -> None:
@pytest.mark.scheduled
def test_openai_invoke(llm: AzureOpenAI) -> None:
"""Test streaming tokens from AzureOpenAI."""
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result, str)

View File

@@ -67,7 +67,7 @@ def test_invoke() -> None:
"""Test invoke tokens from OpenAI."""
llm = OpenAI()
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result, str)
@@ -164,7 +164,7 @@ def test_openai_invoke() -> None:
"""Test streaming tokens from OpenAI."""
llm = OpenAI(max_tokens=10)
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result, str)

View File

@@ -4,4 +4,3 @@ import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

View File

@@ -80,8 +80,8 @@ def test_structured_output_old_model() -> None:
).with_structured_output(Output)
# assert tool calling was used instead of json_schema
assert "tools" in llm.steps[0].kwargs # type: ignore
assert "response_format" not in llm.steps[0].kwargs # type: ignore
assert "tools" in llm.steps[0].kwargs # type: ignore[attr-defined]
assert "response_format" not in llm.steps[0].kwargs # type: ignore[attr-defined]
def test_max_completion_tokens_in_payload() -> None:

View File

@@ -1,4 +1,4 @@
"""Standard LangChain interface tests"""
"""Standard LangChain interface tests."""
import pytest
from langchain_core.language_models import BaseChatModel

View File

@@ -1,5 +1,7 @@
"""Test OpenAI Chat API wrapper."""
from __future__ import annotations
import json
from functools import partial
from types import TracebackType
@@ -45,7 +47,7 @@ from openai.types.responses.response_usage import (
OutputTokensDetails,
)
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from typing_extensions import Self, TypedDict
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models._compat import (
@@ -229,7 +231,7 @@ def test__convert_dict_to_message_tool_call() -> None:
"type": "function",
},
]
raw_tool_calls = list(sorted(raw_tool_calls, key=lambda x: x["id"]))
raw_tool_calls = sorted(raw_tool_calls, key=lambda x: x["id"])
message = {"role": "assistant", "content": None, "tool_calls": raw_tool_calls}
result = _convert_dict_to_message(message)
expected_output = AIMessage(
@@ -260,8 +262,8 @@ def test__convert_dict_to_message_tool_call() -> None:
)
assert result == expected_output
reverted_message_dict = _convert_message_to_dict(expected_output)
reverted_message_dict["tool_calls"] = list(
sorted(reverted_message_dict["tool_calls"], key=lambda x: x["id"])
reverted_message_dict["tool_calls"] = sorted(
reverted_message_dict["tool_calls"], key=lambda x: x["id"]
)
assert reverted_message_dict == message
@@ -272,7 +274,7 @@ class MockAsyncContextManager:
self.chunk_list = chunk_list
self.chunk_num = len(chunk_list)
async def __aenter__(self) -> "MockAsyncContextManager":
async def __aenter__(self) -> Self:
return self
async def __aexit__(
@@ -283,7 +285,7 @@ class MockAsyncContextManager:
) -> None:
pass
def __aiter__(self) -> "MockAsyncContextManager":
def __aiter__(self) -> MockAsyncContextManager:
return self
async def __anext__(self) -> dict:
@@ -291,8 +293,7 @@ class MockAsyncContextManager:
chunk = self.chunk_list[self.current_chunk]
self.current_chunk += 1
return chunk
else:
raise StopAsyncIteration
raise StopAsyncIteration
class MockSyncContextManager:
@@ -301,7 +302,7 @@ class MockSyncContextManager:
self.chunk_list = chunk_list
self.chunk_num = len(chunk_list)
def __enter__(self) -> "MockSyncContextManager":
def __enter__(self) -> Self:
return self
def __exit__(
@@ -312,7 +313,7 @@ class MockSyncContextManager:
) -> None:
pass
def __iter__(self) -> "MockSyncContextManager":
def __iter__(self) -> MockSyncContextManager:
return self
def __next__(self) -> dict:
@@ -320,13 +321,12 @@ class MockSyncContextManager:
chunk = self.chunk_list[self.current_chunk]
self.current_chunk += 1
return chunk
else:
raise StopIteration
raise StopIteration
GLM4_STREAM_META = """{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4eba\u5de5\u667a\u80fd"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u52a9\u624b"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":""}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":","}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4f60\u53ef\u4ee5"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u53eb\u6211"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"AI"}}]}
@@ -339,12 +339,7 @@ GLM4_STREAM_META = """{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","cr
@pytest.fixture
def mock_glm4_completion() -> list:
list_chunk_data = GLM4_STREAM_META.split("\n")
result_list = []
for msg in list_chunk_data:
if msg != "[DONE]":
result_list.append(json.loads(msg))
return result_list
return [json.loads(msg) for msg in list_chunk_data if msg != "[DONE]"]
async def test_glm4_astream(mock_glm4_completion: list) -> None:
@@ -360,7 +355,7 @@ async def test_glm4_astream(mock_glm4_completion: list) -> None:
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "async_client", mock_client):
async for chunk in llm.astream("你的名字叫什么只回答名字"):
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
@@ -385,7 +380,7 @@ def test_glm4_stream(mock_glm4_completion: list) -> None:
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "client", mock_client):
for chunk in llm.stream("你的名字叫什么只回答名字"):
for chunk in llm.stream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
@@ -402,7 +397,7 @@ DEEPSEEK_STREAM_DATA = """{"id":"d3610c24e6b42518a7883ea57c3ea2c3","choices":[{"
{"choices":[{"delta":{"content":"Deep","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"Seek","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":" Chat","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":",","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"一个","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"深度","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
@@ -420,12 +415,7 @@ DEEPSEEK_STREAM_DATA = """{"id":"d3610c24e6b42518a7883ea57c3ea2c3","choices":[{"
@pytest.fixture
def mock_deepseek_completion() -> list[dict]:
list_chunk_data = DEEPSEEK_STREAM_DATA.split("\n")
result_list = []
for msg in list_chunk_data:
if msg != "[DONE]":
result_list.append(json.loads(msg))
return result_list
return [json.loads(msg) for msg in list_chunk_data if msg != "[DONE]"]
async def test_deepseek_astream(mock_deepseek_completion: list) -> None:
@@ -440,7 +430,7 @@ async def test_deepseek_astream(mock_deepseek_completion: list) -> None:
usage_chunk = mock_deepseek_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "async_client", mock_client):
async for chunk in llm.astream("你的名字叫什么只回答名字"):
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
@@ -464,7 +454,7 @@ def test_deepseek_stream(mock_deepseek_completion: list) -> None:
usage_chunk = mock_deepseek_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "client", mock_client):
for chunk in llm.stream("你的名字叫什么只回答名字"):
for chunk in llm.stream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
@@ -488,12 +478,7 @@ OPENAI_STREAM_DATA = """{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":
@pytest.fixture
def mock_openai_completion() -> list[dict]:
list_chunk_data = OPENAI_STREAM_DATA.split("\n")
result_list = []
for msg in list_chunk_data:
if msg != "[DONE]":
result_list.append(json.loads(msg))
return result_list
return [json.loads(msg) for msg in list_chunk_data if msg != "[DONE]"]
async def test_openai_astream(mock_openai_completion: list) -> None:
@@ -508,7 +493,7 @@ async def test_openai_astream(mock_openai_completion: list) -> None:
usage_chunk = mock_openai_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "async_client", mock_client):
async for chunk in llm.astream("你的名字叫什么只回答名字"):
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
@@ -532,7 +517,7 @@ def test_openai_stream(mock_openai_completion: list) -> None:
usage_chunk = mock_openai_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "client", mock_client):
for chunk in llm.stream("你的名字叫什么只回答名字"):
for chunk in llm.stream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
@@ -727,7 +712,7 @@ def test_format_message_content() -> None:
"input": {"location": "San Francisco, CA", "unit": "celsius"},
},
]
assert [{"type": "text", "text": "hello"}] == _format_message_content(content)
assert _format_message_content(content) == [{"type": "text", "text": "hello"}]
# Standard multi-modal inputs
content = [{"type": "image", "source_type": "url", "url": "https://..."}]
@@ -776,14 +761,14 @@ def test_format_message_content() -> None:
class GenerateUsername(BaseModel):
"Get a username based on someone's name and hair color."
"""Get a username based on someone's name and hair color."""
name: str
hair_color: str
class MakeASandwich(BaseModel):
"Make a sandwich given a list of ingredients."
"""Make a sandwich given a list of ingredients."""
bread_type: str
cheese_type: str
@@ -805,7 +790,7 @@ class MakeASandwich(BaseModel):
],
)
@pytest.mark.parametrize("strict", [True, False, None])
def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> None:
def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> None: # noqa: FBT001
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm.bind_tools(
@@ -822,8 +807,8 @@ def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> Non
def test_with_structured_output(
schema: Union[type, dict[str, Any], None],
method: Literal["function_calling", "json_mode", "json_schema"],
include_raw: bool,
strict: Optional[bool],
include_raw: bool, # noqa: FBT001
strict: Optional[bool], # noqa: FBT001
) -> None:
"""Test passing in manually construct tool call message."""
if method == "json_mode":
@@ -873,11 +858,20 @@ def test_get_num_tokens_from_messages() -> None:
),
ToolMessage("foobar", tool_call_id="foo"),
]
expected = 176
actual = llm.get_num_tokens_from_messages(messages)
assert expected == actual
expected = 431
with patch(
"langchain_openai.chat_models.base._url_to_size", return_value=(512, 512)
) as _mock_url_to_size:
actual = llm.get_num_tokens_from_messages(messages)
with patch(
"langchain_openai.chat_models.base._url_to_size", return_value=(512, 128)
) as _mock_url_to_size:
actual = llm.get_num_tokens_from_messages(messages)
expected = 431
assert expected == actual
# Test file inputs
messages = [
HumanMessage(
[
@@ -896,6 +890,28 @@ def test_get_num_tokens_from_messages() -> None:
actual = llm.get_num_tokens_from_messages(messages)
assert actual == 13
# actual = llm.get_num_tokens_from_messages(messages)
# assert expected == actual
# # Test file inputs
# messages = [
# HumanMessage(
# [
# "Summarize this document.",
# {
# "type": "file",
# "file": {
# "filename": "my file",
# "file_data": "data:application/pdf;base64,<data>",
# },
# },
# ]
# )
# ]
# with pytest.warns(match="file inputs are not supported"):
# actual = llm.get_num_tokens_from_messages(messages)
# assert actual == 13
class Foo(BaseModel):
bar: int
@@ -914,7 +930,6 @@ class Foo(BaseModel):
)
def test_schema_from_with_structured_output(schema: type) -> None:
"""Test schema from with_structured_output."""
llm = ChatOpenAI(model="gpt-4o")
structured_llm = llm.with_structured_output(
@@ -1014,10 +1029,10 @@ def test__convert_to_openai_response_format() -> None:
@pytest.mark.parametrize("method", ["function_calling", "json_schema"])
@pytest.mark.parametrize("strict", [True, None])
def test_structured_output_strict(
method: Literal["function_calling", "json_schema"], strict: Optional[bool]
method: Literal["function_calling", "json_schema"],
strict: Optional[bool], # noqa: FBT001
) -> None:
"""Test to verify structured output with strict=True."""
llm = ChatOpenAI(model="gpt-4o-2024-08-06")
class Joke(BaseModel):
@@ -1033,7 +1048,6 @@ def test_structured_output_strict(
def test_nested_structured_output_strict() -> None:
"""Test to verify structured output with strict=True for nested object."""
llm = ChatOpenAI(model="gpt-4o-2024-08-06")
class SelfEvaluation(TypedDict):
@@ -1139,8 +1153,8 @@ def test_structured_output_old_model() -> None:
with pytest.warns(match="Cannot use method='json_schema'"):
llm = ChatOpenAI(model="gpt-4").with_structured_output(Output)
# assert tool calling was used instead of json_schema
assert "tools" in llm.steps[0].kwargs # type: ignore
assert "response_format" not in llm.steps[0].kwargs # type: ignore
assert "tools" in llm.steps[0].kwargs # type: ignore[attr-defined]
assert "response_format" not in llm.steps[0].kwargs # type: ignore[attr-defined]
def test_structured_outputs_parser() -> None:
@@ -1518,7 +1532,7 @@ def test__construct_lc_result_from_responses_api_complex_response() -> None:
arguments='{"location": "New York"}',
),
],
metadata=dict(key1="value1", key2="value2"),
metadata={"key1": "value1", "key2": "value2"},
incomplete_details=IncompleteDetails(reason="max_output_tokens"),
status="completed",
user="user_123",
@@ -1758,7 +1772,6 @@ def test__construct_lc_result_from_responses_api_file_search_response() -> None:
def test__construct_lc_result_from_responses_api_mixed_search_responses() -> None:
"""Test a response with both web search and file search outputs."""
response = Response(
id="resp_123",
created_at=1234567890,
@@ -2229,7 +2242,6 @@ class FakeTracer(BaseTracer):
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
pass
def on_chat_model_start(self, *args: Any, **kwargs: Any) -> Run:
self.chat_model_start_inputs.append({"args": args, "kwargs": kwargs})

View File

@@ -1,4 +1,4 @@
"""Standard LangChain interface tests"""
"""Standard LangChain interface tests."""
from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests import ChatModelUnitTests

View File

@@ -1,4 +1,4 @@
"""Standard LangChain interface tests"""
"""Standard LangChain interface tests."""
from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests import ChatModelUnitTests

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Any, Optional
from unittest.mock import MagicMock, patch
@@ -604,10 +606,9 @@ def _strip_none(obj: Any) -> Any:
"""Recursively strip None values from dictionaries and lists."""
if isinstance(obj, dict):
return {k: _strip_none(v) for k, v in obj.items() if v is not None}
elif isinstance(obj, list):
if isinstance(obj, list):
return [_strip_none(v) for v in obj]
else:
return obj
return obj
def test_responses_stream() -> None:

View File

@@ -32,7 +32,6 @@ def test_embed_documents_with_custom_chunk_size() -> None:
result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
_, tokens, __ = embeddings._tokenize(texts, custom_chunk_size)
mock_create.call_args
mock_create.assert_any_call(input=tokens[0:3], **embeddings._invocation_params)
mock_create.assert_any_call(input=tokens[3:4], **embeddings._invocation_params)
@@ -52,7 +51,6 @@ def test_embed_documents_with_custom_chunk_size_no_check_ctx_length() -> None:
result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
mock_create.call_args
mock_create.assert_any_call(input=texts[0:3], **embeddings._invocation_params)
mock_create.assert_any_call(input=texts[3:4], **embeddings._invocation_params)

View File

@@ -1,4 +1,4 @@
"""Standard LangChain interface tests"""
"""Standard LangChain interface tests."""
from langchain_core.embeddings import Embeddings
from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests

View File

@@ -1,5 +1,7 @@
"""A fake callback handler for testing purposes."""
from __future__ import annotations
from itertools import chain
from typing import Any, Optional, Union
from uuid import UUID
@@ -188,7 +190,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
def on_retriever_error(self, *args: Any, **kwargs: Any) -> Any:
self.on_retriever_error_common()
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore[override]
def __deepcopy__(self, memo: dict) -> FakeCallbackHandler: # type: ignore[override]
return self
@@ -266,5 +268,5 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
async def on_text(self, *args: Any, **kwargs: Any) -> None:
self.on_text_common()
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore[override]
def __deepcopy__(self, memo: dict) -> FakeAsyncCallbackHandler: # type: ignore[override]
return self

File diff suppressed because it is too large Load Diff