langchain: Add ruff rules B (#31908)

See https://docs.astral.sh/ruff/rules/#flake8-bugbear-b

Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet 2025-07-08 16:48:18 +02:00 committed by GitHub
parent 65b098325b
commit 3f839d566a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
56 changed files with 168 additions and 130 deletions

View File

@ -29,10 +29,12 @@ def _warn_on_import(name: str, replacement: Optional[str] = None) -> None:
warnings.warn(
f"Importing {name} from langchain root module is no longer supported. "
f"Please use {replacement} instead.",
stacklevel=3,
)
else:
warnings.warn(
f"Importing {name} from langchain root module is no longer supported.",
stacklevel=3,
)

View File

@ -1366,7 +1366,7 @@ class AgentExecutor(Chain):
"again, pass `handle_parsing_errors=True` to the AgentExecutor. "
f"This is the error: {e!s}"
)
raise ValueError(msg)
raise ValueError(msg) from e
text = str(e)
if isinstance(self.handle_parsing_errors, bool):
if e.send_to_llm:
@ -1380,7 +1380,7 @@ class AgentExecutor(Chain):
observation = self.handle_parsing_errors(e)
else:
msg = "Got unexpected type of `handle_parsing_errors`"
raise ValueError(msg)
raise ValueError(msg) from e
output = AgentAction("_Exception", observation, text)
if run_manager:
run_manager.on_agent_action(output, color="green")
@ -1505,7 +1505,7 @@ class AgentExecutor(Chain):
"again, pass `handle_parsing_errors=True` to the AgentExecutor. "
f"This is the error: {e!s}"
)
raise ValueError(msg)
raise ValueError(msg) from e
text = str(e)
if isinstance(self.handle_parsing_errors, bool):
if e.send_to_llm:
@ -1519,7 +1519,7 @@ class AgentExecutor(Chain):
observation = self.handle_parsing_errors(e)
else:
msg = "Got unexpected type of `handle_parsing_errors`"
raise ValueError(msg)
raise ValueError(msg) from e
output = AgentAction("_Exception", observation, text)
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
observation = await ExceptionTool().arun(

View File

@ -36,9 +36,9 @@ class VectorStoreToolkit(BaseToolkit):
VectorStoreQATool,
VectorStoreQAWithSourcesTool,
)
except ImportError:
except ImportError as e:
msg = "You need to install langchain-community to use this toolkit."
raise ImportError(msg)
raise ImportError(msg) from e
description = VectorStoreQATool.get_description(
self.vectorstore_info.name,
self.vectorstore_info.description,
@ -79,9 +79,9 @@ class VectorStoreRouterToolkit(BaseToolkit):
from langchain_community.tools.vectorstore.tool import (
VectorStoreQATool,
)
except ImportError:
except ImportError as e:
msg = "You need to install langchain-community to use this toolkit."
raise ImportError(msg)
raise ImportError(msg) from e
for vectorstore_info in self.vectorstores:
description = VectorStoreQATool.get_description(
vectorstore_info.name,

View File

@ -32,6 +32,8 @@ from langchain.agents.output_parsers.openai_functions import (
OpenAIFunctionsAgentOutputParser,
)
_NOT_SET = object()
@deprecated("0.1.0", alternative="create_openai_functions_agent", removal="1.0")
class OpenAIFunctionsAgent(BaseSingleActionAgent):
@ -213,9 +215,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
@classmethod
def create_prompt(
cls,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant.",
),
system_message: Optional[SystemMessage] = _NOT_SET, # type: ignore[assignment]
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
) -> ChatPromptTemplate:
"""Create prompt for this agent.
@ -230,8 +230,13 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
A prompt template to pass into this agent.
"""
_prompts = extra_prompt_messages or []
system_message_ = (
system_message
if system_message is not _NOT_SET
else SystemMessage(content="You are a helpful AI assistant.")
)
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
messages = [system_message] if system_message else []
messages = [system_message_] if system_message_ else []
messages.extend(
[
@ -249,9 +254,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant.",
),
system_message: Optional[SystemMessage] = _NOT_SET, # type: ignore[assignment]
**kwargs: Any,
) -> BaseSingleActionAgent:
"""Construct an agent from an LLM and tools.
@ -265,9 +268,14 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
Defaults to a default system message.
kwargs: Additional parameters to pass to the agent.
"""
system_message_ = (
system_message
if system_message is not _NOT_SET
else SystemMessage(content="You are a helpful AI assistant.")
)
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
system_message=system_message_,
)
return cls( # type: ignore[call-arg]
llm=llm,

View File

@ -46,21 +46,21 @@ def _parse_ai_message(message: BaseMessage) -> Union[list[AgentAction], AgentFin
if function_call:
try:
arguments = json.loads(function_call["arguments"], strict=False)
except JSONDecodeError:
except JSONDecodeError as e:
msg = (
f"Could not parse tool input: {function_call} because "
f"the `arguments` is not valid JSON."
)
raise OutputParserException(msg)
raise OutputParserException(msg) from e
try:
tools = arguments["actions"]
except (TypeError, KeyError):
except (TypeError, KeyError) as e:
msg = (
f"Could not parse tool input: {function_call} because "
f"the `arguments` JSON does not contain `actions` key."
)
raise OutputParserException(msg)
raise OutputParserException(msg) from e
final_tools: list[AgentAction] = []
for tool_schema in tools:
@ -100,6 +100,9 @@ def _parse_ai_message(message: BaseMessage) -> Union[list[AgentAction], AgentFin
)
_NOT_SET = object()
@deprecated("0.1.0", alternative="create_openai_tools_agent", removal="1.0")
class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
"""Agent driven by OpenAIs function powered API.
@ -263,9 +266,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
@classmethod
def create_prompt(
cls,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant.",
),
system_message: Optional[SystemMessage] = _NOT_SET, # type: ignore[assignment]
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
) -> BasePromptTemplate:
"""Create prompt for this agent.
@ -280,8 +281,13 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
A prompt template to pass into this agent.
"""
_prompts = extra_prompt_messages or []
system_message_ = (
system_message
if system_message is not _NOT_SET
else SystemMessage(content="You are a helpful AI assistant.")
)
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
messages = [system_message] if system_message else []
messages = [system_message_] if system_message_ else []
messages.extend(
[
@ -299,9 +305,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant.",
),
system_message: Optional[SystemMessage] = _NOT_SET, # type: ignore[assignment]
**kwargs: Any,
) -> BaseMultiActionAgent:
"""Construct an agent from an LLM and tools.
@ -315,9 +319,14 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
Default is a default system message.
kwargs: Additional arguments.
"""
system_message_ = (
system_message
if system_message is not _NOT_SET
else SystemMessage(content="You are a helpful AI assistant.")
)
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
system_message=system_message_,
)
return cls( # type: ignore[call-arg]
llm=llm,

View File

@ -47,12 +47,12 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
else:
# otherwise it returns a json object
_tool_input = json.loads(function_call["arguments"], strict=False)
except JSONDecodeError:
except JSONDecodeError as e:
msg = (
f"Could not parse tool input: {function_call} because "
f"the `arguments` is not valid JSON."
)
raise OutputParserException(msg)
raise OutputParserException(msg) from e
# HACK HACK HACK:
# The code that encodes tool input into Open AI uses a special variable

View File

@ -72,10 +72,10 @@ class ReActJsonSingleInputOutputParser(AgentOutputParser):
text,
)
except Exception:
except Exception as e:
if not includes_answer:
msg = f"Could not parse LLM output: {text}"
raise OutputParserException(msg)
raise OutputParserException(msg) from e
output = text.split(FINAL_ANSWER_ACTION)[-1].strip()
return AgentFinish({"output": output}, text)

View File

@ -46,12 +46,12 @@ def parse_ai_message_to_tool_action(
tool_calls.append(
ToolCall(name=function_name, args=args, id=tool_call["id"]),
)
except JSONDecodeError:
except JSONDecodeError as e:
msg = (
f"Could not parse tool input: {function} because "
f"the `arguments` is not valid JSON."
)
raise OutputParserException(msg)
raise OutputParserException(msg) from e
for tool_call in tool_calls:
# HACK HACK HACK:
# The code that encodes tool input into Open AI uses a special variable

View File

@ -70,12 +70,12 @@ def StreamlitCallbackHandler(
from langchain_community.callbacks.streamlit.streamlit_callback_handler import ( # noqa: E501
StreamlitCallbackHandler as _InternalStreamlitCallbackHandler,
)
except ImportError:
except ImportError as e:
msg = (
"To use the StreamlitCallbackHandler, please install "
"langchain-community with `pip install langchain-community`."
)
raise ImportError(msg)
raise ImportError(msg) from e
return _InternalStreamlitCallbackHandler(
parent_container,

View File

@ -255,6 +255,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
stacklevel=4,
)
values["callbacks"] = values.pop("callback_manager", None)
return values

View File

@ -504,6 +504,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
warnings.warn(
"`ChatVectorDBChain` is deprecated - "
"please use `from langchain.chains import ConversationalRetrievalChain`",
stacklevel=4,
)
return values

View File

@ -202,7 +202,7 @@ class FlareChain(Chain):
response = ""
for i in range(self.max_iter):
for _i in range(self.max_iter):
_run_manager.on_text(
f"Current Response: {response}",
color="blue",
@ -261,13 +261,13 @@ class FlareChain(Chain):
"""
try:
from langchain_openai import ChatOpenAI
except ImportError:
except ImportError as e:
msg = (
"OpenAI is required for FlareChain. "
"Please install langchain-openai."
"pip install langchain-openai"
)
raise ImportError(msg)
raise ImportError(msg) from e
llm = ChatOpenAI(
max_completion_tokens=max_generation_len,
logprobs=True,

View File

@ -350,6 +350,7 @@ class LLMChain(Chain):
warnings.warn(
"The predict_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain.",
stacklevel=2,
)
result = self.predict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
@ -365,6 +366,7 @@ class LLMChain(Chain):
warnings.warn(
"The apredict_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain.",
stacklevel=2,
)
result = await self.apredict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
@ -380,6 +382,7 @@ class LLMChain(Chain):
warnings.warn(
"The apply_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain.",
stacklevel=2,
)
result = self.apply(input_list, callbacks=callbacks)
return self._parse_generation(result)
@ -404,6 +407,7 @@ class LLMChain(Chain):
warnings.warn(
"The aapply_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain.",
stacklevel=2,
)
result = await self.aapply(input_list, callbacks=callbacks)
return self._parse_generation(result)

View File

@ -112,6 +112,7 @@ class LLMCheckerChain(Chain):
"Directly instantiating an LLMCheckerChain with an llm is deprecated. "
"Please instantiate with question_to_checked_assertions_chain "
"or using the from_llm class method.",
stacklevel=5,
)
if (
"question_to_checked_assertions_chain" not in values

View File

@ -166,17 +166,18 @@ class LLMMathChain(Chain):
def raise_deprecation(cls, values: dict) -> Any:
try:
import numexpr # noqa: F401
except ImportError:
except ImportError as e:
msg = (
"LLMMathChain requires the numexpr package. "
"Please install it with `pip install numexpr`."
)
raise ImportError(msg)
raise ImportError(msg) from e
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMMathChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method.",
stacklevel=5,
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
@ -216,7 +217,7 @@ class LLMMathChain(Chain):
f'LLMMathChain._evaluate("{expression}") raised error: {e}.'
" Please try again with a valid numerical expression"
)
raise ValueError(msg)
raise ValueError(msg) from e
# Remove any leading and trailing brackets from the output
return re.sub(r"^\[|\]$", "", output)

View File

@ -118,6 +118,7 @@ class LLMSummarizationCheckerChain(Chain):
"Directly instantiating an LLMSummarizationCheckerChain with an llm is "
"deprecated. Please instantiate with"
" sequential_chain argument or using the from_llm class method.",
stacklevel=5,
)
if "sequential_chain" not in values and values["llm"] is not None:
values["sequential_chain"] = _load_sequential_chain(

View File

@ -566,13 +566,13 @@ def _load_graph_cypher_chain(config: dict, **kwargs: Any) -> GraphCypherQAChain:
try:
from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain
except ImportError:
except ImportError as e:
msg = (
"To use this GraphCypherQAChain functionality you must install the "
"langchain_community package. "
"You can install it with `pip install langchain_community`"
)
raise ImportError(msg)
raise ImportError(msg) from e
return GraphCypherQAChain(
graph=graph,
cypher_generation_chain=cypher_generation_chain,
@ -614,13 +614,13 @@ def _load_api_chain(config: dict, **kwargs: Any) -> APIChain:
def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain:
try:
from langchain.chains.llm_requests import LLMRequestsChain
except ImportError:
except ImportError as e:
msg = (
"To use this LLMRequestsChain functionality you must install the "
"langchain package. "
"You can install it with `pip install langchain`"
)
raise ImportError(msg)
raise ImportError(msg) from e
if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain")

View File

@ -72,12 +72,12 @@ class OpenAIModerationChain(Chain):
values["client"] = openai.OpenAI(api_key=openai_api_key)
values["async_client"] = openai.AsyncOpenAI(api_key=openai_api_key)
except ImportError:
except ImportError as e:
msg = (
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
raise ImportError(msg)
raise ImportError(msg) from e
return values
@property

View File

@ -74,6 +74,7 @@ class NatBotChain(Chain):
"Directly instantiating an NatBotChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method.",
stacklevel=5,
)
if "llm_chain" not in values and values["llm"] is not None:
values["llm_chain"] = PROMPT | values["llm"] | StrOutputParser()

View File

@ -63,12 +63,12 @@ class Crawler:
def __init__(self) -> None:
try:
from playwright.sync_api import sync_playwright
except ImportError:
except ImportError as e:
msg = (
"Could not import playwright python package. "
"Please install it with `pip install playwright`."
)
raise ImportError(msg)
raise ImportError(msg) from e
self.browser: Browser = (
sync_playwright().start().chromium.launch(headless=False)
)

View File

@ -94,12 +94,12 @@ def openapi_spec_to_openai_fn(
"""
try:
from langchain_community.tools import APIOperation
except ImportError:
except ImportError as e:
msg = (
"Could not import langchain_community.tools. "
"Please install it with `pip install langchain-community`."
)
raise ImportError(msg)
raise ImportError(msg) from e
if not spec.paths:
return [], lambda: None

View File

@ -77,6 +77,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
warnings.warn(
"`VectorDBQAWithSourcesChain` is deprecated - "
"please use `from langchain.chains import RetrievalQAWithSourcesChain`",
stacklevel=5,
)
return values

View File

@ -64,7 +64,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
)
except Exception as e:
msg = f"Parsing text\n{text}\n raised following error:\n{e}"
raise OutputParserException(msg)
raise OutputParserException(msg) from e
@classmethod
def from_components(

View File

@ -162,6 +162,7 @@ class QueryTransformer(Transformer):
warnings.warn(
"Dates are expected to be provided in ISO 8601 date format "
"(YYYY-MM-DD).",
stacklevel=3,
)
return {"date": item, "type": "date"}
@ -173,9 +174,9 @@ class QueryTransformer(Transformer):
except ValueError:
try:
datetime.datetime.strptime(item, "%Y-%m-%dT%H:%M:%S")
except ValueError:
except ValueError as e:
msg = "Datetime values are expected to be in ISO 8601 format."
raise ValueError(msg)
raise ValueError(msg) from e
return {"datetime": item, "type": "datetime"}
def string(self, item: Any) -> str:

View File

@ -3,7 +3,6 @@
from __future__ import annotations
import inspect
import warnings
from abc import abstractmethod
from typing import Any, Optional
@ -320,15 +319,6 @@ class VectorDBQA(BaseRetrievalQA):
search_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Extra search args."""
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: dict) -> Any:
warnings.warn(
"`VectorDBQA` is deprecated - "
"please use `from langchain.chains import RetrievalQA`",
)
return values
@model_validator(mode="before")
@classmethod
def validate_search_type(cls, values: dict) -> Any:

View File

@ -193,4 +193,4 @@ class RouterOutputParser(BaseOutputParser[dict[str, str]]):
return parsed
except Exception as e:
msg = f"Parsing text\n{text}\n raised following error:\n{e}"
raise OutputParserException(msg)
raise OutputParserException(msg) from e

View File

@ -102,7 +102,7 @@ class SequentialChain(Chain):
) -> dict[str, str]:
known_values = inputs.copy()
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
for i, chain in enumerate(self.chains):
for _i, chain in enumerate(self.chains):
callbacks = _run_manager.get_child()
outputs = chain(known_values, return_only_outputs=True, callbacks=callbacks)
known_values.update(outputs)
@ -116,7 +116,7 @@ class SequentialChain(Chain):
known_values = inputs.copy()
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
for i, chain in enumerate(self.chains):
for _i, chain in enumerate(self.chains):
outputs = await chain.acall(
known_values,
return_only_outputs=True,

View File

@ -36,6 +36,7 @@ def __getattr__(name: str) -> None:
"Please import from langchain-community instead:\n\n"
f"`from langchain_community.chat_models import {name}`.\n\n"
"To install langchain-community run `pip install -U langchain-community`.",
stacklevel=2,
category=LangChainDeprecationWarning,
)

View File

@ -316,6 +316,7 @@ def init_chat_model(
f"{config_prefix=} has been set but no fields are configurable. Set "
f"`configurable_fields=(...)` to specify the model params that are "
f"configurable.",
stacklevel=2,
)
if not configurable_fields:

View File

@ -62,12 +62,12 @@ def _embedding_factory() -> Embeddings:
from langchain_community.embeddings.openai import ( # type: ignore[no-redef]
OpenAIEmbeddings,
)
except ImportError:
except ImportError as e:
msg = (
"Could not import OpenAIEmbeddings. Please install the "
"OpenAIEmbeddings package using `pip install langchain-openai`."
)
raise ImportError(msg)
raise ImportError(msg) from e
return OpenAIEmbeddings()
@ -139,14 +139,14 @@ class _EmbeddingDistanceChainMixin(Chain):
if isinstance(embeddings, tuple(types_)):
try:
import tiktoken # noqa: F401
except ImportError:
except ImportError as e:
msg = (
"The tiktoken library is required to use the default "
"OpenAI embeddings with embedding distance evaluators."
" Please either manually select a different Embeddings object"
" or install tiktoken using `pip install tiktoken`."
)
raise ImportError(msg)
raise ImportError(msg) from e
return values
model_config = ConfigDict(
@ -202,13 +202,13 @@ class _EmbeddingDistanceChainMixin(Chain):
"""
try:
from langchain_community.utils.math import cosine_similarity
except ImportError:
except ImportError as e:
msg = (
"The cosine_similarity function is required to compute cosine distance."
" Please install the langchain-community package using"
" `pip install langchain-community`."
)
raise ImportError(msg)
raise ImportError(msg) from e
return 1.0 - cosine_similarity(a, b)
@staticmethod

View File

@ -61,12 +61,12 @@ def load_dataset(uri: str) -> list[dict]:
"""
try:
from datasets import load_dataset
except ImportError:
except ImportError as e:
msg = (
"load_dataset requires the `datasets` package."
" Please install with `pip install datasets`"
)
raise ImportError(msg)
raise ImportError(msg) from e
dataset = load_dataset(f"LangChainDatasets/{uri}")
return list(dataset["train"])
@ -142,7 +142,7 @@ def load_evaluator(
from langchain_community.chat_models.openai import ( # type: ignore[no-redef]
ChatOpenAI,
)
except ImportError:
except ImportError as e:
msg = (
"Could not import langchain_openai or fallback onto "
"langchain_community. Please install langchain_openai "
@ -150,7 +150,7 @@ def load_evaluator(
"It's recommended to install langchain_openai AND "
"specify a language model explicitly."
)
raise ImportError(msg)
raise ImportError(msg) from e
llm = llm or ChatOpenAI(model="gpt-4", seed=42, temperature=0)
except Exception as e:

View File

@ -48,14 +48,14 @@ class JsonEditDistanceEvaluator(StringEvaluator):
else:
try:
from rapidfuzz import distance as rfd
except ImportError:
except ImportError as e:
msg = (
"The default string_distance operator for the "
" JsonEditDistanceEvaluator requires installation of "
"the rapidfuzz package. "
"Please install it with `pip install rapidfuzz`."
)
raise ImportError(msg)
raise ImportError(msg) from e
self._string_distance = rfd.DamerauLevenshtein.normalized_distance
if canonicalize is not None:
self._canonicalize = canonicalize

View File

@ -45,12 +45,12 @@ class JsonSchemaEvaluator(StringEvaluator):
super().__init__()
try:
import jsonschema # noqa: F401
except ImportError:
except ImportError as e:
msg = (
"The JsonSchemaEvaluator requires the jsonschema package."
" Please install it with `pip install jsonschema`."
)
raise ImportError(msg)
raise ImportError(msg) from e
@property
def requires_input(self) -> bool:
@ -70,9 +70,9 @@ class JsonSchemaEvaluator(StringEvaluator):
def _parse_json(self, node: Any) -> Union[dict, list, None, float, bool, int, str]:
if isinstance(node, str):
return parse_json_markdown(node)
if hasattr(node, "schema") and callable(getattr(node, "schema")):
if hasattr(node, "schema") and callable(node.schema):
# Pydantic model
return getattr(node, "schema")()
return node.schema()
return node
def _validate(self, prediction: Any, schema: Any) -> dict:

View File

@ -124,12 +124,12 @@ class _EvalArgsMixin:
msg = f"{self.__class__.__name__} requires an input string."
raise ValueError(msg)
if input_ is not None and not self.requires_input:
warn(self._skip_input_warning)
warn(self._skip_input_warning, stacklevel=3)
if self.requires_reference and reference is None:
msg = f"{self.__class__.__name__} requires a reference string."
raise ValueError(msg)
if reference is not None and not self.requires_reference:
warn(self._skip_reference_warning)
warn(self._skip_reference_warning, stacklevel=3)
class StringEvaluator(_EvalArgsMixin, ABC):

View File

@ -29,12 +29,12 @@ def _load_rapidfuzz() -> Any:
"""
try:
import rapidfuzz
except ImportError:
except ImportError as e:
msg = (
"Please install the rapidfuzz library to use the FuzzyMatchStringEvaluator."
"Please install it with `pip install rapidfuzz`."
)
raise ImportError(msg)
raise ImportError(msg) from e
return rapidfuzz.distance

View File

@ -176,13 +176,14 @@ def _get_in_memory_vectorstore() -> type[VectorStore]:
try:
from langchain_community.vectorstores.inmemory import InMemoryVectorStore
except ImportError:
except ImportError as e:
msg = "Please install langchain-community to use the InMemoryVectorStore."
raise ImportError(msg)
raise ImportError(msg) from e
warnings.warn(
"Using InMemoryVectorStore as the default vectorstore."
"This memory store won't persist data. You should explicitly"
"specify a vectorstore when using VectorstoreIndexCreator",
stacklevel=3,
)
return InMemoryVectorStore

View File

@ -552,6 +552,7 @@ def __getattr__(name: str) -> Any:
"Please import from langchain-community instead:\n\n"
f"`from langchain_community.llms import {name}`.\n\n"
"To install langchain-community run `pip install -U langchain-community`.",
stacklevel=2,
category=LangChainDeprecationWarning,
)

View File

@ -58,6 +58,7 @@ class BaseChatMemory(BaseMemory, ABC):
f"'{self.__class__.__name__}' got multiple output keys:"
f" {outputs.keys()}. The default 'output' key is being used."
f" If this is not desired, please manually set 'output_key'.",
stacklevel=3,
)
else:
msg = (

View File

@ -42,6 +42,7 @@ class CombinedMemory(BaseMemory):
"When using CombinedMemory, "
"input keys should be so the input is known. "
f" Was not set on {val}",
stacklevel=5,
)
return value

View File

@ -115,12 +115,12 @@ class UpstashRedisEntityStore(BaseEntityStore):
):
try:
from upstash_redis import Redis
except ImportError:
except ImportError as e:
msg = (
"Could not import upstash_redis python package. "
"Please install it with `pip install upstash_redis`."
)
raise ImportError(msg)
raise ImportError(msg) from e
super().__init__(*args, **kwargs)
@ -211,23 +211,23 @@ class RedisEntityStore(BaseEntityStore):
):
try:
import redis
except ImportError:
except ImportError as e:
msg = (
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
raise ImportError(msg)
raise ImportError(msg) from e
super().__init__(*args, **kwargs)
try:
from langchain_community.utilities.redis import get_client
except ImportError:
except ImportError as e:
msg = (
"Could not import langchain_community.utilities.redis.get_client. "
"Please install it with `pip install langchain-community`."
)
raise ImportError(msg)
raise ImportError(msg) from e
try:
self.redis_client = get_client(redis_url=url, decode_responses=True)
@ -311,12 +311,12 @@ class SQLiteEntityStore(BaseEntityStore):
super().__init__(*args, **kwargs)
try:
import sqlite3
except ImportError:
except ImportError as e:
msg = (
"Could not import sqlite3 python package. "
"Please install it with `pip install sqlite3`."
)
raise ImportError(msg)
raise ImportError(msg) from e
# Basic validation to prevent obviously malicious table/session names
if not table_name.isidentifier() or not session_id.isidentifier():

View File

@ -26,12 +26,12 @@ class EnumOutputParser(BaseOutputParser[Enum]):
def parse(self, response: str) -> Enum:
try:
return self.enum(response.strip())
except ValueError:
except ValueError as e:
msg = (
f"Response '{response}' is not one of the "
f"expected values: {self._valid_values}"
)
raise OutputParserException(msg)
raise OutputParserException(msg) from e
def get_format_instructions(self) -> str:
return f"Select one of the following options: {', '.join(self._valid_values)}"

View File

@ -137,17 +137,17 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
self.dataframe[request_params],
request_type,
)()
except (AttributeError, IndexError, KeyError):
except (AttributeError, IndexError, KeyError) as e:
if request_type not in {"column", "row"}:
msg = f"Unsupported request type '{request_type}'. \
Please check the format instructions."
raise OutputParserException(msg)
raise OutputParserException(msg) from e
msg = f"""Requested index {
request_params
if stripped_request_params is None
else stripped_request_params
} is out of bounds."""
raise OutputParserException(msg)
raise OutputParserException(msg) from e
return result

View File

@ -43,12 +43,12 @@ class CohereRerank(BaseDocumentCompressor):
if not values.get("client"):
try:
import cohere
except ImportError:
except ImportError as e:
msg = (
"Could not import cohere python package. "
"Please install it with `pip install cohere`."
)
raise ImportError(msg)
raise ImportError(msg) from e
cohere_api_key = get_from_dict_or_env(
values,
"cohere_api_key",
@ -92,7 +92,7 @@ class CohereRerank(BaseDocumentCompressor):
max_chunks_per_doc=max_chunks_per_doc,
)
if hasattr(results, "results"):
results = getattr(results, "results")
results = results.results
return [
{"index": res.index, "relevance_score": res.relevance_score}
for res in results

View File

@ -11,12 +11,12 @@ from pydantic import ConfigDict, Field
def _get_similarity_function() -> Callable:
try:
from langchain_community.utils.math import cosine_similarity
except ImportError:
except ImportError as e:
msg = (
"To use please install langchain-community "
"with `pip install langchain-community`."
)
raise ImportError(msg)
raise ImportError(msg) from e
return cosine_similarity
@ -62,12 +62,12 @@ class EmbeddingsFilter(BaseDocumentCompressor):
_get_embeddings_from_stateful_docs,
get_stateful_documents,
)
except ImportError:
except ImportError as e:
msg = (
"To use please install langchain-community "
"with `pip install langchain-community`."
)
raise ImportError(msg)
raise ImportError(msg) from e
try:
import numpy as np
@ -105,12 +105,12 @@ class EmbeddingsFilter(BaseDocumentCompressor):
_aget_embeddings_from_stateful_docs,
get_stateful_documents,
)
except ImportError:
except ImportError as e:
msg = (
"To use please install langchain-community "
"with `pip install langchain-community`."
)
raise ImportError(msg)
raise ImportError(msg) from e
try:
import numpy as np

View File

@ -80,7 +80,7 @@ class MergerRetriever(BaseRetriever):
merged_documents = []
max_docs = max(map(len, retriever_docs), default=0)
for i in range(max_docs):
for retriever, doc in zip(self.retrievers, retriever_docs):
for _retriever, doc in zip(self.retrievers, retriever_docs):
if i < len(doc):
merged_documents.append(doc[i])
@ -116,7 +116,7 @@ class MergerRetriever(BaseRetriever):
merged_documents = []
max_docs = max(map(len, retriever_docs), default=0)
for i in range(max_docs):
for retriever, doc in zip(self.retrievers, retriever_docs):
for _retriever, doc in zip(self.retrievers, retriever_docs):
if i < len(doc):
merged_documents.append(doc[i])

View File

@ -27,12 +27,12 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
"""Get the translator class corresponding to the vector store class."""
try:
import langchain_community # noqa: F401
except ImportError:
except ImportError as e:
msg = (
"The langchain-community package must be installed to use this feature."
" Please install it using `pip install langchain-community`."
)
raise ImportError(msg)
raise ImportError(msg) from e
from langchain_community.query_constructors.astradb import AstraDBTranslator
from langchain_community.query_constructors.chroma import ChromaTranslator

View File

@ -160,9 +160,9 @@ class EvalError(dict):
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
except KeyError as e:
msg = f"'EvalError' object has no attribute '{name}'"
raise AttributeError(msg)
raise AttributeError(msg) from e
def _wrap_in_chain_factory(
@ -350,7 +350,7 @@ def _validate_example_inputs_for_language_model(
except InputFormatError:
try:
_get_messages(first_example.inputs or {})
except InputFormatError:
except InputFormatError as e:
msg = (
"Example inputs do not match language model input format. "
"Expected a dictionary with messages or a single prompt."
@ -359,7 +359,7 @@ def _validate_example_inputs_for_language_model(
" to convert the example.inputs to a compatible format"
" for the llm or chat model you wish to evaluate."
)
raise InputFormatError(msg)
raise InputFormatError(msg) from e
def _validate_example_inputs_for_chain(
@ -1041,7 +1041,7 @@ run_on_dataset(
f"Test project {project_name} already exists. Please use a different name:"
f"\n\n{example_msg}"
)
raise ValueError(msg)
raise ValueError(msg) from e
comparison_url = dataset.url + f"/compare?selectedSessions={project.id}"
print( # noqa: T201
f"View the evaluation results for project '{project_name}'"

View File

@ -79,6 +79,7 @@ def __getattr__(name: str) -> Any:
"To install langchain-community run "
"`pip install -U langchain-community`.",
category=LangChainDeprecationWarning,
stacklevel=2,
)
return getattr(tools, name)

View File

@ -146,6 +146,7 @@ ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy
select = [
"A", # flake8-builtins
"ASYNC", # flake8-async
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"COM", # flake8-commas
"D", # pydocstyle

View File

@ -1,7 +1,7 @@
"""A mock Robot server."""
from enum import Enum
from typing import Any, Optional, Union
from typing import Annotated, Any, Optional, Union
from uuid import uuid4
import uvicorn
@ -127,7 +127,9 @@ async def goto(x: int, y: int, z: int, cautiousness: Cautiousness) -> dict[str,
@app.get("/get_state", description="Get the robot's state")
async def get_state(
fields: list[StateItems] = Query(..., description="List of state items to return"),
fields: Annotated[
list[StateItems], Query(..., description="List of state items to return")
],
) -> dict[str, Any]:
state = {}
for field in fields:

View File

@ -441,7 +441,10 @@ def test_agent_invalid_tool() -> None:
)
resp = agent("when was langchain made")
resp["intermediate_steps"][0][1] == "Foo is not a valid tool, try one of [Search]."
assert (
resp["intermediate_steps"][0][1]
== "Foo is not a valid tool, try one of [Search]."
)
async def test_runnable_agent() -> None:

View File

@ -359,4 +359,7 @@ async def test_agent_invalid_tool() -> None:
)
resp = await agent.acall("when was langchain made")
resp["intermediate_steps"][0][1] == "Foo is not a valid tool, try one of [Search]."
assert (
resp["intermediate_steps"][0][1]
== "Foo is not a valid tool, try one of [Search]."
)

View File

@ -26,7 +26,8 @@ def test_parse() -> None:
except OutputParserException:
pass # Test passes if OutputParserException is raised
else:
assert False, f"Expected OutputParserException, but got {parser.parse(text)}"
msg = f"Expected OutputParserException, but got {parser.parse(text)}"
raise AssertionError(msg)
def test_output_type() -> None:

View File

@ -93,7 +93,8 @@ def test_yaml_output_parser_fail() -> None:
print("parse_result:", e) # noqa: T201
assert "Failed to parse TestModel from completion" in str(e)
else:
assert False, "Expected OutputParserException"
msg = "Expected OutputParserException"
raise AssertionError(msg)
def test_yaml_output_parser_output_type() -> None:

View File

@ -65,7 +65,6 @@ _INVALID_PROMPTS = (
_VALID_MESSAGES,
)
def test__get_messages_valid(inputs: dict[str, Any]) -> None:
{"messages": []}
_get_messages(inputs)

View File

@ -2734,7 +2734,7 @@ dev = [
{ name = "jupyter", specifier = ">=1.0.0,<2.0.0" },
{ name = "setuptools", specifier = ">=67.6.1,<68.0.0" },
]
lint = [{ name = "ruff", specifier = ">=0.11.2,<0.12.0" }]
lint = [{ name = "ruff", specifier = ">=0.12.2,<0.13" }]
test = [
{ name = "blockbuster", specifier = "~=1.5.18" },
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" },
@ -2900,7 +2900,7 @@ requires-dist = [
[package.metadata.requires-dev]
codespell = [{ name = "codespell", specifier = ">=2.2.0,<3.0.0" }]
dev = [{ name = "langchain-core", editable = "../core" }]
lint = [{ name = "ruff", specifier = ">=0.5,<1.0" }]
lint = [{ name = "ruff", specifier = ">=0.12.2,<0.13" }]
test = [
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" },
{ name = "langchain-core", editable = "../core" },