mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-27 19:46:55 +00:00
partners[lint]: run pyupgrade
to get code in line with 3.9 standards (#30781)
Using `pyupgrade` to get all `partners` code up to 3.9 standards (mostly, fixing old `typing` imports).
This commit is contained in:
parent
e72f3c26a0
commit
8c6734325b
@ -1,21 +1,14 @@
|
||||
import copy
|
||||
import re
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -89,8 +82,8 @@ class AnthropicTool(TypedDict):
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: Dict[str, Any]
|
||||
cache_control: NotRequired[Dict[str, str]]
|
||||
input_schema: dict[str, Any]
|
||||
cache_control: NotRequired[dict[str, str]]
|
||||
|
||||
|
||||
def _is_builtin_tool(tool: Any) -> bool:
|
||||
@ -109,7 +102,7 @@ def _is_builtin_tool(tool: Any) -> bool:
|
||||
return any(tool_type.startswith(prefix) for prefix in _builtin_tool_prefixes)
|
||||
|
||||
|
||||
def _format_image(image_url: str) -> Dict:
|
||||
def _format_image(image_url: str) -> dict:
|
||||
"""
|
||||
Formats an image of format data:image/jpeg;base64,{b64_string}
|
||||
to a dict for anthropic api
|
||||
@ -138,7 +131,7 @@ def _format_image(image_url: str) -> Dict:
|
||||
|
||||
def _merge_messages(
|
||||
messages: Sequence[BaseMessage],
|
||||
) -> List[Union[SystemMessage, AIMessage, HumanMessage]]:
|
||||
) -> list[Union[SystemMessage, AIMessage, HumanMessage]]:
|
||||
"""Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501
|
||||
merged: list = []
|
||||
for curr in messages:
|
||||
@ -169,7 +162,7 @@ def _merge_messages(
|
||||
for c in (SystemMessage, HumanMessage)
|
||||
):
|
||||
if isinstance(cast(BaseMessage, last).content, str):
|
||||
new_content: List = [
|
||||
new_content: list = [
|
||||
{"type": "text", "text": cast(BaseMessage, last).content}
|
||||
]
|
||||
else:
|
||||
@ -185,8 +178,8 @@ def _merge_messages(
|
||||
|
||||
|
||||
def _format_messages(
|
||||
messages: List[BaseMessage],
|
||||
) -> Tuple[Union[str, List[Dict], None], List[Dict]]:
|
||||
messages: list[BaseMessage],
|
||||
) -> tuple[Union[str, list[dict], None], list[dict]]:
|
||||
"""Format messages for anthropic."""
|
||||
|
||||
"""
|
||||
@ -198,8 +191,8 @@ def _format_messages(
|
||||
for m in messages
|
||||
]
|
||||
"""
|
||||
system: Union[str, List[Dict], None] = None
|
||||
formatted_messages: List[Dict] = []
|
||||
system: Union[str, list[dict], None] = None
|
||||
formatted_messages: list[dict] = []
|
||||
|
||||
merged_messages = _merge_messages(messages)
|
||||
for i, message in enumerate(merged_messages):
|
||||
@ -220,7 +213,7 @@ def _format_messages(
|
||||
continue
|
||||
|
||||
role = _message_type_lookups[message.type]
|
||||
content: Union[str, List]
|
||||
content: Union[str, list]
|
||||
|
||||
if not isinstance(message.content, str):
|
||||
# parse as dict
|
||||
@ -830,7 +823,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
max_retries: int = 2
|
||||
"""Number of retries allowed for requests sent to the Anthropic Completion API."""
|
||||
|
||||
stop_sequences: Optional[List[str]] = Field(None, alias="stop")
|
||||
stop_sequences: Optional[list[str]] = Field(None, alias="stop")
|
||||
"""Default stop sequences."""
|
||||
|
||||
anthropic_api_url: Optional[str] = Field(
|
||||
@ -858,7 +851,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
default_headers: Optional[Mapping[str, str]] = None
|
||||
"""Headers to pass to the Anthropic clients, will be used for every API call."""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to use streaming or not."""
|
||||
@ -868,7 +861,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
message chunks will be generated during the stream including usage metadata.
|
||||
"""
|
||||
|
||||
thinking: Optional[Dict[str, Any]] = Field(default=None)
|
||||
thinking: Optional[dict[str, Any]] = Field(default=None)
|
||||
"""Parameters for Claude reasoning,
|
||||
e.g., ``{"type": "enabled", "budget_tokens": 10_000}``"""
|
||||
|
||||
@ -878,7 +871,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
return "anthropic-chat"
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {"anthropic_api_key": "ANTHROPIC_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
@ -886,12 +879,12 @@ class ChatAnthropic(BaseChatModel):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "anthropic"]
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model": self.model,
|
||||
@ -907,7 +900,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
}
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
@ -925,14 +918,14 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict) -> Any:
|
||||
def build_extra(cls, values: dict) -> Any:
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
return values
|
||||
|
||||
@cached_property
|
||||
def _client_params(self) -> Dict[str, Any]:
|
||||
client_params: Dict[str, Any] = {
|
||||
def _client_params(self) -> dict[str, Any]:
|
||||
client_params: dict[str, Any] = {
|
||||
"api_key": self.anthropic_api_key.get_secret_value(),
|
||||
"base_url": self.anthropic_api_url,
|
||||
"max_retries": self.max_retries,
|
||||
@ -958,9 +951,9 @@ class ChatAnthropic(BaseChatModel):
|
||||
self,
|
||||
input_: LanguageModelInput,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Dict,
|
||||
) -> Dict:
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: dict,
|
||||
) -> dict:
|
||||
messages = self._convert_input(input_).to_messages()
|
||||
system, formatted_messages = _format_messages(messages)
|
||||
payload = {
|
||||
@ -981,8 +974,8 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
stream_usage: Optional[bool] = None,
|
||||
@ -1012,8 +1005,8 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
stream_usage: Optional[bool] = None,
|
||||
@ -1088,8 +1081,8 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -1104,8 +1097,8 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -1120,7 +1113,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
def _get_llm_for_structured_output_when_thinking_is_enabled(
|
||||
self,
|
||||
schema: Union[Dict, type],
|
||||
schema: Union[dict, type],
|
||||
formatted_tool: AnthropicTool,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
thinking_admonition = (
|
||||
@ -1148,10 +1141,10 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[Dict[str, str], Literal["any", "auto"], str]
|
||||
Union[dict[str, str], Literal["any", "auto"], str]
|
||||
] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
@ -1326,11 +1319,11 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Union[Dict, type],
|
||||
schema: Union[dict, type],
|
||||
*,
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
@ -1483,9 +1476,9 @@ class ChatAnthropic(BaseChatModel):
|
||||
@beta()
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: list[BaseMessage],
|
||||
tools: Optional[
|
||||
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
|
||||
Sequence[Union[dict[str, Any], type, Callable, BaseTool]]
|
||||
] = None,
|
||||
) -> int:
|
||||
"""Count tokens in a sequence of input messages.
|
||||
@ -1546,7 +1539,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
https://docs.anthropic.com/en/docs/build-with-claude/token-counting
|
||||
"""
|
||||
formatted_system, formatted_messages = _format_messages(messages)
|
||||
kwargs: Dict[str, Any] = {}
|
||||
kwargs: dict[str, Any] = {}
|
||||
if isinstance(formatted_system, str):
|
||||
kwargs["system"] = formatted_system
|
||||
if tools:
|
||||
@ -1562,7 +1555,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
|
||||
def convert_to_anthropic_tool(
|
||||
tool: Union[Dict[str, Any], Type, Callable, BaseTool],
|
||||
tool: Union[dict[str, Any], type, Callable, BaseTool],
|
||||
) -> AnthropicTool:
|
||||
"""Convert a tool-like object to an Anthropic tool definition."""
|
||||
# already in Anthropic tool format
|
||||
@ -1611,8 +1604,8 @@ class _AnthropicToolUse(TypedDict):
|
||||
|
||||
|
||||
def _lc_tool_calls_to_anthropic_tool_use_blocks(
|
||||
tool_calls: List[ToolCall],
|
||||
) -> List[_AnthropicToolUse]:
|
||||
tool_calls: list[ToolCall],
|
||||
) -> list[_AnthropicToolUse]:
|
||||
blocks = []
|
||||
for tool_call in tool_calls:
|
||||
blocks.append(
|
||||
@ -1735,7 +1728,7 @@ class ChatAnthropicMessages(ChatAnthropic):
|
||||
|
||||
|
||||
def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata:
|
||||
input_token_details: Dict = {
|
||||
input_token_details: dict = {
|
||||
"cache_read": getattr(anthropic_usage, "cache_read_input_tokens", None),
|
||||
"cache_creation": getattr(anthropic_usage, "cache_creation_input_tokens", None),
|
||||
}
|
||||
|
@ -1,8 +1,6 @@
|
||||
import json
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -44,7 +42,7 @@ TOOL_PARAMETER_FORMAT = """<parameter>
|
||||
</parameter>"""
|
||||
|
||||
|
||||
def _get_type(parameter: Dict[str, Any]) -> str:
|
||||
def _get_type(parameter: dict[str, Any]) -> str:
|
||||
if "type" in parameter:
|
||||
return parameter["type"]
|
||||
if "anyOf" in parameter:
|
||||
@ -54,9 +52,9 @@ def _get_type(parameter: Dict[str, Any]) -> str:
|
||||
return json.dumps(parameter)
|
||||
|
||||
|
||||
def get_system_message(tools: List[Dict]) -> str:
|
||||
def get_system_message(tools: list[dict]) -> str:
|
||||
"""Generate a system message that describes the available tools."""
|
||||
tools_data: List[Dict] = [
|
||||
tools_data: list[dict] = [
|
||||
{
|
||||
"tool_name": tool["name"],
|
||||
"tool_description": tool["description"],
|
||||
@ -86,13 +84,13 @@ def get_system_message(tools: List[Dict]) -> str:
|
||||
return SYSTEM_PROMPT_FORMAT.format(formatted_tools=tools_formatted)
|
||||
|
||||
|
||||
def _xml_to_dict(t: Any) -> Union[str, Dict[str, Any]]:
|
||||
def _xml_to_dict(t: Any) -> Union[str, dict[str, Any]]:
|
||||
# Base case: If the element has no children, return its text or an empty string.
|
||||
if len(t) == 0:
|
||||
return t.text or ""
|
||||
|
||||
# Recursive case: The element has children. Convert them into a dictionary.
|
||||
d: Dict[str, Any] = {}
|
||||
d: dict[str, Any] = {}
|
||||
for child in t:
|
||||
if child.tag not in d:
|
||||
d[child.tag] = _xml_to_dict(child)
|
||||
@ -104,7 +102,7 @@ def _xml_to_dict(t: Any) -> Union[str, Dict[str, Any]]:
|
||||
return d
|
||||
|
||||
|
||||
def _xml_to_function_call(invoke: Any, tools: List[Dict]) -> Dict[str, Any]:
|
||||
def _xml_to_function_call(invoke: Any, tools: list[dict]) -> dict[str, Any]:
|
||||
name = invoke.find("tool_name").text
|
||||
arguments = _xml_to_dict(invoke.find("parameters"))
|
||||
|
||||
@ -135,7 +133,7 @@ def _xml_to_function_call(invoke: Any, tools: List[Dict]) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _xml_to_tool_calls(elem: Any, tools: List[Dict]) -> List[Dict[str, Any]]:
|
||||
def _xml_to_tool_calls(elem: Any, tools: list[dict]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert an XML element and its children into a dictionary of dictionaries.
|
||||
"""
|
||||
|
@ -1,13 +1,9 @@
|
||||
import re
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
)
|
||||
|
||||
@ -83,11 +79,11 @@ class _AnthropicCommon(BaseLanguageModel):
|
||||
HUMAN_PROMPT: Optional[str] = None
|
||||
AI_PROMPT: Optional[str] = None
|
||||
count_tokens: Optional[Callable[[str], int]] = None
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict) -> Any:
|
||||
def build_extra(cls, values: dict) -> Any:
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
return values
|
||||
@ -131,7 +127,7 @@ class _AnthropicCommon(BaseLanguageModel):
|
||||
"""Get the identifying parameters."""
|
||||
return {**{}, **self._default_params}
|
||||
|
||||
def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]:
|
||||
def _get_anthropic_stop(self, stop: Optional[list[str]] = None) -> list[str]:
|
||||
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
|
||||
raise NameError("Please ensure the anthropic package is loaded")
|
||||
|
||||
@ -165,7 +161,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_warning(cls, values: Dict) -> Any:
|
||||
def raise_warning(cls, values: dict) -> Any:
|
||||
"""Raise warning that this class is deprecated."""
|
||||
warnings.warn(
|
||||
"This Anthropic LLM is deprecated. "
|
||||
@ -180,7 +176,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
|
||||
return "anthropic-llm"
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {"anthropic_api_key": "ANTHROPIC_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
@ -188,7 +184,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
|
||||
return True
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model": self.model,
|
||||
@ -203,7 +199,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
|
||||
}
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
@ -233,7 +229,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@ -277,7 +273,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@ -303,7 +299,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
@ -338,7 +334,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional, Type, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain_core.messages import AIMessage, ToolCall
|
||||
from langchain_core.messages.tool import tool_call
|
||||
@ -14,14 +14,14 @@ class ToolsOutputParser(BaseGenerationOutputParser):
|
||||
"""Whether to return only the first tool call."""
|
||||
args_only: bool = False
|
||||
"""Whether to return only the arguments of the tool calls."""
|
||||
pydantic_schemas: Optional[List[Type[BaseModel]]] = None
|
||||
pydantic_schemas: Optional[list[type[BaseModel]]] = None
|
||||
"""Pydantic schemas to parse tool calls into."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
Args:
|
||||
@ -34,7 +34,7 @@ class ToolsOutputParser(BaseGenerationOutputParser):
|
||||
if not result or not isinstance(result[0], ChatGeneration):
|
||||
return None if self.first_tool_only else []
|
||||
message = cast(AIMessage, result[0].message)
|
||||
tool_calls: List = [
|
||||
tool_calls: list = [
|
||||
dict(tc) for tc in _extract_tool_calls_from_message(message)
|
||||
]
|
||||
if isinstance(message.content, list):
|
||||
@ -64,14 +64,14 @@ class ToolsOutputParser(BaseGenerationOutputParser):
|
||||
return cls_(**tool_call["args"])
|
||||
|
||||
|
||||
def _extract_tool_calls_from_message(message: AIMessage) -> List[ToolCall]:
|
||||
def _extract_tool_calls_from_message(message: AIMessage) -> list[ToolCall]:
|
||||
"""Extract tool calls from a list of content blocks."""
|
||||
if message.tool_calls:
|
||||
return message.tool_calls
|
||||
return extract_tool_calls(message.content)
|
||||
|
||||
|
||||
def extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[ToolCall]:
|
||||
def extract_tool_calls(content: Union[str, list[Union[str, dict]]]) -> list[ToolCall]:
|
||||
"""Extract tool calls from a list of content blocks."""
|
||||
if isinstance(content, list):
|
||||
tool_calls = []
|
||||
|
@ -55,8 +55,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
|
||||
disallow_untyped_defs = "True"
|
||||
plugins = ['pydantic.mypy']
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "T201"]
|
||||
select = ["E", "F", "I", "T201", "UP"]
|
||||
ignore = [ "UP007", ]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import json
|
||||
from base64 import b64encode
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
@ -270,7 +270,7 @@ def test_anthropic_call() -> None:
|
||||
def test_anthropic_generate() -> None:
|
||||
"""Test generate method of anthropic."""
|
||||
chat = ChatAnthropic(model=MODEL_NAME)
|
||||
chat_messages: List[List[BaseMessage]] = [
|
||||
chat_messages: list[list[BaseMessage]] = [
|
||||
[HumanMessage(content="How many toes do dogs have?")]
|
||||
]
|
||||
messages_copy = [messages.copy() for messages in chat_messages]
|
||||
@ -318,7 +318,7 @@ async def test_anthropic_async_streaming_callback() -> None:
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
chat_messages: List[BaseMessage] = [
|
||||
chat_messages: list[BaseMessage] = [
|
||||
HumanMessage(content="How many toes do dogs have?")
|
||||
]
|
||||
async for token in chat.astream(chat_messages):
|
||||
@ -809,7 +809,7 @@ def test_image_tool_calling() -> None:
|
||||
|
||||
fav_color: str
|
||||
|
||||
human_content: List[dict] = [
|
||||
human_content: list[dict] = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "what's your favorite color in this image",
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Test ChatAnthropic chat model."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from pydantic import BaseModel, Field
|
||||
@ -136,7 +136,7 @@ def test_anthropic_complex_structured_output() -> None:
|
||||
sender_address: Optional[str] = Field(
|
||||
None, description="The sender's address, if available"
|
||||
)
|
||||
action_items: List[str] = Field(
|
||||
action_items: list[str] = Field(
|
||||
..., description="A list of action items requested by the email"
|
||||
)
|
||||
topic: str = Field(
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Test Anthropic API wrapper."""
|
||||
|
||||
from typing import Generator
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Type, cast
|
||||
from typing import Literal, cast
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
@ -14,7 +14,7 @@ REPO_ROOT_DIR = Path(__file__).parents[5]
|
||||
|
||||
class TestAnthropicStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return ChatAnthropic
|
||||
|
||||
@property
|
||||
@ -36,9 +36,9 @@ class TestAnthropicStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def supported_usage_metadata_details(
|
||||
self,
|
||||
) -> Dict[
|
||||
) -> dict[
|
||||
Literal["invoke", "stream"],
|
||||
List[
|
||||
list[
|
||||
Literal[
|
||||
"audio_input",
|
||||
"audio_output",
|
||||
@ -58,7 +58,7 @@ class TestAnthropicStandard(ChatModelIntegrationTests):
|
||||
model="claude-3-5-sonnet-20240620", # type: ignore[call-arg]
|
||||
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, # type: ignore[call-arg]
|
||||
)
|
||||
with open(REPO_ROOT_DIR / "README.md", "r") as f:
|
||||
with open(REPO_ROOT_DIR / "README.md") as f:
|
||||
readme = f.read()
|
||||
|
||||
input_ = f"""What's langchain? Here's the langchain README:
|
||||
@ -87,7 +87,7 @@ class TestAnthropicStandard(ChatModelIntegrationTests):
|
||||
model="claude-3-5-sonnet-20240620", # type: ignore[call-arg]
|
||||
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, # type: ignore[call-arg]
|
||||
)
|
||||
with open(REPO_ROOT_DIR / "README.md", "r") as f:
|
||||
with open(REPO_ROOT_DIR / "README.md") as f:
|
||||
readme = f.read()
|
||||
|
||||
input_ = f"""What's langchain? Here's the langchain README:
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Test chat model integration."""
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Literal, Type, cast
|
||||
from typing import Any, Callable, Literal, cast
|
||||
|
||||
import pytest
|
||||
from anthropic.types import Message, TextBlock, Usage
|
||||
@ -297,7 +297,7 @@ def test__merge_messages_mutation() -> None:
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def pydantic() -> Type[BaseModel]:
|
||||
def pydantic() -> type[BaseModel]:
|
||||
class dummy_function(BaseModel):
|
||||
"""dummy function"""
|
||||
|
||||
@ -328,7 +328,7 @@ def dummy_tool() -> BaseTool:
|
||||
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
||||
|
||||
class DummyFunction(BaseTool): # type: ignore[override]
|
||||
args_schema: Type[BaseModel] = Schema
|
||||
args_schema: type[BaseModel] = Schema
|
||||
name: str = "dummy_function"
|
||||
description: str = "dummy function"
|
||||
|
||||
@ -339,7 +339,7 @@ def dummy_tool() -> BaseTool:
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def json_schema() -> Dict:
|
||||
def json_schema() -> dict:
|
||||
return {
|
||||
"title": "dummy_function",
|
||||
"description": "dummy function",
|
||||
@ -357,7 +357,7 @@ def json_schema() -> Dict:
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def openai_function() -> Dict:
|
||||
def openai_function() -> dict:
|
||||
return {
|
||||
"name": "dummy_function",
|
||||
"description": "dummy function",
|
||||
@ -377,11 +377,11 @@ def openai_function() -> Dict:
|
||||
|
||||
|
||||
def test_convert_to_anthropic_tool(
|
||||
pydantic: Type[BaseModel],
|
||||
pydantic: type[BaseModel],
|
||||
function: Callable,
|
||||
dummy_tool: BaseTool,
|
||||
json_schema: Dict,
|
||||
openai_function: Dict,
|
||||
json_schema: dict,
|
||||
openai_function: dict,
|
||||
) -> None:
|
||||
expected = {
|
||||
"name": "dummy_function",
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
@ -6,7 +6,7 @@ from pydantic import BaseModel
|
||||
|
||||
from langchain_anthropic.output_parsers import ToolsOutputParser
|
||||
|
||||
_CONTENT: List = [
|
||||
_CONTENT: list = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "thought",
|
||||
@ -19,7 +19,7 @@ _CONTENT: List = [
|
||||
{"type": "tool_use", "input": {"baz": "a"}, "id": "2", "name": "_Foo2"},
|
||||
]
|
||||
|
||||
_RESULT: List = [ChatGeneration(message=AIMessage(_CONTENT))] # type: ignore[misc]
|
||||
_RESULT: list = [ChatGeneration(message=AIMessage(_CONTENT))] # type: ignore[misc]
|
||||
|
||||
|
||||
class _Foo1(BaseModel):
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
@ -10,7 +8,7 @@ from langchain_anthropic import ChatAnthropic
|
||||
|
||||
class TestAnthropicStandard(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return ChatAnthropic
|
||||
|
||||
@property
|
||||
|
@ -8,17 +8,12 @@ from __future__ import annotations
|
||||
import base64
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -37,11 +32,11 @@ logger = logging.getLogger()
|
||||
DEFAULT_K = 4 # Number of Documents to return.
|
||||
|
||||
|
||||
def _results_to_docs(results: Any) -> List[Document]:
|
||||
def _results_to_docs(results: Any) -> list[Document]:
|
||||
return [doc for doc, _ in _results_to_docs_and_scores(results)]
|
||||
|
||||
|
||||
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
|
||||
def _results_to_docs_and_scores(results: Any) -> list[tuple[Document, float]]:
|
||||
return [
|
||||
# TODO: Chroma can do batch querying,
|
||||
# we shouldn't hard code to the 1st result
|
||||
@ -58,7 +53,7 @@ def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
|
||||
]
|
||||
|
||||
|
||||
def _results_to_docs_and_vectors(results: Any) -> List[Tuple[Document, np.ndarray]]:
|
||||
def _results_to_docs_and_vectors(results: Any) -> list[tuple[Document, np.ndarray]]:
|
||||
return [
|
||||
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
|
||||
for result in zip(
|
||||
@ -69,7 +64,7 @@ def _results_to_docs_and_vectors(results: Any) -> List[Tuple[Document, np.ndarra
|
||||
]
|
||||
|
||||
|
||||
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
|
||||
Matrix = Union[list[list[float]], list[np.ndarray], np.ndarray]
|
||||
|
||||
|
||||
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
@ -104,7 +99,7 @@ def maximal_marginal_relevance(
|
||||
embedding_list: list,
|
||||
lambda_mult: float = 0.5,
|
||||
k: int = 4,
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
"""Calculate maximal marginal relevance.
|
||||
|
||||
Args:
|
||||
@ -287,7 +282,7 @@ class Chroma(VectorStore):
|
||||
embedding_function: Optional[Embeddings] = None,
|
||||
persist_directory: Optional[str] = None,
|
||||
client_settings: Optional[chromadb.config.Settings] = None,
|
||||
collection_metadata: Optional[Dict] = None,
|
||||
collection_metadata: Optional[dict] = None,
|
||||
client: Optional[chromadb.ClientAPI] = None,
|
||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||
create_collection_if_not_exists: Optional[bool] = True,
|
||||
@ -370,13 +365,13 @@ class Chroma(VectorStore):
|
||||
@xor_args(("query_texts", "query_embeddings"))
|
||||
def __query_collection(
|
||||
self,
|
||||
query_texts: Optional[List[str]] = None,
|
||||
query_embeddings: Optional[List[List[float]]] = None,
|
||||
query_texts: Optional[list[str]] = None,
|
||||
query_embeddings: Optional[list[list[float]]] = None,
|
||||
n_results: int = 4,
|
||||
where: Optional[Dict[str, str]] = None,
|
||||
where_document: Optional[Dict[str, str]] = None,
|
||||
where: Optional[dict[str, str]] = None,
|
||||
where_document: Optional[dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[List[Document], chromadb.QueryResult]:
|
||||
) -> Union[list[Document], chromadb.QueryResult]:
|
||||
"""Query the chroma collection.
|
||||
|
||||
Args:
|
||||
@ -411,11 +406,11 @@ class Chroma(VectorStore):
|
||||
|
||||
def add_images(
|
||||
self,
|
||||
uris: List[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
uris: list[str],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Run more images through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
@ -502,10 +497,10 @@ class Chroma(VectorStore):
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
@ -591,9 +586,9 @@ class Chroma(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
k: int = DEFAULT_K,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
filter: Optional[dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Run similarity search with Chroma.
|
||||
|
||||
Args:
|
||||
@ -612,12 +607,12 @@ class Chroma(VectorStore):
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
embedding: list[float],
|
||||
k: int = DEFAULT_K,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
where_document: Optional[Dict[str, str]] = None,
|
||||
filter: Optional[dict[str, str]] = None,
|
||||
where_document: Optional[dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
@ -642,12 +637,12 @@ class Chroma(VectorStore):
|
||||
|
||||
def similarity_search_by_vector_with_relevance_scores(
|
||||
self,
|
||||
embedding: List[float],
|
||||
embedding: list[float],
|
||||
k: int = DEFAULT_K,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
where_document: Optional[Dict[str, str]] = None,
|
||||
filter: Optional[dict[str, str]] = None,
|
||||
where_document: Optional[dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
) -> list[tuple[Document, float]]:
|
||||
"""Return docs most similar to embedding vector and similarity score.
|
||||
|
||||
Args:
|
||||
@ -675,10 +670,10 @@ class Chroma(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
k: int = DEFAULT_K,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
where_document: Optional[Dict[str, str]] = None,
|
||||
filter: Optional[dict[str, str]] = None,
|
||||
where_document: Optional[dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
) -> list[tuple[Document, float]]:
|
||||
"""Run similarity search with Chroma with distance.
|
||||
|
||||
Args:
|
||||
@ -717,10 +712,10 @@ class Chroma(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
k: int = DEFAULT_K,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
where_document: Optional[Dict[str, str]] = None,
|
||||
filter: Optional[dict[str, str]] = None,
|
||||
where_document: Optional[dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, np.ndarray]]:
|
||||
) -> list[tuple[Document, np.ndarray]]:
|
||||
"""Run similarity search with Chroma with vectors.
|
||||
|
||||
Args:
|
||||
@ -800,9 +795,9 @@ class Chroma(VectorStore):
|
||||
self,
|
||||
uri: str,
|
||||
k: int = DEFAULT_K,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
filter: Optional[dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Search for similar images based on the given image URI.
|
||||
|
||||
Args:
|
||||
@ -844,9 +839,9 @@ class Chroma(VectorStore):
|
||||
self,
|
||||
uri: str,
|
||||
k: int = DEFAULT_K,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
filter: Optional[dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
) -> list[tuple[Document, float]]:
|
||||
"""Search for similar images based on the given image URI.
|
||||
|
||||
Args:
|
||||
@ -886,14 +881,14 @@ class Chroma(VectorStore):
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
embedding: list[float],
|
||||
k: int = DEFAULT_K,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
where_document: Optional[Dict[str, str]] = None,
|
||||
filter: Optional[dict[str, str]] = None,
|
||||
where_document: Optional[dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
@ -942,10 +937,10 @@ class Chroma(VectorStore):
|
||||
k: int = DEFAULT_K,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
where_document: Optional[Dict[str, str]] = None,
|
||||
filter: Optional[dict[str, str]] = None,
|
||||
where_document: Optional[dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
@ -1005,8 +1000,8 @@ class Chroma(VectorStore):
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
include: Optional[list[str]] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Gets the collection.
|
||||
|
||||
Args:
|
||||
@ -1081,7 +1076,7 @@ class Chroma(VectorStore):
|
||||
return self.update_documents([document_id], [document])
|
||||
|
||||
# type: ignore
|
||||
def update_documents(self, ids: List[str], documents: List[Document]) -> None:
|
||||
def update_documents(self, ids: list[str], documents: list[Document]) -> None:
|
||||
"""Update a document in the collection.
|
||||
|
||||
Args:
|
||||
@ -1129,16 +1124,16 @@ class Chroma(VectorStore):
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[Chroma],
|
||||
texts: List[str],
|
||||
cls: type[Chroma],
|
||||
texts: list[str],
|
||||
embedding: Optional[Embeddings] = None,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[list[str]] = None,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
persist_directory: Optional[str] = None,
|
||||
client_settings: Optional[chromadb.config.Settings] = None,
|
||||
client: Optional[chromadb.ClientAPI] = None,
|
||||
collection_metadata: Optional[Dict] = None,
|
||||
collection_metadata: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> Chroma:
|
||||
"""Create a Chroma vectorstore from a raw documents.
|
||||
@ -1200,15 +1195,15 @@ class Chroma(VectorStore):
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls: Type[Chroma],
|
||||
documents: List[Document],
|
||||
cls: type[Chroma],
|
||||
documents: list[Document],
|
||||
embedding: Optional[Embeddings] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
ids: Optional[list[str]] = None,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
persist_directory: Optional[str] = None,
|
||||
client_settings: Optional[chromadb.config.Settings] = None,
|
||||
client: Optional[chromadb.ClientAPI] = None, # Add this line
|
||||
collection_metadata: Optional[Dict] = None,
|
||||
collection_metadata: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> Chroma:
|
||||
"""Create a Chroma vectorstore from a list of documents.
|
||||
@ -1249,7 +1244,7 @@ class Chroma(VectorStore):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
|
||||
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> None:
|
||||
"""Delete by vector IDs.
|
||||
|
||||
Args:
|
||||
|
@ -58,8 +58,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = true
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "T201", "D"]
|
||||
select = ["E", "F", "I", "T201", "D", "UP"]
|
||||
ignore = [ "UP007", ]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""Fake Embedding class for testing purposes."""
|
||||
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
@ -11,22 +10,22 @@ fake_texts = ["foo", "bar", "baz"]
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Return simple embeddings.
|
||||
Embeddings encode each text as its index."""
|
||||
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
|
||||
return [[1.0] * 9 + [float(i)] for i in range(len(texts))]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Return constant query embeddings.
|
||||
Embeddings are identical to embed_documents(texts)[0].
|
||||
Distance to each text will be that text's index,
|
||||
as it was passed to embed_documents."""
|
||||
return [float(1.0)] * 9 + [float(0.0)]
|
||||
return [1.0] * 9 + [0.0]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
return self.embed_query(text)
|
||||
|
||||
|
||||
@ -35,22 +34,22 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
|
||||
vectors for the same texts."""
|
||||
|
||||
def __init__(self, dimensionality: int = 10) -> None:
|
||||
self.known_texts: List[str] = []
|
||||
self.known_texts: list[str] = []
|
||||
self.dimensionality = dimensionality
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Return consistent embeddings for each text seen so far."""
|
||||
out_vectors = []
|
||||
for text in texts:
|
||||
if text not in self.known_texts:
|
||||
self.known_texts.append(text)
|
||||
vector = [float(1.0)] * (self.dimensionality - 1) + [
|
||||
vector = [1.0] * (self.dimensionality - 1) + [
|
||||
float(self.known_texts.index(text))
|
||||
]
|
||||
out_vectors.append(vector)
|
||||
return out_vectors
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Return consistent embeddings for the text, if seen before, or a constant
|
||||
one if the text is unknown."""
|
||||
return self.embed_documents([text])[0]
|
||||
@ -61,13 +60,13 @@ class AngularTwoDimensionalEmbeddings(Embeddings):
|
||||
From angles (as strings in units of pi) to unit embedding vectors on a circle.
|
||||
"""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Make a list of texts into a list of embedding vectors.
|
||||
"""
|
||||
return [self.embed_query(text) for text in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""
|
||||
Convert input text to a 'vector' (list of floats).
|
||||
If the text is a number, use it as the angle for the
|
||||
|
@ -3,8 +3,8 @@
|
||||
import os.path
|
||||
import tempfile
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import (
|
||||
Generator,
|
||||
cast,
|
||||
)
|
||||
|
||||
@ -222,7 +222,7 @@ def test_chroma_with_metadatas_with_scores_using_vector() -> None:
|
||||
def test_chroma_search_filter() -> None:
|
||||
"""Test end to end construction and search with metadata filtering."""
|
||||
texts = ["far", "bar", "baz"]
|
||||
metadatas = [{"first_letter": "{}".format(text[0])} for text in texts]
|
||||
metadatas = [{"first_letter": f"{text[0]}"} for text in texts]
|
||||
ids = [f"id_{i}" for i in range(len(texts))]
|
||||
docsearch = Chroma.from_texts(
|
||||
collection_name="test_collection",
|
||||
@ -245,7 +245,7 @@ def test_chroma_search_filter() -> None:
|
||||
def test_chroma_search_filter_with_scores() -> None:
|
||||
"""Test end to end construction and scored search with metadata filtering."""
|
||||
texts = ["far", "bar", "baz"]
|
||||
metadatas = [{"first_letter": "{}".format(text[0])} for text in texts]
|
||||
metadatas = [{"first_letter": f"{text[0]}"} for text in texts]
|
||||
ids = [f"id_{i}" for i in range(len(texts))]
|
||||
docsearch = Chroma.from_texts(
|
||||
collection_name="test_collection",
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Generator
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
@ -1,7 +1,8 @@
|
||||
"""DeepSeek chat models."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Dict, Iterator, List, Literal, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Literal, Optional, TypeVar, Union
|
||||
|
||||
import openai
|
||||
from langchain_core.callbacks import (
|
||||
@ -19,8 +20,8 @@ from typing_extensions import Self
|
||||
DEFAULT_API_BASE = "https://api.deepseek.com/v1"
|
||||
|
||||
_BM = TypeVar("_BM", bound=BaseModel)
|
||||
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
|
||||
_DictOrPydantic = Union[Dict, _BM]
|
||||
_DictOrPydanticClass = Union[dict[str, Any], type[_BM], type]
|
||||
_DictOrPydantic = Union[dict, _BM]
|
||||
|
||||
|
||||
class ChatDeepSeek(BaseChatOpenAI):
|
||||
@ -178,7 +179,7 @@ class ChatDeepSeek(BaseChatOpenAI):
|
||||
return "chat-deepseek"
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
"""A map of constructor argument names to secret ids."""
|
||||
return {"api_key": "DEEPSEEK_API_KEY"}
|
||||
|
||||
@ -217,7 +218,7 @@ class ChatDeepSeek(BaseChatOpenAI):
|
||||
def _create_chat_result(
|
||||
self,
|
||||
response: Union[dict, openai.BaseModel],
|
||||
generation_info: Optional[Dict] = None,
|
||||
generation_info: Optional[dict] = None,
|
||||
) -> ChatResult:
|
||||
rtn = super()._create_chat_result(response, generation_info)
|
||||
|
||||
@ -243,8 +244,8 @@ class ChatDeepSeek(BaseChatOpenAI):
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
self,
|
||||
chunk: dict,
|
||||
default_chunk_class: Type,
|
||||
base_generation_info: Optional[Dict],
|
||||
default_chunk_class: type,
|
||||
base_generation_info: Optional[dict],
|
||||
) -> Optional[ChatGenerationChunk]:
|
||||
generation_chunk = super()._convert_chunk_to_generation_chunk(
|
||||
chunk,
|
||||
@ -268,8 +269,8 @@ class ChatDeepSeek(BaseChatOpenAI):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
@ -287,8 +288,8 @@ class ChatDeepSeek(BaseChatOpenAI):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
|
@ -45,8 +45,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "T201"]
|
||||
select = ["E", "F", "I", "T201", "UP"]
|
||||
ignore = [ "UP007", ]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Test ChatDeepSeek chat model."""
|
||||
|
||||
from typing import Optional, Type
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
@ -13,7 +13,7 @@ from langchain_deepseek.chat_models import ChatDeepSeek
|
||||
|
||||
class TestChatDeepSeek(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[ChatDeepSeek]:
|
||||
def chat_model_class(self) -> type[ChatDeepSeek]:
|
||||
return ChatDeepSeek
|
||||
|
||||
@property
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Test chat model integration."""
|
||||
|
||||
from typing import Any, Dict, Literal, Type, Union
|
||||
from typing import Any, Literal, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
@ -28,9 +28,9 @@ class MockOpenAIResponse(BaseModel):
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: Union[Literal["none", "warn", "error"], bool] = True,
|
||||
context: Union[Dict[str, Any], None] = None,
|
||||
context: Union[dict[str, Any], None] = None,
|
||||
serialize_as_any: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
choices_list = []
|
||||
for choice in self.choices:
|
||||
if isinstance(choice.message, ChatCompletionMessage):
|
||||
@ -57,7 +57,7 @@ class MockOpenAIResponse(BaseModel):
|
||||
|
||||
class TestChatDeepSeekUnit(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[ChatDeepSeek]:
|
||||
def chat_model_class(self) -> type[ChatDeepSeek]:
|
||||
return ChatDeepSeek
|
||||
|
||||
@property
|
||||
@ -134,7 +134,7 @@ class TestChatDeepSeekCustomUnit:
|
||||
def test_convert_chunk_with_reasoning_content(self) -> None:
|
||||
"""Test that reasoning_content is properly extracted from streaming chunk."""
|
||||
chat_model = ChatDeepSeek(model="deepseek-chat", api_key=SecretStr("api_key"))
|
||||
chunk: Dict[str, Any] = {
|
||||
chunk: dict[str, Any] = {
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
@ -158,7 +158,7 @@ class TestChatDeepSeekCustomUnit:
|
||||
def test_convert_chunk_with_reasoning(self) -> None:
|
||||
"""Test that reasoning is properly extracted from streaming chunk."""
|
||||
chat_model = ChatDeepSeek(model="deepseek-chat", api_key=SecretStr("api_key"))
|
||||
chunk: Dict[str, Any] = {
|
||||
chunk: dict[str, Any] = {
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
@ -182,7 +182,7 @@ class TestChatDeepSeekCustomUnit:
|
||||
def test_convert_chunk_without_reasoning(self) -> None:
|
||||
"""Test that chunk without reasoning fields works correctly."""
|
||||
chat_model = ChatDeepSeek(model="deepseek-chat", api_key=SecretStr("api_key"))
|
||||
chunk: Dict[str, Any] = {"choices": [{"delta": {"content": "Main content"}}]}
|
||||
chunk: dict[str, Any] = {"choices": [{"delta": {"content": "Main content"}}]}
|
||||
|
||||
chunk_result = chat_model._convert_chunk_to_generation_chunk(
|
||||
chunk, AIMessageChunk, None
|
||||
@ -194,7 +194,7 @@ class TestChatDeepSeekCustomUnit:
|
||||
def test_convert_chunk_with_empty_delta(self) -> None:
|
||||
"""Test that chunk with empty delta works correctly."""
|
||||
chat_model = ChatDeepSeek(model="deepseek-chat", api_key=SecretStr("api_key"))
|
||||
chunk: Dict[str, Any] = {"choices": [{"delta": {}}]}
|
||||
chunk: dict[str, Any] = {"choices": [{"delta": {}}]}
|
||||
|
||||
chunk_result = chat_model._convert_chunk_to_generation_chunk(
|
||||
chunk, AIMessageChunk, None
|
||||
|
@ -1,11 +1,10 @@
|
||||
import os # type: ignore[import-not-found]
|
||||
from typing import Dict
|
||||
|
||||
from exa_py import Exa # type: ignore
|
||||
from langchain_core.utils import convert_to_secret_str
|
||||
|
||||
|
||||
def initialize_client(values: Dict) -> Dict:
|
||||
def initialize_client(values: dict) -> dict:
|
||||
"""Initialize the client."""
|
||||
exa_api_key = values.get("exa_api_key") or os.environ.get("EXA_API_KEY") or ""
|
||||
values["exa_api_key"] = convert_to_secret_str(exa_api_key)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from exa_py import Exa # type: ignore[untyped-import]
|
||||
from exa_py.api import (
|
||||
@ -13,7 +13,7 @@ from pydantic import Field, SecretStr, model_validator
|
||||
from langchain_exa._utilities import initialize_client
|
||||
|
||||
|
||||
def _get_metadata(result: Any) -> Dict[str, Any]:
|
||||
def _get_metadata(result: Any) -> dict[str, Any]:
|
||||
"""Get the metadata from a result object."""
|
||||
metadata = {
|
||||
"title": result.title,
|
||||
@ -35,9 +35,9 @@ class ExaSearchRetriever(BaseRetriever):
|
||||
|
||||
k: int = 10 # num_results
|
||||
"""The number of search results to return."""
|
||||
include_domains: Optional[List[str]] = None
|
||||
include_domains: Optional[list[str]] = None
|
||||
"""A list of domains to include in the search."""
|
||||
exclude_domains: Optional[List[str]] = None
|
||||
exclude_domains: Optional[list[str]] = None
|
||||
"""A list of domains to exclude from the search."""
|
||||
start_crawl_date: Optional[str] = None
|
||||
"""The start date for the crawl (in YYYY-MM-DD format)."""
|
||||
@ -62,14 +62,14 @@ class ExaSearchRetriever(BaseRetriever):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
def validate_environment(cls, values: dict) -> Any:
|
||||
"""Validate the environment."""
|
||||
values = initialize_client(values)
|
||||
return values
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
response = self.client.search_and_contents( # type: ignore[misc]
|
||||
query,
|
||||
num_results=self.k,
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Tool for the Exa Search API."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from exa_py import Exa # type: ignore[untyped-import]
|
||||
from exa_py.api import (
|
||||
@ -66,7 +66,7 @@ class ExaSearchResults(BaseTool): # type: ignore[override]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
def validate_environment(cls, values: dict) -> Any:
|
||||
"""Validate the environment."""
|
||||
values = initialize_client(values)
|
||||
return values
|
||||
@ -77,15 +77,15 @@ class ExaSearchResults(BaseTool): # type: ignore[override]
|
||||
num_results: int,
|
||||
text_contents_options: Optional[Union[TextContentsOptions, bool]] = None,
|
||||
highlights: Optional[Union[HighlightsContentsOptions, bool]] = None,
|
||||
include_domains: Optional[List[str]] = None,
|
||||
exclude_domains: Optional[List[str]] = None,
|
||||
include_domains: Optional[list[str]] = None,
|
||||
exclude_domains: Optional[list[str]] = None,
|
||||
start_crawl_date: Optional[str] = None,
|
||||
end_crawl_date: Optional[str] = None,
|
||||
start_published_date: Optional[str] = None,
|
||||
end_published_date: Optional[str] = None,
|
||||
use_autoprompt: Optional[bool] = None,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> Union[List[Dict], str]:
|
||||
) -> Union[list[dict], str]:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
return self.client.search_and_contents(
|
||||
@ -120,7 +120,7 @@ class ExaFindSimilarResults(BaseTool): # type: ignore[override]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
def validate_environment(cls, values: dict) -> Any:
|
||||
"""Validate the environment."""
|
||||
values = initialize_client(values)
|
||||
return values
|
||||
@ -131,8 +131,8 @@ class ExaFindSimilarResults(BaseTool): # type: ignore[override]
|
||||
num_results: int,
|
||||
text_contents_options: Optional[Union[TextContentsOptions, bool]] = None,
|
||||
highlights: Optional[Union[HighlightsContentsOptions, bool]] = None,
|
||||
include_domains: Optional[List[str]] = None,
|
||||
exclude_domains: Optional[List[str]] = None,
|
||||
include_domains: Optional[list[str]] = None,
|
||||
exclude_domains: Optional[list[str]] = None,
|
||||
start_crawl_date: Optional[str] = None,
|
||||
end_crawl_date: Optional[str] = None,
|
||||
start_published_date: Optional[str] = None,
|
||||
@ -140,7 +140,7 @@ class ExaFindSimilarResults(BaseTool): # type: ignore[override]
|
||||
exclude_source_domain: Optional[bool] = None,
|
||||
category: Optional[str] = None,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> Union[List[Dict], str]:
|
||||
) -> Union[list[dict], str]:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
return self.client.find_similar_and_contents(
|
||||
|
@ -45,8 +45,12 @@ langchain-core = { path = "../../core", editable = true }
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "T201"]
|
||||
select = ["E", "F", "I", "T201", "UP"]
|
||||
ignore = [ "UP007", ]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
@ -4,20 +4,13 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
@ -109,7 +102,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
# Fix for azure
|
||||
# Also Fireworks returns None for tool invocations
|
||||
content = _dict.get("content", "") or ""
|
||||
additional_kwargs: Dict = {}
|
||||
additional_kwargs: dict = {}
|
||||
if function_call := _dict.get("function_call"):
|
||||
additional_kwargs["function_call"] = dict(function_call)
|
||||
tool_calls = []
|
||||
@ -157,7 +150,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
Returns:
|
||||
The dictionary.
|
||||
"""
|
||||
message_dict: Dict[str, Any]
|
||||
message_dict: dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
@ -205,14 +198,14 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
|
||||
|
||||
def _convert_chunk_to_message_chunk(
|
||||
chunk: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
chunk: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
choice = chunk["choices"][0]
|
||||
_dict = choice["delta"]
|
||||
role = cast(str, _dict.get("role"))
|
||||
content = cast(str, _dict.get("content") or "")
|
||||
additional_kwargs: Dict = {}
|
||||
tool_call_chunks: List[ToolCallChunk] = []
|
||||
additional_kwargs: dict = {}
|
||||
tool_call_chunks: list[ToolCallChunk] = []
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
@ -290,17 +283,17 @@ class ChatFireworks(BaseChatModel):
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "fireworks"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
def lc_attributes(self) -> dict[str, Any]:
|
||||
attributes: dict[str, Any] = {}
|
||||
if self.fireworks_api_base:
|
||||
attributes["fireworks_api_base"] = self.fireworks_api_base
|
||||
|
||||
@ -319,9 +312,9 @@ class ChatFireworks(BaseChatModel):
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.0
|
||||
"""What sampling temperature to use."""
|
||||
stop: Optional[Union[str, List[str]]] = Field(default=None, alias="stop_sequences")
|
||||
stop: Optional[Union[str, list[str]]] = Field(default=None, alias="stop_sequences")
|
||||
"""Default stop sequences."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
fireworks_api_key: SecretStr = Field(
|
||||
alias="api_key",
|
||||
@ -344,7 +337,7 @@ class ChatFireworks(BaseChatModel):
|
||||
)
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||
request_timeout: Union[float, tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
"""Timeout for requests to Fireworks completion API. Can be float, httpx.Timeout or
|
||||
@ -364,7 +357,7 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
@ -398,7 +391,7 @@ class ChatFireworks(BaseChatModel):
|
||||
return self
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling Fireworks API."""
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
@ -413,7 +406,7 @@ class ChatFireworks(BaseChatModel):
|
||||
return params
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
@ -429,7 +422,7 @@ class ChatFireworks(BaseChatModel):
|
||||
ls_params["ls_stop"] = ls_stop
|
||||
return ls_params
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
system_fingerprint = None
|
||||
for output in llm_outputs:
|
||||
@ -452,15 +445,15 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
for chunk in self.client.create(messages=message_dicts, **params):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
@ -487,8 +480,8 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
@ -509,8 +502,8 @@ class ChatFireworks(BaseChatModel):
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
self, messages: list[BaseMessage], stop: Optional[list[str]]
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
params = self._default_params
|
||||
if stop is not None:
|
||||
params["stop"] = stop
|
||||
@ -547,15 +540,15 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
async for chunk in self.async_client.acreate(messages=message_dicts, **params):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
@ -584,8 +577,8 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
@ -607,13 +600,13 @@ class ChatFireworks(BaseChatModel):
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {"model_name": self.model_name, **self._default_params}
|
||||
|
||||
def _get_invocation_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
return {
|
||||
"model": self.model_name,
|
||||
@ -634,7 +627,7 @@ class ChatFireworks(BaseChatModel):
|
||||
)
|
||||
def bind_functions(
|
||||
self,
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
function_call: Optional[
|
||||
Union[_FunctionCall, str, Literal["auto", "none"]]
|
||||
] = None,
|
||||
@ -690,7 +683,7 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "any", "none"], bool]
|
||||
@ -738,14 +731,14 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[Union[Dict, Type[BaseModel]]] = None,
|
||||
schema: Optional[Union[dict, type[BaseModel]]] = None,
|
||||
*,
|
||||
method: Literal[
|
||||
"function_calling", "json_mode", "json_schema"
|
||||
] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
|
@ -1,5 +1,3 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils import secret_from_env
|
||||
from openai import OpenAI
|
||||
@ -96,13 +94,13 @@ class FireworksEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return self
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs."""
|
||||
return [
|
||||
i.embedding
|
||||
for i in self.client.embeddings.create(input=texts, model=self.model).data
|
||||
]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text."""
|
||||
return self.embed_documents([text])[0]
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Wrapper around Fireworks AI's Completion API."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from aiohttp import ClientSession
|
||||
@ -63,7 +63,7 @@ class Fireworks(LLM):
|
||||
for question answering or summarization. A value greater than 1 introduces more
|
||||
randomness in the output.
|
||||
"""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
top_k: Optional[int] = None
|
||||
"""Used to limit the number of choices for the next predicted word or token. It
|
||||
@ -90,7 +90,7 @@ class Fireworks(LLM):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
@ -109,7 +109,7 @@ class Fireworks(LLM):
|
||||
return f"langchain-fireworks/{__version__}"
|
||||
|
||||
@property
|
||||
def default_params(self) -> Dict[str, Any]:
|
||||
def default_params(self) -> dict[str, Any]:
|
||||
return {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
@ -122,7 +122,7 @@ class Fireworks(LLM):
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@ -139,7 +139,7 @@ class Fireworks(LLM):
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
stop_to_use = stop[0] if stop and len(stop) == 1 else stop
|
||||
payload: Dict[str, Any] = {
|
||||
payload: dict[str, Any] = {
|
||||
**self.default_params,
|
||||
"prompt": prompt,
|
||||
"stop": stop_to_use,
|
||||
@ -168,7 +168,7 @@ class Fireworks(LLM):
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@ -185,7 +185,7 @@ class Fireworks(LLM):
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
stop_to_use = stop[0] if stop and len(stop) == 1 else stop
|
||||
payload: Dict[str, Any] = {
|
||||
payload: dict[str, Any] = {
|
||||
**self.default_params,
|
||||
"prompt": prompt,
|
||||
"stop": stop_to_use,
|
||||
|
@ -48,8 +48,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "T201"]
|
||||
select = ["E", "F", "I", "T201", "UP"]
|
||||
ignore = [ "UP007", ]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
@ -4,12 +4,12 @@ You will need FIREWORKS_API_KEY set in your environment to run these tests.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Annotated, Any, Literal, Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_fireworks import ChatFireworks
|
||||
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
@ -14,7 +12,7 @@ from langchain_fireworks import ChatFireworks
|
||||
|
||||
class TestFireworksStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return ChatFireworks
|
||||
|
||||
@property
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Tuple, Type
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests
|
||||
|
||||
@ -10,7 +8,7 @@ from langchain_fireworks import FireworksEmbeddings
|
||||
|
||||
class TestFireworksStandard(EmbeddingsUnitTests):
|
||||
@property
|
||||
def embeddings_class(self) -> Type[Embeddings]:
|
||||
def embeddings_class(self) -> type[Embeddings]:
|
||||
return FireworksEmbeddings
|
||||
|
||||
@property
|
||||
@ -18,7 +16,7 @@ class TestFireworksStandard(EmbeddingsUnitTests):
|
||||
return {"api_key": "test_api_key"}
|
||||
|
||||
@property
|
||||
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
|
||||
def init_from_env_params(self) -> tuple[dict, dict, dict]:
|
||||
return (
|
||||
{
|
||||
"FIREWORKS_API_KEY": "api_key",
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Tuple, Type
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_tests.unit_tests import ( # type: ignore[import-not-found]
|
||||
ChatModelUnitTests, # type: ignore[import-not-found]
|
||||
@ -12,7 +10,7 @@ from langchain_fireworks import ChatFireworks
|
||||
|
||||
class TestFireworksStandard(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return ChatFireworks
|
||||
|
||||
@property
|
||||
@ -20,7 +18,7 @@ class TestFireworksStandard(ChatModelUnitTests):
|
||||
return {"api_key": "test_api_key"}
|
||||
|
||||
@property
|
||||
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
|
||||
def init_from_env_params(self) -> tuple[dict, dict, dict]:
|
||||
return (
|
||||
{
|
||||
"FIREWORKS_API_KEY": "api_key",
|
||||
|
@ -4,20 +4,13 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
@ -307,9 +300,9 @@ class ChatGroq(BaseChatModel):
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences")
|
||||
stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences")
|
||||
"""Default stop sequences."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
groq_api_key: Optional[SecretStr] = Field(
|
||||
alias="api_key", default_factory=secret_from_env("GROQ_API_KEY", default=None)
|
||||
@ -324,7 +317,7 @@ class ChatGroq(BaseChatModel):
|
||||
groq_proxy: Optional[str] = Field(
|
||||
default_factory=from_env("GROQ_PROXY", default=None)
|
||||
)
|
||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||
request_timeout: Union[float, tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
"""Timeout for requests to Groq completion API. Can be float, httpx.Timeout or
|
||||
@ -353,7 +346,7 @@ class ChatGroq(BaseChatModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
@ -392,7 +385,7 @@ class ChatGroq(BaseChatModel):
|
||||
self.default_headers or {}
|
||||
)
|
||||
|
||||
client_params: Dict[str, Any] = {
|
||||
client_params: dict[str, Any] = {
|
||||
"api_key": (
|
||||
self.groq_api_key.get_secret_value() if self.groq_api_key else None
|
||||
),
|
||||
@ -406,13 +399,13 @@ class ChatGroq(BaseChatModel):
|
||||
try:
|
||||
import groq
|
||||
|
||||
sync_specific: Dict[str, Any] = {"http_client": self.http_client}
|
||||
sync_specific: dict[str, Any] = {"http_client": self.http_client}
|
||||
if not self.client:
|
||||
self.client = groq.Groq(
|
||||
**client_params, **sync_specific
|
||||
).chat.completions
|
||||
if not self.async_client:
|
||||
async_specific: Dict[str, Any] = {"http_client": self.http_async_client}
|
||||
async_specific: dict[str, Any] = {"http_client": self.http_async_client}
|
||||
self.async_client = groq.AsyncGroq(
|
||||
**client_params, **async_specific
|
||||
).chat.completions
|
||||
@ -427,7 +420,7 @@ class ChatGroq(BaseChatModel):
|
||||
# Serializable class method overrides
|
||||
#
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {"groq_api_key": "GROQ_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
@ -444,7 +437,7 @@ class ChatGroq(BaseChatModel):
|
||||
return "groq-chat"
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
@ -480,8 +473,8 @@ class ChatGroq(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -500,8 +493,8 @@ class ChatGroq(BaseChatModel):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -521,8 +514,8 @@ class ChatGroq(BaseChatModel):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
@ -530,7 +523,7 @@ class ChatGroq(BaseChatModel):
|
||||
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
for chunk in self.client.create(messages=message_dicts, **params):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
@ -560,8 +553,8 @@ class ChatGroq(BaseChatModel):
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
@ -569,7 +562,7 @@ class ChatGroq(BaseChatModel):
|
||||
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
async for chunk in await self.async_client.create(
|
||||
messages=message_dicts, **params
|
||||
):
|
||||
@ -605,7 +598,7 @@ class ChatGroq(BaseChatModel):
|
||||
# Internal methods
|
||||
#
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling Groq API."""
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
@ -652,15 +645,15 @@ class ChatGroq(BaseChatModel):
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
self, messages: list[BaseMessage], stop: Optional[list[str]]
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
params = self._default_params
|
||||
if stop is not None:
|
||||
params["stop"] = stop
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
system_fingerprint = None
|
||||
for output in llm_outputs:
|
||||
@ -688,7 +681,7 @@ class ChatGroq(BaseChatModel):
|
||||
)
|
||||
def bind_functions(
|
||||
self,
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
function_call: Optional[
|
||||
Union[_FunctionCall, str, Literal["auto", "none"]]
|
||||
] = None,
|
||||
@ -743,7 +736,7 @@ class ChatGroq(BaseChatModel):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "any", "none"], bool]
|
||||
@ -791,12 +784,12 @@ class ChatGroq(BaseChatModel):
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[Union[Dict, Type[BaseModel]]] = None,
|
||||
schema: Optional[Union[dict, type[BaseModel]]] = None,
|
||||
*,
|
||||
method: Literal["function_calling", "json_mode"] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
@ -1096,7 +1089,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
Returns:
|
||||
The dictionary.
|
||||
"""
|
||||
message_dict: Dict[str, Any]
|
||||
message_dict: dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
@ -1142,13 +1135,13 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
|
||||
|
||||
def _convert_chunk_to_message_chunk(
|
||||
chunk: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
chunk: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
choice = chunk["choices"][0]
|
||||
_dict = choice["delta"]
|
||||
role = cast(str, _dict.get("role"))
|
||||
content = cast(str, _dict.get("content") or "")
|
||||
additional_kwargs: Dict = {}
|
||||
additional_kwargs: dict = {}
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
@ -1202,7 +1195,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
return HumanMessage(content=_dict.get("content", ""))
|
||||
elif role == "assistant":
|
||||
content = _dict.get("content", "") or ""
|
||||
additional_kwargs: Dict = {}
|
||||
additional_kwargs: dict = {}
|
||||
if function_call := _dict.get("function_call"):
|
||||
additional_kwargs["function_call"] = dict(function_call)
|
||||
tool_calls = []
|
||||
|
@ -40,8 +40,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "W"]
|
||||
select = ["E", "F", "I", "W", "UP"]
|
||||
ignore = [ "UP007", ]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||
@ -17,7 +15,7 @@ rate_limiter = InMemoryRateLimiter(requests_per_second=0.2)
|
||||
|
||||
class BaseTestGroq(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return ChatGroq
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet implemented.")
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""A fake callback handler for testing purposes."""
|
||||
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
@ -15,7 +15,7 @@ class BaseFakeCallbackHandler(BaseModel):
|
||||
starts: int = 0
|
||||
ends: int = 0
|
||||
errors: int = 0
|
||||
errors_args: List[Any] = []
|
||||
errors_args: list[Any] = []
|
||||
text: int = 0
|
||||
ignore_llm_: bool = False
|
||||
ignore_chain_: bool = False
|
||||
@ -264,8 +264,8 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_tests.unit_tests.chat_models import (
|
||||
ChatModelUnitTests,
|
||||
@ -12,7 +10,7 @@ from langchain_groq import ChatGroq
|
||||
|
||||
class TestGroqStandard(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return ChatGroq
|
||||
|
||||
@property
|
||||
|
@ -1,16 +1,13 @@
|
||||
"""Hugging Face Chat Wrapper."""
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -46,8 +43,8 @@ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."
|
||||
class TGI_RESPONSE:
|
||||
"""Response from the TextGenInference API."""
|
||||
|
||||
choices: List[Any]
|
||||
usage: Dict
|
||||
choices: list[Any]
|
||||
usage: dict
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -56,12 +53,12 @@ class TGI_MESSAGE:
|
||||
|
||||
role: str
|
||||
content: str
|
||||
tool_calls: List[Dict]
|
||||
tool_calls: list[dict]
|
||||
|
||||
|
||||
def _convert_message_to_chat_message(
|
||||
message: BaseMessage,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
return dict(role=message.role, content=message.content)
|
||||
elif isinstance(message, HumanMessage):
|
||||
@ -104,7 +101,7 @@ def _convert_TGI_message_to_LC_message(
|
||||
content = cast(str, _message.content)
|
||||
if content is None:
|
||||
content = ""
|
||||
additional_kwargs: Dict = {}
|
||||
additional_kwargs: dict = {}
|
||||
if tool_calls := _message.tool_calls:
|
||||
if "arguments" in tool_calls[0]["function"]:
|
||||
functions = tool_calls[0]["function"].pop("arguments")
|
||||
@ -358,8 +355,8 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -380,8 +377,8 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -398,7 +395,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
def _to_chat_prompt(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: list[BaseMessage],
|
||||
) -> str:
|
||||
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
|
||||
if not messages:
|
||||
@ -472,7 +469,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required"], bool]
|
||||
@ -529,8 +526,8 @@ class ChatHuggingFace(BaseChatModel):
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> List[Dict[Any, Any]]:
|
||||
self, messages: list[BaseMessage], stop: Optional[list[str]]
|
||||
) -> list[dict[Any, Any]]:
|
||||
message_dicts = [_convert_message_to_chat_message(m) for m in messages]
|
||||
return message_dicts
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@ -40,16 +40,16 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
cache_folder: Optional[str] = None
|
||||
"""Path to store models.
|
||||
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the Sentence Transformer model, such as `device`,
|
||||
`prompts`, `default_prompt_name`, `revision`, `trust_remote_code`, or `token`.
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer"""
|
||||
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
encode_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method for the documents of
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
`precision`, `normalize_embeddings`, and more.
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
|
||||
query_encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
query_encode_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method for the query of
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
`precision`, `normalize_embeddings`, and more.
|
||||
@ -102,8 +102,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
|
||||
def _embed(
|
||||
self, texts: list[str], encode_kwargs: Dict[str, Any]
|
||||
) -> List[List[float]]:
|
||||
self, texts: list[str], encode_kwargs: dict[str, Any]
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Embed a text using the HuggingFace transformer model.
|
||||
|
||||
@ -138,7 +138,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
@ -149,7 +149,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
return self._embed(texts, self.encode_kwargs)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Compute query embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
|
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils import from_env
|
||||
@ -101,7 +101,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return self
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Call out to HuggingFaceHub's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
@ -117,7 +117,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
responses = self.client.feature_extraction(text=texts, **_model_kwargs)
|
||||
return responses
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Async Call to HuggingFaceHub's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
@ -134,7 +134,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return responses
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Call out to HuggingFaceHub's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
@ -146,7 +146,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
response = self.embed_documents([text])[0]
|
||||
return response
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Async Call to HuggingFaceHub's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
|
@ -2,7 +2,8 @@ import inspect
|
||||
import json # type: ignore[import-not-found]
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -96,7 +97,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
"""Whether to prepend the prompt to the generated text"""
|
||||
truncate: Optional[int] = None
|
||||
"""Truncate inputs tokens to the given size"""
|
||||
stop_sequences: List[str] = Field(default_factory=list)
|
||||
stop_sequences: list[str] = Field(default_factory=list)
|
||||
"""Stop generating tokens if a member of `stop_sequences` is generated"""
|
||||
seed: Optional[int] = None
|
||||
"""Random sampling seed"""
|
||||
@ -111,9 +112,9 @@ class HuggingFaceEndpoint(LLM):
|
||||
watermark: bool = False
|
||||
"""Watermarking with [A Watermark for Large Language Models]
|
||||
(https://arxiv.org/abs/2301.10226)"""
|
||||
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
server_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any text-generation-inference server parameters not explicitly specified"""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `call` not explicitly specified"""
|
||||
model: str
|
||||
client: Any = None #: :meta private:
|
||||
@ -128,7 +129,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
@ -252,7 +253,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
return self
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling text generation inference API."""
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
@ -285,8 +286,8 @@ class HuggingFaceEndpoint(LLM):
|
||||
return "huggingface_endpoint"
|
||||
|
||||
def _invocation_params(
|
||||
self, runtime_stop: Optional[List[str]], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
self, runtime_stop: Optional[list[str]], **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
params = {**self._default_params, **kwargs}
|
||||
params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or [])
|
||||
return params
|
||||
@ -294,7 +295,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@ -326,7 +327,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@ -357,7 +358,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
@ -394,7 +395,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
|
@ -2,7 +2,8 @@ from __future__ import annotations # type: ignore[import-not-found]
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||
from collections.abc import Iterator, Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import BaseLLM
|
||||
@ -82,7 +83,7 @@ class HuggingFacePipeline(BaseLLM):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_init_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def pre_init_validator(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Ensure model_id is set either by pipeline or user input."""
|
||||
if "model_id" not in values:
|
||||
if "pipeline" in values and values["pipeline"]:
|
||||
@ -297,13 +298,13 @@ class HuggingFacePipeline(BaseLLM):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
# List to hold all results
|
||||
text_generations: List[str] = []
|
||||
text_generations: list[str] = []
|
||||
pipeline_kwargs = kwargs.get("pipeline_kwargs", {})
|
||||
skip_prompt = kwargs.get("skip_prompt", False)
|
||||
|
||||
@ -347,7 +348,7 @@ class HuggingFacePipeline(BaseLLM):
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
|
@ -49,8 +49,12 @@ langchain-community = { path = "../../community", editable = true }
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "T201"]
|
||||
select = ["E", "F", "I", "T201", "UP"]
|
||||
ignore = [ "UP007", ]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Test HuggingFace embeddings."""
|
||||
|
||||
from typing import Type
|
||||
|
||||
from langchain_tests.integration_tests import EmbeddingsIntegrationTests
|
||||
|
||||
from langchain_huggingface.embeddings import (
|
||||
@ -12,7 +10,7 @@ from langchain_huggingface.embeddings import (
|
||||
|
||||
class TestHuggingFaceEmbeddings(EmbeddingsIntegrationTests):
|
||||
@property
|
||||
def embeddings_class(self) -> Type[HuggingFaceEmbeddings]:
|
||||
def embeddings_class(self) -> type[HuggingFaceEmbeddings]:
|
||||
return HuggingFaceEmbeddings
|
||||
|
||||
@property
|
||||
@ -22,7 +20,7 @@ class TestHuggingFaceEmbeddings(EmbeddingsIntegrationTests):
|
||||
|
||||
class TestHuggingFaceEndpointEmbeddings(EmbeddingsIntegrationTests):
|
||||
@property
|
||||
def embeddings_class(self) -> Type[HuggingFaceEndpointEmbeddings]:
|
||||
def embeddings_class(self) -> type[HuggingFaceEndpointEmbeddings]:
|
||||
return HuggingFaceEndpointEmbeddings
|
||||
|
||||
@property
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Generator
|
||||
from collections.abc import Generator
|
||||
|
||||
from langchain_huggingface.llms import HuggingFacePipeline
|
||||
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
@ -12,7 +10,7 @@ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
|
||||
|
||||
class TestHuggingFaceEndpoint(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return ChatHuggingFace
|
||||
|
||||
@property
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List # type: ignore[import-not-found]
|
||||
from typing import Any # type: ignore[import-not-found]
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest # type: ignore[import-not-found]
|
||||
@ -45,7 +45,7 @@ from langchain_huggingface.llms.huggingface_endpoint import (
|
||||
],
|
||||
)
|
||||
def test_convert_message_to_chat_message(
|
||||
message: BaseMessage, expected: Dict[str, str]
|
||||
message: BaseMessage, expected: dict[str, str]
|
||||
) -> None:
|
||||
result = _convert_message_to_chat_message(message)
|
||||
assert result == expected
|
||||
@ -150,7 +150,7 @@ def test_create_chat_result(chat_hugging_face: Any) -> None:
|
||||
],
|
||||
)
|
||||
def test_to_chat_prompt_errors(
|
||||
chat_hugging_face: Any, messages: List[BaseMessage], expected_error: str
|
||||
chat_hugging_face: Any, messages: list[BaseMessage], expected_error: str
|
||||
) -> None:
|
||||
with pytest.raises(ValueError) as e:
|
||||
chat_hugging_face._to_chat_prompt(messages)
|
||||
@ -194,7 +194,7 @@ def test_to_chat_prompt_valid_messages(chat_hugging_face: Any) -> None:
|
||||
],
|
||||
)
|
||||
def test_to_chatml_format(
|
||||
chat_hugging_face: Any, message: BaseMessage, expected: Dict[str, str]
|
||||
chat_hugging_face: Any, message: BaseMessage, expected: dict[str, str]
|
||||
) -> None:
|
||||
result = chat_hugging_face._to_chatml_format(message)
|
||||
assert result == expected
|
||||
@ -207,7 +207,7 @@ def test_to_chatml_format_with_invalid_type(chat_hugging_face: Any) -> None:
|
||||
assert "Unknown message type:" in str(e.value)
|
||||
|
||||
|
||||
def tool_mock() -> Dict:
|
||||
def tool_mock() -> dict:
|
||||
return {"function": {"name": "test_tool"}}
|
||||
|
||||
|
||||
@ -232,7 +232,7 @@ def tool_mock() -> Dict:
|
||||
)
|
||||
def test_bind_tools_errors(
|
||||
chat_hugging_face: Any,
|
||||
tools: Dict[str, str],
|
||||
tools: dict[str, str],
|
||||
tool_choice: Any,
|
||||
expected_exception: Any,
|
||||
expected_message: str,
|
||||
|
@ -7,20 +7,14 @@ import os
|
||||
import re
|
||||
import ssl
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -142,13 +136,13 @@ def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str:
|
||||
|
||||
|
||||
def _convert_mistral_chat_message_to_message(
|
||||
_message: Dict,
|
||||
_message: dict,
|
||||
) -> BaseMessage:
|
||||
role = _message["role"]
|
||||
assert role == "assistant", f"Expected role to be 'assistant', got {role}"
|
||||
content = cast(str, _message["content"])
|
||||
|
||||
additional_kwargs: Dict = {}
|
||||
additional_kwargs: dict = {}
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
if raw_tool_calls := _message.get("tool_calls"):
|
||||
@ -196,8 +190,8 @@ async def _araise_on_error(response: httpx.Response) -> None:
|
||||
|
||||
|
||||
async def _aiter_sse(
|
||||
event_source_mgr: AsyncContextManager[EventSource],
|
||||
) -> AsyncIterator[Dict]:
|
||||
event_source_mgr: AbstractAsyncContextManager[EventSource],
|
||||
) -> AsyncIterator[dict]:
|
||||
"""Iterate over the server-sent events."""
|
||||
async with event_source_mgr as event_source:
|
||||
await _araise_on_error(event_source.response)
|
||||
@ -234,7 +228,7 @@ async def acompletion_with_retry(
|
||||
|
||||
|
||||
def _convert_chunk_to_message_chunk(
|
||||
chunk: Dict, default_class: Type[BaseMessageChunk]
|
||||
chunk: dict, default_class: type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
_choice = chunk["choices"][0]
|
||||
_delta = _choice["delta"]
|
||||
@ -243,7 +237,7 @@ def _convert_chunk_to_message_chunk(
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
additional_kwargs: Dict = {}
|
||||
additional_kwargs: dict = {}
|
||||
response_metadata = {}
|
||||
if raw_tool_calls := _delta.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
@ -295,7 +289,7 @@ def _convert_chunk_to_message_chunk(
|
||||
|
||||
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
|
||||
"""Format Langchain ToolCall to dict expected by Mistral."""
|
||||
result: Dict[str, Any] = {
|
||||
result: dict[str, Any] = {
|
||||
"function": {
|
||||
"name": tool_call["name"],
|
||||
"arguments": json.dumps(tool_call["args"]),
|
||||
@ -309,7 +303,7 @@ def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
|
||||
|
||||
def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) -> dict:
|
||||
"""Format Langchain InvalidToolCall to dict expected by Mistral."""
|
||||
result: Dict[str, Any] = {
|
||||
result: dict[str, Any] = {
|
||||
"function": {
|
||||
"name": invalid_tool_call["name"],
|
||||
"arguments": invalid_tool_call["args"],
|
||||
@ -323,13 +317,13 @@ def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) ->
|
||||
|
||||
def _convert_message_to_mistral_chat_message(
|
||||
message: BaseMessage,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
return dict(role=message.role, content=message.content)
|
||||
elif isinstance(message, HumanMessage):
|
||||
return dict(role="user", content=message.content)
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict: Dict[str, Any] = {"role": "assistant"}
|
||||
message_dict: dict[str, Any] = {"role": "assistant"}
|
||||
tool_calls = []
|
||||
if message.tool_calls or message.invalid_tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
@ -407,7 +401,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
random_seed: Optional[int] = None
|
||||
safe_mode: Optional[bool] = None
|
||||
streaming: bool = False
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any invocation parameters not explicitly specified."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
@ -417,14 +411,14 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling the API."""
|
||||
defaults = {
|
||||
"model": self.model,
|
||||
@ -439,7 +433,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
return filtered
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
@ -456,7 +450,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
return ls_params
|
||||
|
||||
@property
|
||||
def _client_params(self) -> Dict[str, Any]:
|
||||
def _client_params(self) -> dict[str, Any]:
|
||||
"""Get the parameters used for the client."""
|
||||
return self._default_params
|
||||
|
||||
@ -473,7 +467,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
stream = kwargs["stream"]
|
||||
if stream:
|
||||
|
||||
def iter_sse() -> Iterator[Dict]:
|
||||
def iter_sse() -> Iterator[dict]:
|
||||
with connect_sse(
|
||||
self.client, "POST", "/chat/completions", json=kwargs
|
||||
) as event_source:
|
||||
@ -492,7 +486,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
rtn = _completion_with_retry(**kwargs)
|
||||
return rtn
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
@ -557,8 +551,8 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
@ -577,7 +571,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_chat_result(self, response: Dict) -> ChatResult:
|
||||
def _create_chat_result(self, response: dict) -> ChatResult:
|
||||
generations = []
|
||||
token_usage = response.get("usage", {})
|
||||
for res in response["choices"]:
|
||||
@ -603,8 +597,8 @@ class ChatMistralAI(BaseChatModel):
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict], Dict[str, Any]]:
|
||||
self, messages: list[BaseMessage], stop: Optional[list[str]]
|
||||
) -> tuple[list[dict], dict[str, Any]]:
|
||||
params = self._client_params
|
||||
if stop is not None or "stop" in params:
|
||||
if "stop" in params:
|
||||
@ -617,15 +611,15 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
for chunk in self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
@ -643,15 +637,15 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
async for chunk in await acompletion_with_retry(
|
||||
self, messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
@ -669,8 +663,8 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
@ -691,7 +685,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||
tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
@ -733,14 +727,14 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[Union[Dict, Type]] = None,
|
||||
schema: Optional[Union[dict, type]] = None,
|
||||
*,
|
||||
method: Literal[
|
||||
"function_calling", "json_mode", "json_schema"
|
||||
] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
@ -1048,7 +1042,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
return llm | output_parser
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return self._default_params
|
||||
|
||||
@ -1058,7 +1052,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
return "mistralai-chat"
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {"mistral_api_key": "MISTRAL_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
@ -1067,14 +1061,14 @@ class ChatMistralAI(BaseChatModel):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "mistralai"]
|
||||
|
||||
|
||||
def _convert_to_openai_response_format(
|
||||
schema: Union[Dict[str, Any], Type], *, strict: Optional[bool] = None
|
||||
) -> Dict:
|
||||
schema: Union[dict[str, Any], type], *, strict: Optional[bool] = None
|
||||
) -> dict:
|
||||
"""Same as in ChatOpenAI, but don't pass through Pydantic BaseModels."""
|
||||
if (
|
||||
isinstance(schema, dict)
|
||||
@ -1094,8 +1088,10 @@ def _convert_to_openai_response_format(
|
||||
function["schema"] = function.pop("parameters")
|
||||
response_format = {"type": "json_schema", "json_schema": function}
|
||||
|
||||
if strict is not None and strict is not response_format["json_schema"].get(
|
||||
"strict"
|
||||
if (
|
||||
strict is not None
|
||||
and strict is not response_format["json_schema"].get("strict")
|
||||
and isinstance(schema, dict)
|
||||
):
|
||||
msg = (
|
||||
f"Output schema already has 'strict' value set to "
|
||||
|
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Iterable, List
|
||||
from collections.abc import Iterable
|
||||
|
||||
import httpx
|
||||
from httpx import Response
|
||||
@ -33,7 +33,7 @@ class DummyTokenizer:
|
||||
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)"""
|
||||
|
||||
@staticmethod
|
||||
def encode_batch(texts: List[str]) -> List[List[str]]:
|
||||
def encode_batch(texts: list[str]) -> list[list[str]]:
|
||||
return [list(text) for text in texts]
|
||||
|
||||
|
||||
@ -177,7 +177,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
self.tokenizer = Tokenizer.from_pretrained(
|
||||
"mistralai/Mixtral-8x7B-v0.1"
|
||||
)
|
||||
except IOError: # huggingface_hub GatedRepoError
|
||||
except OSError: # huggingface_hub GatedRepoError
|
||||
warnings.warn(
|
||||
"Could not download mistral tokenizer from Huggingface for "
|
||||
"calculating batch sizes. Set a Huggingface token via the "
|
||||
@ -187,10 +187,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
self.tokenizer = DummyTokenizer()
|
||||
return self
|
||||
|
||||
def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
|
||||
def _get_batches(self, texts: list[str]) -> Iterable[list[str]]:
|
||||
"""Split a list of texts into batches of less than 16k tokens
|
||||
for Mistral API."""
|
||||
batch: List[str] = []
|
||||
batch: list[str] = []
|
||||
batch_tokens = 0
|
||||
|
||||
text_token_lengths = [
|
||||
@ -211,7 +211,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed a list of document texts.
|
||||
|
||||
Args:
|
||||
@ -230,7 +230,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
wait=wait_fixed(self.wait_time),
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
)
|
||||
def _embed_batch(batch: List[str]) -> Response:
|
||||
def _embed_batch(batch: list[str]) -> Response:
|
||||
response = self.client.post(
|
||||
url="/embeddings",
|
||||
json=dict(
|
||||
@ -252,7 +252,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
logger.error(f"An error occurred with MistralAI: {e}")
|
||||
raise
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed a list of document texts.
|
||||
|
||||
Args:
|
||||
@ -283,7 +283,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
logger.error(f"An error occurred with MistralAI: {e}")
|
||||
raise
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed a single query text.
|
||||
|
||||
Args:
|
||||
@ -294,7 +294,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Embed a single query text.
|
||||
|
||||
Args:
|
||||
|
@ -44,8 +44,12 @@ langchain-tests = { path = "../../standard-tests", editable = true }
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "T201"]
|
||||
select = ["E", "F", "I", "T201", "UP"]
|
||||
ignore = [ "UP007", ]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_tests.integration_tests import ( # type: ignore[import-not-found]
|
||||
ChatModelIntegrationTests, # type: ignore[import-not-found]
|
||||
@ -12,7 +10,7 @@ from langchain_mistralai import ChatMistralAI
|
||||
|
||||
class TestMistralStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return ChatMistralAI
|
||||
|
||||
@property
|
||||
|
@ -1,7 +1,8 @@
|
||||
"""Test MistralAI Chat API wrapper."""
|
||||
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, List, cast
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
@ -104,13 +105,13 @@ def test_mistralai_initialization_baseurl_env(env_var_name: str) -> None:
|
||||
],
|
||||
)
|
||||
def test_convert_message_to_mistral_chat_message(
|
||||
message: BaseMessage, expected: Dict
|
||||
message: BaseMessage, expected: dict
|
||||
) -> None:
|
||||
result = _convert_message_to_mistral_chat_message(message)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def _make_completion_response_from_token(token: str) -> Dict:
|
||||
def _make_completion_response_from_token(token: str) -> dict:
|
||||
return dict(
|
||||
id="abc123",
|
||||
model="fake_model",
|
||||
@ -236,7 +237,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
|
||||
|
||||
def test_custom_token_counting() -> None:
|
||||
def token_encoder(text: str) -> List[int]:
|
||||
def token_encoder(text: str) -> list[int]:
|
||||
return [1, 2, 3]
|
||||
|
||||
llm = ChatMistralAI(custom_get_token_ids=token_encoder)
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_tests.unit_tests import ( # type: ignore[import-not-found]
|
||||
ChatModelUnitTests, # type: ignore[import-not-found]
|
||||
@ -12,5 +10,5 @@ from langchain_mistralai import ChatMistralAI
|
||||
|
||||
class TestMistralStandard(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return ChatMistralAI
|
||||
|
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import List, Literal, Optional, overload
|
||||
from typing import Literal, Optional, overload
|
||||
|
||||
import nomic # type: ignore[import]
|
||||
from langchain_core.embeddings import Embeddings
|
||||
@ -86,7 +86,7 @@ class NomicEmbeddings(Embeddings):
|
||||
self.device = device
|
||||
self.vision_model = vision_model
|
||||
|
||||
def embed(self, texts: List[str], *, task_type: str) -> List[List[float]]:
|
||||
def embed(self, texts: list[str], *, task_type: str) -> list[list[float]]:
|
||||
"""Embed texts.
|
||||
|
||||
Args:
|
||||
@ -105,7 +105,7 @@ class NomicEmbeddings(Embeddings):
|
||||
)
|
||||
return output["embeddings"]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs.
|
||||
|
||||
Args:
|
||||
@ -116,7 +116,7 @@ class NomicEmbeddings(Embeddings):
|
||||
task_type="search_document",
|
||||
)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text.
|
||||
|
||||
Args:
|
||||
@ -127,7 +127,7 @@ class NomicEmbeddings(Embeddings):
|
||||
task_type="search_query",
|
||||
)[0]
|
||||
|
||||
def embed_image(self, uris: List[str]) -> List[List[float]]:
|
||||
def embed_image(self, uris: list[str]) -> list[list[float]]:
|
||||
return embed.image(
|
||||
images=uris,
|
||||
model=self.vision_model,
|
||||
|
@ -40,13 +40,18 @@ dev = ["langchain-core"]
|
||||
[tool.uv.sources]
|
||||
langchain-core = { path = "../../core", editable = true }
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"T201", # print
|
||||
"UP", # pyupgrade
|
||||
]
|
||||
ignore = [ "UP007", ]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
@ -3,7 +3,6 @@
|
||||
It provides infrastructure for interacting with the Ollama service.
|
||||
"""
|
||||
|
||||
|
||||
from importlib import metadata
|
||||
|
||||
from langchain_ollama.chat_models import ChatOllama
|
||||
|
@ -1,21 +1,14 @@
|
||||
"""Ollama chat models."""
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Final,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -153,7 +146,7 @@ def _parse_arguments_from_tool_call(
|
||||
|
||||
def _get_tool_calls_from_response(
|
||||
response: Mapping[str, Any],
|
||||
) -> List[ToolCall]:
|
||||
) -> list[ToolCall]:
|
||||
"""Get tool calls from ollama response."""
|
||||
tool_calls = []
|
||||
if "message" in response:
|
||||
@ -341,7 +334,7 @@ class ChatOllama(BaseChatModel):
|
||||
model: str
|
||||
"""Model name to use."""
|
||||
|
||||
extract_reasoning: Optional[Union[bool, Tuple[str, str]]] = False
|
||||
extract_reasoning: Optional[Union[bool, tuple[str, str]]] = False
|
||||
"""Whether to extract the reasoning tokens in think blocks.
|
||||
Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`.
|
||||
If a tuple is supplied, they are assumed to be the (start, end) tokens.
|
||||
@ -399,7 +392,7 @@ class ChatOllama(BaseChatModel):
|
||||
to a specific number will make the model generate the same text for
|
||||
the same prompt."""
|
||||
|
||||
stop: Optional[List[str]] = None
|
||||
stop: Optional[list[str]] = None
|
||||
"""Sets the stop tokens to use."""
|
||||
|
||||
tfs_z: Optional[float] = None
|
||||
@ -443,10 +436,10 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def _chat_params(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
ollama_messages = self._convert_messages_to_ollama_messages(messages)
|
||||
|
||||
if self.stop is not None and stop is not None:
|
||||
@ -499,13 +492,13 @@ class ChatOllama(BaseChatModel):
|
||||
return self
|
||||
|
||||
def _convert_messages_to_ollama_messages(
|
||||
self, messages: List[BaseMessage]
|
||||
self, messages: list[BaseMessage]
|
||||
) -> Sequence[Message]:
|
||||
ollama_messages: List = []
|
||||
ollama_messages: list = []
|
||||
for message in messages:
|
||||
role: Literal["user", "assistant", "system", "tool"]
|
||||
tool_call_id: Optional[str] = None
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||
tool_calls: Optional[list[dict[str, Any]]] = None
|
||||
if isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
elif isinstance(message, AIMessage):
|
||||
@ -531,7 +524,7 @@ class ChatOllama(BaseChatModel):
|
||||
if isinstance(message.content, str):
|
||||
content = message.content
|
||||
else:
|
||||
for content_part in cast(List[Dict], message.content):
|
||||
for content_part in cast(list[dict], message.content):
|
||||
if content_part.get("type") == "text":
|
||||
content += f"\n{content_part['text']}"
|
||||
elif content_part.get("type") == "tool_use":
|
||||
@ -583,7 +576,7 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def _extract_reasoning(
|
||||
self, message_chunk: BaseMessageChunk, is_thinking: bool
|
||||
) -> Tuple[BaseMessageChunk, bool]:
|
||||
) -> tuple[BaseMessageChunk, bool]:
|
||||
"""Mutate a message chunk to extract reasoning content."""
|
||||
if not self.extract_reasoning:
|
||||
return message_chunk, is_thinking
|
||||
@ -605,8 +598,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
async def _acreate_chat_stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
|
||||
chat_params = self._chat_params(messages, stop, **kwargs)
|
||||
@ -619,8 +612,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def _create_chat_stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Union[Mapping[str, Any], str]]:
|
||||
chat_params = self._chat_params(messages, stop, **kwargs)
|
||||
@ -632,8 +625,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def _chat_stream_with_aggregation(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
@ -657,8 +650,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
async def _achat_stream_with_aggregation(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
@ -681,7 +674,7 @@ class ChatOllama(BaseChatModel):
|
||||
return final_chunk
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
@ -697,8 +690,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -719,8 +712,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def _iterate_over_stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
is_thinking = False
|
||||
@ -758,8 +751,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
@ -773,8 +766,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
async def _aiterate_over_stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
is_thinking = False
|
||||
@ -812,8 +805,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
@ -827,8 +820,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -854,7 +847,7 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None,
|
||||
**kwargs: Any,
|
||||
@ -877,12 +870,12 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Union[Dict, type],
|
||||
schema: Union[dict, type],
|
||||
*,
|
||||
method: Literal["function_calling", "json_mode", "json_schema"] = "json_schema",
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Ollama embeddings models."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from ollama import AsyncClient, Client
|
||||
@ -188,7 +188,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
"""The temperature of the model. Increasing the temperature will
|
||||
make the model answer more creatively. (Default: 0.8)"""
|
||||
|
||||
stop: Optional[List[str]] = None
|
||||
stop: Optional[list[str]] = None
|
||||
"""Sets the stop tokens to use."""
|
||||
|
||||
tfs_z: Optional[float] = None
|
||||
@ -211,7 +211,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling Ollama."""
|
||||
return {
|
||||
"mirostat": self.mirostat,
|
||||
@ -237,18 +237,18 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
self._async_client = AsyncClient(host=self.base_url, **client_kwargs)
|
||||
return self
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs."""
|
||||
embedded_docs = self._client.embed(
|
||||
self.model, texts, options=self._default_params, keep_alive=self.keep_alive
|
||||
)["embeddings"]
|
||||
return embedded_docs
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text."""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs."""
|
||||
embedded_docs = (
|
||||
await self._async_client.embed(
|
||||
@ -257,6 +257,6 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
)["embeddings"]
|
||||
return embedded_docs
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text."""
|
||||
return (await self.aembed_documents([text]))[0]
|
||||
|
@ -1,13 +1,9 @@
|
||||
"""Ollama large language models."""
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
@ -89,7 +85,7 @@ class OllamaLLM(BaseLLM):
|
||||
to a specific number will make the model generate the same text for
|
||||
the same prompt."""
|
||||
|
||||
stop: Optional[List[str]] = None
|
||||
stop: Optional[list[str]] = None
|
||||
"""Sets the stop tokens to use."""
|
||||
|
||||
tfs_z: Optional[float] = None
|
||||
@ -134,9 +130,9 @@ class OllamaLLM(BaseLLM):
|
||||
def _generate_params(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
@ -181,7 +177,7 @@ class OllamaLLM(BaseLLM):
|
||||
return "ollama-llm"
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
@ -200,7 +196,7 @@ class OllamaLLM(BaseLLM):
|
||||
async def _acreate_generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
|
||||
async for part in await self._async_client.generate(
|
||||
@ -211,7 +207,7 @@ class OllamaLLM(BaseLLM):
|
||||
def _create_generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Union[Mapping[str, Any], str]]:
|
||||
yield from self._client.generate(
|
||||
@ -221,7 +217,7 @@ class OllamaLLM(BaseLLM):
|
||||
async def _astream_with_aggregation(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
@ -253,7 +249,7 @@ class OllamaLLM(BaseLLM):
|
||||
def _stream_with_aggregation(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
@ -284,8 +280,8 @@ class OllamaLLM(BaseLLM):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@ -303,8 +299,8 @@ class OllamaLLM(BaseLLM):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@ -323,7 +319,7 @@ class OllamaLLM(BaseLLM):
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
@ -345,7 +341,7 @@ class OllamaLLM(BaseLLM):
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
|
@ -40,6 +40,9 @@ langchain-tests = { path = "../../standard-tests", editable = true }
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
@ -47,8 +50,9 @@ select = [
|
||||
"I", # isort
|
||||
"T201", # print
|
||||
"D", # pydocstyle
|
||||
|
||||
"UP", # pyupgrade
|
||||
]
|
||||
ignore = [ "UP007", ]
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""load multiple Python files specified as command line arguments."""
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
@ -1,10 +1,10 @@
|
||||
"""Ollama specific chat model integration tests"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
@ -78,7 +78,7 @@ def test_structured_output_deeply_nested(model: str) -> None:
|
||||
class Data(BaseModel):
|
||||
"""Extracted data about people."""
|
||||
|
||||
people: List[Person]
|
||||
people: list[Person]
|
||||
|
||||
chat = llm.with_structured_output(Data) # type: ignore[arg-type]
|
||||
text = (
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Test chat model integration using standard integration tests."""
|
||||
|
||||
from typing import Type
|
||||
|
||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
from langchain_ollama.chat_models import ChatOllama
|
||||
@ -9,7 +7,7 @@ from langchain_ollama.chat_models import ChatOllama
|
||||
|
||||
class TestChatOllama(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[ChatOllama]:
|
||||
def chat_model_class(self) -> type[ChatOllama]:
|
||||
return ChatOllama
|
||||
|
||||
@property
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Test Ollama embeddings."""
|
||||
|
||||
from typing import Type
|
||||
|
||||
from langchain_tests.integration_tests import EmbeddingsIntegrationTests
|
||||
|
||||
from langchain_ollama.embeddings import OllamaEmbeddings
|
||||
@ -9,7 +7,7 @@ from langchain_ollama.embeddings import OllamaEmbeddings
|
||||
|
||||
class TestOllamaEmbeddings(EmbeddingsIntegrationTests):
|
||||
@property
|
||||
def embeddings_class(self) -> Type[OllamaEmbeddings]:
|
||||
def embeddings_class(self) -> type[OllamaEmbeddings]:
|
||||
return OllamaEmbeddings
|
||||
|
||||
@property
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Test chat model integration."""
|
||||
|
||||
import json
|
||||
from typing import Dict, Type
|
||||
|
||||
from langchain_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
@ -9,11 +9,11 @@ from langchain_ollama.chat_models import ChatOllama, _parse_arguments_from_tool_
|
||||
|
||||
class TestChatOllama(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[ChatOllama]:
|
||||
def chat_model_class(self) -> type[ChatOllama]:
|
||||
return ChatOllama
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> Dict:
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"model": "llama3-groq-tool-use"}
|
||||
|
||||
|
||||
|
@ -4,18 +4,8 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
@ -34,8 +24,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_BM = TypeVar("_BM", bound=BaseModel)
|
||||
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM]]
|
||||
_DictOrPydantic = Union[Dict, _BM]
|
||||
_DictOrPydanticClass = Union[dict[str, Any], type[_BM]]
|
||||
_DictOrPydantic = Union[dict, _BM]
|
||||
|
||||
|
||||
class _AllReturnType(TypedDict):
|
||||
@ -547,7 +537,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
Used for tracing and token counting. Does NOT affect completion.
|
||||
"""
|
||||
|
||||
disabled_params: Optional[Dict[str, Any]] = Field(default=None)
|
||||
disabled_params: Optional[dict[str, Any]] = Field(default=None)
|
||||
"""Parameters of the OpenAI client or chat.completions endpoint that should be
|
||||
disabled for the given model.
|
||||
|
||||
@ -570,12 +560,12 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "azure_openai"]
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {
|
||||
"openai_api_key": "AZURE_OPENAI_API_KEY",
|
||||
"azure_ad_token": "AZURE_OPENAI_AD_TOKEN",
|
||||
@ -672,7 +662,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
return self
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
**{"azure_deployment": self.deployment_name},
|
||||
@ -684,14 +674,14 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
return "azure-openai-chat"
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
def lc_attributes(self) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_api_type": self.openai_api_type,
|
||||
"openai_api_version": self.openai_api_version,
|
||||
}
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
@ -710,7 +700,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
def _create_chat_result(
|
||||
self,
|
||||
response: Union[dict, openai.BaseModel],
|
||||
generation_info: Optional[Dict] = None,
|
||||
generation_info: Optional[dict] = None,
|
||||
) -> ChatResult:
|
||||
chat_result = super()._create_chat_result(response, generation_info)
|
||||
|
||||
|
@ -10,6 +10,7 @@ import re
|
||||
import ssl
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from json import JSONDecodeError
|
||||
@ -18,17 +19,9 @@ from operator import itemgetter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -137,7 +130,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
# Fix for azure
|
||||
# Also OpenAI returns None for tool invocations
|
||||
content = _dict.get("content", "") or ""
|
||||
additional_kwargs: Dict = {}
|
||||
additional_kwargs: dict = {}
|
||||
if function_call := _dict.get("function_call"):
|
||||
additional_kwargs["function_call"] = dict(function_call)
|
||||
tool_calls = []
|
||||
@ -243,7 +236,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
Returns:
|
||||
The dictionary.
|
||||
"""
|
||||
message_dict: Dict[str, Any] = {"content": _format_message_content(message.content)}
|
||||
message_dict: dict[str, Any] = {"content": _format_message_content(message.content)}
|
||||
if (name := message.name or message.additional_kwargs.get("name")) is not None:
|
||||
message_dict["name"] = name
|
||||
|
||||
@ -304,12 +297,12 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
id_ = _dict.get("id")
|
||||
role = cast(str, _dict.get("role"))
|
||||
content = cast(str, _dict.get("content") or "")
|
||||
additional_kwargs: Dict = {}
|
||||
additional_kwargs: dict = {}
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
@ -418,8 +411,8 @@ class _FunctionCall(TypedDict):
|
||||
|
||||
|
||||
_BM = TypeVar("_BM", bound=BaseModel)
|
||||
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
|
||||
_DictOrPydantic = Union[Dict, _BM]
|
||||
_DictOrPydanticClass = Union[dict[str, Any], type[_BM], type]
|
||||
_DictOrPydantic = Union[dict, _BM]
|
||||
|
||||
|
||||
class _AllReturnType(TypedDict):
|
||||
@ -437,7 +430,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"""Model name to use."""
|
||||
temperature: Optional[float] = None
|
||||
"""What sampling temperature to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
openai_api_key: Optional[SecretStr] = Field(
|
||||
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
|
||||
@ -451,7 +444,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
openai_proxy: Optional[str] = Field(
|
||||
default_factory=from_env("OPENAI_PROXY", default=None)
|
||||
)
|
||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||
request_timeout: Union[float, tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
|
||||
@ -476,7 +469,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"""Number of most likely tokens to return at each token position, each with
|
||||
an associated log probability. `logprobs` must be set to true
|
||||
if this parameter is used."""
|
||||
logit_bias: Optional[Dict[int, int]] = None
|
||||
logit_bias: Optional[dict[int, int]] = None
|
||||
"""Modify the likelihood of specified tokens appearing in the completion."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
@ -517,14 +510,14 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
http_async_client: Union[Any, None] = Field(default=None, exclude=True)
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
http_client as well if you'd like a custom client for sync invocations."""
|
||||
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences")
|
||||
stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences")
|
||||
"""Default stop sequences."""
|
||||
extra_body: Optional[Mapping[str, Any]] = None
|
||||
"""Optional additional JSON properties to include in the request parameters when
|
||||
making requests to OpenAI compatible APIs, such as vLLM."""
|
||||
include_response_headers: bool = False
|
||||
"""Whether to include response headers in the output message response_metadata."""
|
||||
disabled_params: Optional[Dict[str, Any]] = Field(default=None)
|
||||
disabled_params: Optional[dict[str, Any]] = Field(default=None)
|
||||
"""Parameters of the OpenAI client or chat.completions endpoint that should be
|
||||
disabled for the given model.
|
||||
|
||||
@ -554,7 +547,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
@ -562,7 +555,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_temperature(cls, values: Dict[str, Any]) -> Any:
|
||||
def validate_temperature(cls, values: dict[str, Any]) -> Any:
|
||||
"""Currently o1 models only allow temperature=1."""
|
||||
model = values.get("model_name") or values.get("model") or ""
|
||||
if model.startswith("o1") and "temperature" not in values:
|
||||
@ -642,7 +635,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
return self
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
exclude_if_none = {
|
||||
"presence_penalty": self.presence_penalty,
|
||||
@ -669,7 +662,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
return params
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
system_fingerprint = None
|
||||
for output in llm_outputs:
|
||||
@ -697,8 +690,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
self,
|
||||
chunk: dict,
|
||||
default_chunk_class: Type,
|
||||
base_generation_info: Optional[Dict],
|
||||
default_chunk_class: type,
|
||||
base_generation_info: Optional[dict],
|
||||
) -> Optional[ChatGenerationChunk]:
|
||||
if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
|
||||
return None
|
||||
@ -749,8 +742,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
def _stream_responses(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
@ -783,8 +776,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
async def _astream_responses(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
@ -838,8 +831,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
stream_usage: Optional[bool] = None,
|
||||
@ -850,7 +843,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if stream_usage:
|
||||
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
base_generation_info = {}
|
||||
|
||||
if "response_format" in payload:
|
||||
@ -908,8 +901,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -965,7 +958,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
self,
|
||||
input_: LanguageModelInput,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
messages = self._convert_input(input_).to_messages()
|
||||
@ -982,7 +975,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
def _create_chat_result(
|
||||
self,
|
||||
response: Union[dict, openai.BaseModel],
|
||||
generation_info: Optional[Dict] = None,
|
||||
generation_info: Optional[dict] = None,
|
||||
) -> ChatResult:
|
||||
generations = []
|
||||
|
||||
@ -1032,8 +1025,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
stream_usage: Optional[bool] = None,
|
||||
@ -1044,7 +1037,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if stream_usage:
|
||||
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
base_generation_info = {}
|
||||
|
||||
if "response_format" in payload:
|
||||
@ -1106,8 +1099,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -1160,13 +1153,13 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {"model_name": self.model_name, **self._default_params}
|
||||
|
||||
def _get_invocation_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
return {
|
||||
"model": self.model_name,
|
||||
@ -1176,7 +1169,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
}
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
@ -1199,7 +1192,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"""Return type of chat model."""
|
||||
return "openai-chat"
|
||||
|
||||
def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
|
||||
def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:
|
||||
if self.tiktoken_model_name is not None:
|
||||
model = self.tiktoken_model_name
|
||||
else:
|
||||
@ -1211,7 +1204,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
return model, encoding
|
||||
|
||||
def get_token_ids(self, text: str) -> List[int]:
|
||||
def get_token_ids(self, text: str) -> list[int]:
|
||||
"""Get the tokens present in the text with tiktoken package."""
|
||||
if self.custom_get_token_ids is not None:
|
||||
return self.custom_get_token_ids(text)
|
||||
@ -1223,9 +1216,9 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: list[BaseMessage],
|
||||
tools: Optional[
|
||||
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
|
||||
Sequence[Union[dict[str, Any], type, Callable, BaseTool]]
|
||||
] = None,
|
||||
) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
@ -1327,7 +1320,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
)
|
||||
def bind_functions(
|
||||
self,
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
function_call: Optional[
|
||||
Union[_FunctionCall, str, Literal["auto", "none"]]
|
||||
] = None,
|
||||
@ -1380,7 +1373,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
||||
@ -1727,7 +1720,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
else:
|
||||
return llm | output_parser
|
||||
|
||||
def _filter_disabled_params(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
def _filter_disabled_params(self, **kwargs: Any) -> dict[str, Any]:
|
||||
if not self.disabled_params:
|
||||
return kwargs
|
||||
filtered = {}
|
||||
@ -2301,17 +2294,17 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "openai"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
def lc_attributes(self) -> dict[str, Any]:
|
||||
attributes: dict[str, Any] = {}
|
||||
|
||||
if self.openai_organization:
|
||||
attributes["openai_organization"] = self.openai_organization
|
||||
@ -2330,7 +2323,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
return True
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
params = super()._default_params
|
||||
if "max_tokens" in params:
|
||||
@ -2342,7 +2335,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
self,
|
||||
input_: LanguageModelInput,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
@ -2719,7 +2712,7 @@ def _lc_invalid_tool_call_to_openai_tool_call(
|
||||
}
|
||||
|
||||
|
||||
def _url_to_size(image_source: str) -> Optional[Tuple[int, int]]:
|
||||
def _url_to_size(image_source: str) -> Optional[tuple[int, int]]:
|
||||
try:
|
||||
from PIL import Image # type: ignore[import]
|
||||
except ImportError:
|
||||
@ -2771,7 +2764,7 @@ def _is_b64(s: str) -> bool:
|
||||
return s.startswith("data:image")
|
||||
|
||||
|
||||
def _resize(width: int, height: int) -> Tuple[int, int]:
|
||||
def _resize(width: int, height: int) -> tuple[int, int]:
|
||||
# larger side must be <= 2048
|
||||
if width > 2048 or height > 2048:
|
||||
if width > height:
|
||||
@ -2792,8 +2785,8 @@ def _resize(width: int, height: int) -> Tuple[int, int]:
|
||||
|
||||
|
||||
def _convert_to_openai_response_format(
|
||||
schema: Union[Dict[str, Any], Type], *, strict: Optional[bool] = None
|
||||
) -> Union[Dict, TypeBaseModel]:
|
||||
schema: Union[dict[str, Any], type], *, strict: Optional[bool] = None
|
||||
) -> Union[dict, TypeBaseModel]:
|
||||
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||
return schema
|
||||
|
||||
@ -2815,8 +2808,10 @@ def _convert_to_openai_response_format(
|
||||
function["schema"] = function.pop("parameters")
|
||||
response_format = {"type": "json_schema", "json_schema": function}
|
||||
|
||||
if strict is not None and strict is not response_format["json_schema"].get(
|
||||
"strict"
|
||||
if (
|
||||
strict is not None
|
||||
and strict is not response_format["json_schema"].get("strict")
|
||||
and isinstance(schema, dict)
|
||||
):
|
||||
msg = (
|
||||
f"Output schema already has 'strict' value set to "
|
||||
@ -2829,7 +2824,7 @@ def _convert_to_openai_response_format(
|
||||
|
||||
|
||||
def _oai_structured_outputs_parser(
|
||||
ai_msg: AIMessage, schema: Type[_BM]
|
||||
ai_msg: AIMessage, schema: type[_BM]
|
||||
) -> Optional[PydanticBaseModel]:
|
||||
if parsed := ai_msg.additional_kwargs.get("parsed"):
|
||||
if isinstance(parsed, dict):
|
||||
@ -3141,7 +3136,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
||||
|
||||
def _construct_lc_result_from_responses_api(
|
||||
response: Response,
|
||||
schema: Optional[Type[_BM]] = None,
|
||||
schema: Optional[type[_BM]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> ChatResult:
|
||||
"""Construct ChatResponse from OpenAI Response API response."""
|
||||
@ -3278,7 +3273,7 @@ def _construct_lc_result_from_responses_api(
|
||||
|
||||
|
||||
def _convert_responses_chunk_to_generation_chunk(
|
||||
chunk: Any, schema: Optional[Type[_BM]] = None, metadata: Optional[dict] = None
|
||||
chunk: Any, schema: Optional[type[_BM]] = None, metadata: Optional[dict] = None
|
||||
) -> Optional[ChatGenerationChunk]:
|
||||
content = []
|
||||
tool_call_chunks: list = []
|
||||
|
@ -2,12 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Awaitable, Callable, Optional, Union
|
||||
from collections.abc import Awaitable
|
||||
from typing import Callable, Optional, Union, cast
|
||||
|
||||
import openai
|
||||
from langchain_core.utils import from_env, secret_from_env
|
||||
from pydantic import Field, SecretStr, model_validator
|
||||
from typing_extensions import Self, cast
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_openai.embeddings.base import OpenAIEmbeddings
|
||||
|
||||
|
@ -2,20 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union, cast
|
||||
|
||||
import openai
|
||||
import tiktoken
|
||||
@ -29,19 +17,19 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def _process_batched_chunked_embeddings(
|
||||
num_texts: int,
|
||||
tokens: List[Union[List[int], str]],
|
||||
batched_embeddings: List[List[float]],
|
||||
indices: List[int],
|
||||
tokens: list[Union[list[int], str]],
|
||||
batched_embeddings: list[list[float]],
|
||||
indices: list[int],
|
||||
skip_empty: bool,
|
||||
) -> List[Optional[List[float]]]:
|
||||
) -> list[Optional[list[float]]]:
|
||||
# for each text, this is the list of embeddings (list of list of floats)
|
||||
# corresponding to the chunks of the text
|
||||
results: List[List[List[float]]] = [[] for _ in range(num_texts)]
|
||||
results: list[list[list[float]]] = [[] for _ in range(num_texts)]
|
||||
|
||||
# for each text, this is the token length of each chunk
|
||||
# for transformers tokenization, this is the string length
|
||||
# for tiktoken, this is the number of tokens
|
||||
num_tokens_in_batch: List[List[int]] = [[] for _ in range(num_texts)]
|
||||
num_tokens_in_batch: list[list[int]] = [[] for _ in range(num_texts)]
|
||||
|
||||
for i in range(len(indices)):
|
||||
if skip_empty and len(batched_embeddings[i]) == 1:
|
||||
@ -50,10 +38,10 @@ def _process_batched_chunked_embeddings(
|
||||
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
|
||||
|
||||
# for each text, this is the final embedding
|
||||
embeddings: List[Optional[List[float]]] = []
|
||||
embeddings: list[Optional[list[float]]] = []
|
||||
for i in range(num_texts):
|
||||
# an embedding for each chunk
|
||||
_result: List[List[float]] = results[i]
|
||||
_result: list[list[float]] = results[i]
|
||||
|
||||
if len(_result) == 0:
|
||||
# this will be populated with the embedding of an empty string
|
||||
@ -213,13 +201,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
),
|
||||
)
|
||||
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
|
||||
allowed_special: Union[Literal["all"], Set[str], None] = None
|
||||
disallowed_special: Union[Literal["all"], Set[str], Sequence[str], None] = None
|
||||
allowed_special: Union[Literal["all"], set[str], None] = None
|
||||
disallowed_special: Union[Literal["all"], set[str], Sequence[str], None] = None
|
||||
chunk_size: int = 1000
|
||||
"""Maximum number of texts to embed in each batch"""
|
||||
max_retries: int = 2
|
||||
"""Maximum number of retries to make when generating."""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float], Any]] = Field(
|
||||
request_timeout: Optional[Union[float, tuple[float, float], Any]] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
|
||||
@ -240,7 +228,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
when tiktoken is called, you can specify a model name to use here."""
|
||||
show_progress_bar: bool = False
|
||||
"""Whether to show a progress bar when embedding."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
skip_empty: bool = False
|
||||
"""Whether to skip empty strings when embedding or raise an error.
|
||||
@ -270,7 +258,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
@ -354,15 +342,15 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
return self
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
params: Dict = {"model": self.model, **self.model_kwargs}
|
||||
def _invocation_params(self) -> dict[str, Any]:
|
||||
params: dict = {"model": self.model, **self.model_kwargs}
|
||||
if self.dimensions is not None:
|
||||
params["dimensions"] = self.dimensions
|
||||
return params
|
||||
|
||||
def _tokenize(
|
||||
self, texts: List[str], chunk_size: int
|
||||
) -> Tuple[Iterable[int], List[Union[List[int], str]], List[int]]:
|
||||
self, texts: list[str], chunk_size: int
|
||||
) -> tuple[Iterable[int], list[Union[list[int], str]], list[int]]:
|
||||
"""
|
||||
Take the input `texts` and `chunk_size` and return 3 iterables as a tuple:
|
||||
|
||||
@ -383,8 +371,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
indices: An iterable of the same length as `tokens` that maps each token-array
|
||||
to the index of the original text in `texts`.
|
||||
"""
|
||||
tokens: List[Union[List[int], str]] = []
|
||||
indices: List[int] = []
|
||||
tokens: list[Union[list[int], str]] = []
|
||||
indices: list[int] = []
|
||||
model_name = self.tiktoken_model_name or self.model
|
||||
|
||||
# If tiktoken flag set to False
|
||||
@ -403,11 +391,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
for i, text in enumerate(texts):
|
||||
# Tokenize the text using HuggingFace transformers
|
||||
tokenized: List[int] = tokenizer.encode(text, add_special_tokens=False)
|
||||
tokenized: list[int] = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
# Split tokens into chunks respecting the embedding_ctx_length
|
||||
for j in range(0, len(tokenized), self.embedding_ctx_length):
|
||||
token_chunk: List[int] = tokenized[
|
||||
token_chunk: list[int] = tokenized[
|
||||
j : j + self.embedding_ctx_length
|
||||
]
|
||||
|
||||
@ -420,7 +408,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
encoder_kwargs: Dict[str, Any] = {
|
||||
encoder_kwargs: dict[str, Any] = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"allowed_special": self.allowed_special,
|
||||
@ -459,8 +447,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
# please refer to
|
||||
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
|
||||
def _get_len_safe_embeddings(
|
||||
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
|
||||
) -> List[List[float]]:
|
||||
self, texts: list[str], *, engine: str, chunk_size: Optional[int] = None
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Generate length-safe embeddings for a list of texts.
|
||||
|
||||
@ -478,7 +466,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
_chunk_size = chunk_size or self.chunk_size
|
||||
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
|
||||
batched_embeddings: List[List[float]] = []
|
||||
batched_embeddings: list[list[float]] = []
|
||||
for i in _iter:
|
||||
response = self.client.create(
|
||||
input=tokens[i : i + _chunk_size], **self._invocation_params
|
||||
@ -490,9 +478,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = _process_batched_chunked_embeddings(
|
||||
len(texts), tokens, batched_embeddings, indices, self.skip_empty
|
||||
)
|
||||
_cached_empty_embedding: Optional[List[float]] = None
|
||||
_cached_empty_embedding: Optional[list[float]] = None
|
||||
|
||||
def empty_embedding() -> List[float]:
|
||||
def empty_embedding() -> list[float]:
|
||||
nonlocal _cached_empty_embedding
|
||||
if _cached_empty_embedding is None:
|
||||
average_embedded = self.client.create(
|
||||
@ -508,8 +496,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
# please refer to
|
||||
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
|
||||
async def _aget_len_safe_embeddings(
|
||||
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
|
||||
) -> List[List[float]]:
|
||||
self, texts: list[str], *, engine: str, chunk_size: Optional[int] = None
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Asynchronously generate length-safe embeddings for a list of texts.
|
||||
|
||||
@ -528,7 +516,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
_chunk_size = chunk_size or self.chunk_size
|
||||
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
|
||||
batched_embeddings: List[List[float]] = []
|
||||
batched_embeddings: list[list[float]] = []
|
||||
_chunk_size = chunk_size or self.chunk_size
|
||||
for i in range(0, len(tokens), _chunk_size):
|
||||
response = await self.async_client.create(
|
||||
@ -542,9 +530,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = _process_batched_chunked_embeddings(
|
||||
len(texts), tokens, batched_embeddings, indices, self.skip_empty
|
||||
)
|
||||
_cached_empty_embedding: Optional[List[float]] = None
|
||||
_cached_empty_embedding: Optional[list[float]] = None
|
||||
|
||||
async def empty_embedding() -> List[float]:
|
||||
async def empty_embedding() -> list[float]:
|
||||
nonlocal _cached_empty_embedding
|
||||
if _cached_empty_embedding is None:
|
||||
average_embedded = await self.async_client.create(
|
||||
@ -558,8 +546,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
return [e if e is not None else await empty_embedding() for e in embeddings]
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], chunk_size: int | None = None
|
||||
) -> List[List[float]]:
|
||||
self, texts: list[str], chunk_size: int | None = None
|
||||
) -> list[list[float]]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
@ -572,7 +560,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
chunk_size_ = chunk_size or self.chunk_size
|
||||
if not self.check_embedding_ctx_length:
|
||||
embeddings: List[List[float]] = []
|
||||
embeddings: list[list[float]] = []
|
||||
for i in range(0, len(texts), chunk_size_):
|
||||
response = self.client.create(
|
||||
input=texts[i : i + chunk_size_], **self._invocation_params
|
||||
@ -588,8 +576,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
return self._get_len_safe_embeddings(texts, engine=engine)
|
||||
|
||||
async def aembed_documents(
|
||||
self, texts: List[str], chunk_size: int | None = None
|
||||
) -> List[List[float]]:
|
||||
self, texts: list[str], chunk_size: int | None = None
|
||||
) -> list[list[float]]:
|
||||
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
|
||||
|
||||
Args:
|
||||
@ -602,7 +590,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
chunk_size_ = chunk_size or self.chunk_size
|
||||
if not self.check_embedding_ctx_length:
|
||||
embeddings: List[List[float]] = []
|
||||
embeddings: list[list[float]] = []
|
||||
for i in range(0, len(texts), chunk_size_):
|
||||
response = await self.async_client.create(
|
||||
input=texts[i : i + chunk_size_], **self._invocation_params
|
||||
@ -617,7 +605,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
engine = cast(str, self.deployment)
|
||||
return await self._aget_len_safe_embeddings(texts, engine=engine)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
@ -628,7 +616,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Call out to OpenAI's embedding endpoint async for embedding query text.
|
||||
|
||||
Args:
|
||||
|
@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Union
|
||||
from collections.abc import Awaitable, Mapping
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models import LangSmithParams
|
||||
from langchain_core.utils import from_env, secret_from_env
|
||||
from pydantic import Field, SecretStr, model_validator
|
||||
from typing_extensions import Self, cast
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_openai.llms.base import BaseOpenAI
|
||||
|
||||
@ -91,12 +92,12 @@ class AzureOpenAI(BaseOpenAI):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "llms", "openai"]
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {
|
||||
"openai_api_key": "AZURE_OPENAI_API_KEY",
|
||||
"azure_ad_token": "AZURE_OPENAI_AD_TOKEN",
|
||||
@ -188,12 +189,12 @@ class AzureOpenAI(BaseOpenAI):
|
||||
}
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
def _invocation_params(self) -> dict[str, Any]:
|
||||
openai_params = {"model": self.deployment_name}
|
||||
return {**openai_params, **super()._invocation_params}
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
@ -209,7 +210,7 @@ class AzureOpenAI(BaseOpenAI):
|
||||
return "azure"
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
def lc_attributes(self) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_api_type": self.openai_api_type,
|
||||
"openai_api_version": self.openai_api_version,
|
||||
|
@ -2,21 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import AsyncIterator, Collection, Iterator, Mapping
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import openai
|
||||
import tiktoken
|
||||
@ -35,7 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _update_token_usage(
|
||||
keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
|
||||
keys: set[str], response: dict[str, Any], token_usage: dict[str, Any]
|
||||
) -> None:
|
||||
"""Update token usage."""
|
||||
_keys_to_use = keys.intersection(response["usage"])
|
||||
@ -47,7 +34,7 @@ def _update_token_usage(
|
||||
|
||||
|
||||
def _stream_response_to_generation_chunk(
|
||||
stream_response: Dict[str, Any],
|
||||
stream_response: dict[str, Any],
|
||||
) -> GenerationChunk:
|
||||
"""Convert a stream response to a generation chunk."""
|
||||
if not stream_response["choices"]:
|
||||
@ -84,7 +71,7 @@ class BaseOpenAI(BaseLLM):
|
||||
"""How many completions to generate for each prompt."""
|
||||
best_of: int = 1
|
||||
"""Generates best_of completions server-side and returns the "best"."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
openai_api_key: Optional[SecretStr] = Field(
|
||||
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
|
||||
@ -108,12 +95,12 @@ class BaseOpenAI(BaseLLM):
|
||||
)
|
||||
batch_size: int = 20
|
||||
"""Batch size to use when passing multiple documents to generate."""
|
||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||
request_timeout: Union[float, tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
|
||||
None."""
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logit_bias: Optional[dict[str, float]] = None
|
||||
"""Adjust the probability of specific tokens being generated."""
|
||||
max_retries: int = 2
|
||||
"""Maximum number of retries to make when generating."""
|
||||
@ -124,7 +111,7 @@ class BaseOpenAI(BaseLLM):
|
||||
as well the chosen tokens."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
|
||||
allowed_special: Union[Literal["all"], set[str]] = set()
|
||||
"""Set of special tokens that are allowed。"""
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
|
||||
"""Set of special tokens that are not allowed。"""
|
||||
@ -157,7 +144,7 @@ class BaseOpenAI(BaseLLM):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
@ -197,9 +184,9 @@ class BaseOpenAI(BaseLLM):
|
||||
return self
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
normal_params: Dict[str, Any] = {
|
||||
normal_params: dict[str, Any] = {
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
@ -228,7 +215,7 @@ class BaseOpenAI(BaseLLM):
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
@ -255,7 +242,7 @@ class BaseOpenAI(BaseLLM):
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
@ -283,8 +270,8 @@ class BaseOpenAI(BaseLLM):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@ -307,7 +294,7 @@ class BaseOpenAI(BaseLLM):
|
||||
params = {**params, **kwargs}
|
||||
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
token_usage: dict[str, int] = {}
|
||||
# Get the token usage from the response.
|
||||
# Includes prompt, completion, and total tokens used.
|
||||
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||
@ -363,8 +350,8 @@ class BaseOpenAI(BaseLLM):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@ -373,7 +360,7 @@ class BaseOpenAI(BaseLLM):
|
||||
params = {**params, **kwargs}
|
||||
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
token_usage: dict[str, int] = {}
|
||||
# Get the token usage from the response.
|
||||
# Includes prompt, completion, and total tokens used.
|
||||
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||
@ -419,10 +406,10 @@ class BaseOpenAI(BaseLLM):
|
||||
|
||||
def get_sub_prompts(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> List[List[str]]:
|
||||
params: dict[str, Any],
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
) -> list[list[str]]:
|
||||
"""Get the sub prompts for llm call."""
|
||||
if stop is not None:
|
||||
params["stop"] = stop
|
||||
@ -441,9 +428,9 @@ class BaseOpenAI(BaseLLM):
|
||||
def create_llm_result(
|
||||
self,
|
||||
choices: Any,
|
||||
prompts: List[str],
|
||||
params: Dict[str, Any],
|
||||
token_usage: Dict[str, int],
|
||||
prompts: list[str],
|
||||
params: dict[str, Any],
|
||||
token_usage: dict[str, int],
|
||||
*,
|
||||
system_fingerprint: Optional[str] = None,
|
||||
) -> LLMResult:
|
||||
@ -470,7 +457,7 @@ class BaseOpenAI(BaseLLM):
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
def _invocation_params(self) -> dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
return self._default_params
|
||||
|
||||
@ -484,7 +471,7 @@ class BaseOpenAI(BaseLLM):
|
||||
"""Return type of llm."""
|
||||
return "openai"
|
||||
|
||||
def get_token_ids(self, text: str) -> List[int]:
|
||||
def get_token_ids(self, text: str) -> list[int]:
|
||||
"""Get the token IDs using the tiktoken package."""
|
||||
if self.custom_get_token_ids is not None:
|
||||
return self.custom_get_token_ids(text)
|
||||
@ -689,7 +676,7 @@ class OpenAI(BaseOpenAI):
|
||||
""" # noqa: E501
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "llms", "openai"]
|
||||
|
||||
@ -699,16 +686,16 @@ class OpenAI(BaseOpenAI):
|
||||
return True
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
def _invocation_params(self) -> dict[str, Any]:
|
||||
return {**{"model": self.model_name}, **super()._invocation_params}
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
def lc_attributes(self) -> dict[str, Any]:
|
||||
attributes: dict[str, Any] = {}
|
||||
if self.openai_api_base:
|
||||
attributes["openai_api_base"] = self.openai_api_base
|
||||
|
||||
|
@ -59,8 +59,12 @@ disallow_untyped_defs = "True"
|
||||
module = "transformers"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "T201"]
|
||||
select = ["E", "F", "I", "T201", "UP"]
|
||||
ignore = [ "UP007", ]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
import os
|
||||
from typing import Type
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||
@ -14,7 +13,7 @@ OPENAI_API_BASE = os.environ.get("AZURE_OPENAI_API_BASE", "")
|
||||
|
||||
class TestAzureOpenAIStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return AzureChatOpenAI
|
||||
|
||||
@property
|
||||
@ -40,7 +39,7 @@ class TestAzureOpenAIStandardLegacy(ChatModelIntegrationTests):
|
||||
"""Test a legacy model."""
|
||||
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return AzureChatOpenAI
|
||||
|
||||
@property
|
||||
|
@ -2,9 +2,10 @@
|
||||
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import Any, AsyncIterator, List, Literal, Optional, cast
|
||||
from typing import Any, Literal, Optional, cast
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
@ -531,14 +532,14 @@ class MakeASandwich(BaseModel):
|
||||
|
||||
bread_type: str
|
||||
cheese_type: str
|
||||
condiments: List[str]
|
||||
vegetables: List[str]
|
||||
condiments: list[str]
|
||||
vegetables: list[str]
|
||||
|
||||
|
||||
def test_tool_use() -> None:
|
||||
llm = ChatOpenAI(model="gpt-4-turbo", temperature=0)
|
||||
llm_with_tool = llm.bind_tools(tools=[GenerateUsername], tool_choice=True)
|
||||
msgs: List = [HumanMessage("Sally has green hair, what would her username be?")]
|
||||
msgs: list = [HumanMessage("Sally has green hair, what would her username be?")]
|
||||
ai_msg = llm_with_tool.invoke(msgs)
|
||||
|
||||
assert isinstance(ai_msg, AIMessage)
|
||||
@ -583,7 +584,7 @@ def test_manual_tool_call_msg(use_responses_api: bool) -> None:
|
||||
model="gpt-3.5-turbo-0125", temperature=0, use_responses_api=use_responses_api
|
||||
)
|
||||
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
|
||||
msgs: List = [
|
||||
msgs: list = [
|
||||
HumanMessage("Sally has green hair, what would her username be?"),
|
||||
AIMessage(
|
||||
content="",
|
||||
@ -1045,7 +1046,7 @@ def test_audio_output_modality() -> None:
|
||||
},
|
||||
)
|
||||
|
||||
history: List[BaseMessage] = [
|
||||
history: list[BaseMessage] = [
|
||||
HumanMessage("Make me a short audio clip of you yelling")
|
||||
]
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Type, cast
|
||||
from typing import Literal, cast
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
@ -14,7 +14,7 @@ REPO_ROOT_DIR = Path(__file__).parents[6]
|
||||
|
||||
class TestOpenAIStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return ChatOpenAI
|
||||
|
||||
@property
|
||||
@ -36,9 +36,9 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def supported_usage_metadata_details(
|
||||
self,
|
||||
) -> Dict[
|
||||
) -> dict[
|
||||
Literal["invoke", "stream"],
|
||||
List[
|
||||
list[
|
||||
Literal[
|
||||
"audio_input",
|
||||
"audio_output",
|
||||
@ -51,7 +51,7 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
|
||||
return {"invoke": ["reasoning_output", "cache_read_input"], "stream": []}
|
||||
|
||||
def invoke_with_cache_read_input(self, *, stream: bool = False) -> AIMessage:
|
||||
with open(REPO_ROOT_DIR / "README.md", "r") as f:
|
||||
with open(REPO_ROOT_DIR / "README.md") as f:
|
||||
readme = f.read()
|
||||
|
||||
input_ = f"""What's langchain? Here's the langchain README:
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Annotated, Any, Optional, cast
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
@ -13,7 +13,7 @@ from langchain_core.messages import (
|
||||
BaseMessageChunk,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Standard LangChain interface tests for Responses API"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
@ -11,7 +9,7 @@ from tests.integration_tests.chat_models.test_base_standard import TestOpenAISta
|
||||
|
||||
class TestOpenAIResponses(TestOpenAIStandard):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return ChatOpenAI
|
||||
|
||||
@property
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_tests.integration_tests.embeddings import EmbeddingsIntegrationTests
|
||||
|
||||
@ -10,7 +8,7 @@ from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
class TestOpenAIStandard(EmbeddingsIntegrationTests):
|
||||
@property
|
||||
def embeddings_class(self) -> Type[Embeddings]:
|
||||
def embeddings_class(self) -> type[Embeddings]:
|
||||
return OpenAIEmbeddings
|
||||
|
||||
@property
|
||||
|
@ -1,7 +1,8 @@
|
||||
"""Test AzureOpenAI wrapper."""
|
||||
|
||||
import os
|
||||
from typing import Any, Generator
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Test OpenAI llm."""
|
||||
|
||||
from typing import Generator
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Tuple, Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
@ -12,7 +10,7 @@ from langchain_openai import AzureChatOpenAI
|
||||
|
||||
class TestOpenAIStandard(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return AzureChatOpenAI
|
||||
|
||||
@property
|
||||
@ -30,7 +28,7 @@ class TestOpenAIStandard(ChatModelUnitTests):
|
||||
super().test_bind_tool_pydantic(model, my_adder_tool)
|
||||
|
||||
@property
|
||||
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
|
||||
def init_from_env_params(self) -> tuple[dict, dict, dict]:
|
||||
return (
|
||||
{
|
||||
"AZURE_OPENAI_API_KEY": "api_key",
|
||||
|
@ -3,7 +3,7 @@
|
||||
import json
|
||||
from functools import partial
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
|
||||
from typing import Any, Literal, Optional, Union, cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -241,7 +241,7 @@ class MockAsyncContextManager:
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
@ -270,7 +270,7 @@ class MockSyncContextManager:
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
@ -382,7 +382,7 @@ DEEPSEEK_STREAM_DATA = """{"id":"d3610c24e6b42518a7883ea57c3ea2c3","choices":[{"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_deepseek_completion() -> List[Dict]:
|
||||
def mock_deepseek_completion() -> list[dict]:
|
||||
list_chunk_data = DEEPSEEK_STREAM_DATA.split("\n")
|
||||
result_list = []
|
||||
for msg in list_chunk_data:
|
||||
@ -450,7 +450,7 @@ OPENAI_STREAM_DATA = """{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_completion() -> List[Dict]:
|
||||
def mock_openai_completion() -> list[dict]:
|
||||
list_chunk_data = OPENAI_STREAM_DATA.split("\n")
|
||||
result_list = []
|
||||
for msg in list_chunk_data:
|
||||
@ -615,7 +615,7 @@ def test_openai_invoke_name(mock_client: MagicMock) -> None:
|
||||
|
||||
|
||||
def test_custom_token_counting() -> None:
|
||||
def token_encoder(text: str) -> List[int]:
|
||||
def token_encoder(text: str) -> list[int]:
|
||||
return [1, 2, 3]
|
||||
|
||||
llm = ChatOpenAI(custom_get_token_ids=token_encoder)
|
||||
@ -662,8 +662,8 @@ class MakeASandwich(BaseModel):
|
||||
|
||||
bread_type: str
|
||||
cheese_type: str
|
||||
condiments: List[str]
|
||||
vegetables: List[str]
|
||||
condiments: list[str]
|
||||
vegetables: list[str]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -695,7 +695,7 @@ def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> Non
|
||||
@pytest.mark.parametrize("include_raw", [True, False])
|
||||
@pytest.mark.parametrize("strict", [True, False, None])
|
||||
def test_with_structured_output(
|
||||
schema: Union[Type, Dict[str, Any], None],
|
||||
schema: Union[type, dict[str, Any], None],
|
||||
method: Literal["function_calling", "json_mode", "json_schema"],
|
||||
include_raw: bool,
|
||||
strict: Optional[bool],
|
||||
@ -787,7 +787,7 @@ class Foo(BaseModel):
|
||||
# FooV1
|
||||
],
|
||||
)
|
||||
def test_schema_from_with_structured_output(schema: Type) -> None:
|
||||
def test_schema_from_with_structured_output(schema: type) -> None:
|
||||
"""Test schema from with_structured_output."""
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o")
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Tuple, Type
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
@ -10,11 +8,11 @@ from langchain_openai import ChatOpenAI
|
||||
|
||||
class TestOpenAIStandard(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return ChatOpenAI
|
||||
|
||||
@property
|
||||
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
|
||||
def init_from_env_params(self) -> tuple[dict, dict, dict]:
|
||||
return (
|
||||
{
|
||||
"OPENAI_API_KEY": "api_key",
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Tuple, Type
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
@ -10,7 +8,7 @@ from langchain_openai import ChatOpenAI
|
||||
|
||||
class TestOpenAIResponses(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return ChatOpenAI
|
||||
|
||||
@property
|
||||
@ -18,7 +16,7 @@ class TestOpenAIResponses(ChatModelUnitTests):
|
||||
return {"use_responses_api": True}
|
||||
|
||||
@property
|
||||
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
|
||||
def init_from_env_params(self) -> tuple[dict, dict, dict]:
|
||||
return (
|
||||
{
|
||||
"OPENAI_API_KEY": "api_key",
|
||||
|
@ -1,5 +1,3 @@
|
||||
from typing import Tuple, Type
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests
|
||||
|
||||
@ -8,7 +6,7 @@ from langchain_openai import AzureOpenAIEmbeddings
|
||||
|
||||
class TestAzureOpenAIStandard(EmbeddingsUnitTests):
|
||||
@property
|
||||
def embeddings_class(self) -> Type[Embeddings]:
|
||||
def embeddings_class(self) -> type[Embeddings]:
|
||||
return AzureOpenAIEmbeddings
|
||||
|
||||
@property
|
||||
@ -16,7 +14,7 @@ class TestAzureOpenAIStandard(EmbeddingsUnitTests):
|
||||
return {"api_key": "api_key", "azure_endpoint": "https://endpoint.com"}
|
||||
|
||||
@property
|
||||
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
|
||||
def init_from_env_params(self) -> tuple[dict, dict, dict]:
|
||||
return (
|
||||
{
|
||||
"AZURE_OPENAI_API_KEY": "api_key",
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Tuple, Type
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests
|
||||
|
||||
@ -10,11 +8,11 @@ from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
class TestOpenAIStandard(EmbeddingsUnitTests):
|
||||
@property
|
||||
def embeddings_class(self) -> Type[Embeddings]:
|
||||
def embeddings_class(self) -> type[Embeddings]:
|
||||
return OpenAIEmbeddings
|
||||
|
||||
@property
|
||||
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
|
||||
def init_from_env_params(self) -> tuple[dict, dict, dict]:
|
||||
return (
|
||||
{
|
||||
"OPENAI_API_KEY": "api_key",
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""A fake callback handler for testing purposes."""
|
||||
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
@ -15,7 +15,7 @@ class BaseFakeCallbackHandler(BaseModel):
|
||||
starts: int = 0
|
||||
ends: int = 0
|
||||
errors: int = 0
|
||||
errors_args: List[Any] = []
|
||||
errors_args: list[Any] = []
|
||||
text: int = 0
|
||||
ignore_llm_: bool = False
|
||||
ignore_chain_: bool = False
|
||||
@ -195,8 +195,8 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
|
@ -1,5 +1,4 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
@ -65,7 +64,7 @@ def test_get_token_ids(model: str) -> None:
|
||||
|
||||
|
||||
def test_custom_token_counting() -> None:
|
||||
def token_encoder(text: str) -> List[int]:
|
||||
def token_encoder(text: str) -> list[int]:
|
||||
return [1, 2, 3]
|
||||
|
||||
llm = OpenAI(custom_get_token_ids=token_encoder)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Type, cast
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.load import dumpd
|
||||
@ -72,7 +72,7 @@ def test_azure_openai_embeddings_secrets() -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings]
|
||||
)
|
||||
def test_azure_openai_api_key_is_secret_string(model_class: Type) -> None:
|
||||
def test_azure_openai_api_key_is_secret_string(model_class: type) -> None:
|
||||
"""Test that the API key is stored as a SecretStr."""
|
||||
model = model_class(
|
||||
openai_api_key="secret-api-key",
|
||||
@ -88,7 +88,7 @@ def test_azure_openai_api_key_is_secret_string(model_class: Type) -> None:
|
||||
"model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings]
|
||||
)
|
||||
def test_azure_openai_api_key_masked_when_passed_from_env(
|
||||
model_class: Type, monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
model_class: type, monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test that the API key is masked when passed from an environment variable."""
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "secret-api-key")
|
||||
@ -109,7 +109,7 @@ def test_azure_openai_api_key_masked_when_passed_from_env(
|
||||
"model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings]
|
||||
)
|
||||
def test_azure_openai_api_key_masked_when_passed_via_constructor(
|
||||
model_class: Type, capsys: CaptureFixture
|
||||
model_class: type, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test that the API key is masked when passed via the constructor."""
|
||||
model = model_class(
|
||||
@ -133,7 +133,7 @@ def test_azure_openai_api_key_masked_when_passed_via_constructor(
|
||||
"model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings]
|
||||
)
|
||||
def test_azure_openai_uses_actual_secret_value_from_secretstr(
|
||||
model_class: Type,
|
||||
model_class: type,
|
||||
) -> None:
|
||||
"""Test that the actual secret value is correctly retrieved."""
|
||||
model = model_class(
|
||||
@ -147,7 +147,7 @@ def test_azure_openai_uses_actual_secret_value_from_secretstr(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
|
||||
def test_openai_api_key_is_secret_string(model_class: Type) -> None:
|
||||
def test_openai_api_key_is_secret_string(model_class: type) -> None:
|
||||
"""Test that the API key is stored as a SecretStr."""
|
||||
model = model_class(openai_api_key="secret-api-key")
|
||||
assert isinstance(model.openai_api_key, SecretStr)
|
||||
@ -155,7 +155,7 @@ def test_openai_api_key_is_secret_string(model_class: Type) -> None:
|
||||
|
||||
@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
|
||||
def test_openai_api_key_masked_when_passed_from_env(
|
||||
model_class: Type, monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
model_class: type, monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test that the API key is masked when passed from an environment variable."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "secret-api-key")
|
||||
@ -168,7 +168,7 @@ def test_openai_api_key_masked_when_passed_from_env(
|
||||
|
||||
@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
|
||||
def test_openai_api_key_masked_when_passed_via_constructor(
|
||||
model_class: Type, capsys: CaptureFixture
|
||||
model_class: type, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test that the API key is masked when passed via the constructor."""
|
||||
model = model_class(openai_api_key="secret-api-key")
|
||||
@ -179,14 +179,14 @@ def test_openai_api_key_masked_when_passed_via_constructor(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
|
||||
def test_openai_uses_actual_secret_value_from_secretstr(model_class: Type) -> None:
|
||||
def test_openai_uses_actual_secret_value_from_secretstr(model_class: type) -> None:
|
||||
"""Test that the actual secret value is correctly retrieved."""
|
||||
model = model_class(openai_api_key="secret-api-key")
|
||||
assert cast(SecretStr, model.openai_api_key).get_secret_value() == "secret-api-key"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [AzureChatOpenAI, AzureOpenAI])
|
||||
def test_azure_serialized_secrets(model_class: Type) -> None:
|
||||
def test_azure_serialized_secrets(model_class: type) -> None:
|
||||
"""Test that the actual secret value is correctly retrieved."""
|
||||
model = model_class(
|
||||
openai_api_key="secret-api-key", api_version="foo", azure_endpoint="foo"
|
||||
|
@ -3,20 +3,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterator, Mapping
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Literal, Optional, TypeVar, Union
|
||||
|
||||
import openai
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
@ -50,8 +39,8 @@ from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
_BM = TypeVar("_BM", bound=BaseModel)
|
||||
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
|
||||
_DictOrPydantic = Union[Dict, _BM]
|
||||
_DictOrPydanticClass = Union[dict[str, Any], type[_BM], type]
|
||||
_DictOrPydantic = Union[dict, _BM]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -162,14 +151,14 @@ class ChatPerplexity(BaseChatModel):
|
||||
"""Model name."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
pplx_api_key: Optional[SecretStr] = Field(
|
||||
default_factory=secret_from_env("PPLX_API_KEY", default=None), alias="api_key"
|
||||
)
|
||||
"""Base URL path for API requests,
|
||||
leave blank if not using a proxy or service emulator."""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = Field(
|
||||
request_timeout: Optional[Union[float, tuple[float, float]]] = Field(
|
||||
None, alias="timeout"
|
||||
)
|
||||
"""Timeout for requests to PerplexityChat completion API. Default is None."""
|
||||
@ -183,12 +172,12 @@ class ChatPerplexity(BaseChatModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {"pplx_api_key": "PPLX_API_KEY"}
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
@ -232,7 +221,7 @@ class ChatPerplexity(BaseChatModel):
|
||||
return self
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling PerplexityChat API."""
|
||||
return {
|
||||
"max_tokens": self.max_tokens,
|
||||
@ -241,7 +230,7 @@ class ChatPerplexity(BaseChatModel):
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> dict[str, Any]:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
@ -255,8 +244,8 @@ class ChatPerplexity(BaseChatModel):
|
||||
return message_dict
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
self, messages: list[BaseMessage], stop: Optional[list[str]]
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
params = dict(self._invocation_params)
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
@ -266,11 +255,11 @@ class ChatPerplexity(BaseChatModel):
|
||||
return message_dicts, params
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
self, _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
self, _dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role")
|
||||
content = _dict.get("content") or ""
|
||||
additional_kwargs: Dict = {}
|
||||
additional_kwargs: dict = {}
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
@ -296,8 +285,8 @@ class ChatPerplexity(BaseChatModel):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
@ -367,8 +356,8 @@ class ChatPerplexity(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -402,7 +391,7 @@ class ChatPerplexity(BaseChatModel):
|
||||
@property
|
||||
def _invocation_params(self) -> Mapping[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
pplx_creds: Dict[str, Any] = {"model": self.model}
|
||||
pplx_creds: dict[str, Any] = {"model": self.model}
|
||||
return {**pplx_creds, **self._default_params}
|
||||
|
||||
@property
|
||||
|
@ -55,8 +55,12 @@ plugins = ['pydantic.mypy']
|
||||
module = "transformers"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "T201"]
|
||||
select = ["E", "F", "I", "T201", "UP"]
|
||||
ignore = [ "UP007", ]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Standard LangChain interface tests."""
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||
@ -11,7 +9,7 @@ from langchain_perplexity import ChatPerplexity
|
||||
|
||||
class TestPerplexityStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return ChatPerplexity
|
||||
|
||||
@property
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
|
||||
@ -51,7 +51,7 @@ def test_perplexity_stream_includes_citations(mocker: MockerFixture) -> None:
|
||||
"choices": [{"delta": {"content": "Perplexity"}, "finish_reason": None}],
|
||||
"citations": ["example.com", "example2.com"],
|
||||
}
|
||||
mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
|
||||
mock_chunks: list[dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.__iter__.return_value = mock_chunks
|
||||
patcher = mocker.patch.object(
|
||||
@ -103,7 +103,7 @@ def test_perplexity_stream_includes_citations_and_images(mocker: MockerFixture)
|
||||
}
|
||||
],
|
||||
}
|
||||
mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
|
||||
mock_chunks: list[dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.__iter__.return_value = mock_chunks
|
||||
patcher = mocker.patch.object(
|
||||
@ -162,7 +162,7 @@ def test_perplexity_stream_includes_citations_and_related_questions(
|
||||
"citations": ["example.com", "example2.com"],
|
||||
"related_questions": ["example_question_1", "example_question_2"],
|
||||
}
|
||||
mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
|
||||
mock_chunks: list[dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.__iter__.return_value = mock_chunks
|
||||
patcher = mocker.patch.object(
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Test Perplexity Chat API wrapper."""
|
||||
|
||||
from typing import Tuple, Type
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
@ -10,9 +8,9 @@ from langchain_perplexity import ChatPerplexity
|
||||
|
||||
class TestPerplexityStandard(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return ChatPerplexity
|
||||
|
||||
@property
|
||||
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
|
||||
def init_from_env_params(self) -> tuple[dict, dict, dict]:
|
||||
return ({"PPLX_API_KEY": "api_key"}, {}, {"pplx_api_key": "api_key"})
|
||||
|
@ -5,7 +5,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Generic, Literal, Optional, TypeVar, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field, FilePath
|
||||
@ -24,7 +24,7 @@ class PropertySettings(BaseModel):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
type: Literal["string", "number", "array", "object", "boolean"]
|
||||
default: Union[str, int, float, List, Dict, bool, None] = Field(default=None)
|
||||
default: Union[str, int, float, list, dict, bool, None] = Field(default=None)
|
||||
description: str = Field(default="")
|
||||
|
||||
|
||||
@ -58,8 +58,8 @@ class Prompty(BaseModel):
|
||||
# metadata
|
||||
name: str = Field(default="")
|
||||
description: str = Field(default="")
|
||||
authors: List[str] = Field(default=[])
|
||||
tags: List[str] = Field(default=[])
|
||||
authors: list[str] = Field(default=[])
|
||||
tags: list[str] = Field(default=[])
|
||||
version: str = Field(default="")
|
||||
base: str = Field(default="")
|
||||
basePrompty: Optional[Prompty] = Field(default=None)
|
||||
@ -70,8 +70,8 @@ class Prompty(BaseModel):
|
||||
sample: dict = Field(default={})
|
||||
|
||||
# input / output
|
||||
inputs: Dict[str, PropertySettings] = Field(default={})
|
||||
outputs: Dict[str, PropertySettings] = Field(default={})
|
||||
inputs: dict[str, PropertySettings] = Field(default={})
|
||||
outputs: dict[str, PropertySettings] = Field(default={})
|
||||
|
||||
# template
|
||||
template: TemplateSettings
|
||||
@ -79,7 +79,7 @@ class Prompty(BaseModel):
|
||||
file: FilePath = Field(default="") # type: ignore[assignment]
|
||||
content: str = Field(default="")
|
||||
|
||||
def to_safe_dict(self) -> Dict[str, Any]:
|
||||
def to_safe_dict(self) -> dict[str, Any]:
|
||||
d = {}
|
||||
for k, v in self:
|
||||
if v != "" and v != {} and v != [] and v is not None:
|
||||
@ -130,7 +130,7 @@ class Prompty(BaseModel):
|
||||
attribute.startswith("file:")
|
||||
and Path(parent / attribute.split(":")[1]).exists()
|
||||
):
|
||||
with open(parent / attribute.split(":")[1], "r") as f:
|
||||
with open(parent / attribute.split(":")[1]) as f:
|
||||
items = json.load(f)
|
||||
if isinstance(items, list):
|
||||
return [Prompty.normalize(value, parent) for value in items]
|
||||
@ -155,8 +155,8 @@ class Prompty(BaseModel):
|
||||
|
||||
|
||||
def param_hoisting(
|
||||
top: Dict[str, Any], bottom: Dict[str, Any], top_key: Any = None
|
||||
) -> Dict[str, Any]:
|
||||
top: dict[str, Any], bottom: dict[str, Any], top_key: Any = None
|
||||
) -> dict[str, Any]:
|
||||
"""Merge two dictionaries with hoisting of parameters from bottom to top.
|
||||
|
||||
Args:
|
||||
@ -198,18 +198,18 @@ class NoOpParser(Invoker):
|
||||
return data
|
||||
|
||||
|
||||
class InvokerFactory(object):
|
||||
class InvokerFactory:
|
||||
"""Factory for creating invokers."""
|
||||
|
||||
_instance = None
|
||||
_renderers: Dict[str, Type[Invoker]] = {}
|
||||
_parsers: Dict[str, Type[Invoker]] = {}
|
||||
_executors: Dict[str, Type[Invoker]] = {}
|
||||
_processors: Dict[str, Type[Invoker]] = {}
|
||||
_renderers: dict[str, type[Invoker]] = {}
|
||||
_parsers: dict[str, type[Invoker]] = {}
|
||||
_executors: dict[str, type[Invoker]] = {}
|
||||
_processors: dict[str, type[Invoker]] = {}
|
||||
|
||||
def __new__(cls) -> InvokerFactory:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(InvokerFactory, cls).__new__(cls)
|
||||
cls._instance = super().__new__(cls)
|
||||
# Add NOOP invokers
|
||||
cls._renderers["NOOP"] = NoOpParser
|
||||
cls._parsers["NOOP"] = NoOpParser
|
||||
@ -221,7 +221,7 @@ class InvokerFactory(object):
|
||||
self,
|
||||
type: Literal["renderer", "parser", "executor", "processor"],
|
||||
name: str,
|
||||
invoker: Type[Invoker],
|
||||
invoker: type[Invoker],
|
||||
) -> None:
|
||||
if type == "renderer":
|
||||
self._renderers[name] = invoker
|
||||
@ -264,7 +264,7 @@ class InvokerFactory(object):
|
||||
else:
|
||||
raise ValueError(f"Invalid type {type}")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"renderers": {
|
||||
k: f"{v.__module__}.{v.__name__}" for k, v in self._renderers.items()
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.runnables import Runnable, RunnableLambda
|
||||
@ -10,10 +10,10 @@ from .utils import load, prepare
|
||||
def create_chat_prompt(
|
||||
path: str,
|
||||
input_name_agent_scratchpad: str = "agent_scratchpad",
|
||||
) -> Runnable[Dict[str, Any], ChatPromptTemplate]:
|
||||
) -> Runnable[dict[str, Any], ChatPromptTemplate]:
|
||||
"""Create a chat prompt from a Langchain schema."""
|
||||
|
||||
def runnable_chat_lambda(inputs: Dict[str, Any]) -> ChatPromptTemplate:
|
||||
def runnable_chat_lambda(inputs: dict[str, Any]) -> ChatPromptTemplate:
|
||||
p = load(path)
|
||||
parsed = prepare(p, inputs)
|
||||
# Parsed messages have been templated
|
||||
|
@ -1,6 +1,6 @@
|
||||
import base64
|
||||
import re
|
||||
from typing import Dict, List, Type, Union
|
||||
from typing import Union
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@ -15,7 +15,7 @@ from .core import Invoker, Prompty, SimpleModel
|
||||
|
||||
|
||||
class RoleMap:
|
||||
_ROLE_MAP: Dict[str, Type[BaseMessage]] = {
|
||||
_ROLE_MAP: dict[str, type[BaseMessage]] = {
|
||||
"system": SystemMessage,
|
||||
"user": HumanMessage,
|
||||
"human": HumanMessage,
|
||||
@ -26,7 +26,7 @@ class RoleMap:
|
||||
ROLES = _ROLE_MAP.keys()
|
||||
|
||||
@classmethod
|
||||
def get_message_class(cls, role: str) -> Type[BaseMessage]:
|
||||
def get_message_class(cls, role: str) -> type[BaseMessage]:
|
||||
return cls._ROLE_MAP[role]
|
||||
|
||||
|
||||
@ -60,7 +60,7 @@ class PromptyChatParser(Invoker):
|
||||
"and .jpg / .jpeg are supported."
|
||||
)
|
||||
|
||||
def parse_content(self, content: str) -> Union[str, List]:
|
||||
def parse_content(self, content: str) -> Union[str, list]:
|
||||
"""for parsing inline images"""
|
||||
# regular expression to parse markdown images
|
||||
image = r"(?P<alt>!\[[^\]]*\])\((?P<filename>.*?)(?=\"|\))\)"
|
||||
|
@ -1,6 +1,6 @@
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from .core import (
|
||||
Frontmatter,
|
||||
@ -120,7 +120,7 @@ def load(prompt_path: str, configuration: str = "default") -> Prompty:
|
||||
|
||||
def prepare(
|
||||
prompt: Prompty,
|
||||
inputs: Dict[str, Any] = {},
|
||||
inputs: dict[str, Any] = {},
|
||||
) -> Any:
|
||||
"""Prepare the inputs for the prompty.
|
||||
|
||||
@ -166,9 +166,9 @@ def prepare(
|
||||
|
||||
def run(
|
||||
prompt: Prompty,
|
||||
content: Union[Dict, List, str],
|
||||
configuration: Dict[str, Any] = {},
|
||||
parameters: Dict[str, Any] = {},
|
||||
content: Union[dict, list, str],
|
||||
configuration: dict[str, Any] = {},
|
||||
parameters: dict[str, Any] = {},
|
||||
raw: bool = False,
|
||||
) -> Any:
|
||||
"""Run the prompty.
|
||||
@ -219,9 +219,9 @@ def run(
|
||||
|
||||
def execute(
|
||||
prompt: Union[str, Prompty],
|
||||
configuration: Dict[str, Any] = {},
|
||||
parameters: Dict[str, Any] = {},
|
||||
inputs: Dict[str, Any] = {},
|
||||
configuration: dict[str, Any] = {},
|
||||
parameters: dict[str, Any] = {},
|
||||
inputs: dict[str, Any] = {},
|
||||
raw: bool = False,
|
||||
connection: str = "default",
|
||||
) -> Any:
|
||||
|
@ -45,9 +45,12 @@ langchain-core = { path = "../../core", editable = true }
|
||||
langchain-text-splitters = { path = "../../text-splitters", editable = true }
|
||||
langchain = { path = "../../langchain", editable = true }
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
select = ["E", "F", "I"]
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "T201", "UP"]
|
||||
ignore = [ "UP007", ]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user