feat(langchain): add ruff rules TRY (#32047)

See https://docs.astral.sh/ruff/rules/#tryceratops-try

* TRY004 (replace by TypeError) in main code is escaped with `noqa` to
not break backward compatibility. The rule is still interesting for new
code.
* TRY301 ignored at the moment. This one is quite hard to fix and I'm
not sure it's very interesting to activate it.

Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet 2025-07-21 19:41:20 +02:00 committed by GitHub
parent 8b8d90bea5
commit 64261449b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 72 additions and 72 deletions

View File

@ -106,11 +106,12 @@ def create_importer(
"<https://python.langchain.com/docs/versions/v0_2/>"
),
)
return result
except Exception as e:
msg = f"module {new_module} has no attribute {name}"
raise AttributeError(msg) from e
return result
if fallback_module:
try:
module = importlib.import_module(fallback_module)
@ -139,12 +140,13 @@ def create_importer(
"<https://python.langchain.com/docs/versions/v0_2/>"
),
)
return result
except Exception as e:
msg = f"module {fallback_module} has no attribute {name}"
raise AttributeError(msg) from e
return result
msg = f"module {package} has no attribute {name}"
raise AttributeError(msg)

View File

@ -1380,7 +1380,7 @@ class AgentExecutor(Chain):
observation = self.handle_parsing_errors(e)
else:
msg = "Got unexpected type of `handle_parsing_errors`"
raise ValueError(msg) from e
raise ValueError(msg) from e # noqa: TRY004
output = AgentAction("_Exception", observation, text)
if run_manager:
run_manager.on_agent_action(output, color="green")
@ -1519,7 +1519,7 @@ class AgentExecutor(Chain):
observation = self.handle_parsing_errors(e)
else:
msg = "Got unexpected type of `handle_parsing_errors`"
raise ValueError(msg) from e
raise ValueError(msg) from e # noqa: TRY004
output = AgentAction("_Exception", observation, text)
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
observation = await ExceptionTool().arun(

View File

@ -55,7 +55,7 @@ class ChatAgent(Agent):
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
if not isinstance(agent_scratchpad, str):
msg = "agent_scratchpad should be of type string."
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
if agent_scratchpad:
return (
f"This was your previous work "

View File

@ -358,12 +358,12 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
run = self._wait_for_run(run.id, run.thread_id)
except BaseException as e:
run_manager.on_chain_error(e)
raise e
raise
try:
response = self._get_response(run)
except BaseException as e:
run_manager.on_chain_error(e, metadata=run.dict())
raise e
raise
else:
run_manager.on_chain_end(response)
return response
@ -494,12 +494,12 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
run = await self._await_for_run(run.id, run.thread_id)
except BaseException as e:
run_manager.on_chain_error(e)
raise e
raise
try:
response = self._get_response(run)
except BaseException as e:
run_manager.on_chain_error(e, metadata=run.dict())
raise e
raise
else:
run_manager.on_chain_end(response)
return response

View File

@ -87,7 +87,7 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
) -> Union[AgentAction, AgentFinish]:
if not isinstance(result[0], ChatGeneration):
msg = "This output parser only works on ChatGeneration output"
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
message = result[0].message
return self._parse_ai_message(message)

View File

@ -61,7 +61,7 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
) -> Union[list[AgentAction], AgentFinish]:
if not isinstance(result[0], ChatGeneration):
msg = "This output parser only works on ChatGeneration output"
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
message = result[0].message
return parse_ai_message_to_openai_tool_action(message)

View File

@ -98,7 +98,7 @@ class ToolsAgentOutputParser(MultiActionAgentOutputParser):
) -> Union[list[AgentAction], AgentFinish]:
if not isinstance(result[0], ChatGeneration):
msg = "This output parser only works on ChatGeneration output"
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
message = result[0].message
return parse_ai_message_to_tool_action(message)

View File

@ -56,7 +56,7 @@ class StructuredChatAgent(Agent):
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
if not isinstance(agent_scratchpad, str):
msg = "agent_scratchpad should be of type string."
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
if agent_scratchpad:
return (
f"This was your previous work "

View File

@ -71,12 +71,8 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
try:
if self.output_fixing_parser is not None:
parsed_obj: Union[AgentAction, AgentFinish] = (
self.output_fixing_parser.parse(text)
)
else:
parsed_obj = self.base_parser.parse(text)
return parsed_obj
return self.output_fixing_parser.parse(text)
return self.base_parser.parse(text)
except Exception as e:
msg = f"Could not parse LLM output: {text}"
raise OutputParserException(msg) from e

View File

@ -174,7 +174,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
)
except BaseException as e:
run_manager.on_chain_error(e)
raise e
raise
run_manager.on_chain_end(outputs)
if include_run_info:
@ -228,7 +228,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
)
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
raise
await run_manager.on_chain_end(outputs)
if include_run_info:

