partners[lint]: run pyupgrade to get code in line with 3.9 standards (#30781)

Using `pyupgrade` to get all `partners` code up to 3.9 standards
(mostly, fixing old `typing` imports).
This commit is contained in:
Sydney Runkle 2025-04-11 07:18:44 -04:00 committed by GitHub
parent e72f3c26a0
commit 8c6734325b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
123 changed files with 1000 additions and 1109 deletions

View File

@ -1,21 +1,14 @@
import copy import copy
import re import re
import warnings import warnings
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from functools import cached_property from functools import cached_property
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Callable, Callable,
Dict,
Iterator,
List,
Literal, Literal,
Mapping,
Optional, Optional,
Sequence,
Tuple,
Type,
Union, Union,
cast, cast,
) )
@ -89,8 +82,8 @@ class AnthropicTool(TypedDict):
name: str name: str
description: str description: str
input_schema: Dict[str, Any] input_schema: dict[str, Any]
cache_control: NotRequired[Dict[str, str]] cache_control: NotRequired[dict[str, str]]
def _is_builtin_tool(tool: Any) -> bool: def _is_builtin_tool(tool: Any) -> bool:
@ -109,7 +102,7 @@ def _is_builtin_tool(tool: Any) -> bool:
return any(tool_type.startswith(prefix) for prefix in _builtin_tool_prefixes) return any(tool_type.startswith(prefix) for prefix in _builtin_tool_prefixes)
def _format_image(image_url: str) -> Dict: def _format_image(image_url: str) -> dict:
""" """
Formats an image of format data:image/jpeg;base64,{b64_string} Formats an image of format data:image/jpeg;base64,{b64_string}
to a dict for anthropic api to a dict for anthropic api
@ -138,7 +131,7 @@ def _format_image(image_url: str) -> Dict:
def _merge_messages( def _merge_messages(
messages: Sequence[BaseMessage], messages: Sequence[BaseMessage],
) -> List[Union[SystemMessage, AIMessage, HumanMessage]]: ) -> list[Union[SystemMessage, AIMessage, HumanMessage]]:
"""Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501 """Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501
merged: list = [] merged: list = []
for curr in messages: for curr in messages:
@ -169,7 +162,7 @@ def _merge_messages(
for c in (SystemMessage, HumanMessage) for c in (SystemMessage, HumanMessage)
): ):
if isinstance(cast(BaseMessage, last).content, str): if isinstance(cast(BaseMessage, last).content, str):
new_content: List = [ new_content: list = [
{"type": "text", "text": cast(BaseMessage, last).content} {"type": "text", "text": cast(BaseMessage, last).content}
] ]
else: else:
@ -185,8 +178,8 @@ def _merge_messages(
def _format_messages( def _format_messages(
messages: List[BaseMessage], messages: list[BaseMessage],
) -> Tuple[Union[str, List[Dict], None], List[Dict]]: ) -> tuple[Union[str, list[dict], None], list[dict]]:
"""Format messages for anthropic.""" """Format messages for anthropic."""
""" """
@ -198,8 +191,8 @@ def _format_messages(
for m in messages for m in messages
] ]
""" """
system: Union[str, List[Dict], None] = None system: Union[str, list[dict], None] = None
formatted_messages: List[Dict] = [] formatted_messages: list[dict] = []
merged_messages = _merge_messages(messages) merged_messages = _merge_messages(messages)
for i, message in enumerate(merged_messages): for i, message in enumerate(merged_messages):
@ -220,7 +213,7 @@ def _format_messages(
continue continue
role = _message_type_lookups[message.type] role = _message_type_lookups[message.type]
content: Union[str, List] content: Union[str, list]
if not isinstance(message.content, str): if not isinstance(message.content, str):
# parse as dict # parse as dict
@ -830,7 +823,7 @@ class ChatAnthropic(BaseChatModel):
max_retries: int = 2 max_retries: int = 2
"""Number of retries allowed for requests sent to the Anthropic Completion API.""" """Number of retries allowed for requests sent to the Anthropic Completion API."""
stop_sequences: Optional[List[str]] = Field(None, alias="stop") stop_sequences: Optional[list[str]] = Field(None, alias="stop")
"""Default stop sequences.""" """Default stop sequences."""
anthropic_api_url: Optional[str] = Field( anthropic_api_url: Optional[str] = Field(
@ -858,7 +851,7 @@ class ChatAnthropic(BaseChatModel):
default_headers: Optional[Mapping[str, str]] = None default_headers: Optional[Mapping[str, str]] = None
"""Headers to pass to the Anthropic clients, will be used for every API call.""" """Headers to pass to the Anthropic clients, will be used for every API call."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: dict[str, Any] = Field(default_factory=dict)
streaming: bool = False streaming: bool = False
"""Whether to use streaming or not.""" """Whether to use streaming or not."""
@ -868,7 +861,7 @@ class ChatAnthropic(BaseChatModel):
message chunks will be generated during the stream including usage metadata. message chunks will be generated during the stream including usage metadata.
""" """
thinking: Optional[Dict[str, Any]] = Field(default=None) thinking: Optional[dict[str, Any]] = Field(default=None)
"""Parameters for Claude reasoning, """Parameters for Claude reasoning,
e.g., ``{"type": "enabled", "budget_tokens": 10_000}``""" e.g., ``{"type": "enabled", "budget_tokens": 10_000}``"""
@ -878,7 +871,7 @@ class ChatAnthropic(BaseChatModel):
return "anthropic-chat" return "anthropic-chat"
@property @property
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> dict[str, str]:
return {"anthropic_api_key": "ANTHROPIC_API_KEY"} return {"anthropic_api_key": "ANTHROPIC_API_KEY"}
@classmethod @classmethod
@ -886,12 +879,12 @@ class ChatAnthropic(BaseChatModel):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "chat_models", "anthropic"] return ["langchain", "chat_models", "anthropic"]
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> dict[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
return { return {
"model": self.model, "model": self.model,
@ -907,7 +900,7 @@ class ChatAnthropic(BaseChatModel):
} }
def _get_ls_params( def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams: ) -> LangSmithParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs) params = self._get_invocation_params(stop=stop, **kwargs)
@ -925,14 +918,14 @@ class ChatAnthropic(BaseChatModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def build_extra(cls, values: Dict) -> Any: def build_extra(cls, values: dict) -> Any:
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names) values = _build_model_kwargs(values, all_required_field_names)
return values return values
@cached_property @cached_property
def _client_params(self) -> Dict[str, Any]: def _client_params(self) -> dict[str, Any]:
client_params: Dict[str, Any] = { client_params: dict[str, Any] = {
"api_key": self.anthropic_api_key.get_secret_value(), "api_key": self.anthropic_api_key.get_secret_value(),
"base_url": self.anthropic_api_url, "base_url": self.anthropic_api_url,
"max_retries": self.max_retries, "max_retries": self.max_retries,
@ -958,9 +951,9 @@ class ChatAnthropic(BaseChatModel):
self, self,
input_: LanguageModelInput, input_: LanguageModelInput,
*, *,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Dict, **kwargs: dict,
) -> Dict: ) -> dict:
messages = self._convert_input(input_).to_messages() messages = self._convert_input(input_).to_messages()
system, formatted_messages = _format_messages(messages) system, formatted_messages = _format_messages(messages)
payload = { payload = {
@ -981,8 +974,8 @@ class ChatAnthropic(BaseChatModel):
def _stream( def _stream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
*, *,
stream_usage: Optional[bool] = None, stream_usage: Optional[bool] = None,
@ -1012,8 +1005,8 @@ class ChatAnthropic(BaseChatModel):
async def _astream( async def _astream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
*, *,
stream_usage: Optional[bool] = None, stream_usage: Optional[bool] = None,
@ -1088,8 +1081,8 @@ class ChatAnthropic(BaseChatModel):
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -1104,8 +1097,8 @@ class ChatAnthropic(BaseChatModel):
async def _agenerate( async def _agenerate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -1120,7 +1113,7 @@ class ChatAnthropic(BaseChatModel):
def _get_llm_for_structured_output_when_thinking_is_enabled( def _get_llm_for_structured_output_when_thinking_is_enabled(
self, self,
schema: Union[Dict, type], schema: Union[dict, type],
formatted_tool: AnthropicTool, formatted_tool: AnthropicTool,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
thinking_admonition = ( thinking_admonition = (
@ -1148,10 +1141,10 @@ class ChatAnthropic(BaseChatModel):
def bind_tools( def bind_tools(
self, self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
*, *,
tool_choice: Optional[ tool_choice: Optional[
Union[Dict[str, str], Literal["any", "auto"], str] Union[dict[str, str], Literal["any", "auto"], str]
] = None, ] = None,
parallel_tool_calls: Optional[bool] = None, parallel_tool_calls: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
@ -1326,11 +1319,11 @@ class ChatAnthropic(BaseChatModel):
def with_structured_output( def with_structured_output(
self, self,
schema: Union[Dict, type], schema: Union[dict, type],
*, *,
include_raw: bool = False, include_raw: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema. """Model wrapper that returns outputs formatted to match the given schema.
Args: Args:
@ -1483,9 +1476,9 @@ class ChatAnthropic(BaseChatModel):
@beta() @beta()
def get_num_tokens_from_messages( def get_num_tokens_from_messages(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
tools: Optional[ tools: Optional[
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]] Sequence[Union[dict[str, Any], type, Callable, BaseTool]]
] = None, ] = None,
) -> int: ) -> int:
"""Count tokens in a sequence of input messages. """Count tokens in a sequence of input messages.
@ -1546,7 +1539,7 @@ class ChatAnthropic(BaseChatModel):
https://docs.anthropic.com/en/docs/build-with-claude/token-counting https://docs.anthropic.com/en/docs/build-with-claude/token-counting
""" """
formatted_system, formatted_messages = _format_messages(messages) formatted_system, formatted_messages = _format_messages(messages)
kwargs: Dict[str, Any] = {} kwargs: dict[str, Any] = {}
if isinstance(formatted_system, str): if isinstance(formatted_system, str):
kwargs["system"] = formatted_system kwargs["system"] = formatted_system
if tools: if tools:
@ -1562,7 +1555,7 @@ class ChatAnthropic(BaseChatModel):
def convert_to_anthropic_tool( def convert_to_anthropic_tool(
tool: Union[Dict[str, Any], Type, Callable, BaseTool], tool: Union[dict[str, Any], type, Callable, BaseTool],
) -> AnthropicTool: ) -> AnthropicTool:
"""Convert a tool-like object to an Anthropic tool definition.""" """Convert a tool-like object to an Anthropic tool definition."""
# already in Anthropic tool format # already in Anthropic tool format
@ -1611,8 +1604,8 @@ class _AnthropicToolUse(TypedDict):
def _lc_tool_calls_to_anthropic_tool_use_blocks( def _lc_tool_calls_to_anthropic_tool_use_blocks(
tool_calls: List[ToolCall], tool_calls: list[ToolCall],
) -> List[_AnthropicToolUse]: ) -> list[_AnthropicToolUse]:
blocks = [] blocks = []
for tool_call in tool_calls: for tool_call in tool_calls:
blocks.append( blocks.append(
@ -1735,7 +1728,7 @@ class ChatAnthropicMessages(ChatAnthropic):
def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata: def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata:
input_token_details: Dict = { input_token_details: dict = {
"cache_read": getattr(anthropic_usage, "cache_read_input_tokens", None), "cache_read": getattr(anthropic_usage, "cache_read_input_tokens", None),
"cache_creation": getattr(anthropic_usage, "cache_creation_input_tokens", None), "cache_creation": getattr(anthropic_usage, "cache_creation_input_tokens", None),
} }

View File

@ -1,8 +1,6 @@
import json import json
from typing import ( from typing import (
Any, Any,
Dict,
List,
Union, Union,
) )
@ -44,7 +42,7 @@ TOOL_PARAMETER_FORMAT = """<parameter>
</parameter>""" </parameter>"""
def _get_type(parameter: Dict[str, Any]) -> str: def _get_type(parameter: dict[str, Any]) -> str:
if "type" in parameter: if "type" in parameter:
return parameter["type"] return parameter["type"]
if "anyOf" in parameter: if "anyOf" in parameter:
@ -54,9 +52,9 @@ def _get_type(parameter: Dict[str, Any]) -> str:
return json.dumps(parameter) return json.dumps(parameter)
def get_system_message(tools: List[Dict]) -> str: def get_system_message(tools: list[dict]) -> str:
"""Generate a system message that describes the available tools.""" """Generate a system message that describes the available tools."""
tools_data: List[Dict] = [ tools_data: list[dict] = [
{ {
"tool_name": tool["name"], "tool_name": tool["name"],
"tool_description": tool["description"], "tool_description": tool["description"],
@ -86,13 +84,13 @@ def get_system_message(tools: List[Dict]) -> str:
return SYSTEM_PROMPT_FORMAT.format(formatted_tools=tools_formatted) return SYSTEM_PROMPT_FORMAT.format(formatted_tools=tools_formatted)
def _xml_to_dict(t: Any) -> Union[str, Dict[str, Any]]: def _xml_to_dict(t: Any) -> Union[str, dict[str, Any]]:
# Base case: If the element has no children, return its text or an empty string. # Base case: If the element has no children, return its text or an empty string.
if len(t) == 0: if len(t) == 0:
return t.text or "" return t.text or ""
# Recursive case: The element has children. Convert them into a dictionary. # Recursive case: The element has children. Convert them into a dictionary.
d: Dict[str, Any] = {} d: dict[str, Any] = {}
for child in t: for child in t:
if child.tag not in d: if child.tag not in d:
d[child.tag] = _xml_to_dict(child) d[child.tag] = _xml_to_dict(child)
@ -104,7 +102,7 @@ def _xml_to_dict(t: Any) -> Union[str, Dict[str, Any]]:
return d return d
def _xml_to_function_call(invoke: Any, tools: List[Dict]) -> Dict[str, Any]: def _xml_to_function_call(invoke: Any, tools: list[dict]) -> dict[str, Any]:
name = invoke.find("tool_name").text name = invoke.find("tool_name").text
arguments = _xml_to_dict(invoke.find("parameters")) arguments = _xml_to_dict(invoke.find("parameters"))
@ -135,7 +133,7 @@ def _xml_to_function_call(invoke: Any, tools: List[Dict]) -> Dict[str, Any]:
} }
def _xml_to_tool_calls(elem: Any, tools: List[Dict]) -> List[Dict[str, Any]]: def _xml_to_tool_calls(elem: Any, tools: list[dict]) -> list[dict[str, Any]]:
""" """
Convert an XML element and its children into a dictionary of dictionaries. Convert an XML element and its children into a dictionary of dictionaries.
""" """

View File

@ -1,13 +1,9 @@
import re import re
import warnings import warnings
from collections.abc import AsyncIterator, Iterator, Mapping
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Callable, Callable,
Dict,
Iterator,
List,
Mapping,
Optional, Optional,
) )
@ -83,11 +79,11 @@ class _AnthropicCommon(BaseLanguageModel):
HUMAN_PROMPT: Optional[str] = None HUMAN_PROMPT: Optional[str] = None
AI_PROMPT: Optional[str] = None AI_PROMPT: Optional[str] = None
count_tokens: Optional[Callable[[str], int]] = None count_tokens: Optional[Callable[[str], int]] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: dict[str, Any] = Field(default_factory=dict)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def build_extra(cls, values: Dict) -> Any: def build_extra(cls, values: dict) -> Any:
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names) values = _build_model_kwargs(values, all_required_field_names)
return values return values
@ -131,7 +127,7 @@ class _AnthropicCommon(BaseLanguageModel):
"""Get the identifying parameters.""" """Get the identifying parameters."""
return {**{}, **self._default_params} return {**{}, **self._default_params}
def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]: def _get_anthropic_stop(self, stop: Optional[list[str]] = None) -> list[str]:
if not self.HUMAN_PROMPT or not self.AI_PROMPT: if not self.HUMAN_PROMPT or not self.AI_PROMPT:
raise NameError("Please ensure the anthropic package is loaded") raise NameError("Please ensure the anthropic package is loaded")
@ -165,7 +161,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def raise_warning(cls, values: Dict) -> Any: def raise_warning(cls, values: dict) -> Any:
"""Raise warning that this class is deprecated.""" """Raise warning that this class is deprecated."""
warnings.warn( warnings.warn(
"This Anthropic LLM is deprecated. " "This Anthropic LLM is deprecated. "
@ -180,7 +176,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
return "anthropic-llm" return "anthropic-llm"
@property @property
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> dict[str, str]:
return {"anthropic_api_key": "ANTHROPIC_API_KEY"} return {"anthropic_api_key": "ANTHROPIC_API_KEY"}
@classmethod @classmethod
@ -188,7 +184,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
return True return True
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> dict[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
return { return {
"model": self.model, "model": self.model,
@ -203,7 +199,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
} }
def _get_ls_params( def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams: ) -> LangSmithParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs) params = super()._get_ls_params(stop=stop, **kwargs)
@ -233,7 +229,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
def _call( def _call(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
@ -277,7 +273,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
async def _acall( async def _acall(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
@ -303,7 +299,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
def _stream( def _stream(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[GenerationChunk]: ) -> Iterator[GenerationChunk]:
@ -338,7 +334,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
async def _astream( async def _astream(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[GenerationChunk]: ) -> AsyncIterator[GenerationChunk]:

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, Type, Union, cast from typing import Any, Optional, Union, cast
from langchain_core.messages import AIMessage, ToolCall from langchain_core.messages import AIMessage, ToolCall
from langchain_core.messages.tool import tool_call from langchain_core.messages.tool import tool_call
@ -14,14 +14,14 @@ class ToolsOutputParser(BaseGenerationOutputParser):
"""Whether to return only the first tool call.""" """Whether to return only the first tool call."""
args_only: bool = False args_only: bool = False
"""Whether to return only the arguments of the tool calls.""" """Whether to return only the arguments of the tool calls."""
pydantic_schemas: Optional[List[Type[BaseModel]]] = None pydantic_schemas: Optional[list[type[BaseModel]]] = None
"""Pydantic schemas to parse tool calls into.""" """Pydantic schemas to parse tool calls into."""
model_config = ConfigDict( model_config = ConfigDict(
extra="forbid", extra="forbid",
) )
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse a list of candidate model Generations into a specific format. """Parse a list of candidate model Generations into a specific format.
Args: Args:
@ -34,7 +34,7 @@ class ToolsOutputParser(BaseGenerationOutputParser):
if not result or not isinstance(result[0], ChatGeneration): if not result or not isinstance(result[0], ChatGeneration):
return None if self.first_tool_only else [] return None if self.first_tool_only else []
message = cast(AIMessage, result[0].message) message = cast(AIMessage, result[0].message)
tool_calls: List = [ tool_calls: list = [
dict(tc) for tc in _extract_tool_calls_from_message(message) dict(tc) for tc in _extract_tool_calls_from_message(message)
] ]
if isinstance(message.content, list): if isinstance(message.content, list):
@ -64,14 +64,14 @@ class ToolsOutputParser(BaseGenerationOutputParser):
return cls_(**tool_call["args"]) return cls_(**tool_call["args"])
def _extract_tool_calls_from_message(message: AIMessage) -> List[ToolCall]: def _extract_tool_calls_from_message(message: AIMessage) -> list[ToolCall]:
"""Extract tool calls from a list of content blocks.""" """Extract tool calls from a list of content blocks."""
if message.tool_calls: if message.tool_calls:
return message.tool_calls return message.tool_calls
return extract_tool_calls(message.content) return extract_tool_calls(message.content)
def extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[ToolCall]: def extract_tool_calls(content: Union[str, list[Union[str, dict]]]) -> list[ToolCall]:
"""Extract tool calls from a list of content blocks.""" """Extract tool calls from a list of content blocks."""
if isinstance(content, list): if isinstance(content, list):
tool_calls = [] tool_calls = []

View File

@ -55,8 +55,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
disallow_untyped_defs = "True" disallow_untyped_defs = "True"
plugins = ['pydantic.mypy'] plugins = ['pydantic.mypy']
[tool.ruff]
target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I", "T201"] select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.coverage.run] [tool.coverage.run]
omit = ["tests/*"] omit = ["tests/*"]

View File

@ -2,7 +2,7 @@
import json import json
from base64 import b64encode from base64 import b64encode
from typing import List, Optional from typing import Optional
import httpx import httpx
import pytest import pytest
@ -270,7 +270,7 @@ def test_anthropic_call() -> None:
def test_anthropic_generate() -> None: def test_anthropic_generate() -> None:
"""Test generate method of anthropic.""" """Test generate method of anthropic."""
chat = ChatAnthropic(model=MODEL_NAME) chat = ChatAnthropic(model=MODEL_NAME)
chat_messages: List[List[BaseMessage]] = [ chat_messages: list[list[BaseMessage]] = [
[HumanMessage(content="How many toes do dogs have?")] [HumanMessage(content="How many toes do dogs have?")]
] ]
messages_copy = [messages.copy() for messages in chat_messages] messages_copy = [messages.copy() for messages in chat_messages]
@ -318,7 +318,7 @@ async def test_anthropic_async_streaming_callback() -> None:
callback_manager=callback_manager, callback_manager=callback_manager,
verbose=True, verbose=True,
) )
chat_messages: List[BaseMessage] = [ chat_messages: list[BaseMessage] = [
HumanMessage(content="How many toes do dogs have?") HumanMessage(content="How many toes do dogs have?")
] ]
async for token in chat.astream(chat_messages): async for token in chat.astream(chat_messages):
@ -809,7 +809,7 @@ def test_image_tool_calling() -> None:
fav_color: str fav_color: str
human_content: List[dict] = [ human_content: list[dict] = [
{ {
"type": "text", "type": "text",
"text": "what's your favorite color in this image", "text": "what's your favorite color in this image",

View File

@ -1,7 +1,7 @@
"""Test ChatAnthropic chat model.""" """Test ChatAnthropic chat model."""
from enum import Enum from enum import Enum
from typing import List, Optional from typing import Optional
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -136,7 +136,7 @@ def test_anthropic_complex_structured_output() -> None:
sender_address: Optional[str] = Field( sender_address: Optional[str] = Field(
None, description="The sender's address, if available" None, description="The sender's address, if available"
) )
action_items: List[str] = Field( action_items: list[str] = Field(
..., description="A list of action items requested by the email" ..., description="A list of action items requested by the email"
) )
topic: str = Field( topic: str = Field(

View File

@ -1,6 +1,6 @@
"""Test Anthropic API wrapper.""" """Test Anthropic API wrapper."""
from typing import Generator from collections.abc import Generator
import pytest import pytest
from langchain_core.callbacks import CallbackManager from langchain_core.callbacks import CallbackManager

View File

@ -1,7 +1,7 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from pathlib import Path from pathlib import Path
from typing import Dict, List, Literal, Type, cast from typing import Literal, cast
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
@ -14,7 +14,7 @@ REPO_ROOT_DIR = Path(__file__).parents[5]
class TestAnthropicStandard(ChatModelIntegrationTests): class TestAnthropicStandard(ChatModelIntegrationTests):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return ChatAnthropic return ChatAnthropic
@property @property
@ -36,9 +36,9 @@ class TestAnthropicStandard(ChatModelIntegrationTests):
@property @property
def supported_usage_metadata_details( def supported_usage_metadata_details(
self, self,
) -> Dict[ ) -> dict[
Literal["invoke", "stream"], Literal["invoke", "stream"],
List[ list[
Literal[ Literal[
"audio_input", "audio_input",
"audio_output", "audio_output",
@ -58,7 +58,7 @@ class TestAnthropicStandard(ChatModelIntegrationTests):
model="claude-3-5-sonnet-20240620", # type: ignore[call-arg] model="claude-3-5-sonnet-20240620", # type: ignore[call-arg]
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, # type: ignore[call-arg] extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, # type: ignore[call-arg]
) )
with open(REPO_ROOT_DIR / "README.md", "r") as f: with open(REPO_ROOT_DIR / "README.md") as f:
readme = f.read() readme = f.read()
input_ = f"""What's langchain? Here's the langchain README: input_ = f"""What's langchain? Here's the langchain README:
@ -87,7 +87,7 @@ class TestAnthropicStandard(ChatModelIntegrationTests):
model="claude-3-5-sonnet-20240620", # type: ignore[call-arg] model="claude-3-5-sonnet-20240620", # type: ignore[call-arg]
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, # type: ignore[call-arg] extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, # type: ignore[call-arg]
) )
with open(REPO_ROOT_DIR / "README.md", "r") as f: with open(REPO_ROOT_DIR / "README.md") as f:
readme = f.read() readme = f.read()
input_ = f"""What's langchain? Here's the langchain README: input_ = f"""What's langchain? Here's the langchain README:

View File

@ -1,7 +1,7 @@
"""Test chat model integration.""" """Test chat model integration."""
import os import os
from typing import Any, Callable, Dict, Literal, Type, cast from typing import Any, Callable, Literal, cast
import pytest import pytest
from anthropic.types import Message, TextBlock, Usage from anthropic.types import Message, TextBlock, Usage
@ -297,7 +297,7 @@ def test__merge_messages_mutation() -> None:
@pytest.fixture() @pytest.fixture()
def pydantic() -> Type[BaseModel]: def pydantic() -> type[BaseModel]:
class dummy_function(BaseModel): class dummy_function(BaseModel):
"""dummy function""" """dummy function"""
@ -328,7 +328,7 @@ def dummy_tool() -> BaseTool:
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'") arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
class DummyFunction(BaseTool): # type: ignore[override] class DummyFunction(BaseTool): # type: ignore[override]
args_schema: Type[BaseModel] = Schema args_schema: type[BaseModel] = Schema
name: str = "dummy_function" name: str = "dummy_function"
description: str = "dummy function" description: str = "dummy function"
@ -339,7 +339,7 @@ def dummy_tool() -> BaseTool:
@pytest.fixture() @pytest.fixture()
def json_schema() -> Dict: def json_schema() -> dict:
return { return {
"title": "dummy_function", "title": "dummy_function",
"description": "dummy function", "description": "dummy function",
@ -357,7 +357,7 @@ def json_schema() -> Dict:
@pytest.fixture() @pytest.fixture()
def openai_function() -> Dict: def openai_function() -> dict:
return { return {
"name": "dummy_function", "name": "dummy_function",
"description": "dummy function", "description": "dummy function",
@ -377,11 +377,11 @@ def openai_function() -> Dict:
def test_convert_to_anthropic_tool( def test_convert_to_anthropic_tool(
pydantic: Type[BaseModel], pydantic: type[BaseModel],
function: Callable, function: Callable,
dummy_tool: BaseTool, dummy_tool: BaseTool,
json_schema: Dict, json_schema: dict,
openai_function: Dict, openai_function: dict,
) -> None: ) -> None:
expected = { expected = {
"name": "dummy_function", "name": "dummy_function",

View File

@ -1,4 +1,4 @@
from typing import Any, List, Literal from typing import Any, Literal
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration from langchain_core.outputs import ChatGeneration
@ -6,7 +6,7 @@ from pydantic import BaseModel
from langchain_anthropic.output_parsers import ToolsOutputParser from langchain_anthropic.output_parsers import ToolsOutputParser
_CONTENT: List = [ _CONTENT: list = [
{ {
"type": "text", "type": "text",
"text": "thought", "text": "thought",
@ -19,7 +19,7 @@ _CONTENT: List = [
{"type": "tool_use", "input": {"baz": "a"}, "id": "2", "name": "_Foo2"}, {"type": "tool_use", "input": {"baz": "a"}, "id": "2", "name": "_Foo2"},
] ]
_RESULT: List = [ChatGeneration(message=AIMessage(_CONTENT))] # type: ignore[misc] _RESULT: list = [ChatGeneration(message=AIMessage(_CONTENT))] # type: ignore[misc]
class _Foo1(BaseModel): class _Foo1(BaseModel):

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Type
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests import ChatModelUnitTests from langchain_tests.unit_tests import ChatModelUnitTests
@ -10,7 +8,7 @@ from langchain_anthropic import ChatAnthropic
class TestAnthropicStandard(ChatModelUnitTests): class TestAnthropicStandard(ChatModelUnitTests):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return ChatAnthropic return ChatAnthropic
@property @property

View File

@ -8,17 +8,12 @@ from __future__ import annotations
import base64 import base64
import logging import logging
import uuid import uuid
from collections.abc import Iterable, Sequence
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable, Callable,
Dict,
Iterable,
List,
Optional, Optional,
Sequence,
Tuple,
Type,
Union, Union,
) )
@ -37,11 +32,11 @@ logger = logging.getLogger()
DEFAULT_K = 4 # Number of Documents to return. DEFAULT_K = 4 # Number of Documents to return.
def _results_to_docs(results: Any) -> List[Document]: def _results_to_docs(results: Any) -> list[Document]:
return [doc for doc, _ in _results_to_docs_and_scores(results)] return [doc for doc, _ in _results_to_docs_and_scores(results)]
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]: def _results_to_docs_and_scores(results: Any) -> list[tuple[Document, float]]:
return [ return [
# TODO: Chroma can do batch querying, # TODO: Chroma can do batch querying,
# we shouldn't hard code to the 1st result # we shouldn't hard code to the 1st result
@ -58,7 +53,7 @@ def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
] ]
def _results_to_docs_and_vectors(results: Any) -> List[Tuple[Document, np.ndarray]]: def _results_to_docs_and_vectors(results: Any) -> list[tuple[Document, np.ndarray]]:
return [ return [
(Document(page_content=result[0], metadata=result[1] or {}), result[2]) (Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip( for result in zip(
@ -69,7 +64,7 @@ def _results_to_docs_and_vectors(results: Any) -> List[Tuple[Document, np.ndarra
] ]
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] Matrix = Union[list[list[float]], list[np.ndarray], np.ndarray]
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
@ -104,7 +99,7 @@ def maximal_marginal_relevance(
embedding_list: list, embedding_list: list,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
k: int = 4, k: int = 4,
) -> List[int]: ) -> list[int]:
"""Calculate maximal marginal relevance. """Calculate maximal marginal relevance.
Args: Args:
@ -287,7 +282,7 @@ class Chroma(VectorStore):
embedding_function: Optional[Embeddings] = None, embedding_function: Optional[Embeddings] = None,
persist_directory: Optional[str] = None, persist_directory: Optional[str] = None,
client_settings: Optional[chromadb.config.Settings] = None, client_settings: Optional[chromadb.config.Settings] = None,
collection_metadata: Optional[Dict] = None, collection_metadata: Optional[dict] = None,
client: Optional[chromadb.ClientAPI] = None, client: Optional[chromadb.ClientAPI] = None,
relevance_score_fn: Optional[Callable[[float], float]] = None, relevance_score_fn: Optional[Callable[[float], float]] = None,
create_collection_if_not_exists: Optional[bool] = True, create_collection_if_not_exists: Optional[bool] = True,
@ -370,13 +365,13 @@ class Chroma(VectorStore):
@xor_args(("query_texts", "query_embeddings")) @xor_args(("query_texts", "query_embeddings"))
def __query_collection( def __query_collection(
self, self,
query_texts: Optional[List[str]] = None, query_texts: Optional[list[str]] = None,
query_embeddings: Optional[List[List[float]]] = None, query_embeddings: Optional[list[list[float]]] = None,
n_results: int = 4, n_results: int = 4,
where: Optional[Dict[str, str]] = None, where: Optional[dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None, where_document: Optional[dict[str, str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Union[List[Document], chromadb.QueryResult]: ) -> Union[list[Document], chromadb.QueryResult]:
"""Query the chroma collection. """Query the chroma collection.
Args: Args:
@ -411,11 +406,11 @@ class Chroma(VectorStore):
def add_images( def add_images(
self, self,
uris: List[str], uris: list[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[list[dict]] = None,
ids: Optional[List[str]] = None, ids: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> list[str]:
"""Run more images through the embeddings and add to the vectorstore. """Run more images through the embeddings and add to the vectorstore.
Args: Args:
@ -502,10 +497,10 @@ class Chroma(VectorStore):
def add_texts( def add_texts(
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[list[dict]] = None,
ids: Optional[List[str]] = None, ids: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> list[str]:
"""Run more texts through the embeddings and add to the vectorstore. """Run more texts through the embeddings and add to the vectorstore.
Args: Args:
@ -591,9 +586,9 @@ class Chroma(VectorStore):
self, self,
query: str, query: str,
k: int = DEFAULT_K, k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None, filter: Optional[dict[str, str]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
"""Run similarity search with Chroma. """Run similarity search with Chroma.
Args: Args:
@ -612,12 +607,12 @@ class Chroma(VectorStore):
def similarity_search_by_vector( def similarity_search_by_vector(
self, self,
embedding: List[float], embedding: list[float],
k: int = DEFAULT_K, k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None, filter: Optional[dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None, where_document: Optional[dict[str, str]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
"""Return docs most similar to embedding vector. """Return docs most similar to embedding vector.
Args: Args:
@ -642,12 +637,12 @@ class Chroma(VectorStore):
def similarity_search_by_vector_with_relevance_scores( def similarity_search_by_vector_with_relevance_scores(
self, self,
embedding: List[float], embedding: list[float],
k: int = DEFAULT_K, k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None, filter: Optional[dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None, where_document: Optional[dict[str, str]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> list[tuple[Document, float]]:
"""Return docs most similar to embedding vector and similarity score. """Return docs most similar to embedding vector and similarity score.
Args: Args:
@ -675,10 +670,10 @@ class Chroma(VectorStore):
self, self,
query: str, query: str,
k: int = DEFAULT_K, k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None, filter: Optional[dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None, where_document: Optional[dict[str, str]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> list[tuple[Document, float]]:
"""Run similarity search with Chroma with distance. """Run similarity search with Chroma with distance.
Args: Args:
@ -717,10 +712,10 @@ class Chroma(VectorStore):
self, self,
query: str, query: str,
k: int = DEFAULT_K, k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None, filter: Optional[dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None, where_document: Optional[dict[str, str]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[Document, np.ndarray]]: ) -> list[tuple[Document, np.ndarray]]:
"""Run similarity search with Chroma with vectors. """Run similarity search with Chroma with vectors.
Args: Args:
@ -800,9 +795,9 @@ class Chroma(VectorStore):
self, self,
uri: str, uri: str,
k: int = DEFAULT_K, k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None, filter: Optional[dict[str, str]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
"""Search for similar images based on the given image URI. """Search for similar images based on the given image URI.
Args: Args:
@ -844,9 +839,9 @@ class Chroma(VectorStore):
self, self,
uri: str, uri: str,
k: int = DEFAULT_K, k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None, filter: Optional[dict[str, str]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> list[tuple[Document, float]]:
"""Search for similar images based on the given image URI. """Search for similar images based on the given image URI.
Args: Args:
@ -886,14 +881,14 @@ class Chroma(VectorStore):
def max_marginal_relevance_search_by_vector( def max_marginal_relevance_search_by_vector(
self, self,
embedding: List[float], embedding: list[float],
k: int = DEFAULT_K, k: int = DEFAULT_K,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None, filter: Optional[dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None, where_document: Optional[dict[str, str]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity
@ -942,10 +937,10 @@ class Chroma(VectorStore):
k: int = DEFAULT_K, k: int = DEFAULT_K,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None, filter: Optional[dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None, where_document: Optional[dict[str, str]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity
@ -1005,8 +1000,8 @@ class Chroma(VectorStore):
limit: Optional[int] = None, limit: Optional[int] = None,
offset: Optional[int] = None, offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None, where_document: Optional[WhereDocument] = None,
include: Optional[List[str]] = None, include: Optional[list[str]] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Gets the collection. """Gets the collection.
Args: Args:
@ -1081,7 +1076,7 @@ class Chroma(VectorStore):
return self.update_documents([document_id], [document]) return self.update_documents([document_id], [document])
# type: ignore # type: ignore
def update_documents(self, ids: List[str], documents: List[Document]) -> None: def update_documents(self, ids: list[str], documents: list[Document]) -> None:
"""Update a document in the collection. """Update a document in the collection.
Args: Args:
@ -1129,16 +1124,16 @@ class Chroma(VectorStore):
@classmethod @classmethod
def from_texts( def from_texts(
cls: Type[Chroma], cls: type[Chroma],
texts: List[str], texts: list[str],
embedding: Optional[Embeddings] = None, embedding: Optional[Embeddings] = None,
metadatas: Optional[List[dict]] = None, metadatas: Optional[list[dict]] = None,
ids: Optional[List[str]] = None, ids: Optional[list[str]] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
persist_directory: Optional[str] = None, persist_directory: Optional[str] = None,
client_settings: Optional[chromadb.config.Settings] = None, client_settings: Optional[chromadb.config.Settings] = None,
client: Optional[chromadb.ClientAPI] = None, client: Optional[chromadb.ClientAPI] = None,
collection_metadata: Optional[Dict] = None, collection_metadata: Optional[dict] = None,
**kwargs: Any, **kwargs: Any,
) -> Chroma: ) -> Chroma:
"""Create a Chroma vectorstore from a raw documents. """Create a Chroma vectorstore from a raw documents.
@ -1200,15 +1195,15 @@ class Chroma(VectorStore):
@classmethod @classmethod
def from_documents( def from_documents(
cls: Type[Chroma], cls: type[Chroma],
documents: List[Document], documents: list[Document],
embedding: Optional[Embeddings] = None, embedding: Optional[Embeddings] = None,
ids: Optional[List[str]] = None, ids: Optional[list[str]] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
persist_directory: Optional[str] = None, persist_directory: Optional[str] = None,
client_settings: Optional[chromadb.config.Settings] = None, client_settings: Optional[chromadb.config.Settings] = None,
client: Optional[chromadb.ClientAPI] = None, # Add this line client: Optional[chromadb.ClientAPI] = None, # Add this line
collection_metadata: Optional[Dict] = None, collection_metadata: Optional[dict] = None,
**kwargs: Any, **kwargs: Any,
) -> Chroma: ) -> Chroma:
"""Create a Chroma vectorstore from a list of documents. """Create a Chroma vectorstore from a list of documents.
@ -1249,7 +1244,7 @@ class Chroma(VectorStore):
**kwargs, **kwargs,
) )
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None: def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> None:
"""Delete by vector IDs. """Delete by vector IDs.
Args: Args:

View File

@ -58,8 +58,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
[tool.mypy] [tool.mypy]
disallow_untyped_defs = true disallow_untyped_defs = true
[tool.ruff]
target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I", "T201", "D"] select = ["E", "F", "I", "T201", "D", "UP"]
ignore = [ "UP007", ]
[tool.coverage.run] [tool.coverage.run]
omit = ["tests/*"] omit = ["tests/*"]

View File

@ -1,7 +1,6 @@
"""Fake Embedding class for testing purposes.""" """Fake Embedding class for testing purposes."""
import math import math
from typing import List
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
@ -11,22 +10,22 @@ fake_texts = ["foo", "bar", "baz"]
class FakeEmbeddings(Embeddings): class FakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing.""" """Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Return simple embeddings. """Return simple embeddings.
Embeddings encode each text as its index.""" Embeddings encode each text as its index."""
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))] return [[1.0] * 9 + [float(i)] for i in range(len(texts))]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return self.embed_documents(texts) return self.embed_documents(texts)
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> list[float]:
"""Return constant query embeddings. """Return constant query embeddings.
Embeddings are identical to embed_documents(texts)[0]. Embeddings are identical to embed_documents(texts)[0].
Distance to each text will be that text's index, Distance to each text will be that text's index,
as it was passed to embed_documents.""" as it was passed to embed_documents."""
return [float(1.0)] * 9 + [float(0.0)] return [1.0] * 9 + [0.0]
async def aembed_query(self, text: str) -> List[float]: async def aembed_query(self, text: str) -> list[float]:
return self.embed_query(text) return self.embed_query(text)
@ -35,22 +34,22 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
vectors for the same texts.""" vectors for the same texts."""
def __init__(self, dimensionality: int = 10) -> None: def __init__(self, dimensionality: int = 10) -> None:
self.known_texts: List[str] = [] self.known_texts: list[str] = []
self.dimensionality = dimensionality self.dimensionality = dimensionality
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Return consistent embeddings for each text seen so far.""" """Return consistent embeddings for each text seen so far."""
out_vectors = [] out_vectors = []
for text in texts: for text in texts:
if text not in self.known_texts: if text not in self.known_texts:
self.known_texts.append(text) self.known_texts.append(text)
vector = [float(1.0)] * (self.dimensionality - 1) + [ vector = [1.0] * (self.dimensionality - 1) + [
float(self.known_texts.index(text)) float(self.known_texts.index(text))
] ]
out_vectors.append(vector) out_vectors.append(vector)
return out_vectors return out_vectors
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> list[float]:
"""Return consistent embeddings for the text, if seen before, or a constant """Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown.""" one if the text is unknown."""
return self.embed_documents([text])[0] return self.embed_documents([text])[0]
@ -61,13 +60,13 @@ class AngularTwoDimensionalEmbeddings(Embeddings):
From angles (as strings in units of pi) to unit embedding vectors on a circle. From angles (as strings in units of pi) to unit embedding vectors on a circle.
""" """
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
""" """
Make a list of texts into a list of embedding vectors. Make a list of texts into a list of embedding vectors.
""" """
return [self.embed_query(text) for text in texts] return [self.embed_query(text) for text in texts]
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> list[float]:
""" """
Convert input text to a 'vector' (list of floats). Convert input text to a 'vector' (list of floats).
If the text is a number, use it as the angle for the If the text is a number, use it as the angle for the

View File

@ -3,8 +3,8 @@
import os.path import os.path
import tempfile import tempfile
import uuid import uuid
from collections.abc import Generator
from typing import ( from typing import (
Generator,
cast, cast,
) )
@ -222,7 +222,7 @@ def test_chroma_with_metadatas_with_scores_using_vector() -> None:
def test_chroma_search_filter() -> None: def test_chroma_search_filter() -> None:
"""Test end to end construction and search with metadata filtering.""" """Test end to end construction and search with metadata filtering."""
texts = ["far", "bar", "baz"] texts = ["far", "bar", "baz"]
metadatas = [{"first_letter": "{}".format(text[0])} for text in texts] metadatas = [{"first_letter": f"{text[0]}"} for text in texts]
ids = [f"id_{i}" for i in range(len(texts))] ids = [f"id_{i}" for i in range(len(texts))]
docsearch = Chroma.from_texts( docsearch = Chroma.from_texts(
collection_name="test_collection", collection_name="test_collection",
@ -245,7 +245,7 @@ def test_chroma_search_filter() -> None:
def test_chroma_search_filter_with_scores() -> None: def test_chroma_search_filter_with_scores() -> None:
"""Test end to end construction and scored search with metadata filtering.""" """Test end to end construction and scored search with metadata filtering."""
texts = ["far", "bar", "baz"] texts = ["far", "bar", "baz"]
metadatas = [{"first_letter": "{}".format(text[0])} for text in texts] metadatas = [{"first_letter": f"{text[0]}"} for text in texts]
ids = [f"id_{i}" for i in range(len(texts))] ids = [f"id_{i}" for i in range(len(texts))]
docsearch = Chroma.from_texts( docsearch = Chroma.from_texts(
collection_name="test_collection", collection_name="test_collection",

View File

@ -1,4 +1,4 @@
from typing import Generator from collections.abc import Generator
import pytest import pytest
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore

View File

@ -1,7 +1,8 @@
"""DeepSeek chat models.""" """DeepSeek chat models."""
from collections.abc import Iterator
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, Dict, Iterator, List, Literal, Optional, Type, TypeVar, Union from typing import Any, Literal, Optional, TypeVar, Union
import openai import openai
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -19,8 +20,8 @@ from typing_extensions import Self
DEFAULT_API_BASE = "https://api.deepseek.com/v1" DEFAULT_API_BASE = "https://api.deepseek.com/v1"
_BM = TypeVar("_BM", bound=BaseModel) _BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type] _DictOrPydanticClass = Union[dict[str, Any], type[_BM], type]
_DictOrPydantic = Union[Dict, _BM] _DictOrPydantic = Union[dict, _BM]
class ChatDeepSeek(BaseChatOpenAI): class ChatDeepSeek(BaseChatOpenAI):
@ -178,7 +179,7 @@ class ChatDeepSeek(BaseChatOpenAI):
return "chat-deepseek" return "chat-deepseek"
@property @property
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> dict[str, str]:
"""A map of constructor argument names to secret ids.""" """A map of constructor argument names to secret ids."""
return {"api_key": "DEEPSEEK_API_KEY"} return {"api_key": "DEEPSEEK_API_KEY"}
@ -217,7 +218,7 @@ class ChatDeepSeek(BaseChatOpenAI):
def _create_chat_result( def _create_chat_result(
self, self,
response: Union[dict, openai.BaseModel], response: Union[dict, openai.BaseModel],
generation_info: Optional[Dict] = None, generation_info: Optional[dict] = None,
) -> ChatResult: ) -> ChatResult:
rtn = super()._create_chat_result(response, generation_info) rtn = super()._create_chat_result(response, generation_info)
@ -243,8 +244,8 @@ class ChatDeepSeek(BaseChatOpenAI):
def _convert_chunk_to_generation_chunk( def _convert_chunk_to_generation_chunk(
self, self,
chunk: dict, chunk: dict,
default_chunk_class: Type, default_chunk_class: type,
base_generation_info: Optional[Dict], base_generation_info: Optional[dict],
) -> Optional[ChatGenerationChunk]: ) -> Optional[ChatGenerationChunk]:
generation_chunk = super()._convert_chunk_to_generation_chunk( generation_chunk = super()._convert_chunk_to_generation_chunk(
chunk, chunk,
@ -268,8 +269,8 @@ class ChatDeepSeek(BaseChatOpenAI):
def _stream( def _stream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
@ -287,8 +288,8 @@ class ChatDeepSeek(BaseChatOpenAI):
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:

View File

@ -45,8 +45,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
[tool.mypy] [tool.mypy]
disallow_untyped_defs = "True" disallow_untyped_defs = "True"
[tool.ruff]
target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I", "T201"] select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.coverage.run] [tool.coverage.run]
omit = ["tests/*"] omit = ["tests/*"]

View File

@ -1,6 +1,6 @@
"""Test ChatDeepSeek chat model.""" """Test ChatDeepSeek chat model."""
from typing import Optional, Type from typing import Optional
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
@ -13,7 +13,7 @@ from langchain_deepseek.chat_models import ChatDeepSeek
class TestChatDeepSeek(ChatModelIntegrationTests): class TestChatDeepSeek(ChatModelIntegrationTests):
@property @property
def chat_model_class(self) -> Type[ChatDeepSeek]: def chat_model_class(self) -> type[ChatDeepSeek]:
return ChatDeepSeek return ChatDeepSeek
@property @property

View File

@ -1,6 +1,6 @@
"""Test chat model integration.""" """Test chat model integration."""
from typing import Any, Dict, Literal, Type, Union from typing import Any, Literal, Union
from unittest.mock import MagicMock from unittest.mock import MagicMock
from langchain_core.messages import AIMessageChunk from langchain_core.messages import AIMessageChunk
@ -28,9 +28,9 @@ class MockOpenAIResponse(BaseModel):
exclude_none: bool = False, exclude_none: bool = False,
round_trip: bool = False, round_trip: bool = False,
warnings: Union[Literal["none", "warn", "error"], bool] = True, warnings: Union[Literal["none", "warn", "error"], bool] = True,
context: Union[Dict[str, Any], None] = None, context: Union[dict[str, Any], None] = None,
serialize_as_any: bool = False, serialize_as_any: bool = False,
) -> Dict[str, Any]: ) -> dict[str, Any]:
choices_list = [] choices_list = []
for choice in self.choices: for choice in self.choices:
if isinstance(choice.message, ChatCompletionMessage): if isinstance(choice.message, ChatCompletionMessage):
@ -57,7 +57,7 @@ class MockOpenAIResponse(BaseModel):
class TestChatDeepSeekUnit(ChatModelUnitTests): class TestChatDeepSeekUnit(ChatModelUnitTests):
@property @property
def chat_model_class(self) -> Type[ChatDeepSeek]: def chat_model_class(self) -> type[ChatDeepSeek]:
return ChatDeepSeek return ChatDeepSeek
@property @property
@ -134,7 +134,7 @@ class TestChatDeepSeekCustomUnit:
def test_convert_chunk_with_reasoning_content(self) -> None: def test_convert_chunk_with_reasoning_content(self) -> None:
"""Test that reasoning_content is properly extracted from streaming chunk.""" """Test that reasoning_content is properly extracted from streaming chunk."""
chat_model = ChatDeepSeek(model="deepseek-chat", api_key=SecretStr("api_key")) chat_model = ChatDeepSeek(model="deepseek-chat", api_key=SecretStr("api_key"))
chunk: Dict[str, Any] = { chunk: dict[str, Any] = {
"choices": [ "choices": [
{ {
"delta": { "delta": {
@ -158,7 +158,7 @@ class TestChatDeepSeekCustomUnit:
def test_convert_chunk_with_reasoning(self) -> None: def test_convert_chunk_with_reasoning(self) -> None:
"""Test that reasoning is properly extracted from streaming chunk.""" """Test that reasoning is properly extracted from streaming chunk."""
chat_model = ChatDeepSeek(model="deepseek-chat", api_key=SecretStr("api_key")) chat_model = ChatDeepSeek(model="deepseek-chat", api_key=SecretStr("api_key"))
chunk: Dict[str, Any] = { chunk: dict[str, Any] = {
"choices": [ "choices": [
{ {
"delta": { "delta": {
@ -182,7 +182,7 @@ class TestChatDeepSeekCustomUnit:
def test_convert_chunk_without_reasoning(self) -> None: def test_convert_chunk_without_reasoning(self) -> None:
"""Test that chunk without reasoning fields works correctly.""" """Test that chunk without reasoning fields works correctly."""
chat_model = ChatDeepSeek(model="deepseek-chat", api_key=SecretStr("api_key")) chat_model = ChatDeepSeek(model="deepseek-chat", api_key=SecretStr("api_key"))
chunk: Dict[str, Any] = {"choices": [{"delta": {"content": "Main content"}}]} chunk: dict[str, Any] = {"choices": [{"delta": {"content": "Main content"}}]}
chunk_result = chat_model._convert_chunk_to_generation_chunk( chunk_result = chat_model._convert_chunk_to_generation_chunk(
chunk, AIMessageChunk, None chunk, AIMessageChunk, None
@ -194,7 +194,7 @@ class TestChatDeepSeekCustomUnit:
def test_convert_chunk_with_empty_delta(self) -> None: def test_convert_chunk_with_empty_delta(self) -> None:
"""Test that chunk with empty delta works correctly.""" """Test that chunk with empty delta works correctly."""
chat_model = ChatDeepSeek(model="deepseek-chat", api_key=SecretStr("api_key")) chat_model = ChatDeepSeek(model="deepseek-chat", api_key=SecretStr("api_key"))
chunk: Dict[str, Any] = {"choices": [{"delta": {}}]} chunk: dict[str, Any] = {"choices": [{"delta": {}}]}
chunk_result = chat_model._convert_chunk_to_generation_chunk( chunk_result = chat_model._convert_chunk_to_generation_chunk(
chunk, AIMessageChunk, None chunk, AIMessageChunk, None

View File

@ -1,11 +1,10 @@
import os # type: ignore[import-not-found] import os # type: ignore[import-not-found]
from typing import Dict
from exa_py import Exa # type: ignore from exa_py import Exa # type: ignore
from langchain_core.utils import convert_to_secret_str from langchain_core.utils import convert_to_secret_str
def initialize_client(values: Dict) -> Dict: def initialize_client(values: dict) -> dict:
"""Initialize the client.""" """Initialize the client."""
exa_api_key = values.get("exa_api_key") or os.environ.get("EXA_API_KEY") or "" exa_api_key = values.get("exa_api_key") or os.environ.get("EXA_API_KEY") or ""
values["exa_api_key"] = convert_to_secret_str(exa_api_key) values["exa_api_key"] = convert_to_secret_str(exa_api_key)

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Literal, Optional, Union
from exa_py import Exa # type: ignore[untyped-import] from exa_py import Exa # type: ignore[untyped-import]
from exa_py.api import ( from exa_py.api import (
@ -13,7 +13,7 @@ from pydantic import Field, SecretStr, model_validator
from langchain_exa._utilities import initialize_client from langchain_exa._utilities import initialize_client
def _get_metadata(result: Any) -> Dict[str, Any]: def _get_metadata(result: Any) -> dict[str, Any]:
"""Get the metadata from a result object.""" """Get the metadata from a result object."""
metadata = { metadata = {
"title": result.title, "title": result.title,
@ -35,9 +35,9 @@ class ExaSearchRetriever(BaseRetriever):
k: int = 10 # num_results k: int = 10 # num_results
"""The number of search results to return.""" """The number of search results to return."""
include_domains: Optional[List[str]] = None include_domains: Optional[list[str]] = None
"""A list of domains to include in the search.""" """A list of domains to include in the search."""
exclude_domains: Optional[List[str]] = None exclude_domains: Optional[list[str]] = None
"""A list of domains to exclude from the search.""" """A list of domains to exclude from the search."""
start_crawl_date: Optional[str] = None start_crawl_date: Optional[str] = None
"""The start date for the crawl (in YYYY-MM-DD format).""" """The start date for the crawl (in YYYY-MM-DD format)."""
@ -62,14 +62,14 @@ class ExaSearchRetriever(BaseRetriever):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_environment(cls, values: Dict) -> Any: def validate_environment(cls, values: dict) -> Any:
"""Validate the environment.""" """Validate the environment."""
values = initialize_client(values) values = initialize_client(values)
return values return values
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]: ) -> list[Document]:
response = self.client.search_and_contents( # type: ignore[misc] response = self.client.search_and_contents( # type: ignore[misc]
query, query,
num_results=self.k, num_results=self.k,

View File

@ -1,6 +1,6 @@
"""Tool for the Exa Search API.""" """Tool for the Exa Search API."""
from typing import Any, Dict, List, Optional, Union from typing import Any, Optional, Union
from exa_py import Exa # type: ignore[untyped-import] from exa_py import Exa # type: ignore[untyped-import]
from exa_py.api import ( from exa_py.api import (
@ -66,7 +66,7 @@ class ExaSearchResults(BaseTool): # type: ignore[override]
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_environment(cls, values: Dict) -> Any: def validate_environment(cls, values: dict) -> Any:
"""Validate the environment.""" """Validate the environment."""
values = initialize_client(values) values = initialize_client(values)
return values return values
@ -77,15 +77,15 @@ class ExaSearchResults(BaseTool): # type: ignore[override]
num_results: int, num_results: int,
text_contents_options: Optional[Union[TextContentsOptions, bool]] = None, text_contents_options: Optional[Union[TextContentsOptions, bool]] = None,
highlights: Optional[Union[HighlightsContentsOptions, bool]] = None, highlights: Optional[Union[HighlightsContentsOptions, bool]] = None,
include_domains: Optional[List[str]] = None, include_domains: Optional[list[str]] = None,
exclude_domains: Optional[List[str]] = None, exclude_domains: Optional[list[str]] = None,
start_crawl_date: Optional[str] = None, start_crawl_date: Optional[str] = None,
end_crawl_date: Optional[str] = None, end_crawl_date: Optional[str] = None,
start_published_date: Optional[str] = None, start_published_date: Optional[str] = None,
end_published_date: Optional[str] = None, end_published_date: Optional[str] = None,
use_autoprompt: Optional[bool] = None, use_autoprompt: Optional[bool] = None,
run_manager: Optional[CallbackManagerForToolRun] = None, run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Union[List[Dict], str]: ) -> Union[list[dict], str]:
"""Use the tool.""" """Use the tool."""
try: try:
return self.client.search_and_contents( return self.client.search_and_contents(
@ -120,7 +120,7 @@ class ExaFindSimilarResults(BaseTool): # type: ignore[override]
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_environment(cls, values: Dict) -> Any: def validate_environment(cls, values: dict) -> Any:
"""Validate the environment.""" """Validate the environment."""
values = initialize_client(values) values = initialize_client(values)
return values return values
@ -131,8 +131,8 @@ class ExaFindSimilarResults(BaseTool): # type: ignore[override]
num_results: int, num_results: int,
text_contents_options: Optional[Union[TextContentsOptions, bool]] = None, text_contents_options: Optional[Union[TextContentsOptions, bool]] = None,
highlights: Optional[Union[HighlightsContentsOptions, bool]] = None, highlights: Optional[Union[HighlightsContentsOptions, bool]] = None,
include_domains: Optional[List[str]] = None, include_domains: Optional[list[str]] = None,
exclude_domains: Optional[List[str]] = None, exclude_domains: Optional[list[str]] = None,
start_crawl_date: Optional[str] = None, start_crawl_date: Optional[str] = None,
end_crawl_date: Optional[str] = None, end_crawl_date: Optional[str] = None,
start_published_date: Optional[str] = None, start_published_date: Optional[str] = None,
@ -140,7 +140,7 @@ class ExaFindSimilarResults(BaseTool): # type: ignore[override]
exclude_source_domain: Optional[bool] = None, exclude_source_domain: Optional[bool] = None,
category: Optional[str] = None, category: Optional[str] = None,
run_manager: Optional[CallbackManagerForToolRun] = None, run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Union[List[Dict], str]: ) -> Union[list[dict], str]:
"""Use the tool.""" """Use the tool."""
try: try:
return self.client.find_similar_and_contents( return self.client.find_similar_and_contents(

View File

@ -45,8 +45,12 @@ langchain-core = { path = "../../core", editable = true }
[tool.mypy] [tool.mypy]
disallow_untyped_defs = "True" disallow_untyped_defs = "True"
[tool.ruff]
target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I", "T201"] select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.coverage.run] [tool.coverage.run]
omit = ["tests/*"] omit = ["tests/*"]

View File

@ -4,20 +4,13 @@ from __future__ import annotations
import json import json
import logging import logging
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Callable, Callable,
Dict,
Iterator,
List,
Literal, Literal,
Mapping,
Optional, Optional,
Sequence,
Tuple,
Type,
TypedDict, TypedDict,
Union, Union,
cast, cast,
@ -109,7 +102,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
# Fix for azure # Fix for azure
# Also Fireworks returns None for tool invocations # Also Fireworks returns None for tool invocations
content = _dict.get("content", "") or "" content = _dict.get("content", "") or ""
additional_kwargs: Dict = {} additional_kwargs: dict = {}
if function_call := _dict.get("function_call"): if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call) additional_kwargs["function_call"] = dict(function_call)
tool_calls = [] tool_calls = []
@ -157,7 +150,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
Returns: Returns:
The dictionary. The dictionary.
""" """
message_dict: Dict[str, Any] message_dict: dict[str, Any]
if isinstance(message, ChatMessage): if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content} message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage): elif isinstance(message, HumanMessage):
@ -205,14 +198,14 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_chunk_to_message_chunk( def _convert_chunk_to_message_chunk(
chunk: Mapping[str, Any], default_class: Type[BaseMessageChunk] chunk: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk: ) -> BaseMessageChunk:
choice = chunk["choices"][0] choice = chunk["choices"][0]
_dict = choice["delta"] _dict = choice["delta"]
role = cast(str, _dict.get("role")) role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "") content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {} additional_kwargs: dict = {}
tool_call_chunks: List[ToolCallChunk] = [] tool_call_chunks: list[ToolCallChunk] = []
if _dict.get("function_call"): if _dict.get("function_call"):
function_call = dict(_dict["function_call"]) function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None: if "name" in function_call and function_call["name"] is None:
@ -290,17 +283,17 @@ class ChatFireworks(BaseChatModel):
""" """
@property @property
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> dict[str, str]:
return {"fireworks_api_key": "FIREWORKS_API_KEY"} return {"fireworks_api_key": "FIREWORKS_API_KEY"}
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "chat_models", "fireworks"] return ["langchain", "chat_models", "fireworks"]
@property @property
def lc_attributes(self) -> Dict[str, Any]: def lc_attributes(self) -> dict[str, Any]:
attributes: Dict[str, Any] = {} attributes: dict[str, Any] = {}
if self.fireworks_api_base: if self.fireworks_api_base:
attributes["fireworks_api_base"] = self.fireworks_api_base attributes["fireworks_api_base"] = self.fireworks_api_base
@ -319,9 +312,9 @@ class ChatFireworks(BaseChatModel):
"""Model name to use.""" """Model name to use."""
temperature: float = 0.0 temperature: float = 0.0
"""What sampling temperature to use.""" """What sampling temperature to use."""
stop: Optional[Union[str, List[str]]] = Field(default=None, alias="stop_sequences") stop: Optional[Union[str, list[str]]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences.""" """Default stop sequences."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
fireworks_api_key: SecretStr = Field( fireworks_api_key: SecretStr = Field(
alias="api_key", alias="api_key",
@ -344,7 +337,7 @@ class ChatFireworks(BaseChatModel):
) )
"""Base URL path for API requests, leave blank if not using a proxy or service """Base URL path for API requests, leave blank if not using a proxy or service
emulator.""" emulator."""
request_timeout: Union[float, Tuple[float, float], Any, None] = Field( request_timeout: Union[float, tuple[float, float], Any, None] = Field(
default=None, alias="timeout" default=None, alias="timeout"
) )
"""Timeout for requests to Fireworks completion API. Can be float, httpx.Timeout or """Timeout for requests to Fireworks completion API. Can be float, httpx.Timeout or
@ -364,7 +357,7 @@ class ChatFireworks(BaseChatModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any: def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names) values = _build_model_kwargs(values, all_required_field_names)
@ -398,7 +391,7 @@ class ChatFireworks(BaseChatModel):
return self return self
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling Fireworks API.""" """Get the default parameters for calling Fireworks API."""
params = { params = {
"model": self.model_name, "model": self.model_name,
@ -413,7 +406,7 @@ class ChatFireworks(BaseChatModel):
return params return params
def _get_ls_params( def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams: ) -> LangSmithParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs) params = self._get_invocation_params(stop=stop, **kwargs)
@ -429,7 +422,7 @@ class ChatFireworks(BaseChatModel):
ls_params["ls_stop"] = ls_stop ls_params["ls_stop"] = ls_stop
return ls_params return ls_params
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
overall_token_usage: dict = {} overall_token_usage: dict = {}
system_fingerprint = None system_fingerprint = None
for output in llm_outputs: for output in llm_outputs:
@ -452,15 +445,15 @@ class ChatFireworks(BaseChatModel):
def _stream( def _stream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
for chunk in self.client.create(messages=message_dicts, **params): for chunk in self.client.create(messages=message_dicts, **params):
if not isinstance(chunk, dict): if not isinstance(chunk, dict):
chunk = chunk.model_dump() chunk = chunk.model_dump()
@ -487,8 +480,8 @@ class ChatFireworks(BaseChatModel):
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
@ -509,8 +502,8 @@ class ChatFireworks(BaseChatModel):
return self._create_chat_result(response) return self._create_chat_result(response)
def _create_message_dicts( def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]] self, messages: list[BaseMessage], stop: Optional[list[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
params = self._default_params params = self._default_params
if stop is not None: if stop is not None:
params["stop"] = stop params["stop"] = stop
@ -547,15 +540,15 @@ class ChatFireworks(BaseChatModel):
async def _astream( async def _astream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
async for chunk in self.async_client.acreate(messages=message_dicts, **params): async for chunk in self.async_client.acreate(messages=message_dicts, **params):
if not isinstance(chunk, dict): if not isinstance(chunk, dict):
chunk = chunk.model_dump() chunk = chunk.model_dump()
@ -584,8 +577,8 @@ class ChatFireworks(BaseChatModel):
async def _agenerate( async def _agenerate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
@ -607,13 +600,13 @@ class ChatFireworks(BaseChatModel):
return self._create_chat_result(response) return self._create_chat_result(response)
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> dict[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
return {"model_name": self.model_name, **self._default_params} return {"model_name": self.model_name, **self._default_params}
def _get_invocation_params( def _get_invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any self, stop: Optional[list[str]] = None, **kwargs: Any
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Get the parameters used to invoke the model.""" """Get the parameters used to invoke the model."""
return { return {
"model": self.model_name, "model": self.model_name,
@ -634,7 +627,7 @@ class ChatFireworks(BaseChatModel):
) )
def bind_functions( def bind_functions(
self, self,
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
function_call: Optional[ function_call: Optional[
Union[_FunctionCall, str, Literal["auto", "none"]] Union[_FunctionCall, str, Literal["auto", "none"]]
] = None, ] = None,
@ -690,7 +683,7 @@ class ChatFireworks(BaseChatModel):
def bind_tools( def bind_tools(
self, self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
*, *,
tool_choice: Optional[ tool_choice: Optional[
Union[dict, str, Literal["auto", "any", "none"], bool] Union[dict, str, Literal["auto", "any", "none"], bool]
@ -738,14 +731,14 @@ class ChatFireworks(BaseChatModel):
def with_structured_output( def with_structured_output(
self, self,
schema: Optional[Union[Dict, Type[BaseModel]]] = None, schema: Optional[Union[dict, type[BaseModel]]] = None,
*, *,
method: Literal[ method: Literal[
"function_calling", "json_mode", "json_schema" "function_calling", "json_mode", "json_schema"
] = "function_calling", ] = "function_calling",
include_raw: bool = False, include_raw: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema. """Model wrapper that returns outputs formatted to match the given schema.
Args: Args:

View File

@ -1,5 +1,3 @@
from typing import List
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.utils import secret_from_env from langchain_core.utils import secret_from_env
from openai import OpenAI from openai import OpenAI
@ -96,13 +94,13 @@ class FireworksEmbeddings(BaseModel, Embeddings):
) )
return self return self
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs.""" """Embed search docs."""
return [ return [
i.embedding i.embedding
for i in self.client.embeddings.create(input=texts, model=self.model).data for i in self.client.embeddings.create(input=texts, model=self.model).data
] ]
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> list[float]:
"""Embed query text.""" """Embed query text."""
return self.embed_documents([text])[0] return self.embed_documents([text])[0]

View File

@ -1,7 +1,7 @@
"""Wrapper around Fireworks AI's Completion API.""" """Wrapper around Fireworks AI's Completion API."""
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Optional
import requests import requests
from aiohttp import ClientSession from aiohttp import ClientSession
@ -63,7 +63,7 @@ class Fireworks(LLM):
for question answering or summarization. A value greater than 1 introduces more for question answering or summarization. A value greater than 1 introduces more
randomness in the output. randomness in the output.
""" """
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
top_k: Optional[int] = None top_k: Optional[int] = None
"""Used to limit the number of choices for the next predicted word or token. It """Used to limit the number of choices for the next predicted word or token. It
@ -90,7 +90,7 @@ class Fireworks(LLM):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any: def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names) values = _build_model_kwargs(values, all_required_field_names)
@ -109,7 +109,7 @@ class Fireworks(LLM):
return f"langchain-fireworks/{__version__}" return f"langchain-fireworks/{__version__}"
@property @property
def default_params(self) -> Dict[str, Any]: def default_params(self) -> dict[str, Any]:
return { return {
"model": self.model, "model": self.model,
"temperature": self.temperature, "temperature": self.temperature,
@ -122,7 +122,7 @@ class Fireworks(LLM):
def _call( def _call(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
@ -139,7 +139,7 @@ class Fireworks(LLM):
"Content-Type": "application/json", "Content-Type": "application/json",
} }
stop_to_use = stop[0] if stop and len(stop) == 1 else stop stop_to_use = stop[0] if stop and len(stop) == 1 else stop
payload: Dict[str, Any] = { payload: dict[str, Any] = {
**self.default_params, **self.default_params,
"prompt": prompt, "prompt": prompt,
"stop": stop_to_use, "stop": stop_to_use,
@ -168,7 +168,7 @@ class Fireworks(LLM):
async def _acall( async def _acall(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
@ -185,7 +185,7 @@ class Fireworks(LLM):
"Content-Type": "application/json", "Content-Type": "application/json",
} }
stop_to_use = stop[0] if stop and len(stop) == 1 else stop stop_to_use = stop[0] if stop and len(stop) == 1 else stop
payload: Dict[str, Any] = { payload: dict[str, Any] = {
**self.default_params, **self.default_params,
"prompt": prompt, "prompt": prompt,
"stop": stop_to_use, "stop": stop_to_use,

View File

@ -48,8 +48,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
[tool.mypy] [tool.mypy]
disallow_untyped_defs = "True" disallow_untyped_defs = "True"
[tool.ruff]
target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I", "T201"] select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.coverage.run] [tool.coverage.run]
omit = ["tests/*"] omit = ["tests/*"]

View File

@ -4,12 +4,12 @@ You will need FIREWORKS_API_KEY set in your environment to run these tests.
""" """
import json import json
from typing import Any, Literal, Optional from typing import Annotated, Any, Literal, Optional
import pytest import pytest
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated, TypedDict from typing_extensions import TypedDict
from langchain_fireworks import ChatFireworks from langchain_fireworks import ChatFireworks

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Type
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
@ -14,7 +12,7 @@ from langchain_fireworks import ChatFireworks
class TestFireworksStandard(ChatModelIntegrationTests): class TestFireworksStandard(ChatModelIntegrationTests):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return ChatFireworks return ChatFireworks
@property @property

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Tuple, Type
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests
@ -10,7 +8,7 @@ from langchain_fireworks import FireworksEmbeddings
class TestFireworksStandard(EmbeddingsUnitTests): class TestFireworksStandard(EmbeddingsUnitTests):
@property @property
def embeddings_class(self) -> Type[Embeddings]: def embeddings_class(self) -> type[Embeddings]:
return FireworksEmbeddings return FireworksEmbeddings
@property @property
@ -18,7 +16,7 @@ class TestFireworksStandard(EmbeddingsUnitTests):
return {"api_key": "test_api_key"} return {"api_key": "test_api_key"}
@property @property
def init_from_env_params(self) -> Tuple[dict, dict, dict]: def init_from_env_params(self) -> tuple[dict, dict, dict]:
return ( return (
{ {
"FIREWORKS_API_KEY": "api_key", "FIREWORKS_API_KEY": "api_key",

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Tuple, Type
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests import ( # type: ignore[import-not-found] from langchain_tests.unit_tests import ( # type: ignore[import-not-found]
ChatModelUnitTests, # type: ignore[import-not-found] ChatModelUnitTests, # type: ignore[import-not-found]
@ -12,7 +10,7 @@ from langchain_fireworks import ChatFireworks
class TestFireworksStandard(ChatModelUnitTests): class TestFireworksStandard(ChatModelUnitTests):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return ChatFireworks return ChatFireworks
@property @property
@ -20,7 +18,7 @@ class TestFireworksStandard(ChatModelUnitTests):
return {"api_key": "test_api_key"} return {"api_key": "test_api_key"}
@property @property
def init_from_env_params(self) -> Tuple[dict, dict, dict]: def init_from_env_params(self) -> tuple[dict, dict, dict]:
return ( return (
{ {
"FIREWORKS_API_KEY": "api_key", "FIREWORKS_API_KEY": "api_key",

View File

@ -4,20 +4,13 @@ from __future__ import annotations
import json import json
import warnings import warnings
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Callable, Callable,
Dict,
Iterator,
List,
Literal, Literal,
Mapping,
Optional, Optional,
Sequence,
Tuple,
Type,
TypedDict, TypedDict,
Union, Union,
cast, cast,
@ -307,9 +300,9 @@ class ChatGroq(BaseChatModel):
"""Model name to use.""" """Model name to use."""
temperature: float = 0.7 temperature: float = 0.7
"""What sampling temperature to use.""" """What sampling temperature to use."""
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences") stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences.""" """Default stop sequences."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
groq_api_key: Optional[SecretStr] = Field( groq_api_key: Optional[SecretStr] = Field(
alias="api_key", default_factory=secret_from_env("GROQ_API_KEY", default=None) alias="api_key", default_factory=secret_from_env("GROQ_API_KEY", default=None)
@ -324,7 +317,7 @@ class ChatGroq(BaseChatModel):
groq_proxy: Optional[str] = Field( groq_proxy: Optional[str] = Field(
default_factory=from_env("GROQ_PROXY", default=None) default_factory=from_env("GROQ_PROXY", default=None)
) )
request_timeout: Union[float, Tuple[float, float], Any, None] = Field( request_timeout: Union[float, tuple[float, float], Any, None] = Field(
default=None, alias="timeout" default=None, alias="timeout"
) )
"""Timeout for requests to Groq completion API. Can be float, httpx.Timeout or """Timeout for requests to Groq completion API. Can be float, httpx.Timeout or
@ -353,7 +346,7 @@ class ChatGroq(BaseChatModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any: def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
@ -392,7 +385,7 @@ class ChatGroq(BaseChatModel):
self.default_headers or {} self.default_headers or {}
) )
client_params: Dict[str, Any] = { client_params: dict[str, Any] = {
"api_key": ( "api_key": (
self.groq_api_key.get_secret_value() if self.groq_api_key else None self.groq_api_key.get_secret_value() if self.groq_api_key else None
), ),
@ -406,13 +399,13 @@ class ChatGroq(BaseChatModel):
try: try:
import groq import groq
sync_specific: Dict[str, Any] = {"http_client": self.http_client} sync_specific: dict[str, Any] = {"http_client": self.http_client}
if not self.client: if not self.client:
self.client = groq.Groq( self.client = groq.Groq(
**client_params, **sync_specific **client_params, **sync_specific
).chat.completions ).chat.completions
if not self.async_client: if not self.async_client:
async_specific: Dict[str, Any] = {"http_client": self.http_async_client} async_specific: dict[str, Any] = {"http_client": self.http_async_client}
self.async_client = groq.AsyncGroq( self.async_client = groq.AsyncGroq(
**client_params, **async_specific **client_params, **async_specific
).chat.completions ).chat.completions
@ -427,7 +420,7 @@ class ChatGroq(BaseChatModel):
# Serializable class method overrides # Serializable class method overrides
# #
@property @property
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> dict[str, str]:
return {"groq_api_key": "GROQ_API_KEY"} return {"groq_api_key": "GROQ_API_KEY"}
@classmethod @classmethod
@ -444,7 +437,7 @@ class ChatGroq(BaseChatModel):
return "groq-chat" return "groq-chat"
def _get_ls_params( def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams: ) -> LangSmithParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs) params = self._get_invocation_params(stop=stop, **kwargs)
@ -480,8 +473,8 @@ class ChatGroq(BaseChatModel):
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -500,8 +493,8 @@ class ChatGroq(BaseChatModel):
async def _agenerate( async def _agenerate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -521,8 +514,8 @@ class ChatGroq(BaseChatModel):
def _stream( def _stream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
@ -530,7 +523,7 @@ class ChatGroq(BaseChatModel):
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
for chunk in self.client.create(messages=message_dicts, **params): for chunk in self.client.create(messages=message_dicts, **params):
if not isinstance(chunk, dict): if not isinstance(chunk, dict):
chunk = chunk.model_dump() chunk = chunk.model_dump()
@ -560,8 +553,8 @@ class ChatGroq(BaseChatModel):
async def _astream( async def _astream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
@ -569,7 +562,7 @@ class ChatGroq(BaseChatModel):
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
async for chunk in await self.async_client.create( async for chunk in await self.async_client.create(
messages=message_dicts, **params messages=message_dicts, **params
): ):
@ -605,7 +598,7 @@ class ChatGroq(BaseChatModel):
# Internal methods # Internal methods
# #
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling Groq API.""" """Get the default parameters for calling Groq API."""
params = { params = {
"model": self.model_name, "model": self.model_name,
@ -652,15 +645,15 @@ class ChatGroq(BaseChatModel):
return ChatResult(generations=generations, llm_output=llm_output) return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts( def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]] self, messages: list[BaseMessage], stop: Optional[list[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
params = self._default_params params = self._default_params
if stop is not None: if stop is not None:
params["stop"] = stop params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages] message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params return message_dicts, params
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
overall_token_usage: dict = {} overall_token_usage: dict = {}
system_fingerprint = None system_fingerprint = None
for output in llm_outputs: for output in llm_outputs:
@ -688,7 +681,7 @@ class ChatGroq(BaseChatModel):
) )
def bind_functions( def bind_functions(
self, self,
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
function_call: Optional[ function_call: Optional[
Union[_FunctionCall, str, Literal["auto", "none"]] Union[_FunctionCall, str, Literal["auto", "none"]]
] = None, ] = None,
@ -743,7 +736,7 @@ class ChatGroq(BaseChatModel):
def bind_tools( def bind_tools(
self, self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
*, *,
tool_choice: Optional[ tool_choice: Optional[
Union[dict, str, Literal["auto", "any", "none"], bool] Union[dict, str, Literal["auto", "any", "none"], bool]
@ -791,12 +784,12 @@ class ChatGroq(BaseChatModel):
def with_structured_output( def with_structured_output(
self, self,
schema: Optional[Union[Dict, Type[BaseModel]]] = None, schema: Optional[Union[dict, type[BaseModel]]] = None,
*, *,
method: Literal["function_calling", "json_mode"] = "function_calling", method: Literal["function_calling", "json_mode"] = "function_calling",
include_raw: bool = False, include_raw: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema. """Model wrapper that returns outputs formatted to match the given schema.
Args: Args:
@ -1096,7 +1089,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
Returns: Returns:
The dictionary. The dictionary.
""" """
message_dict: Dict[str, Any] message_dict: dict[str, Any]
if isinstance(message, ChatMessage): if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content} message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage): elif isinstance(message, HumanMessage):
@ -1142,13 +1135,13 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_chunk_to_message_chunk( def _convert_chunk_to_message_chunk(
chunk: Mapping[str, Any], default_class: Type[BaseMessageChunk] chunk: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk: ) -> BaseMessageChunk:
choice = chunk["choices"][0] choice = chunk["choices"][0]
_dict = choice["delta"] _dict = choice["delta"]
role = cast(str, _dict.get("role")) role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "") content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {} additional_kwargs: dict = {}
if _dict.get("function_call"): if _dict.get("function_call"):
function_call = dict(_dict["function_call"]) function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None: if "name" in function_call and function_call["name"] is None:
@ -1202,7 +1195,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
return HumanMessage(content=_dict.get("content", "")) return HumanMessage(content=_dict.get("content", ""))
elif role == "assistant": elif role == "assistant":
content = _dict.get("content", "") or "" content = _dict.get("content", "") or ""
additional_kwargs: Dict = {} additional_kwargs: dict = {}
if function_call := _dict.get("function_call"): if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call) additional_kwargs["function_call"] = dict(function_call)
tool_calls = [] tool_calls = []

View File

@ -40,8 +40,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
[tool.mypy] [tool.mypy]
disallow_untyped_defs = "True" disallow_untyped_defs = "True"
[tool.ruff]
target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I", "W"] select = ["E", "F", "I", "W", "UP"]
ignore = [ "UP007", ]
[tool.coverage.run] [tool.coverage.run]
omit = ["tests/*"] omit = ["tests/*"]

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Type
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.rate_limiters import InMemoryRateLimiter from langchain_core.rate_limiters import InMemoryRateLimiter
@ -17,7 +15,7 @@ rate_limiter = InMemoryRateLimiter(requests_per_second=0.2)
class BaseTestGroq(ChatModelIntegrationTests): class BaseTestGroq(ChatModelIntegrationTests):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return ChatGroq return ChatGroq
@pytest.mark.xfail(reason="Not yet implemented.") @pytest.mark.xfail(reason="Not yet implemented.")

View File

@ -1,7 +1,7 @@
"""A fake callback handler for testing purposes.""" """A fake callback handler for testing purposes."""
from itertools import chain from itertools import chain
from typing import Any, Dict, List, Optional, Union from typing import Any, Optional, Union
from uuid import UUID from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
@ -15,7 +15,7 @@ class BaseFakeCallbackHandler(BaseModel):
starts: int = 0 starts: int = 0
ends: int = 0 ends: int = 0
errors: int = 0 errors: int = 0
errors_args: List[Any] = [] errors_args: list[Any] = []
text: int = 0 text: int = 0
ignore_llm_: bool = False ignore_llm_: bool = False
ignore_chain_: bool = False ignore_chain_: bool = False
@ -264,8 +264,8 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler): class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
def on_chat_model_start( def on_chat_model_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
messages: List[List[BaseMessage]], messages: list[list[BaseMessage]],
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Type
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests.chat_models import ( from langchain_tests.unit_tests.chat_models import (
ChatModelUnitTests, ChatModelUnitTests,
@ -12,7 +10,7 @@ from langchain_groq import ChatGroq
class TestGroqStandard(ChatModelUnitTests): class TestGroqStandard(ChatModelUnitTests):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return ChatGroq return ChatGroq
@property @property

View File

@ -1,16 +1,13 @@
"""Hugging Face Chat Wrapper.""" """Hugging Face Chat Wrapper."""
import json import json
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Dict,
List,
Literal, Literal,
Optional, Optional,
Sequence,
Type,
Union, Union,
cast, cast,
) )
@ -46,8 +43,8 @@ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."
class TGI_RESPONSE: class TGI_RESPONSE:
"""Response from the TextGenInference API.""" """Response from the TextGenInference API."""
choices: List[Any] choices: list[Any]
usage: Dict usage: dict
@dataclass @dataclass
@ -56,12 +53,12 @@ class TGI_MESSAGE:
role: str role: str
content: str content: str
tool_calls: List[Dict] tool_calls: list[dict]
def _convert_message_to_chat_message( def _convert_message_to_chat_message(
message: BaseMessage, message: BaseMessage,
) -> Dict: ) -> dict:
if isinstance(message, ChatMessage): if isinstance(message, ChatMessage):
return dict(role=message.role, content=message.content) return dict(role=message.role, content=message.content)
elif isinstance(message, HumanMessage): elif isinstance(message, HumanMessage):
@ -104,7 +101,7 @@ def _convert_TGI_message_to_LC_message(
content = cast(str, _message.content) content = cast(str, _message.content)
if content is None: if content is None:
content = "" content = ""
additional_kwargs: Dict = {} additional_kwargs: dict = {}
if tool_calls := _message.tool_calls: if tool_calls := _message.tool_calls:
if "arguments" in tool_calls[0]["function"]: if "arguments" in tool_calls[0]["function"]:
functions = tool_calls[0]["function"].pop("arguments") functions = tool_calls[0]["function"].pop("arguments")
@ -358,8 +355,8 @@ class ChatHuggingFace(BaseChatModel):
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -380,8 +377,8 @@ class ChatHuggingFace(BaseChatModel):
async def _agenerate( async def _agenerate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -398,7 +395,7 @@ class ChatHuggingFace(BaseChatModel):
def _to_chat_prompt( def _to_chat_prompt(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
) -> str: ) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM.""" """Convert a list of messages into a prompt format expected by wrapped LLM."""
if not messages: if not messages:
@ -472,7 +469,7 @@ class ChatHuggingFace(BaseChatModel):
def bind_tools( def bind_tools(
self, self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
*, *,
tool_choice: Optional[ tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required"], bool] Union[dict, str, Literal["auto", "none", "required"], bool]
@ -529,8 +526,8 @@ class ChatHuggingFace(BaseChatModel):
return super().bind(tools=formatted_tools, **kwargs) return super().bind(tools=formatted_tools, **kwargs)
def _create_message_dicts( def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]] self, messages: list[BaseMessage], stop: Optional[list[str]]
) -> List[Dict[Any, Any]]: ) -> list[dict[Any, Any]]:
message_dicts = [_convert_message_to_chat_message(m) for m in messages] message_dicts = [_convert_message_to_chat_message(m) for m in messages]
return message_dicts return message_dicts

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional from typing import Any, Optional
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
@ -40,16 +40,16 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
cache_folder: Optional[str] = None cache_folder: Optional[str] = None
"""Path to store models. """Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.""" Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the Sentence Transformer model, such as `device`, """Keyword arguments to pass to the Sentence Transformer model, such as `device`,
`prompts`, `default_prompt_name`, `revision`, `trust_remote_code`, or `token`. `prompts`, `default_prompt_name`, `revision`, `trust_remote_code`, or `token`.
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer""" See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer"""
encode_kwargs: Dict[str, Any] = Field(default_factory=dict) encode_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method for the documents of """Keyword arguments to pass when calling the `encode` method for the documents of
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`, the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
`precision`, `normalize_embeddings`, and more. `precision`, `normalize_embeddings`, and more.
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode""" See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
query_encode_kwargs: Dict[str, Any] = Field(default_factory=dict) query_encode_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method for the query of """Keyword arguments to pass when calling the `encode` method for the query of
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`, the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
`precision`, `normalize_embeddings`, and more. `precision`, `normalize_embeddings`, and more.
@ -102,8 +102,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
) )
def _embed( def _embed(
self, texts: list[str], encode_kwargs: Dict[str, Any] self, texts: list[str], encode_kwargs: dict[str, Any]
) -> List[List[float]]: ) -> list[list[float]]:
""" """
Embed a text using the HuggingFace transformer model. Embed a text using the HuggingFace transformer model.
@ -138,7 +138,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
return embeddings.tolist() return embeddings.tolist()
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Compute doc embeddings using a HuggingFace transformer model. """Compute doc embeddings using a HuggingFace transformer model.
Args: Args:
@ -149,7 +149,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
""" """
return self._embed(texts, self.encode_kwargs) return self._embed(texts, self.encode_kwargs)
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> list[float]:
"""Compute query embeddings using a HuggingFace transformer model. """Compute query embeddings using a HuggingFace transformer model.
Args: Args:

View File

@ -1,5 +1,5 @@
import os import os
from typing import Any, List, Optional from typing import Any, Optional
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.utils import from_env from langchain_core.utils import from_env
@ -101,7 +101,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
) )
return self return self
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Call out to HuggingFaceHub's embedding endpoint for embedding search docs. """Call out to HuggingFaceHub's embedding endpoint for embedding search docs.
Args: Args:
@ -117,7 +117,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
responses = self.client.feature_extraction(text=texts, **_model_kwargs) responses = self.client.feature_extraction(text=texts, **_model_kwargs)
return responses return responses
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Async Call to HuggingFaceHub's embedding endpoint for embedding search docs. """Async Call to HuggingFaceHub's embedding endpoint for embedding search docs.
Args: Args:
@ -134,7 +134,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
) )
return responses return responses
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> list[float]:
"""Call out to HuggingFaceHub's embedding endpoint for embedding query text. """Call out to HuggingFaceHub's embedding endpoint for embedding query text.
Args: Args:
@ -146,7 +146,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
response = self.embed_documents([text])[0] response = self.embed_documents([text])[0]
return response return response
async def aembed_query(self, text: str) -> List[float]: async def aembed_query(self, text: str) -> list[float]:
"""Async Call to HuggingFaceHub's embedding endpoint for embedding query text. """Async Call to HuggingFaceHub's embedding endpoint for embedding query text.
Args: Args:

View File

@ -2,7 +2,8 @@ import inspect
import json # type: ignore[import-not-found] import json # type: ignore[import-not-found]
import logging import logging
import os import os
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional from collections.abc import AsyncIterator, Iterator, Mapping
from typing import Any, Optional
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -96,7 +97,7 @@ class HuggingFaceEndpoint(LLM):
"""Whether to prepend the prompt to the generated text""" """Whether to prepend the prompt to the generated text"""
truncate: Optional[int] = None truncate: Optional[int] = None
"""Truncate inputs tokens to the given size""" """Truncate inputs tokens to the given size"""
stop_sequences: List[str] = Field(default_factory=list) stop_sequences: list[str] = Field(default_factory=list)
"""Stop generating tokens if a member of `stop_sequences` is generated""" """Stop generating tokens if a member of `stop_sequences` is generated"""
seed: Optional[int] = None seed: Optional[int] = None
"""Random sampling seed""" """Random sampling seed"""
@ -111,9 +112,9 @@ class HuggingFaceEndpoint(LLM):
watermark: bool = False watermark: bool = False
"""Watermarking with [A Watermark for Large Language Models] """Watermarking with [A Watermark for Large Language Models]
(https://arxiv.org/abs/2301.10226)""" (https://arxiv.org/abs/2301.10226)"""
server_kwargs: Dict[str, Any] = Field(default_factory=dict) server_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any text-generation-inference server parameters not explicitly specified""" """Holds any text-generation-inference server parameters not explicitly specified"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `call` not explicitly specified""" """Holds any model parameters valid for `call` not explicitly specified"""
model: str model: str
client: Any = None #: :meta private: client: Any = None #: :meta private:
@ -128,7 +129,7 @@ class HuggingFaceEndpoint(LLM):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any: def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
@ -252,7 +253,7 @@ class HuggingFaceEndpoint(LLM):
return self return self
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling text generation inference API.""" """Get the default parameters for calling text generation inference API."""
return { return {
"max_new_tokens": self.max_new_tokens, "max_new_tokens": self.max_new_tokens,
@ -285,8 +286,8 @@ class HuggingFaceEndpoint(LLM):
return "huggingface_endpoint" return "huggingface_endpoint"
def _invocation_params( def _invocation_params(
self, runtime_stop: Optional[List[str]], **kwargs: Any self, runtime_stop: Optional[list[str]], **kwargs: Any
) -> Dict[str, Any]: ) -> dict[str, Any]:
params = {**self._default_params, **kwargs} params = {**self._default_params, **kwargs}
params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or []) params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or [])
return params return params
@ -294,7 +295,7 @@ class HuggingFaceEndpoint(LLM):
def _call( def _call(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
@ -326,7 +327,7 @@ class HuggingFaceEndpoint(LLM):
async def _acall( async def _acall(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
@ -357,7 +358,7 @@ class HuggingFaceEndpoint(LLM):
def _stream( def _stream(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[GenerationChunk]: ) -> Iterator[GenerationChunk]:
@ -394,7 +395,7 @@ class HuggingFaceEndpoint(LLM):
async def _astream( async def _astream(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[GenerationChunk]: ) -> AsyncIterator[GenerationChunk]:

View File

@ -2,7 +2,8 @@ from __future__ import annotations # type: ignore[import-not-found]
import importlib.util import importlib.util
import logging import logging
from typing import Any, Dict, Iterator, List, Mapping, Optional from collections.abc import Iterator, Mapping
from typing import Any, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import BaseLLM from langchain_core.language_models.llms import BaseLLM
@ -82,7 +83,7 @@ class HuggingFacePipeline(BaseLLM):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def pre_init_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: def pre_init_validator(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Ensure model_id is set either by pipeline or user input.""" """Ensure model_id is set either by pipeline or user input."""
if "model_id" not in values: if "model_id" not in values:
if "pipeline" in values and values["pipeline"]: if "pipeline" in values and values["pipeline"]:
@ -297,13 +298,13 @@ class HuggingFacePipeline(BaseLLM):
def _generate( def _generate(
self, self,
prompts: List[str], prompts: list[str],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
# List to hold all results # List to hold all results
text_generations: List[str] = [] text_generations: list[str] = []
pipeline_kwargs = kwargs.get("pipeline_kwargs", {}) pipeline_kwargs = kwargs.get("pipeline_kwargs", {})
skip_prompt = kwargs.get("skip_prompt", False) skip_prompt = kwargs.get("skip_prompt", False)
@ -347,7 +348,7 @@ class HuggingFacePipeline(BaseLLM):
def _stream( def _stream(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[GenerationChunk]: ) -> Iterator[GenerationChunk]:

View File

@ -49,8 +49,12 @@ langchain-community = { path = "../../community", editable = true }
[tool.mypy] [tool.mypy]
disallow_untyped_defs = "True" disallow_untyped_defs = "True"
[tool.ruff]
target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I", "T201"] select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.coverage.run] [tool.coverage.run]
omit = ["tests/*"] omit = ["tests/*"]

View File

@ -1,7 +1,5 @@
"""Test HuggingFace embeddings.""" """Test HuggingFace embeddings."""
from typing import Type
from langchain_tests.integration_tests import EmbeddingsIntegrationTests from langchain_tests.integration_tests import EmbeddingsIntegrationTests
from langchain_huggingface.embeddings import ( from langchain_huggingface.embeddings import (
@ -12,7 +10,7 @@ from langchain_huggingface.embeddings import (
class TestHuggingFaceEmbeddings(EmbeddingsIntegrationTests): class TestHuggingFaceEmbeddings(EmbeddingsIntegrationTests):
@property @property
def embeddings_class(self) -> Type[HuggingFaceEmbeddings]: def embeddings_class(self) -> type[HuggingFaceEmbeddings]:
return HuggingFaceEmbeddings return HuggingFaceEmbeddings
@property @property
@ -22,7 +20,7 @@ class TestHuggingFaceEmbeddings(EmbeddingsIntegrationTests):
class TestHuggingFaceEndpointEmbeddings(EmbeddingsIntegrationTests): class TestHuggingFaceEndpointEmbeddings(EmbeddingsIntegrationTests):
@property @property
def embeddings_class(self) -> Type[HuggingFaceEndpointEmbeddings]: def embeddings_class(self) -> type[HuggingFaceEndpointEmbeddings]:
return HuggingFaceEndpointEmbeddings return HuggingFaceEndpointEmbeddings
@property @property

View File

@ -1,4 +1,4 @@
from typing import Generator from collections.abc import Generator
from langchain_huggingface.llms import HuggingFacePipeline from langchain_huggingface.llms import HuggingFacePipeline

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Type
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
@ -12,7 +10,7 @@ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
class TestHuggingFaceEndpoint(ChatModelIntegrationTests): class TestHuggingFaceEndpoint(ChatModelIntegrationTests):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return ChatHuggingFace return ChatHuggingFace
@property @property

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List # type: ignore[import-not-found] from typing import Any # type: ignore[import-not-found]
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
import pytest # type: ignore[import-not-found] import pytest # type: ignore[import-not-found]
@ -45,7 +45,7 @@ from langchain_huggingface.llms.huggingface_endpoint import (
], ],
) )
def test_convert_message_to_chat_message( def test_convert_message_to_chat_message(
message: BaseMessage, expected: Dict[str, str] message: BaseMessage, expected: dict[str, str]
) -> None: ) -> None:
result = _convert_message_to_chat_message(message) result = _convert_message_to_chat_message(message)
assert result == expected assert result == expected
@ -150,7 +150,7 @@ def test_create_chat_result(chat_hugging_face: Any) -> None:
], ],
) )
def test_to_chat_prompt_errors( def test_to_chat_prompt_errors(
chat_hugging_face: Any, messages: List[BaseMessage], expected_error: str chat_hugging_face: Any, messages: list[BaseMessage], expected_error: str
) -> None: ) -> None:
with pytest.raises(ValueError) as e: with pytest.raises(ValueError) as e:
chat_hugging_face._to_chat_prompt(messages) chat_hugging_face._to_chat_prompt(messages)
@ -194,7 +194,7 @@ def test_to_chat_prompt_valid_messages(chat_hugging_face: Any) -> None:
], ],
) )
def test_to_chatml_format( def test_to_chatml_format(
chat_hugging_face: Any, message: BaseMessage, expected: Dict[str, str] chat_hugging_face: Any, message: BaseMessage, expected: dict[str, str]
) -> None: ) -> None:
result = chat_hugging_face._to_chatml_format(message) result = chat_hugging_face._to_chatml_format(message)
assert result == expected assert result == expected
@ -207,7 +207,7 @@ def test_to_chatml_format_with_invalid_type(chat_hugging_face: Any) -> None:
assert "Unknown message type:" in str(e.value) assert "Unknown message type:" in str(e.value)
def tool_mock() -> Dict: def tool_mock() -> dict:
return {"function": {"name": "test_tool"}} return {"function": {"name": "test_tool"}}
@ -232,7 +232,7 @@ def tool_mock() -> Dict:
) )
def test_bind_tools_errors( def test_bind_tools_errors(
chat_hugging_face: Any, chat_hugging_face: Any,
tools: Dict[str, str], tools: dict[str, str],
tool_choice: Any, tool_choice: Any,
expected_exception: Any, expected_exception: Any,
expected_message: str, expected_message: str,

View File

@ -7,20 +7,14 @@ import os
import re import re
import ssl import ssl
import uuid import uuid
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import AbstractAsyncContextManager
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
Any, Any,
AsyncContextManager,
AsyncIterator,
Callable, Callable,
Dict,
Iterator,
List,
Literal, Literal,
Optional, Optional,
Sequence,
Tuple,
Type,
Union, Union,
cast, cast,
) )
@ -142,13 +136,13 @@ def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str:
def _convert_mistral_chat_message_to_message( def _convert_mistral_chat_message_to_message(
_message: Dict, _message: dict,
) -> BaseMessage: ) -> BaseMessage:
role = _message["role"] role = _message["role"]
assert role == "assistant", f"Expected role to be 'assistant', got {role}" assert role == "assistant", f"Expected role to be 'assistant', got {role}"
content = cast(str, _message["content"]) content = cast(str, _message["content"])
additional_kwargs: Dict = {} additional_kwargs: dict = {}
tool_calls = [] tool_calls = []
invalid_tool_calls = [] invalid_tool_calls = []
if raw_tool_calls := _message.get("tool_calls"): if raw_tool_calls := _message.get("tool_calls"):
@ -196,8 +190,8 @@ async def _araise_on_error(response: httpx.Response) -> None:
async def _aiter_sse( async def _aiter_sse(
event_source_mgr: AsyncContextManager[EventSource], event_source_mgr: AbstractAsyncContextManager[EventSource],
) -> AsyncIterator[Dict]: ) -> AsyncIterator[dict]:
"""Iterate over the server-sent events.""" """Iterate over the server-sent events."""
async with event_source_mgr as event_source: async with event_source_mgr as event_source:
await _araise_on_error(event_source.response) await _araise_on_error(event_source.response)
@ -234,7 +228,7 @@ async def acompletion_with_retry(
def _convert_chunk_to_message_chunk( def _convert_chunk_to_message_chunk(
chunk: Dict, default_class: Type[BaseMessageChunk] chunk: dict, default_class: type[BaseMessageChunk]
) -> BaseMessageChunk: ) -> BaseMessageChunk:
_choice = chunk["choices"][0] _choice = chunk["choices"][0]
_delta = _choice["delta"] _delta = _choice["delta"]
@ -243,7 +237,7 @@ def _convert_chunk_to_message_chunk(
if role == "user" or default_class == HumanMessageChunk: if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content) return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk: elif role == "assistant" or default_class == AIMessageChunk:
additional_kwargs: Dict = {} additional_kwargs: dict = {}
response_metadata = {} response_metadata = {}
if raw_tool_calls := _delta.get("tool_calls"): if raw_tool_calls := _delta.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls additional_kwargs["tool_calls"] = raw_tool_calls
@ -295,7 +289,7 @@ def _convert_chunk_to_message_chunk(
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict: def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
"""Format Langchain ToolCall to dict expected by Mistral.""" """Format Langchain ToolCall to dict expected by Mistral."""
result: Dict[str, Any] = { result: dict[str, Any] = {
"function": { "function": {
"name": tool_call["name"], "name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]), "arguments": json.dumps(tool_call["args"]),
@ -309,7 +303,7 @@ def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) -> dict: def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) -> dict:
"""Format Langchain InvalidToolCall to dict expected by Mistral.""" """Format Langchain InvalidToolCall to dict expected by Mistral."""
result: Dict[str, Any] = { result: dict[str, Any] = {
"function": { "function": {
"name": invalid_tool_call["name"], "name": invalid_tool_call["name"],
"arguments": invalid_tool_call["args"], "arguments": invalid_tool_call["args"],
@ -323,13 +317,13 @@ def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) ->
def _convert_message_to_mistral_chat_message( def _convert_message_to_mistral_chat_message(
message: BaseMessage, message: BaseMessage,
) -> Dict: ) -> dict:
if isinstance(message, ChatMessage): if isinstance(message, ChatMessage):
return dict(role=message.role, content=message.content) return dict(role=message.role, content=message.content)
elif isinstance(message, HumanMessage): elif isinstance(message, HumanMessage):
return dict(role="user", content=message.content) return dict(role="user", content=message.content)
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
message_dict: Dict[str, Any] = {"role": "assistant"} message_dict: dict[str, Any] = {"role": "assistant"}
tool_calls = [] tool_calls = []
if message.tool_calls or message.invalid_tool_calls: if message.tool_calls or message.invalid_tool_calls:
for tool_call in message.tool_calls: for tool_call in message.tool_calls:
@ -407,7 +401,7 @@ class ChatMistralAI(BaseChatModel):
random_seed: Optional[int] = None random_seed: Optional[int] = None
safe_mode: Optional[bool] = None safe_mode: Optional[bool] = None
streaming: bool = False streaming: bool = False
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any invocation parameters not explicitly specified.""" """Holds any invocation parameters not explicitly specified."""
model_config = ConfigDict( model_config = ConfigDict(
@ -417,14 +411,14 @@ class ChatMistralAI(BaseChatModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any: def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names) values = _build_model_kwargs(values, all_required_field_names)
return values return values
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling the API.""" """Get the default parameters for calling the API."""
defaults = { defaults = {
"model": self.model, "model": self.model,
@ -439,7 +433,7 @@ class ChatMistralAI(BaseChatModel):
return filtered return filtered
def _get_ls_params( def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams: ) -> LangSmithParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs) params = self._get_invocation_params(stop=stop, **kwargs)
@ -456,7 +450,7 @@ class ChatMistralAI(BaseChatModel):
return ls_params return ls_params
@property @property
def _client_params(self) -> Dict[str, Any]: def _client_params(self) -> dict[str, Any]:
"""Get the parameters used for the client.""" """Get the parameters used for the client."""
return self._default_params return self._default_params
@ -473,7 +467,7 @@ class ChatMistralAI(BaseChatModel):
stream = kwargs["stream"] stream = kwargs["stream"]
if stream: if stream:
def iter_sse() -> Iterator[Dict]: def iter_sse() -> Iterator[dict]:
with connect_sse( with connect_sse(
self.client, "POST", "/chat/completions", json=kwargs self.client, "POST", "/chat/completions", json=kwargs
) as event_source: ) as event_source:
@ -492,7 +486,7 @@ class ChatMistralAI(BaseChatModel):
rtn = _completion_with_retry(**kwargs) rtn = _completion_with_retry(**kwargs)
return rtn return rtn
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
overall_token_usage: dict = {} overall_token_usage: dict = {}
for output in llm_outputs: for output in llm_outputs:
if output is None: if output is None:
@ -557,8 +551,8 @@ class ChatMistralAI(BaseChatModel):
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
@ -577,7 +571,7 @@ class ChatMistralAI(BaseChatModel):
) )
return self._create_chat_result(response) return self._create_chat_result(response)
def _create_chat_result(self, response: Dict) -> ChatResult: def _create_chat_result(self, response: dict) -> ChatResult:
generations = [] generations = []
token_usage = response.get("usage", {}) token_usage = response.get("usage", {})
for res in response["choices"]: for res in response["choices"]:
@ -603,8 +597,8 @@ class ChatMistralAI(BaseChatModel):
return ChatResult(generations=generations, llm_output=llm_output) return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts( def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]] self, messages: list[BaseMessage], stop: Optional[list[str]]
) -> Tuple[List[Dict], Dict[str, Any]]: ) -> tuple[list[dict], dict[str, Any]]:
params = self._client_params params = self._client_params
if stop is not None or "stop" in params: if stop is not None or "stop" in params:
if "stop" in params: if "stop" in params:
@ -617,15 +611,15 @@ class ChatMistralAI(BaseChatModel):
def _stream( def _stream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
for chunk in self.completion_with_retry( for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params messages=message_dicts, run_manager=run_manager, **params
): ):
@ -643,15 +637,15 @@ class ChatMistralAI(BaseChatModel):
async def _astream( async def _astream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
async for chunk in await acompletion_with_retry( async for chunk in await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params self, messages=message_dicts, run_manager=run_manager, **params
): ):
@ -669,8 +663,8 @@ class ChatMistralAI(BaseChatModel):
async def _agenerate( async def _agenerate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
@ -691,7 +685,7 @@ class ChatMistralAI(BaseChatModel):
def bind_tools( def bind_tools(
self, self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None, tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
@ -733,14 +727,14 @@ class ChatMistralAI(BaseChatModel):
def with_structured_output( def with_structured_output(
self, self,
schema: Optional[Union[Dict, Type]] = None, schema: Optional[Union[dict, type]] = None,
*, *,
method: Literal[ method: Literal[
"function_calling", "json_mode", "json_schema" "function_calling", "json_mode", "json_schema"
] = "function_calling", ] = "function_calling",
include_raw: bool = False, include_raw: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema. """Model wrapper that returns outputs formatted to match the given schema.
Args: Args:
@ -1048,7 +1042,7 @@ class ChatMistralAI(BaseChatModel):
return llm | output_parser return llm | output_parser
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> dict[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
return self._default_params return self._default_params
@ -1058,7 +1052,7 @@ class ChatMistralAI(BaseChatModel):
return "mistralai-chat" return "mistralai-chat"
@property @property
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> dict[str, str]:
return {"mistral_api_key": "MISTRAL_API_KEY"} return {"mistral_api_key": "MISTRAL_API_KEY"}
@classmethod @classmethod
@ -1067,14 +1061,14 @@ class ChatMistralAI(BaseChatModel):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "chat_models", "mistralai"] return ["langchain", "chat_models", "mistralai"]
def _convert_to_openai_response_format( def _convert_to_openai_response_format(
schema: Union[Dict[str, Any], Type], *, strict: Optional[bool] = None schema: Union[dict[str, Any], type], *, strict: Optional[bool] = None
) -> Dict: ) -> dict:
"""Same as in ChatOpenAI, but don't pass through Pydantic BaseModels.""" """Same as in ChatOpenAI, but don't pass through Pydantic BaseModels."""
if ( if (
isinstance(schema, dict) isinstance(schema, dict)
@ -1094,8 +1088,10 @@ def _convert_to_openai_response_format(
function["schema"] = function.pop("parameters") function["schema"] = function.pop("parameters")
response_format = {"type": "json_schema", "json_schema": function} response_format = {"type": "json_schema", "json_schema": function}
if strict is not None and strict is not response_format["json_schema"].get( if (
"strict" strict is not None
and strict is not response_format["json_schema"].get("strict")
and isinstance(schema, dict)
): ):
msg = ( msg = (
f"Output schema already has 'strict' value set to " f"Output schema already has 'strict' value set to "

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
import warnings import warnings
from typing import Iterable, List from collections.abc import Iterable
import httpx import httpx
from httpx import Response from httpx import Response
@ -33,7 +33,7 @@ class DummyTokenizer:
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)""" """Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)"""
@staticmethod @staticmethod
def encode_batch(texts: List[str]) -> List[List[str]]: def encode_batch(texts: list[str]) -> list[list[str]]:
return [list(text) for text in texts] return [list(text) for text in texts]
@ -177,7 +177,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
self.tokenizer = Tokenizer.from_pretrained( self.tokenizer = Tokenizer.from_pretrained(
"mistralai/Mixtral-8x7B-v0.1" "mistralai/Mixtral-8x7B-v0.1"
) )
except IOError: # huggingface_hub GatedRepoError except OSError: # huggingface_hub GatedRepoError
warnings.warn( warnings.warn(
"Could not download mistral tokenizer from Huggingface for " "Could not download mistral tokenizer from Huggingface for "
"calculating batch sizes. Set a Huggingface token via the " "calculating batch sizes. Set a Huggingface token via the "
@ -187,10 +187,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
self.tokenizer = DummyTokenizer() self.tokenizer = DummyTokenizer()
return self return self
def _get_batches(self, texts: List[str]) -> Iterable[List[str]]: def _get_batches(self, texts: list[str]) -> Iterable[list[str]]:
"""Split a list of texts into batches of less than 16k tokens """Split a list of texts into batches of less than 16k tokens
for Mistral API.""" for Mistral API."""
batch: List[str] = [] batch: list[str] = []
batch_tokens = 0 batch_tokens = 0
text_token_lengths = [ text_token_lengths = [
@ -211,7 +211,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
if batch: if batch:
yield batch yield batch
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of document texts. """Embed a list of document texts.
Args: Args:
@ -230,7 +230,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
wait=wait_fixed(self.wait_time), wait=wait_fixed(self.wait_time),
stop=stop_after_attempt(self.max_retries), stop=stop_after_attempt(self.max_retries),
) )
def _embed_batch(batch: List[str]) -> Response: def _embed_batch(batch: list[str]) -> Response:
response = self.client.post( response = self.client.post(
url="/embeddings", url="/embeddings",
json=dict( json=dict(
@ -252,7 +252,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
logger.error(f"An error occurred with MistralAI: {e}") logger.error(f"An error occurred with MistralAI: {e}")
raise raise
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of document texts. """Embed a list of document texts.
Args: Args:
@ -283,7 +283,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
logger.error(f"An error occurred with MistralAI: {e}") logger.error(f"An error occurred with MistralAI: {e}")
raise raise
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> list[float]:
"""Embed a single query text. """Embed a single query text.
Args: Args:
@ -294,7 +294,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
""" """
return self.embed_documents([text])[0] return self.embed_documents([text])[0]
async def aembed_query(self, text: str) -> List[float]: async def aembed_query(self, text: str) -> list[float]:
"""Embed a single query text. """Embed a single query text.
Args: Args:

View File

@ -44,8 +44,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
[tool.mypy] [tool.mypy]
disallow_untyped_defs = "True" disallow_untyped_defs = "True"
[tool.ruff]
target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I", "T201"] select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.coverage.run] [tool.coverage.run]
omit = ["tests/*"] omit = ["tests/*"]

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Type
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_tests.integration_tests import ( # type: ignore[import-not-found] from langchain_tests.integration_tests import ( # type: ignore[import-not-found]
ChatModelIntegrationTests, # type: ignore[import-not-found] ChatModelIntegrationTests, # type: ignore[import-not-found]
@ -12,7 +10,7 @@ from langchain_mistralai import ChatMistralAI
class TestMistralStandard(ChatModelIntegrationTests): class TestMistralStandard(ChatModelIntegrationTests):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return ChatMistralAI return ChatMistralAI
@property @property

View File

@ -1,7 +1,8 @@
"""Test MistralAI Chat API wrapper.""" """Test MistralAI Chat API wrapper."""
import os import os
from typing import Any, AsyncGenerator, Dict, Generator, List, cast from collections.abc import AsyncGenerator, Generator
from typing import Any, cast
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import httpx import httpx
@ -104,13 +105,13 @@ def test_mistralai_initialization_baseurl_env(env_var_name: str) -> None:
], ],
) )
def test_convert_message_to_mistral_chat_message( def test_convert_message_to_mistral_chat_message(
message: BaseMessage, expected: Dict message: BaseMessage, expected: dict
) -> None: ) -> None:
result = _convert_message_to_mistral_chat_message(message) result = _convert_message_to_mistral_chat_message(message)
assert result == expected assert result == expected
def _make_completion_response_from_token(token: str) -> Dict: def _make_completion_response_from_token(token: str) -> dict:
return dict( return dict(
id="abc123", id="abc123",
model="fake_model", model="fake_model",
@ -236,7 +237,7 @@ def test__convert_dict_to_message_tool_call() -> None:
def test_custom_token_counting() -> None: def test_custom_token_counting() -> None:
def token_encoder(text: str) -> List[int]: def token_encoder(text: str) -> list[int]:
return [1, 2, 3] return [1, 2, 3]
llm = ChatMistralAI(custom_get_token_ids=token_encoder) llm = ChatMistralAI(custom_get_token_ids=token_encoder)

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Type
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests import ( # type: ignore[import-not-found] from langchain_tests.unit_tests import ( # type: ignore[import-not-found]
ChatModelUnitTests, # type: ignore[import-not-found] ChatModelUnitTests, # type: ignore[import-not-found]
@ -12,5 +10,5 @@ from langchain_mistralai import ChatMistralAI
class TestMistralStandard(ChatModelUnitTests): class TestMistralStandard(ChatModelUnitTests):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return ChatMistralAI return ChatMistralAI

View File

@ -1,5 +1,5 @@
import os import os
from typing import List, Literal, Optional, overload from typing import Literal, Optional, overload
import nomic # type: ignore[import] import nomic # type: ignore[import]
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
@ -86,7 +86,7 @@ class NomicEmbeddings(Embeddings):
self.device = device self.device = device
self.vision_model = vision_model self.vision_model = vision_model
def embed(self, texts: List[str], *, task_type: str) -> List[List[float]]: def embed(self, texts: list[str], *, task_type: str) -> list[list[float]]:
"""Embed texts. """Embed texts.
Args: Args:
@ -105,7 +105,7 @@ class NomicEmbeddings(Embeddings):
) )
return output["embeddings"] return output["embeddings"]
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs. """Embed search docs.
Args: Args:
@ -116,7 +116,7 @@ class NomicEmbeddings(Embeddings):
task_type="search_document", task_type="search_document",
) )
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> list[float]:
"""Embed query text. """Embed query text.
Args: Args:
@ -127,7 +127,7 @@ class NomicEmbeddings(Embeddings):
task_type="search_query", task_type="search_query",
)[0] )[0]
def embed_image(self, uris: List[str]) -> List[List[float]]: def embed_image(self, uris: list[str]) -> list[list[float]]:
return embed.image( return embed.image(
images=uris, images=uris,
model=self.vision_model, model=self.vision_model,

View File

@ -40,13 +40,18 @@ dev = ["langchain-core"]
[tool.uv.sources] [tool.uv.sources]
langchain-core = { path = "../../core", editable = true } langchain-core = { path = "../../core", editable = true }
[tool.ruff]
target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = [ select = [
"E", # pycodestyle "E", # pycodestyle
"F", # pyflakes "F", # pyflakes
"I", # isort "I", # isort
"T201", # print "T201", # print
"UP", # pyupgrade
] ]
ignore = [ "UP007", ]
[tool.mypy] [tool.mypy]
disallow_untyped_defs = "True" disallow_untyped_defs = "True"

View File

@ -3,7 +3,6 @@
It provides infrastructure for interacting with the Ollama service. It provides infrastructure for interacting with the Ollama service.
""" """
from importlib import metadata from importlib import metadata
from langchain_ollama.chat_models import ChatOllama from langchain_ollama.chat_models import ChatOllama

View File

@ -1,21 +1,14 @@
"""Ollama chat models.""" """Ollama chat models."""
import json import json
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Callable, Callable,
Dict,
Final, Final,
Iterator,
List,
Literal, Literal,
Mapping,
Optional, Optional,
Sequence,
Tuple,
Type,
Union, Union,
cast, cast,
) )
@ -153,7 +146,7 @@ def _parse_arguments_from_tool_call(
def _get_tool_calls_from_response( def _get_tool_calls_from_response(
response: Mapping[str, Any], response: Mapping[str, Any],
) -> List[ToolCall]: ) -> list[ToolCall]:
"""Get tool calls from ollama response.""" """Get tool calls from ollama response."""
tool_calls = [] tool_calls = []
if "message" in response: if "message" in response:
@ -341,7 +334,7 @@ class ChatOllama(BaseChatModel):
model: str model: str
"""Model name to use.""" """Model name to use."""
extract_reasoning: Optional[Union[bool, Tuple[str, str]]] = False extract_reasoning: Optional[Union[bool, tuple[str, str]]] = False
"""Whether to extract the reasoning tokens in think blocks. """Whether to extract the reasoning tokens in think blocks.
Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`. Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`.
If a tuple is supplied, they are assumed to be the (start, end) tokens. If a tuple is supplied, they are assumed to be the (start, end) tokens.
@ -399,7 +392,7 @@ class ChatOllama(BaseChatModel):
to a specific number will make the model generate the same text for to a specific number will make the model generate the same text for
the same prompt.""" the same prompt."""
stop: Optional[List[str]] = None stop: Optional[list[str]] = None
"""Sets the stop tokens to use.""" """Sets the stop tokens to use."""
tfs_z: Optional[float] = None tfs_z: Optional[float] = None
@ -443,10 +436,10 @@ class ChatOllama(BaseChatModel):
def _chat_params( def _chat_params(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> dict[str, Any]:
ollama_messages = self._convert_messages_to_ollama_messages(messages) ollama_messages = self._convert_messages_to_ollama_messages(messages)
if self.stop is not None and stop is not None: if self.stop is not None and stop is not None:
@ -499,13 +492,13 @@ class ChatOllama(BaseChatModel):
return self return self
def _convert_messages_to_ollama_messages( def _convert_messages_to_ollama_messages(
self, messages: List[BaseMessage] self, messages: list[BaseMessage]
) -> Sequence[Message]: ) -> Sequence[Message]:
ollama_messages: List = [] ollama_messages: list = []
for message in messages: for message in messages:
role: Literal["user", "assistant", "system", "tool"] role: Literal["user", "assistant", "system", "tool"]
tool_call_id: Optional[str] = None tool_call_id: Optional[str] = None
tool_calls: Optional[List[Dict[str, Any]]] = None tool_calls: Optional[list[dict[str, Any]]] = None
if isinstance(message, HumanMessage): if isinstance(message, HumanMessage):
role = "user" role = "user"
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
@ -531,7 +524,7 @@ class ChatOllama(BaseChatModel):
if isinstance(message.content, str): if isinstance(message.content, str):
content = message.content content = message.content
else: else:
for content_part in cast(List[Dict], message.content): for content_part in cast(list[dict], message.content):
if content_part.get("type") == "text": if content_part.get("type") == "text":
content += f"\n{content_part['text']}" content += f"\n{content_part['text']}"
elif content_part.get("type") == "tool_use": elif content_part.get("type") == "tool_use":
@ -583,7 +576,7 @@ class ChatOllama(BaseChatModel):
def _extract_reasoning( def _extract_reasoning(
self, message_chunk: BaseMessageChunk, is_thinking: bool self, message_chunk: BaseMessageChunk, is_thinking: bool
) -> Tuple[BaseMessageChunk, bool]: ) -> tuple[BaseMessageChunk, bool]:
"""Mutate a message chunk to extract reasoning content.""" """Mutate a message chunk to extract reasoning content."""
if not self.extract_reasoning: if not self.extract_reasoning:
return message_chunk, is_thinking return message_chunk, is_thinking
@ -605,8 +598,8 @@ class ChatOllama(BaseChatModel):
async def _acreate_chat_stream( async def _acreate_chat_stream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Union[Mapping[str, Any], str]]: ) -> AsyncIterator[Union[Mapping[str, Any], str]]:
chat_params = self._chat_params(messages, stop, **kwargs) chat_params = self._chat_params(messages, stop, **kwargs)
@ -619,8 +612,8 @@ class ChatOllama(BaseChatModel):
def _create_chat_stream( def _create_chat_stream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Union[Mapping[str, Any], str]]: ) -> Iterator[Union[Mapping[str, Any], str]]:
chat_params = self._chat_params(messages, stop, **kwargs) chat_params = self._chat_params(messages, stop, **kwargs)
@ -632,8 +625,8 @@ class ChatOllama(BaseChatModel):
def _chat_stream_with_aggregation( def _chat_stream_with_aggregation(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
verbose: bool = False, verbose: bool = False,
**kwargs: Any, **kwargs: Any,
@ -657,8 +650,8 @@ class ChatOllama(BaseChatModel):
async def _achat_stream_with_aggregation( async def _achat_stream_with_aggregation(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
verbose: bool = False, verbose: bool = False,
**kwargs: Any, **kwargs: Any,
@ -681,7 +674,7 @@ class ChatOllama(BaseChatModel):
return final_chunk return final_chunk
def _get_ls_params( def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams: ) -> LangSmithParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs) params = self._get_invocation_params(stop=stop, **kwargs)
@ -697,8 +690,8 @@ class ChatOllama(BaseChatModel):
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -719,8 +712,8 @@ class ChatOllama(BaseChatModel):
def _iterate_over_stream( def _iterate_over_stream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
is_thinking = False is_thinking = False
@ -758,8 +751,8 @@ class ChatOllama(BaseChatModel):
def _stream( def _stream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
@ -773,8 +766,8 @@ class ChatOllama(BaseChatModel):
async def _aiterate_over_stream( async def _aiterate_over_stream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
is_thinking = False is_thinking = False
@ -812,8 +805,8 @@ class ChatOllama(BaseChatModel):
async def _astream( async def _astream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
@ -827,8 +820,8 @@ class ChatOllama(BaseChatModel):
async def _agenerate( async def _agenerate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -854,7 +847,7 @@ class ChatOllama(BaseChatModel):
def bind_tools( def bind_tools(
self, self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
*, *,
tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None, tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None,
**kwargs: Any, **kwargs: Any,
@ -877,12 +870,12 @@ class ChatOllama(BaseChatModel):
def with_structured_output( def with_structured_output(
self, self,
schema: Union[Dict, type], schema: Union[dict, type],
*, *,
method: Literal["function_calling", "json_mode", "json_schema"] = "json_schema", method: Literal["function_calling", "json_mode", "json_schema"] = "json_schema",
include_raw: bool = False, include_raw: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema. """Model wrapper that returns outputs formatted to match the given schema.
Args: Args:

View File

@ -1,6 +1,6 @@
"""Ollama embeddings models.""" """Ollama embeddings models."""
from typing import Any, Dict, List, Optional from typing import Any, Optional
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from ollama import AsyncClient, Client from ollama import AsyncClient, Client
@ -188,7 +188,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
"""The temperature of the model. Increasing the temperature will """The temperature of the model. Increasing the temperature will
make the model answer more creatively. (Default: 0.8)""" make the model answer more creatively. (Default: 0.8)"""
stop: Optional[List[str]] = None stop: Optional[list[str]] = None
"""Sets the stop tokens to use.""" """Sets the stop tokens to use."""
tfs_z: Optional[float] = None tfs_z: Optional[float] = None
@ -211,7 +211,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
) )
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling Ollama.""" """Get the default parameters for calling Ollama."""
return { return {
"mirostat": self.mirostat, "mirostat": self.mirostat,
@ -237,18 +237,18 @@ class OllamaEmbeddings(BaseModel, Embeddings):
self._async_client = AsyncClient(host=self.base_url, **client_kwargs) self._async_client = AsyncClient(host=self.base_url, **client_kwargs)
return self return self
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs.""" """Embed search docs."""
embedded_docs = self._client.embed( embedded_docs = self._client.embed(
self.model, texts, options=self._default_params, keep_alive=self.keep_alive self.model, texts, options=self._default_params, keep_alive=self.keep_alive
)["embeddings"] )["embeddings"]
return embedded_docs return embedded_docs
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> list[float]:
"""Embed query text.""" """Embed query text."""
return self.embed_documents([text])[0] return self.embed_documents([text])[0]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs.""" """Embed search docs."""
embedded_docs = ( embedded_docs = (
await self._async_client.embed( await self._async_client.embed(
@ -257,6 +257,6 @@ class OllamaEmbeddings(BaseModel, Embeddings):
)["embeddings"] )["embeddings"]
return embedded_docs return embedded_docs
async def aembed_query(self, text: str) -> List[float]: async def aembed_query(self, text: str) -> list[float]:
"""Embed query text.""" """Embed query text."""
return (await self.aembed_documents([text]))[0] return (await self.aembed_documents([text]))[0]

View File

@ -1,13 +1,9 @@
"""Ollama large language models.""" """Ollama large language models."""
from collections.abc import AsyncIterator, Iterator, Mapping
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Dict,
Iterator,
List,
Literal, Literal,
Mapping,
Optional, Optional,
Union, Union,
) )
@ -89,7 +85,7 @@ class OllamaLLM(BaseLLM):
to a specific number will make the model generate the same text for to a specific number will make the model generate the same text for
the same prompt.""" the same prompt."""
stop: Optional[List[str]] = None stop: Optional[list[str]] = None
"""Sets the stop tokens to use.""" """Sets the stop tokens to use."""
tfs_z: Optional[float] = None tfs_z: Optional[float] = None
@ -134,9 +130,9 @@ class OllamaLLM(BaseLLM):
def _generate_params( def _generate_params(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> dict[str, Any]:
if self.stop is not None and stop is not None: if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.") raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None: elif self.stop is not None:
@ -181,7 +177,7 @@ class OllamaLLM(BaseLLM):
return "ollama-llm" return "ollama-llm"
def _get_ls_params( def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams: ) -> LangSmithParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs) params = super()._get_ls_params(stop=stop, **kwargs)
@ -200,7 +196,7 @@ class OllamaLLM(BaseLLM):
async def _acreate_generate_stream( async def _acreate_generate_stream(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Union[Mapping[str, Any], str]]: ) -> AsyncIterator[Union[Mapping[str, Any], str]]:
async for part in await self._async_client.generate( async for part in await self._async_client.generate(
@ -211,7 +207,7 @@ class OllamaLLM(BaseLLM):
def _create_generate_stream( def _create_generate_stream(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Union[Mapping[str, Any], str]]: ) -> Iterator[Union[Mapping[str, Any], str]]:
yield from self._client.generate( yield from self._client.generate(
@ -221,7 +217,7 @@ class OllamaLLM(BaseLLM):
async def _astream_with_aggregation( async def _astream_with_aggregation(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
verbose: bool = False, verbose: bool = False,
**kwargs: Any, **kwargs: Any,
@ -253,7 +249,7 @@ class OllamaLLM(BaseLLM):
def _stream_with_aggregation( def _stream_with_aggregation(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
verbose: bool = False, verbose: bool = False,
**kwargs: Any, **kwargs: Any,
@ -284,8 +280,8 @@ class OllamaLLM(BaseLLM):
def _generate( def _generate(
self, self,
prompts: List[str], prompts: list[str],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
@ -303,8 +299,8 @@ class OllamaLLM(BaseLLM):
async def _agenerate( async def _agenerate(
self, self,
prompts: List[str], prompts: list[str],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
@ -323,7 +319,7 @@ class OllamaLLM(BaseLLM):
def _stream( def _stream(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[GenerationChunk]: ) -> Iterator[GenerationChunk]:
@ -345,7 +341,7 @@ class OllamaLLM(BaseLLM):
async def _astream( async def _astream(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[GenerationChunk]: ) -> AsyncIterator[GenerationChunk]:

View File

@ -40,6 +40,9 @@ langchain-tests = { path = "../../standard-tests", editable = true }
[tool.mypy] [tool.mypy]
disallow_untyped_defs = "True" disallow_untyped_defs = "True"
[tool.ruff]
target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = [ select = [
"E", # pycodestyle "E", # pycodestyle
@ -47,8 +50,9 @@ select = [
"I", # isort "I", # isort
"T201", # print "T201", # print
"D", # pydocstyle "D", # pydocstyle
"UP", # pyupgrade
] ]
ignore = [ "UP007", ]
[tool.ruff.lint.pydocstyle] [tool.ruff.lint.pydocstyle]
convention = "google" convention = "google"

View File

@ -1,4 +1,5 @@
"""load multiple Python files specified as command line arguments.""" """load multiple Python files specified as command line arguments."""
import sys import sys
import traceback import traceback
from importlib.machinery import SourceFileLoader from importlib.machinery import SourceFileLoader

View File

@ -1,10 +1,10 @@
"""Ollama specific chat model integration tests""" """Ollama specific chat model integration tests"""
from typing import List, Optional from typing import Annotated, Optional
import pytest import pytest
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated, TypedDict from typing_extensions import TypedDict
from langchain_ollama import ChatOllama from langchain_ollama import ChatOllama
@ -78,7 +78,7 @@ def test_structured_output_deeply_nested(model: str) -> None:
class Data(BaseModel): class Data(BaseModel):
"""Extracted data about people.""" """Extracted data about people."""
people: List[Person] people: list[Person]
chat = llm.with_structured_output(Data) # type: ignore[arg-type] chat = llm.with_structured_output(Data) # type: ignore[arg-type]
text = ( text = (

View File

@ -1,7 +1,5 @@
"""Test chat model integration using standard integration tests.""" """Test chat model integration using standard integration tests."""
from typing import Type
from langchain_tests.integration_tests import ChatModelIntegrationTests from langchain_tests.integration_tests import ChatModelIntegrationTests
from langchain_ollama.chat_models import ChatOllama from langchain_ollama.chat_models import ChatOllama
@ -9,7 +7,7 @@ from langchain_ollama.chat_models import ChatOllama
class TestChatOllama(ChatModelIntegrationTests): class TestChatOllama(ChatModelIntegrationTests):
@property @property
def chat_model_class(self) -> Type[ChatOllama]: def chat_model_class(self) -> type[ChatOllama]:
return ChatOllama return ChatOllama
@property @property

View File

@ -1,7 +1,5 @@
"""Test Ollama embeddings.""" """Test Ollama embeddings."""
from typing import Type
from langchain_tests.integration_tests import EmbeddingsIntegrationTests from langchain_tests.integration_tests import EmbeddingsIntegrationTests
from langchain_ollama.embeddings import OllamaEmbeddings from langchain_ollama.embeddings import OllamaEmbeddings
@ -9,7 +7,7 @@ from langchain_ollama.embeddings import OllamaEmbeddings
class TestOllamaEmbeddings(EmbeddingsIntegrationTests): class TestOllamaEmbeddings(EmbeddingsIntegrationTests):
@property @property
def embeddings_class(self) -> Type[OllamaEmbeddings]: def embeddings_class(self) -> type[OllamaEmbeddings]:
return OllamaEmbeddings return OllamaEmbeddings
@property @property

View File

@ -1,6 +1,6 @@
"""Test chat model integration.""" """Test chat model integration."""
import json import json
from typing import Dict, Type
from langchain_tests.unit_tests import ChatModelUnitTests from langchain_tests.unit_tests import ChatModelUnitTests
@ -9,11 +9,11 @@ from langchain_ollama.chat_models import ChatOllama, _parse_arguments_from_tool_
class TestChatOllama(ChatModelUnitTests): class TestChatOllama(ChatModelUnitTests):
@property @property
def chat_model_class(self) -> Type[ChatOllama]: def chat_model_class(self) -> type[ChatOllama]:
return ChatOllama return ChatOllama
@property @property
def chat_model_params(self) -> Dict: def chat_model_params(self) -> dict:
return {"model": "llama3-groq-tool-use"} return {"model": "llama3-groq-tool-use"}

View File

@ -4,18 +4,8 @@ from __future__ import annotations
import logging import logging
import os import os
from typing import ( from collections.abc import Awaitable
Any, from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
Awaitable,
Callable,
Dict,
List,
Optional,
Type,
TypedDict,
TypeVar,
Union,
)
import openai import openai
from langchain_core.language_models import LanguageModelInput from langchain_core.language_models import LanguageModelInput
@ -34,8 +24,8 @@ logger = logging.getLogger(__name__)
_BM = TypeVar("_BM", bound=BaseModel) _BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM]] _DictOrPydanticClass = Union[dict[str, Any], type[_BM]]
_DictOrPydantic = Union[Dict, _BM] _DictOrPydantic = Union[dict, _BM]
class _AllReturnType(TypedDict): class _AllReturnType(TypedDict):
@ -547,7 +537,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
Used for tracing and token counting. Does NOT affect completion. Used for tracing and token counting. Does NOT affect completion.
""" """
disabled_params: Optional[Dict[str, Any]] = Field(default=None) disabled_params: Optional[dict[str, Any]] = Field(default=None)
"""Parameters of the OpenAI client or chat.completions endpoint that should be """Parameters of the OpenAI client or chat.completions endpoint that should be
disabled for the given model. disabled for the given model.
@ -570,12 +560,12 @@ class AzureChatOpenAI(BaseChatOpenAI):
""" """
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "chat_models", "azure_openai"] return ["langchain", "chat_models", "azure_openai"]
@property @property
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> dict[str, str]:
return { return {
"openai_api_key": "AZURE_OPENAI_API_KEY", "openai_api_key": "AZURE_OPENAI_API_KEY",
"azure_ad_token": "AZURE_OPENAI_AD_TOKEN", "azure_ad_token": "AZURE_OPENAI_AD_TOKEN",
@ -672,7 +662,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
return self return self
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> dict[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
return { return {
**{"azure_deployment": self.deployment_name}, **{"azure_deployment": self.deployment_name},
@ -684,14 +674,14 @@ class AzureChatOpenAI(BaseChatOpenAI):
return "azure-openai-chat" return "azure-openai-chat"
@property @property
def lc_attributes(self) -> Dict[str, Any]: def lc_attributes(self) -> dict[str, Any]:
return { return {
"openai_api_type": self.openai_api_type, "openai_api_type": self.openai_api_type,
"openai_api_version": self.openai_api_version, "openai_api_version": self.openai_api_version,
} }
def _get_ls_params( def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams: ) -> LangSmithParams:
"""Get the parameters used to invoke the model.""" """Get the parameters used to invoke the model."""
params = super()._get_ls_params(stop=stop, **kwargs) params = super()._get_ls_params(stop=stop, **kwargs)
@ -710,7 +700,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
def _create_chat_result( def _create_chat_result(
self, self,
response: Union[dict, openai.BaseModel], response: Union[dict, openai.BaseModel],
generation_info: Optional[Dict] = None, generation_info: Optional[dict] = None,
) -> ChatResult: ) -> ChatResult:
chat_result = super()._create_chat_result(response, generation_info) chat_result = super()._create_chat_result(response, generation_info)

View File

@ -10,6 +10,7 @@ import re
import ssl import ssl
import sys import sys
import warnings import warnings
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from functools import partial from functools import partial
from io import BytesIO from io import BytesIO
from json import JSONDecodeError from json import JSONDecodeError
@ -18,17 +19,9 @@ from operator import itemgetter
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncIterator,
Callable, Callable,
Dict,
Iterator,
List,
Literal, Literal,
Mapping,
Optional, Optional,
Sequence,
Tuple,
Type,
TypedDict, TypedDict,
TypeVar, TypeVar,
Union, Union,
@ -137,7 +130,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
# Fix for azure # Fix for azure
# Also OpenAI returns None for tool invocations # Also OpenAI returns None for tool invocations
content = _dict.get("content", "") or "" content = _dict.get("content", "") or ""
additional_kwargs: Dict = {} additional_kwargs: dict = {}
if function_call := _dict.get("function_call"): if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call) additional_kwargs["function_call"] = dict(function_call)
tool_calls = [] tool_calls = []
@ -243,7 +236,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
Returns: Returns:
The dictionary. The dictionary.
""" """
message_dict: Dict[str, Any] = {"content": _format_message_content(message.content)} message_dict: dict[str, Any] = {"content": _format_message_content(message.content)}
if (name := message.name or message.additional_kwargs.get("name")) is not None: if (name := message.name or message.additional_kwargs.get("name")) is not None:
message_dict["name"] = name message_dict["name"] = name
@ -304,12 +297,12 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_delta_to_message_chunk( def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] _dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk: ) -> BaseMessageChunk:
id_ = _dict.get("id") id_ = _dict.get("id")
role = cast(str, _dict.get("role")) role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "") content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {} additional_kwargs: dict = {}
if _dict.get("function_call"): if _dict.get("function_call"):
function_call = dict(_dict["function_call"]) function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None: if "name" in function_call and function_call["name"] is None:
@ -418,8 +411,8 @@ class _FunctionCall(TypedDict):
_BM = TypeVar("_BM", bound=BaseModel) _BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type] _DictOrPydanticClass = Union[dict[str, Any], type[_BM], type]
_DictOrPydantic = Union[Dict, _BM] _DictOrPydantic = Union[dict, _BM]
class _AllReturnType(TypedDict): class _AllReturnType(TypedDict):
@ -437,7 +430,7 @@ class BaseChatOpenAI(BaseChatModel):
"""Model name to use.""" """Model name to use."""
temperature: Optional[float] = None temperature: Optional[float] = None
"""What sampling temperature to use.""" """What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[SecretStr] = Field( openai_api_key: Optional[SecretStr] = Field(
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
@ -451,7 +444,7 @@ class BaseChatOpenAI(BaseChatModel):
openai_proxy: Optional[str] = Field( openai_proxy: Optional[str] = Field(
default_factory=from_env("OPENAI_PROXY", default=None) default_factory=from_env("OPENAI_PROXY", default=None)
) )
request_timeout: Union[float, Tuple[float, float], Any, None] = Field( request_timeout: Union[float, tuple[float, float], Any, None] = Field(
default=None, alias="timeout" default=None, alias="timeout"
) )
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or """Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
@ -476,7 +469,7 @@ class BaseChatOpenAI(BaseChatModel):
"""Number of most likely tokens to return at each token position, each with """Number of most likely tokens to return at each token position, each with
an associated log probability. `logprobs` must be set to true an associated log probability. `logprobs` must be set to true
if this parameter is used.""" if this parameter is used."""
logit_bias: Optional[Dict[int, int]] = None logit_bias: Optional[dict[int, int]] = None
"""Modify the likelihood of specified tokens appearing in the completion.""" """Modify the likelihood of specified tokens appearing in the completion."""
streaming: bool = False streaming: bool = False
"""Whether to stream the results or not.""" """Whether to stream the results or not."""
@ -517,14 +510,14 @@ class BaseChatOpenAI(BaseChatModel):
http_async_client: Union[Any, None] = Field(default=None, exclude=True) http_async_client: Union[Any, None] = Field(default=None, exclude=True)
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify """Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations.""" http_client as well if you'd like a custom client for sync invocations."""
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences") stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences.""" """Default stop sequences."""
extra_body: Optional[Mapping[str, Any]] = None extra_body: Optional[Mapping[str, Any]] = None
"""Optional additional JSON properties to include in the request parameters when """Optional additional JSON properties to include in the request parameters when
making requests to OpenAI compatible APIs, such as vLLM.""" making requests to OpenAI compatible APIs, such as vLLM."""
include_response_headers: bool = False include_response_headers: bool = False
"""Whether to include response headers in the output message response_metadata.""" """Whether to include response headers in the output message response_metadata."""
disabled_params: Optional[Dict[str, Any]] = Field(default=None) disabled_params: Optional[dict[str, Any]] = Field(default=None)
"""Parameters of the OpenAI client or chat.completions endpoint that should be """Parameters of the OpenAI client or chat.completions endpoint that should be
disabled for the given model. disabled for the given model.
@ -554,7 +547,7 @@ class BaseChatOpenAI(BaseChatModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any: def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names) values = _build_model_kwargs(values, all_required_field_names)
@ -562,7 +555,7 @@ class BaseChatOpenAI(BaseChatModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_temperature(cls, values: Dict[str, Any]) -> Any: def validate_temperature(cls, values: dict[str, Any]) -> Any:
"""Currently o1 models only allow temperature=1.""" """Currently o1 models only allow temperature=1."""
model = values.get("model_name") or values.get("model") or "" model = values.get("model_name") or values.get("model") or ""
if model.startswith("o1") and "temperature" not in values: if model.startswith("o1") and "temperature" not in values:
@ -642,7 +635,7 @@ class BaseChatOpenAI(BaseChatModel):
return self return self
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling OpenAI API.""" """Get the default parameters for calling OpenAI API."""
exclude_if_none = { exclude_if_none = {
"presence_penalty": self.presence_penalty, "presence_penalty": self.presence_penalty,
@ -669,7 +662,7 @@ class BaseChatOpenAI(BaseChatModel):
return params return params
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
overall_token_usage: dict = {} overall_token_usage: dict = {}
system_fingerprint = None system_fingerprint = None
for output in llm_outputs: for output in llm_outputs:
@ -697,8 +690,8 @@ class BaseChatOpenAI(BaseChatModel):
def _convert_chunk_to_generation_chunk( def _convert_chunk_to_generation_chunk(
self, self,
chunk: dict, chunk: dict,
default_chunk_class: Type, default_chunk_class: type,
base_generation_info: Optional[Dict], base_generation_info: Optional[dict],
) -> Optional[ChatGenerationChunk]: ) -> Optional[ChatGenerationChunk]:
if chunk.get("type") == "content.delta": # from beta.chat.completions.stream if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
return None return None
@ -749,8 +742,8 @@ class BaseChatOpenAI(BaseChatModel):
def _stream_responses( def _stream_responses(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
@ -783,8 +776,8 @@ class BaseChatOpenAI(BaseChatModel):
async def _astream_responses( async def _astream_responses(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
@ -838,8 +831,8 @@ class BaseChatOpenAI(BaseChatModel):
def _stream( def _stream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
*, *,
stream_usage: Optional[bool] = None, stream_usage: Optional[bool] = None,
@ -850,7 +843,7 @@ class BaseChatOpenAI(BaseChatModel):
if stream_usage: if stream_usage:
kwargs["stream_options"] = {"include_usage": stream_usage} kwargs["stream_options"] = {"include_usage": stream_usage}
payload = self._get_request_payload(messages, stop=stop, **kwargs) payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {} base_generation_info = {}
if "response_format" in payload: if "response_format" in payload:
@ -908,8 +901,8 @@ class BaseChatOpenAI(BaseChatModel):
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -965,7 +958,7 @@ class BaseChatOpenAI(BaseChatModel):
self, self,
input_: LanguageModelInput, input_: LanguageModelInput,
*, *,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> dict: ) -> dict:
messages = self._convert_input(input_).to_messages() messages = self._convert_input(input_).to_messages()
@ -982,7 +975,7 @@ class BaseChatOpenAI(BaseChatModel):
def _create_chat_result( def _create_chat_result(
self, self,
response: Union[dict, openai.BaseModel], response: Union[dict, openai.BaseModel],
generation_info: Optional[Dict] = None, generation_info: Optional[dict] = None,
) -> ChatResult: ) -> ChatResult:
generations = [] generations = []
@ -1032,8 +1025,8 @@ class BaseChatOpenAI(BaseChatModel):
async def _astream( async def _astream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
*, *,
stream_usage: Optional[bool] = None, stream_usage: Optional[bool] = None,
@ -1044,7 +1037,7 @@ class BaseChatOpenAI(BaseChatModel):
if stream_usage: if stream_usage:
kwargs["stream_options"] = {"include_usage": stream_usage} kwargs["stream_options"] = {"include_usage": stream_usage}
payload = self._get_request_payload(messages, stop=stop, **kwargs) payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {} base_generation_info = {}
if "response_format" in payload: if "response_format" in payload:
@ -1106,8 +1099,8 @@ class BaseChatOpenAI(BaseChatModel):
async def _agenerate( async def _agenerate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -1160,13 +1153,13 @@ class BaseChatOpenAI(BaseChatModel):
) )
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> dict[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
return {"model_name": self.model_name, **self._default_params} return {"model_name": self.model_name, **self._default_params}
def _get_invocation_params( def _get_invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any self, stop: Optional[list[str]] = None, **kwargs: Any
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Get the parameters used to invoke the model.""" """Get the parameters used to invoke the model."""
return { return {
"model": self.model_name, "model": self.model_name,
@ -1176,7 +1169,7 @@ class BaseChatOpenAI(BaseChatModel):
} }
def _get_ls_params( def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams: ) -> LangSmithParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs) params = self._get_invocation_params(stop=stop, **kwargs)
@ -1199,7 +1192,7 @@ class BaseChatOpenAI(BaseChatModel):
"""Return type of chat model.""" """Return type of chat model."""
return "openai-chat" return "openai-chat"
def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]: def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:
if self.tiktoken_model_name is not None: if self.tiktoken_model_name is not None:
model = self.tiktoken_model_name model = self.tiktoken_model_name
else: else:
@ -1211,7 +1204,7 @@ class BaseChatOpenAI(BaseChatModel):
encoding = tiktoken.get_encoding(model) encoding = tiktoken.get_encoding(model)
return model, encoding return model, encoding
def get_token_ids(self, text: str) -> List[int]: def get_token_ids(self, text: str) -> list[int]:
"""Get the tokens present in the text with tiktoken package.""" """Get the tokens present in the text with tiktoken package."""
if self.custom_get_token_ids is not None: if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text) return self.custom_get_token_ids(text)
@ -1223,9 +1216,9 @@ class BaseChatOpenAI(BaseChatModel):
def get_num_tokens_from_messages( def get_num_tokens_from_messages(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
tools: Optional[ tools: Optional[
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]] Sequence[Union[dict[str, Any], type, Callable, BaseTool]]
] = None, ] = None,
) -> int: ) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
@ -1327,7 +1320,7 @@ class BaseChatOpenAI(BaseChatModel):
) )
def bind_functions( def bind_functions(
self, self,
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
function_call: Optional[ function_call: Optional[
Union[_FunctionCall, str, Literal["auto", "none"]] Union[_FunctionCall, str, Literal["auto", "none"]]
] = None, ] = None,
@ -1380,7 +1373,7 @@ class BaseChatOpenAI(BaseChatModel):
def bind_tools( def bind_tools(
self, self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
*, *,
tool_choice: Optional[ tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool] Union[dict, str, Literal["auto", "none", "required", "any"], bool]
@ -1727,7 +1720,7 @@ class BaseChatOpenAI(BaseChatModel):
else: else:
return llm | output_parser return llm | output_parser
def _filter_disabled_params(self, **kwargs: Any) -> Dict[str, Any]: def _filter_disabled_params(self, **kwargs: Any) -> dict[str, Any]:
if not self.disabled_params: if not self.disabled_params:
return kwargs return kwargs
filtered = {} filtered = {}
@ -2301,17 +2294,17 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
"""Maximum number of tokens to generate.""" """Maximum number of tokens to generate."""
@property @property
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"} return {"openai_api_key": "OPENAI_API_KEY"}
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "chat_models", "openai"] return ["langchain", "chat_models", "openai"]
@property @property
def lc_attributes(self) -> Dict[str, Any]: def lc_attributes(self) -> dict[str, Any]:
attributes: Dict[str, Any] = {} attributes: dict[str, Any] = {}
if self.openai_organization: if self.openai_organization:
attributes["openai_organization"] = self.openai_organization attributes["openai_organization"] = self.openai_organization
@ -2330,7 +2323,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
return True return True
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling OpenAI API.""" """Get the default parameters for calling OpenAI API."""
params = super()._default_params params = super()._default_params
if "max_tokens" in params: if "max_tokens" in params:
@ -2342,7 +2335,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
self, self,
input_: LanguageModelInput, input_: LanguageModelInput,
*, *,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> dict: ) -> dict:
payload = super()._get_request_payload(input_, stop=stop, **kwargs) payload = super()._get_request_payload(input_, stop=stop, **kwargs)
@ -2719,7 +2712,7 @@ def _lc_invalid_tool_call_to_openai_tool_call(
} }
def _url_to_size(image_source: str) -> Optional[Tuple[int, int]]: def _url_to_size(image_source: str) -> Optional[tuple[int, int]]:
try: try:
from PIL import Image # type: ignore[import] from PIL import Image # type: ignore[import]
except ImportError: except ImportError:
@ -2771,7 +2764,7 @@ def _is_b64(s: str) -> bool:
return s.startswith("data:image") return s.startswith("data:image")
def _resize(width: int, height: int) -> Tuple[int, int]: def _resize(width: int, height: int) -> tuple[int, int]:
# larger side must be <= 2048 # larger side must be <= 2048
if width > 2048 or height > 2048: if width > 2048 or height > 2048:
if width > height: if width > height:
@ -2792,8 +2785,8 @@ def _resize(width: int, height: int) -> Tuple[int, int]:
def _convert_to_openai_response_format( def _convert_to_openai_response_format(
schema: Union[Dict[str, Any], Type], *, strict: Optional[bool] = None schema: Union[dict[str, Any], type], *, strict: Optional[bool] = None
) -> Union[Dict, TypeBaseModel]: ) -> Union[dict, TypeBaseModel]:
if isinstance(schema, type) and is_basemodel_subclass(schema): if isinstance(schema, type) and is_basemodel_subclass(schema):
return schema return schema
@ -2815,8 +2808,10 @@ def _convert_to_openai_response_format(
function["schema"] = function.pop("parameters") function["schema"] = function.pop("parameters")
response_format = {"type": "json_schema", "json_schema": function} response_format = {"type": "json_schema", "json_schema": function}
if strict is not None and strict is not response_format["json_schema"].get( if (
"strict" strict is not None
and strict is not response_format["json_schema"].get("strict")
and isinstance(schema, dict)
): ):
msg = ( msg = (
f"Output schema already has 'strict' value set to " f"Output schema already has 'strict' value set to "
@ -2829,7 +2824,7 @@ def _convert_to_openai_response_format(
def _oai_structured_outputs_parser( def _oai_structured_outputs_parser(
ai_msg: AIMessage, schema: Type[_BM] ai_msg: AIMessage, schema: type[_BM]
) -> Optional[PydanticBaseModel]: ) -> Optional[PydanticBaseModel]:
if parsed := ai_msg.additional_kwargs.get("parsed"): if parsed := ai_msg.additional_kwargs.get("parsed"):
if isinstance(parsed, dict): if isinstance(parsed, dict):
@ -3141,7 +3136,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
def _construct_lc_result_from_responses_api( def _construct_lc_result_from_responses_api(
response: Response, response: Response,
schema: Optional[Type[_BM]] = None, schema: Optional[type[_BM]] = None,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
) -> ChatResult: ) -> ChatResult:
"""Construct ChatResponse from OpenAI Response API response.""" """Construct ChatResponse from OpenAI Response API response."""
@ -3278,7 +3273,7 @@ def _construct_lc_result_from_responses_api(
def _convert_responses_chunk_to_generation_chunk( def _convert_responses_chunk_to_generation_chunk(
chunk: Any, schema: Optional[Type[_BM]] = None, metadata: Optional[dict] = None chunk: Any, schema: Optional[type[_BM]] = None, metadata: Optional[dict] = None
) -> Optional[ChatGenerationChunk]: ) -> Optional[ChatGenerationChunk]:
content = [] content = []
tool_call_chunks: list = [] tool_call_chunks: list = []

View File

@ -2,12 +2,13 @@
from __future__ import annotations from __future__ import annotations
from typing import Awaitable, Callable, Optional, Union from collections.abc import Awaitable
from typing import Callable, Optional, Union, cast
import openai import openai
from langchain_core.utils import from_env, secret_from_env from langchain_core.utils import from_env, secret_from_env
from pydantic import Field, SecretStr, model_validator from pydantic import Field, SecretStr, model_validator
from typing_extensions import Self, cast from typing_extensions import Self
from langchain_openai.embeddings.base import OpenAIEmbeddings from langchain_openai.embeddings.base import OpenAIEmbeddings

View File

@ -2,20 +2,8 @@ from __future__ import annotations
import logging import logging
import warnings import warnings
from typing import ( from collections.abc import Iterable, Mapping, Sequence
Any, from typing import Any, Literal, Optional, Union, cast
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
import openai import openai
import tiktoken import tiktoken
@ -29,19 +17,19 @@ logger = logging.getLogger(__name__)
def _process_batched_chunked_embeddings( def _process_batched_chunked_embeddings(
num_texts: int, num_texts: int,
tokens: List[Union[List[int], str]], tokens: list[Union[list[int], str]],
batched_embeddings: List[List[float]], batched_embeddings: list[list[float]],
indices: List[int], indices: list[int],
skip_empty: bool, skip_empty: bool,
) -> List[Optional[List[float]]]: ) -> list[Optional[list[float]]]:
# for each text, this is the list of embeddings (list of list of floats) # for each text, this is the list of embeddings (list of list of floats)
# corresponding to the chunks of the text # corresponding to the chunks of the text
results: List[List[List[float]]] = [[] for _ in range(num_texts)] results: list[list[list[float]]] = [[] for _ in range(num_texts)]
# for each text, this is the token length of each chunk # for each text, this is the token length of each chunk
# for transformers tokenization, this is the string length # for transformers tokenization, this is the string length
# for tiktoken, this is the number of tokens # for tiktoken, this is the number of tokens
num_tokens_in_batch: List[List[int]] = [[] for _ in range(num_texts)] num_tokens_in_batch: list[list[int]] = [[] for _ in range(num_texts)]
for i in range(len(indices)): for i in range(len(indices)):
if skip_empty and len(batched_embeddings[i]) == 1: if skip_empty and len(batched_embeddings[i]) == 1:
@ -50,10 +38,10 @@ def _process_batched_chunked_embeddings(
num_tokens_in_batch[indices[i]].append(len(tokens[i])) num_tokens_in_batch[indices[i]].append(len(tokens[i]))
# for each text, this is the final embedding # for each text, this is the final embedding
embeddings: List[Optional[List[float]]] = [] embeddings: list[Optional[list[float]]] = []
for i in range(num_texts): for i in range(num_texts):
# an embedding for each chunk # an embedding for each chunk
_result: List[List[float]] = results[i] _result: list[list[float]] = results[i]
if len(_result) == 0: if len(_result) == 0:
# this will be populated with the embedding of an empty string # this will be populated with the embedding of an empty string
@ -213,13 +201,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
), ),
) )
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided.""" """Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
allowed_special: Union[Literal["all"], Set[str], None] = None allowed_special: Union[Literal["all"], set[str], None] = None
disallowed_special: Union[Literal["all"], Set[str], Sequence[str], None] = None disallowed_special: Union[Literal["all"], set[str], Sequence[str], None] = None
chunk_size: int = 1000 chunk_size: int = 1000
"""Maximum number of texts to embed in each batch""" """Maximum number of texts to embed in each batch"""
max_retries: int = 2 max_retries: int = 2
"""Maximum number of retries to make when generating.""" """Maximum number of retries to make when generating."""
request_timeout: Optional[Union[float, Tuple[float, float], Any]] = Field( request_timeout: Optional[Union[float, tuple[float, float], Any]] = Field(
default=None, alias="timeout" default=None, alias="timeout"
) )
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or """Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
@ -240,7 +228,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
when tiktoken is called, you can specify a model name to use here.""" when tiktoken is called, you can specify a model name to use here."""
show_progress_bar: bool = False show_progress_bar: bool = False
"""Whether to show a progress bar when embedding.""" """Whether to show a progress bar when embedding."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
skip_empty: bool = False skip_empty: bool = False
"""Whether to skip empty strings when embedding or raise an error. """Whether to skip empty strings when embedding or raise an error.
@ -270,7 +258,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any: def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
@ -354,15 +342,15 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
return self return self
@property @property
def _invocation_params(self) -> Dict[str, Any]: def _invocation_params(self) -> dict[str, Any]:
params: Dict = {"model": self.model, **self.model_kwargs} params: dict = {"model": self.model, **self.model_kwargs}
if self.dimensions is not None: if self.dimensions is not None:
params["dimensions"] = self.dimensions params["dimensions"] = self.dimensions
return params return params
def _tokenize( def _tokenize(
self, texts: List[str], chunk_size: int self, texts: list[str], chunk_size: int
) -> Tuple[Iterable[int], List[Union[List[int], str]], List[int]]: ) -> tuple[Iterable[int], list[Union[list[int], str]], list[int]]:
""" """
Take the input `texts` and `chunk_size` and return 3 iterables as a tuple: Take the input `texts` and `chunk_size` and return 3 iterables as a tuple:
@ -383,8 +371,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
indices: An iterable of the same length as `tokens` that maps each token-array indices: An iterable of the same length as `tokens` that maps each token-array
to the index of the original text in `texts`. to the index of the original text in `texts`.
""" """
tokens: List[Union[List[int], str]] = [] tokens: list[Union[list[int], str]] = []
indices: List[int] = [] indices: list[int] = []
model_name = self.tiktoken_model_name or self.model model_name = self.tiktoken_model_name or self.model
# If tiktoken flag set to False # If tiktoken flag set to False
@ -403,11 +391,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
) )
for i, text in enumerate(texts): for i, text in enumerate(texts):
# Tokenize the text using HuggingFace transformers # Tokenize the text using HuggingFace transformers
tokenized: List[int] = tokenizer.encode(text, add_special_tokens=False) tokenized: list[int] = tokenizer.encode(text, add_special_tokens=False)
# Split tokens into chunks respecting the embedding_ctx_length # Split tokens into chunks respecting the embedding_ctx_length
for j in range(0, len(tokenized), self.embedding_ctx_length): for j in range(0, len(tokenized), self.embedding_ctx_length):
token_chunk: List[int] = tokenized[ token_chunk: list[int] = tokenized[
j : j + self.embedding_ctx_length j : j + self.embedding_ctx_length
] ]
@ -420,7 +408,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
encoding = tiktoken.encoding_for_model(model_name) encoding = tiktoken.encoding_for_model(model_name)
except KeyError: except KeyError:
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
encoder_kwargs: Dict[str, Any] = { encoder_kwargs: dict[str, Any] = {
k: v k: v
for k, v in { for k, v in {
"allowed_special": self.allowed_special, "allowed_special": self.allowed_special,
@ -459,8 +447,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# please refer to # please refer to
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
def _get_len_safe_embeddings( def _get_len_safe_embeddings(
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None self, texts: list[str], *, engine: str, chunk_size: Optional[int] = None
) -> List[List[float]]: ) -> list[list[float]]:
""" """
Generate length-safe embeddings for a list of texts. Generate length-safe embeddings for a list of texts.
@ -478,7 +466,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
""" """
_chunk_size = chunk_size or self.chunk_size _chunk_size = chunk_size or self.chunk_size
_iter, tokens, indices = self._tokenize(texts, _chunk_size) _iter, tokens, indices = self._tokenize(texts, _chunk_size)
batched_embeddings: List[List[float]] = [] batched_embeddings: list[list[float]] = []
for i in _iter: for i in _iter:
response = self.client.create( response = self.client.create(
input=tokens[i : i + _chunk_size], **self._invocation_params input=tokens[i : i + _chunk_size], **self._invocation_params
@ -490,9 +478,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
embeddings = _process_batched_chunked_embeddings( embeddings = _process_batched_chunked_embeddings(
len(texts), tokens, batched_embeddings, indices, self.skip_empty len(texts), tokens, batched_embeddings, indices, self.skip_empty
) )
_cached_empty_embedding: Optional[List[float]] = None _cached_empty_embedding: Optional[list[float]] = None
def empty_embedding() -> List[float]: def empty_embedding() -> list[float]:
nonlocal _cached_empty_embedding nonlocal _cached_empty_embedding
if _cached_empty_embedding is None: if _cached_empty_embedding is None:
average_embedded = self.client.create( average_embedded = self.client.create(
@ -508,8 +496,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# please refer to # please refer to
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
async def _aget_len_safe_embeddings( async def _aget_len_safe_embeddings(
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None self, texts: list[str], *, engine: str, chunk_size: Optional[int] = None
) -> List[List[float]]: ) -> list[list[float]]:
""" """
Asynchronously generate length-safe embeddings for a list of texts. Asynchronously generate length-safe embeddings for a list of texts.
@ -528,7 +516,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
_chunk_size = chunk_size or self.chunk_size _chunk_size = chunk_size or self.chunk_size
_iter, tokens, indices = self._tokenize(texts, _chunk_size) _iter, tokens, indices = self._tokenize(texts, _chunk_size)
batched_embeddings: List[List[float]] = [] batched_embeddings: list[list[float]] = []
_chunk_size = chunk_size or self.chunk_size _chunk_size = chunk_size or self.chunk_size
for i in range(0, len(tokens), _chunk_size): for i in range(0, len(tokens), _chunk_size):
response = await self.async_client.create( response = await self.async_client.create(
@ -542,9 +530,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
embeddings = _process_batched_chunked_embeddings( embeddings = _process_batched_chunked_embeddings(
len(texts), tokens, batched_embeddings, indices, self.skip_empty len(texts), tokens, batched_embeddings, indices, self.skip_empty
) )
_cached_empty_embedding: Optional[List[float]] = None _cached_empty_embedding: Optional[list[float]] = None
async def empty_embedding() -> List[float]: async def empty_embedding() -> list[float]:
nonlocal _cached_empty_embedding nonlocal _cached_empty_embedding
if _cached_empty_embedding is None: if _cached_empty_embedding is None:
average_embedded = await self.async_client.create( average_embedded = await self.async_client.create(
@ -558,8 +546,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
return [e if e is not None else await empty_embedding() for e in embeddings] return [e if e is not None else await empty_embedding() for e in embeddings]
def embed_documents( def embed_documents(
self, texts: List[str], chunk_size: int | None = None self, texts: list[str], chunk_size: int | None = None
) -> List[List[float]]: ) -> list[list[float]]:
"""Call out to OpenAI's embedding endpoint for embedding search docs. """Call out to OpenAI's embedding endpoint for embedding search docs.
Args: Args:
@ -572,7 +560,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
""" """
chunk_size_ = chunk_size or self.chunk_size chunk_size_ = chunk_size or self.chunk_size
if not self.check_embedding_ctx_length: if not self.check_embedding_ctx_length:
embeddings: List[List[float]] = [] embeddings: list[list[float]] = []
for i in range(0, len(texts), chunk_size_): for i in range(0, len(texts), chunk_size_):
response = self.client.create( response = self.client.create(
input=texts[i : i + chunk_size_], **self._invocation_params input=texts[i : i + chunk_size_], **self._invocation_params
@ -588,8 +576,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
return self._get_len_safe_embeddings(texts, engine=engine) return self._get_len_safe_embeddings(texts, engine=engine)
async def aembed_documents( async def aembed_documents(
self, texts: List[str], chunk_size: int | None = None self, texts: list[str], chunk_size: int | None = None
) -> List[List[float]]: ) -> list[list[float]]:
"""Call out to OpenAI's embedding endpoint async for embedding search docs. """Call out to OpenAI's embedding endpoint async for embedding search docs.
Args: Args:
@ -602,7 +590,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
""" """
chunk_size_ = chunk_size or self.chunk_size chunk_size_ = chunk_size or self.chunk_size
if not self.check_embedding_ctx_length: if not self.check_embedding_ctx_length:
embeddings: List[List[float]] = [] embeddings: list[list[float]] = []
for i in range(0, len(texts), chunk_size_): for i in range(0, len(texts), chunk_size_):
response = await self.async_client.create( response = await self.async_client.create(
input=texts[i : i + chunk_size_], **self._invocation_params input=texts[i : i + chunk_size_], **self._invocation_params
@ -617,7 +605,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
engine = cast(str, self.deployment) engine = cast(str, self.deployment)
return await self._aget_len_safe_embeddings(texts, engine=engine) return await self._aget_len_safe_embeddings(texts, engine=engine)
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> list[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text. """Call out to OpenAI's embedding endpoint for embedding query text.
Args: Args:
@ -628,7 +616,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
""" """
return self.embed_documents([text])[0] return self.embed_documents([text])[0]
async def aembed_query(self, text: str) -> List[float]: async def aembed_query(self, text: str) -> list[float]:
"""Call out to OpenAI's embedding endpoint async for embedding query text. """Call out to OpenAI's embedding endpoint async for embedding query text.
Args: Args:

View File

@ -1,13 +1,14 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Union from collections.abc import Awaitable, Mapping
from typing import Any, Callable, Optional, Union, cast
import openai import openai
from langchain_core.language_models import LangSmithParams from langchain_core.language_models import LangSmithParams
from langchain_core.utils import from_env, secret_from_env from langchain_core.utils import from_env, secret_from_env
from pydantic import Field, SecretStr, model_validator from pydantic import Field, SecretStr, model_validator
from typing_extensions import Self, cast from typing_extensions import Self
from langchain_openai.llms.base import BaseOpenAI from langchain_openai.llms.base import BaseOpenAI
@ -91,12 +92,12 @@ class AzureOpenAI(BaseOpenAI):
""" """
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "llms", "openai"] return ["langchain", "llms", "openai"]
@property @property
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> dict[str, str]:
return { return {
"openai_api_key": "AZURE_OPENAI_API_KEY", "openai_api_key": "AZURE_OPENAI_API_KEY",
"azure_ad_token": "AZURE_OPENAI_AD_TOKEN", "azure_ad_token": "AZURE_OPENAI_AD_TOKEN",
@ -188,12 +189,12 @@ class AzureOpenAI(BaseOpenAI):
} }
@property @property
def _invocation_params(self) -> Dict[str, Any]: def _invocation_params(self) -> dict[str, Any]:
openai_params = {"model": self.deployment_name} openai_params = {"model": self.deployment_name}
return {**openai_params, **super()._invocation_params} return {**openai_params, **super()._invocation_params}
def _get_ls_params( def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams: ) -> LangSmithParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs) params = super()._get_ls_params(stop=stop, **kwargs)
@ -209,7 +210,7 @@ class AzureOpenAI(BaseOpenAI):
return "azure" return "azure"
@property @property
def lc_attributes(self) -> Dict[str, Any]: def lc_attributes(self) -> dict[str, Any]:
return { return {
"openai_api_type": self.openai_api_type, "openai_api_type": self.openai_api_type,
"openai_api_version": self.openai_api_version, "openai_api_version": self.openai_api_version,

View File

@ -2,21 +2,8 @@ from __future__ import annotations
import logging import logging
import sys import sys
from typing import ( from collections.abc import AsyncIterator, Collection, Iterator, Mapping
AbstractSet, from typing import Any, Literal, Optional, Union
Any,
AsyncIterator,
Collection,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Set,
Tuple,
Union,
)
import openai import openai
import tiktoken import tiktoken
@ -35,7 +22,7 @@ logger = logging.getLogger(__name__)
def _update_token_usage( def _update_token_usage(
keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any] keys: set[str], response: dict[str, Any], token_usage: dict[str, Any]
) -> None: ) -> None:
"""Update token usage.""" """Update token usage."""
_keys_to_use = keys.intersection(response["usage"]) _keys_to_use = keys.intersection(response["usage"])
@ -47,7 +34,7 @@ def _update_token_usage(
def _stream_response_to_generation_chunk( def _stream_response_to_generation_chunk(
stream_response: Dict[str, Any], stream_response: dict[str, Any],
) -> GenerationChunk: ) -> GenerationChunk:
"""Convert a stream response to a generation chunk.""" """Convert a stream response to a generation chunk."""
if not stream_response["choices"]: if not stream_response["choices"]:
@ -84,7 +71,7 @@ class BaseOpenAI(BaseLLM):
"""How many completions to generate for each prompt.""" """How many completions to generate for each prompt."""
best_of: int = 1 best_of: int = 1
"""Generates best_of completions server-side and returns the "best".""" """Generates best_of completions server-side and returns the "best"."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[SecretStr] = Field( openai_api_key: Optional[SecretStr] = Field(
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
@ -108,12 +95,12 @@ class BaseOpenAI(BaseLLM):
) )
batch_size: int = 20 batch_size: int = 20
"""Batch size to use when passing multiple documents to generate.""" """Batch size to use when passing multiple documents to generate."""
request_timeout: Union[float, Tuple[float, float], Any, None] = Field( request_timeout: Union[float, tuple[float, float], Any, None] = Field(
default=None, alias="timeout" default=None, alias="timeout"
) )
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or """Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
None.""" None."""
logit_bias: Optional[Dict[str, float]] = None logit_bias: Optional[dict[str, float]] = None
"""Adjust the probability of specific tokens being generated.""" """Adjust the probability of specific tokens being generated."""
max_retries: int = 2 max_retries: int = 2
"""Maximum number of retries to make when generating.""" """Maximum number of retries to make when generating."""
@ -124,7 +111,7 @@ class BaseOpenAI(BaseLLM):
as well the chosen tokens.""" as well the chosen tokens."""
streaming: bool = False streaming: bool = False
"""Whether to stream the results or not.""" """Whether to stream the results or not."""
allowed_special: Union[Literal["all"], AbstractSet[str]] = set() allowed_special: Union[Literal["all"], set[str]] = set()
"""Set of special tokens that are allowed。""" """Set of special tokens that are allowed。"""
disallowed_special: Union[Literal["all"], Collection[str]] = "all" disallowed_special: Union[Literal["all"], Collection[str]] = "all"
"""Set of special tokens that are not allowed。""" """Set of special tokens that are not allowed。"""
@ -157,7 +144,7 @@ class BaseOpenAI(BaseLLM):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any: def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names) values = _build_model_kwargs(values, all_required_field_names)
@ -197,9 +184,9 @@ class BaseOpenAI(BaseLLM):
return self return self
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling OpenAI API.""" """Get the default parameters for calling OpenAI API."""
normal_params: Dict[str, Any] = { normal_params: dict[str, Any] = {
"temperature": self.temperature, "temperature": self.temperature,
"top_p": self.top_p, "top_p": self.top_p,
"frequency_penalty": self.frequency_penalty, "frequency_penalty": self.frequency_penalty,
@ -228,7 +215,7 @@ class BaseOpenAI(BaseLLM):
def _stream( def _stream(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[GenerationChunk]: ) -> Iterator[GenerationChunk]:
@ -255,7 +242,7 @@ class BaseOpenAI(BaseLLM):
async def _astream( async def _astream(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[GenerationChunk]: ) -> AsyncIterator[GenerationChunk]:
@ -283,8 +270,8 @@ class BaseOpenAI(BaseLLM):
def _generate( def _generate(
self, self,
prompts: List[str], prompts: list[str],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
@ -307,7 +294,7 @@ class BaseOpenAI(BaseLLM):
params = {**params, **kwargs} params = {**params, **kwargs}
sub_prompts = self.get_sub_prompts(params, prompts, stop) sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = [] choices = []
token_usage: Dict[str, int] = {} token_usage: dict[str, int] = {}
# Get the token usage from the response. # Get the token usage from the response.
# Includes prompt, completion, and total tokens used. # Includes prompt, completion, and total tokens used.
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"} _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
@ -363,8 +350,8 @@ class BaseOpenAI(BaseLLM):
async def _agenerate( async def _agenerate(
self, self,
prompts: List[str], prompts: list[str],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
@ -373,7 +360,7 @@ class BaseOpenAI(BaseLLM):
params = {**params, **kwargs} params = {**params, **kwargs}
sub_prompts = self.get_sub_prompts(params, prompts, stop) sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = [] choices = []
token_usage: Dict[str, int] = {} token_usage: dict[str, int] = {}
# Get the token usage from the response. # Get the token usage from the response.
# Includes prompt, completion, and total tokens used. # Includes prompt, completion, and total tokens used.
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"} _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
@ -419,10 +406,10 @@ class BaseOpenAI(BaseLLM):
def get_sub_prompts( def get_sub_prompts(
self, self,
params: Dict[str, Any], params: dict[str, Any],
prompts: List[str], prompts: list[str],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
) -> List[List[str]]: ) -> list[list[str]]:
"""Get the sub prompts for llm call.""" """Get the sub prompts for llm call."""
if stop is not None: if stop is not None:
params["stop"] = stop params["stop"] = stop
@ -441,9 +428,9 @@ class BaseOpenAI(BaseLLM):
def create_llm_result( def create_llm_result(
self, self,
choices: Any, choices: Any,
prompts: List[str], prompts: list[str],
params: Dict[str, Any], params: dict[str, Any],
token_usage: Dict[str, int], token_usage: dict[str, int],
*, *,
system_fingerprint: Optional[str] = None, system_fingerprint: Optional[str] = None,
) -> LLMResult: ) -> LLMResult:
@ -470,7 +457,7 @@ class BaseOpenAI(BaseLLM):
return LLMResult(generations=generations, llm_output=llm_output) return LLMResult(generations=generations, llm_output=llm_output)
@property @property
def _invocation_params(self) -> Dict[str, Any]: def _invocation_params(self) -> dict[str, Any]:
"""Get the parameters used to invoke the model.""" """Get the parameters used to invoke the model."""
return self._default_params return self._default_params
@ -484,7 +471,7 @@ class BaseOpenAI(BaseLLM):
"""Return type of llm.""" """Return type of llm."""
return "openai" return "openai"
def get_token_ids(self, text: str) -> List[int]: def get_token_ids(self, text: str) -> list[int]:
"""Get the token IDs using the tiktoken package.""" """Get the token IDs using the tiktoken package."""
if self.custom_get_token_ids is not None: if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text) return self.custom_get_token_ids(text)
@ -689,7 +676,7 @@ class OpenAI(BaseOpenAI):
""" # noqa: E501 """ # noqa: E501
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "llms", "openai"] return ["langchain", "llms", "openai"]
@ -699,16 +686,16 @@ class OpenAI(BaseOpenAI):
return True return True
@property @property
def _invocation_params(self) -> Dict[str, Any]: def _invocation_params(self) -> dict[str, Any]:
return {**{"model": self.model_name}, **super()._invocation_params} return {**{"model": self.model_name}, **super()._invocation_params}
@property @property
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"} return {"openai_api_key": "OPENAI_API_KEY"}
@property @property
def lc_attributes(self) -> Dict[str, Any]: def lc_attributes(self) -> dict[str, Any]:
attributes: Dict[str, Any] = {} attributes: dict[str, Any] = {}
if self.openai_api_base: if self.openai_api_base:
attributes["openai_api_base"] = self.openai_api_base attributes["openai_api_base"] = self.openai_api_base

View File

@ -59,8 +59,12 @@ disallow_untyped_defs = "True"
module = "transformers" module = "transformers"
ignore_missing_imports = true ignore_missing_imports = true
[tool.ruff]
target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I", "T201"] select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.ruff.format] [tool.ruff.format]
docstring-code-format = true docstring-code-format = true

View File

@ -1,7 +1,6 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
import os import os
from typing import Type
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_tests.integration_tests import ChatModelIntegrationTests from langchain_tests.integration_tests import ChatModelIntegrationTests
@ -14,7 +13,7 @@ OPENAI_API_BASE = os.environ.get("AZURE_OPENAI_API_BASE", "")
class TestAzureOpenAIStandard(ChatModelIntegrationTests): class TestAzureOpenAIStandard(ChatModelIntegrationTests):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return AzureChatOpenAI return AzureChatOpenAI
@property @property
@ -40,7 +39,7 @@ class TestAzureOpenAIStandardLegacy(ChatModelIntegrationTests):
"""Test a legacy model.""" """Test a legacy model."""
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return AzureChatOpenAI return AzureChatOpenAI
@property @property

View File

@ -2,9 +2,10 @@
import base64 import base64
import json import json
from collections.abc import AsyncIterator
from pathlib import Path from pathlib import Path
from textwrap import dedent from textwrap import dedent
from typing import Any, AsyncIterator, List, Literal, Optional, cast from typing import Any, Literal, Optional, cast
import httpx import httpx
import openai import openai
@ -531,14 +532,14 @@ class MakeASandwich(BaseModel):
bread_type: str bread_type: str
cheese_type: str cheese_type: str
condiments: List[str] condiments: list[str]
vegetables: List[str] vegetables: list[str]
def test_tool_use() -> None: def test_tool_use() -> None:
llm = ChatOpenAI(model="gpt-4-turbo", temperature=0) llm = ChatOpenAI(model="gpt-4-turbo", temperature=0)
llm_with_tool = llm.bind_tools(tools=[GenerateUsername], tool_choice=True) llm_with_tool = llm.bind_tools(tools=[GenerateUsername], tool_choice=True)
msgs: List = [HumanMessage("Sally has green hair, what would her username be?")] msgs: list = [HumanMessage("Sally has green hair, what would her username be?")]
ai_msg = llm_with_tool.invoke(msgs) ai_msg = llm_with_tool.invoke(msgs)
assert isinstance(ai_msg, AIMessage) assert isinstance(ai_msg, AIMessage)
@ -583,7 +584,7 @@ def test_manual_tool_call_msg(use_responses_api: bool) -> None:
model="gpt-3.5-turbo-0125", temperature=0, use_responses_api=use_responses_api model="gpt-3.5-turbo-0125", temperature=0, use_responses_api=use_responses_api
) )
llm_with_tool = llm.bind_tools(tools=[GenerateUsername]) llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
msgs: List = [ msgs: list = [
HumanMessage("Sally has green hair, what would her username be?"), HumanMessage("Sally has green hair, what would her username be?"),
AIMessage( AIMessage(
content="", content="",
@ -1045,7 +1046,7 @@ def test_audio_output_modality() -> None:
}, },
) )
history: List[BaseMessage] = [ history: list[BaseMessage] = [
HumanMessage("Make me a short audio clip of you yelling") HumanMessage("Make me a short audio clip of you yelling")
] ]

View File

@ -1,7 +1,7 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from pathlib import Path from pathlib import Path
from typing import Dict, List, Literal, Type, cast from typing import Literal, cast
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
@ -14,7 +14,7 @@ REPO_ROOT_DIR = Path(__file__).parents[6]
class TestOpenAIStandard(ChatModelIntegrationTests): class TestOpenAIStandard(ChatModelIntegrationTests):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return ChatOpenAI return ChatOpenAI
@property @property
@ -36,9 +36,9 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
@property @property
def supported_usage_metadata_details( def supported_usage_metadata_details(
self, self,
) -> Dict[ ) -> dict[
Literal["invoke", "stream"], Literal["invoke", "stream"],
List[ list[
Literal[ Literal[
"audio_input", "audio_input",
"audio_output", "audio_output",
@ -51,7 +51,7 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
return {"invoke": ["reasoning_output", "cache_read_input"], "stream": []} return {"invoke": ["reasoning_output", "cache_read_input"], "stream": []}
def invoke_with_cache_read_input(self, *, stream: bool = False) -> AIMessage: def invoke_with_cache_read_input(self, *, stream: bool = False) -> AIMessage:
with open(REPO_ROOT_DIR / "README.md", "r") as f: with open(REPO_ROOT_DIR / "README.md") as f:
readme = f.read() readme = f.read()
input_ = f"""What's langchain? Here's the langchain README: input_ = f"""What's langchain? Here's the langchain README:

View File

@ -2,7 +2,7 @@
import json import json
import os import os
from typing import Any, Optional, cast from typing import Annotated, Any, Optional, cast
import openai import openai
import pytest import pytest
@ -13,7 +13,7 @@ from langchain_core.messages import (
BaseMessageChunk, BaseMessageChunk,
) )
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Annotated, TypedDict from typing_extensions import TypedDict
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests for Responses API""" """Standard LangChain interface tests for Responses API"""
from typing import Type
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
@ -11,7 +9,7 @@ from tests.integration_tests.chat_models.test_base_standard import TestOpenAISta
class TestOpenAIResponses(TestOpenAIStandard): class TestOpenAIResponses(TestOpenAIStandard):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return ChatOpenAI return ChatOpenAI
@property @property

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Type
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_tests.integration_tests.embeddings import EmbeddingsIntegrationTests from langchain_tests.integration_tests.embeddings import EmbeddingsIntegrationTests
@ -10,7 +8,7 @@ from langchain_openai import OpenAIEmbeddings
class TestOpenAIStandard(EmbeddingsIntegrationTests): class TestOpenAIStandard(EmbeddingsIntegrationTests):
@property @property
def embeddings_class(self) -> Type[Embeddings]: def embeddings_class(self) -> type[Embeddings]:
return OpenAIEmbeddings return OpenAIEmbeddings
@property @property

View File

@ -1,7 +1,8 @@
"""Test AzureOpenAI wrapper.""" """Test AzureOpenAI wrapper."""
import os import os
from typing import Any, Generator from collections.abc import Generator
from typing import Any
import pytest import pytest
from langchain_core.callbacks import CallbackManager from langchain_core.callbacks import CallbackManager

View File

@ -1,6 +1,6 @@
"""Test OpenAI llm.""" """Test OpenAI llm."""
from typing import Generator from collections.abc import Generator
import pytest import pytest
from langchain_core.callbacks import CallbackManager from langchain_core.callbacks import CallbackManager

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Tuple, Type
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
@ -12,7 +10,7 @@ from langchain_openai import AzureChatOpenAI
class TestOpenAIStandard(ChatModelUnitTests): class TestOpenAIStandard(ChatModelUnitTests):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return AzureChatOpenAI return AzureChatOpenAI
@property @property
@ -30,7 +28,7 @@ class TestOpenAIStandard(ChatModelUnitTests):
super().test_bind_tool_pydantic(model, my_adder_tool) super().test_bind_tool_pydantic(model, my_adder_tool)
@property @property
def init_from_env_params(self) -> Tuple[dict, dict, dict]: def init_from_env_params(self) -> tuple[dict, dict, dict]:
return ( return (
{ {
"AZURE_OPENAI_API_KEY": "api_key", "AZURE_OPENAI_API_KEY": "api_key",

View File

@ -3,7 +3,7 @@
import json import json
from functools import partial from functools import partial
from types import TracebackType from types import TracebackType
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast from typing import Any, Literal, Optional, Union, cast
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@ -241,7 +241,7 @@ class MockAsyncContextManager:
async def __aexit__( async def __aexit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: Optional[type[BaseException]],
exc: Optional[BaseException], exc: Optional[BaseException],
tb: Optional[TracebackType], tb: Optional[TracebackType],
) -> None: ) -> None:
@ -270,7 +270,7 @@ class MockSyncContextManager:
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: Optional[type[BaseException]],
exc: Optional[BaseException], exc: Optional[BaseException],
tb: Optional[TracebackType], tb: Optional[TracebackType],
) -> None: ) -> None:
@ -382,7 +382,7 @@ DEEPSEEK_STREAM_DATA = """{"id":"d3610c24e6b42518a7883ea57c3ea2c3","choices":[{"
@pytest.fixture @pytest.fixture
def mock_deepseek_completion() -> List[Dict]: def mock_deepseek_completion() -> list[dict]:
list_chunk_data = DEEPSEEK_STREAM_DATA.split("\n") list_chunk_data = DEEPSEEK_STREAM_DATA.split("\n")
result_list = [] result_list = []
for msg in list_chunk_data: for msg in list_chunk_data:
@ -450,7 +450,7 @@ OPENAI_STREAM_DATA = """{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":
@pytest.fixture @pytest.fixture
def mock_openai_completion() -> List[Dict]: def mock_openai_completion() -> list[dict]:
list_chunk_data = OPENAI_STREAM_DATA.split("\n") list_chunk_data = OPENAI_STREAM_DATA.split("\n")
result_list = [] result_list = []
for msg in list_chunk_data: for msg in list_chunk_data:
@ -615,7 +615,7 @@ def test_openai_invoke_name(mock_client: MagicMock) -> None:
def test_custom_token_counting() -> None: def test_custom_token_counting() -> None:
def token_encoder(text: str) -> List[int]: def token_encoder(text: str) -> list[int]:
return [1, 2, 3] return [1, 2, 3]
llm = ChatOpenAI(custom_get_token_ids=token_encoder) llm = ChatOpenAI(custom_get_token_ids=token_encoder)
@ -662,8 +662,8 @@ class MakeASandwich(BaseModel):
bread_type: str bread_type: str
cheese_type: str cheese_type: str
condiments: List[str] condiments: list[str]
vegetables: List[str] vegetables: list[str]
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -695,7 +695,7 @@ def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> Non
@pytest.mark.parametrize("include_raw", [True, False]) @pytest.mark.parametrize("include_raw", [True, False])
@pytest.mark.parametrize("strict", [True, False, None]) @pytest.mark.parametrize("strict", [True, False, None])
def test_with_structured_output( def test_with_structured_output(
schema: Union[Type, Dict[str, Any], None], schema: Union[type, dict[str, Any], None],
method: Literal["function_calling", "json_mode", "json_schema"], method: Literal["function_calling", "json_mode", "json_schema"],
include_raw: bool, include_raw: bool,
strict: Optional[bool], strict: Optional[bool],
@ -787,7 +787,7 @@ class Foo(BaseModel):
# FooV1 # FooV1
], ],
) )
def test_schema_from_with_structured_output(schema: Type) -> None: def test_schema_from_with_structured_output(schema: type) -> None:
"""Test schema from with_structured_output.""" """Test schema from with_structured_output."""
llm = ChatOpenAI(model="gpt-4o") llm = ChatOpenAI(model="gpt-4o")

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Tuple, Type
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests import ChatModelUnitTests from langchain_tests.unit_tests import ChatModelUnitTests
@ -10,11 +8,11 @@ from langchain_openai import ChatOpenAI
class TestOpenAIStandard(ChatModelUnitTests): class TestOpenAIStandard(ChatModelUnitTests):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return ChatOpenAI return ChatOpenAI
@property @property
def init_from_env_params(self) -> Tuple[dict, dict, dict]: def init_from_env_params(self) -> tuple[dict, dict, dict]:
return ( return (
{ {
"OPENAI_API_KEY": "api_key", "OPENAI_API_KEY": "api_key",

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Tuple, Type
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests import ChatModelUnitTests from langchain_tests.unit_tests import ChatModelUnitTests
@ -10,7 +8,7 @@ from langchain_openai import ChatOpenAI
class TestOpenAIResponses(ChatModelUnitTests): class TestOpenAIResponses(ChatModelUnitTests):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return ChatOpenAI return ChatOpenAI
@property @property
@ -18,7 +16,7 @@ class TestOpenAIResponses(ChatModelUnitTests):
return {"use_responses_api": True} return {"use_responses_api": True}
@property @property
def init_from_env_params(self) -> Tuple[dict, dict, dict]: def init_from_env_params(self) -> tuple[dict, dict, dict]:
return ( return (
{ {
"OPENAI_API_KEY": "api_key", "OPENAI_API_KEY": "api_key",

View File

@ -1,5 +1,3 @@
from typing import Tuple, Type
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests
@ -8,7 +6,7 @@ from langchain_openai import AzureOpenAIEmbeddings
class TestAzureOpenAIStandard(EmbeddingsUnitTests): class TestAzureOpenAIStandard(EmbeddingsUnitTests):
@property @property
def embeddings_class(self) -> Type[Embeddings]: def embeddings_class(self) -> type[Embeddings]:
return AzureOpenAIEmbeddings return AzureOpenAIEmbeddings
@property @property
@ -16,7 +14,7 @@ class TestAzureOpenAIStandard(EmbeddingsUnitTests):
return {"api_key": "api_key", "azure_endpoint": "https://endpoint.com"} return {"api_key": "api_key", "azure_endpoint": "https://endpoint.com"}
@property @property
def init_from_env_params(self) -> Tuple[dict, dict, dict]: def init_from_env_params(self) -> tuple[dict, dict, dict]:
return ( return (
{ {
"AZURE_OPENAI_API_KEY": "api_key", "AZURE_OPENAI_API_KEY": "api_key",

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Tuple, Type
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests
@ -10,11 +8,11 @@ from langchain_openai import OpenAIEmbeddings
class TestOpenAIStandard(EmbeddingsUnitTests): class TestOpenAIStandard(EmbeddingsUnitTests):
@property @property
def embeddings_class(self) -> Type[Embeddings]: def embeddings_class(self) -> type[Embeddings]:
return OpenAIEmbeddings return OpenAIEmbeddings
@property @property
def init_from_env_params(self) -> Tuple[dict, dict, dict]: def init_from_env_params(self) -> tuple[dict, dict, dict]:
return ( return (
{ {
"OPENAI_API_KEY": "api_key", "OPENAI_API_KEY": "api_key",

View File

@ -1,7 +1,7 @@
"""A fake callback handler for testing purposes.""" """A fake callback handler for testing purposes."""
from itertools import chain from itertools import chain
from typing import Any, Dict, List, Optional, Union from typing import Any, Optional, Union
from uuid import UUID from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
@ -15,7 +15,7 @@ class BaseFakeCallbackHandler(BaseModel):
starts: int = 0 starts: int = 0
ends: int = 0 ends: int = 0
errors: int = 0 errors: int = 0
errors_args: List[Any] = [] errors_args: list[Any] = []
text: int = 0 text: int = 0
ignore_llm_: bool = False ignore_llm_: bool = False
ignore_chain_: bool = False ignore_chain_: bool = False
@ -195,8 +195,8 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler): class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
def on_chat_model_start( def on_chat_model_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
messages: List[List[BaseMessage]], messages: list[list[BaseMessage]],
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,

View File

@ -1,5 +1,4 @@
import os import os
from typing import List
import pytest import pytest
@ -65,7 +64,7 @@ def test_get_token_ids(model: str) -> None:
def test_custom_token_counting() -> None: def test_custom_token_counting() -> None:
def token_encoder(text: str) -> List[int]: def token_encoder(text: str) -> list[int]:
return [1, 2, 3] return [1, 2, 3]
llm = OpenAI(custom_get_token_ids=token_encoder) llm = OpenAI(custom_get_token_ids=token_encoder)

View File

@ -1,4 +1,4 @@
from typing import Type, cast from typing import cast
import pytest import pytest
from langchain_core.load import dumpd from langchain_core.load import dumpd
@ -72,7 +72,7 @@ def test_azure_openai_embeddings_secrets() -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings] "model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings]
) )
def test_azure_openai_api_key_is_secret_string(model_class: Type) -> None: def test_azure_openai_api_key_is_secret_string(model_class: type) -> None:
"""Test that the API key is stored as a SecretStr.""" """Test that the API key is stored as a SecretStr."""
model = model_class( model = model_class(
openai_api_key="secret-api-key", openai_api_key="secret-api-key",
@ -88,7 +88,7 @@ def test_azure_openai_api_key_is_secret_string(model_class: Type) -> None:
"model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings] "model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings]
) )
def test_azure_openai_api_key_masked_when_passed_from_env( def test_azure_openai_api_key_masked_when_passed_from_env(
model_class: Type, monkeypatch: MonkeyPatch, capsys: CaptureFixture model_class: type, monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None: ) -> None:
"""Test that the API key is masked when passed from an environment variable.""" """Test that the API key is masked when passed from an environment variable."""
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "secret-api-key") monkeypatch.setenv("AZURE_OPENAI_API_KEY", "secret-api-key")
@ -109,7 +109,7 @@ def test_azure_openai_api_key_masked_when_passed_from_env(
"model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings] "model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings]
) )
def test_azure_openai_api_key_masked_when_passed_via_constructor( def test_azure_openai_api_key_masked_when_passed_via_constructor(
model_class: Type, capsys: CaptureFixture model_class: type, capsys: CaptureFixture
) -> None: ) -> None:
"""Test that the API key is masked when passed via the constructor.""" """Test that the API key is masked when passed via the constructor."""
model = model_class( model = model_class(
@ -133,7 +133,7 @@ def test_azure_openai_api_key_masked_when_passed_via_constructor(
"model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings] "model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings]
) )
def test_azure_openai_uses_actual_secret_value_from_secretstr( def test_azure_openai_uses_actual_secret_value_from_secretstr(
model_class: Type, model_class: type,
) -> None: ) -> None:
"""Test that the actual secret value is correctly retrieved.""" """Test that the actual secret value is correctly retrieved."""
model = model_class( model = model_class(
@ -147,7 +147,7 @@ def test_azure_openai_uses_actual_secret_value_from_secretstr(
@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings]) @pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
def test_openai_api_key_is_secret_string(model_class: Type) -> None: def test_openai_api_key_is_secret_string(model_class: type) -> None:
"""Test that the API key is stored as a SecretStr.""" """Test that the API key is stored as a SecretStr."""
model = model_class(openai_api_key="secret-api-key") model = model_class(openai_api_key="secret-api-key")
assert isinstance(model.openai_api_key, SecretStr) assert isinstance(model.openai_api_key, SecretStr)
@ -155,7 +155,7 @@ def test_openai_api_key_is_secret_string(model_class: Type) -> None:
@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings]) @pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
def test_openai_api_key_masked_when_passed_from_env( def test_openai_api_key_masked_when_passed_from_env(
model_class: Type, monkeypatch: MonkeyPatch, capsys: CaptureFixture model_class: type, monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None: ) -> None:
"""Test that the API key is masked when passed from an environment variable.""" """Test that the API key is masked when passed from an environment variable."""
monkeypatch.setenv("OPENAI_API_KEY", "secret-api-key") monkeypatch.setenv("OPENAI_API_KEY", "secret-api-key")
@ -168,7 +168,7 @@ def test_openai_api_key_masked_when_passed_from_env(
@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings]) @pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
def test_openai_api_key_masked_when_passed_via_constructor( def test_openai_api_key_masked_when_passed_via_constructor(
model_class: Type, capsys: CaptureFixture model_class: type, capsys: CaptureFixture
) -> None: ) -> None:
"""Test that the API key is masked when passed via the constructor.""" """Test that the API key is masked when passed via the constructor."""
model = model_class(openai_api_key="secret-api-key") model = model_class(openai_api_key="secret-api-key")
@ -179,14 +179,14 @@ def test_openai_api_key_masked_when_passed_via_constructor(
@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings]) @pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
def test_openai_uses_actual_secret_value_from_secretstr(model_class: Type) -> None: def test_openai_uses_actual_secret_value_from_secretstr(model_class: type) -> None:
"""Test that the actual secret value is correctly retrieved.""" """Test that the actual secret value is correctly retrieved."""
model = model_class(openai_api_key="secret-api-key") model = model_class(openai_api_key="secret-api-key")
assert cast(SecretStr, model.openai_api_key).get_secret_value() == "secret-api-key" assert cast(SecretStr, model.openai_api_key).get_secret_value() == "secret-api-key"
@pytest.mark.parametrize("model_class", [AzureChatOpenAI, AzureOpenAI]) @pytest.mark.parametrize("model_class", [AzureChatOpenAI, AzureOpenAI])
def test_azure_serialized_secrets(model_class: Type) -> None: def test_azure_serialized_secrets(model_class: type) -> None:
"""Test that the actual secret value is correctly retrieved.""" """Test that the actual secret value is correctly retrieved."""
model = model_class( model = model_class(
openai_api_key="secret-api-key", api_version="foo", azure_endpoint="foo" openai_api_key="secret-api-key", api_version="foo", azure_endpoint="foo"

View File

@ -3,20 +3,9 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from collections.abc import Iterator, Mapping
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import Any, Literal, Optional, TypeVar, Union
Any,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
import openai import openai
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks import CallbackManagerForLLMRun
@ -50,8 +39,8 @@ from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self from typing_extensions import Self
_BM = TypeVar("_BM", bound=BaseModel) _BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type] _DictOrPydanticClass = Union[dict[str, Any], type[_BM], type]
_DictOrPydantic = Union[Dict, _BM] _DictOrPydantic = Union[dict, _BM]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -162,14 +151,14 @@ class ChatPerplexity(BaseChatModel):
"""Model name.""" """Model name."""
temperature: float = 0.7 temperature: float = 0.7
"""What sampling temperature to use.""" """What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
pplx_api_key: Optional[SecretStr] = Field( pplx_api_key: Optional[SecretStr] = Field(
default_factory=secret_from_env("PPLX_API_KEY", default=None), alias="api_key" default_factory=secret_from_env("PPLX_API_KEY", default=None), alias="api_key"
) )
"""Base URL path for API requests, """Base URL path for API requests,
leave blank if not using a proxy or service emulator.""" leave blank if not using a proxy or service emulator."""
request_timeout: Optional[Union[float, Tuple[float, float]]] = Field( request_timeout: Optional[Union[float, tuple[float, float]]] = Field(
None, alias="timeout" None, alias="timeout"
) )
"""Timeout for requests to PerplexityChat completion API. Default is None.""" """Timeout for requests to PerplexityChat completion API. Default is None."""
@ -183,12 +172,12 @@ class ChatPerplexity(BaseChatModel):
model_config = ConfigDict(populate_by_name=True) model_config = ConfigDict(populate_by_name=True)
@property @property
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> dict[str, str]:
return {"pplx_api_key": "PPLX_API_KEY"} return {"pplx_api_key": "PPLX_API_KEY"}
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any: def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
@ -232,7 +221,7 @@ class ChatPerplexity(BaseChatModel):
return self return self
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling PerplexityChat API.""" """Get the default parameters for calling PerplexityChat API."""
return { return {
"max_tokens": self.max_tokens, "max_tokens": self.max_tokens,
@ -241,7 +230,7 @@ class ChatPerplexity(BaseChatModel):
**self.model_kwargs, **self.model_kwargs,
} }
def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]: def _convert_message_to_dict(self, message: BaseMessage) -> dict[str, Any]:
if isinstance(message, ChatMessage): if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content} message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
@ -255,8 +244,8 @@ class ChatPerplexity(BaseChatModel):
return message_dict return message_dict
def _create_message_dicts( def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]] self, messages: list[BaseMessage], stop: Optional[list[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
params = dict(self._invocation_params) params = dict(self._invocation_params)
if stop is not None: if stop is not None:
if "stop" in params: if "stop" in params:
@ -266,11 +255,11 @@ class ChatPerplexity(BaseChatModel):
return message_dicts, params return message_dicts, params
def _convert_delta_to_message_chunk( def _convert_delta_to_message_chunk(
self, _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] self, _dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk: ) -> BaseMessageChunk:
role = _dict.get("role") role = _dict.get("role")
content = _dict.get("content") or "" content = _dict.get("content") or ""
additional_kwargs: Dict = {} additional_kwargs: dict = {}
if _dict.get("function_call"): if _dict.get("function_call"):
function_call = dict(_dict["function_call"]) function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None: if "name" in function_call and function_call["name"] is None:
@ -296,8 +285,8 @@ class ChatPerplexity(BaseChatModel):
def _stream( def _stream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
@ -367,8 +356,8 @@ class ChatPerplexity(BaseChatModel):
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -402,7 +391,7 @@ class ChatPerplexity(BaseChatModel):
@property @property
def _invocation_params(self) -> Mapping[str, Any]: def _invocation_params(self) -> Mapping[str, Any]:
"""Get the parameters used to invoke the model.""" """Get the parameters used to invoke the model."""
pplx_creds: Dict[str, Any] = {"model": self.model} pplx_creds: dict[str, Any] = {"model": self.model}
return {**pplx_creds, **self._default_params} return {**pplx_creds, **self._default_params}
@property @property

View File

@ -55,8 +55,12 @@ plugins = ['pydantic.mypy']
module = "transformers" module = "transformers"
ignore_missing_imports = true ignore_missing_imports = true
[tool.ruff]
target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I", "T201"] select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.ruff.format] [tool.ruff.format]
docstring-code-format = true docstring-code-format = true

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests.""" """Standard LangChain interface tests."""
from typing import Type
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_tests.integration_tests import ChatModelIntegrationTests from langchain_tests.integration_tests import ChatModelIntegrationTests
@ -11,7 +9,7 @@ from langchain_perplexity import ChatPerplexity
class TestPerplexityStandard(ChatModelIntegrationTests): class TestPerplexityStandard(ChatModelIntegrationTests):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return ChatPerplexity return ChatPerplexity
@property @property

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional from typing import Any, Optional
from unittest.mock import MagicMock from unittest.mock import MagicMock
from langchain_core.messages import AIMessageChunk, BaseMessageChunk from langchain_core.messages import AIMessageChunk, BaseMessageChunk
@ -51,7 +51,7 @@ def test_perplexity_stream_includes_citations(mocker: MockerFixture) -> None:
"choices": [{"delta": {"content": "Perplexity"}, "finish_reason": None}], "choices": [{"delta": {"content": "Perplexity"}, "finish_reason": None}],
"citations": ["example.com", "example2.com"], "citations": ["example.com", "example2.com"],
} }
mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1] mock_chunks: list[dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
mock_stream = MagicMock() mock_stream = MagicMock()
mock_stream.__iter__.return_value = mock_chunks mock_stream.__iter__.return_value = mock_chunks
patcher = mocker.patch.object( patcher = mocker.patch.object(
@ -103,7 +103,7 @@ def test_perplexity_stream_includes_citations_and_images(mocker: MockerFixture)
} }
], ],
} }
mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1] mock_chunks: list[dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
mock_stream = MagicMock() mock_stream = MagicMock()
mock_stream.__iter__.return_value = mock_chunks mock_stream.__iter__.return_value = mock_chunks
patcher = mocker.patch.object( patcher = mocker.patch.object(
@ -162,7 +162,7 @@ def test_perplexity_stream_includes_citations_and_related_questions(
"citations": ["example.com", "example2.com"], "citations": ["example.com", "example2.com"],
"related_questions": ["example_question_1", "example_question_2"], "related_questions": ["example_question_1", "example_question_2"],
} }
mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1] mock_chunks: list[dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
mock_stream = MagicMock() mock_stream = MagicMock()
mock_stream.__iter__.return_value = mock_chunks mock_stream.__iter__.return_value = mock_chunks
patcher = mocker.patch.object( patcher = mocker.patch.object(

View File

@ -1,7 +1,5 @@
"""Test Perplexity Chat API wrapper.""" """Test Perplexity Chat API wrapper."""
from typing import Tuple, Type
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests import ChatModelUnitTests from langchain_tests.unit_tests import ChatModelUnitTests
@ -10,9 +8,9 @@ from langchain_perplexity import ChatPerplexity
class TestPerplexityStandard(ChatModelUnitTests): class TestPerplexityStandard(ChatModelUnitTests):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
return ChatPerplexity return ChatPerplexity
@property @property
def init_from_env_params(self) -> Tuple[dict, dict, dict]: def init_from_env_params(self) -> tuple[dict, dict, dict]:
return ({"PPLX_API_KEY": "api_key"}, {}, {"pplx_api_key": "api_key"}) return ({"PPLX_API_KEY": "api_key"}, {}, {"pplx_api_key": "api_key"})

View File

@ -5,7 +5,7 @@ import json
import os import os
import re import re
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union from typing import Any, Generic, Literal, Optional, TypeVar, Union
import yaml import yaml
from pydantic import BaseModel, ConfigDict, Field, FilePath from pydantic import BaseModel, ConfigDict, Field, FilePath
@ -24,7 +24,7 @@ class PropertySettings(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
type: Literal["string", "number", "array", "object", "boolean"] type: Literal["string", "number", "array", "object", "boolean"]
default: Union[str, int, float, List, Dict, bool, None] = Field(default=None) default: Union[str, int, float, list, dict, bool, None] = Field(default=None)
description: str = Field(default="") description: str = Field(default="")
@ -58,8 +58,8 @@ class Prompty(BaseModel):
# metadata # metadata
name: str = Field(default="") name: str = Field(default="")
description: str = Field(default="") description: str = Field(default="")
authors: List[str] = Field(default=[]) authors: list[str] = Field(default=[])
tags: List[str] = Field(default=[]) tags: list[str] = Field(default=[])
version: str = Field(default="") version: str = Field(default="")
base: str = Field(default="") base: str = Field(default="")
basePrompty: Optional[Prompty] = Field(default=None) basePrompty: Optional[Prompty] = Field(default=None)
@ -70,8 +70,8 @@ class Prompty(BaseModel):
sample: dict = Field(default={}) sample: dict = Field(default={})
# input / output # input / output
inputs: Dict[str, PropertySettings] = Field(default={}) inputs: dict[str, PropertySettings] = Field(default={})
outputs: Dict[str, PropertySettings] = Field(default={}) outputs: dict[str, PropertySettings] = Field(default={})
# template # template
template: TemplateSettings template: TemplateSettings
@ -79,7 +79,7 @@ class Prompty(BaseModel):
file: FilePath = Field(default="") # type: ignore[assignment] file: FilePath = Field(default="") # type: ignore[assignment]
content: str = Field(default="") content: str = Field(default="")
def to_safe_dict(self) -> Dict[str, Any]: def to_safe_dict(self) -> dict[str, Any]:
d = {} d = {}
for k, v in self: for k, v in self:
if v != "" and v != {} and v != [] and v is not None: if v != "" and v != {} and v != [] and v is not None:
@ -130,7 +130,7 @@ class Prompty(BaseModel):
attribute.startswith("file:") attribute.startswith("file:")
and Path(parent / attribute.split(":")[1]).exists() and Path(parent / attribute.split(":")[1]).exists()
): ):
with open(parent / attribute.split(":")[1], "r") as f: with open(parent / attribute.split(":")[1]) as f:
items = json.load(f) items = json.load(f)
if isinstance(items, list): if isinstance(items, list):
return [Prompty.normalize(value, parent) for value in items] return [Prompty.normalize(value, parent) for value in items]
@ -155,8 +155,8 @@ class Prompty(BaseModel):
def param_hoisting( def param_hoisting(
top: Dict[str, Any], bottom: Dict[str, Any], top_key: Any = None top: dict[str, Any], bottom: dict[str, Any], top_key: Any = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Merge two dictionaries with hoisting of parameters from bottom to top. """Merge two dictionaries with hoisting of parameters from bottom to top.
Args: Args:
@ -198,18 +198,18 @@ class NoOpParser(Invoker):
return data return data
class InvokerFactory(object): class InvokerFactory:
"""Factory for creating invokers.""" """Factory for creating invokers."""
_instance = None _instance = None
_renderers: Dict[str, Type[Invoker]] = {} _renderers: dict[str, type[Invoker]] = {}
_parsers: Dict[str, Type[Invoker]] = {} _parsers: dict[str, type[Invoker]] = {}
_executors: Dict[str, Type[Invoker]] = {} _executors: dict[str, type[Invoker]] = {}
_processors: Dict[str, Type[Invoker]] = {} _processors: dict[str, type[Invoker]] = {}
def __new__(cls) -> InvokerFactory: def __new__(cls) -> InvokerFactory:
if cls._instance is None: if cls._instance is None:
cls._instance = super(InvokerFactory, cls).__new__(cls) cls._instance = super().__new__(cls)
# Add NOOP invokers # Add NOOP invokers
cls._renderers["NOOP"] = NoOpParser cls._renderers["NOOP"] = NoOpParser
cls._parsers["NOOP"] = NoOpParser cls._parsers["NOOP"] = NoOpParser
@ -221,7 +221,7 @@ class InvokerFactory(object):
self, self,
type: Literal["renderer", "parser", "executor", "processor"], type: Literal["renderer", "parser", "executor", "processor"],
name: str, name: str,
invoker: Type[Invoker], invoker: type[Invoker],
) -> None: ) -> None:
if type == "renderer": if type == "renderer":
self._renderers[name] = invoker self._renderers[name] = invoker
@ -264,7 +264,7 @@ class InvokerFactory(object):
else: else:
raise ValueError(f"Invalid type {type}") raise ValueError(f"Invalid type {type}")
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return { return {
"renderers": { "renderers": {
k: f"{v.__module__}.{v.__name__}" for k, v in self._renderers.items() k: f"{v.__module__}.{v.__name__}" for k, v in self._renderers.items()

View File

@ -1,4 +1,4 @@
from typing import Any, Dict from typing import Any
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import Runnable, RunnableLambda from langchain_core.runnables import Runnable, RunnableLambda
@ -10,10 +10,10 @@ from .utils import load, prepare
def create_chat_prompt( def create_chat_prompt(
path: str, path: str,
input_name_agent_scratchpad: str = "agent_scratchpad", input_name_agent_scratchpad: str = "agent_scratchpad",
) -> Runnable[Dict[str, Any], ChatPromptTemplate]: ) -> Runnable[dict[str, Any], ChatPromptTemplate]:
"""Create a chat prompt from a Langchain schema.""" """Create a chat prompt from a Langchain schema."""
def runnable_chat_lambda(inputs: Dict[str, Any]) -> ChatPromptTemplate: def runnable_chat_lambda(inputs: dict[str, Any]) -> ChatPromptTemplate:
p = load(path) p = load(path)
parsed = prepare(p, inputs) parsed = prepare(p, inputs)
# Parsed messages have been templated # Parsed messages have been templated

View File

@ -1,6 +1,6 @@
import base64 import base64
import re import re
from typing import Dict, List, Type, Union from typing import Union
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
@ -15,7 +15,7 @@ from .core import Invoker, Prompty, SimpleModel
class RoleMap: class RoleMap:
_ROLE_MAP: Dict[str, Type[BaseMessage]] = { _ROLE_MAP: dict[str, type[BaseMessage]] = {
"system": SystemMessage, "system": SystemMessage,
"user": HumanMessage, "user": HumanMessage,
"human": HumanMessage, "human": HumanMessage,
@ -26,7 +26,7 @@ class RoleMap:
ROLES = _ROLE_MAP.keys() ROLES = _ROLE_MAP.keys()
@classmethod @classmethod
def get_message_class(cls, role: str) -> Type[BaseMessage]: def get_message_class(cls, role: str) -> type[BaseMessage]:
return cls._ROLE_MAP[role] return cls._ROLE_MAP[role]
@ -60,7 +60,7 @@ class PromptyChatParser(Invoker):
"and .jpg / .jpeg are supported." "and .jpg / .jpeg are supported."
) )
def parse_content(self, content: str) -> Union[str, List]: def parse_content(self, content: str) -> Union[str, list]:
"""for parsing inline images""" """for parsing inline images"""
# regular expression to parse markdown images # regular expression to parse markdown images
image = r"(?P<alt>!\[[^\]]*\])\((?P<filename>.*?)(?=\"|\))\)" image = r"(?P<alt>!\[[^\]]*\])\((?P<filename>.*?)(?=\"|\))\)"

View File

@ -1,6 +1,6 @@
import traceback import traceback
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Union from typing import Any, Union
from .core import ( from .core import (
Frontmatter, Frontmatter,
@ -120,7 +120,7 @@ def load(prompt_path: str, configuration: str = "default") -> Prompty:
def prepare( def prepare(
prompt: Prompty, prompt: Prompty,
inputs: Dict[str, Any] = {}, inputs: dict[str, Any] = {},
) -> Any: ) -> Any:
"""Prepare the inputs for the prompty. """Prepare the inputs for the prompty.
@ -166,9 +166,9 @@ def prepare(
def run( def run(
prompt: Prompty, prompt: Prompty,
content: Union[Dict, List, str], content: Union[dict, list, str],
configuration: Dict[str, Any] = {}, configuration: dict[str, Any] = {},
parameters: Dict[str, Any] = {}, parameters: dict[str, Any] = {},
raw: bool = False, raw: bool = False,
) -> Any: ) -> Any:
"""Run the prompty. """Run the prompty.
@ -219,9 +219,9 @@ def run(
def execute( def execute(
prompt: Union[str, Prompty], prompt: Union[str, Prompty],
configuration: Dict[str, Any] = {}, configuration: dict[str, Any] = {},
parameters: Dict[str, Any] = {}, parameters: dict[str, Any] = {},
inputs: Dict[str, Any] = {}, inputs: dict[str, Any] = {},
raw: bool = False, raw: bool = False,
connection: str = "default", connection: str = "default",
) -> Any: ) -> Any:

View File

@ -45,9 +45,12 @@ langchain-core = { path = "../../core", editable = true }
langchain-text-splitters = { path = "../../text-splitters", editable = true } langchain-text-splitters = { path = "../../text-splitters", editable = true }
langchain = { path = "../../langchain", editable = true } langchain = { path = "../../langchain", editable = true }
[tool.ruff] [tool.ruff]
select = ["E", "F", "I"] target-version = "py39"
[tool.ruff.lint]
select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.mypy] [tool.mypy]
disallow_untyped_defs = "True" disallow_untyped_defs = "True"

Some files were not shown because too many files have changed in this diff Show More