mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-18 17:11:25 +00:00
feat(langchain): add ruff rules TRY (#32047)
See https://docs.astral.sh/ruff/rules/#tryceratops-try * TRY004 (replace by TypeError) in main code is escaped with `noqa` to not break backward compatibility. The rule is still interesting for new code. * TRY301 ignored at the moment. This one is quite hard to fix and I'm not sure it's very interesting to activate it. Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
parent
8b8d90bea5
commit
64261449b8
@ -106,11 +106,12 @@ def create_importer(
|
|||||||
"<https://python.langchain.com/docs/versions/v0_2/>"
|
"<https://python.langchain.com/docs/versions/v0_2/>"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return result
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"module {new_module} has no attribute {name}"
|
msg = f"module {new_module} has no attribute {name}"
|
||||||
raise AttributeError(msg) from e
|
raise AttributeError(msg) from e
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
if fallback_module:
|
if fallback_module:
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(fallback_module)
|
module = importlib.import_module(fallback_module)
|
||||||
@ -139,12 +140,13 @@ def create_importer(
|
|||||||
"<https://python.langchain.com/docs/versions/v0_2/>"
|
"<https://python.langchain.com/docs/versions/v0_2/>"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"module {fallback_module} has no attribute {name}"
|
msg = f"module {fallback_module} has no attribute {name}"
|
||||||
raise AttributeError(msg) from e
|
raise AttributeError(msg) from e
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
msg = f"module {package} has no attribute {name}"
|
msg = f"module {package} has no attribute {name}"
|
||||||
raise AttributeError(msg)
|
raise AttributeError(msg)
|
||||||
|
|
||||||
|
@ -1380,7 +1380,7 @@ class AgentExecutor(Chain):
|
|||||||
observation = self.handle_parsing_errors(e)
|
observation = self.handle_parsing_errors(e)
|
||||||
else:
|
else:
|
||||||
msg = "Got unexpected type of `handle_parsing_errors`"
|
msg = "Got unexpected type of `handle_parsing_errors`"
|
||||||
raise ValueError(msg) from e
|
raise ValueError(msg) from e # noqa: TRY004
|
||||||
output = AgentAction("_Exception", observation, text)
|
output = AgentAction("_Exception", observation, text)
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_agent_action(output, color="green")
|
run_manager.on_agent_action(output, color="green")
|
||||||
@ -1519,7 +1519,7 @@ class AgentExecutor(Chain):
|
|||||||
observation = self.handle_parsing_errors(e)
|
observation = self.handle_parsing_errors(e)
|
||||||
else:
|
else:
|
||||||
msg = "Got unexpected type of `handle_parsing_errors`"
|
msg = "Got unexpected type of `handle_parsing_errors`"
|
||||||
raise ValueError(msg) from e
|
raise ValueError(msg) from e # noqa: TRY004
|
||||||
output = AgentAction("_Exception", observation, text)
|
output = AgentAction("_Exception", observation, text)
|
||||||
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
||||||
observation = await ExceptionTool().arun(
|
observation = await ExceptionTool().arun(
|
||||||
|
@ -55,7 +55,7 @@ class ChatAgent(Agent):
|
|||||||
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
|
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
|
||||||
if not isinstance(agent_scratchpad, str):
|
if not isinstance(agent_scratchpad, str):
|
||||||
msg = "agent_scratchpad should be of type string."
|
msg = "agent_scratchpad should be of type string."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg) # noqa: TRY004
|
||||||
if agent_scratchpad:
|
if agent_scratchpad:
|
||||||
return (
|
return (
|
||||||
f"This was your previous work "
|
f"This was your previous work "
|
||||||
|
@ -358,12 +358,12 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
run = self._wait_for_run(run.id, run.thread_id)
|
run = self._wait_for_run(run.id, run.thread_id)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise e
|
raise
|
||||||
try:
|
try:
|
||||||
response = self._get_response(run)
|
response = self._get_response(run)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e, metadata=run.dict())
|
run_manager.on_chain_error(e, metadata=run.dict())
|
||||||
raise e
|
raise
|
||||||
else:
|
else:
|
||||||
run_manager.on_chain_end(response)
|
run_manager.on_chain_end(response)
|
||||||
return response
|
return response
|
||||||
@ -494,12 +494,12 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
run = await self._await_for_run(run.id, run.thread_id)
|
run = await self._await_for_run(run.id, run.thread_id)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise e
|
raise
|
||||||
try:
|
try:
|
||||||
response = self._get_response(run)
|
response = self._get_response(run)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e, metadata=run.dict())
|
run_manager.on_chain_error(e, metadata=run.dict())
|
||||||
raise e
|
raise
|
||||||
else:
|
else:
|
||||||
run_manager.on_chain_end(response)
|
run_manager.on_chain_end(response)
|
||||||
return response
|
return response
|
||||||
|
@ -87,7 +87,7 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
|
|||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
if not isinstance(result[0], ChatGeneration):
|
if not isinstance(result[0], ChatGeneration):
|
||||||
msg = "This output parser only works on ChatGeneration output"
|
msg = "This output parser only works on ChatGeneration output"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg) # noqa: TRY004
|
||||||
message = result[0].message
|
message = result[0].message
|
||||||
return self._parse_ai_message(message)
|
return self._parse_ai_message(message)
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
|
|||||||
) -> Union[list[AgentAction], AgentFinish]:
|
) -> Union[list[AgentAction], AgentFinish]:
|
||||||
if not isinstance(result[0], ChatGeneration):
|
if not isinstance(result[0], ChatGeneration):
|
||||||
msg = "This output parser only works on ChatGeneration output"
|
msg = "This output parser only works on ChatGeneration output"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg) # noqa: TRY004
|
||||||
message = result[0].message
|
message = result[0].message
|
||||||
return parse_ai_message_to_openai_tool_action(message)
|
return parse_ai_message_to_openai_tool_action(message)
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ class ToolsAgentOutputParser(MultiActionAgentOutputParser):
|
|||||||
) -> Union[list[AgentAction], AgentFinish]:
|
) -> Union[list[AgentAction], AgentFinish]:
|
||||||
if not isinstance(result[0], ChatGeneration):
|
if not isinstance(result[0], ChatGeneration):
|
||||||
msg = "This output parser only works on ChatGeneration output"
|
msg = "This output parser only works on ChatGeneration output"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg) # noqa: TRY004
|
||||||
message = result[0].message
|
message = result[0].message
|
||||||
return parse_ai_message_to_tool_action(message)
|
return parse_ai_message_to_tool_action(message)
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ class StructuredChatAgent(Agent):
|
|||||||
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
|
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
|
||||||
if not isinstance(agent_scratchpad, str):
|
if not isinstance(agent_scratchpad, str):
|
||||||
msg = "agent_scratchpad should be of type string."
|
msg = "agent_scratchpad should be of type string."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg) # noqa: TRY004
|
||||||
if agent_scratchpad:
|
if agent_scratchpad:
|
||||||
return (
|
return (
|
||||||
f"This was your previous work "
|
f"This was your previous work "
|
||||||
|
@ -71,12 +71,8 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
|||||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||||
try:
|
try:
|
||||||
if self.output_fixing_parser is not None:
|
if self.output_fixing_parser is not None:
|
||||||
parsed_obj: Union[AgentAction, AgentFinish] = (
|
return self.output_fixing_parser.parse(text)
|
||||||
self.output_fixing_parser.parse(text)
|
return self.base_parser.parse(text)
|
||||||
)
|
|
||||||
else:
|
|
||||||
parsed_obj = self.base_parser.parse(text)
|
|
||||||
return parsed_obj
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"Could not parse LLM output: {text}"
|
msg = f"Could not parse LLM output: {text}"
|
||||||
raise OutputParserException(msg) from e
|
raise OutputParserException(msg) from e
|
||||||
|
@ -174,7 +174,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|||||||
)
|
)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise e
|
raise
|
||||||
run_manager.on_chain_end(outputs)
|
run_manager.on_chain_end(outputs)
|
||||||
|
|
||||||
if include_run_info:
|
if include_run_info:
|
||||||
@ -228,7 +228,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|||||||
)
|
)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise e
|
raise
|
||||||
await run_manager.on_chain_end(outputs)
|
await run_manager.on_chain_end(outputs)
|
||||||
|
|
||||||
if include_run_info:
|
if include_run_info:
|
||||||
|
@ -127,7 +127,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
"Output parser of llm_chain should be a RegexParser,"
|
"Output parser of llm_chain should be a RegexParser,"
|
||||||
f" got {output_parser}"
|
f" got {output_parser}"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg) # noqa: TRY004
|
||||||
output_keys = output_parser.output_keys
|
output_keys = output_parser.output_keys
|
||||||
if self.rank_key not in output_keys:
|
if self.rank_key not in output_keys:
|
||||||
msg = (
|
msg = (
|
||||||
|
@ -57,7 +57,7 @@ def _get_chat_history(chat_history: list[CHAT_TURN_TYPE]) -> str:
|
|||||||
f"Unsupported chat history format: {type(dialogue_turn)}."
|
f"Unsupported chat history format: {type(dialogue_turn)}."
|
||||||
f" Full chat history: {chat_history} "
|
f" Full chat history: {chat_history} "
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg) # noqa: TRY004
|
||||||
return buffer
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
|
@ -164,12 +164,13 @@ class ElasticsearchDatabaseChain(Chain):
|
|||||||
chain_result: dict[str, Any] = {self.output_key: final_result}
|
chain_result: dict[str, Any] = {self.output_key: final_result}
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||||
return chain_result
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
# Append intermediate steps to exception, to aid in logging and later
|
# Append intermediate steps to exception, to aid in logging and later
|
||||||
# improvement of few shot prompt seeds
|
# improvement of few shot prompt seeds
|
||||||
exc.intermediate_steps = intermediate_steps # type: ignore[attr-defined]
|
exc.intermediate_steps = intermediate_steps # type: ignore[attr-defined]
|
||||||
raise exc
|
raise
|
||||||
|
|
||||||
|
return chain_result
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _chain_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
|
@ -251,7 +251,7 @@ class LLMChain(Chain):
|
|||||||
response = self.generate(input_list, run_manager=run_manager)
|
response = self.generate(input_list, run_manager=run_manager)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise e
|
raise
|
||||||
outputs = self.create_outputs(response)
|
outputs = self.create_outputs(response)
|
||||||
run_manager.on_chain_end({"outputs": outputs})
|
run_manager.on_chain_end({"outputs": outputs})
|
||||||
return outputs
|
return outputs
|
||||||
@ -276,7 +276,7 @@ class LLMChain(Chain):
|
|||||||
response = await self.agenerate(input_list, run_manager=run_manager)
|
response = await self.agenerate(input_list, run_manager=run_manager)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise e
|
raise
|
||||||
outputs = self.create_outputs(response)
|
outputs = self.create_outputs(response)
|
||||||
await run_manager.on_chain_end({"outputs": outputs})
|
await run_manager.on_chain_end({"outputs": outputs})
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -117,7 +117,7 @@ def _load_stuff_documents_chain(config: dict, **kwargs: Any) -> StuffDocumentsCh
|
|||||||
|
|
||||||
if not isinstance(llm_chain, LLMChain):
|
if not isinstance(llm_chain, LLMChain):
|
||||||
msg = f"Expected LLMChain, got {llm_chain}"
|
msg = f"Expected LLMChain, got {llm_chain}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
|
||||||
if "document_prompt" in config:
|
if "document_prompt" in config:
|
||||||
prompt_config = config.pop("document_prompt")
|
prompt_config = config.pop("document_prompt")
|
||||||
@ -150,7 +150,7 @@ def _load_map_reduce_documents_chain(
|
|||||||
|
|
||||||
if not isinstance(llm_chain, LLMChain):
|
if not isinstance(llm_chain, LLMChain):
|
||||||
msg = f"Expected LLMChain, got {llm_chain}"
|
msg = f"Expected LLMChain, got {llm_chain}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
|
||||||
if "reduce_documents_chain" in config:
|
if "reduce_documents_chain" in config:
|
||||||
reduce_documents_chain = load_chain_from_config(
|
reduce_documents_chain = load_chain_from_config(
|
||||||
|
@ -361,13 +361,13 @@ def get_openapi_chain(
|
|||||||
try:
|
try:
|
||||||
spec = conversion(spec)
|
spec = conversion(spec)
|
||||||
break
|
break
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
raise e
|
raise
|
||||||
except Exception: # noqa: S110
|
except Exception: # noqa: S110
|
||||||
pass
|
pass
|
||||||
if isinstance(spec, str):
|
if isinstance(spec, str):
|
||||||
msg = f"Unable to parse spec from source {spec}"
|
msg = f"Unable to parse spec from source {spec}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg) # noqa: TRY004
|
||||||
openai_fns, call_api_fn = openapi_spec_to_openai_fn(spec)
|
openai_fns, call_api_fn = openapi_spec_to_openai_fn(spec)
|
||||||
if not llm:
|
if not llm:
|
||||||
msg = (
|
msg = (
|
||||||
|
@ -125,7 +125,7 @@ class LLMRouterChain(RouterChain):
|
|||||||
def _validate_outputs(self, outputs: dict[str, Any]) -> None:
|
def _validate_outputs(self, outputs: dict[str, Any]) -> None:
|
||||||
super()._validate_outputs(outputs)
|
super()._validate_outputs(outputs)
|
||||||
if not isinstance(outputs["next_inputs"], dict):
|
if not isinstance(outputs["next_inputs"], dict):
|
||||||
raise ValueError
|
raise ValueError # noqa: TRY004
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
@ -178,10 +178,10 @@ class RouterOutputParser(BaseOutputParser[dict[str, str]]):
|
|||||||
parsed = parse_and_check_json_markdown(text, expected_keys)
|
parsed = parse_and_check_json_markdown(text, expected_keys)
|
||||||
if not isinstance(parsed["destination"], str):
|
if not isinstance(parsed["destination"], str):
|
||||||
msg = "Expected 'destination' to be a string."
|
msg = "Expected 'destination' to be a string."
|
||||||
raise ValueError(msg)
|
raise TypeError(msg)
|
||||||
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
|
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
|
||||||
msg = f"Expected 'next_inputs' to be {self.next_inputs_type}."
|
msg = f"Expected 'next_inputs' to be {self.next_inputs_type}."
|
||||||
raise ValueError(msg)
|
raise TypeError(msg)
|
||||||
parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
|
parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
|
||||||
if (
|
if (
|
||||||
parsed["destination"].strip().lower()
|
parsed["destination"].strip().lower()
|
||||||
@ -190,7 +190,7 @@ class RouterOutputParser(BaseOutputParser[dict[str, str]]):
|
|||||||
parsed["destination"] = None
|
parsed["destination"] = None
|
||||||
else:
|
else:
|
||||||
parsed["destination"] = parsed["destination"].strip()
|
parsed["destination"] = parsed["destination"].strip()
|
||||||
return parsed
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"Parsing text\n{text}\n raised following error:\n{e}"
|
msg = f"Parsing text\n{text}\n raised following error:\n{e}"
|
||||||
raise OutputParserException(msg) from e
|
raise OutputParserException(msg) from e
|
||||||
|
return parsed
|
||||||
|
@ -340,7 +340,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
"key_encoder must be either 'blake2b', 'sha1', 'sha256', 'sha512' "
|
"key_encoder must be either 'blake2b', 'sha1', 'sha256', 'sha512' "
|
||||||
"or a callable that encodes keys."
|
"or a callable that encodes keys."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
|
||||||
document_embedding_store = EncoderBackedStore[str, list[float]](
|
document_embedding_store = EncoderBackedStore[str, list[float]](
|
||||||
document_embedding_cache,
|
document_embedding_cache,
|
||||||
|
@ -23,11 +23,10 @@ from langchain.schema import RUN_KEY
|
|||||||
def _import_numpy() -> Any:
|
def _import_numpy() -> Any:
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
return np
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
msg = "Could not import numpy, please install with `pip install numpy`."
|
msg = "Could not import numpy, please install with `pip install numpy`."
|
||||||
raise ImportError(msg) from e
|
raise ImportError(msg) from e
|
||||||
|
return np
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -74,9 +74,9 @@ class JsonValidityEvaluator(StringEvaluator):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
parse_json_markdown(prediction, parser=json.loads)
|
parse_json_markdown(prediction, parser=json.loads)
|
||||||
return {"score": 1}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"score": 0, "reasoning": str(e)}
|
return {"score": 0, "reasoning": str(e)}
|
||||||
|
return {"score": 1}
|
||||||
|
|
||||||
|
|
||||||
class JsonEqualityEvaluator(StringEvaluator):
|
class JsonEqualityEvaluator(StringEvaluator):
|
||||||
|
@ -80,11 +80,9 @@ class JsonSchemaEvaluator(StringEvaluator):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
validate(instance=prediction, schema=schema)
|
validate(instance=prediction, schema=schema)
|
||||||
return {
|
|
||||||
"score": True,
|
|
||||||
}
|
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
return {"score": False, "reasoning": repr(e)}
|
return {"score": False, "reasoning": repr(e)}
|
||||||
|
return {"score": True}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _evaluate_strings(
|
def _evaluate_strings(
|
||||||
|
@ -153,7 +153,7 @@ class SQLRecordManager(RecordManager):
|
|||||||
"""Create the database schema."""
|
"""Create the database schema."""
|
||||||
if isinstance(self.engine, AsyncEngine):
|
if isinstance(self.engine, AsyncEngine):
|
||||||
msg = "This method is not supported for async engines."
|
msg = "This method is not supported for async engines."
|
||||||
raise AssertionError(msg)
|
raise AssertionError(msg) # noqa: TRY004
|
||||||
|
|
||||||
Base.metadata.create_all(self.engine)
|
Base.metadata.create_all(self.engine)
|
||||||
|
|
||||||
@ -162,7 +162,7 @@ class SQLRecordManager(RecordManager):
|
|||||||
|
|
||||||
if not isinstance(self.engine, AsyncEngine):
|
if not isinstance(self.engine, AsyncEngine):
|
||||||
msg = "This method is not supported for sync engines."
|
msg = "This method is not supported for sync engines."
|
||||||
raise AssertionError(msg)
|
raise AssertionError(msg) # noqa: TRY004
|
||||||
|
|
||||||
async with self.engine.begin() as session:
|
async with self.engine.begin() as session:
|
||||||
await session.run_sync(Base.metadata.create_all)
|
await session.run_sync(Base.metadata.create_all)
|
||||||
@ -173,7 +173,7 @@ class SQLRecordManager(RecordManager):
|
|||||||
|
|
||||||
if isinstance(self.session_factory, async_sessionmaker):
|
if isinstance(self.session_factory, async_sessionmaker):
|
||||||
msg = "This method is not supported for async engines."
|
msg = "This method is not supported for async engines."
|
||||||
raise AssertionError(msg)
|
raise AssertionError(msg) # noqa: TRY004
|
||||||
|
|
||||||
session = self.session_factory()
|
session = self.session_factory()
|
||||||
try:
|
try:
|
||||||
@ -187,7 +187,7 @@ class SQLRecordManager(RecordManager):
|
|||||||
|
|
||||||
if not isinstance(self.session_factory, async_sessionmaker):
|
if not isinstance(self.session_factory, async_sessionmaker):
|
||||||
msg = "This method is not supported for sync engines."
|
msg = "This method is not supported for sync engines."
|
||||||
raise AssertionError(msg)
|
raise AssertionError(msg) # noqa: TRY004
|
||||||
|
|
||||||
async with self.session_factory() as session:
|
async with self.session_factory() as session:
|
||||||
yield session
|
yield session
|
||||||
@ -221,7 +221,7 @@ class SQLRecordManager(RecordManager):
|
|||||||
dt = float(dt)
|
dt = float(dt)
|
||||||
if not isinstance(dt, float):
|
if not isinstance(dt, float):
|
||||||
msg = f"Unexpected type for datetime: {type(dt)}"
|
msg = f"Unexpected type for datetime: {type(dt)}"
|
||||||
raise AssertionError(msg)
|
raise AssertionError(msg) # noqa: TRY004
|
||||||
return dt
|
return dt
|
||||||
|
|
||||||
async def aget_time(self) -> float:
|
async def aget_time(self) -> float:
|
||||||
@ -254,7 +254,7 @@ class SQLRecordManager(RecordManager):
|
|||||||
dt = float(dt)
|
dt = float(dt)
|
||||||
if not isinstance(dt, float):
|
if not isinstance(dt, float):
|
||||||
msg = f"Unexpected type for datetime: {type(dt)}"
|
msg = f"Unexpected type for datetime: {type(dt)}"
|
||||||
raise AssertionError(msg)
|
raise AssertionError(msg) # noqa: TRY004
|
||||||
return dt
|
return dt
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
|
@ -128,7 +128,7 @@ class UpstashRedisEntityStore(BaseEntityStore):
|
|||||||
self.redis_client = Redis(url=url, token=token)
|
self.redis_client = Redis(url=url, token=token)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
error_msg = "Upstash Redis instance could not be initiated"
|
error_msg = "Upstash Redis instance could not be initiated"
|
||||||
logger.error(error_msg)
|
logger.exception(error_msg)
|
||||||
raise RuntimeError(error_msg) from exc
|
raise RuntimeError(error_msg) from exc
|
||||||
|
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
@ -237,8 +237,8 @@ class RedisEntityStore(BaseEntityStore):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
self.redis_client = get_client(redis_url=url, decode_responses=True)
|
self.redis_client = get_client(redis_url=url, decode_responses=True)
|
||||||
except redis.exceptions.ConnectionError as error:
|
except redis.exceptions.ConnectionError:
|
||||||
logger.error(error)
|
logger.exception("Redis client could not connect")
|
||||||
|
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
self.key_prefix = key_prefix
|
self.key_prefix = key_prefix
|
||||||
|
@ -39,7 +39,7 @@ class ModelLaboratory:
|
|||||||
"If you want to initialize with LLMs, use the `from_llms` method "
|
"If you want to initialize with LLMs, use the `from_llms` method "
|
||||||
"instead (`ModelLaboratory.from_llms(...)`)"
|
"instead (`ModelLaboratory.from_llms(...)`)"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg) # noqa: TRY004
|
||||||
if len(chain.input_keys) != 1:
|
if len(chain.input_keys) != 1:
|
||||||
msg = (
|
msg = (
|
||||||
"Currently only support chains with one input variable, "
|
"Currently only support chains with one input variable, "
|
||||||
|
@ -70,7 +70,7 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|||||||
return self.parser.parse(completion)
|
return self.parser.parse(completion)
|
||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise
|
||||||
retries += 1
|
retries += 1
|
||||||
if self.legacy and hasattr(self.retry_chain, "run"):
|
if self.legacy and hasattr(self.retry_chain, "run"):
|
||||||
completion = self.retry_chain.run(
|
completion = self.retry_chain.run(
|
||||||
@ -107,7 +107,7 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|||||||
return await self.parser.aparse(completion)
|
return await self.parser.aparse(completion)
|
||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise
|
||||||
retries += 1
|
retries += 1
|
||||||
if self.legacy and hasattr(self.retry_chain, "arun"):
|
if self.legacy and hasattr(self.retry_chain, "arun"):
|
||||||
completion = await self.retry_chain.arun(
|
completion = await self.retry_chain.arun(
|
||||||
|
@ -104,9 +104,9 @@ class RetryOutputParser(BaseOutputParser[T]):
|
|||||||
while retries <= self.max_retries:
|
while retries <= self.max_retries:
|
||||||
try:
|
try:
|
||||||
return self.parser.parse(completion)
|
return self.parser.parse(completion)
|
||||||
except OutputParserException as e:
|
except OutputParserException:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise
|
||||||
retries += 1
|
retries += 1
|
||||||
if self.legacy and hasattr(self.retry_chain, "run"):
|
if self.legacy and hasattr(self.retry_chain, "run"):
|
||||||
completion = self.retry_chain.run(
|
completion = self.retry_chain.run(
|
||||||
@ -141,7 +141,7 @@ class RetryOutputParser(BaseOutputParser[T]):
|
|||||||
return await self.parser.aparse(completion)
|
return await self.parser.aparse(completion)
|
||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise
|
||||||
retries += 1
|
retries += 1
|
||||||
if self.legacy and hasattr(self.retry_chain, "arun"):
|
if self.legacy and hasattr(self.retry_chain, "arun"):
|
||||||
completion = await self.retry_chain.arun(
|
completion = await self.retry_chain.arun(
|
||||||
@ -232,7 +232,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
|||||||
return self.parser.parse(completion)
|
return self.parser.parse(completion)
|
||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise
|
||||||
retries += 1
|
retries += 1
|
||||||
if self.legacy and hasattr(self.retry_chain, "run"):
|
if self.legacy and hasattr(self.retry_chain, "run"):
|
||||||
completion = self.retry_chain.run(
|
completion = self.retry_chain.run(
|
||||||
@ -260,7 +260,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
|||||||
return await self.parser.aparse(completion)
|
return await self.parser.aparse(completion)
|
||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise
|
||||||
retries += 1
|
retries += 1
|
||||||
if self.legacy and hasattr(self.retry_chain, "arun"):
|
if self.legacy and hasattr(self.retry_chain, "arun"):
|
||||||
completion = await self.retry_chain.arun(
|
completion = await self.retry_chain.arun(
|
||||||
|
@ -48,7 +48,7 @@ class DocumentCompressorPipeline(BaseDocumentCompressor):
|
|||||||
documents = _transformer.transform_documents(documents)
|
documents = _transformer.transform_documents(documents)
|
||||||
else:
|
else:
|
||||||
msg = f"Got unexpected transformer type: {_transformer}"
|
msg = f"Got unexpected transformer type: {_transformer}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg) # noqa: TRY004
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
async def acompress_documents(
|
async def acompress_documents(
|
||||||
@ -78,5 +78,5 @@ class DocumentCompressorPipeline(BaseDocumentCompressor):
|
|||||||
documents = await _transformer.atransform_documents(documents)
|
documents = await _transformer.atransform_documents(documents)
|
||||||
else:
|
else:
|
||||||
msg = f"Got unexpected transformer type: {_transformer}"
|
msg = f"Got unexpected transformer type: {_transformer}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg) # noqa: TRY004
|
||||||
return documents
|
return documents
|
||||||
|
@ -116,7 +116,7 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
result = self.rank_fusion(input, run_manager=run_manager, config=config)
|
result = self.rank_fusion(input, run_manager=run_manager, config=config)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
run_manager.on_retriever_error(e)
|
run_manager.on_retriever_error(e)
|
||||||
raise e
|
raise
|
||||||
else:
|
else:
|
||||||
run_manager.on_retriever_end(
|
run_manager.on_retriever_end(
|
||||||
result,
|
result,
|
||||||
@ -157,7 +157,7 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await run_manager.on_retriever_error(e)
|
await run_manager.on_retriever_error(e)
|
||||||
raise e
|
raise
|
||||||
else:
|
else:
|
||||||
await run_manager.on_retriever_end(
|
await run_manager.on_retriever_end(
|
||||||
result,
|
result,
|
||||||
|
@ -558,7 +558,7 @@ def _construct_run_evaluator(
|
|||||||
return run_evaluator_dec(eval_config)
|
return run_evaluator_dec(eval_config)
|
||||||
else:
|
else:
|
||||||
msg = f"Unknown evaluator type: {type(eval_config)}"
|
msg = f"Unknown evaluator type: {type(eval_config)}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
|
||||||
if isinstance(evaluator_, StringEvaluator):
|
if isinstance(evaluator_, StringEvaluator):
|
||||||
if evaluator_.requires_reference and reference_key is None:
|
if evaluator_.requires_reference and reference_key is None:
|
||||||
@ -668,7 +668,7 @@ def _load_run_evaluators(
|
|||||||
f"Unsupported custom evaluator: {custom_evaluator}."
|
f"Unsupported custom evaluator: {custom_evaluator}."
|
||||||
f" Expected RunEvaluator or StringEvaluator."
|
f" Expected RunEvaluator or StringEvaluator."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
|
||||||
return run_evaluators
|
return run_evaluators
|
||||||
|
|
||||||
@ -1040,7 +1040,7 @@ def _prepare_eval_run(
|
|||||||
)
|
)
|
||||||
except (HTTPError, ValueError, LangSmithError) as e:
|
except (HTTPError, ValueError, LangSmithError) as e:
|
||||||
if "already exists " not in str(e):
|
if "already exists " not in str(e):
|
||||||
raise e
|
raise
|
||||||
uid = uuid.uuid4()
|
uid = uuid.uuid4()
|
||||||
example_msg = f"""
|
example_msg = f"""
|
||||||
run_on_dataset(
|
run_on_dataset(
|
||||||
@ -1123,9 +1123,9 @@ class _DatasetRunContainer:
|
|||||||
run_id=None,
|
run_id=None,
|
||||||
project_id=self.project.id,
|
project_id=self.project.id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Error running batch evaluator %s: %s", repr(evaluator), e
|
"Error running batch evaluator %s", repr(evaluator)
|
||||||
)
|
)
|
||||||
return aggregate_feedback
|
return aggregate_feedback
|
||||||
|
|
||||||
|
@ -181,6 +181,7 @@ select = [
|
|||||||
"T10", # flake8-debugger
|
"T10", # flake8-debugger
|
||||||
"T20", # flake8-print
|
"T20", # flake8-print
|
||||||
"TID", # flake8-tidy-imports
|
"TID", # flake8-tidy-imports
|
||||||
|
"TRY", # tryceratops
|
||||||
"UP", # pyupgrade
|
"UP", # pyupgrade
|
||||||
"W", # pycodestyle warning
|
"W", # pycodestyle warning
|
||||||
"YTT", # flake8-2020
|
"YTT", # flake8-2020
|
||||||
@ -201,6 +202,9 @@ ignore = [
|
|||||||
"RUF012", # Doesn't play well with Pydantic
|
"RUF012", # Doesn't play well with Pydantic
|
||||||
"SLF001", # Private member access
|
"SLF001", # Private member access
|
||||||
"UP007", # pyupgrade: non-pep604-annotation-union
|
"UP007", # pyupgrade: non-pep604-annotation-union
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
"TRY301", # tryceratops: raise-within-try
|
||||||
]
|
]
|
||||||
unfixable = ["B028"] # People should intentionally tune the stacklevel
|
unfixable = ["B028"] # People should intentionally tune the stacklevel
|
||||||
|
|
||||||
|
@ -105,7 +105,7 @@ class GenericFakeChatModel(BaseChatModel):
|
|||||||
f"Expected generate to return a ChatResult, "
|
f"Expected generate to return a ChatResult, "
|
||||||
f"but got {type(chat_result)} instead."
|
f"but got {type(chat_result)} instead."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise TypeError(msg)
|
||||||
|
|
||||||
message = chat_result.generations[0].message
|
message = chat_result.generations[0].message
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ class GenericFakeChatModel(BaseChatModel):
|
|||||||
f"Expected invoke to return an AIMessage, "
|
f"Expected invoke to return an AIMessage, "
|
||||||
f"but got {type(message)} instead."
|
f"but got {type(message)} instead."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise TypeError(msg)
|
||||||
|
|
||||||
content = message.content
|
content = message.content
|
||||||
|
|
||||||
|
@ -2721,7 +2721,7 @@ dependencies = [
|
|||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "jsonpatch", specifier = ">=1.33,<2.0" },
|
{ name = "jsonpatch", specifier = ">=1.33,<2.0" },
|
||||||
{ name = "langsmith", specifier = ">=0.3.45" },
|
{ name = "langsmith", specifier = ">=0.3.45" },
|
||||||
{ name = "packaging", specifier = ">=23.2,<25" },
|
{ name = "packaging", specifier = ">=23.2" },
|
||||||
{ name = "pydantic", specifier = ">=2.7.4" },
|
{ name = "pydantic", specifier = ">=2.7.4" },
|
||||||
{ name = "pyyaml", specifier = ">=5.3" },
|
{ name = "pyyaml", specifier = ">=5.3" },
|
||||||
{ name = "tenacity", specifier = ">=8.1.0,!=8.4.0,<10.0.0" },
|
{ name = "tenacity", specifier = ">=8.1.0,!=8.4.0,<10.0.0" },
|
||||||
|
Loading…
Reference in New Issue
Block a user