View File

@ -127,7 +127,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
"Output parser of llm_chain should be a RegexParser,"
f" got {output_parser}"
)
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
output_keys = output_parser.output_keys
if self.rank_key not in output_keys:
msg = (

View File

@ -57,7 +57,7 @@ def _get_chat_history(chat_history: list[CHAT_TURN_TYPE]) -> str:
f"Unsupported chat history format: {type(dialogue_turn)}."
f" Full chat history: {chat_history} "
)
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
return buffer

View File

@ -164,12 +164,13 @@ class ElasticsearchDatabaseChain(Chain):
chain_result: dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result
except Exception as exc:
# Append intermediate steps to exception, to aid in logging and later
# improvement of few shot prompt seeds
exc.intermediate_steps = intermediate_steps # type: ignore[attr-defined]
raise exc
raise
return chain_result
@property
def _chain_type(self) -> str:

View File

@ -251,7 +251,7 @@ class LLMChain(Chain):
response = self.generate(input_list, run_manager=run_manager)
except BaseException as e:
run_manager.on_chain_error(e)
raise e
raise
outputs = self.create_outputs(response)
run_manager.on_chain_end({"outputs": outputs})
return outputs
@ -276,7 +276,7 @@ class LLMChain(Chain):
response = await self.agenerate(input_list, run_manager=run_manager)
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
raise
outputs = self.create_outputs(response)
await run_manager.on_chain_end({"outputs": outputs})
return outputs

View File

