partners[lint]: run pyupgrade to get code in line with 3.9 standards (#30781)

Using `pyupgrade` to get all `partners` code up to 3.9 standards
(mostly, fixing old `typing` imports).
This commit is contained in:
Sydney Runkle 2025-04-11 07:18:44 -04:00 committed by GitHub
parent e72f3c26a0
commit 8c6734325b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
123 changed files with 1000 additions and 1109 deletions

View File

@ -1,21 +1,14 @@
import copy
import re
import warnings
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from functools import cached_property
from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
@ -89,8 +82,8 @@ class AnthropicTool(TypedDict):
name: str
description: str
input_schema: Dict[str, Any]
cache_control: NotRequired[Dict[str, str]]
input_schema: dict[str, Any]
cache_control: NotRequired[dict[str, str]]
def _is_builtin_tool(tool: Any) -> bool:
@ -109,7 +102,7 @@ def _is_builtin_tool(tool: Any) -> bool:
return any(tool_type.startswith(prefix) for prefix in _builtin_tool_prefixes)
def _format_image(image_url: str) -> Dict:
def _format_image(image_url: str) -> dict:
"""
Formats an image of format data:image/jpeg;base64,{b64_string}
to a dict for anthropic api
@ -138,7 +131,7 @@ def _format_image(image_url: str) -> Dict:
def _merge_messages(
messages: Sequence[BaseMessage],
) -> List[Union[SystemMessage, AIMessage, HumanMessage]]:
) -> list[Union[SystemMessage, AIMessage, HumanMessage]]:
"""Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501
merged: list = []
for curr in messages:
@ -169,7 +162,7 @@ def _merge_messages(
for c in (SystemMessage, HumanMessage)
):
if isinstance(cast(BaseMessage, last).content, str):
new_content: List = [
new_content: list = [
{"type": "text", "text": cast(BaseMessage, last).content}
]
else:
@ -185,8 +178,8 @@ def _merge_messages(
def _format_messages(
messages: List[BaseMessage],
) -> Tuple[Union[str, List[Dict], None], List[Dict]]:
messages: list[BaseMessage],
) -> tuple[Union[str, list[dict], None], list[dict]]:
"""Format messages for anthropic."""
"""
@ -198,8 +191,8 @@ def _format_messages(
for m in messages
]
"""
system: Union[str, List[Dict], None] = None
formatted_messages: List[Dict] = []
system: Union[str, list[dict], None] = None
formatted_messages: list[dict] = []
merged_messages = _merge_messages(messages)
for i, message in enumerate(merged_messages):
@ -220,7 +213,7 @@ def _format_messages(
continue
role = _message_type_lookups[message.type]
content: Union[str, List]
content: Union[str, list]
if not isinstance(message.content, str):
# parse as dict
@ -830,7 +823,7 @@ class ChatAnthropic(BaseChatModel):
max_retries: int = 2
"""Number of retries allowed for requests sent to the Anthropic Completion API."""
stop_sequences: Optional[List[str]] = Field(None, alias="stop")
stop_sequences: Optional[list[str]] = Field(None, alias="stop")
"""Default stop sequences."""
anthropic_api_url: Optional[str] = Field(
@ -858,7 +851,7 @@ class ChatAnthropic(BaseChatModel):
default_headers: Optional[Mapping[str, str]] = None
"""Headers to pass to the Anthropic clients, will be used for every API call."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: dict[str, Any] = Field(default_factory=dict)
streaming: bool = False
"""Whether to use streaming or not."""
@ -868,7 +861,7 @@ class ChatAnthropic(BaseChatModel):
message chunks will be generated during the stream including usage metadata.
"""
thinking: Optional[Dict[str, Any]] = Field(default=None)
thinking: Optional[dict[str, Any]] = Field(default=None)
"""Parameters for Claude reasoning,
e.g., ``{"type": "enabled", "budget_tokens": 10_000}``"""
@ -878,7 +871,7 @@ class ChatAnthropic(BaseChatModel):
return "anthropic-chat"
@property
def lc_secrets(self) -> Dict[str, str]:
def lc_secrets(self) -> dict[str, str]:
return {"anthropic_api_key": "ANTHROPIC_API_KEY"}
@classmethod
@ -886,12 +879,12 @@ class ChatAnthropic(BaseChatModel):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "anthropic"]
@property
def _identifying_params(self) -> Dict[str, Any]:
def _identifying_params(self) -> dict[str, Any]:
"""Get the identifying parameters."""
return {
"model": self.model,
@ -907,7 +900,7 @@ class ChatAnthropic(BaseChatModel):
}
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
@ -925,14 +918,14 @@ class ChatAnthropic(BaseChatModel):
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict) -> Any:
def build_extra(cls, values: dict) -> Any:
all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names)
return values
@cached_property
def _client_params(self) -> Dict[str, Any]:
client_params: Dict[str, Any] = {
def _client_params(self) -> dict[str, Any]:
client_params: dict[str, Any] = {
"api_key": self.anthropic_api_key.get_secret_value(),
"base_url": self.anthropic_api_url,
"max_retries": self.max_retries,
@ -958,9 +951,9 @@ class ChatAnthropic(BaseChatModel):
self,
input_: LanguageModelInput,
*,
stop: Optional[List[str]] = None,
**kwargs: Dict,
) -> Dict:
stop: Optional[list[str]] = None,
**kwargs: dict,
) -> dict:
messages = self._convert_input(input_).to_messages()
system, formatted_messages = _format_messages(messages)
payload = {
@ -981,8 +974,8 @@ class ChatAnthropic(BaseChatModel):
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
*,
stream_usage: Optional[bool] = None,
@ -1012,8 +1005,8 @@ class ChatAnthropic(BaseChatModel):
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
*,
stream_usage: Optional[bool] = None,
@ -1088,8 +1081,8 @@ class ChatAnthropic(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -1104,8 +1097,8 @@ class ChatAnthropic(BaseChatModel):
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -1120,7 +1113,7 @@ class ChatAnthropic(BaseChatModel):
def _get_llm_for_structured_output_when_thinking_is_enabled(
self,
schema: Union[Dict, type],
schema: Union[dict, type],
formatted_tool: AnthropicTool,
) -> Runnable[LanguageModelInput, BaseMessage]:
thinking_admonition = (
@ -1148,10 +1141,10 @@ class ChatAnthropic(BaseChatModel):
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
*,
tool_choice: Optional[
Union[Dict[str, str], Literal["any", "auto"], str]
Union[dict[str, str], Literal["any", "auto"], str]
] = None,
parallel_tool_calls: Optional[bool] = None,
**kwargs: Any,
@ -1326,11 +1319,11 @@ class ChatAnthropic(BaseChatModel):
def with_structured_output(
self,
schema: Union[Dict, type],
schema: Union[dict, type],
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
@ -1483,9 +1476,9 @@ class ChatAnthropic(BaseChatModel):
@beta()
def get_num_tokens_from_messages(
self,
messages: List[BaseMessage],
messages: list[BaseMessage],
tools: Optional[
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
Sequence[Union[dict[str, Any], type, Callable, BaseTool]]
] = None,
) -> int:
"""Count tokens in a sequence of input messages.
@ -1546,7 +1539,7 @@ class ChatAnthropic(BaseChatModel):
https://docs.anthropic.com/en/docs/build-with-claude/token-counting
"""
formatted_system, formatted_messages = _format_messages(messages)
kwargs: Dict[str, Any] = {}
kwargs: dict[str, Any] = {}
if isinstance(formatted_system, str):
kwargs["system"] = formatted_system
if tools:
@ -1562,7 +1555,7 @@ class ChatAnthropic(BaseChatModel):
def convert_to_anthropic_tool(
tool: Union[Dict[str, Any], Type, Callable, BaseTool],
tool: Union[dict[str, Any], type, Callable, BaseTool],
) -> AnthropicTool:
"""Convert a tool-like object to an Anthropic tool definition."""
# already in Anthropic tool format
@ -1611,8 +1604,8 @@ class _AnthropicToolUse(TypedDict):
def _lc_tool_calls_to_anthropic_tool_use_blocks(
tool_calls: List[ToolCall],
) -> List[_AnthropicToolUse]:
tool_calls: list[ToolCall],
) -> list[_AnthropicToolUse]:
blocks = []
for tool_call in tool_calls:
blocks.append(
@ -1735,7 +1728,7 @@ class ChatAnthropicMessages(ChatAnthropic):
def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata:
input_token_details: Dict = {
input_token_details: dict = {
"cache_read": getattr(anthropic_usage, "cache_read_input_tokens", None),
"cache_creation": getattr(anthropic_usage, "cache_creation_input_tokens", None),
}

View File

@ -1,8 +1,6 @@
import json
from typing import (
Any,
Dict,
List,
Union,
)
@ -44,7 +42,7 @@ TOOL_PARAMETER_FORMAT = """<parameter>
</parameter>"""
def _get_type(parameter: Dict[str, Any]) -> str:
def _get_type(parameter: dict[str, Any]) -> str:
if "type" in parameter:
return parameter["type"]
if "anyOf" in parameter:
@ -54,9 +52,9 @@ def _get_type(parameter: Dict[str, Any]) -> str:
return json.dumps(parameter)
def get_system_message(tools: List[Dict]) -> str:
def get_system_message(tools: list[dict]) -> str:
"""Generate a system message that describes the available tools."""
tools_data: List[Dict] = [
tools_data: list[dict] = [
{
"tool_name": tool["name"],
"tool_description": tool["description"],
@ -86,13 +84,13 @@ def get_system_message(tools: List[Dict]) -> str:
return SYSTEM_PROMPT_FORMAT.format(formatted_tools=tools_formatted)
def _xml_to_dict(t: Any) -> Union[str, Dict[str, Any]]:
def _xml_to_dict(t: Any) -> Union[str, dict[str, Any]]:
# Base case: If the element has no children, return its text or an empty string.
if len(t) == 0:
return t.text or ""
# Recursive case: The element has children. Convert them into a dictionary.
d: Dict[str, Any] = {}
d: dict[str, Any] = {}
for child in t:
if child.tag not in d:
d[child.tag] = _xml_to_dict(child)
@ -104,7 +102,7 @@ def _xml_to_dict(t: Any) -> Union[str, Dict[str, Any]]:
return d
def _xml_to_function_call(invoke: Any, tools: List[Dict]) -> Dict[str, Any]:
def _xml_to_function_call(invoke: Any, tools: list[dict]) -> dict[str, Any]:
name = invoke.find("tool_name").text
arguments = _xml_to_dict(invoke.find("parameters"))
@ -135,7 +133,7 @@ def _xml_to_function_call(invoke: Any, tools: List[Dict]) -> Dict[str, Any]:
}
def _xml_to_tool_calls(elem: Any, tools: List[Dict]) -> List[Dict[str, Any]]:
def _xml_to_tool_calls(elem: Any, tools: list[dict]) -> list[dict[str, Any]]:
"""
Convert an XML element and its children into a dictionary of dictionaries.
"""

View File

@ -1,13 +1,9 @@
import re
import warnings
from collections.abc import AsyncIterator, Iterator, Mapping
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
)
@ -83,11 +79,11 @@ class _AnthropicCommon(BaseLanguageModel):
HUMAN_PROMPT: Optional[str] = None
AI_PROMPT: Optional[str] = None
count_tokens: Optional[Callable[[str], int]] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: dict[str, Any] = Field(default_factory=dict)
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict) -> Any:
def build_extra(cls, values: dict) -> Any:
all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names)
return values
@ -131,7 +127,7 @@ class _AnthropicCommon(BaseLanguageModel):
"""Get the identifying parameters."""
return {**{}, **self._default_params}
def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]:
def _get_anthropic_stop(self, stop: Optional[list[str]] = None) -> list[str]:
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
raise NameError("Please ensure the anthropic package is loaded")
@ -165,7 +161,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
@model_validator(mode="before")
@classmethod
def raise_warning(cls, values: Dict) -> Any:
def raise_warning(cls, values: dict) -> Any:
"""Raise warning that this class is deprecated."""
warnings.warn(
"This Anthropic LLM is deprecated. "
@ -180,7 +176,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
return "anthropic-llm"
@property
def lc_secrets(self) -> Dict[str, str]:
def lc_secrets(self) -> dict[str, str]:
return {"anthropic_api_key": "ANTHROPIC_API_KEY"}
@classmethod
@ -188,7 +184,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
return True
@property
def _identifying_params(self) -> Dict[str, Any]:
def _identifying_params(self) -> dict[str, Any]:
"""Get the identifying parameters."""
return {
"model": self.model,
@ -203,7 +199,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
}
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs)
@ -233,7 +229,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
@ -277,7 +273,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
@ -303,7 +299,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
@ -338,7 +334,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, Type, Union, cast
from typing import Any, Optional, Union, cast
from langchain_core.messages import AIMessage, ToolCall
from langchain_core.messages.tool import tool_call
@ -14,14 +14,14 @@ class ToolsOutputParser(BaseGenerationOutputParser):
"""Whether to return only the first tool call."""
args_only: bool = False
"""Whether to return only the arguments of the tool calls."""
pydantic_schemas: Optional[List[Type[BaseModel]]] = None
pydantic_schemas: Optional[list[type[BaseModel]]] = None
"""Pydantic schemas to parse tool calls into."""
model_config = ConfigDict(
extra="forbid",
)
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse a list of candidate model Generations into a specific format.
Args:
@ -34,7 +34,7 @@ class ToolsOutputParser(BaseGenerationOutputParser):
if not result or not isinstance(result[0], ChatGeneration):
return None if self.first_tool_only else []
message = cast(AIMessage, result[0].message)
tool_calls: List = [
tool_calls: list = [
dict(tc) for tc in _extract_tool_calls_from_message(message)
]
if isinstance(message.content, list):
@ -64,14 +64,14 @@ class ToolsOutputParser(BaseGenerationOutputParser):
return cls_(**tool_call["args"])
def _extract_tool_calls_from_message(message: AIMessage) -> List[ToolCall]:
def _extract_tool_calls_from_message(message: AIMessage) -> list[ToolCall]:
"""Extract tool calls from a list of content blocks."""
if message.tool_calls:
return message.tool_calls
return extract_tool_calls(message.content)
def extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[ToolCall]:
def extract_tool_calls(content: Union[str, list[Union[str, dict]]]) -> list[ToolCall]:
"""Extract tool calls from a list of content blocks."""
if isinstance(content, list):
tool_calls = []

View File

@ -55,8 +55,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
disallow_untyped_defs = "True"
plugins = ['pydantic.mypy']
[tool.ruff]
target-version = "py39"
[tool.ruff.lint]
select = ["E", "F", "I", "T201"]
select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.coverage.run]
omit = ["tests/*"]

View File

@ -2,7 +2,7 @@
import json
from base64 import b64encode
from typing import List, Optional
from typing import Optional
import httpx
import pytest
@ -270,7 +270,7 @@ def test_anthropic_call() -> None:
def test_anthropic_generate() -> None:
"""Test generate method of anthropic."""
chat = ChatAnthropic(model=MODEL_NAME)
chat_messages: List[List[BaseMessage]] = [
chat_messages: list[list[BaseMessage]] = [
[HumanMessage(content="How many toes do dogs have?")]
]
messages_copy = [messages.copy() for messages in chat_messages]
@ -318,7 +318,7 @@ async def test_anthropic_async_streaming_callback() -> None:
callback_manager=callback_manager,
verbose=True,
)
chat_messages: List[BaseMessage] = [
chat_messages: list[BaseMessage] = [
HumanMessage(content="How many toes do dogs have?")
]
async for token in chat.astream(chat_messages):
@ -809,7 +809,7 @@ def test_image_tool_calling() -> None:
fav_color: str
human_content: List[dict] = [
human_content: list[dict] = [
{
"type": "text",
"text": "what's your favorite color in this image",

View File

@ -1,7 +1,7 @@
"""Test ChatAnthropic chat model."""
from enum import Enum
from typing import List, Optional
from typing import Optional
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
@ -136,7 +136,7 @@ def test_anthropic_complex_structured_output() -> None:
sender_address: Optional[str] = Field(
None, description="The sender's address, if available"
)
action_items: List[str] = Field(
action_items: list[str] = Field(
..., description="A list of action items requested by the email"
)
topic: str = Field(

View File

@ -1,6 +1,6 @@
"""Test Anthropic API wrapper."""
from typing import Generator
from collections.abc import Generator
import pytest
from langchain_core.callbacks import CallbackManager

View File

@ -1,7 +1,7 @@
"""Standard LangChain interface tests"""
from pathlib import Path
from typing import Dict, List, Literal, Type, cast
from typing import Literal, cast
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
@ -14,7 +14,7 @@ REPO_ROOT_DIR = Path(__file__).parents[5]
class TestAnthropicStandard(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return ChatAnthropic
@property
@ -36,9 +36,9 @@ class TestAnthropicStandard(ChatModelIntegrationTests):
@property
def supported_usage_metadata_details(
self,
) -> Dict[
) -> dict[
Literal["invoke", "stream"],
List[
list[
Literal[
"audio_input",
"audio_output",
@ -58,7 +58,7 @@ class TestAnthropicStandard(ChatModelIntegrationTests):
model="claude-3-5-sonnet-20240620", # type: ignore[call-arg]
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, # type: ignore[call-arg]
)
with open(REPO_ROOT_DIR / "README.md", "r") as f:
with open(REPO_ROOT_DIR / "README.md") as f:
readme = f.read()
input_ = f"""What's langchain? Here's the langchain README:
@ -87,7 +87,7 @@ class TestAnthropicStandard(ChatModelIntegrationTests):
model="claude-3-5-sonnet-20240620", # type: ignore[call-arg]
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, # type: ignore[call-arg]
)
with open(REPO_ROOT_DIR / "README.md", "r") as f:
with open(REPO_ROOT_DIR / "README.md") as f:
readme = f.read()
input_ = f"""What's langchain? Here's the langchain README:

View File

@ -1,7 +1,7 @@
"""Test chat model integration."""
import os
from typing import Any, Callable, Dict, Literal, Type, cast
from typing import Any, Callable, Literal, cast
import pytest
from anthropic.types import Message, TextBlock, Usage
@ -297,7 +297,7 @@ def test__merge_messages_mutation() -> None:
@pytest.fixture()
def pydantic() -> Type[BaseModel]:
def pydantic() -> type[BaseModel]:
class dummy_function(BaseModel):
"""dummy function"""
@ -328,7 +328,7 @@ def dummy_tool() -> BaseTool:
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
class DummyFunction(BaseTool): # type: ignore[override]
args_schema: Type[BaseModel] = Schema
args_schema: type[BaseModel] = Schema
name: str = "dummy_function"
description: str = "dummy function"
@ -339,7 +339,7 @@ def dummy_tool() -> BaseTool:
@pytest.fixture()
def json_schema() -> Dict:
def json_schema() -> dict:
return {
"title": "dummy_function",
"description": "dummy function",
@ -357,7 +357,7 @@ def json_schema() -> Dict:
@pytest.fixture()
def openai_function() -> Dict:
def openai_function() -> dict:
return {
"name": "dummy_function",
"description": "dummy function",
@ -377,11 +377,11 @@ def openai_function() -> Dict:
def test_convert_to_anthropic_tool(
pydantic: Type[BaseModel],
pydantic: type[BaseModel],
function: Callable,
dummy_tool: BaseTool,
json_schema: Dict,
openai_function: Dict,
json_schema: dict,
openai_function: dict,
) -> None:
expected = {
"name": "dummy_function",

View File

@ -1,4 +1,4 @@
from typing import Any, List, Literal
from typing import Any, Literal
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration
@ -6,7 +6,7 @@ from pydantic import BaseModel
from langchain_anthropic.output_parsers import ToolsOutputParser
_CONTENT: List = [
_CONTENT: list = [
{
"type": "text",
"text": "thought",
@ -19,7 +19,7 @@ _CONTENT: List = [
{"type": "tool_use", "input": {"baz": "a"}, "id": "2", "name": "_Foo2"},
]
_RESULT: List = [ChatGeneration(message=AIMessage(_CONTENT))] # type: ignore[misc]
_RESULT: list = [ChatGeneration(message=AIMessage(_CONTENT))] # type: ignore[misc]
class _Foo1(BaseModel):

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests"""
from typing import Type
from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests import ChatModelUnitTests
@ -10,7 +8,7 @@ from langchain_anthropic import ChatAnthropic
class TestAnthropicStandard(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return ChatAnthropic
@property

View File

@ -8,17 +8,12 @@ from __future__ import annotations
import base64
import logging
import uuid
from collections.abc import Iterable, Sequence
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
@ -37,11 +32,11 @@ logger = logging.getLogger()
DEFAULT_K = 4 # Number of Documents to return.
def _results_to_docs(results: Any) -> List[Document]:
def _results_to_docs(results: Any) -> list[Document]:
return [doc for doc, _ in _results_to_docs_and_scores(results)]
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
def _results_to_docs_and_scores(results: Any) -> list[tuple[Document, float]]:
return [
# TODO: Chroma can do batch querying,
# we shouldn't hard code to the 1st result
@ -58,7 +53,7 @@ def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
]
def _results_to_docs_and_vectors(results: Any) -> List[Tuple[Document, np.ndarray]]:
def _results_to_docs_and_vectors(results: Any) -> list[tuple[Document, np.ndarray]]:
return [
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(
@ -69,7 +64,7 @@ def _results_to_docs_and_vectors(results: Any) -> List[Tuple[Document, np.ndarra
]
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
Matrix = Union[list[list[float]], list[np.ndarray], np.ndarray]
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
@ -104,7 +99,7 @@ def maximal_marginal_relevance(
embedding_list: list,
lambda_mult: float = 0.5,
k: int = 4,
) -> List[int]:
) -> list[int]:
"""Calculate maximal marginal relevance.
Args:
@ -287,7 +282,7 @@ class Chroma(VectorStore):
embedding_function: Optional[Embeddings] = None,
persist_directory: Optional[str] = None,
client_settings: Optional[chromadb.config.Settings] = None,
collection_metadata: Optional[Dict] = None,
collection_metadata: Optional[dict] = None,
client: Optional[chromadb.ClientAPI] = None,
relevance_score_fn: Optional[Callable[[float], float]] = None,
create_collection_if_not_exists: Optional[bool] = True,
@ -370,13 +365,13 @@ class Chroma(VectorStore):
@xor_args(("query_texts", "query_embeddings"))
def __query_collection(
self,
query_texts: Optional[List[str]] = None,
query_embeddings: Optional[List[List[float]]] = None,
query_texts: Optional[list[str]] = None,
query_embeddings: Optional[list[list[float]]] = None,
n_results: int = 4,
where: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
where: Optional[dict[str, str]] = None,
where_document: Optional[dict[str, str]] = None,
**kwargs: Any,
) -> Union[List[Document], chromadb.QueryResult]:
) -> Union[list[Document], chromadb.QueryResult]:
"""Query the chroma collection.
Args:
@ -411,11 +406,11 @@ class Chroma(VectorStore):
def add_images(
self,
uris: List[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
uris: list[str],
metadatas: Optional[list[dict]] = None,
ids: Optional[list[str]] = None,
**kwargs: Any,
) -> List[str]:
) -> list[str]:
"""Run more images through the embeddings and add to the vectorstore.
Args:
@ -502,10 +497,10 @@ class Chroma(VectorStore):
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
metadatas: Optional[list[dict]] = None,
ids: Optional[list[str]] = None,
**kwargs: Any,
) -> List[str]:
) -> list[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
@ -591,9 +586,9 @@ class Chroma(VectorStore):
self,
query: str,
k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None,
filter: Optional[dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
"""Run similarity search with Chroma.
Args:
@ -612,12 +607,12 @@ class Chroma(VectorStore):
def similarity_search_by_vector(
self,
embedding: List[float],
embedding: list[float],
k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
filter: Optional[dict[str, str]] = None,
where_document: Optional[dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
"""Return docs most similar to embedding vector.
Args:
@ -642,12 +637,12 @@ class Chroma(VectorStore):
def similarity_search_by_vector_with_relevance_scores(
self,
embedding: List[float],
embedding: list[float],
k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
filter: Optional[dict[str, str]] = None,
where_document: Optional[dict[str, str]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
) -> list[tuple[Document, float]]:
"""Return docs most similar to embedding vector and similarity score.
Args:
@ -675,10 +670,10 @@ class Chroma(VectorStore):
self,
query: str,
k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
filter: Optional[dict[str, str]] = None,
where_document: Optional[dict[str, str]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
) -> list[tuple[Document, float]]:
"""Run similarity search with Chroma with distance.
Args:
@ -717,10 +712,10 @@ class Chroma(VectorStore):
self,
query: str,
k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
filter: Optional[dict[str, str]] = None,
where_document: Optional[dict[str, str]] = None,
**kwargs: Any,
) -> List[Tuple[Document, np.ndarray]]:
) -> list[tuple[Document, np.ndarray]]:
"""Run similarity search with Chroma with vectors.
Args:
@ -800,9 +795,9 @@ class Chroma(VectorStore):
self,
uri: str,
k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None,
filter: Optional[dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
"""Search for similar images based on the given image URI.
Args:
@ -844,9 +839,9 @@ class Chroma(VectorStore):
self,
uri: str,
k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None,
filter: Optional[dict[str, str]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
) -> list[tuple[Document, float]]:
"""Search for similar images based on the given image URI.
Args:
@ -886,14 +881,14 @@ class Chroma(VectorStore):
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
embedding: list[float],
k: int = DEFAULT_K,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
filter: Optional[dict[str, str]] = None,
where_document: Optional[dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
@ -942,10 +937,10 @@ class Chroma(VectorStore):
k: int = DEFAULT_K,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
filter: Optional[dict[str, str]] = None,
where_document: Optional[dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
@ -1005,8 +1000,8 @@ class Chroma(VectorStore):
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Optional[List[str]] = None,
) -> Dict[str, Any]:
include: Optional[list[str]] = None,
) -> dict[str, Any]:
"""Gets the collection.
Args:
@ -1081,7 +1076,7 @@ class Chroma(VectorStore):
return self.update_documents([document_id], [document])
# type: ignore
def update_documents(self, ids: List[str], documents: List[Document]) -> None:
def update_documents(self, ids: list[str], documents: list[Document]) -> None:
"""Update a document in the collection.
Args:
@ -1129,16 +1124,16 @@ class Chroma(VectorStore):
@classmethod
def from_texts(
cls: Type[Chroma],
texts: List[str],
cls: type[Chroma],
texts: list[str],
embedding: Optional[Embeddings] = None,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
metadatas: Optional[list[dict]] = None,
ids: Optional[list[str]] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
persist_directory: Optional[str] = None,
client_settings: Optional[chromadb.config.Settings] = None,
client: Optional[chromadb.ClientAPI] = None,
collection_metadata: Optional[Dict] = None,
collection_metadata: Optional[dict] = None,
**kwargs: Any,
) -> Chroma:
"""Create a Chroma vectorstore from a raw documents.
@ -1200,15 +1195,15 @@ class Chroma(VectorStore):
@classmethod
def from_documents(
cls: Type[Chroma],
documents: List[Document],
cls: type[Chroma],
documents: list[Document],
embedding: Optional[Embeddings] = None,
ids: Optional[List[str]] = None,
ids: Optional[list[str]] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
persist_directory: Optional[str] = None,
client_settings: Optional[chromadb.config.Settings] = None,
client: Optional[chromadb.ClientAPI] = None, # Add this line
collection_metadata: Optional[Dict] = None,
collection_metadata: Optional[dict] = None,
**kwargs: Any,
) -> Chroma:
"""Create a Chroma vectorstore from a list of documents.
@ -1249,7 +1244,7 @@ class Chroma(VectorStore):
**kwargs,
)
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> None:
"""Delete by vector IDs.
Args:

View File

@ -58,8 +58,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
[tool.mypy]
disallow_untyped_defs = true
[tool.ruff]
target-version = "py39"
[tool.ruff.lint]
select = ["E", "F", "I", "T201", "D"]
select = ["E", "F", "I", "T201", "D", "UP"]
ignore = [ "UP007", ]
[tool.coverage.run]
omit = ["tests/*"]

View File

@ -1,7 +1,6 @@
"""Fake Embedding class for testing purposes."""
import math
from typing import List
from langchain_core.embeddings import Embeddings
@ -11,22 +10,22 @@ fake_texts = ["foo", "bar", "baz"]
class FakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Return simple embeddings.
Embeddings encode each text as its index."""
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
return [[1.0] * 9 + [float(i)] for i in range(len(texts))]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return self.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Return constant query embeddings.
Embeddings are identical to embed_documents(texts)[0].
Distance to each text will be that text's index,
as it was passed to embed_documents."""
return [float(1.0)] * 9 + [float(0.0)]
return [1.0] * 9 + [0.0]
async def aembed_query(self, text: str) -> List[float]:
async def aembed_query(self, text: str) -> list[float]:
return self.embed_query(text)
@ -35,22 +34,22 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
vectors for the same texts."""
def __init__(self, dimensionality: int = 10) -> None:
self.known_texts: List[str] = []
self.known_texts: list[str] = []
self.dimensionality = dimensionality
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Return consistent embeddings for each text seen so far."""
out_vectors = []
for text in texts:
if text not in self.known_texts:
self.known_texts.append(text)
vector = [float(1.0)] * (self.dimensionality - 1) + [
vector = [1.0] * (self.dimensionality - 1) + [
float(self.known_texts.index(text))
]
out_vectors.append(vector)
return out_vectors
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
return self.embed_documents([text])[0]
@ -61,13 +60,13 @@ class AngularTwoDimensionalEmbeddings(Embeddings):
From angles (as strings in units of pi) to unit embedding vectors on a circle.
"""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""
Make a list of texts into a list of embedding vectors.
"""
return [self.embed_query(text) for text in texts]
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""
Convert input text to a 'vector' (list of floats).
If the text is a number, use it as the angle for the

View File

@ -3,8 +3,8 @@
import os.path
import tempfile
import uuid
from collections.abc import Generator
from typing import (
Generator,
cast,
)
@ -222,7 +222,7 @@ def test_chroma_with_metadatas_with_scores_using_vector() -> None:
def test_chroma_search_filter() -> None:
"""Test end to end construction and search with metadata filtering."""
texts = ["far", "bar", "baz"]
metadatas = [{"first_letter": "{}".format(text[0])} for text in texts]
metadatas = [{"first_letter": f"{text[0]}"} for text in texts]
ids = [f"id_{i}" for i in range(len(texts))]
docsearch = Chroma.from_texts(
collection_name="test_collection",
@ -245,7 +245,7 @@ def test_chroma_search_filter() -> None:
def test_chroma_search_filter_with_scores() -> None:
"""Test end to end construction and scored search with metadata filtering."""
texts = ["far", "bar", "baz"]
metadatas = [{"first_letter": "{}".format(text[0])} for text in texts]
metadatas = [{"first_letter": f"{text[0]}"} for text in texts]
ids = [f"id_{i}" for i in range(len(texts))]
docsearch = Chroma.from_texts(
collection_name="test_collection",

View File

@ -1,4 +1,4 @@
from typing import Generator
from collections.abc import Generator
import pytest
from langchain_core.vectorstores import VectorStore

View File

@ -1,7 +1,8 @@
"""DeepSeek chat models."""
from collections.abc import Iterator
from json import JSONDecodeError
from typing import Any, Dict, Iterator, List, Literal, Optional, Type, TypeVar, Union
from typing import Any, Literal, Optional, TypeVar, Union
import openai
from langchain_core.callbacks import (
@ -19,8 +20,8 @@ from typing_extensions import Self
DEFAULT_API_BASE = "https://api.deepseek.com/v1"
_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
_DictOrPydantic = Union[Dict, _BM]
_DictOrPydanticClass = Union[dict[str, Any], type[_BM], type]
_DictOrPydantic = Union[dict, _BM]
class ChatDeepSeek(BaseChatOpenAI):
@ -178,7 +179,7 @@ class ChatDeepSeek(BaseChatOpenAI):
return "chat-deepseek"
@property
def lc_secrets(self) -> Dict[str, str]:
def lc_secrets(self) -> dict[str, str]:
"""A map of constructor argument names to secret ids."""
return {"api_key": "DEEPSEEK_API_KEY"}
@ -217,7 +218,7 @@ class ChatDeepSeek(BaseChatOpenAI):
def _create_chat_result(
self,
response: Union[dict, openai.BaseModel],
generation_info: Optional[Dict] = None,
generation_info: Optional[dict] = None,
) -> ChatResult:
rtn = super()._create_chat_result(response, generation_info)
@ -243,8 +244,8 @@ class ChatDeepSeek(BaseChatOpenAI):
def _convert_chunk_to_generation_chunk(
self,
chunk: dict,
default_chunk_class: Type,
base_generation_info: Optional[Dict],
default_chunk_class: type,
base_generation_info: Optional[dict],
) -> Optional[ChatGenerationChunk]:
generation_chunk = super()._convert_chunk_to_generation_chunk(
chunk,
@ -268,8 +269,8 @@ class ChatDeepSeek(BaseChatOpenAI):
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
@ -287,8 +288,8 @@ class ChatDeepSeek(BaseChatOpenAI):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:

View File

@ -45,8 +45,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
[tool.mypy]
disallow_untyped_defs = "True"
[tool.ruff]
target-version = "py39"
[tool.ruff.lint]
select = ["E", "F", "I", "T201"]
select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.coverage.run]
omit = ["tests/*"]

View File

@ -1,6 +1,6 @@
"""Test ChatDeepSeek chat model."""
from typing import Optional, Type
from typing import Optional
import pytest
from langchain_core.language_models import BaseChatModel
@ -13,7 +13,7 @@ from langchain_deepseek.chat_models import ChatDeepSeek
class TestChatDeepSeek(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[ChatDeepSeek]:
def chat_model_class(self) -> type[ChatDeepSeek]:
return ChatDeepSeek
@property

View File

@ -1,6 +1,6 @@
"""Test chat model integration."""
from typing import Any, Dict, Literal, Type, Union
from typing import Any, Literal, Union
from unittest.mock import MagicMock
from langchain_core.messages import AIMessageChunk
@ -28,9 +28,9 @@ class MockOpenAIResponse(BaseModel):
exclude_none: bool = False,
round_trip: bool = False,
warnings: Union[Literal["none", "warn", "error"], bool] = True,
context: Union[Dict[str, Any], None] = None,
context: Union[dict[str, Any], None] = None,
serialize_as_any: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
choices_list = []
for choice in self.choices:
if isinstance(choice.message, ChatCompletionMessage):
@ -57,7 +57,7 @@ class MockOpenAIResponse(BaseModel):
class TestChatDeepSeekUnit(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[ChatDeepSeek]:
def chat_model_class(self) -> type[ChatDeepSeek]:
return ChatDeepSeek
@property
@ -134,7 +134,7 @@ class TestChatDeepSeekCustomUnit:
def test_convert_chunk_with_reasoning_content(self) -> None:
"""Test that reasoning_content is properly extracted from streaming chunk."""
chat_model = ChatDeepSeek(model="deepseek-chat", api_key=SecretStr("api_key"))
chunk: Dict[str, Any] = {
chunk: dict[str, Any] = {
"choices": [
{
"delta": {
@ -158,7 +158,7 @@ class TestChatDeepSeekCustomUnit:
def test_convert_chunk_with_reasoning(self) -> None:
"""Test that reasoning is properly extracted from streaming chunk."""
chat_model = ChatDeepSeek(model="deepseek-chat", api_key=SecretStr("api_key"))
chunk: Dict[str, Any] = {
chunk: dict[str, Any] = {
"choices": [
{
"delta": {
@ -182,7 +182,7 @@ class TestChatDeepSeekCustomUnit:
def test_convert_chunk_without_reasoning(self) -> None:
"""Test that chunk without reasoning fields works correctly."""
chat_model = ChatDeepSeek(model="deepseek-chat", api_key=SecretStr("api_key"))
chunk: Dict[str, Any] = {"choices": [{"delta": {"content": "Main content"}}]}
chunk: dict[str, Any] = {"choices": [{"delta": {"content": "Main content"}}]}
chunk_result = chat_model._convert_chunk_to_generation_chunk(
chunk, AIMessageChunk, None
@ -194,7 +194,7 @@ class TestChatDeepSeekCustomUnit:
def test_convert_chunk_with_empty_delta(self) -> None:
"""Test that chunk with empty delta works correctly."""
chat_model = ChatDeepSeek(model="deepseek-chat", api_key=SecretStr("api_key"))
chunk: Dict[str, Any] = {"choices": [{"delta": {}}]}
chunk: dict[str, Any] = {"choices": [{"delta": {}}]}
chunk_result = chat_model._convert_chunk_to_generation_chunk(
chunk, AIMessageChunk, None

View File

@ -1,11 +1,10 @@
import os # type: ignore[import-not-found]
from typing import Dict
from exa_py import Exa # type: ignore
from langchain_core.utils import convert_to_secret_str
def initialize_client(values: Dict) -> Dict:
def initialize_client(values: dict) -> dict:
"""Initialize the client."""
exa_api_key = values.get("exa_api_key") or os.environ.get("EXA_API_KEY") or ""
values["exa_api_key"] = convert_to_secret_str(exa_api_key)

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, Literal, Optional, Union
from exa_py import Exa # type: ignore[untyped-import]
from exa_py.api import (
@ -13,7 +13,7 @@ from pydantic import Field, SecretStr, model_validator
from langchain_exa._utilities import initialize_client
def _get_metadata(result: Any) -> Dict[str, Any]:
def _get_metadata(result: Any) -> dict[str, Any]:
"""Get the metadata from a result object."""
metadata = {
"title": result.title,
@ -35,9 +35,9 @@ class ExaSearchRetriever(BaseRetriever):
k: int = 10 # num_results
"""The number of search results to return."""
include_domains: Optional[List[str]] = None
include_domains: Optional[list[str]] = None
"""A list of domains to include in the search."""
exclude_domains: Optional[List[str]] = None
exclude_domains: Optional[list[str]] = None
"""A list of domains to exclude from the search."""
start_crawl_date: Optional[str] = None
"""The start date for the crawl (in YYYY-MM-DD format)."""
@ -62,14 +62,14 @@ class ExaSearchRetriever(BaseRetriever):
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
def validate_environment(cls, values: dict) -> Any:
"""Validate the environment."""
values = initialize_client(values)
return values
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
) -> list[Document]:
response = self.client.search_and_contents( # type: ignore[misc]
query,
num_results=self.k,

View File

@ -1,6 +1,6 @@
"""Tool for the Exa Search API."""
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
from exa_py import Exa # type: ignore[untyped-import]
from exa_py.api import (
@ -66,7 +66,7 @@ class ExaSearchResults(BaseTool): # type: ignore[override]
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
def validate_environment(cls, values: dict) -> Any:
"""Validate the environment."""
values = initialize_client(values)
return values
@ -77,15 +77,15 @@ class ExaSearchResults(BaseTool): # type: ignore[override]
num_results: int,
text_contents_options: Optional[Union[TextContentsOptions, bool]] = None,
highlights: Optional[Union[HighlightsContentsOptions, bool]] = None,
include_domains: Optional[List[str]] = None,
exclude_domains: Optional[List[str]] = None,
include_domains: Optional[list[str]] = None,
exclude_domains: Optional[list[str]] = None,
start_crawl_date: Optional[str] = None,
end_crawl_date: Optional[str] = None,
start_published_date: Optional[str] = None,
end_published_date: Optional[str] = None,
use_autoprompt: Optional[bool] = None,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Union[List[Dict], str]:
) -> Union[list[dict], str]:
"""Use the tool."""
try:
return self.client.search_and_contents(
@ -120,7 +120,7 @@ class ExaFindSimilarResults(BaseTool): # type: ignore[override]
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
def validate_environment(cls, values: dict) -> Any:
"""Validate the environment."""
values = initialize_client(values)
return values
@ -131,8 +131,8 @@ class ExaFindSimilarResults(BaseTool): # type: ignore[override]
num_results: int,
text_contents_options: Optional[Union[TextContentsOptions, bool]] = None,
highlights: Optional[Union[HighlightsContentsOptions, bool]] = None,
include_domains: Optional[List[str]] = None,
exclude_domains: Optional[List[str]] = None,
include_domains: Optional[list[str]] = None,
exclude_domains: Optional[list[str]] = None,
start_crawl_date: Optional[str] = None,
end_crawl_date: Optional[str] = None,
start_published_date: Optional[str] = None,
@ -140,7 +140,7 @@ class ExaFindSimilarResults(BaseTool): # type: ignore[override]
exclude_source_domain: Optional[bool] = None,
category: Optional[str] = None,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Union[List[Dict], str]:
) -> Union[list[dict], str]:
"""Use the tool."""
try:
return self.client.find_similar_and_contents(

View File

@ -45,8 +45,12 @@ langchain-core = { path = "../../core", editable = true }
[tool.mypy]
disallow_untyped_defs = "True"
[tool.ruff]
target-version = "py39"
[tool.ruff.lint]
select = ["E", "F", "I", "T201"]
select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.coverage.run]
omit = ["tests/*"]

View File

@ -4,20 +4,13 @@ from __future__ import annotations
import json
import logging
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypedDict,
Union,
cast,
@ -109,7 +102,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
# Fix for azure
# Also Fireworks returns None for tool invocations
content = _dict.get("content", "") or ""
additional_kwargs: Dict = {}
additional_kwargs: dict = {}
if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call)
tool_calls = []
@ -157,7 +150,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
Returns:
The dictionary.
"""
message_dict: Dict[str, Any]
message_dict: dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
@ -205,14 +198,14 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_chunk_to_message_chunk(
chunk: Mapping[str, Any], default_class: Type[BaseMessageChunk]
chunk: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
choice = chunk["choices"][0]
_dict = choice["delta"]
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {}
tool_call_chunks: List[ToolCallChunk] = []
additional_kwargs: dict = {}
tool_call_chunks: list[ToolCallChunk] = []
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
@ -290,17 +283,17 @@ class ChatFireworks(BaseChatModel):
"""
@property
def lc_secrets(self) -> Dict[str, str]:
def lc_secrets(self) -> dict[str, str]:
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "fireworks"]
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
def lc_attributes(self) -> dict[str, Any]:
attributes: dict[str, Any] = {}
if self.fireworks_api_base:
attributes["fireworks_api_base"] = self.fireworks_api_base
@ -319,9 +312,9 @@ class ChatFireworks(BaseChatModel):
"""Model name to use."""
temperature: float = 0.0
"""What sampling temperature to use."""
stop: Optional[Union[str, List[str]]] = Field(default=None, alias="stop_sequences")
stop: Optional[Union[str, list[str]]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
fireworks_api_key: SecretStr = Field(
alias="api_key",
@ -344,7 +337,7 @@ class ChatFireworks(BaseChatModel):
)
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator."""
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
request_timeout: Union[float, tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
"""Timeout for requests to Fireworks completion API. Can be float, httpx.Timeout or
@ -364,7 +357,7 @@ class ChatFireworks(BaseChatModel):
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names)
@ -398,7 +391,7 @@ class ChatFireworks(BaseChatModel):
return self
@property
def _default_params(self) -> Dict[str, Any]:
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling Fireworks API."""
params = {
"model": self.model_name,
@ -413,7 +406,7 @@ class ChatFireworks(BaseChatModel):
return params
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
@ -429,7 +422,7 @@ class ChatFireworks(BaseChatModel):
ls_params["ls_stop"] = ls_stop
return ls_params
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
system_fingerprint = None
for output in llm_outputs:
@ -452,15 +445,15 @@ class ChatFireworks(BaseChatModel):
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
for chunk in self.client.create(messages=message_dicts, **params):
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
@ -487,8 +480,8 @@ class ChatFireworks(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
@ -509,8 +502,8 @@ class ChatFireworks(BaseChatModel):
return self._create_chat_result(response)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
self, messages: list[BaseMessage], stop: Optional[list[str]]
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
params = self._default_params
if stop is not None:
params["stop"] = stop
@ -547,15 +540,15 @@ class ChatFireworks(BaseChatModel):
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
async for chunk in self.async_client.acreate(messages=message_dicts, **params):
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
@ -584,8 +577,8 @@ class ChatFireworks(BaseChatModel):
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
@ -607,13 +600,13 @@ class ChatFireworks(BaseChatModel):
return self._create_chat_result(response)
@property
def _identifying_params(self) -> Dict[str, Any]:
def _identifying_params(self) -> dict[str, Any]:
"""Get the identifying parameters."""
return {"model_name": self.model_name, **self._default_params}
def _get_invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Dict[str, Any]:
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> dict[str, Any]:
"""Get the parameters used to invoke the model."""
return {
"model": self.model_name,
@ -634,7 +627,7 @@ class ChatFireworks(BaseChatModel):
)
def bind_functions(
self,
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
function_call: Optional[
Union[_FunctionCall, str, Literal["auto", "none"]]
] = None,
@ -690,7 +683,7 @@ class ChatFireworks(BaseChatModel):
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "any", "none"], bool]
@ -738,14 +731,14 @@ class ChatFireworks(BaseChatModel):
def with_structured_output(
self,
schema: Optional[Union[Dict, Type[BaseModel]]] = None,
schema: Optional[Union[dict, type[BaseModel]]] = None,
*,
method: Literal[
"function_calling", "json_mode", "json_schema"
] = "function_calling",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:

View File

@ -1,5 +1,3 @@
from typing import List
from langchain_core.embeddings import Embeddings
from langchain_core.utils import secret_from_env
from openai import OpenAI
@ -96,13 +94,13 @@ class FireworksEmbeddings(BaseModel, Embeddings):
)
return self
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""
return [
i.embedding
for i in self.client.embeddings.create(input=texts, model=self.model).data
]
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
return self.embed_documents([text])[0]

View File

@ -1,7 +1,7 @@
"""Wrapper around Fireworks AI's Completion API."""
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import requests
from aiohttp import ClientSession
@ -63,7 +63,7 @@ class Fireworks(LLM):
for question answering or summarization. A value greater than 1 introduces more
randomness in the output.
"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
top_k: Optional[int] = None
"""Used to limit the number of choices for the next predicted word or token. It
@ -90,7 +90,7 @@ class Fireworks(LLM):
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names)
@ -109,7 +109,7 @@ class Fireworks(LLM):
return f"langchain-fireworks/{__version__}"
@property
def default_params(self) -> Dict[str, Any]:
def default_params(self) -> dict[str, Any]:
return {
"model": self.model,
"temperature": self.temperature,
@ -122,7 +122,7 @@ class Fireworks(LLM):
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
@ -139,7 +139,7 @@ class Fireworks(LLM):
"Content-Type": "application/json",
}
stop_to_use = stop[0] if stop and len(stop) == 1 else stop
payload: Dict[str, Any] = {
payload: dict[str, Any] = {
**self.default_params,
"prompt": prompt,
"stop": stop_to_use,
@ -168,7 +168,7 @@ class Fireworks(LLM):
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
@ -185,7 +185,7 @@ class Fireworks(LLM):
"Content-Type": "application/json",
}
stop_to_use = stop[0] if stop and len(stop) == 1 else stop
payload: Dict[str, Any] = {
payload: dict[str, Any] = {
**self.default_params,
"prompt": prompt,
"stop": stop_to_use,

View File

@ -48,8 +48,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
[tool.mypy]
disallow_untyped_defs = "True"
[tool.ruff]
target-version = "py39"
[tool.ruff.lint]
select = ["E", "F", "I", "T201"]
select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.coverage.run]
omit = ["tests/*"]

View File

@ -4,12 +4,12 @@ You will need FIREWORKS_API_KEY set in your environment to run these tests.
"""
import json
from typing import Any, Literal, Optional
from typing import Annotated, Any, Literal, Optional
import pytest
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk
from pydantic import BaseModel, Field
from typing_extensions import Annotated, TypedDict
from typing_extensions import TypedDict
from langchain_fireworks import ChatFireworks

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests"""
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
@ -14,7 +12,7 @@ from langchain_fireworks import ChatFireworks
class TestFireworksStandard(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return ChatFireworks
@property

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests"""
from typing import Tuple, Type
from langchain_core.embeddings import Embeddings
from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests
@ -10,7 +8,7 @@ from langchain_fireworks import FireworksEmbeddings
class TestFireworksStandard(EmbeddingsUnitTests):
@property
def embeddings_class(self) -> Type[Embeddings]:
def embeddings_class(self) -> type[Embeddings]:
return FireworksEmbeddings
@property
@ -18,7 +16,7 @@ class TestFireworksStandard(EmbeddingsUnitTests):
return {"api_key": "test_api_key"}
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
def init_from_env_params(self) -> tuple[dict, dict, dict]:
return (
{
"FIREWORKS_API_KEY": "api_key",

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests"""
from typing import Tuple, Type
from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests import ( # type: ignore[import-not-found]
ChatModelUnitTests, # type: ignore[import-not-found]
@ -12,7 +10,7 @@ from langchain_fireworks import ChatFireworks
class TestFireworksStandard(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return ChatFireworks
@property
@ -20,7 +18,7 @@ class TestFireworksStandard(ChatModelUnitTests):
return {"api_key": "test_api_key"}
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
def init_from_env_params(self) -> tuple[dict, dict, dict]:
return (
{
"FIREWORKS_API_KEY": "api_key",

View File

@ -4,20 +4,13 @@ from __future__ import annotations
import json
import warnings
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypedDict,
Union,
cast,
@ -307,9 +300,9 @@ class ChatGroq(BaseChatModel):
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences")
stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
groq_api_key: Optional[SecretStr] = Field(
alias="api_key", default_factory=secret_from_env("GROQ_API_KEY", default=None)
@ -324,7 +317,7 @@ class ChatGroq(BaseChatModel):
groq_proxy: Optional[str] = Field(
default_factory=from_env("GROQ_PROXY", default=None)
)
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
request_timeout: Union[float, tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
"""Timeout for requests to Groq completion API. Can be float, httpx.Timeout or
@ -353,7 +346,7 @@ class ChatGroq(BaseChatModel):
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
@ -392,7 +385,7 @@ class ChatGroq(BaseChatModel):
self.default_headers or {}
)
client_params: Dict[str, Any] = {
client_params: dict[str, Any] = {
"api_key": (
self.groq_api_key.get_secret_value() if self.groq_api_key else None
),
@ -406,13 +399,13 @@ class ChatGroq(BaseChatModel):
try:
import groq
sync_specific: Dict[str, Any] = {"http_client": self.http_client}
sync_specific: dict[str, Any] = {"http_client": self.http_client}
if not self.client:
self.client = groq.Groq(
**client_params, **sync_specific
).chat.completions
if not self.async_client:
async_specific: Dict[str, Any] = {"http_client": self.http_async_client}
async_specific: dict[str, Any] = {"http_client": self.http_async_client}
self.async_client = groq.AsyncGroq(
**client_params, **async_specific
).chat.completions
@ -427,7 +420,7 @@ class ChatGroq(BaseChatModel):
# Serializable class method overrides
#
@property
def lc_secrets(self) -> Dict[str, str]:
def lc_secrets(self) -> dict[str, str]:
return {"groq_api_key": "GROQ_API_KEY"}
@classmethod
@ -444,7 +437,7 @@ class ChatGroq(BaseChatModel):
return "groq-chat"
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
@ -480,8 +473,8 @@ class ChatGroq(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -500,8 +493,8 @@ class ChatGroq(BaseChatModel):
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -521,8 +514,8 @@ class ChatGroq(BaseChatModel):
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
@ -530,7 +523,7 @@ class ChatGroq(BaseChatModel):
params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
for chunk in self.client.create(messages=message_dicts, **params):
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
@ -560,8 +553,8 @@ class ChatGroq(BaseChatModel):
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
@ -569,7 +562,7 @@ class ChatGroq(BaseChatModel):
params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
async for chunk in await self.async_client.create(
messages=message_dicts, **params
):
@ -605,7 +598,7 @@ class ChatGroq(BaseChatModel):
# Internal methods
#
@property
def _default_params(self) -> Dict[str, Any]:
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling Groq API."""
params = {
"model": self.model_name,
@ -652,15 +645,15 @@ class ChatGroq(BaseChatModel):
return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
self, messages: list[BaseMessage], stop: Optional[list[str]]
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
params = self._default_params
if stop is not None:
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
system_fingerprint = None
for output in llm_outputs:
@ -688,7 +681,7 @@ class ChatGroq(BaseChatModel):
)
def bind_functions(
self,
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
function_call: Optional[
Union[_FunctionCall, str, Literal["auto", "none"]]
] = None,
@ -743,7 +736,7 @@ class ChatGroq(BaseChatModel):
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "any", "none"], bool]
@ -791,12 +784,12 @@ class ChatGroq(BaseChatModel):
def with_structured_output(
self,
schema: Optional[Union[Dict, Type[BaseModel]]] = None,
schema: Optional[Union[dict, type[BaseModel]]] = None,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
@ -1096,7 +1089,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
Returns:
The dictionary.
"""
message_dict: Dict[str, Any]
message_dict: dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
@ -1142,13 +1135,13 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_chunk_to_message_chunk(
chunk: Mapping[str, Any], default_class: Type[BaseMessageChunk]
chunk: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
choice = chunk["choices"][0]
_dict = choice["delta"]
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {}
additional_kwargs: dict = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
@ -1202,7 +1195,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
return HumanMessage(content=_dict.get("content", ""))
elif role == "assistant":
content = _dict.get("content", "") or ""
additional_kwargs: Dict = {}
additional_kwargs: dict = {}
if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call)
tool_calls = []

View File

@ -40,8 +40,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
[tool.mypy]
disallow_untyped_defs = "True"
[tool.ruff]
target-version = "py39"
[tool.ruff.lint]
select = ["E", "F", "I", "W"]
select = ["E", "F", "I", "W", "UP"]
ignore = [ "UP007", ]
[tool.coverage.run]
omit = ["tests/*"]

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests"""
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.rate_limiters import InMemoryRateLimiter
@ -17,7 +15,7 @@ rate_limiter = InMemoryRateLimiter(requests_per_second=0.2)
class BaseTestGroq(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return ChatGroq
@pytest.mark.xfail(reason="Not yet implemented.")

View File

@ -1,7 +1,7 @@
"""A fake callback handler for testing purposes."""
from itertools import chain
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
@ -15,7 +15,7 @@ class BaseFakeCallbackHandler(BaseModel):
starts: int = 0
ends: int = 0
errors: int = 0
errors_args: List[Any] = []
errors_args: list[Any] = []
text: int = 0
ignore_llm_: bool = False
ignore_chain_: bool = False
@ -264,8 +264,8 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests"""
from typing import Type
from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests.chat_models import (
ChatModelUnitTests,
@ -12,7 +10,7 @@ from langchain_groq import ChatGroq
class TestGroqStandard(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return ChatGroq
@property

View File

@ -1,16 +1,13 @@
"""Hugging Face Chat Wrapper."""
import json
from collections.abc import Sequence
from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Type,
Union,
cast,
)
@ -46,8 +43,8 @@ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."
class TGI_RESPONSE:
"""Response from the TextGenInference API."""
choices: List[Any]
usage: Dict
choices: list[Any]
usage: dict
@dataclass
@ -56,12 +53,12 @@ class TGI_MESSAGE:
role: str
content: str
tool_calls: List[Dict]
tool_calls: list[dict]
def _convert_message_to_chat_message(
message: BaseMessage,
) -> Dict:
) -> dict:
if isinstance(message, ChatMessage):
return dict(role=message.role, content=message.content)
elif isinstance(message, HumanMessage):
@ -104,7 +101,7 @@ def _convert_TGI_message_to_LC_message(
content = cast(str, _message.content)
if content is None:
content = ""
additional_kwargs: Dict = {}
additional_kwargs: dict = {}
if tool_calls := _message.tool_calls:
if "arguments" in tool_calls[0]["function"]:
functions = tool_calls[0]["function"].pop("arguments")
@ -358,8 +355,8 @@ class ChatHuggingFace(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -380,8 +377,8 @@ class ChatHuggingFace(BaseChatModel):
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -398,7 +395,7 @@ class ChatHuggingFace(BaseChatModel):
def _to_chat_prompt(
self,
messages: List[BaseMessage],
messages: list[BaseMessage],
) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
if not messages:
@ -472,7 +469,7 @@ class ChatHuggingFace(BaseChatModel):
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required"], bool]
@ -529,8 +526,8 @@ class ChatHuggingFace(BaseChatModel):
return super().bind(tools=formatted_tools, **kwargs)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> List[Dict[Any, Any]]:
self, messages: list[BaseMessage], stop: Optional[list[str]]
) -> list[dict[Any, Any]]:
message_dicts = [_convert_message_to_chat_message(m) for m in messages]
return message_dicts

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain_core.embeddings import Embeddings
from pydantic import BaseModel, ConfigDict, Field
@ -40,16 +40,16 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
cache_folder: Optional[str] = None
"""Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the Sentence Transformer model, such as `device`,
`prompts`, `default_prompt_name`, `revision`, `trust_remote_code`, or `token`.
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer"""
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
encode_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method for the documents of
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
`precision`, `normalize_embeddings`, and more.
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
query_encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
query_encode_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method for the query of
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
`precision`, `normalize_embeddings`, and more.
@ -102,8 +102,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
)
def _embed(
self, texts: list[str], encode_kwargs: Dict[str, Any]
) -> List[List[float]]:
self, texts: list[str], encode_kwargs: dict[str, Any]
) -> list[list[float]]:
"""
Embed a text using the HuggingFace transformer model.
@ -138,7 +138,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
return embeddings.tolist()
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
Args:
@ -149,7 +149,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
"""
return self._embed(texts, self.encode_kwargs)
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Compute query embeddings using a HuggingFace transformer model.
Args:

View File

@ -1,5 +1,5 @@
import os
from typing import Any, List, Optional
from typing import Any, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.utils import from_env
@ -101,7 +101,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
)
return self
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Call out to HuggingFaceHub's embedding endpoint for embedding search docs.
Args:
@ -117,7 +117,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
responses = self.client.feature_extraction(text=texts, **_model_kwargs)
return responses
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Async Call to HuggingFaceHub's embedding endpoint for embedding search docs.
Args:
@ -134,7 +134,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
)
return responses
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Call out to HuggingFaceHub's embedding endpoint for embedding query text.
Args:
@ -146,7 +146,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
response = self.embed_documents([text])[0]
return response
async def aembed_query(self, text: str) -> List[float]:
async def aembed_query(self, text: str) -> list[float]:
"""Async Call to HuggingFaceHub's embedding endpoint for embedding query text.
Args:

View File

@ -2,7 +2,8 @@ import inspect
import json # type: ignore[import-not-found]
import logging
import os
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
from collections.abc import AsyncIterator, Iterator, Mapping
from typing import Any, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
@ -96,7 +97,7 @@ class HuggingFaceEndpoint(LLM):
"""Whether to prepend the prompt to the generated text"""
truncate: Optional[int] = None
"""Truncate inputs tokens to the given size"""
stop_sequences: List[str] = Field(default_factory=list)
stop_sequences: list[str] = Field(default_factory=list)
"""Stop generating tokens if a member of `stop_sequences` is generated"""
seed: Optional[int] = None
"""Random sampling seed"""
@ -111,9 +112,9 @@ class HuggingFaceEndpoint(LLM):
watermark: bool = False
"""Watermarking with [A Watermark for Large Language Models]
(https://arxiv.org/abs/2301.10226)"""
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
server_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any text-generation-inference server parameters not explicitly specified"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `call` not explicitly specified"""
model: str
client: Any = None #: :meta private:
@ -128,7 +129,7 @@ class HuggingFaceEndpoint(LLM):
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
@ -252,7 +253,7 @@ class HuggingFaceEndpoint(LLM):
return self
@property
def _default_params(self) -> Dict[str, Any]:
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling text generation inference API."""
return {
"max_new_tokens": self.max_new_tokens,
@ -285,8 +286,8 @@ class HuggingFaceEndpoint(LLM):
return "huggingface_endpoint"
def _invocation_params(
self, runtime_stop: Optional[List[str]], **kwargs: Any
) -> Dict[str, Any]:
self, runtime_stop: Optional[list[str]], **kwargs: Any
) -> dict[str, Any]:
params = {**self._default_params, **kwargs}
params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or [])
return params
@ -294,7 +295,7 @@ class HuggingFaceEndpoint(LLM):
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
@ -326,7 +327,7 @@ class HuggingFaceEndpoint(LLM):
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
@ -357,7 +358,7 @@ class HuggingFaceEndpoint(LLM):
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
@ -394,7 +395,7 @@ class HuggingFaceEndpoint(LLM):
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:

View File

@ -2,7 +2,8 @@ from __future__ import annotations # type: ignore[import-not-found]
import importlib.util
import logging
from typing import Any, Dict, Iterator, List, Mapping, Optional
from collections.abc import Iterator, Mapping
from typing import Any, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import BaseLLM
@ -82,7 +83,7 @@ class HuggingFacePipeline(BaseLLM):
@model_validator(mode="before")
@classmethod
def pre_init_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
def pre_init_validator(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Ensure model_id is set either by pipeline or user input."""
if "model_id" not in values:
if "pipeline" in values and values["pipeline"]:
@ -297,13 +298,13 @@ class HuggingFacePipeline(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
prompts: list[str],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
# List to hold all results
text_generations: List[str] = []
text_generations: list[str] = []
pipeline_kwargs = kwargs.get("pipeline_kwargs", {})
skip_prompt = kwargs.get("skip_prompt", False)
@ -347,7 +348,7 @@ class HuggingFacePipeline(BaseLLM):
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:

View File

@ -49,8 +49,12 @@ langchain-community = { path = "../../community", editable = true }
[tool.mypy]
disallow_untyped_defs = "True"
[tool.ruff]
target-version = "py39"
[tool.ruff.lint]
select = ["E", "F", "I", "T201"]
select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.coverage.run]
omit = ["tests/*"]

View File

@ -1,7 +1,5 @@
"""Test HuggingFace embeddings."""
from typing import Type
from langchain_tests.integration_tests import EmbeddingsIntegrationTests
from langchain_huggingface.embeddings import (
@ -12,7 +10,7 @@ from langchain_huggingface.embeddings import (
class TestHuggingFaceEmbeddings(EmbeddingsIntegrationTests):
@property
def embeddings_class(self) -> Type[HuggingFaceEmbeddings]:
def embeddings_class(self) -> type[HuggingFaceEmbeddings]:
return HuggingFaceEmbeddings
@property
@ -22,7 +20,7 @@ class TestHuggingFaceEmbeddings(EmbeddingsIntegrationTests):
class TestHuggingFaceEndpointEmbeddings(EmbeddingsIntegrationTests):
@property
def embeddings_class(self) -> Type[HuggingFaceEndpointEmbeddings]:
def embeddings_class(self) -> type[HuggingFaceEndpointEmbeddings]:
return HuggingFaceEndpointEmbeddings
@property

View File

@ -1,4 +1,4 @@
from typing import Generator
from collections.abc import Generator
from langchain_huggingface.llms import HuggingFacePipeline

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests"""
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
@ -12,7 +10,7 @@ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
class TestHuggingFaceEndpoint(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return ChatHuggingFace
@property

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List # type: ignore[import-not-found]
from typing import Any # type: ignore[import-not-found]
from unittest.mock import MagicMock, Mock, patch
import pytest # type: ignore[import-not-found]
@ -45,7 +45,7 @@ from langchain_huggingface.llms.huggingface_endpoint import (
],
)
def test_convert_message_to_chat_message(
message: BaseMessage, expected: Dict[str, str]
message: BaseMessage, expected: dict[str, str]
) -> None:
result = _convert_message_to_chat_message(message)
assert result == expected
@ -150,7 +150,7 @@ def test_create_chat_result(chat_hugging_face: Any) -> None:
],
)
def test_to_chat_prompt_errors(
chat_hugging_face: Any, messages: List[BaseMessage], expected_error: str
chat_hugging_face: Any, messages: list[BaseMessage], expected_error: str
) -> None:
with pytest.raises(ValueError) as e:
chat_hugging_face._to_chat_prompt(messages)
@ -194,7 +194,7 @@ def test_to_chat_prompt_valid_messages(chat_hugging_face: Any) -> None:
],
)
def test_to_chatml_format(
chat_hugging_face: Any, message: BaseMessage, expected: Dict[str, str]
chat_hugging_face: Any, message: BaseMessage, expected: dict[str, str]
) -> None:
result = chat_hugging_face._to_chatml_format(message)
assert result == expected
@ -207,7 +207,7 @@ def test_to_chatml_format_with_invalid_type(chat_hugging_face: Any) -> None:
assert "Unknown message type:" in str(e.value)
def tool_mock() -> Dict:
def tool_mock() -> dict:
return {"function": {"name": "test_tool"}}
@ -232,7 +232,7 @@ def tool_mock() -> Dict:
)
def test_bind_tools_errors(
chat_hugging_face: Any,
tools: Dict[str, str],
tools: dict[str, str],
tool_choice: Any,
expected_exception: Any,
expected_message: str,

View File

@ -7,20 +7,14 @@ import os
import re
import ssl
import uuid
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import AbstractAsyncContextManager
from operator import itemgetter
from typing import (
Any,
AsyncContextManager,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
@ -142,13 +136,13 @@ def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str:
def _convert_mistral_chat_message_to_message(
_message: Dict,
_message: dict,
) -> BaseMessage:
role = _message["role"]
assert role == "assistant", f"Expected role to be 'assistant', got {role}"
content = cast(str, _message["content"])
additional_kwargs: Dict = {}
additional_kwargs: dict = {}
tool_calls = []
invalid_tool_calls = []
if raw_tool_calls := _message.get("tool_calls"):
@ -196,8 +190,8 @@ async def _araise_on_error(response: httpx.Response) -> None:
async def _aiter_sse(
event_source_mgr: AsyncContextManager[EventSource],
) -> AsyncIterator[Dict]:
event_source_mgr: AbstractAsyncContextManager[EventSource],
) -> AsyncIterator[dict]:
"""Iterate over the server-sent events."""
async with event_source_mgr as event_source:
await _araise_on_error(event_source.response)
@ -234,7 +228,7 @@ async def acompletion_with_retry(
def _convert_chunk_to_message_chunk(
chunk: Dict, default_class: Type[BaseMessageChunk]
chunk: dict, default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
_choice = chunk["choices"][0]
_delta = _choice["delta"]
@ -243,7 +237,7 @@ def _convert_chunk_to_message_chunk(
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
additional_kwargs: Dict = {}
additional_kwargs: dict = {}
response_metadata = {}
if raw_tool_calls := _delta.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
@ -295,7 +289,7 @@ def _convert_chunk_to_message_chunk(
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
"""Format Langchain ToolCall to dict expected by Mistral."""
result: Dict[str, Any] = {
result: dict[str, Any] = {
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
@ -309,7 +303,7 @@ def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) -> dict:
"""Format Langchain InvalidToolCall to dict expected by Mistral."""
result: Dict[str, Any] = {
result: dict[str, Any] = {
"function": {
"name": invalid_tool_call["name"],
"arguments": invalid_tool_call["args"],
@ -323,13 +317,13 @@ def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) ->
def _convert_message_to_mistral_chat_message(
message: BaseMessage,
) -> Dict:
) -> dict:
if isinstance(message, ChatMessage):
return dict(role=message.role, content=message.content)
elif isinstance(message, HumanMessage):
return dict(role="user", content=message.content)
elif isinstance(message, AIMessage):
message_dict: Dict[str, Any] = {"role": "assistant"}
message_dict: dict[str, Any] = {"role": "assistant"}
tool_calls = []
if message.tool_calls or message.invalid_tool_calls:
for tool_call in message.tool_calls:
@ -407,7 +401,7 @@ class ChatMistralAI(BaseChatModel):
random_seed: Optional[int] = None
safe_mode: Optional[bool] = None
streaming: bool = False
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any invocation parameters not explicitly specified."""
model_config = ConfigDict(
@ -417,14 +411,14 @@ class ChatMistralAI(BaseChatModel):
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names)
return values
@property
def _default_params(self) -> Dict[str, Any]:
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling the API."""
defaults = {
"model": self.model,
@ -439,7 +433,7 @@ class ChatMistralAI(BaseChatModel):
return filtered
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
@ -456,7 +450,7 @@ class ChatMistralAI(BaseChatModel):
return ls_params
@property
def _client_params(self) -> Dict[str, Any]:
def _client_params(self) -> dict[str, Any]:
"""Get the parameters used for the client."""
return self._default_params
@ -473,7 +467,7 @@ class ChatMistralAI(BaseChatModel):
stream = kwargs["stream"]
if stream:
def iter_sse() -> Iterator[Dict]:
def iter_sse() -> Iterator[dict]:
with connect_sse(
self.client, "POST", "/chat/completions", json=kwargs
) as event_source:
@ -492,7 +486,7 @@ class ChatMistralAI(BaseChatModel):
rtn = _completion_with_retry(**kwargs)
return rtn
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
@ -557,8 +551,8 @@ class ChatMistralAI(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
@ -577,7 +571,7 @@ class ChatMistralAI(BaseChatModel):
)
return self._create_chat_result(response)
def _create_chat_result(self, response: Dict) -> ChatResult:
def _create_chat_result(self, response: dict) -> ChatResult:
generations = []
token_usage = response.get("usage", {})
for res in response["choices"]:
@ -603,8 +597,8 @@ class ChatMistralAI(BaseChatModel):
return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict], Dict[str, Any]]:
self, messages: list[BaseMessage], stop: Optional[list[str]]
) -> tuple[list[dict], dict[str, Any]]:
params = self._client_params
if stop is not None or "stop" in params:
if "stop" in params:
@ -617,15 +611,15 @@ class ChatMistralAI(BaseChatModel):
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
@ -643,15 +637,15 @@ class ChatMistralAI(BaseChatModel):
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
):
@ -669,8 +663,8 @@ class ChatMistralAI(BaseChatModel):
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
@ -691,7 +685,7 @@ class ChatMistralAI(BaseChatModel):
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
@ -733,14 +727,14 @@ class ChatMistralAI(BaseChatModel):
def with_structured_output(
self,
schema: Optional[Union[Dict, Type]] = None,
schema: Optional[Union[dict, type]] = None,
*,
method: Literal[
"function_calling", "json_mode", "json_schema"
] = "function_calling",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
@ -1048,7 +1042,7 @@ class ChatMistralAI(BaseChatModel):
return llm | output_parser
@property
def _identifying_params(self) -> Dict[str, Any]:
def _identifying_params(self) -> dict[str, Any]:
"""Get the identifying parameters."""
return self._default_params
@ -1058,7 +1052,7 @@ class ChatMistralAI(BaseChatModel):
return "mistralai-chat"
@property
def lc_secrets(self) -> Dict[str, str]:
def lc_secrets(self) -> dict[str, str]:
return {"mistral_api_key": "MISTRAL_API_KEY"}
@classmethod
@ -1067,14 +1061,14 @@ class ChatMistralAI(BaseChatModel):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "mistralai"]
def _convert_to_openai_response_format(
schema: Union[Dict[str, Any], Type], *, strict: Optional[bool] = None
) -> Dict:
schema: Union[dict[str, Any], type], *, strict: Optional[bool] = None
) -> dict:
"""Same as in ChatOpenAI, but don't pass through Pydantic BaseModels."""
if (
isinstance(schema, dict)
@ -1094,8 +1088,10 @@ def _convert_to_openai_response_format(
function["schema"] = function.pop("parameters")
response_format = {"type": "json_schema", "json_schema": function}
if strict is not None and strict is not response_format["json_schema"].get(
"strict"
if (
strict is not None
and strict is not response_format["json_schema"].get("strict")
and isinstance(schema, dict)
):
msg = (
f"Output schema already has 'strict' value set to "

View File

@ -1,7 +1,7 @@
import asyncio
import logging
import warnings
from typing import Iterable, List
from collections.abc import Iterable
import httpx
from httpx import Response
@ -33,7 +33,7 @@ class DummyTokenizer:
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)"""
@staticmethod
def encode_batch(texts: List[str]) -> List[List[str]]:
def encode_batch(texts: list[str]) -> list[list[str]]:
return [list(text) for text in texts]
@ -177,7 +177,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
self.tokenizer = Tokenizer.from_pretrained(
"mistralai/Mixtral-8x7B-v0.1"
)
except IOError: # huggingface_hub GatedRepoError
except OSError: # huggingface_hub GatedRepoError
warnings.warn(
"Could not download mistral tokenizer from Huggingface for "
"calculating batch sizes. Set a Huggingface token via the "
@ -187,10 +187,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
self.tokenizer = DummyTokenizer()
return self
def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
def _get_batches(self, texts: list[str]) -> Iterable[list[str]]:
"""Split a list of texts into batches of less than 16k tokens
for Mistral API."""
batch: List[str] = []
batch: list[str] = []
batch_tokens = 0
text_token_lengths = [
@ -211,7 +211,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
if batch:
yield batch
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of document texts.
Args:
@ -230,7 +230,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
wait=wait_fixed(self.wait_time),
stop=stop_after_attempt(self.max_retries),
)
def _embed_batch(batch: List[str]) -> Response:
def _embed_batch(batch: list[str]) -> Response:
response = self.client.post(
url="/embeddings",
json=dict(
@ -252,7 +252,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
logger.error(f"An error occurred with MistralAI: {e}")
raise
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of document texts.
Args:
@ -283,7 +283,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
logger.error(f"An error occurred with MistralAI: {e}")
raise
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Embed a single query text.
Args:
@ -294,7 +294,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
"""
return self.embed_documents([text])[0]
async def aembed_query(self, text: str) -> List[float]:
async def aembed_query(self, text: str) -> list[float]:
"""Embed a single query text.
Args:

View File

@ -44,8 +44,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
[tool.mypy]
disallow_untyped_defs = "True"
[tool.ruff]
target-version = "py39"
[tool.ruff.lint]
select = ["E", "F", "I", "T201"]
select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.coverage.run]
omit = ["tests/*"]

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests"""
from typing import Type
from langchain_core.language_models import BaseChatModel
from langchain_tests.integration_tests import ( # type: ignore[import-not-found]
ChatModelIntegrationTests, # type: ignore[import-not-found]
@ -12,7 +10,7 @@ from langchain_mistralai import ChatMistralAI
class TestMistralStandard(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return ChatMistralAI
@property

View File

@ -1,7 +1,8 @@
"""Test MistralAI Chat API wrapper."""
import os
from typing import Any, AsyncGenerator, Dict, Generator, List, cast
from collections.abc import AsyncGenerator, Generator
from typing import Any, cast
from unittest.mock import MagicMock, patch
import httpx
@ -104,13 +105,13 @@ def test_mistralai_initialization_baseurl_env(env_var_name: str) -> None:
],
)
def test_convert_message_to_mistral_chat_message(
message: BaseMessage, expected: Dict
message: BaseMessage, expected: dict
) -> None:
result = _convert_message_to_mistral_chat_message(message)
assert result == expected
def _make_completion_response_from_token(token: str) -> Dict:
def _make_completion_response_from_token(token: str) -> dict:
return dict(
id="abc123",
model="fake_model",
@ -236,7 +237,7 @@ def test__convert_dict_to_message_tool_call() -> None:
def test_custom_token_counting() -> None:
def token_encoder(text: str) -> List[int]:
def token_encoder(text: str) -> list[int]:
return [1, 2, 3]
llm = ChatMistralAI(custom_get_token_ids=token_encoder)

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests"""
from typing import Type
from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests import ( # type: ignore[import-not-found]
ChatModelUnitTests, # type: ignore[import-not-found]
@ -12,5 +10,5 @@ from langchain_mistralai import ChatMistralAI
class TestMistralStandard(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return ChatMistralAI

View File

@ -1,5 +1,5 @@
import os
from typing import List, Literal, Optional, overload
from typing import Literal, Optional, overload
import nomic # type: ignore[import]
from langchain_core.embeddings import Embeddings
@ -86,7 +86,7 @@ class NomicEmbeddings(Embeddings):
self.device = device
self.vision_model = vision_model
def embed(self, texts: List[str], *, task_type: str) -> List[List[float]]:
def embed(self, texts: list[str], *, task_type: str) -> list[list[float]]:
"""Embed texts.
Args:
@ -105,7 +105,7 @@ class NomicEmbeddings(Embeddings):
)
return output["embeddings"]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs.
Args:
@ -116,7 +116,7 @@ class NomicEmbeddings(Embeddings):
task_type="search_document",
)
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Embed query text.
Args:
@ -127,7 +127,7 @@ class NomicEmbeddings(Embeddings):
task_type="search_query",
)[0]
def embed_image(self, uris: List[str]) -> List[List[float]]:
def embed_image(self, uris: list[str]) -> list[list[float]]:
return embed.image(
images=uris,
model=self.vision_model,

View File

@ -40,13 +40,18 @@ dev = ["langchain-core"]
[tool.uv.sources]
langchain-core = { path = "../../core", editable = true }
[tool.ruff]
target-version = "py39"
[tool.ruff.lint]
select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
"T201", # print
"UP", # pyupgrade
]
ignore = [ "UP007", ]
[tool.mypy]
disallow_untyped_defs = "True"

View File

@ -3,7 +3,6 @@
It provides infrastructure for interacting with the Ollama service.
"""
from importlib import metadata
from langchain_ollama.chat_models import ChatOllama

View File

@ -1,21 +1,14 @@
"""Ollama chat models."""
import json
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Final,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
@ -153,7 +146,7 @@ def _parse_arguments_from_tool_call(
def _get_tool_calls_from_response(
response: Mapping[str, Any],
) -> List[ToolCall]:
) -> list[ToolCall]:
"""Get tool calls from ollama response."""
tool_calls = []
if "message" in response:
@ -341,7 +334,7 @@ class ChatOllama(BaseChatModel):
model: str
"""Model name to use."""
extract_reasoning: Optional[Union[bool, Tuple[str, str]]] = False
extract_reasoning: Optional[Union[bool, tuple[str, str]]] = False
"""Whether to extract the reasoning tokens in think blocks.
Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`.
If a tuple is supplied, they are assumed to be the (start, end) tokens.
@ -399,7 +392,7 @@ class ChatOllama(BaseChatModel):
to a specific number will make the model generate the same text for
the same prompt."""
stop: Optional[List[str]] = None
stop: Optional[list[str]] = None
"""Sets the stop tokens to use."""
tfs_z: Optional[float] = None
@ -443,10 +436,10 @@ class ChatOllama(BaseChatModel):
def _chat_params(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
ollama_messages = self._convert_messages_to_ollama_messages(messages)
if self.stop is not None and stop is not None:
@ -499,13 +492,13 @@ class ChatOllama(BaseChatModel):
return self
def _convert_messages_to_ollama_messages(
self, messages: List[BaseMessage]
self, messages: list[BaseMessage]
) -> Sequence[Message]:
ollama_messages: List = []
ollama_messages: list = []
for message in messages:
role: Literal["user", "assistant", "system", "tool"]
tool_call_id: Optional[str] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
tool_calls: Optional[list[dict[str, Any]]] = None
if isinstance(message, HumanMessage):
role = "user"
elif isinstance(message, AIMessage):
@ -531,7 +524,7 @@ class ChatOllama(BaseChatModel):
if isinstance(message.content, str):
content = message.content
else:
for content_part in cast(List[Dict], message.content):
for content_part in cast(list[dict], message.content):
if content_part.get("type") == "text":
content += f"\n{content_part['text']}"
elif content_part.get("type") == "tool_use":
@ -583,7 +576,7 @@ class ChatOllama(BaseChatModel):
def _extract_reasoning(
self, message_chunk: BaseMessageChunk, is_thinking: bool
) -> Tuple[BaseMessageChunk, bool]:
) -> tuple[BaseMessageChunk, bool]:
"""Mutate a message chunk to extract reasoning content."""
if not self.extract_reasoning:
return message_chunk, is_thinking
@ -605,8 +598,8 @@ class ChatOllama(BaseChatModel):
async def _acreate_chat_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
chat_params = self._chat_params(messages, stop, **kwargs)
@ -619,8 +612,8 @@ class ChatOllama(BaseChatModel):
def _create_chat_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Iterator[Union[Mapping[str, Any], str]]:
chat_params = self._chat_params(messages, stop, **kwargs)
@ -632,8 +625,8 @@ class ChatOllama(BaseChatModel):
def _chat_stream_with_aggregation(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
@ -657,8 +650,8 @@ class ChatOllama(BaseChatModel):
async def _achat_stream_with_aggregation(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
@ -681,7 +674,7 @@ class ChatOllama(BaseChatModel):
return final_chunk
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
@ -697,8 +690,8 @@ class ChatOllama(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -719,8 +712,8 @@ class ChatOllama(BaseChatModel):
def _iterate_over_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
is_thinking = False
@ -758,8 +751,8 @@ class ChatOllama(BaseChatModel):
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
@ -773,8 +766,8 @@ class ChatOllama(BaseChatModel):
async def _aiterate_over_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
is_thinking = False
@ -812,8 +805,8 @@ class ChatOllama(BaseChatModel):
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
@ -827,8 +820,8 @@ class ChatOllama(BaseChatModel):
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -854,7 +847,7 @@ class ChatOllama(BaseChatModel):
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
*,
tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None,
**kwargs: Any,
@ -877,12 +870,12 @@ class ChatOllama(BaseChatModel):
def with_structured_output(
self,
schema: Union[Dict, type],
schema: Union[dict, type],
*,
method: Literal["function_calling", "json_mode", "json_schema"] = "json_schema",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:

View File

@ -1,6 +1,6 @@
"""Ollama embeddings models."""
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain_core.embeddings import Embeddings
from ollama import AsyncClient, Client
@ -188,7 +188,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
"""The temperature of the model. Increasing the temperature will
make the model answer more creatively. (Default: 0.8)"""
stop: Optional[List[str]] = None
stop: Optional[list[str]] = None
"""Sets the stop tokens to use."""
tfs_z: Optional[float] = None
@ -211,7 +211,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
)
@property
def _default_params(self) -> Dict[str, Any]:
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling Ollama."""
return {
"mirostat": self.mirostat,
@ -237,18 +237,18 @@ class OllamaEmbeddings(BaseModel, Embeddings):
self._async_client = AsyncClient(host=self.base_url, **client_kwargs)
return self
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""
embedded_docs = self._client.embed(
self.model, texts, options=self._default_params, keep_alive=self.keep_alive
)["embeddings"]
return embedded_docs
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
return self.embed_documents([text])[0]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""
embedded_docs = (
await self._async_client.embed(
@ -257,6 +257,6 @@ class OllamaEmbeddings(BaseModel, Embeddings):
)["embeddings"]
return embedded_docs
async def aembed_query(self, text: str) -> List[float]:
async def aembed_query(self, text: str) -> list[float]:
"""Embed query text."""
return (await self.aembed_documents([text]))[0]

View File

@ -1,13 +1,9 @@
"""Ollama large language models."""
from collections.abc import AsyncIterator, Iterator, Mapping
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Union,
)
@ -89,7 +85,7 @@ class OllamaLLM(BaseLLM):
to a specific number will make the model generate the same text for
the same prompt."""
stop: Optional[List[str]] = None
stop: Optional[list[str]] = None
"""Sets the stop tokens to use."""
tfs_z: Optional[float] = None
@ -134,9 +130,9 @@ class OllamaLLM(BaseLLM):
def _generate_params(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
@ -181,7 +177,7 @@ class OllamaLLM(BaseLLM):
return "ollama-llm"
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs)
@ -200,7 +196,7 @@ class OllamaLLM(BaseLLM):
async def _acreate_generate_stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
async for part in await self._async_client.generate(
@ -211,7 +207,7 @@ class OllamaLLM(BaseLLM):
def _create_generate_stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Iterator[Union[Mapping[str, Any], str]]:
yield from self._client.generate(
@ -221,7 +217,7 @@ class OllamaLLM(BaseLLM):
async def _astream_with_aggregation(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
@ -253,7 +249,7 @@ class OllamaLLM(BaseLLM):
def _stream_with_aggregation(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
@ -284,8 +280,8 @@ class OllamaLLM(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
prompts: list[str],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
@ -303,8 +299,8 @@ class OllamaLLM(BaseLLM):
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
prompts: list[str],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
@ -323,7 +319,7 @@ class OllamaLLM(BaseLLM):
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
@ -345,7 +341,7 @@ class OllamaLLM(BaseLLM):
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:

View File

@ -40,6 +40,9 @@ langchain-tests = { path = "../../standard-tests", editable = true }
[tool.mypy]
disallow_untyped_defs = "True"
[tool.ruff]
target-version = "py39"
[tool.ruff.lint]
select = [
"E", # pycodestyle
@ -47,8 +50,9 @@ select = [
"I", # isort
"T201", # print
"D", # pydocstyle
"UP", # pyupgrade
]
ignore = [ "UP007", ]
[tool.ruff.lint.pydocstyle]
convention = "google"

View File

@ -1,4 +1,5 @@
"""load multiple Python files specified as command line arguments."""
import sys
import traceback
from importlib.machinery import SourceFileLoader

View File

@ -1,10 +1,10 @@
"""Ollama specific chat model integration tests"""
from typing import List, Optional
from typing import Annotated, Optional
import pytest
from pydantic import BaseModel, Field
from typing_extensions import Annotated, TypedDict
from typing_extensions import TypedDict
from langchain_ollama import ChatOllama
@ -78,7 +78,7 @@ def test_structured_output_deeply_nested(model: str) -> None:
class Data(BaseModel):
"""Extracted data about people."""
people: List[Person]
people: list[Person]
chat = llm.with_structured_output(Data) # type: ignore[arg-type]
text = (

View File

@ -1,7 +1,5 @@
"""Test chat model integration using standard integration tests."""
from typing import Type
from langchain_tests.integration_tests import ChatModelIntegrationTests
from langchain_ollama.chat_models import ChatOllama
@ -9,7 +7,7 @@ from langchain_ollama.chat_models import ChatOllama
class TestChatOllama(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[ChatOllama]:
def chat_model_class(self) -> type[ChatOllama]:
return ChatOllama
@property

View File

@ -1,7 +1,5 @@
"""Test Ollama embeddings."""
from typing import Type
from langchain_tests.integration_tests import EmbeddingsIntegrationTests
from langchain_ollama.embeddings import OllamaEmbeddings
@ -9,7 +7,7 @@ from langchain_ollama.embeddings import OllamaEmbeddings
class TestOllamaEmbeddings(EmbeddingsIntegrationTests):
@property
def embeddings_class(self) -> Type[OllamaEmbeddings]:
def embeddings_class(self) -> type[OllamaEmbeddings]:
return OllamaEmbeddings
@property

View File

@ -1,6 +1,6 @@
"""Test chat model integration."""
import json
from typing import Dict, Type
from langchain_tests.unit_tests import ChatModelUnitTests
@ -9,11 +9,11 @@ from langchain_ollama.chat_models import ChatOllama, _parse_arguments_from_tool_
class TestChatOllama(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[ChatOllama]:
def chat_model_class(self) -> type[ChatOllama]:
return ChatOllama
@property
def chat_model_params(self) -> Dict:
def chat_model_params(self) -> dict:
return {"model": "llama3-groq-tool-use"}

View File

@ -4,18 +4,8 @@ from __future__ import annotations
import logging
import os
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Type,
TypedDict,
TypeVar,
Union,
)
from collections.abc import Awaitable
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
import openai
from langchain_core.language_models import LanguageModelInput
@ -34,8 +24,8 @@ logger = logging.getLogger(__name__)
_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM]]
_DictOrPydantic = Union[Dict, _BM]
_DictOrPydanticClass = Union[dict[str, Any], type[_BM]]
_DictOrPydantic = Union[dict, _BM]
class _AllReturnType(TypedDict):
@ -547,7 +537,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
Used for tracing and token counting. Does NOT affect completion.
"""
disabled_params: Optional[Dict[str, Any]] = Field(default=None)
disabled_params: Optional[dict[str, Any]] = Field(default=None)
"""Parameters of the OpenAI client or chat.completions endpoint that should be
disabled for the given model.
@ -570,12 +560,12 @@ class AzureChatOpenAI(BaseChatOpenAI):
"""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "azure_openai"]
@property
def lc_secrets(self) -> Dict[str, str]:
def lc_secrets(self) -> dict[str, str]:
return {
"openai_api_key": "AZURE_OPENAI_API_KEY",
"azure_ad_token": "AZURE_OPENAI_AD_TOKEN",
@ -672,7 +662,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
return self
@property
def _identifying_params(self) -> Dict[str, Any]:
def _identifying_params(self) -> dict[str, Any]:
"""Get the identifying parameters."""
return {
**{"azure_deployment": self.deployment_name},
@ -684,14 +674,14 @@ class AzureChatOpenAI(BaseChatOpenAI):
return "azure-openai-chat"
@property
def lc_attributes(self) -> Dict[str, Any]:
def lc_attributes(self) -> dict[str, Any]:
return {
"openai_api_type": self.openai_api_type,
"openai_api_version": self.openai_api_version,
}
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get the parameters used to invoke the model."""
params = super()._get_ls_params(stop=stop, **kwargs)
@ -710,7 +700,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
def _create_chat_result(
self,
response: Union[dict, openai.BaseModel],
generation_info: Optional[Dict] = None,
generation_info: Optional[dict] = None,
) -> ChatResult:
chat_result = super()._create_chat_result(response, generation_info)

View File

@ -10,6 +10,7 @@ import re
import ssl
import sys
import warnings
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from functools import partial
from io import BytesIO
from json import JSONDecodeError
@ -18,17 +19,9 @@ from operator import itemgetter
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
@ -137,7 +130,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
# Fix for azure
# Also OpenAI returns None for tool invocations
content = _dict.get("content", "") or ""
additional_kwargs: Dict = {}
additional_kwargs: dict = {}
if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call)
tool_calls = []
@ -243,7 +236,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
Returns:
The dictionary.
"""
message_dict: Dict[str, Any] = {"content": _format_message_content(message.content)}
message_dict: dict[str, Any] = {"content": _format_message_content(message.content)}
if (name := message.name or message.additional_kwargs.get("name")) is not None:
message_dict["name"] = name
@ -304,12 +297,12 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
id_ = _dict.get("id")
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {}
additional_kwargs: dict = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
@ -418,8 +411,8 @@ class _FunctionCall(TypedDict):
_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
_DictOrPydantic = Union[Dict, _BM]
_DictOrPydanticClass = Union[dict[str, Any], type[_BM], type]
_DictOrPydantic = Union[dict, _BM]
class _AllReturnType(TypedDict):
@ -437,7 +430,7 @@ class BaseChatOpenAI(BaseChatModel):
"""Model name to use."""
temperature: Optional[float] = None
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[SecretStr] = Field(
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
@ -451,7 +444,7 @@ class BaseChatOpenAI(BaseChatModel):
openai_proxy: Optional[str] = Field(
default_factory=from_env("OPENAI_PROXY", default=None)
)
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
request_timeout: Union[float, tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
@ -476,7 +469,7 @@ class BaseChatOpenAI(BaseChatModel):
"""Number of most likely tokens to return at each token position, each with
an associated log probability. `logprobs` must be set to true
if this parameter is used."""
logit_bias: Optional[Dict[int, int]] = None
logit_bias: Optional[dict[int, int]] = None
"""Modify the likelihood of specified tokens appearing in the completion."""
streaming: bool = False
"""Whether to stream the results or not."""
@ -517,14 +510,14 @@ class BaseChatOpenAI(BaseChatModel):
http_async_client: Union[Any, None] = Field(default=None, exclude=True)
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences")
stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences."""
extra_body: Optional[Mapping[str, Any]] = None
"""Optional additional JSON properties to include in the request parameters when
making requests to OpenAI compatible APIs, such as vLLM."""
include_response_headers: bool = False
"""Whether to include response headers in the output message response_metadata."""
disabled_params: Optional[Dict[str, Any]] = Field(default=None)
disabled_params: Optional[dict[str, Any]] = Field(default=None)
"""Parameters of the OpenAI client or chat.completions endpoint that should be
disabled for the given model.
@ -554,7 +547,7 @@ class BaseChatOpenAI(BaseChatModel):
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names)
@ -562,7 +555,7 @@ class BaseChatOpenAI(BaseChatModel):
@model_validator(mode="before")
@classmethod
def validate_temperature(cls, values: Dict[str, Any]) -> Any:
def validate_temperature(cls, values: dict[str, Any]) -> Any:
"""Currently o1 models only allow temperature=1."""
model = values.get("model_name") or values.get("model") or ""
if model.startswith("o1") and "temperature" not in values:
@ -642,7 +635,7 @@ class BaseChatOpenAI(BaseChatModel):
return self
@property
def _default_params(self) -> Dict[str, Any]:
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
exclude_if_none = {
"presence_penalty": self.presence_penalty,
@ -669,7 +662,7 @@ class BaseChatOpenAI(BaseChatModel):
return params
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
system_fingerprint = None
for output in llm_outputs:
@ -697,8 +690,8 @@ class BaseChatOpenAI(BaseChatModel):
def _convert_chunk_to_generation_chunk(
self,
chunk: dict,
default_chunk_class: Type,
base_generation_info: Optional[Dict],
default_chunk_class: type,
base_generation_info: Optional[dict],
) -> Optional[ChatGenerationChunk]:
if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
return None
@ -749,8 +742,8 @@ class BaseChatOpenAI(BaseChatModel):
def _stream_responses(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
@ -783,8 +776,8 @@ class BaseChatOpenAI(BaseChatModel):
async def _astream_responses(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
@ -838,8 +831,8 @@ class BaseChatOpenAI(BaseChatModel):
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
*,
stream_usage: Optional[bool] = None,
@ -850,7 +843,7 @@ class BaseChatOpenAI(BaseChatModel):
if stream_usage:
kwargs["stream_options"] = {"include_usage": stream_usage}
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
if "response_format" in payload:
@ -908,8 +901,8 @@ class BaseChatOpenAI(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -965,7 +958,7 @@ class BaseChatOpenAI(BaseChatModel):
self,
input_: LanguageModelInput,
*,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> dict:
messages = self._convert_input(input_).to_messages()
@ -982,7 +975,7 @@ class BaseChatOpenAI(BaseChatModel):
def _create_chat_result(
self,
response: Union[dict, openai.BaseModel],
generation_info: Optional[Dict] = None,
generation_info: Optional[dict] = None,
) -> ChatResult:
generations = []
@ -1032,8 +1025,8 @@ class BaseChatOpenAI(BaseChatModel):
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
*,
stream_usage: Optional[bool] = None,
@ -1044,7 +1037,7 @@ class BaseChatOpenAI(BaseChatModel):
if stream_usage:
kwargs["stream_options"] = {"include_usage": stream_usage}
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
if "response_format" in payload:
@ -1106,8 +1099,8 @@ class BaseChatOpenAI(BaseChatModel):
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -1160,13 +1153,13 @@ class BaseChatOpenAI(BaseChatModel):
)
@property
def _identifying_params(self) -> Dict[str, Any]:
def _identifying_params(self) -> dict[str, Any]:
"""Get the identifying parameters."""
return {"model_name": self.model_name, **self._default_params}
def _get_invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Dict[str, Any]:
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> dict[str, Any]:
"""Get the parameters used to invoke the model."""
return {
"model": self.model_name,
@ -1176,7 +1169,7 @@ class BaseChatOpenAI(BaseChatModel):
}
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
@ -1199,7 +1192,7 @@ class BaseChatOpenAI(BaseChatModel):
"""Return type of chat model."""
return "openai-chat"
def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:
if self.tiktoken_model_name is not None:
model = self.tiktoken_model_name
else:
@ -1211,7 +1204,7 @@ class BaseChatOpenAI(BaseChatModel):
encoding = tiktoken.get_encoding(model)
return model, encoding
def get_token_ids(self, text: str) -> List[int]:
def get_token_ids(self, text: str) -> list[int]:
"""Get the tokens present in the text with tiktoken package."""
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
@ -1223,9 +1216,9 @@ class BaseChatOpenAI(BaseChatModel):
def get_num_tokens_from_messages(
self,
messages: List[BaseMessage],
messages: list[BaseMessage],
tools: Optional[
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
Sequence[Union[dict[str, Any], type, Callable, BaseTool]]
] = None,
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
@ -1327,7 +1320,7 @@ class BaseChatOpenAI(BaseChatModel):
)
def bind_functions(
self,
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
function_call: Optional[
Union[_FunctionCall, str, Literal["auto", "none"]]
] = None,
@ -1380,7 +1373,7 @@ class BaseChatOpenAI(BaseChatModel):
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
@ -1727,7 +1720,7 @@ class BaseChatOpenAI(BaseChatModel):
else:
return llm | output_parser
def _filter_disabled_params(self, **kwargs: Any) -> Dict[str, Any]:
def _filter_disabled_params(self, **kwargs: Any) -> dict[str, Any]:
if not self.disabled_params:
return kwargs
filtered = {}
@ -2301,17 +2294,17 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
"""Maximum number of tokens to generate."""
@property
def lc_secrets(self) -> Dict[str, str]:
def lc_secrets(self) -> dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "openai"]
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
def lc_attributes(self) -> dict[str, Any]:
attributes: dict[str, Any] = {}
if self.openai_organization:
attributes["openai_organization"] = self.openai_organization
@ -2330,7 +2323,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
return True
@property
def _default_params(self) -> Dict[str, Any]:
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
params = super()._default_params
if "max_tokens" in params:
@ -2342,7 +2335,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
self,
input_: LanguageModelInput,
*,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> dict:
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
@ -2719,7 +2712,7 @@ def _lc_invalid_tool_call_to_openai_tool_call(
}
def _url_to_size(image_source: str) -> Optional[Tuple[int, int]]:
def _url_to_size(image_source: str) -> Optional[tuple[int, int]]:
try:
from PIL import Image # type: ignore[import]
except ImportError:
@ -2771,7 +2764,7 @@ def _is_b64(s: str) -> bool:
return s.startswith("data:image")
def _resize(width: int, height: int) -> Tuple[int, int]:
def _resize(width: int, height: int) -> tuple[int, int]:
# larger side must be <= 2048
if width > 2048 or height > 2048:
if width > height:
@ -2792,8 +2785,8 @@ def _resize(width: int, height: int) -> Tuple[int, int]:
def _convert_to_openai_response_format(
schema: Union[Dict[str, Any], Type], *, strict: Optional[bool] = None
) -> Union[Dict, TypeBaseModel]:
schema: Union[dict[str, Any], type], *, strict: Optional[bool] = None
) -> Union[dict, TypeBaseModel]:
if isinstance(schema, type) and is_basemodel_subclass(schema):
return schema
@ -2815,8 +2808,10 @@ def _convert_to_openai_response_format(
function["schema"] = function.pop("parameters")
response_format = {"type": "json_schema", "json_schema": function}
if strict is not None and strict is not response_format["json_schema"].get(
"strict"
if (
strict is not None
and strict is not response_format["json_schema"].get("strict")
and isinstance(schema, dict)
):
msg = (
f"Output schema already has 'strict' value set to "
@ -2829,7 +2824,7 @@ def _convert_to_openai_response_format(
def _oai_structured_outputs_parser(
ai_msg: AIMessage, schema: Type[_BM]
ai_msg: AIMessage, schema: type[_BM]
) -> Optional[PydanticBaseModel]:
if parsed := ai_msg.additional_kwargs.get("parsed"):
if isinstance(parsed, dict):
@ -3141,7 +3136,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
def _construct_lc_result_from_responses_api(
response: Response,
schema: Optional[Type[_BM]] = None,
schema: Optional[type[_BM]] = None,
metadata: Optional[dict] = None,
) -> ChatResult:
"""Construct ChatResponse from OpenAI Response API response."""
@ -3278,7 +3273,7 @@ def _construct_lc_result_from_responses_api(
def _convert_responses_chunk_to_generation_chunk(
chunk: Any, schema: Optional[Type[_BM]] = None, metadata: Optional[dict] = None
chunk: Any, schema: Optional[type[_BM]] = None, metadata: Optional[dict] = None
) -> Optional[ChatGenerationChunk]:
content = []
tool_call_chunks: list = []

View File

@ -2,12 +2,13 @@
from __future__ import annotations
from typing import Awaitable, Callable, Optional, Union
from collections.abc import Awaitable
from typing import Callable, Optional, Union, cast
import openai
from langchain_core.utils import from_env, secret_from_env
from pydantic import Field, SecretStr, model_validator
from typing_extensions import Self, cast
from typing_extensions import Self
from langchain_openai.embeddings.base import OpenAIEmbeddings

View File

@ -2,20 +2,8 @@ from __future__ import annotations
import logging
import warnings
from typing import (
Any,
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, Union, cast
import openai
import tiktoken
@ -29,19 +17,19 @@ logger = logging.getLogger(__name__)
def _process_batched_chunked_embeddings(
num_texts: int,
tokens: List[Union[List[int], str]],
batched_embeddings: List[List[float]],
indices: List[int],
tokens: list[Union[list[int], str]],
batched_embeddings: list[list[float]],
indices: list[int],
skip_empty: bool,
) -> List[Optional[List[float]]]:
) -> list[Optional[list[float]]]:
# for each text, this is the list of embeddings (list of list of floats)
# corresponding to the chunks of the text
results: List[List[List[float]]] = [[] for _ in range(num_texts)]
results: list[list[list[float]]] = [[] for _ in range(num_texts)]
# for each text, this is the token length of each chunk
# for transformers tokenization, this is the string length
# for tiktoken, this is the number of tokens
num_tokens_in_batch: List[List[int]] = [[] for _ in range(num_texts)]
num_tokens_in_batch: list[list[int]] = [[] for _ in range(num_texts)]
for i in range(len(indices)):
if skip_empty and len(batched_embeddings[i]) == 1:
@ -50,10 +38,10 @@ def _process_batched_chunked_embeddings(
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
# for each text, this is the final embedding
embeddings: List[Optional[List[float]]] = []
embeddings: list[Optional[list[float]]] = []
for i in range(num_texts):
# an embedding for each chunk
_result: List[List[float]] = results[i]
_result: list[list[float]] = results[i]
if len(_result) == 0:
# this will be populated with the embedding of an empty string
@ -213,13 +201,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
),
)
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
allowed_special: Union[Literal["all"], Set[str], None] = None
disallowed_special: Union[Literal["all"], Set[str], Sequence[str], None] = None
allowed_special: Union[Literal["all"], set[str], None] = None
disallowed_special: Union[Literal["all"], set[str], Sequence[str], None] = None
chunk_size: int = 1000
"""Maximum number of texts to embed in each batch"""
max_retries: int = 2
"""Maximum number of retries to make when generating."""
request_timeout: Optional[Union[float, Tuple[float, float], Any]] = Field(
request_timeout: Optional[Union[float, tuple[float, float], Any]] = Field(
default=None, alias="timeout"
)
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
@ -240,7 +228,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
when tiktoken is called, you can specify a model name to use here."""
show_progress_bar: bool = False
"""Whether to show a progress bar when embedding."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
skip_empty: bool = False
"""Whether to skip empty strings when embedding or raise an error.
@ -270,7 +258,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
@ -354,15 +342,15 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
return self
@property
def _invocation_params(self) -> Dict[str, Any]:
params: Dict = {"model": self.model, **self.model_kwargs}
def _invocation_params(self) -> dict[str, Any]:
params: dict = {"model": self.model, **self.model_kwargs}
if self.dimensions is not None:
params["dimensions"] = self.dimensions
return params
def _tokenize(
self, texts: List[str], chunk_size: int
) -> Tuple[Iterable[int], List[Union[List[int], str]], List[int]]:
self, texts: list[str], chunk_size: int
) -> tuple[Iterable[int], list[Union[list[int], str]], list[int]]:
"""
Take the input `texts` and `chunk_size` and return 3 iterables as a tuple:
@ -383,8 +371,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
indices: An iterable of the same length as `tokens` that maps each token-array
to the index of the original text in `texts`.
"""
tokens: List[Union[List[int], str]] = []
indices: List[int] = []
tokens: list[Union[list[int], str]] = []
indices: list[int] = []
model_name = self.tiktoken_model_name or self.model
# If tiktoken flag set to False
@ -403,11 +391,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
)
for i, text in enumerate(texts):
# Tokenize the text using HuggingFace transformers
tokenized: List[int] = tokenizer.encode(text, add_special_tokens=False)
tokenized: list[int] = tokenizer.encode(text, add_special_tokens=False)
# Split tokens into chunks respecting the embedding_ctx_length
for j in range(0, len(tokenized), self.embedding_ctx_length):
token_chunk: List[int] = tokenized[
token_chunk: list[int] = tokenized[
j : j + self.embedding_ctx_length
]
@ -420,7 +408,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
encoding = tiktoken.encoding_for_model(model_name)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
encoder_kwargs: Dict[str, Any] = {
encoder_kwargs: dict[str, Any] = {
k: v
for k, v in {
"allowed_special": self.allowed_special,
@ -459,8 +447,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# please refer to
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
def _get_len_safe_embeddings(
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
) -> List[List[float]]:
self, texts: list[str], *, engine: str, chunk_size: Optional[int] = None
) -> list[list[float]]:
"""
Generate length-safe embeddings for a list of texts.
@ -478,7 +466,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""
_chunk_size = chunk_size or self.chunk_size
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
batched_embeddings: List[List[float]] = []
batched_embeddings: list[list[float]] = []
for i in _iter:
response = self.client.create(
input=tokens[i : i + _chunk_size], **self._invocation_params
@ -490,9 +478,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
embeddings = _process_batched_chunked_embeddings(
len(texts), tokens, batched_embeddings, indices, self.skip_empty
)
_cached_empty_embedding: Optional[List[float]] = None
_cached_empty_embedding: Optional[list[float]] = None
def empty_embedding() -> List[float]:
def empty_embedding() -> list[float]:
nonlocal _cached_empty_embedding
if _cached_empty_embedding is None:
average_embedded = self.client.create(
@ -508,8 +496,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# please refer to
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
async def _aget_len_safe_embeddings(
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
) -> List[List[float]]:
self, texts: list[str], *, engine: str, chunk_size: Optional[int] = None
) -> list[list[float]]:
"""
Asynchronously generate length-safe embeddings for a list of texts.
@ -528,7 +516,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
_chunk_size = chunk_size or self.chunk_size
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
batched_embeddings: List[List[float]] = []
batched_embeddings: list[list[float]] = []
_chunk_size = chunk_size or self.chunk_size
for i in range(0, len(tokens), _chunk_size):
response = await self.async_client.create(
@ -542,9 +530,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
embeddings = _process_batched_chunked_embeddings(
len(texts), tokens, batched_embeddings, indices, self.skip_empty
)
_cached_empty_embedding: Optional[List[float]] = None
_cached_empty_embedding: Optional[list[float]] = None
async def empty_embedding() -> List[float]:
async def empty_embedding() -> list[float]:
nonlocal _cached_empty_embedding
if _cached_empty_embedding is None:
average_embedded = await self.async_client.create(
@ -558,8 +546,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
return [e if e is not None else await empty_embedding() for e in embeddings]
def embed_documents(
self, texts: List[str], chunk_size: int | None = None
) -> List[List[float]]:
self, texts: list[str], chunk_size: int | None = None
) -> list[list[float]]:
"""Call out to OpenAI's embedding endpoint for embedding search docs.
Args:
@ -572,7 +560,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""
chunk_size_ = chunk_size or self.chunk_size
if not self.check_embedding_ctx_length:
embeddings: List[List[float]] = []
embeddings: list[list[float]] = []
for i in range(0, len(texts), chunk_size_):
response = self.client.create(
input=texts[i : i + chunk_size_], **self._invocation_params
@ -588,8 +576,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
return self._get_len_safe_embeddings(texts, engine=engine)
async def aembed_documents(
self, texts: List[str], chunk_size: int | None = None
) -> List[List[float]]:
self, texts: list[str], chunk_size: int | None = None
) -> list[list[float]]:
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
Args:
@ -602,7 +590,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""
chunk_size_ = chunk_size or self.chunk_size
if not self.check_embedding_ctx_length:
embeddings: List[List[float]] = []
embeddings: list[list[float]] = []
for i in range(0, len(texts), chunk_size_):
response = await self.async_client.create(
input=texts[i : i + chunk_size_], **self._invocation_params
@ -617,7 +605,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
engine = cast(str, self.deployment)
return await self._aget_len_safe_embeddings(texts, engine=engine)
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text.
Args:
@ -628,7 +616,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""
return self.embed_documents([text])[0]
async def aembed_query(self, text: str) -> List[float]:
async def aembed_query(self, text: str) -> list[float]:
"""Call out to OpenAI's embedding endpoint async for embedding query text.
Args:

View File

@ -1,13 +1,14 @@
from __future__ import annotations
import logging
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Union
from collections.abc import Awaitable, Mapping
from typing import Any, Callable, Optional, Union, cast
import openai
from langchain_core.language_models import LangSmithParams
from langchain_core.utils import from_env, secret_from_env
from pydantic import Field, SecretStr, model_validator
from typing_extensions import Self, cast
from typing_extensions import Self
from langchain_openai.llms.base import BaseOpenAI
@ -91,12 +92,12 @@ class AzureOpenAI(BaseOpenAI):
"""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "openai"]
@property
def lc_secrets(self) -> Dict[str, str]:
def lc_secrets(self) -> dict[str, str]:
return {
"openai_api_key": "AZURE_OPENAI_API_KEY",
"azure_ad_token": "AZURE_OPENAI_AD_TOKEN",
@ -188,12 +189,12 @@ class AzureOpenAI(BaseOpenAI):
}
@property
def _invocation_params(self) -> Dict[str, Any]:
def _invocation_params(self) -> dict[str, Any]:
openai_params = {"model": self.deployment_name}
return {**openai_params, **super()._invocation_params}
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs)
@ -209,7 +210,7 @@ class AzureOpenAI(BaseOpenAI):
return "azure"
@property
def lc_attributes(self) -> Dict[str, Any]:
def lc_attributes(self) -> dict[str, Any]:
return {
"openai_api_type": self.openai_api_type,
"openai_api_version": self.openai_api_version,

View File

@ -2,21 +2,8 @@ from __future__ import annotations
import logging
import sys
from typing import (
AbstractSet,
Any,
AsyncIterator,
Collection,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Set,
Tuple,
Union,
)
from collections.abc import AsyncIterator, Collection, Iterator, Mapping
from typing import Any, Literal, Optional, Union
import openai
import tiktoken
@ -35,7 +22,7 @@ logger = logging.getLogger(__name__)
def _update_token_usage(
keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
keys: set[str], response: dict[str, Any], token_usage: dict[str, Any]
) -> None:
"""Update token usage."""
_keys_to_use = keys.intersection(response["usage"])
@ -47,7 +34,7 @@ def _update_token_usage(
def _stream_response_to_generation_chunk(
stream_response: Dict[str, Any],
stream_response: dict[str, Any],
) -> GenerationChunk:
"""Convert a stream response to a generation chunk."""
if not stream_response["choices"]:
@ -84,7 +71,7 @@ class BaseOpenAI(BaseLLM):
"""How many completions to generate for each prompt."""
best_of: int = 1
"""Generates best_of completions server-side and returns the "best"."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[SecretStr] = Field(
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
@ -108,12 +95,12 @@ class BaseOpenAI(BaseLLM):
)
batch_size: int = 20
"""Batch size to use when passing multiple documents to generate."""
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
request_timeout: Union[float, tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
None."""
logit_bias: Optional[Dict[str, float]] = None
logit_bias: Optional[dict[str, float]] = None
"""Adjust the probability of specific tokens being generated."""
max_retries: int = 2
"""Maximum number of retries to make when generating."""
@ -124,7 +111,7 @@ class BaseOpenAI(BaseLLM):
as well the chosen tokens."""
streaming: bool = False
"""Whether to stream the results or not."""
allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
allowed_special: Union[Literal["all"], set[str]] = set()
"""Set of special tokens that are allowed。"""
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
"""Set of special tokens that are not allowed。"""
@ -157,7 +144,7 @@ class BaseOpenAI(BaseLLM):
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names)
@ -197,9 +184,9 @@ class BaseOpenAI(BaseLLM):
return self
@property
def _default_params(self) -> Dict[str, Any]:
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
normal_params: Dict[str, Any] = {
normal_params: dict[str, Any] = {
"temperature": self.temperature,
"top_p": self.top_p,
"frequency_penalty": self.frequency_penalty,
@ -228,7 +215,7 @@ class BaseOpenAI(BaseLLM):
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
@ -255,7 +242,7 @@ class BaseOpenAI(BaseLLM):
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
@ -283,8 +270,8 @@ class BaseOpenAI(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
prompts: list[str],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
@ -307,7 +294,7 @@ class BaseOpenAI(BaseLLM):
params = {**params, **kwargs}
sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = []
token_usage: Dict[str, int] = {}
token_usage: dict[str, int] = {}
# Get the token usage from the response.
# Includes prompt, completion, and total tokens used.
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
@ -363,8 +350,8 @@ class BaseOpenAI(BaseLLM):
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
prompts: list[str],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
@ -373,7 +360,7 @@ class BaseOpenAI(BaseLLM):
params = {**params, **kwargs}
sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = []
token_usage: Dict[str, int] = {}
token_usage: dict[str, int] = {}
# Get the token usage from the response.
# Includes prompt, completion, and total tokens used.
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
@ -419,10 +406,10 @@ class BaseOpenAI(BaseLLM):
def get_sub_prompts(
self,
params: Dict[str, Any],
prompts: List[str],
stop: Optional[List[str]] = None,
) -> List[List[str]]:
params: dict[str, Any],
prompts: list[str],
stop: Optional[list[str]] = None,
) -> list[list[str]]:
"""Get the sub prompts for llm call."""
if stop is not None:
params["stop"] = stop
@ -441,9 +428,9 @@ class BaseOpenAI(BaseLLM):
def create_llm_result(
self,
choices: Any,
prompts: List[str],
params: Dict[str, Any],
token_usage: Dict[str, int],
prompts: list[str],
params: dict[str, Any],
token_usage: dict[str, int],
*,
system_fingerprint: Optional[str] = None,
) -> LLMResult:
@ -470,7 +457,7 @@ class BaseOpenAI(BaseLLM):
return LLMResult(generations=generations, llm_output=llm_output)
@property
def _invocation_params(self) -> Dict[str, Any]:
def _invocation_params(self) -> dict[str, Any]:
"""Get the parameters used to invoke the model."""
return self._default_params
@ -484,7 +471,7 @@ class BaseOpenAI(BaseLLM):
"""Return type of llm."""
return "openai"
def get_token_ids(self, text: str) -> List[int]:
def get_token_ids(self, text: str) -> list[int]:
"""Get the token IDs using the tiktoken package."""
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
@ -689,7 +676,7 @@ class OpenAI(BaseOpenAI):
""" # noqa: E501
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "openai"]
@ -699,16 +686,16 @@ class OpenAI(BaseOpenAI):
return True
@property
def _invocation_params(self) -> Dict[str, Any]:
def _invocation_params(self) -> dict[str, Any]:
return {**{"model": self.model_name}, **super()._invocation_params}
@property
def lc_secrets(self) -> Dict[str, str]:
def lc_secrets(self) -> dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
def lc_attributes(self) -> dict[str, Any]:
attributes: dict[str, Any] = {}
if self.openai_api_base:
attributes["openai_api_base"] = self.openai_api_base

View File

@ -59,8 +59,12 @@ disallow_untyped_defs = "True"
module = "transformers"
ignore_missing_imports = true
[tool.ruff]
target-version = "py39"
[tool.ruff.lint]
select = ["E", "F", "I", "T201"]
select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.ruff.format]
docstring-code-format = true

View File

@ -1,7 +1,6 @@
"""Standard LangChain interface tests"""
import os
from typing import Type
from langchain_core.language_models import BaseChatModel
from langchain_tests.integration_tests import ChatModelIntegrationTests
@ -14,7 +13,7 @@ OPENAI_API_BASE = os.environ.get("AZURE_OPENAI_API_BASE", "")
class TestAzureOpenAIStandard(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return AzureChatOpenAI
@property
@ -40,7 +39,7 @@ class TestAzureOpenAIStandardLegacy(ChatModelIntegrationTests):
"""Test a legacy model."""
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return AzureChatOpenAI
@property

View File

@ -2,9 +2,10 @@
import base64
import json
from collections.abc import AsyncIterator
from pathlib import Path
from textwrap import dedent
from typing import Any, AsyncIterator, List, Literal, Optional, cast
from typing import Any, Literal, Optional, cast
import httpx
import openai
@ -531,14 +532,14 @@ class MakeASandwich(BaseModel):
bread_type: str
cheese_type: str
condiments: List[str]
vegetables: List[str]
condiments: list[str]
vegetables: list[str]
def test_tool_use() -> None:
llm = ChatOpenAI(model="gpt-4-turbo", temperature=0)
llm_with_tool = llm.bind_tools(tools=[GenerateUsername], tool_choice=True)
msgs: List = [HumanMessage("Sally has green hair, what would her username be?")]
msgs: list = [HumanMessage("Sally has green hair, what would her username be?")]
ai_msg = llm_with_tool.invoke(msgs)
assert isinstance(ai_msg, AIMessage)
@ -583,7 +584,7 @@ def test_manual_tool_call_msg(use_responses_api: bool) -> None:
model="gpt-3.5-turbo-0125", temperature=0, use_responses_api=use_responses_api
)
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
msgs: List = [
msgs: list = [
HumanMessage("Sally has green hair, what would her username be?"),
AIMessage(
content="",
@ -1045,7 +1046,7 @@ def test_audio_output_modality() -> None:
},
)
history: List[BaseMessage] = [
history: list[BaseMessage] = [
HumanMessage("Make me a short audio clip of you yelling")
]

View File

@ -1,7 +1,7 @@
"""Standard LangChain interface tests"""
from pathlib import Path
from typing import Dict, List, Literal, Type, cast
from typing import Literal, cast
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
@ -14,7 +14,7 @@ REPO_ROOT_DIR = Path(__file__).parents[6]
class TestOpenAIStandard(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return ChatOpenAI
@property
@ -36,9 +36,9 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
@property
def supported_usage_metadata_details(
self,
) -> Dict[
) -> dict[
Literal["invoke", "stream"],
List[
list[
Literal[
"audio_input",
"audio_output",
@ -51,7 +51,7 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
return {"invoke": ["reasoning_output", "cache_read_input"], "stream": []}
def invoke_with_cache_read_input(self, *, stream: bool = False) -> AIMessage:
with open(REPO_ROOT_DIR / "README.md", "r") as f:
with open(REPO_ROOT_DIR / "README.md") as f:
readme = f.read()
input_ = f"""What's langchain? Here's the langchain README:

View File

@ -2,7 +2,7 @@
import json
import os
from typing import Any, Optional, cast
from typing import Annotated, Any, Optional, cast
import openai
import pytest
@ -13,7 +13,7 @@ from langchain_core.messages import (
BaseMessageChunk,
)
from pydantic import BaseModel
from typing_extensions import Annotated, TypedDict
from typing_extensions import TypedDict
from langchain_openai import ChatOpenAI

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests for Responses API"""
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
@ -11,7 +9,7 @@ from tests.integration_tests.chat_models.test_base_standard import TestOpenAISta
class TestOpenAIResponses(TestOpenAIStandard):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return ChatOpenAI
@property

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests"""
from typing import Type
from langchain_core.embeddings import Embeddings
from langchain_tests.integration_tests.embeddings import EmbeddingsIntegrationTests
@ -10,7 +8,7 @@ from langchain_openai import OpenAIEmbeddings
class TestOpenAIStandard(EmbeddingsIntegrationTests):
@property
def embeddings_class(self) -> Type[Embeddings]:
def embeddings_class(self) -> type[Embeddings]:
return OpenAIEmbeddings
@property

View File

@ -1,7 +1,8 @@
"""Test AzureOpenAI wrapper."""
import os
from typing import Any, Generator
from collections.abc import Generator
from typing import Any
import pytest
from langchain_core.callbacks import CallbackManager

View File

@ -1,6 +1,6 @@
"""Test OpenAI llm."""
from typing import Generator
from collections.abc import Generator
import pytest
from langchain_core.callbacks import CallbackManager

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests"""
from typing import Tuple, Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
@ -12,7 +10,7 @@ from langchain_openai import AzureChatOpenAI
class TestOpenAIStandard(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return AzureChatOpenAI
@property
@ -30,7 +28,7 @@ class TestOpenAIStandard(ChatModelUnitTests):
super().test_bind_tool_pydantic(model, my_adder_tool)
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
def init_from_env_params(self) -> tuple[dict, dict, dict]:
return (
{
"AZURE_OPENAI_API_KEY": "api_key",

View File

@ -3,7 +3,7 @@
import json
from functools import partial
from types import TracebackType
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
from typing import Any, Literal, Optional, Union, cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -241,7 +241,7 @@ class MockAsyncContextManager:
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_type: Optional[type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
@ -270,7 +270,7 @@ class MockSyncContextManager:
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_type: Optional[type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
@ -382,7 +382,7 @@ DEEPSEEK_STREAM_DATA = """{"id":"d3610c24e6b42518a7883ea57c3ea2c3","choices":[{"
@pytest.fixture
def mock_deepseek_completion() -> List[Dict]:
def mock_deepseek_completion() -> list[dict]:
list_chunk_data = DEEPSEEK_STREAM_DATA.split("\n")
result_list = []
for msg in list_chunk_data:
@ -450,7 +450,7 @@ OPENAI_STREAM_DATA = """{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":
@pytest.fixture
def mock_openai_completion() -> List[Dict]:
def mock_openai_completion() -> list[dict]:
list_chunk_data = OPENAI_STREAM_DATA.split("\n")
result_list = []
for msg in list_chunk_data:
@ -615,7 +615,7 @@ def test_openai_invoke_name(mock_client: MagicMock) -> None:
def test_custom_token_counting() -> None:
def token_encoder(text: str) -> List[int]:
def token_encoder(text: str) -> list[int]:
return [1, 2, 3]
llm = ChatOpenAI(custom_get_token_ids=token_encoder)
@ -662,8 +662,8 @@ class MakeASandwich(BaseModel):
bread_type: str
cheese_type: str
condiments: List[str]
vegetables: List[str]
condiments: list[str]
vegetables: list[str]
@pytest.mark.parametrize(
@ -695,7 +695,7 @@ def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> Non
@pytest.mark.parametrize("include_raw", [True, False])
@pytest.mark.parametrize("strict", [True, False, None])
def test_with_structured_output(
schema: Union[Type, Dict[str, Any], None],
schema: Union[type, dict[str, Any], None],
method: Literal["function_calling", "json_mode", "json_schema"],
include_raw: bool,
strict: Optional[bool],
@ -787,7 +787,7 @@ class Foo(BaseModel):
# FooV1
],
)
def test_schema_from_with_structured_output(schema: Type) -> None:
def test_schema_from_with_structured_output(schema: type) -> None:
"""Test schema from with_structured_output."""
llm = ChatOpenAI(model="gpt-4o")

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests"""
from typing import Tuple, Type
from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests import ChatModelUnitTests
@ -10,11 +8,11 @@ from langchain_openai import ChatOpenAI
class TestOpenAIStandard(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return ChatOpenAI
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
def init_from_env_params(self) -> tuple[dict, dict, dict]:
return (
{
"OPENAI_API_KEY": "api_key",

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests"""
from typing import Tuple, Type
from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests import ChatModelUnitTests
@ -10,7 +8,7 @@ from langchain_openai import ChatOpenAI
class TestOpenAIResponses(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return ChatOpenAI
@property
@ -18,7 +16,7 @@ class TestOpenAIResponses(ChatModelUnitTests):
return {"use_responses_api": True}
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
def init_from_env_params(self) -> tuple[dict, dict, dict]:
return (
{
"OPENAI_API_KEY": "api_key",

View File

@ -1,5 +1,3 @@
from typing import Tuple, Type
from langchain_core.embeddings import Embeddings
from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests
@ -8,7 +6,7 @@ from langchain_openai import AzureOpenAIEmbeddings
class TestAzureOpenAIStandard(EmbeddingsUnitTests):
@property
def embeddings_class(self) -> Type[Embeddings]:
def embeddings_class(self) -> type[Embeddings]:
return AzureOpenAIEmbeddings
@property
@ -16,7 +14,7 @@ class TestAzureOpenAIStandard(EmbeddingsUnitTests):
return {"api_key": "api_key", "azure_endpoint": "https://endpoint.com"}
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
def init_from_env_params(self) -> tuple[dict, dict, dict]:
return (
{
"AZURE_OPENAI_API_KEY": "api_key",

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests"""
from typing import Tuple, Type
from langchain_core.embeddings import Embeddings
from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests
@ -10,11 +8,11 @@ from langchain_openai import OpenAIEmbeddings
class TestOpenAIStandard(EmbeddingsUnitTests):
@property
def embeddings_class(self) -> Type[Embeddings]:
def embeddings_class(self) -> type[Embeddings]:
return OpenAIEmbeddings
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
def init_from_env_params(self) -> tuple[dict, dict, dict]:
return (
{
"OPENAI_API_KEY": "api_key",

View File

@ -1,7 +1,7 @@
"""A fake callback handler for testing purposes."""
from itertools import chain
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
@ -15,7 +15,7 @@ class BaseFakeCallbackHandler(BaseModel):
starts: int = 0
ends: int = 0
errors: int = 0
errors_args: List[Any] = []
errors_args: list[Any] = []
text: int = 0
ignore_llm_: bool = False
ignore_chain_: bool = False
@ -195,8 +195,8 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,

View File

@ -1,5 +1,4 @@
import os
from typing import List
import pytest
@ -65,7 +64,7 @@ def test_get_token_ids(model: str) -> None:
def test_custom_token_counting() -> None:
def token_encoder(text: str) -> List[int]:
def token_encoder(text: str) -> list[int]:
return [1, 2, 3]
llm = OpenAI(custom_get_token_ids=token_encoder)

View File

@ -1,4 +1,4 @@
from typing import Type, cast
from typing import cast
import pytest
from langchain_core.load import dumpd
@ -72,7 +72,7 @@ def test_azure_openai_embeddings_secrets() -> None:
@pytest.mark.parametrize(
"model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings]
)
def test_azure_openai_api_key_is_secret_string(model_class: Type) -> None:
def test_azure_openai_api_key_is_secret_string(model_class: type) -> None:
"""Test that the API key is stored as a SecretStr."""
model = model_class(
openai_api_key="secret-api-key",
@ -88,7 +88,7 @@ def test_azure_openai_api_key_is_secret_string(model_class: Type) -> None:
"model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings]
)
def test_azure_openai_api_key_masked_when_passed_from_env(
model_class: Type, monkeypatch: MonkeyPatch, capsys: CaptureFixture
model_class: type, monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test that the API key is masked when passed from an environment variable."""
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "secret-api-key")
@ -109,7 +109,7 @@ def test_azure_openai_api_key_masked_when_passed_from_env(
"model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings]
)
def test_azure_openai_api_key_masked_when_passed_via_constructor(
model_class: Type, capsys: CaptureFixture
model_class: type, capsys: CaptureFixture
) -> None:
"""Test that the API key is masked when passed via the constructor."""
model = model_class(
@ -133,7 +133,7 @@ def test_azure_openai_api_key_masked_when_passed_via_constructor(
"model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings]
)
def test_azure_openai_uses_actual_secret_value_from_secretstr(
model_class: Type,
model_class: type,
) -> None:
"""Test that the actual secret value is correctly retrieved."""
model = model_class(
@ -147,7 +147,7 @@ def test_azure_openai_uses_actual_secret_value_from_secretstr(
@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
def test_openai_api_key_is_secret_string(model_class: Type) -> None:
def test_openai_api_key_is_secret_string(model_class: type) -> None:
"""Test that the API key is stored as a SecretStr."""
model = model_class(openai_api_key="secret-api-key")
assert isinstance(model.openai_api_key, SecretStr)
@ -155,7 +155,7 @@ def test_openai_api_key_is_secret_string(model_class: Type) -> None:
@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
def test_openai_api_key_masked_when_passed_from_env(
model_class: Type, monkeypatch: MonkeyPatch, capsys: CaptureFixture
model_class: type, monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test that the API key is masked when passed from an environment variable."""
monkeypatch.setenv("OPENAI_API_KEY", "secret-api-key")
@ -168,7 +168,7 @@ def test_openai_api_key_masked_when_passed_from_env(
@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
def test_openai_api_key_masked_when_passed_via_constructor(
model_class: Type, capsys: CaptureFixture
model_class: type, capsys: CaptureFixture
) -> None:
"""Test that the API key is masked when passed via the constructor."""
model = model_class(openai_api_key="secret-api-key")
@ -179,14 +179,14 @@ def test_openai_api_key_masked_when_passed_via_constructor(
@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
def test_openai_uses_actual_secret_value_from_secretstr(model_class: Type) -> None:
def test_openai_uses_actual_secret_value_from_secretstr(model_class: type) -> None:
"""Test that the actual secret value is correctly retrieved."""
model = model_class(openai_api_key="secret-api-key")
assert cast(SecretStr, model.openai_api_key).get_secret_value() == "secret-api-key"
@pytest.mark.parametrize("model_class", [AzureChatOpenAI, AzureOpenAI])
def test_azure_serialized_secrets(model_class: Type) -> None:
def test_azure_serialized_secrets(model_class: type) -> None:
"""Test that the actual secret value is correctly retrieved."""
model = model_class(
openai_api_key="secret-api-key", api_version="foo", azure_endpoint="foo"

View File

@ -3,20 +3,9 @@
from __future__ import annotations
import logging
from collections.abc import Iterator, Mapping
from operator import itemgetter
from typing import (
Any,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from typing import Any, Literal, Optional, TypeVar, Union
import openai
from langchain_core.callbacks import CallbackManagerForLLMRun
@ -50,8 +39,8 @@ from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
_DictOrPydantic = Union[Dict, _BM]
_DictOrPydanticClass = Union[dict[str, Any], type[_BM], type]
_DictOrPydantic = Union[dict, _BM]
logger = logging.getLogger(__name__)
@ -162,14 +151,14 @@ class ChatPerplexity(BaseChatModel):
"""Model name."""
temperature: float = 0.7
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
pplx_api_key: Optional[SecretStr] = Field(
default_factory=secret_from_env("PPLX_API_KEY", default=None), alias="api_key"
)
"""Base URL path for API requests,
leave blank if not using a proxy or service emulator."""
request_timeout: Optional[Union[float, Tuple[float, float]]] = Field(
request_timeout: Optional[Union[float, tuple[float, float]]] = Field(
None, alias="timeout"
)
"""Timeout for requests to PerplexityChat completion API. Default is None."""
@ -183,12 +172,12 @@ class ChatPerplexity(BaseChatModel):
model_config = ConfigDict(populate_by_name=True)
@property
def lc_secrets(self) -> Dict[str, str]:
def lc_secrets(self) -> dict[str, str]:
return {"pplx_api_key": "PPLX_API_KEY"}
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
@ -232,7 +221,7 @@ class ChatPerplexity(BaseChatModel):
return self
@property
def _default_params(self) -> Dict[str, Any]:
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling PerplexityChat API."""
return {
"max_tokens": self.max_tokens,
@ -241,7 +230,7 @@ class ChatPerplexity(BaseChatModel):
**self.model_kwargs,
}
def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
def _convert_message_to_dict(self, message: BaseMessage) -> dict[str, Any]:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, SystemMessage):
@ -255,8 +244,8 @@ class ChatPerplexity(BaseChatModel):
return message_dict
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
self, messages: list[BaseMessage], stop: Optional[list[str]]
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
params = dict(self._invocation_params)
if stop is not None:
if "stop" in params:
@ -266,11 +255,11 @@ class ChatPerplexity(BaseChatModel):
return message_dicts, params
def _convert_delta_to_message_chunk(
self, _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
self, _dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
additional_kwargs: Dict = {}
additional_kwargs: dict = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
@ -296,8 +285,8 @@ class ChatPerplexity(BaseChatModel):
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
@ -367,8 +356,8 @@ class ChatPerplexity(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -402,7 +391,7 @@ class ChatPerplexity(BaseChatModel):
@property
def _invocation_params(self) -> Mapping[str, Any]:
"""Get the parameters used to invoke the model."""
pplx_creds: Dict[str, Any] = {"model": self.model}
pplx_creds: dict[str, Any] = {"model": self.model}
return {**pplx_creds, **self._default_params}
@property

View File

@ -55,8 +55,12 @@ plugins = ['pydantic.mypy']
module = "transformers"
ignore_missing_imports = true
[tool.ruff]
target-version = "py39"
[tool.ruff.lint]
select = ["E", "F", "I", "T201"]
select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.ruff.format]
docstring-code-format = true

View File

@ -1,7 +1,5 @@
"""Standard LangChain interface tests."""
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_tests.integration_tests import ChatModelIntegrationTests
@ -11,7 +9,7 @@ from langchain_perplexity import ChatPerplexity
class TestPerplexityStandard(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return ChatPerplexity
@property

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from unittest.mock import MagicMock
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
@ -51,7 +51,7 @@ def test_perplexity_stream_includes_citations(mocker: MockerFixture) -> None:
"choices": [{"delta": {"content": "Perplexity"}, "finish_reason": None}],
"citations": ["example.com", "example2.com"],
}
mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
mock_chunks: list[dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
mock_stream = MagicMock()
mock_stream.__iter__.return_value = mock_chunks
patcher = mocker.patch.object(
@ -103,7 +103,7 @@ def test_perplexity_stream_includes_citations_and_images(mocker: MockerFixture)
}
],
}
mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
mock_chunks: list[dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
mock_stream = MagicMock()
mock_stream.__iter__.return_value = mock_chunks
patcher = mocker.patch.object(
@ -162,7 +162,7 @@ def test_perplexity_stream_includes_citations_and_related_questions(
"citations": ["example.com", "example2.com"],
"related_questions": ["example_question_1", "example_question_2"],
}
mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
mock_chunks: list[dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
mock_stream = MagicMock()
mock_stream.__iter__.return_value = mock_chunks
patcher = mocker.patch.object(

View File

@ -1,7 +1,5 @@
"""Test Perplexity Chat API wrapper."""
from typing import Tuple, Type
from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests import ChatModelUnitTests
@ -10,9 +8,9 @@ from langchain_perplexity import ChatPerplexity
class TestPerplexityStandard(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_class(self) -> type[BaseChatModel]:
return ChatPerplexity
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
def init_from_env_params(self) -> tuple[dict, dict, dict]:
return ({"PPLX_API_KEY": "api_key"}, {}, {"pplx_api_key": "api_key"})

View File

@ -5,7 +5,7 @@ import json
import os
import re
from pathlib import Path
from typing import Any, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union
from typing import Any, Generic, Literal, Optional, TypeVar, Union
import yaml
from pydantic import BaseModel, ConfigDict, Field, FilePath
@ -24,7 +24,7 @@ class PropertySettings(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
type: Literal["string", "number", "array", "object", "boolean"]
default: Union[str, int, float, List, Dict, bool, None] = Field(default=None)
default: Union[str, int, float, list, dict, bool, None] = Field(default=None)
description: str = Field(default="")
@ -58,8 +58,8 @@ class Prompty(BaseModel):
# metadata
name: str = Field(default="")
description: str = Field(default="")
authors: List[str] = Field(default=[])
tags: List[str] = Field(default=[])
authors: list[str] = Field(default=[])
tags: list[str] = Field(default=[])
version: str = Field(default="")
base: str = Field(default="")
basePrompty: Optional[Prompty] = Field(default=None)
@ -70,8 +70,8 @@ class Prompty(BaseModel):
sample: dict = Field(default={})
# input / output
inputs: Dict[str, PropertySettings] = Field(default={})
outputs: Dict[str, PropertySettings] = Field(default={})
inputs: dict[str, PropertySettings] = Field(default={})
outputs: dict[str, PropertySettings] = Field(default={})
# template
template: TemplateSettings
@ -79,7 +79,7 @@ class Prompty(BaseModel):
file: FilePath = Field(default="") # type: ignore[assignment]
content: str = Field(default="")
def to_safe_dict(self) -> Dict[str, Any]:
def to_safe_dict(self) -> dict[str, Any]:
d = {}
for k, v in self:
if v != "" and v != {} and v != [] and v is not None:
@ -130,7 +130,7 @@ class Prompty(BaseModel):
attribute.startswith("file:")
and Path(parent / attribute.split(":")[1]).exists()
):
with open(parent / attribute.split(":")[1], "r") as f:
with open(parent / attribute.split(":")[1]) as f:
items = json.load(f)
if isinstance(items, list):
return [Prompty.normalize(value, parent) for value in items]
@ -155,8 +155,8 @@ class Prompty(BaseModel):
def param_hoisting(
top: Dict[str, Any], bottom: Dict[str, Any], top_key: Any = None
) -> Dict[str, Any]:
top: dict[str, Any], bottom: dict[str, Any], top_key: Any = None
) -> dict[str, Any]:
"""Merge two dictionaries with hoisting of parameters from bottom to top.
Args:
@ -198,18 +198,18 @@ class NoOpParser(Invoker):
return data
class InvokerFactory(object):
class InvokerFactory:
"""Factory for creating invokers."""
_instance = None
_renderers: Dict[str, Type[Invoker]] = {}
_parsers: Dict[str, Type[Invoker]] = {}
_executors: Dict[str, Type[Invoker]] = {}
_processors: Dict[str, Type[Invoker]] = {}
_renderers: dict[str, type[Invoker]] = {}
_parsers: dict[str, type[Invoker]] = {}
_executors: dict[str, type[Invoker]] = {}
_processors: dict[str, type[Invoker]] = {}
def __new__(cls) -> InvokerFactory:
if cls._instance is None:
cls._instance = super(InvokerFactory, cls).__new__(cls)
cls._instance = super().__new__(cls)
# Add NOOP invokers
cls._renderers["NOOP"] = NoOpParser
cls._parsers["NOOP"] = NoOpParser
@ -221,7 +221,7 @@ class InvokerFactory(object):
self,
type: Literal["renderer", "parser", "executor", "processor"],
name: str,
invoker: Type[Invoker],
invoker: type[Invoker],
) -> None:
if type == "renderer":
self._renderers[name] = invoker
@ -264,7 +264,7 @@ class InvokerFactory(object):
else:
raise ValueError(f"Invalid type {type}")
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return {
"renderers": {
k: f"{v.__module__}.{v.__name__}" for k, v in self._renderers.items()

View File

@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import Runnable, RunnableLambda
@ -10,10 +10,10 @@ from .utils import load, prepare
def create_chat_prompt(
path: str,
input_name_agent_scratchpad: str = "agent_scratchpad",
) -> Runnable[Dict[str, Any], ChatPromptTemplate]:
) -> Runnable[dict[str, Any], ChatPromptTemplate]:
"""Create a chat prompt from a Langchain schema."""
def runnable_chat_lambda(inputs: Dict[str, Any]) -> ChatPromptTemplate:
def runnable_chat_lambda(inputs: dict[str, Any]) -> ChatPromptTemplate:
p = load(path)
parsed = prepare(p, inputs)
# Parsed messages have been templated

View File

@ -1,6 +1,6 @@
import base64
import re
from typing import Dict, List, Type, Union
from typing import Union
from langchain_core.messages import (
AIMessage,
@ -15,7 +15,7 @@ from .core import Invoker, Prompty, SimpleModel
class RoleMap:
_ROLE_MAP: Dict[str, Type[BaseMessage]] = {
_ROLE_MAP: dict[str, type[BaseMessage]] = {
"system": SystemMessage,
"user": HumanMessage,
"human": HumanMessage,
@ -26,7 +26,7 @@ class RoleMap:
ROLES = _ROLE_MAP.keys()
@classmethod
def get_message_class(cls, role: str) -> Type[BaseMessage]:
def get_message_class(cls, role: str) -> type[BaseMessage]:
return cls._ROLE_MAP[role]
@ -60,7 +60,7 @@ class PromptyChatParser(Invoker):
"and .jpg / .jpeg are supported."
)
def parse_content(self, content: str) -> Union[str, List]:
def parse_content(self, content: str) -> Union[str, list]:
"""for parsing inline images"""
# regular expression to parse markdown images
image = r"(?P<alt>!\[[^\]]*\])\((?P<filename>.*?)(?=\"|\))\)"

View File

@ -1,6 +1,6 @@
import traceback
from pathlib import Path
from typing import Any, Dict, List, Union
from typing import Any, Union
from .core import (
Frontmatter,
@ -120,7 +120,7 @@ def load(prompt_path: str, configuration: str = "default") -> Prompty:
def prepare(
prompt: Prompty,
inputs: Dict[str, Any] = {},
inputs: dict[str, Any] = {},
) -> Any:
"""Prepare the inputs for the prompty.
@ -166,9 +166,9 @@ def prepare(
def run(
prompt: Prompty,
content: Union[Dict, List, str],
configuration: Dict[str, Any] = {},
parameters: Dict[str, Any] = {},
content: Union[dict, list, str],
configuration: dict[str, Any] = {},
parameters: dict[str, Any] = {},
raw: bool = False,
) -> Any:
"""Run the prompty.
@ -219,9 +219,9 @@ def run(
def execute(
prompt: Union[str, Prompty],
configuration: Dict[str, Any] = {},
parameters: Dict[str, Any] = {},
inputs: Dict[str, Any] = {},
configuration: dict[str, Any] = {},
parameters: dict[str, Any] = {},
inputs: dict[str, Any] = {},
raw: bool = False,
connection: str = "default",
) -> Any:

View File

@ -45,9 +45,12 @@ langchain-core = { path = "../../core", editable = true }
langchain-text-splitters = { path = "../../text-splitters", editable = true }
langchain = { path = "../../langchain", editable = true }
[tool.ruff]
select = ["E", "F", "I"]
target-version = "py39"
[tool.ruff.lint]
select = ["E", "F", "I", "T201", "UP"]
ignore = [ "UP007", ]
[tool.mypy]
disallow_untyped_defs = "True"

Some files were not shown because too many files have changed in this diff Show More