mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-03 15:55:44 +00:00
Compare commits
2 Commits
langchain-
...
mdrxy/open
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca10f0eee5 | ||
|
|
e7e8391fd0 |
@@ -3,10 +3,10 @@ from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
from langchain_openai.llms import AzureOpenAI, OpenAI
|
||||
|
||||
__all__ = [
|
||||
"OpenAI",
|
||||
"ChatOpenAI",
|
||||
"OpenAIEmbeddings",
|
||||
"AzureOpenAI",
|
||||
"AzureChatOpenAI",
|
||||
"AzureOpenAI",
|
||||
"AzureOpenAIEmbeddings",
|
||||
"ChatOpenAI",
|
||||
"OpenAI",
|
||||
"OpenAIEmbeddings",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from langchain_openai.chat_models.azure import AzureChatOpenAI
|
||||
from langchain_openai.chat_models.base import ChatOpenAI
|
||||
|
||||
__all__ = ["ChatOpenAI", "AzureChatOpenAI"]
|
||||
__all__ = ["AzureChatOpenAI", "ChatOpenAI"]
|
||||
|
||||
@@ -6,7 +6,10 @@ for each instance of ChatOpenAI.
|
||||
Logic is largely replicated from openai._base_client.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Any, Optional
|
||||
@@ -15,30 +18,26 @@ import openai
|
||||
|
||||
|
||||
class _SyncHttpxClientWrapper(openai.DefaultHttpxClient):
|
||||
"""Borrowed from openai._base_client"""
|
||||
"""Borrowed from openai._base_client."""
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.is_closed:
|
||||
return
|
||||
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
self.close()
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
|
||||
class _AsyncHttpxClientWrapper(openai.DefaultAsyncHttpxClient):
|
||||
"""Borrowed from openai._base_client"""
|
||||
"""Borrowed from openai._base_client."""
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.is_closed:
|
||||
return
|
||||
|
||||
try:
|
||||
# TODO(someday): support non asyncio runtimes here
|
||||
# TODO(someday): support non asyncio runtimes here
|
||||
with contextlib.suppress(Exception):
|
||||
asyncio.get_running_loop().create_task(self.aclose())
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
|
||||
def _build_sync_httpx_client(
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
This module converts between AIMessage output formats for the Responses API.
|
||||
"""Converts between AIMessage output formats for the Responses API.
|
||||
|
||||
ChatOpenAI v0.3 stores reasoning and tool outputs in AIMessage.additional_kwargs:
|
||||
|
||||
@@ -16,7 +15,11 @@ ChatOpenAI v0.3 stores reasoning and tool outputs in AIMessage.additional_kwargs
|
||||
"summary": [{"type": "summary_text", "text": "Reasoning summary"}],
|
||||
},
|
||||
"tool_outputs": [
|
||||
{"type": "web_search_call", "id": "websearch_123", "status": "completed"}
|
||||
{
|
||||
"type": "web_search_call",
|
||||
"id": "websearch_123",
|
||||
"status": "completed",
|
||||
}
|
||||
],
|
||||
"refusal": "I cannot assist with that.",
|
||||
},
|
||||
@@ -54,7 +57,9 @@ content blocks, rather than on the AIMessage.id, which now stores the response I
|
||||
|
||||
For backwards compatibility, this module provides functions to convert between the
|
||||
old and new formats. The functions are used internally by ChatOpenAI.
|
||||
""" # noqa: E501
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Union
|
||||
@@ -65,7 +70,8 @@ _FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
|
||||
|
||||
|
||||
def _convert_to_v03_ai_message(
|
||||
message: AIMessage, has_reasoning: bool = False
|
||||
message: AIMessage,
|
||||
has_reasoning: bool = False, # noqa: FBT001, FBT002
|
||||
) -> AIMessage:
|
||||
"""Mutate an AIMessage to the old-style v0.3 format."""
|
||||
if isinstance(message.content, list):
|
||||
|
||||
@@ -5,12 +5,11 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncIterator, Awaitable, Iterator
|
||||
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import LangSmithParams
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.utils import from_env, secret_from_env
|
||||
@@ -28,18 +27,12 @@ _DictOrPydanticClass = Union[dict[str, Any], type[_BM]]
|
||||
_DictOrPydantic = Union[dict, _BM]
|
||||
|
||||
|
||||
class _AllReturnType(TypedDict):
|
||||
raw: BaseMessage
|
||||
parsed: Optional[_DictOrPydantic]
|
||||
parsing_error: Optional[BaseException]
|
||||
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
return isinstance(obj, type) and is_basemodel_subclass(obj)
|
||||
|
||||
|
||||
class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"""Azure OpenAI chat model integration.
|
||||
r"""Azure OpenAI chat model integration.
|
||||
|
||||
Setup:
|
||||
Head to the https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart?tabs=command-line%2Cpython-new&pivots=programming-language-python
|
||||
@@ -138,7 +131,11 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
|
||||
AIMessage(
|
||||
content="J'adore programmer.",
|
||||
usage_metadata={"input_tokens": 28, "output_tokens": 6, "total_tokens": 34},
|
||||
usage_metadata={
|
||||
"input_tokens": 28,
|
||||
"output_tokens": 6,
|
||||
"total_tokens": 34,
|
||||
},
|
||||
response_metadata={
|
||||
"token_usage": {
|
||||
"completion_tokens": 6,
|
||||
@@ -184,8 +181,12 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
AIMessageChunk(content="ad", id="run-a6f294d3-0700-4f6a-abc2-c6ef1178c37f")
|
||||
AIMessageChunk(content="ore", id="run-a6f294d3-0700-4f6a-abc2-c6ef1178c37f")
|
||||
AIMessageChunk(content=" la", id="run-a6f294d3-0700-4f6a-abc2-c6ef1178c37f")
|
||||
AIMessageChunk(content=" programm", id="run-a6f294d3-0700-4f6a-abc2-c6ef1178c37f")
|
||||
AIMessageChunk(content="ation", id="run-a6f294d3-0700-4f6a-abc2-c6ef1178c37f")
|
||||
AIMessageChunk(
|
||||
content=" programm", id="run-a6f294d3-0700-4f6a-abc2-c6ef1178c37f"
|
||||
)
|
||||
AIMessageChunk(
|
||||
content="ation", id="run-a6f294d3-0700-4f6a-abc2-c6ef1178c37f"
|
||||
)
|
||||
AIMessageChunk(content=".", id="run-a6f294d3-0700-4f6a-abc2-c6ef1178c37f")
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
@@ -294,7 +295,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
|
||||
setup: str = Field(description="The setup of the joke")
|
||||
punchline: str = Field(description="The punchline to the joke")
|
||||
rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
|
||||
rating: Optional[int] = Field(
|
||||
description="How funny the joke is, from 1 to 10"
|
||||
)
|
||||
|
||||
|
||||
structured_llm = llm.with_structured_output(Joke)
|
||||
@@ -465,8 +468,8 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
Example: `https://example-resource.azure.openai.com/`
|
||||
"""
|
||||
deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment")
|
||||
"""A model deployment.
|
||||
|
||||
"""A model deployment.
|
||||
|
||||
If given sets the base client URL to include `/deployments/{azure_deployment}`.
|
||||
Note: this means you won't be able to use non-deployment endpoints.
|
||||
"""
|
||||
@@ -497,27 +500,27 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"""
|
||||
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
||||
"""A function that returns an Azure Active Directory token.
|
||||
|
||||
|
||||
Will be invoked on every sync request. For async requests,
|
||||
will be invoked if `azure_ad_async_token_provider` is not provided.
|
||||
"""
|
||||
|
||||
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
|
||||
"""A function that returns an Azure Active Directory token.
|
||||
|
||||
|
||||
Will be invoked on every async request.
|
||||
"""
|
||||
|
||||
model_version: str = ""
|
||||
"""The version of the model (e.g. "0125" for gpt-3.5-0125).
|
||||
|
||||
Azure OpenAI doesn't return model version with the response by default so it must
|
||||
Azure OpenAI doesn't return model version with the response by default so it must
|
||||
be manually specified if you want to use this information downstream, e.g. when
|
||||
calculating costs.
|
||||
|
||||
When you specify the version, it will be appended to the model name in the
|
||||
response. Setting correct version will help you to calculate the cost properly.
|
||||
Model version is not validated, so make sure you set it correctly to get the
|
||||
When you specify the version, it will be appended to the model name in the
|
||||
response. Setting correct version will help you to calculate the cost properly.
|
||||
Model version is not validated, so make sure you set it correctly to get the
|
||||
correct cost.
|
||||
"""
|
||||
|
||||
@@ -527,36 +530,36 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"""Legacy, for openai<1.0.0 support."""
|
||||
|
||||
validate_base_url: bool = True
|
||||
"""If legacy arg openai_api_base is passed in, try to infer if it is a base_url or
|
||||
"""If legacy arg openai_api_base is passed in, try to infer if it is a base_url or
|
||||
azure_endpoint and update client params accordingly.
|
||||
"""
|
||||
|
||||
model_name: Optional[str] = Field(default=None, alias="model") # type: ignore[assignment]
|
||||
"""Name of the deployed OpenAI model, e.g. "gpt-4o", "gpt-35-turbo", etc.
|
||||
|
||||
"""Name of the deployed OpenAI model, e.g. "gpt-4o", "gpt-35-turbo", etc.
|
||||
|
||||
Distinct from the Azure deployment name, which is set by the Azure user.
|
||||
Used for tracing and token counting. Does NOT affect completion.
|
||||
"""
|
||||
|
||||
disabled_params: Optional[dict[str, Any]] = Field(default=None)
|
||||
"""Parameters of the OpenAI client or chat.completions endpoint that should be
|
||||
"""Parameters of the OpenAI client or chat.completions endpoint that should be
|
||||
disabled for the given model.
|
||||
|
||||
Should be specified as ``{"param": None | ['val1', 'val2']}`` where the key is the
|
||||
Should be specified as ``{"param": None | ['val1', 'val2']}`` where the key is the
|
||||
parameter and the value is either None, meaning that parameter should never be
|
||||
used, or it's a list of disabled values for the parameter.
|
||||
|
||||
For example, older models may not support the 'parallel_tool_calls' parameter at
|
||||
all, in which case ``disabled_params={"parallel_tool_calls: None}`` can ben passed
|
||||
For example, older models may not support the 'parallel_tool_calls' parameter at
|
||||
all, in which case ``disabled_params={"parallel_tool_calls: None}`` can ben passed
|
||||
in.
|
||||
|
||||
|
||||
If a parameter is disabled then it will not be used by default in any methods, e.g.
|
||||
in
|
||||
in
|
||||
:meth:`~langchain_openai.chat_models.azure.AzureChatOpenAI.with_structured_output`.
|
||||
However this does not prevent a user from directly passed in the parameter during
|
||||
invocation.
|
||||
|
||||
By default, unless ``model_name="gpt-4o"`` is specified, then
|
||||
invocation.
|
||||
|
||||
By default, unless ``model_name="gpt-4o"`` is specified, then
|
||||
'parallel_tools_calls' will be disabled.
|
||||
"""
|
||||
|
||||
@@ -580,9 +583,11 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if self.n is not None and self.n < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
elif self.n is not None and self.n > 1 and self.streaming:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
msg = "n must be at least 1."
|
||||
raise ValueError(msg)
|
||||
if self.n is not None and self.n > 1 and self.streaming:
|
||||
msg = "n must be 1 when streaming."
|
||||
raise ValueError(msg)
|
||||
|
||||
if self.disabled_params is None:
|
||||
# As of 09-17-2024 'parallel_tool_calls' param is only supported for gpt-4o.
|
||||
@@ -602,13 +607,14 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
openai_api_base = self.openai_api_base
|
||||
if openai_api_base and self.validate_base_url:
|
||||
if "/openai" not in openai_api_base:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"As of openai>=1.0.0, Azure endpoints should be specified via "
|
||||
"the `azure_endpoint` param not `openai_api_base` "
|
||||
"(or alias `base_url`)."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if self.deployment_name:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"As of openai>=1.0.0, if `azure_deployment` (or alias "
|
||||
"`deployment_name`) is specified then "
|
||||
"`base_url` (or alias `openai_api_base`) should not be. "
|
||||
@@ -620,6 +626,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"Or you can equivalently specify:\n\n"
|
||||
'base_url="https://xxx.openai.azure.com/openai/deployments/my-deployment"'
|
||||
)
|
||||
raise ValueError(msg)
|
||||
client_params: dict = {
|
||||
"api_version": self.openai_api_version,
|
||||
"azure_endpoint": self.azure_endpoint,
|
||||
@@ -665,10 +672,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
@property
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
**{"azure_deployment": self.deployment_name},
|
||||
**super()._identifying_params,
|
||||
}
|
||||
return {"azure_deployment": self.deployment_name, **super()._identifying_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
@@ -709,10 +713,11 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
response = response.model_dump()
|
||||
for res in response["choices"]:
|
||||
if res.get("finish_reason", None) == "content_filter":
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"Azure has not provided the response due to a content filter "
|
||||
"being triggered"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if "model" in response:
|
||||
model = response["model"]
|
||||
@@ -740,8 +745,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"""Route to Chat Completions or Responses API."""
|
||||
if self._use_responses_api({**kwargs, **self.model_kwargs}):
|
||||
return super()._stream_responses(*args, **kwargs)
|
||||
else:
|
||||
return super()._stream(*args, **kwargs)
|
||||
return super()._stream(*args, **kwargs)
|
||||
|
||||
async def _astream(
|
||||
self, *args: Any, **kwargs: Any
|
||||
@@ -763,7 +767,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
strict: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
r"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
schema:
|
||||
@@ -930,7 +934,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
)
|
||||
|
||||
|
||||
llm = AzureChatOpenAI(azure_deployment="...", model="gpt-4o", temperature=0)
|
||||
llm = AzureChatOpenAI(
|
||||
azure_deployment="...", model="gpt-4o", temperature=0
|
||||
)
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification)
|
||||
|
||||
structured_llm.invoke(
|
||||
@@ -961,7 +967,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
)
|
||||
|
||||
|
||||
llm = AzureChatOpenAI(azure_deployment="...", model="gpt-4o", temperature=0)
|
||||
llm = AzureChatOpenAI(
|
||||
azure_deployment="...", model="gpt-4o", temperature=0
|
||||
)
|
||||
structured_llm = llm.with_structured_output(
|
||||
AnswerWithJustification, method="function_calling"
|
||||
)
|
||||
@@ -990,7 +998,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
justification: str
|
||||
|
||||
|
||||
llm = AzureChatOpenAI(azure_deployment="...", model="gpt-4o", temperature=0)
|
||||
llm = AzureChatOpenAI(
|
||||
azure_deployment="...", model="gpt-4o", temperature=0
|
||||
)
|
||||
structured_llm = llm.with_structured_output(
|
||||
AnswerWithJustification, include_raw=True
|
||||
)
|
||||
@@ -1022,7 +1032,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
]
|
||||
|
||||
|
||||
llm = AzureChatOpenAI(azure_deployment="...", model="gpt-4o", temperature=0)
|
||||
llm = AzureChatOpenAI(
|
||||
azure_deployment="...", model="gpt-4o", temperature=0
|
||||
)
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification)
|
||||
|
||||
structured_llm.invoke(
|
||||
@@ -1119,6 +1131,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
# },
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
|
||||
""" # noqa: E501
|
||||
return super().with_structured_output(
|
||||
schema, method=method, include_raw=include_raw, strict=strict, **kwargs
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -138,13 +139,14 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
|
||||
Returns:
|
||||
The LangChain message.
|
||||
|
||||
"""
|
||||
role = _dict.get("role")
|
||||
name = _dict.get("name")
|
||||
id_ = _dict.get("id")
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict.get("content", ""), id=id_, name=name)
|
||||
elif role == "assistant":
|
||||
if role == "assistant":
|
||||
# Fix for azure
|
||||
# Also OpenAI returns None for tool invocations
|
||||
content = _dict.get("content", "") or ""
|
||||
@@ -172,22 +174,19 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
tool_calls=tool_calls,
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
)
|
||||
elif role in ("system", "developer"):
|
||||
if role == "developer":
|
||||
additional_kwargs = {"__openai_role__": role}
|
||||
else:
|
||||
additional_kwargs = {}
|
||||
if role in ("system", "developer"):
|
||||
additional_kwargs = {"__openai_role__": role} if role == "developer" else {}
|
||||
return SystemMessage(
|
||||
content=_dict.get("content", ""),
|
||||
name=name,
|
||||
id=id_,
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
elif role == "function":
|
||||
if role == "function":
|
||||
return FunctionMessage(
|
||||
content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_
|
||||
)
|
||||
elif role == "tool":
|
||||
if role == "tool":
|
||||
additional_kwargs = {}
|
||||
if "name" in _dict:
|
||||
additional_kwargs["name"] = _dict["name"]
|
||||
@@ -198,8 +197,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
name=name,
|
||||
id=id_,
|
||||
)
|
||||
else:
|
||||
return ChatMessage(content=_dict.get("content", ""), role=role, id=id_) # type: ignore[arg-type]
|
||||
return ChatMessage(content=_dict.get("content", ""), role=role, id=id_) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _format_message_content(content: Any) -> Any:
|
||||
@@ -214,7 +212,7 @@ def _format_message_content(content: Any) -> Any:
|
||||
and block["type"] in ("tool_use", "thinking", "reasoning_content")
|
||||
):
|
||||
continue
|
||||
elif isinstance(block, dict) and is_data_content_block(block):
|
||||
if isinstance(block, dict) and is_data_content_block(block):
|
||||
formatted_content.append(convert_to_openai_data_block(block))
|
||||
# Anthropic image blocks
|
||||
elif (
|
||||
@@ -255,6 +253,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
|
||||
Returns:
|
||||
The dictionary.
|
||||
|
||||
"""
|
||||
message_dict: dict[str, Any] = {"content": _format_message_content(message.content)}
|
||||
if (name := message.name or message.additional_kwargs.get("name")) is not None:
|
||||
@@ -314,7 +313,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
supported_props = {"content", "role", "tool_call_id"}
|
||||
message_dict = {k: v for k, v in message_dict.items() if k in supported_props}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
msg = f"Got unknown type {message}"
|
||||
raise TypeError(msg)
|
||||
return message_dict
|
||||
|
||||
|
||||
@@ -333,7 +333,7 @@ def _convert_delta_to_message_chunk(
|
||||
tool_call_chunks = []
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
try:
|
||||
with contextlib.suppress(KeyError):
|
||||
tool_call_chunks = [
|
||||
tool_call_chunk(
|
||||
name=rtc["function"].get("name"),
|
||||
@@ -343,19 +343,17 @@ def _convert_delta_to_message_chunk(
|
||||
)
|
||||
for rtc in raw_tool_calls
|
||||
]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content, id=id_)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
if role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
id=id_,
|
||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
||||
)
|
||||
elif role in ("system", "developer") or default_class == SystemMessageChunk:
|
||||
if role in ("system", "developer") or default_class == SystemMessageChunk:
|
||||
if role == "developer":
|
||||
additional_kwargs = {"__openai_role__": "developer"}
|
||||
else:
|
||||
@@ -363,16 +361,15 @@ def _convert_delta_to_message_chunk(
|
||||
return SystemMessageChunk(
|
||||
content=content, id=id_, additional_kwargs=additional_kwargs
|
||||
)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
if role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
|
||||
elif role == "tool" or default_class == ToolMessageChunk:
|
||||
if role == "tool" or default_class == ToolMessageChunk:
|
||||
return ToolMessageChunk(
|
||||
content=content, tool_call_id=_dict["tool_call_id"], id=id_
|
||||
)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
if role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role, id=id_)
|
||||
else:
|
||||
return default_class(content=content, id=id_) # type: ignore
|
||||
return default_class(content=content, id=id_) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _update_token_usage(
|
||||
@@ -382,24 +379,25 @@ def _update_token_usage(
|
||||
# `reasoning_tokens` is nested inside `completion_tokens_details`
|
||||
if isinstance(new_usage, int):
|
||||
if not isinstance(overall_token_usage, int):
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Got different types for token usage: "
|
||||
f"{type(new_usage)} and {type(overall_token_usage)}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return new_usage + overall_token_usage
|
||||
elif isinstance(new_usage, dict):
|
||||
if isinstance(new_usage, dict):
|
||||
if not isinstance(overall_token_usage, dict):
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Got different types for token usage: "
|
||||
f"{type(new_usage)} and {type(overall_token_usage)}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return {
|
||||
k: _update_token_usage(overall_token_usage.get(k, 0), v)
|
||||
for k, v in new_usage.items()
|
||||
}
|
||||
else:
|
||||
warnings.warn(f"Unexpected type for token usage: {type(new_usage)}")
|
||||
return new_usage
|
||||
warnings.warn(f"Unexpected type for token usage: {type(new_usage)}", stacklevel=3)
|
||||
return new_usage
|
||||
|
||||
|
||||
def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
|
||||
@@ -412,20 +410,19 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
|
||||
"langchain-openai==0.3. To use `with_structured_output` with this model, "
|
||||
'specify `method="function_calling"`.'
|
||||
)
|
||||
warnings.warn(message)
|
||||
warnings.warn(message, stacklevel=3)
|
||||
raise e
|
||||
elif "Invalid schema for response_format" in e.message:
|
||||
if "Invalid schema for response_format" in e.message:
|
||||
message = (
|
||||
"Invalid schema for OpenAI's structured output feature, which is the "
|
||||
"default method for `with_structured_output` as of langchain-openai==0.3. "
|
||||
'Specify `method="function_calling"` instead or update your schema. '
|
||||
"See supported schemas: "
|
||||
"https://platform.openai.com/docs/guides/structured-outputs#supported-schemas" # noqa: E501
|
||||
"https://platform.openai.com/docs/guides/structured-outputs#supported-schemas"
|
||||
)
|
||||
warnings.warn(message)
|
||||
warnings.warn(message, stacklevel=3)
|
||||
raise e
|
||||
else:
|
||||
raise
|
||||
raise
|
||||
|
||||
|
||||
class _FunctionCall(TypedDict):
|
||||
@@ -437,12 +434,6 @@ _DictOrPydanticClass = Union[dict[str, Any], type[_BM], type]
|
||||
_DictOrPydantic = Union[dict, _BM]
|
||||
|
||||
|
||||
class _AllReturnType(TypedDict):
|
||||
raw: BaseMessage
|
||||
parsed: Optional[_DictOrPydantic]
|
||||
parsing_error: Optional[BaseException]
|
||||
|
||||
|
||||
class BaseChatOpenAI(BaseChatModel):
|
||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
@@ -458,7 +449,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
|
||||
)
|
||||
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
openai_organization: Optional[str] = Field(default=None, alias="organization")
|
||||
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
|
||||
@@ -469,7 +460,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
request_timeout: Union[float, tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
|
||||
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
|
||||
None."""
|
||||
stream_usage: bool = False
|
||||
"""Whether to include usage metadata in streaming output. If True, an additional
|
||||
@@ -489,7 +480,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"""Whether to return logprobs."""
|
||||
top_logprobs: Optional[int] = None
|
||||
"""Number of most likely tokens to return at each token position, each with
|
||||
an associated log probability. `logprobs` must be set to true
|
||||
an associated log probability. `logprobs` must be set to true
|
||||
if this parameter is used."""
|
||||
logit_bias: Optional[dict[int, int]] = None
|
||||
"""Modify the likelihood of specified tokens appearing in the completion."""
|
||||
@@ -507,7 +498,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
Reasoning models only, like OpenAI o1, o3, and o4-mini.
|
||||
|
||||
Currently supported values are low, medium, and high. Reducing reasoning effort
|
||||
Currently supported values are low, medium, and high. Reducing reasoning effort
|
||||
can result in faster responses and fewer tokens used on reasoning in a response.
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
@@ -528,25 +519,25 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
.. versionadded:: 0.3.24
|
||||
"""
|
||||
tiktoken_model_name: Optional[str] = None
|
||||
"""The model name to pass to tiktoken when using this class.
|
||||
Tiktoken is used to count the number of tokens in documents to constrain
|
||||
them to be under a certain limit. By default, when set to None, this will
|
||||
be the same as the embedding model name. However, there are some cases
|
||||
where you may want to use this Embedding class with a model name not
|
||||
supported by tiktoken. This can include when using Azure embeddings or
|
||||
when using one of the many model providers that expose an OpenAI-like
|
||||
API but with different models. In those cases, in order to avoid erroring
|
||||
"""The model name to pass to tiktoken when using this class.
|
||||
Tiktoken is used to count the number of tokens in documents to constrain
|
||||
them to be under a certain limit. By default, when set to None, this will
|
||||
be the same as the embedding model name. However, there are some cases
|
||||
where you may want to use this Embedding class with a model name not
|
||||
supported by tiktoken. This can include when using Azure embeddings or
|
||||
when using one of the many model providers that expose an OpenAI-like
|
||||
API but with different models. In those cases, in order to avoid erroring
|
||||
when tiktoken is called, you can specify a model name to use here."""
|
||||
default_headers: Union[Mapping[str, str], None] = None
|
||||
default_query: Union[Mapping[str, object], None] = None
|
||||
# Configure a custom httpx client. See the
|
||||
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
|
||||
http_client: Union[Any, None] = Field(default=None, exclude=True)
|
||||
"""Optional httpx.Client. Only used for sync invocations. Must specify
|
||||
"""Optional httpx.Client. Only used for sync invocations. Must specify
|
||||
http_async_client as well if you'd like a custom client for async invocations.
|
||||
"""
|
||||
http_async_client: Union[Any, None] = Field(default=None, exclude=True)
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
http_client as well if you'd like a custom client for sync invocations."""
|
||||
stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences")
|
||||
"""Default stop sequences."""
|
||||
@@ -556,21 +547,21 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
include_response_headers: bool = False
|
||||
"""Whether to include response headers in the output message response_metadata."""
|
||||
disabled_params: Optional[dict[str, Any]] = Field(default=None)
|
||||
"""Parameters of the OpenAI client or chat.completions endpoint that should be
|
||||
"""Parameters of the OpenAI client or chat.completions endpoint that should be
|
||||
disabled for the given model.
|
||||
|
||||
Should be specified as ``{"param": None | ['val1', 'val2']}`` where the key is the
|
||||
|
||||
Should be specified as ``{"param": None | ['val1', 'val2']}`` where the key is the
|
||||
parameter and the value is either None, meaning that parameter should never be
|
||||
used, or it's a list of disabled values for the parameter.
|
||||
|
||||
For example, older models may not support the 'parallel_tool_calls' parameter at
|
||||
all, in which case ``disabled_params={"parallel_tool_calls": None}`` can be passed
|
||||
|
||||
For example, older models may not support the 'parallel_tool_calls' parameter at
|
||||
all, in which case ``disabled_params={"parallel_tool_calls": None}`` can be passed
|
||||
in.
|
||||
|
||||
|
||||
If a parameter is disabled then it will not be used by default in any methods, e.g.
|
||||
in :meth:`~langchain_openai.chat_models.base.ChatOpenAI.with_structured_output`.
|
||||
However this does not prevent a user from directly passed in the parameter during
|
||||
invocation.
|
||||
invocation.
|
||||
"""
|
||||
|
||||
include: Optional[list[str]] = None
|
||||
@@ -675,13 +666,12 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
return values
|
||||
return _build_model_kwargs(values, all_required_field_names)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_temperature(cls, values: dict[str, Any]) -> Any:
|
||||
"""Currently o1 models only allow temperature=1."""
|
||||
"""o1 models only allow temperature=1."""
|
||||
model = values.get("model_name") or values.get("model") or ""
|
||||
if model.startswith("o1") and "temperature" not in values:
|
||||
values["temperature"] = 1
|
||||
@@ -691,9 +681,11 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if self.n is not None and self.n < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
elif self.n is not None and self.n > 1 and self.streaming:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
msg = "n must be at least 1."
|
||||
raise ValueError(msg)
|
||||
if self.n is not None and self.n > 1 and self.streaming:
|
||||
msg = "n must be 1 when streaming."
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check OPENAI_ORGANIZATION for backwards compatibility.
|
||||
self.openai_organization = (
|
||||
@@ -719,20 +711,22 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
openai_proxy = self.openai_proxy
|
||||
http_client = self.http_client
|
||||
http_async_client = self.http_async_client
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"Cannot specify 'openai_proxy' if one of "
|
||||
"'http_client'/'http_async_client' is already specified. Received:\n"
|
||||
f"{openai_proxy=}\n{http_client=}\n{http_async_client=}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if not self.client:
|
||||
if self.openai_proxy and not self.http_client:
|
||||
try:
|
||||
import httpx
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"Could not import httpx python package. "
|
||||
"Please install it with `pip install httpx`."
|
||||
) from e
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
self.http_client = httpx.Client(
|
||||
proxy=self.openai_proxy, verify=global_ssl_context
|
||||
)
|
||||
@@ -747,10 +741,11 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
try:
|
||||
import httpx
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"Could not import httpx python package. "
|
||||
"Please install it with `pip install httpx`."
|
||||
) from e
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
self.http_async_client = httpx.AsyncClient(
|
||||
proxy=self.openai_proxy, verify=global_ssl_context
|
||||
)
|
||||
@@ -791,15 +786,13 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"store": self.store,
|
||||
}
|
||||
|
||||
params = {
|
||||
return {
|
||||
"model": self.model_name,
|
||||
"stream": self.streaming,
|
||||
**{k: v for k, v in exclude_if_none.items() if v is not None},
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
return params
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
system_fingerprint = None
|
||||
@@ -845,11 +838,10 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
)
|
||||
if len(choices) == 0:
|
||||
# logprobs is implicitly None
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
return ChatGenerationChunk(
|
||||
message=default_chunk_class(content="", usage_metadata=usage_metadata),
|
||||
generation_info=base_generation_info,
|
||||
)
|
||||
return generation_chunk
|
||||
|
||||
choice = choices[0]
|
||||
if choice["delta"] is None:
|
||||
@@ -876,10 +868,9 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
|
||||
message_chunk.usage_metadata = usage_metadata
|
||||
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
return ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
)
|
||||
return generation_chunk
|
||||
|
||||
def _stream_responses(
|
||||
self,
|
||||
@@ -990,7 +981,9 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
yield generation_chunk
|
||||
|
||||
def _should_stream_usage(
|
||||
self, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||
self,
|
||||
stream_usage: Optional[bool] = None, # noqa: FBT001
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""Determine whether to include usage metadata in streaming output.
|
||||
|
||||
@@ -1029,7 +1022,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if self.include_response_headers:
|
||||
warnings.warn(
|
||||
"Cannot currently include response headers when response_format is "
|
||||
"specified."
|
||||
"specified.",
|
||||
stacklevel=2,
|
||||
)
|
||||
payload.pop("stream")
|
||||
response_stream = self.root_client.beta.chat.completions.stream(**payload)
|
||||
@@ -1096,7 +1090,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if self.include_response_headers:
|
||||
warnings.warn(
|
||||
"Cannot currently include response headers when response_format is "
|
||||
"specified."
|
||||
"specified.",
|
||||
stacklevel=2,
|
||||
)
|
||||
payload.pop("stream")
|
||||
try:
|
||||
@@ -1133,18 +1128,15 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
def _use_responses_api(self, payload: dict) -> bool:
|
||||
if isinstance(self.use_responses_api, bool):
|
||||
return self.use_responses_api
|
||||
elif self.output_version == "responses/v1":
|
||||
if (
|
||||
self.output_version == "responses/v1"
|
||||
or self.include is not None
|
||||
or self.reasoning is not None
|
||||
or self.truncation is not None
|
||||
or self.use_previous_response_id
|
||||
):
|
||||
return True
|
||||
elif self.include is not None:
|
||||
return True
|
||||
elif self.reasoning is not None:
|
||||
return True
|
||||
elif self.truncation is not None:
|
||||
return True
|
||||
elif self.use_previous_response_id:
|
||||
return True
|
||||
else:
|
||||
return _use_responses_api(payload)
|
||||
return _use_responses_api(payload)
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
@@ -1192,12 +1184,12 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
try:
|
||||
choices = response_dict["choices"]
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
f"Response missing `choices` key: {response_dict.keys()}"
|
||||
) from e
|
||||
msg = f"Response missing `choices` key: {response_dict.keys()}"
|
||||
raise KeyError(msg) from e
|
||||
|
||||
if choices is None:
|
||||
raise TypeError("Received response with null value for `choices`.")
|
||||
msg = "Received response with null value for `choices`."
|
||||
raise TypeError(msg)
|
||||
|
||||
token_usage = response_dict.get("usage")
|
||||
|
||||
@@ -1257,7 +1249,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if self.include_response_headers:
|
||||
warnings.warn(
|
||||
"Cannot currently include response headers when response_format is "
|
||||
"specified."
|
||||
"specified.",
|
||||
stacklevel=2,
|
||||
)
|
||||
payload.pop("stream")
|
||||
response_stream = self.root_async_client.beta.chat.completions.stream(
|
||||
@@ -1328,7 +1321,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if self.include_response_headers:
|
||||
warnings.warn(
|
||||
"Cannot currently include response headers when response_format is "
|
||||
"specified."
|
||||
"specified.",
|
||||
stacklevel=2,
|
||||
)
|
||||
payload.pop("stream")
|
||||
try:
|
||||
@@ -1439,7 +1433,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if self.custom_get_token_ids is not None:
|
||||
return self.custom_get_token_ids(text)
|
||||
# tiktoken NOT supported for Python 3.7 or below
|
||||
if sys.version_info[1] <= 7:
|
||||
if sys.version_info[1] <= 7: # noqa: YTT203
|
||||
return super().get_token_ids(text)
|
||||
_, encoding_model = self._get_encoding_model()
|
||||
return encoding_model.encode(text)
|
||||
@@ -1459,20 +1453,21 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
as a URL. If these aren't installed image inputs will be ignored in token
|
||||
counting.
|
||||
|
||||
OpenAI reference: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
||||
OpenAI reference: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||
to be converted to tool schemas.
|
||||
|
||||
"""
|
||||
# TODO: Count bound tools as part of input.
|
||||
if tools is not None:
|
||||
warnings.warn(
|
||||
"Counting tokens in tool schemas is not yet supported. Ignoring tools."
|
||||
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
|
||||
stacklevel=2,
|
||||
)
|
||||
if sys.version_info[1] <= 7:
|
||||
if sys.version_info[1] <= 7: # noqa: YTT203
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
model, encoding = self._get_encoding_model()
|
||||
if model.startswith("gpt-3.5-turbo-0301"):
|
||||
@@ -1480,16 +1475,17 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
|
||||
elif model.startswith(("gpt-3.5-turbo", "gpt-4")):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
msg = (
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}. See "
|
||||
"https://platform.openai.com/docs/guides/text-generation/managing-tokens" # noqa: E501
|
||||
"https://platform.openai.com/docs/guides/text-generation/managing-tokens"
|
||||
" for information on how messages are converted to tokens."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
num_tokens = 0
|
||||
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
@@ -1524,13 +1520,12 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
elif val["type"] == "file":
|
||||
warnings.warn(
|
||||
"Token counts for file inputs are not supported. "
|
||||
"Ignoring file inputs."
|
||||
"Ignoring file inputs.",
|
||||
stacklevel=2,
|
||||
)
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized content block type\n\n{val}"
|
||||
)
|
||||
msg = f"Unrecognized content block type\n\n{val}"
|
||||
raise ValueError(msg)
|
||||
elif not value:
|
||||
continue
|
||||
else:
|
||||
@@ -1552,7 +1547,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
self,
|
||||
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
function_call: Optional[
|
||||
Union[_FunctionCall, str, Literal["auto", "none"]]
|
||||
Union[_FunctionCall, str, Literal["auto", "none"]] # noqa: PYI051
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
@@ -1575,8 +1570,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
(if any).
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
"""
|
||||
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
|
||||
if function_call is not None:
|
||||
function_call = (
|
||||
@@ -1586,18 +1581,20 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
else function_call
|
||||
)
|
||||
if isinstance(function_call, dict) and len(formatted_functions) != 1:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"When specifying `function_call`, you must provide exactly one "
|
||||
"function."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if (
|
||||
isinstance(function_call, dict)
|
||||
and formatted_functions[0]["name"] != function_call["name"]
|
||||
):
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Function call {function_call} was specified, but the only "
|
||||
f"provided function was {formatted_functions[0]['name']}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
kwargs = {**kwargs, "function_call": function_call}
|
||||
return super().bind(functions=formatted_functions, **kwargs)
|
||||
|
||||
@@ -1606,7 +1603,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
||||
Union[dict, str, Literal["auto", "none", "required", "any"], bool] # noqa: PYI051
|
||||
] = None,
|
||||
strict: Optional[bool] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
@@ -1645,7 +1642,6 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
Support for ``strict`` argument added.
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
if parallel_tool_calls is not None:
|
||||
kwargs["parallel_tool_calls"] = parallel_tool_calls
|
||||
formatted_tools = [
|
||||
@@ -1680,10 +1676,11 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
elif isinstance(tool_choice, dict):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Unrecognized tool_choice type. Expected str, bool or dict. "
|
||||
f"Received: {tool_choice}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
kwargs["tool_choice"] = tool_choice
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
@@ -1829,11 +1826,11 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
.. versionchanged:: 0.3.21
|
||||
Pass ``kwargs`` through to the model.
|
||||
|
||||
""" # noqa: E501
|
||||
if strict is not None and method == "json_mode":
|
||||
raise ValueError(
|
||||
"Argument `strict` is not supported with `method`='json_mode'"
|
||||
)
|
||||
msg = "Argument `strict` is not supported with `method`='json_mode'"
|
||||
raise ValueError(msg)
|
||||
is_pydantic_schema = _is_pydantic_class(schema)
|
||||
|
||||
if method == "json_schema":
|
||||
@@ -1845,7 +1842,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"Received a Pydantic BaseModel V1 schema. This is not supported by "
|
||||
'method="json_schema". Please use method="function_calling" '
|
||||
"or specify schema via JSON Schema or Pydantic V2 BaseModel. "
|
||||
'Overriding to method="function_calling".'
|
||||
'Overriding to method="function_calling".',
|
||||
stacklevel=2,
|
||||
)
|
||||
method = "function_calling"
|
||||
# Check for incompatible model
|
||||
@@ -1860,28 +1858,28 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
f"see supported models here: "
|
||||
f"https://platform.openai.com/docs/guides/structured-outputs#supported-models. " # noqa: E501
|
||||
"To fix this warning, set `method='function_calling'. "
|
||||
"Overriding to method='function_calling'."
|
||||
"Overriding to method='function_calling'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
method = "function_calling"
|
||||
|
||||
if method == "function_calling":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"schema must be specified when method is not 'json_mode'. "
|
||||
"Received None."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
tool_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||
bind_kwargs = self._filter_disabled_params(
|
||||
**{
|
||||
**dict(
|
||||
tool_choice=tool_name,
|
||||
parallel_tool_calls=False,
|
||||
strict=strict,
|
||||
ls_structured_output_format={
|
||||
"kwargs": {"method": method, "strict": strict},
|
||||
"schema": schema,
|
||||
},
|
||||
),
|
||||
"tool_choice": tool_name,
|
||||
"parallel_tool_calls": False,
|
||||
"strict": strict,
|
||||
"ls_structured_output_format": {
|
||||
"kwargs": {"method": method, "strict": strict},
|
||||
"schema": schema,
|
||||
},
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
@@ -1899,13 +1897,11 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
elif method == "json_mode":
|
||||
llm = self.bind(
|
||||
**{
|
||||
**dict(
|
||||
response_format={"type": "json_object"},
|
||||
ls_structured_output_format={
|
||||
"kwargs": {"method": method},
|
||||
"schema": schema,
|
||||
},
|
||||
),
|
||||
"response_format": {"type": "json_object"},
|
||||
"ls_structured_output_format": {
|
||||
"kwargs": {"method": method},
|
||||
"schema": schema,
|
||||
},
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
@@ -1916,10 +1912,11 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
)
|
||||
elif method == "json_schema":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"schema must be specified when method is not 'json_mode'. "
|
||||
"Received None."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
response_format = _convert_to_openai_response_format(schema, strict=strict)
|
||||
bind_kwargs = {
|
||||
**dict(
|
||||
@@ -1943,10 +1940,11 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
else:
|
||||
output_parser = JsonOutputParser()
|
||||
else:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
||||
f"'json_mode'. Received: '{method}'"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
@@ -1957,8 +1955,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
return llm | output_parser
|
||||
|
||||
def _filter_disabled_params(self, **kwargs: Any) -> dict[str, Any]:
|
||||
if not self.disabled_params:
|
||||
@@ -1971,8 +1968,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
):
|
||||
continue
|
||||
# Keep param
|
||||
else:
|
||||
filtered[k] = v
|
||||
filtered[k] = v
|
||||
return filtered
|
||||
|
||||
def _get_generation_chunk_from_completion(
|
||||
@@ -1999,7 +1995,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
|
||||
class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
"""OpenAI chat model integration.
|
||||
r"""OpenAI chat model integration.
|
||||
|
||||
.. dropdown:: Setup
|
||||
:open:
|
||||
@@ -2126,7 +2122,9 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
|
||||
AIMessageChunk(content="", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
|
||||
AIMessageChunk(content="J", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
|
||||
AIMessageChunk(content="'adore", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
|
||||
AIMessageChunk(
|
||||
content="'adore", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0"
|
||||
)
|
||||
AIMessageChunk(content=" la", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
|
||||
AIMessageChunk(
|
||||
content=" programmation", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0"
|
||||
@@ -2182,7 +2180,11 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
"logprobs": None,
|
||||
},
|
||||
id="run-012cffe2-5d3d-424d-83b5-51c6d4a593d1-0",
|
||||
usage_metadata={"input_tokens": 31, "output_tokens": 5, "total_tokens": 36},
|
||||
usage_metadata={
|
||||
"input_tokens": 31,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 36,
|
||||
},
|
||||
)
|
||||
|
||||
.. dropdown:: Tool calling
|
||||
@@ -2299,7 +2301,9 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
tool = {"type": "web_search_preview"}
|
||||
llm_with_tools = llm.bind_tools([tool])
|
||||
|
||||
response = llm_with_tools.invoke("What was a positive news story from today?")
|
||||
response = llm_with_tools.invoke(
|
||||
"What was a positive news story from today?"
|
||||
)
|
||||
response.content
|
||||
|
||||
.. code-block:: python
|
||||
@@ -2346,7 +2350,8 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
.. code-block:: python
|
||||
|
||||
second_response = llm.invoke(
|
||||
"What is my name?", previous_response_id=response.response_metadata["id"]
|
||||
"What is my name?",
|
||||
previous_response_id=response.response_metadata["id"],
|
||||
)
|
||||
second_response.text()
|
||||
|
||||
@@ -2424,7 +2429,9 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
|
||||
setup: str = Field(description="The setup of the joke")
|
||||
punchline: str = Field(description="The punchline to the joke")
|
||||
rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
|
||||
rating: Optional[int] = Field(
|
||||
description="How funny the joke is, from 1 to 10"
|
||||
)
|
||||
|
||||
|
||||
structured_llm = llm.with_structured_output(Joke)
|
||||
@@ -2684,8 +2691,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
"""Route to Chat Completions or Responses API."""
|
||||
if self._use_responses_api({**kwargs, **self.model_kwargs}):
|
||||
return super()._stream_responses(*args, **kwargs)
|
||||
else:
|
||||
return super()._stream(*args, **kwargs)
|
||||
return super()._stream(*args, **kwargs)
|
||||
|
||||
async def _astream(
|
||||
self, *args: Any, **kwargs: Any
|
||||
@@ -2707,7 +2713,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
strict: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
r"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
schema:
|
||||
@@ -3057,6 +3063,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
# },
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
|
||||
""" # noqa: E501
|
||||
return super().with_structured_output(
|
||||
schema, method=method, include_raw=include_raw, strict=strict, **kwargs
|
||||
@@ -3113,13 +3120,12 @@ def _url_to_size(image_source: str) -> Optional[tuple[int, int]]:
|
||||
response.raise_for_status()
|
||||
width, height = Image.open(BytesIO(response.content)).size
|
||||
return width, height
|
||||
elif _is_b64(image_source):
|
||||
if _is_b64(image_source):
|
||||
_, encoded = image_source.split(",", 1)
|
||||
data = base64.b64decode(encoded)
|
||||
width, height = Image.open(BytesIO(data)).size
|
||||
return width, height
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _count_image_tokens(width: int, height: int) -> int:
|
||||
@@ -3208,17 +3214,16 @@ def _oai_structured_outputs_parser(
|
||||
if parsed := ai_msg.additional_kwargs.get("parsed"):
|
||||
if isinstance(parsed, dict):
|
||||
return schema(**parsed)
|
||||
else:
|
||||
return parsed
|
||||
elif ai_msg.additional_kwargs.get("refusal"):
|
||||
return parsed
|
||||
if ai_msg.additional_kwargs.get("refusal"):
|
||||
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
|
||||
elif ai_msg.tool_calls:
|
||||
if ai_msg.tool_calls:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(
|
||||
"Structured Output response does not have a 'parsed' field nor a 'refusal' "
|
||||
f"field. Received message:\n\n{ai_msg}"
|
||||
)
|
||||
msg = (
|
||||
"Structured Output response does not have a 'parsed' field nor a 'refusal' "
|
||||
f"field. Received message:\n\n{ai_msg}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
class OpenAIRefusalError(Exception):
|
||||
@@ -3315,14 +3320,16 @@ def _use_responses_api(payload: dict) -> bool:
|
||||
def _get_last_messages(
|
||||
messages: Sequence[BaseMessage],
|
||||
) -> tuple[Sequence[BaseMessage], Optional[str]]:
|
||||
"""
|
||||
Return
|
||||
"""Return the last messages in a conversation.
|
||||
|
||||
Returns:
|
||||
1. Every message after the most-recent AIMessage that has a non-empty
|
||||
``response_metadata["id"]`` (may be an empty list),
|
||||
2. That id.
|
||||
|
||||
If the most-recent AIMessage does not have an id (or there is no
|
||||
AIMessage at all) the entire conversation is returned together with ``None``.
|
||||
|
||||
"""
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
msg = messages[i]
|
||||
@@ -3330,8 +3337,7 @@ def _get_last_messages(
|
||||
response_id = msg.response_metadata.get("id")
|
||||
if response_id:
|
||||
return messages[i + 1 :], response_id
|
||||
else:
|
||||
return messages, None
|
||||
return messages, None
|
||||
|
||||
return messages, None
|
||||
|
||||
@@ -3358,13 +3364,14 @@ def _construct_responses_api_payload(
|
||||
if tool["type"] == "image_generation":
|
||||
# Handle partial images (not yet supported)
|
||||
if "partial_images" in tool:
|
||||
raise NotImplementedError(
|
||||
msg = (
|
||||
"Partial image generation is not yet supported "
|
||||
"via the LangChain ChatOpenAI client. Please "
|
||||
"drop the 'partial_images' key from the image_generation "
|
||||
"tool."
|
||||
)
|
||||
elif payload.get("stream") and "partial_images" not in tool:
|
||||
raise NotImplementedError(msg)
|
||||
if payload.get("stream") and "partial_images" not in tool:
|
||||
# OpenAI requires this parameter be set; we ignore it during
|
||||
# streaming.
|
||||
tool["partial_images"] = 1
|
||||
@@ -3390,10 +3397,11 @@ def _construct_responses_api_payload(
|
||||
if schema := payload.pop("response_format", None):
|
||||
if payload.get("text"):
|
||||
text = payload["text"]
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"Can specify at most one of 'response_format' or 'text', received both:"
|
||||
f"\n{schema=}\n{text=}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# For pydantic + non-streaming case, we use responses.parse.
|
||||
# Otherwise, we use responses.create.
|
||||
@@ -3447,7 +3455,7 @@ def _make_computer_call_output_from_message(message: ToolMessage) -> dict:
|
||||
def _pop_index_and_sub_index(block: dict) -> dict:
|
||||
"""When streaming, langchain-core uses the ``index`` key to aggregate
|
||||
text blocks. OpenAI API does not support this key, so we need to remove it.
|
||||
"""
|
||||
""" # noqa: D205
|
||||
new_block = {k: v for k, v in block.items() if k != "index"}
|
||||
if "summary" in new_block and isinstance(new_block["summary"], list):
|
||||
new_summary = []
|
||||
@@ -3758,7 +3766,7 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
current_sub_index: int, # index of content block in output item
|
||||
schema: Optional[type[_BM]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
has_reasoning: bool = False,
|
||||
has_reasoning: bool = False, # noqa: FBT001, FBT002
|
||||
output_version: Literal["v0", "responses/v1"] = "v0",
|
||||
) -> tuple[int, int, int, Optional[ChatGenerationChunk]]:
|
||||
def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None:
|
||||
@@ -3808,12 +3816,9 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
content = []
|
||||
tool_call_chunks: list = []
|
||||
additional_kwargs: dict = {}
|
||||
if metadata:
|
||||
response_metadata = metadata
|
||||
else:
|
||||
response_metadata = {}
|
||||
response_metadata = metadata or {}
|
||||
usage_metadata = None
|
||||
id = None
|
||||
id_ = None
|
||||
if chunk.type == "response.output_text.delta":
|
||||
_advance(chunk.output_index, chunk.content_index)
|
||||
content.append({"type": "text", "text": chunk.delta, "index": current_index})
|
||||
@@ -3828,7 +3833,7 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
elif chunk.type == "response.output_text.done":
|
||||
content.append({"id": chunk.item_id, "index": current_index})
|
||||
elif chunk.type == "response.created":
|
||||
id = chunk.response.id
|
||||
id_ = chunk.response.id
|
||||
response_metadata["id"] = chunk.response.id # Backwards compatibility
|
||||
elif chunk.type == "response.completed":
|
||||
msg = cast(
|
||||
@@ -3849,7 +3854,7 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
}
|
||||
elif chunk.type == "response.output_item.added" and chunk.item.type == "message":
|
||||
if output_version == "v0":
|
||||
id = chunk.item.id
|
||||
id_ = chunk.item.id
|
||||
else:
|
||||
pass
|
||||
elif (
|
||||
@@ -3944,7 +3949,7 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
usage_metadata=usage_metadata,
|
||||
response_metadata=response_metadata,
|
||||
additional_kwargs=additional_kwargs,
|
||||
id=id,
|
||||
id=id_,
|
||||
)
|
||||
if output_version == "v0":
|
||||
message = cast(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from langchain_openai.embeddings.azure import AzureOpenAIEmbeddings
|
||||
from langchain_openai.embeddings.base import OpenAIEmbeddings
|
||||
|
||||
__all__ = ["OpenAIEmbeddings", "AzureOpenAIEmbeddings"]
|
||||
__all__ = ["AzureOpenAIEmbeddings", "OpenAIEmbeddings"]
|
||||
|
||||
@@ -20,7 +20,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): # type: ignore[override]
|
||||
To access AzureOpenAI embedding models you'll need to create an Azure account,
|
||||
get an API key, and install the `langchain-openai` integration package.
|
||||
|
||||
You’ll need to have an Azure OpenAI instance deployed.
|
||||
You'll need to have an Azure OpenAI instance deployed.
|
||||
You can deploy a version on Azure Portal following this
|
||||
[guide](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource?pivots=web-portal).
|
||||
|
||||
@@ -131,7 +131,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): # type: ignore[override]
|
||||
alias="api_version",
|
||||
)
|
||||
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided.
|
||||
|
||||
|
||||
Set to "2023-05-15" by default if env variable `OPENAI_API_VERSION` is not set.
|
||||
"""
|
||||
azure_ad_token: Optional[SecretStr] = Field(
|
||||
@@ -171,19 +171,21 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): # type: ignore[override]
|
||||
if openai_api_base and self.validate_base_url:
|
||||
if "/openai" not in openai_api_base:
|
||||
self.openai_api_base = cast(str, self.openai_api_base) + "/openai"
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"As of openai>=1.0.0, Azure endpoints should be specified via "
|
||||
"the `azure_endpoint` param not `openai_api_base` "
|
||||
"(or alias `base_url`). "
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if self.deployment:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"As of openai>=1.0.0, if `deployment` (or alias "
|
||||
"`azure_deployment`) is specified then "
|
||||
"`openai_api_base` (or alias `base_url`) should not be. "
|
||||
"Instead use `deployment` (or alias `azure_deployment`) "
|
||||
"and `azure_endpoint`."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
client_params: dict = {
|
||||
"api_version": self.openai_api_version,
|
||||
"azure_endpoint": self.azure_endpoint,
|
||||
|
||||
@@ -21,7 +21,7 @@ def _process_batched_chunked_embeddings(
|
||||
tokens: list[Union[list[int], str]],
|
||||
batched_embeddings: list[list[float]],
|
||||
indices: list[int],
|
||||
skip_empty: bool,
|
||||
skip_empty: bool, # noqa: FBT001
|
||||
) -> list[Optional[list[float]]]:
|
||||
# for each text, this is the list of embeddings (list of list of floats)
|
||||
# corresponding to the chunks of the text
|
||||
@@ -50,29 +50,25 @@ def _process_batched_chunked_embeddings(
|
||||
embeddings.append(None)
|
||||
continue
|
||||
|
||||
elif len(_result) == 1:
|
||||
if len(_result) == 1:
|
||||
# if only one embedding was produced, use it
|
||||
embeddings.append(_result[0])
|
||||
continue
|
||||
|
||||
else:
|
||||
# else we need to weighted average
|
||||
# should be same as
|
||||
# average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
total_weight = sum(num_tokens_in_batch[i])
|
||||
average = [
|
||||
sum(
|
||||
val * weight
|
||||
for val, weight in zip(embedding, num_tokens_in_batch[i])
|
||||
)
|
||||
/ total_weight
|
||||
for embedding in zip(*_result)
|
||||
]
|
||||
# else we need to weighted average
|
||||
# should be same as
|
||||
# average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
total_weight = sum(num_tokens_in_batch[i])
|
||||
average = [
|
||||
sum(val * weight for val, weight in zip(embedding, num_tokens_in_batch[i]))
|
||||
/ total_weight
|
||||
for embedding in zip(*_result)
|
||||
]
|
||||
|
||||
# should be same as
|
||||
# embeddings.append((average / np.linalg.norm(average)).tolist())
|
||||
magnitude = sum(val**2 for val in average) ** 0.5
|
||||
embeddings.append([val / magnitude for val in average])
|
||||
# should be same as
|
||||
# embeddings.append((average / np.linalg.norm(average)).tolist())
|
||||
magnitude = sum(val**2 for val in average) ** 0.5
|
||||
embeddings.append([val / magnitude for val in average])
|
||||
|
||||
return embeddings
|
||||
|
||||
@@ -265,21 +261,24 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
msg = f"Found {field_name} supplied twice."
|
||||
raise ValueError(msg)
|
||||
if field_name not in all_required_field_names:
|
||||
warnings.warn(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
Please confirm that {field_name} is what you intended.""",
|
||||
stacklevel=2,
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
@@ -288,10 +287,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if self.openai_api_type in ("azure", "azure_ad", "azuread"):
|
||||
raise ValueError(
|
||||
"If you are using Azure, "
|
||||
"please use the `AzureOpenAIEmbeddings` class."
|
||||
msg = (
|
||||
"If you are using Azure, please use the `AzureOpenAIEmbeddings` class."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
client_params: dict = {
|
||||
"api_key": (
|
||||
self.openai_api_key.get_secret_value() if self.openai_api_key else None
|
||||
@@ -308,20 +307,22 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
openai_proxy = self.openai_proxy
|
||||
http_client = self.http_client
|
||||
http_async_client = self.http_async_client
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"Cannot specify 'openai_proxy' if one of "
|
||||
"'http_client'/'http_async_client' is already specified. Received:\n"
|
||||
f"{openai_proxy=}\n{http_client=}\n{http_async_client=}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if not self.client:
|
||||
if self.openai_proxy and not self.http_client:
|
||||
try:
|
||||
import httpx
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"Could not import httpx python package. "
|
||||
"Please install it with `pip install httpx`."
|
||||
) from e
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
self.http_client = httpx.Client(proxy=self.openai_proxy)
|
||||
sync_specific = {"http_client": self.http_client}
|
||||
self.client = openai.OpenAI(**client_params, **sync_specific).embeddings # type: ignore[arg-type]
|
||||
@@ -330,10 +331,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
try:
|
||||
import httpx
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"Could not import httpx python package. "
|
||||
"Please install it with `pip install httpx`."
|
||||
) from e
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
self.http_async_client = httpx.AsyncClient(proxy=self.openai_proxy)
|
||||
async_specific = {"http_client": self.http_async_client}
|
||||
self.async_client = openai.AsyncOpenAI(
|
||||
@@ -352,8 +354,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
def _tokenize(
|
||||
self, texts: list[str], chunk_size: int
|
||||
) -> tuple[Iterable[int], list[Union[list[int], str]], list[int]]:
|
||||
"""
|
||||
Take the input `texts` and `chunk_size` and return 3 iterables as a tuple:
|
||||
"""Take the input `texts` and `chunk_size` and return 3 iterables as a tuple.
|
||||
|
||||
We have `batches`, where batches are sets of individual texts
|
||||
we want responses from the openai api. The length of a single batch is
|
||||
@@ -380,12 +381,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
if not self.tiktoken_enabled:
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"Could not import transformers python package. "
|
||||
"This is needed for OpenAIEmbeddings to work without "
|
||||
"`tiktoken`. Please install it with `pip install transformers`. "
|
||||
)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path=model_name
|
||||
@@ -455,8 +457,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
chunk_size: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Generate length-safe embeddings for a list of texts.
|
||||
"""Generate length-safe embeddings for a list of texts.
|
||||
|
||||
This method handles tokenization and embedding generation, respecting the
|
||||
set embedding context length and chunk size. It supports both tiktoken
|
||||
@@ -466,9 +467,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
texts (List[str]): A list of texts to embed.
|
||||
engine (str): The engine or model to use for embeddings.
|
||||
chunk_size (Optional[int]): The size of chunks for processing embeddings.
|
||||
kwargs (Any): Additional keyword arguments to pass to the embedding API.
|
||||
|
||||
Returns:
|
||||
List[List[float]]: A list of embeddings for each input text.
|
||||
|
||||
"""
|
||||
_chunk_size = chunk_size or self.chunk_size
|
||||
client_kwargs = {**self._invocation_params, **kwargs}
|
||||
@@ -508,8 +511,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
chunk_size: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Asynchronously generate length-safe embeddings for a list of texts.
|
||||
"""Asynchronously generate length-safe embeddings for a list of texts.
|
||||
|
||||
This method handles tokenization and asynchronous embedding generation,
|
||||
respecting the set embedding context length and chunk size. It supports both
|
||||
@@ -519,11 +521,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
texts (List[str]): A list of texts to embed.
|
||||
engine (str): The engine or model to use for embeddings.
|
||||
chunk_size (Optional[int]): The size of chunks for processing embeddings.
|
||||
kwargs (Any): Additional keyword arguments to pass to the embedding API.
|
||||
|
||||
Returns:
|
||||
List[List[float]]: A list of embeddings for each input text.
|
||||
"""
|
||||
|
||||
"""
|
||||
_chunk_size = chunk_size or self.chunk_size
|
||||
client_kwargs = {**self._invocation_params, **kwargs}
|
||||
_iter, tokens, indices = await run_in_executor(
|
||||
@@ -570,6 +573,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
|
||||
"""
|
||||
chunk_size_ = chunk_size or self.chunk_size
|
||||
client_kwargs = {**self._invocation_params, **kwargs}
|
||||
@@ -604,6 +608,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
|
||||
"""
|
||||
chunk_size_ = chunk_size or self.chunk_size
|
||||
client_kwargs = {**self._invocation_params, **kwargs}
|
||||
@@ -634,6 +639,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
|
||||
"""
|
||||
return self.embed_documents([text], **kwargs)[0]
|
||||
|
||||
@@ -646,6 +652,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
|
||||
"""
|
||||
embeddings = await self.aembed_documents([text], **kwargs)
|
||||
return embeddings[0]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from langchain_openai.llms.azure import AzureOpenAI
|
||||
from langchain_openai.llms.base import OpenAI
|
||||
|
||||
__all__ = ["OpenAI", "AzureOpenAI"]
|
||||
__all__ = ["AzureOpenAI", "OpenAI"]
|
||||
|
||||
@@ -30,6 +30,7 @@ class AzureOpenAI(BaseOpenAI):
|
||||
from langchain_openai import AzureOpenAI
|
||||
|
||||
openai = AzureOpenAI(model_name="gpt-3.5-turbo-instruct")
|
||||
|
||||
"""
|
||||
|
||||
azure_endpoint: Optional[str] = Field(
|
||||
@@ -42,7 +43,7 @@ class AzureOpenAI(BaseOpenAI):
|
||||
Example: `https://example-resource.azure.openai.com/`
|
||||
"""
|
||||
deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment")
|
||||
"""A model deployment.
|
||||
"""A model deployment.
|
||||
|
||||
If given sets the base client URL to include `/deployments/{azure_deployment}`.
|
||||
Note: this means you won't be able to use non-deployment endpoints.
|
||||
@@ -87,7 +88,7 @@ class AzureOpenAI(BaseOpenAI):
|
||||
)
|
||||
"""Legacy, for openai<1.0.0 support."""
|
||||
validate_base_url: bool = True
|
||||
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
|
||||
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
|
||||
infer if it is a base_url or azure_endpoint and update accordingly.
|
||||
"""
|
||||
|
||||
@@ -112,11 +113,14 @@ class AzureOpenAI(BaseOpenAI):
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if self.n < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
msg = "n must be at least 1."
|
||||
raise ValueError(msg)
|
||||
if self.streaming and self.n > 1:
|
||||
raise ValueError("Cannot stream results when n > 1.")
|
||||
msg = "Cannot stream results when n > 1."
|
||||
raise ValueError(msg)
|
||||
if self.streaming and self.best_of > 1:
|
||||
raise ValueError("Cannot stream results when best_of > 1.")
|
||||
msg = "Cannot stream results when best_of > 1."
|
||||
raise ValueError(msg)
|
||||
# For backwards compatibility. Before openai v1, no distinction was made
|
||||
# between azure_endpoint and base_url (openai_api_base).
|
||||
openai_api_base = self.openai_api_base
|
||||
@@ -125,19 +129,21 @@ class AzureOpenAI(BaseOpenAI):
|
||||
self.openai_api_base = (
|
||||
cast(str, self.openai_api_base).rstrip("/") + "/openai"
|
||||
)
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"As of openai>=1.0.0, Azure endpoints should be specified via "
|
||||
"the `azure_endpoint` param not `openai_api_base` "
|
||||
"(or alias `base_url`)."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if self.deployment_name:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"As of openai>=1.0.0, if `deployment_name` (or alias "
|
||||
"`azure_deployment`) is specified then "
|
||||
"`openai_api_base` (or alias `base_url`) should not be. "
|
||||
"Instead use `deployment_name` (or alias `azure_deployment`) "
|
||||
"and `azure_endpoint`."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
self.deployment_name = None
|
||||
client_params: dict = {
|
||||
"api_version": self.openai_api_version,
|
||||
@@ -183,10 +189,7 @@ class AzureOpenAI(BaseOpenAI):
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
**{"deployment_name": self.deployment_name},
|
||||
**super()._identifying_params,
|
||||
}
|
||||
return {"deployment_name": self.deployment_name, **super()._identifying_params}
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> dict[str, Any]:
|
||||
|
||||
@@ -41,10 +41,10 @@ def _stream_response_to_generation_chunk(
|
||||
return GenerationChunk(text="")
|
||||
return GenerationChunk(
|
||||
text=stream_response["choices"][0]["text"] or "",
|
||||
generation_info=dict(
|
||||
finish_reason=stream_response["choices"][0].get("finish_reason", None),
|
||||
logprobs=stream_response["choices"][0].get("logprobs", None),
|
||||
),
|
||||
generation_info={
|
||||
"finish_reason": stream_response["choices"][0].get("finish_reason", None),
|
||||
"logprobs": stream_response["choices"][0].get("logprobs", None),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ class BaseOpenAI(BaseLLM):
|
||||
openai_api_base: Optional[str] = Field(
|
||||
alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None)
|
||||
)
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
openai_organization: Optional[str] = Field(
|
||||
alias="organization",
|
||||
@@ -98,7 +98,7 @@ class BaseOpenAI(BaseLLM):
|
||||
request_timeout: Union[float, tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
|
||||
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
|
||||
None."""
|
||||
logit_bias: Optional[dict[str, float]] = None
|
||||
"""Adjust the probability of specific tokens being generated."""
|
||||
@@ -116,25 +116,25 @@ class BaseOpenAI(BaseLLM):
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
|
||||
"""Set of special tokens that are not allowed。"""
|
||||
tiktoken_model_name: Optional[str] = None
|
||||
"""The model name to pass to tiktoken when using this class.
|
||||
Tiktoken is used to count the number of tokens in documents to constrain
|
||||
them to be under a certain limit. By default, when set to None, this will
|
||||
be the same as the embedding model name. However, there are some cases
|
||||
where you may want to use this Embedding class with a model name not
|
||||
supported by tiktoken. This can include when using Azure embeddings or
|
||||
when using one of the many model providers that expose an OpenAI-like
|
||||
API but with different models. In those cases, in order to avoid erroring
|
||||
"""The model name to pass to tiktoken when using this class.
|
||||
Tiktoken is used to count the number of tokens in documents to constrain
|
||||
them to be under a certain limit. By default, when set to None, this will
|
||||
be the same as the embedding model name. However, there are some cases
|
||||
where you may want to use this Embedding class with a model name not
|
||||
supported by tiktoken. This can include when using Azure embeddings or
|
||||
when using one of the many model providers that expose an OpenAI-like
|
||||
API but with different models. In those cases, in order to avoid erroring
|
||||
when tiktoken is called, you can specify a model name to use here."""
|
||||
default_headers: Union[Mapping[str, str], None] = None
|
||||
default_query: Union[Mapping[str, object], None] = None
|
||||
# Configure a custom httpx client. See the
|
||||
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
|
||||
http_client: Union[Any, None] = None
|
||||
"""Optional httpx.Client. Only used for sync invocations. Must specify
|
||||
"""Optional httpx.Client. Only used for sync invocations. Must specify
|
||||
http_async_client as well if you'd like a custom client for async invocations.
|
||||
"""
|
||||
http_async_client: Union[Any, None] = None
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
http_client as well if you'd like a custom client for sync invocations."""
|
||||
extra_body: Optional[Mapping[str, Any]] = None
|
||||
"""Optional additional JSON properties to include in the request parameters when
|
||||
@@ -147,18 +147,20 @@ class BaseOpenAI(BaseLLM):
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
return values
|
||||
return _build_model_kwargs(values, all_required_field_names)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if self.n < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
msg = "n must be at least 1."
|
||||
raise ValueError(msg)
|
||||
if self.streaming and self.n > 1:
|
||||
raise ValueError("Cannot stream results when n > 1.")
|
||||
msg = "Cannot stream results when n > 1."
|
||||
raise ValueError(msg)
|
||||
if self.streaming and self.best_of > 1:
|
||||
raise ValueError("Cannot stream results when best_of > 1.")
|
||||
msg = "Cannot stream results when best_of > 1."
|
||||
raise ValueError(msg)
|
||||
|
||||
client_params: dict = {
|
||||
"api_key": (
|
||||
@@ -280,6 +282,8 @@ class BaseOpenAI(BaseLLM):
|
||||
Args:
|
||||
prompts: The prompts to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
run_manager: Optional callback manager to use for callbacks.
|
||||
kwargs: Additional keyword arguments to pass to the model.
|
||||
|
||||
Returns:
|
||||
The full LLM output.
|
||||
@@ -288,6 +292,7 @@ class BaseOpenAI(BaseLLM):
|
||||
.. code-block:: python
|
||||
|
||||
response = openai.generate(["Tell me a joke."])
|
||||
|
||||
"""
|
||||
# TODO: write a unit test for this
|
||||
params = self._invocation_params
|
||||
@@ -302,7 +307,8 @@ class BaseOpenAI(BaseLLM):
|
||||
for _prompts in sub_prompts:
|
||||
if self.streaming:
|
||||
if len(_prompts) > 1:
|
||||
raise ValueError("Cannot stream results with multiple prompts.")
|
||||
msg = "Cannot stream results with multiple prompts."
|
||||
raise ValueError(msg)
|
||||
|
||||
generation: Optional[GenerationChunk] = None
|
||||
for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs):
|
||||
@@ -311,7 +317,8 @@ class BaseOpenAI(BaseLLM):
|
||||
else:
|
||||
generation += chunk
|
||||
if generation is None:
|
||||
raise ValueError("Generation is empty after streaming.")
|
||||
msg = "Generation is empty after streaming."
|
||||
raise ValueError(msg)
|
||||
choices.append(
|
||||
{
|
||||
"text": generation.text,
|
||||
@@ -369,7 +376,8 @@ class BaseOpenAI(BaseLLM):
|
||||
for _prompts in sub_prompts:
|
||||
if self.streaming:
|
||||
if len(_prompts) > 1:
|
||||
raise ValueError("Cannot stream results with multiple prompts.")
|
||||
msg = "Cannot stream results with multiple prompts."
|
||||
raise ValueError(msg)
|
||||
|
||||
generation: Optional[GenerationChunk] = None
|
||||
async for chunk in self._astream(
|
||||
@@ -380,7 +388,8 @@ class BaseOpenAI(BaseLLM):
|
||||
else:
|
||||
generation += chunk
|
||||
if generation is None:
|
||||
raise ValueError("Generation is empty after streaming.")
|
||||
msg = "Generation is empty after streaming."
|
||||
raise ValueError(msg)
|
||||
choices.append(
|
||||
{
|
||||
"text": generation.text,
|
||||
@@ -417,15 +426,13 @@ class BaseOpenAI(BaseLLM):
|
||||
params["stop"] = stop
|
||||
if params["max_tokens"] == -1:
|
||||
if len(prompts) != 1:
|
||||
raise ValueError(
|
||||
"max_tokens set to -1 not supported for multiple inputs."
|
||||
)
|
||||
msg = "max_tokens set to -1 not supported for multiple inputs."
|
||||
raise ValueError(msg)
|
||||
params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
|
||||
sub_prompts = [
|
||||
return [
|
||||
prompts[i : i + self.batch_size]
|
||||
for i in range(0, len(prompts), self.batch_size)
|
||||
]
|
||||
return sub_prompts
|
||||
|
||||
def create_llm_result(
|
||||
self,
|
||||
@@ -445,10 +452,10 @@ class BaseOpenAI(BaseLLM):
|
||||
[
|
||||
Generation(
|
||||
text=choice["text"],
|
||||
generation_info=dict(
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
logprobs=choice.get("logprobs"),
|
||||
),
|
||||
generation_info={
|
||||
"finish_reason": choice.get("finish_reason"),
|
||||
"logprobs": choice.get("logprobs"),
|
||||
},
|
||||
)
|
||||
for choice in sub_choices
|
||||
]
|
||||
@@ -466,7 +473,7 @@ class BaseOpenAI(BaseLLM):
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model_name": self.model_name}, **self._default_params}
|
||||
return {"model_name": self.model_name, **self._default_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
@@ -478,7 +485,7 @@ class BaseOpenAI(BaseLLM):
|
||||
if self.custom_get_token_ids is not None:
|
||||
return self.custom_get_token_ids(text)
|
||||
# tiktoken NOT supported for Python < 3.8
|
||||
if sys.version_info[1] < 8:
|
||||
if sys.version_info[1] < 8: # noqa: YTT203
|
||||
return super().get_num_tokens(text)
|
||||
|
||||
model_name = self.tiktoken_model_name or self.model_name
|
||||
@@ -507,6 +514,7 @@ class BaseOpenAI(BaseLLM):
|
||||
.. code-block:: python
|
||||
|
||||
max_tokens = openai.modelname_to_contextsize("gpt-3.5-turbo-instruct")
|
||||
|
||||
"""
|
||||
model_token_mapping = {
|
||||
"gpt-4o-mini": 128_000,
|
||||
@@ -543,7 +551,7 @@ class BaseOpenAI(BaseLLM):
|
||||
if "ft-" in modelname:
|
||||
modelname = modelname.split(":")[0]
|
||||
|
||||
context_size = model_token_mapping.get(modelname, None)
|
||||
context_size = model_token_mapping.get(modelname)
|
||||
|
||||
if context_size is None:
|
||||
raise ValueError(
|
||||
@@ -571,6 +579,7 @@ class BaseOpenAI(BaseLLM):
|
||||
.. code-block:: python
|
||||
|
||||
max_tokens = openai.max_tokens_for_prompt("Tell me a joke.")
|
||||
|
||||
"""
|
||||
num_tokens = self.get_num_tokens(prompt)
|
||||
return self.max_context_size - num_tokens
|
||||
@@ -689,7 +698,7 @@ class OpenAI(BaseOpenAI):
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> dict[str, Any]:
|
||||
return {**{"model": self.model_name}, **super()._invocation_params}
|
||||
return {"model": self.model_name, **super()._invocation_params}
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
|
||||
@@ -4,4 +4,4 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
PydanticToolsParser,
|
||||
)
|
||||
|
||||
__all__ = ["PydanticToolsParser", "JsonOutputToolsParser", "JsonOutputKeyToolsParser"]
|
||||
__all__ = ["JsonOutputKeyToolsParser", "JsonOutputToolsParser", "PydanticToolsParser"]
|
||||
|
||||
@@ -64,12 +64,67 @@ ignore_missing_imports = true
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "T201", "UP", "S"]
|
||||
ignore = [ "UP007", ]
|
||||
select = [
|
||||
"A", # flake8-builtins
|
||||
"B", # flake8-bugbear
|
||||
"ASYNC", # flake8-async
|
||||
"C4", # flake8-comprehensions
|
||||
"COM", # flake8-commas
|
||||
"D", # pydocstyle
|
||||
"DOC", # pydoclint
|
||||
"E", # pycodestyle error
|
||||
"EM", # flake8-errmsg
|
||||
"F", # pyflakes
|
||||
"FA", # flake8-future-annotations
|
||||
"FBT", # flake8-boolean-trap
|
||||
"FLY", # flake8-flynt
|
||||
"I", # isort
|
||||
"ICN", # flake8-import-conventions
|
||||
"INT", # flake8-gettext
|
||||
"ISC", # isort-comprehensions
|
||||
"PGH", # pygrep-hooks
|
||||
"PIE", # flake8-pie
|
||||
"PERF", # flake8-perf
|
||||
"PYI", # flake8-pyi
|
||||
"Q", # flake8-quotes
|
||||
"RET", # flake8-return
|
||||
"RSE", # flake8-rst-docstrings
|
||||
"RUF", # ruff
|
||||
"S", # flake8-bandit
|
||||
"SLF", # flake8-self
|
||||
"SLOT", # flake8-slots
|
||||
"SIM", # flake8-simplify
|
||||
"T10", # flake8-debugger
|
||||
"T20", # flake8-print
|
||||
"TID", # flake8-tidy-imports
|
||||
"UP", # pyupgrade
|
||||
"W", # pycodestyle warning
|
||||
"YTT", # flake8-2020
|
||||
]
|
||||
ignore = [
|
||||
"D100", # pydocstyle: Missing docstring in public module
|
||||
"D101", # pydocstyle: Missing docstring in public class
|
||||
"D102", # pydocstyle: Missing docstring in public method
|
||||
"D103", # pydocstyle: Missing docstring in public function
|
||||
"D104", # pydocstyle: Missing docstring in public package
|
||||
"D105", # pydocstyle: Missing docstring in magic method
|
||||
"D107", # pydocstyle: Missing docstring in __init__
|
||||
"D203", # Messes with the formatter
|
||||
"D407", # pydocstyle: Missing-dashed-underline-after-section
|
||||
"COM812", # Messes with the formatter
|
||||
"ISC001", # Messes with the formatter
|
||||
"D213", # Messes with the formatter
|
||||
"PERF203", # Rarely useful
|
||||
"S112", # Rarely useful
|
||||
"RUF012", # Doesn't play well with Pydantic
|
||||
"SLF001", # Private member access
|
||||
"UP007", # pyupgrade: non-pep604-annotation-union
|
||||
"UP045", # pyupgrade: non-pep604-annotation-optional
|
||||
]
|
||||
unfixable = ["B028"] # People should intentionally tune the stacklevel
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
skip-magic-trailing-comma = true
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Test AzureChatOpenAI wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
@@ -173,7 +175,6 @@ def test_openai_streaming(llm: AzureChatOpenAI) -> None:
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_astream(llm: AzureChatOpenAI) -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
async for chunk in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(chunk.content, str)
|
||||
@@ -185,7 +186,6 @@ async def test_openai_astream(llm: AzureChatOpenAI) -> None:
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_abatch(llm: AzureChatOpenAI) -> None:
|
||||
"""Test streaming tokens from AzureChatOpenAI."""
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
@@ -194,7 +194,6 @@ async def test_openai_abatch(llm: AzureChatOpenAI) -> None:
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_abatch_tags(llm: AzureChatOpenAI) -> None:
|
||||
"""Test batch tokens from AzureChatOpenAI."""
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
@@ -205,7 +204,6 @@ async def test_openai_abatch_tags(llm: AzureChatOpenAI) -> None:
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_batch(llm: AzureChatOpenAI) -> None:
|
||||
"""Test batch tokens from AzureChatOpenAI."""
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
@@ -214,7 +212,6 @@ def test_openai_batch(llm: AzureChatOpenAI) -> None:
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_ainvoke(llm: AzureChatOpenAI) -> None:
|
||||
"""Test invoke tokens from AzureChatOpenAI."""
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
assert result.response_metadata.get("model_name") is not None
|
||||
@@ -223,8 +220,7 @@ async def test_openai_ainvoke(llm: AzureChatOpenAI) -> None:
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_invoke(llm: AzureChatOpenAI) -> None:
|
||||
"""Test invoke tokens from AzureChatOpenAI."""
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
assert result.response_metadata.get("model_name") is not None
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
"""Standard LangChain interface tests."""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Test ChatOpenAI chat model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
@@ -66,7 +68,7 @@ def test_chat_openai_model() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_responses_api", [False, True])
|
||||
def test_chat_openai_system_message(use_responses_api: bool) -> None:
|
||||
def test_chat_openai_system_message(use_responses_api: bool) -> None: # noqa: FBT001
|
||||
"""Test ChatOpenAI wrapper with system message."""
|
||||
chat = ChatOpenAI(use_responses_api=use_responses_api, max_tokens=MAX_TOKEN_COUNT) # type: ignore[call-arg]
|
||||
system_message = SystemMessage(content="You are to chat with the user.")
|
||||
@@ -108,7 +110,7 @@ def test_chat_openai_multiple_completions() -> None:
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@pytest.mark.parametrize("use_responses_api", [False, True])
|
||||
def test_chat_openai_streaming(use_responses_api: bool) -> None:
|
||||
def test_chat_openai_streaming(use_responses_api: bool) -> None: # noqa: FBT001
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
@@ -204,7 +206,7 @@ async def test_async_chat_openai_bind_functions() -> None:
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@pytest.mark.parametrize("use_responses_api", [False, True])
|
||||
async def test_openai_abatch_tags(use_responses_api: bool) -> None:
|
||||
async def test_openai_abatch_tags(use_responses_api: bool) -> None: # noqa: FBT001
|
||||
"""Test batch tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=MAX_TOKEN_COUNT, use_responses_api=use_responses_api) # type: ignore[call-arg]
|
||||
|
||||
@@ -223,7 +225,7 @@ def test_openai_invoke() -> None:
|
||||
service_tier="flex", # Also test service_tier
|
||||
)
|
||||
|
||||
result = llm.invoke("Hello", config=dict(tags=["foo"]))
|
||||
result = llm.invoke("Hello", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
# assert no response headers if include_response_headers is not set
|
||||
@@ -255,11 +257,12 @@ def test_stream() -> None:
|
||||
if chunk.response_metadata:
|
||||
chunks_with_response_metadata += 1
|
||||
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
|
||||
raise AssertionError(
|
||||
msg = (
|
||||
"Expected exactly one chunk with metadata. "
|
||||
"AIMessageChunk aggregation can add these metadata. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
assert isinstance(aggregate, AIMessageChunk)
|
||||
assert aggregate.usage_metadata is not None
|
||||
assert aggregate.usage_metadata["input_tokens"] > 0
|
||||
@@ -270,7 +273,7 @@ def test_stream() -> None:
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
|
||||
async def _test_stream(stream: AsyncIterator, expect_usage: bool) -> None:
|
||||
async def _test_stream(stream: AsyncIterator, expect_usage: bool) -> None: # noqa: FBT001
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
chunks_with_token_counts = 0
|
||||
chunks_with_response_metadata = 0
|
||||
@@ -284,20 +287,22 @@ async def test_astream() -> None:
|
||||
chunks_with_response_metadata += 1
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
if chunks_with_response_metadata != 1:
|
||||
raise AssertionError(
|
||||
msg = (
|
||||
"Expected exactly one chunk with metadata. "
|
||||
"AIMessageChunk aggregation can add these metadata. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
assert full.response_metadata.get("finish_reason") is not None
|
||||
assert full.response_metadata.get("model_name") is not None
|
||||
if expect_usage:
|
||||
if chunks_with_token_counts != 1:
|
||||
raise AssertionError(
|
||||
msg = (
|
||||
"Expected exactly one chunk with token counts. "
|
||||
"AIMessageChunk aggregation adds counts. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
assert full.usage_metadata is not None
|
||||
assert full.usage_metadata["input_tokens"] > 0
|
||||
assert full.usage_metadata["output_tokens"] > 0
|
||||
@@ -401,14 +406,14 @@ async def test_async_response_metadata_streaming() -> None:
|
||||
|
||||
|
||||
class GenerateUsername(BaseModel):
|
||||
"Get a username based on someone's name and hair color."
|
||||
"""Get a username based on someone's name and hair color."""
|
||||
|
||||
name: str
|
||||
hair_color: str
|
||||
|
||||
|
||||
class MakeASandwich(BaseModel):
|
||||
"Make a sandwich given a list of ingredients."
|
||||
"""Make a sandwich given a list of ingredients."""
|
||||
|
||||
bread_type: str
|
||||
cheese_type: str
|
||||
@@ -442,7 +447,7 @@ def test_tool_use() -> None:
|
||||
gathered = message
|
||||
first = False
|
||||
else:
|
||||
gathered = gathered + message # type: ignore
|
||||
gathered = gathered + message # type: ignore[assignment]
|
||||
assert isinstance(gathered, AIMessageChunk)
|
||||
assert isinstance(gathered.tool_call_chunks, list)
|
||||
assert len(gathered.tool_call_chunks) == 1
|
||||
@@ -458,7 +463,7 @@ def test_tool_use() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_responses_api", [False, True])
|
||||
def test_manual_tool_call_msg(use_responses_api: bool) -> None:
|
||||
def test_manual_tool_call_msg(use_responses_api: bool) -> None: # noqa: FBT001
|
||||
"""Test passing in manually construct tool call message."""
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-3.5-turbo-0125", temperature=0, use_responses_api=use_responses_api
|
||||
@@ -499,12 +504,12 @@ def test_manual_tool_call_msg(use_responses_api: bool) -> None:
|
||||
),
|
||||
ToolMessage("sally_green_hair", tool_call_id="foo"),
|
||||
]
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
llm_with_tool.invoke(msgs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_responses_api", [False, True])
|
||||
def test_bind_tools_tool_choice(use_responses_api: bool) -> None:
|
||||
def test_bind_tools_tool_choice(use_responses_api: bool) -> None: # noqa: FBT001
|
||||
"""Test passing in manually construct tool call message."""
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-3.5-turbo-0125", temperature=0, use_responses_api=use_responses_api
|
||||
@@ -536,7 +541,7 @@ def test_disable_parallel_tool_calling() -> None:
|
||||
@pytest.mark.parametrize("model", ["gpt-4o-mini", "o1", "gpt-4"])
|
||||
def test_openai_structured_output(model: str) -> None:
|
||||
class MyModel(BaseModel):
|
||||
"""A Person"""
|
||||
"""A Person."""
|
||||
|
||||
name: str
|
||||
age: int
|
||||
@@ -553,7 +558,7 @@ def test_openai_proxy() -> None:
|
||||
chat_openai = ChatOpenAI(openai_proxy="http://localhost:8080")
|
||||
mounts = chat_openai.client._client._client._mounts
|
||||
assert len(mounts) == 1
|
||||
for key, value in mounts.items():
|
||||
for value in mounts.values():
|
||||
proxy = value._pool._proxy_url.origin
|
||||
assert proxy.scheme == b"http"
|
||||
assert proxy.host == b"localhost"
|
||||
@@ -561,7 +566,7 @@ def test_openai_proxy() -> None:
|
||||
|
||||
async_client_mounts = chat_openai.async_client._client._client._mounts
|
||||
assert len(async_client_mounts) == 1
|
||||
for key, value in async_client_mounts.items():
|
||||
for value in async_client_mounts.values():
|
||||
proxy = value._pool._proxy_url.origin
|
||||
assert proxy.scheme == b"http"
|
||||
assert proxy.host == b"localhost"
|
||||
@@ -569,7 +574,7 @@ def test_openai_proxy() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_responses_api", [False, True])
|
||||
def test_openai_response_headers(use_responses_api: bool) -> None:
|
||||
def test_openai_response_headers(use_responses_api: bool) -> None: # noqa: FBT001
|
||||
"""Test ChatOpenAI response headers."""
|
||||
chat_openai = ChatOpenAI(
|
||||
include_response_headers=True, use_responses_api=use_responses_api
|
||||
@@ -593,7 +598,7 @@ def test_openai_response_headers(use_responses_api: bool) -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_responses_api", [False, True])
|
||||
async def test_openai_response_headers_async(use_responses_api: bool) -> None:
|
||||
async def test_openai_response_headers_async(use_responses_api: bool) -> None: # noqa: FBT001
|
||||
"""Test ChatOpenAI response headers."""
|
||||
chat_openai = ChatOpenAI(
|
||||
include_response_headers=True, use_responses_api=use_responses_api
|
||||
@@ -681,7 +686,7 @@ def test_image_token_counting_png() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_responses_api", [False, True])
|
||||
def test_tool_calling_strict(use_responses_api: bool) -> None:
|
||||
def test_tool_calling_strict(use_responses_api: bool) -> None: # noqa: FBT001
|
||||
"""Test tool calling with strict=True.
|
||||
|
||||
Responses API appears to have fewer constraints on schema when strict=True.
|
||||
@@ -714,7 +719,7 @@ def test_tool_calling_strict(use_responses_api: bool) -> None:
|
||||
# Test stream
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in model_with_tools.stream(query):
|
||||
full = chunk if full is None else full + chunk # type: ignore
|
||||
full = chunk if full is None else full + chunk # type: ignore[assignment]
|
||||
assert isinstance(full, AIMessage)
|
||||
_validate_tool_call_message(full)
|
||||
|
||||
@@ -731,10 +736,9 @@ def test_tool_calling_strict(use_responses_api: bool) -> None:
|
||||
def test_structured_output_strict(
|
||||
model: str,
|
||||
method: Literal["function_calling", "json_schema"],
|
||||
use_responses_api: bool,
|
||||
use_responses_api: bool, # noqa: FBT001
|
||||
) -> None:
|
||||
"""Test to verify structured output with strict=True."""
|
||||
|
||||
from pydantic import BaseModel as BaseModelProper
|
||||
from pydantic import Field as FieldProper
|
||||
|
||||
@@ -771,10 +775,11 @@ def test_structured_output_strict(
|
||||
@pytest.mark.parametrize("use_responses_api", [False, True])
|
||||
@pytest.mark.parametrize(("model", "method"), [("gpt-4o-2024-08-06", "json_schema")])
|
||||
def test_nested_structured_output_strict(
|
||||
model: str, method: Literal["json_schema"], use_responses_api: bool
|
||||
model: str,
|
||||
method: Literal["json_schema"],
|
||||
use_responses_api: bool, # noqa: FBT001
|
||||
) -> None:
|
||||
"""Test to verify structured output with strict=True for nested object."""
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
llm = ChatOpenAI(model=model, temperature=0, use_responses_api=use_responses_api)
|
||||
@@ -814,7 +819,8 @@ def test_nested_structured_output_strict(
|
||||
],
|
||||
)
|
||||
def test_json_schema_openai_format(
|
||||
strict: bool, method: Literal["json_schema", "function_calling"]
|
||||
strict: bool, # noqa: FBT001
|
||||
method: Literal["json_schema", "function_calling"],
|
||||
) -> None:
|
||||
"""Test we can pass in OpenAI schema format specifying strict."""
|
||||
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||
@@ -957,7 +963,7 @@ def test_prediction_tokens() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_responses_api", [False, True])
|
||||
def test_stream_o_series(use_responses_api: bool) -> None:
|
||||
def test_stream_o_series(use_responses_api: bool) -> None: # noqa: FBT001
|
||||
list(
|
||||
ChatOpenAI(model="o3-mini", use_responses_api=use_responses_api).stream(
|
||||
"how are you"
|
||||
@@ -966,7 +972,7 @@ def test_stream_o_series(use_responses_api: bool) -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_responses_api", [False, True])
|
||||
async def test_astream_o_series(use_responses_api: bool) -> None:
|
||||
async def test_astream_o_series(use_responses_api: bool) -> None: # noqa: FBT001
|
||||
async for _ in ChatOpenAI(
|
||||
model="o3-mini", use_responses_api=use_responses_api
|
||||
).astream("how are you"):
|
||||
@@ -1013,7 +1019,7 @@ async def test_astream_response_format() -> None:
|
||||
|
||||
@pytest.mark.parametrize("use_responses_api", [False, True])
|
||||
@pytest.mark.parametrize("use_max_completion_tokens", [True, False])
|
||||
def test_o1(use_max_completion_tokens: bool, use_responses_api: bool) -> None:
|
||||
def test_o1(use_max_completion_tokens: bool, use_responses_api: bool) -> None: # noqa: FBT001
|
||||
if use_max_completion_tokens:
|
||||
kwargs: dict = {"max_completion_tokens": MAX_TOKEN_COUNT}
|
||||
else:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
"""Standard LangChain interface tests."""
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
@@ -66,7 +66,7 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
|
||||
readme = f.read()
|
||||
|
||||
input_ = f"""What's langchain? Here's the langchain README:
|
||||
|
||||
|
||||
{readme}
|
||||
"""
|
||||
llm = ChatOpenAI(model="gpt-4o-mini", stream_usage=True)
|
||||
@@ -123,14 +123,13 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
|
||||
_ = model.invoke([message])
|
||||
|
||||
|
||||
def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage:
|
||||
def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage: # noqa: FBT001
|
||||
if stream:
|
||||
full = None
|
||||
for chunk in llm.stream(input_):
|
||||
full = full + chunk if full else chunk # type: ignore[operator]
|
||||
return cast(AIMessage, full)
|
||||
else:
|
||||
return cast(AIMessage, llm.invoke(input_))
|
||||
return cast(AIMessage, llm.invoke(input_))
|
||||
|
||||
|
||||
@pytest.mark.skip() # Test either finishes in 5 seconds or 5 minutes.
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Test Responses API usage."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Annotated, Any, Literal, Optional, cast
|
||||
@@ -144,7 +146,7 @@ async def test_web_search_async() -> None:
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
def test_function_calling() -> None:
|
||||
def multiply(x: int, y: int) -> int:
|
||||
"""return x * y"""
|
||||
"""Return x * y."""
|
||||
return x * y
|
||||
|
||||
llm = ChatOpenAI(model=MODEL_NAME)
|
||||
@@ -277,7 +279,7 @@ async def test_parsed_dict_schema_async(schema: Any) -> None:
|
||||
|
||||
def test_function_calling_and_structured_output() -> None:
|
||||
def multiply(x: int, y: int) -> int:
|
||||
"""return x * y"""
|
||||
"""Return x * y."""
|
||||
return x * y
|
||||
|
||||
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
|
||||
@@ -332,7 +334,7 @@ def test_stateful_api() -> None:
|
||||
"what's my name", previous_response_id=response.response_metadata["id"]
|
||||
)
|
||||
assert isinstance(second_response.content, list)
|
||||
assert "bobo" in second_response.content[0]["text"].lower() # type: ignore
|
||||
assert "bobo" in second_response.content[0]["text"].lower() # type: ignore[index]
|
||||
|
||||
|
||||
def test_route_from_model_kwargs() -> None:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Standard LangChain interface tests for Responses API"""
|
||||
"""Standard LangChain interface tests for Responses API."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
@@ -31,7 +31,7 @@ class TestOpenAIResponses(TestOpenAIStandard):
|
||||
readme = f.read()
|
||||
|
||||
input_ = f"""What's langchain? Here's the langchain README:
|
||||
|
||||
|
||||
{readme}
|
||||
"""
|
||||
llm = ChatOpenAI(model="gpt-4.1-mini", output_version="responses/v1")
|
||||
@@ -49,11 +49,10 @@ class TestOpenAIResponses(TestOpenAIStandard):
|
||||
return _invoke(llm, input_, stream)
|
||||
|
||||
|
||||
def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage:
|
||||
def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage: # noqa: FBT001
|
||||
if stream:
|
||||
full = None
|
||||
for chunk in llm.stream(input_):
|
||||
full = full + chunk if full else chunk # type: ignore[operator]
|
||||
return cast(AIMessage, full)
|
||||
else:
|
||||
return cast(AIMessage, llm.invoke(input_))
|
||||
return cast(AIMessage, llm.invoke(input_))
|
||||
|
||||
@@ -16,7 +16,6 @@ DEPLOYMENT_NAME = os.environ.get(
|
||||
"AZURE_OPENAI_DEPLOYMENT_NAME",
|
||||
os.environ.get("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME", ""),
|
||||
)
|
||||
print
|
||||
|
||||
|
||||
def _get_embeddings(**kwargs: Any) -> AzureOpenAIEmbeddings:
|
||||
@@ -63,7 +62,7 @@ def test_azure_openai_embedding_documents_chunk_size() -> None:
|
||||
# Max 2048 chunks per batch on Azure OpenAI embeddings
|
||||
assert embedding.chunk_size == 2048
|
||||
assert len(output) == 20
|
||||
assert all([len(out) == 1536 for out in output])
|
||||
assert all(len(out) == 1536 for out in output)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@@ -100,7 +99,6 @@ async def test_azure_openai_embedding_async_query() -> None:
|
||||
@pytest.mark.scheduled
|
||||
def test_azure_openai_embedding_with_empty_string() -> None:
|
||||
"""Test openai embeddings with empty string."""
|
||||
|
||||
document = ["", "abc"]
|
||||
embedding = _get_embeddings()
|
||||
output = embedding.embed_documents(document)
|
||||
@@ -112,7 +110,7 @@ def test_azure_openai_embedding_with_empty_string() -> None:
|
||||
api_key=OPENAI_API_KEY,
|
||||
azure_endpoint=OPENAI_API_BASE,
|
||||
azure_deployment=DEPLOYMENT_NAME,
|
||||
) # type: ignore
|
||||
)
|
||||
.embeddings.create(input="", model="text-embedding-ada-002")
|
||||
.data[0]
|
||||
.embedding
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
"""Standard LangChain interface tests."""
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_tests.integration_tests.embeddings import EmbeddingsIntegrationTests
|
||||
|
||||
@@ -98,7 +98,7 @@ async def test_openai_ainvoke(llm: AzureOpenAI) -> None:
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_invoke(llm: AzureOpenAI) -> None:
|
||||
"""Test streaming tokens from AzureOpenAI."""
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ def test_invoke() -> None:
|
||||
"""Test invoke tokens from OpenAI."""
|
||||
llm = OpenAI()
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@@ -164,7 +164,7 @@ def test_openai_invoke() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
|
||||
@@ -4,4 +4,3 @@ import pytest
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
||||
|
||||
@@ -80,8 +80,8 @@ def test_structured_output_old_model() -> None:
|
||||
).with_structured_output(Output)
|
||||
|
||||
# assert tool calling was used instead of json_schema
|
||||
assert "tools" in llm.steps[0].kwargs # type: ignore
|
||||
assert "response_format" not in llm.steps[0].kwargs # type: ignore
|
||||
assert "tools" in llm.steps[0].kwargs # type: ignore[attr-defined]
|
||||
assert "response_format" not in llm.steps[0].kwargs # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_max_completion_tokens_in_payload() -> None:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
"""Standard LangChain interface tests."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Test OpenAI Chat API wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from functools import partial
|
||||
from types import TracebackType
|
||||
@@ -45,7 +47,7 @@ from openai.types.responses.response_usage import (
|
||||
OutputTokensDetails,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
from typing_extensions import Self, TypedDict
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai.chat_models._compat import (
|
||||
@@ -229,7 +231,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
"type": "function",
|
||||
},
|
||||
]
|
||||
raw_tool_calls = list(sorted(raw_tool_calls, key=lambda x: x["id"]))
|
||||
raw_tool_calls = sorted(raw_tool_calls, key=lambda x: x["id"])
|
||||
message = {"role": "assistant", "content": None, "tool_calls": raw_tool_calls}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = AIMessage(
|
||||
@@ -260,8 +262,8 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
)
|
||||
assert result == expected_output
|
||||
reverted_message_dict = _convert_message_to_dict(expected_output)
|
||||
reverted_message_dict["tool_calls"] = list(
|
||||
sorted(reverted_message_dict["tool_calls"], key=lambda x: x["id"])
|
||||
reverted_message_dict["tool_calls"] = sorted(
|
||||
reverted_message_dict["tool_calls"], key=lambda x: x["id"]
|
||||
)
|
||||
assert reverted_message_dict == message
|
||||
|
||||
@@ -272,7 +274,7 @@ class MockAsyncContextManager:
|
||||
self.chunk_list = chunk_list
|
||||
self.chunk_num = len(chunk_list)
|
||||
|
||||
async def __aenter__(self) -> "MockAsyncContextManager":
|
||||
async def __aenter__(self) -> Self:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
@@ -283,7 +285,7 @@ class MockAsyncContextManager:
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def __aiter__(self) -> "MockAsyncContextManager":
|
||||
def __aiter__(self) -> MockAsyncContextManager:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> dict:
|
||||
@@ -291,8 +293,7 @@ class MockAsyncContextManager:
|
||||
chunk = self.chunk_list[self.current_chunk]
|
||||
self.current_chunk += 1
|
||||
return chunk
|
||||
else:
|
||||
raise StopAsyncIteration
|
||||
raise StopAsyncIteration
|
||||
|
||||
|
||||
class MockSyncContextManager:
|
||||
@@ -301,7 +302,7 @@ class MockSyncContextManager:
|
||||
self.chunk_list = chunk_list
|
||||
self.chunk_num = len(chunk_list)
|
||||
|
||||
def __enter__(self) -> "MockSyncContextManager":
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
@@ -312,7 +313,7 @@ class MockSyncContextManager:
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def __iter__(self) -> "MockSyncContextManager":
|
||||
def __iter__(self) -> MockSyncContextManager:
|
||||
return self
|
||||
|
||||
def __next__(self) -> dict:
|
||||
@@ -320,13 +321,12 @@ class MockSyncContextManager:
|
||||
chunk = self.chunk_list[self.current_chunk]
|
||||
self.current_chunk += 1
|
||||
return chunk
|
||||
else:
|
||||
raise StopIteration
|
||||
raise StopIteration
|
||||
|
||||
|
||||
GLM4_STREAM_META = """{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4eba\u5de5\u667a\u80fd"}}]}
|
||||
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u52a9\u624b"}}]}
|
||||
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":","}}]}
|
||||
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":","}}]}
|
||||
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4f60\u53ef\u4ee5"}}]}
|
||||
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u53eb\u6211"}}]}
|
||||
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"AI"}}]}
|
||||
@@ -339,12 +339,7 @@ GLM4_STREAM_META = """{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","cr
|
||||
@pytest.fixture
|
||||
def mock_glm4_completion() -> list:
|
||||
list_chunk_data = GLM4_STREAM_META.split("\n")
|
||||
result_list = []
|
||||
for msg in list_chunk_data:
|
||||
if msg != "[DONE]":
|
||||
result_list.append(json.loads(msg))
|
||||
|
||||
return result_list
|
||||
return [json.loads(msg) for msg in list_chunk_data if msg != "[DONE]"]
|
||||
|
||||
|
||||
async def test_glm4_astream(mock_glm4_completion: list) -> None:
|
||||
@@ -360,7 +355,7 @@ async def test_glm4_astream(mock_glm4_completion: list) -> None:
|
||||
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "async_client", mock_client):
|
||||
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
|
||||
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
usage_metadata = chunk.usage_metadata
|
||||
@@ -385,7 +380,7 @@ def test_glm4_stream(mock_glm4_completion: list) -> None:
|
||||
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "client", mock_client):
|
||||
for chunk in llm.stream("你的名字叫什么?只回答名字"):
|
||||
for chunk in llm.stream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
usage_metadata = chunk.usage_metadata
|
||||
@@ -402,7 +397,7 @@ DEEPSEEK_STREAM_DATA = """{"id":"d3610c24e6b42518a7883ea57c3ea2c3","choices":[{"
|
||||
{"choices":[{"delta":{"content":"Deep","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":"Seek","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":" Chat","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":",","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":",","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":"一个","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":"由","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":"深度","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
@@ -420,12 +415,7 @@ DEEPSEEK_STREAM_DATA = """{"id":"d3610c24e6b42518a7883ea57c3ea2c3","choices":[{"
|
||||
@pytest.fixture
|
||||
def mock_deepseek_completion() -> list[dict]:
|
||||
list_chunk_data = DEEPSEEK_STREAM_DATA.split("\n")
|
||||
result_list = []
|
||||
for msg in list_chunk_data:
|
||||
if msg != "[DONE]":
|
||||
result_list.append(json.loads(msg))
|
||||
|
||||
return result_list
|
||||
return [json.loads(msg) for msg in list_chunk_data if msg != "[DONE]"]
|
||||
|
||||
|
||||
async def test_deepseek_astream(mock_deepseek_completion: list) -> None:
|
||||
@@ -440,7 +430,7 @@ async def test_deepseek_astream(mock_deepseek_completion: list) -> None:
|
||||
usage_chunk = mock_deepseek_completion[-1]
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "async_client", mock_client):
|
||||
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
|
||||
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
usage_metadata = chunk.usage_metadata
|
||||
@@ -464,7 +454,7 @@ def test_deepseek_stream(mock_deepseek_completion: list) -> None:
|
||||
usage_chunk = mock_deepseek_completion[-1]
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "client", mock_client):
|
||||
for chunk in llm.stream("你的名字叫什么?只回答名字"):
|
||||
for chunk in llm.stream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
usage_metadata = chunk.usage_metadata
|
||||
@@ -488,12 +478,7 @@ OPENAI_STREAM_DATA = """{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":
|
||||
@pytest.fixture
|
||||
def mock_openai_completion() -> list[dict]:
|
||||
list_chunk_data = OPENAI_STREAM_DATA.split("\n")
|
||||
result_list = []
|
||||
for msg in list_chunk_data:
|
||||
if msg != "[DONE]":
|
||||
result_list.append(json.loads(msg))
|
||||
|
||||
return result_list
|
||||
return [json.loads(msg) for msg in list_chunk_data if msg != "[DONE]"]
|
||||
|
||||
|
||||
async def test_openai_astream(mock_openai_completion: list) -> None:
|
||||
@@ -508,7 +493,7 @@ async def test_openai_astream(mock_openai_completion: list) -> None:
|
||||
usage_chunk = mock_openai_completion[-1]
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "async_client", mock_client):
|
||||
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
|
||||
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
usage_metadata = chunk.usage_metadata
|
||||
@@ -532,7 +517,7 @@ def test_openai_stream(mock_openai_completion: list) -> None:
|
||||
usage_chunk = mock_openai_completion[-1]
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "client", mock_client):
|
||||
for chunk in llm.stream("你的名字叫什么?只回答名字"):
|
||||
for chunk in llm.stream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
usage_metadata = chunk.usage_metadata
|
||||
@@ -727,7 +712,7 @@ def test_format_message_content() -> None:
|
||||
"input": {"location": "San Francisco, CA", "unit": "celsius"},
|
||||
},
|
||||
]
|
||||
assert [{"type": "text", "text": "hello"}] == _format_message_content(content)
|
||||
assert _format_message_content(content) == [{"type": "text", "text": "hello"}]
|
||||
|
||||
# Standard multi-modal inputs
|
||||
content = [{"type": "image", "source_type": "url", "url": "https://..."}]
|
||||
@@ -776,14 +761,14 @@ def test_format_message_content() -> None:
|
||||
|
||||
|
||||
class GenerateUsername(BaseModel):
|
||||
"Get a username based on someone's name and hair color."
|
||||
"""Get a username based on someone's name and hair color."""
|
||||
|
||||
name: str
|
||||
hair_color: str
|
||||
|
||||
|
||||
class MakeASandwich(BaseModel):
|
||||
"Make a sandwich given a list of ingredients."
|
||||
"""Make a sandwich given a list of ingredients."""
|
||||
|
||||
bread_type: str
|
||||
cheese_type: str
|
||||
@@ -805,7 +790,7 @@ class MakeASandwich(BaseModel):
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("strict", [True, False, None])
|
||||
def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> None:
|
||||
def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> None: # noqa: FBT001
|
||||
"""Test passing in manually construct tool call message."""
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
llm.bind_tools(
|
||||
@@ -822,8 +807,8 @@ def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> Non
|
||||
def test_with_structured_output(
|
||||
schema: Union[type, dict[str, Any], None],
|
||||
method: Literal["function_calling", "json_mode", "json_schema"],
|
||||
include_raw: bool,
|
||||
strict: Optional[bool],
|
||||
include_raw: bool, # noqa: FBT001
|
||||
strict: Optional[bool], # noqa: FBT001
|
||||
) -> None:
|
||||
"""Test passing in manually construct tool call message."""
|
||||
if method == "json_mode":
|
||||
@@ -873,11 +858,20 @@ def test_get_num_tokens_from_messages() -> None:
|
||||
),
|
||||
ToolMessage("foobar", tool_call_id="foo"),
|
||||
]
|
||||
expected = 176
|
||||
actual = llm.get_num_tokens_from_messages(messages)
|
||||
assert expected == actual
|
||||
expected = 431
|
||||
|
||||
with patch(
|
||||
"langchain_openai.chat_models.base._url_to_size", return_value=(512, 512)
|
||||
) as _mock_url_to_size:
|
||||
actual = llm.get_num_tokens_from_messages(messages)
|
||||
|
||||
with patch(
|
||||
"langchain_openai.chat_models.base._url_to_size", return_value=(512, 128)
|
||||
) as _mock_url_to_size:
|
||||
actual = llm.get_num_tokens_from_messages(messages)
|
||||
expected = 431
|
||||
assert expected == actual
|
||||
|
||||
# Test file inputs
|
||||
messages = [
|
||||
HumanMessage(
|
||||
[
|
||||
@@ -896,6 +890,28 @@ def test_get_num_tokens_from_messages() -> None:
|
||||
actual = llm.get_num_tokens_from_messages(messages)
|
||||
assert actual == 13
|
||||
|
||||
# actual = llm.get_num_tokens_from_messages(messages)
|
||||
# assert expected == actual
|
||||
|
||||
# # Test file inputs
|
||||
# messages = [
|
||||
# HumanMessage(
|
||||
# [
|
||||
# "Summarize this document.",
|
||||
# {
|
||||
# "type": "file",
|
||||
# "file": {
|
||||
# "filename": "my file",
|
||||
# "file_data": "data:application/pdf;base64,<data>",
|
||||
# },
|
||||
# },
|
||||
# ]
|
||||
# )
|
||||
# ]
|
||||
# with pytest.warns(match="file inputs are not supported"):
|
||||
# actual = llm.get_num_tokens_from_messages(messages)
|
||||
# assert actual == 13
|
||||
|
||||
|
||||
class Foo(BaseModel):
|
||||
bar: int
|
||||
@@ -914,7 +930,6 @@ class Foo(BaseModel):
|
||||
)
|
||||
def test_schema_from_with_structured_output(schema: type) -> None:
|
||||
"""Test schema from with_structured_output."""
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o")
|
||||
|
||||
structured_llm = llm.with_structured_output(
|
||||
@@ -1014,10 +1029,10 @@ def test__convert_to_openai_response_format() -> None:
|
||||
@pytest.mark.parametrize("method", ["function_calling", "json_schema"])
|
||||
@pytest.mark.parametrize("strict", [True, None])
|
||||
def test_structured_output_strict(
|
||||
method: Literal["function_calling", "json_schema"], strict: Optional[bool]
|
||||
method: Literal["function_calling", "json_schema"],
|
||||
strict: Optional[bool], # noqa: FBT001
|
||||
) -> None:
|
||||
"""Test to verify structured output with strict=True."""
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-2024-08-06")
|
||||
|
||||
class Joke(BaseModel):
|
||||
@@ -1033,7 +1048,6 @@ def test_structured_output_strict(
|
||||
|
||||
def test_nested_structured_output_strict() -> None:
|
||||
"""Test to verify structured output with strict=True for nested object."""
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-2024-08-06")
|
||||
|
||||
class SelfEvaluation(TypedDict):
|
||||
@@ -1139,8 +1153,8 @@ def test_structured_output_old_model() -> None:
|
||||
with pytest.warns(match="Cannot use method='json_schema'"):
|
||||
llm = ChatOpenAI(model="gpt-4").with_structured_output(Output)
|
||||
# assert tool calling was used instead of json_schema
|
||||
assert "tools" in llm.steps[0].kwargs # type: ignore
|
||||
assert "response_format" not in llm.steps[0].kwargs # type: ignore
|
||||
assert "tools" in llm.steps[0].kwargs # type: ignore[attr-defined]
|
||||
assert "response_format" not in llm.steps[0].kwargs # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_structured_outputs_parser() -> None:
|
||||
@@ -1518,7 +1532,7 @@ def test__construct_lc_result_from_responses_api_complex_response() -> None:
|
||||
arguments='{"location": "New York"}',
|
||||
),
|
||||
],
|
||||
metadata=dict(key1="value1", key2="value2"),
|
||||
metadata={"key1": "value1", "key2": "value2"},
|
||||
incomplete_details=IncompleteDetails(reason="max_output_tokens"),
|
||||
status="completed",
|
||||
user="user_123",
|
||||
@@ -1758,7 +1772,6 @@ def test__construct_lc_result_from_responses_api_file_search_response() -> None:
|
||||
|
||||
def test__construct_lc_result_from_responses_api_mixed_search_responses() -> None:
|
||||
"""Test a response with both web search and file search outputs."""
|
||||
|
||||
response = Response(
|
||||
id="resp_123",
|
||||
created_at=1234567890,
|
||||
@@ -2229,7 +2242,6 @@ class FakeTracer(BaseTracer):
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
pass
|
||||
|
||||
def on_chat_model_start(self, *args: Any, **kwargs: Any) -> Run:
|
||||
self.chat_model_start_inputs.append({"args": args, "kwargs": kwargs})
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
"""Standard LangChain interface tests."""
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
"""Standard LangChain interface tests."""
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -604,10 +606,9 @@ def _strip_none(obj: Any) -> Any:
|
||||
"""Recursively strip None values from dictionaries and lists."""
|
||||
if isinstance(obj, dict):
|
||||
return {k: _strip_none(v) for k, v in obj.items() if v is not None}
|
||||
elif isinstance(obj, list):
|
||||
if isinstance(obj, list):
|
||||
return [_strip_none(v) for v in obj]
|
||||
else:
|
||||
return obj
|
||||
return obj
|
||||
|
||||
|
||||
def test_responses_stream() -> None:
|
||||
|
||||
@@ -32,7 +32,6 @@ def test_embed_documents_with_custom_chunk_size() -> None:
|
||||
|
||||
result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
|
||||
_, tokens, __ = embeddings._tokenize(texts, custom_chunk_size)
|
||||
mock_create.call_args
|
||||
mock_create.assert_any_call(input=tokens[0:3], **embeddings._invocation_params)
|
||||
mock_create.assert_any_call(input=tokens[3:4], **embeddings._invocation_params)
|
||||
|
||||
@@ -52,7 +51,6 @@ def test_embed_documents_with_custom_chunk_size_no_check_ctx_length() -> None:
|
||||
|
||||
result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
|
||||
|
||||
mock_create.call_args
|
||||
mock_create.assert_any_call(input=texts[0:3], **embeddings._invocation_params)
|
||||
mock_create.assert_any_call(input=texts[3:4], **embeddings._invocation_params)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
"""Standard LangChain interface tests."""
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""A fake callback handler for testing purposes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import chain
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
@@ -188,7 +190,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
def on_retriever_error(self, *args: Any, **kwargs: Any) -> Any:
|
||||
self.on_retriever_error_common()
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore[override]
|
||||
def __deepcopy__(self, memo: dict) -> FakeCallbackHandler: # type: ignore[override]
|
||||
return self
|
||||
|
||||
|
||||
@@ -266,5 +268,5 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
async def on_text(self, *args: Any, **kwargs: Any) -> None:
|
||||
self.on_text_common()
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore[override]
|
||||
def __deepcopy__(self, memo: dict) -> FakeAsyncCallbackHandler: # type: ignore[override]
|
||||
return self
|
||||
|
||||
2835
libs/partners/openai/uv.lock
generated
2835
libs/partners/openai/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user