@ -117,7 +117,7 @@ def _load_stuff_documents_chain(config: dict, **kwargs: Any) -> StuffDocumentsCh
if not isinstance(llm_chain, LLMChain):
msg = f"Expected LLMChain, got {llm_chain}"
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
if "document_prompt" in config:
prompt_config = config.pop("document_prompt")
@ -150,7 +150,7 @@ def _load_map_reduce_documents_chain(
if not isinstance(llm_chain, LLMChain):
msg = f"Expected LLMChain, got {llm_chain}"
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
if "reduce_documents_chain" in config:
reduce_documents_chain = load_chain_from_config(

View File

@ -361,13 +361,13 @@ def get_openapi_chain(
try:
spec = conversion(spec)
break
except ImportError as e:
raise e
except ImportError:
raise
except Exception: # noqa: S110
pass
if isinstance(spec, str):
msg = f"Unable to parse spec from source {spec}"
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
openai_fns, call_api_fn = openapi_spec_to_openai_fn(spec)
if not llm:
msg = (

View File

@ -125,7 +125,7 @@ class LLMRouterChain(RouterChain):
def _validate_outputs(self, outputs: dict[str, Any]) -> None:
super()._validate_outputs(outputs)
if not isinstance(outputs["next_inputs"], dict):
raise ValueError
raise ValueError # noqa: TRY004
def _call(
self,
@ -178,10 +178,10 @@ class RouterOutputParser(BaseOutputParser[dict[str, str]]):
parsed = parse_and_check_json_markdown(text, expected_keys)
if not isinstance(parsed["destination"], str):
msg = "Expected 'destination' to be a string."
raise ValueError(msg)
raise TypeError(msg)
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
msg = f"Expected 'next_inputs' to be {self.next_inputs_type}."
raise ValueError(msg)
raise TypeError(msg)
parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
if (
parsed["destination"].strip().lower()
@ -190,7 +190,7 @@ class RouterOutputParser(BaseOutputParser[dict[str, str]]):
parsed["destination"] = None
else:
parsed["destination"] = parsed["destination"].strip()
return parsed
except Exception as e:
msg = f"Parsing text\n{text}\n raised following error:\n{e}"
raise OutputParserException(msg) from e
return parsed

View File

@ -340,7 +340,7 @@ class CacheBackedEmbeddings(Embeddings):
"key_encoder must be either 'blake2b', 'sha1', 'sha256', 'sha512' "
"or a callable that encodes keys."
)
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
document_embedding_store = EncoderBackedStore[str, list[float]](
document_embedding_cache,

View File

@ -23,11 +23,10 @@ from langchain.schema import RUN_KEY
def _import_numpy() -> Any:
try:
import numpy as np
return np
except ImportError as e:
msg = "Could not import numpy, please install with `pip install numpy`."
raise ImportError(msg) from e
return np
logger = logging.getLogger(__name__)

View File

@ -74,9 +74,9 @@ class JsonValidityEvaluator(StringEvaluator):
"""
try:
parse_json_markdown(prediction, parser=json.loads)
return {"score": 1}
except Exception as e:
return {"score": 0, "reasoning": str(e)}
return {"score": 1}
class JsonEqualityEvaluator(StringEvaluator):

View File

@ -80,11 +80,9 @@ class JsonSchemaEvaluator(StringEvaluator):
try:
validate(instance=prediction, schema=schema)
return {
"score": True,
}
except ValidationError as e:
return {"score": False, "reasoning": repr(e)}
return {"score": True}
@override
def _evaluate_strings(

View File

@ -153,7 +153,7 @@ class SQLRecordManager(RecordManager):
"""Create the database schema."""
if isinstance(self.engine, AsyncEngine):
msg = "This method is not supported for async engines."
raise AssertionError(msg)
raise AssertionError(msg) # noqa: TRY004
Base.metadata.create_all(self.engine)
@ -162,7 +162,7 @@ class SQLRecordManager(RecordManager):
if not isinstance(self.engine, AsyncEngine):
msg = "This method is not supported for sync engines."
raise AssertionError(msg)
raise AssertionError(msg) # noqa: TRY004
async with self.engine.begin() as session:
await session.run_sync(Base.metadata.create_all)
@ -173,7 +173,7 @@ class SQLRecordManager(RecordManager):
if isinstance(self.session_factory, async_sessionmaker):
msg = "This method is not supported for async engines."
raise AssertionError(msg)
raise AssertionError(msg) # noqa: TRY004
session = self.session_factory()
try:
@ -187,7 +187,7 @@ class SQLRecordManager(RecordManager):
if not isinstance(self.session_factory, async_sessionmaker):
msg = "This method is not supported for sync engines."
raise AssertionError(msg)
raise AssertionError(msg) # noqa: TRY004
async with self.session_factory() as session:
yield session
@ -221,7 +221,7 @@ class SQLRecordManager(RecordManager):
dt = float(dt)
if not isinstance(dt, float):
msg = f"Unexpected type for datetime: {type(dt)}"
raise AssertionError(msg)
raise AssertionError(msg) # noqa: TRY004
return dt
async def aget_time(self) -> float:
@ -254,7 +254,7 @@ class SQLRecordManager(RecordManager):
dt = float(dt)
if not isinstance(dt, float):
msg = f"Unexpected type for datetime: {type(dt)}"
raise AssertionError(msg)
raise AssertionError(msg) # noqa: TRY004
return dt
def update(

View File

@ -128,7 +128,7 @@ class UpstashRedisEntityStore(BaseEntityStore):
self.redis_client = Redis(url=url, token=token)
except Exception as exc:
error_msg = "Upstash Redis instance could not be initiated"
logger.error(error_msg)
logger.exception(error_msg)
raise RuntimeError(error_msg) from exc
self.session_id = session_id
@ -237,8 +237,8 @@ class RedisEntityStore(BaseEntityStore):
try:
self.redis_client = get_client(redis_url=url, decode_responses=True)
except redis.exceptions.ConnectionError as error:
logger.error(error)
except redis.exceptions.ConnectionError:
logger.exception("Redis client could not connect")
self.session_id = session_id
self.key_prefix = key_prefix

View File

@ -39,7 +39,7 @@ class ModelLaboratory:
"If you want to initialize with LLMs, use the `from_llms` method "
"instead (`ModelLaboratory.from_llms(...)`)"
)
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
if len(chain.input_keys) != 1:
msg = (
"Currently only support chains with one input variable, "

View File

@ -70,7 +70,7 @@ class OutputFixingParser(BaseOutputParser[T]):
return self.parser.parse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
raise
retries += 1
if self.legacy and hasattr(self.retry_chain, "run"):
completion = self.retry_chain.run(
@ -107,7 +107,7 @@ class OutputFixingParser(BaseOutputParser[T]):
return await self.parser.aparse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
raise
retries += 1
if self.legacy and hasattr(self.retry_chain, "arun"):
completion = await self.retry_chain.arun(

View File

@ -104,9 +104,9 @@ class RetryOutputParser(BaseOutputParser[T]):
while retries <= self.max_retries:
try:
return self.parser.parse(completion)
except OutputParserException as e:
except OutputParserException:
if retries == self.max_retries:
raise e
raise
retries += 1
if self.legacy and hasattr(self.retry_chain, "run"):
completion = self.retry_chain.run(
@ -141,7 +141,7 @@ class RetryOutputParser(BaseOutputParser[T]):
return await self.parser.aparse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
raise
retries += 1
if self.legacy and hasattr(self.retry_chain, "arun"):
completion = await self.retry_chain.arun(
@ -232,7 +232,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
return self.parser.parse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
raise
retries += 1
if self.legacy and hasattr(self.retry_chain, "run"):
completion = self.retry_chain.run(
@ -260,7 +260,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
return await self.parser.aparse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
raise
retries += 1
if self.legacy and hasattr(self.retry_chain, "arun"):
completion = await self.retry_chain.arun(

View File

@ -48,7 +48,7 @@ class DocumentCompressorPipeline(BaseDocumentCompressor):
documents = _transformer.transform_documents(documents)
else:
msg = f"Got unexpected transformer type: {_transformer}"
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
return documents
async def acompress_documents(
@ -78,5 +78,5 @@ class DocumentCompressorPipeline(BaseDocumentCompressor):
documents = await _transformer.atransform_documents(documents)
else:
msg = f"Got unexpected transformer type: {_transformer}"
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
return documents

View File

@ -116,7 +116,7 @@ class EnsembleRetriever(BaseRetriever):
result = self.rank_fusion(input, run_manager=run_manager, config=config)
except Exception as e:
run_manager.on_retriever_error(e)
raise e
raise
else:
run_manager.on_retriever_end(
result,
@ -157,7 +157,7 @@ class EnsembleRetriever(BaseRetriever):
)
except Exception as e:
await run_manager.on_retriever_error(e)
raise e
raise
else:
await run_manager.on_retriever_end(
result,

View File

@ -558,7 +558,7 @@ def _construct_run_evaluator(
return run_evaluator_dec(eval_config)
else:
msg = f"Unknown evaluator type: {type(eval_config)}"
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
if isinstance(evaluator_, StringEvaluator):
if evaluator_.requires_reference and reference_key is None:
@ -668,7 +668,7 @@ def _load_run_evaluators(
f"Unsupported custom evaluator: {custom_evaluator}."
f" Expected RunEvaluator or StringEvaluator."
)
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
return run_evaluators
@ -1040,7 +1040,7 @@ def _prepare_eval_run(
)
except (HTTPError, ValueError, LangSmithError) as e:
if "already exists " not in str(e):
raise e
raise
uid = uuid.uuid4()
example_msg = f"""
run_on_dataset(
@ -1123,9 +1123,9 @@ class _DatasetRunContainer:
run_id=None,
project_id=self.project.id,
)
except Exception as e:
except Exception:
logger.exception(
"Error running batch evaluator %s: %s", repr(evaluator), e
"Error running batch evaluator %s", repr(evaluator)
)
return aggregate_feedback

View File

@ -181,6 +181,7 @@ select = [
"T10", # flake8-debugger
"T20", # flake8-print
"TID", # flake8-tidy-imports
"TRY", # tryceratops
"UP", # pyupgrade
"W", # pycodestyle warning
"YTT", # flake8-2020
@ -201,6 +202,9 @@ ignore = [
"RUF012", # Doesn't play well with Pydantic
"SLF001", # Private member access
"UP007", # pyupgrade: non-pep604-annotation-union
# TODO
"TRY301", # tryceratops: raise-within-try
]
unfixable = ["B028"] # People should intentionally tune the stacklevel

View File

@ -105,7 +105,7 @@ class GenericFakeChatModel(BaseChatModel):
f"Expected generate to return a ChatResult, "
f"but got {type(chat_result)} instead."
)
raise ValueError(msg)
raise TypeError(msg)
message = chat_result.generations[0].message
@ -114,7 +114,7 @@ class GenericFakeChatModel(BaseChatModel):
f"Expected invoke to return an AIMessage, "
f"but got {type(message)} instead."
)
raise ValueError(msg)
raise TypeError(msg)
content = message.content

View File

@ -2721,7 +2721,7 @@ dependencies = [
requires-dist = [
{ name = "jsonpatch", specifier = ">=1.33,<2.0" },
{ name = "langsmith", specifier = ">=0.3.45" },
{ name = "packaging", specifier = ">=23.2,<25" },
{ name = "packaging", specifier = ">=23.2" },
{ name = "pydantic", specifier = ">=2.7.4" },
{ name = "pyyaml", specifier = ">=5.3" },
{ name = "tenacity", specifier = ">=8.1.0,!=8.4.0,<10.0.0" },