core[patch]: Add B(bugbear) ruff rules (#25520)

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Christophe Bornet 2024-08-28 09:09:29 +02:00 committed by GitHub
parent d5ddaac1fc
commit ff0df5ea15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 149 additions and 112 deletions

View File

@ -270,7 +270,7 @@ def warn_beta(
message += f" {addendum}" message += f" {addendum}"
warning = LangChainBetaWarning(message) warning = LangChainBetaWarning(message)
warnings.warn(warning, category=LangChainBetaWarning, stacklevel=2) warnings.warn(warning, category=LangChainBetaWarning, stacklevel=4)
def surface_langchain_beta_warnings() -> None: def surface_langchain_beta_warnings() -> None:

View File

@ -444,7 +444,7 @@ def warn_deprecated(
LangChainPendingDeprecationWarning if pending else LangChainDeprecationWarning LangChainPendingDeprecationWarning if pending else LangChainDeprecationWarning
) )
warning = warning_cls(message) warning = warning_cls(message)
warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=2) warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=4)
def surface_langchain_deprecation_warnings() -> None: def surface_langchain_deprecation_warnings() -> None:

View File

@ -14,7 +14,7 @@ if TYPE_CHECKING:
from langchain_core.documents.base import Blob from langchain_core.documents.base import Blob
class BaseLoader(ABC): class BaseLoader(ABC): # noqa: B024
"""Interface for Document Loader. """Interface for Document Loader.
Implementations should implement the lazy-loading method using generators Implementations should implement the lazy-loading method using generators

View File

@ -80,12 +80,12 @@ def _texts_to_nodes(
for text in texts: for text in texts:
try: try:
_metadata = next(metadatas_it).copy() if metadatas_it else {} _metadata = next(metadatas_it).copy() if metadatas_it else {}
except StopIteration: except StopIteration as e:
raise ValueError("texts iterable longer than metadatas") raise ValueError("texts iterable longer than metadatas") from e
try: try:
_id = next(ids_it) if ids_it else None _id = next(ids_it) if ids_it else None
except StopIteration: except StopIteration as e:
raise ValueError("texts iterable longer than ids") raise ValueError("texts iterable longer than ids") from e
links = _metadata.pop(METADATA_LINKS_KEY, []) links = _metadata.pop(METADATA_LINKS_KEY, [])
if not isinstance(links, list): if not isinstance(links, list):

View File

@ -91,7 +91,7 @@ class _HashedDocument(Document):
raise ValueError( raise ValueError(
f"Failed to hash metadata: {e}. " f"Failed to hash metadata: {e}. "
f"Please use a dict that can be serialized using json." f"Please use a dict that can be serialized using json."
) ) from e
values["content_hash"] = content_hash values["content_hash"] = content_hash
values["metadata_hash"] = metadata_hash values["metadata_hash"] = metadata_hash

View File

@ -64,12 +64,12 @@ def get_tokenizer() -> Any:
""" """
try: try:
from transformers import GPT2TokenizerFast # type: ignore[import] from transformers import GPT2TokenizerFast # type: ignore[import]
except ImportError: except ImportError as e:
raise ImportError( raise ImportError(
"Could not import transformers python package. " "Could not import transformers python package. "
"This is needed in order to calculate get_token_ids. " "This is needed in order to calculate get_token_ids. "
"Please install it with `pip install transformers`." "Please install it with `pip install transformers`."
) ) from e
# create a GPT-2 tokenizer instance # create a GPT-2 tokenizer instance
return GPT2TokenizerFast.from_pretrained("gpt2") return GPT2TokenizerFast.from_pretrained("gpt2")

View File

@ -235,6 +235,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
warnings.warn( warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.", "callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning, DeprecationWarning,
stacklevel=5,
) )
values["callbacks"] = values.pop("callback_manager", None) values["callbacks"] = values.pop("callback_manager", None)
return values return values

View File

@ -69,6 +69,10 @@ class FakeListLLM(LLM):
return {"responses": self.responses} return {"responses": self.responses}
class FakeListLLMError(Exception):
"""Fake error for testing purposes."""
class FakeStreamingListLLM(FakeListLLM): class FakeStreamingListLLM(FakeListLLM):
"""Fake streaming list LLM for testing purposes. """Fake streaming list LLM for testing purposes.
@ -98,7 +102,7 @@ class FakeStreamingListLLM(FakeListLLM):
self.error_on_chunk_number is not None self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number and i_c == self.error_on_chunk_number
): ):
raise Exception("Fake error") raise FakeListLLMError
yield c yield c
async def astream( async def astream(
@ -118,5 +122,5 @@ class FakeStreamingListLLM(FakeListLLM):
self.error_on_chunk_number is not None self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number and i_c == self.error_on_chunk_number
): ):
raise Exception("Fake error") raise FakeListLLMError
yield c yield c

View File

@ -44,6 +44,10 @@ class FakeMessagesListChatModel(BaseChatModel):
return "fake-messages-list-chat-model" return "fake-messages-list-chat-model"
class FakeListChatModelError(Exception):
pass
class FakeListChatModel(SimpleChatModel): class FakeListChatModel(SimpleChatModel):
"""Fake ChatModel for testing purposes.""" """Fake ChatModel for testing purposes."""
@ -93,7 +97,7 @@ class FakeListChatModel(SimpleChatModel):
self.error_on_chunk_number is not None self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number and i_c == self.error_on_chunk_number
): ):
raise Exception("Fake error") raise FakeListChatModelError
yield ChatGenerationChunk(message=AIMessageChunk(content=c)) yield ChatGenerationChunk(message=AIMessageChunk(content=c))
@ -116,7 +120,7 @@ class FakeListChatModel(SimpleChatModel):
self.error_on_chunk_number is not None self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number and i_c == self.error_on_chunk_number
): ):
raise Exception("Fake error") raise FakeListChatModelError
yield ChatGenerationChunk(message=AIMessageChunk(content=c)) yield ChatGenerationChunk(message=AIMessageChunk(content=c))
@property @property

View File

@ -114,7 +114,6 @@ def create_base_retry_decorator(
_log_error_once(f"Error in on_retry: {e}") _log_error_once(f"Error in on_retry: {e}")
else: else:
run_manager.on_retry(retry_state) run_manager.on_retry(retry_state)
return None
min_seconds = 4 min_seconds = 4
max_seconds = 10 max_seconds = 10
@ -311,6 +310,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
warnings.warn( warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.", "callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning, DeprecationWarning,
stacklevel=5,
) )
values["callbacks"] = values.pop("callback_manager", None) values["callbacks"] = values.pop("callback_manager", None)
return values return values

View File

@ -290,10 +290,10 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
msg_type = msg_kwargs.pop("type") msg_type = msg_kwargs.pop("type")
# None msg content is not allowed # None msg content is not allowed
msg_content = msg_kwargs.pop("content") or "" msg_content = msg_kwargs.pop("content") or ""
except KeyError: except KeyError as e:
raise ValueError( raise ValueError(
f"Message dict must contain 'role' and 'content' keys, got {message}" f"Message dict must contain 'role' and 'content' keys, got {message}"
) ) from e
_message = _create_message_from_message_type( _message = _create_message_from_message_type(
msg_type, msg_content, **msg_kwargs msg_type, msg_content, **msg_kwargs
) )
@ -344,9 +344,7 @@ def _runnable_support(func: Callable) -> Callable:
if messages is not None: if messages is not None:
return func(messages, **kwargs) return func(messages, **kwargs)
else: else:
return RunnableLambda( return RunnableLambda(partial(func, **kwargs), name=func.__name__)
partial(func, **kwargs), name=getattr(func, "__name__")
)
wrapped.__doc__ = func.__doc__ wrapped.__doc__ = func.__doc__
return wrapped return wrapped
@ -791,7 +789,7 @@ def trim_messages(
raise ValueError raise ValueError
messages = convert_to_messages(messages) messages = convert_to_messages(messages)
if hasattr(token_counter, "get_num_tokens_from_messages"): if hasattr(token_counter, "get_num_tokens_from_messages"):
list_token_counter = getattr(token_counter, "get_num_tokens_from_messages") list_token_counter = token_counter.get_num_tokens_from_messages
elif callable(token_counter): elif callable(token_counter):
if ( if (
list(inspect.signature(token_counter).parameters.values())[0].annotation list(inspect.signature(token_counter).parameters.values())[0].annotation

View File

@ -42,7 +42,9 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
try: try:
func_call = copy.deepcopy(message.additional_kwargs["function_call"]) func_call = copy.deepcopy(message.additional_kwargs["function_call"])
except KeyError as exc: except KeyError as exc:
raise OutputParserException(f"Could not parse function call: {exc}") raise OutputParserException(
f"Could not parse function call: {exc}"
) from exc
if self.args_only: if self.args_only:
return func_call["arguments"] return func_call["arguments"]
@ -100,7 +102,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
if partial: if partial:
return None return None
else: else:
raise OutputParserException(f"Could not parse function call: {exc}") raise OutputParserException(
f"Could not parse function call: {exc}"
) from exc
try: try:
if partial: if partial:
try: try:
@ -126,7 +130,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
except (json.JSONDecodeError, TypeError) as exc: except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException( raise OutputParserException(
f"Could not parse function call data: {exc}" f"Could not parse function call data: {exc}"
) ) from exc
else: else:
try: try:
return { return {
@ -138,7 +142,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
except (json.JSONDecodeError, TypeError) as exc: except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException( raise OutputParserException(
f"Could not parse function call data: {exc}" f"Could not parse function call data: {exc}"
) ) from exc
except KeyError: except KeyError:
return None return None

View File

@ -55,7 +55,7 @@ def parse_tool_call(
f"Function {raw_tool_call['function']['name']} arguments:\n\n" f"Function {raw_tool_call['function']['name']} arguments:\n\n"
f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. " f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. "
f"Received JSONDecodeError {e}" f"Received JSONDecodeError {e}"
) ) from e
parsed = { parsed = {
"name": raw_tool_call["function"]["name"] or "", "name": raw_tool_call["function"]["name"] or "",
"args": function_args or {}, "args": function_args or {},

View File

@ -32,12 +32,12 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
{self.pydantic_object.__class__}" {self.pydantic_object.__class__}"
) )
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e: except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
raise self._parser_exception(e, obj) raise self._parser_exception(e, obj) from e
else: # pydantic v1 else: # pydantic v1
try: try:
return self.pydantic_object.parse_obj(obj) return self.pydantic_object.parse_obj(obj)
except pydantic.ValidationError as e: except pydantic.ValidationError as e:
raise self._parser_exception(e, obj) raise self._parser_exception(e, obj) from e
def _parser_exception( def _parser_exception(
self, e: Exception, json_object: dict self, e: Exception, json_object: dict

View File

@ -46,12 +46,12 @@ class _StreamingParser:
if parser == "defusedxml": if parser == "defusedxml":
try: try:
from defusedxml import ElementTree as DET # type: ignore from defusedxml import ElementTree as DET # type: ignore
except ImportError: except ImportError as e:
raise ImportError( raise ImportError(
"defusedxml is not installed. " "defusedxml is not installed. "
"Please install it to use the defusedxml parser." "Please install it to use the defusedxml parser."
"You can install it with `pip install defusedxml` " "You can install it with `pip install defusedxml` "
) ) from e
_parser = DET.DefusedXMLParser(target=TreeBuilder()) _parser = DET.DefusedXMLParser(target=TreeBuilder())
else: else:
_parser = None _parser = None
@ -189,13 +189,13 @@ class XMLOutputParser(BaseTransformOutputParser):
if self.parser == "defusedxml": if self.parser == "defusedxml":
try: try:
from defusedxml import ElementTree as DET # type: ignore from defusedxml import ElementTree as DET # type: ignore
except ImportError: except ImportError as e:
raise ImportError( raise ImportError(
"defusedxml is not installed. " "defusedxml is not installed. "
"Please install it to use the defusedxml parser." "Please install it to use the defusedxml parser."
"You can install it with `pip install defusedxml`" "You can install it with `pip install defusedxml`"
"See https://github.com/tiran/defusedxml for more details" "See https://github.com/tiran/defusedxml for more details"
) ) from e
_ET = DET # Use the defusedxml parser _ET = DET # Use the defusedxml parser
else: else:
_ET = ET # Use the standard library parser _ET = ET # Use the standard library parser

View File

@ -235,7 +235,9 @@ class PromptTemplate(StringPromptTemplate):
template = f.read() template = f.read()
if input_variables: if input_variables:
warnings.warn( warnings.warn(
"`input_variables' is deprecated and ignored.", DeprecationWarning "`input_variables' is deprecated and ignored.",
DeprecationWarning,
stacklevel=2,
) )
return cls.from_template(template=template, **kwargs) return cls.from_template(template=template, **kwargs)

View File

@ -40,14 +40,14 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
""" """
try: try:
from jinja2.sandbox import SandboxedEnvironment from jinja2.sandbox import SandboxedEnvironment
except ImportError: except ImportError as e:
raise ImportError( raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. " "jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`." "Please install it with `pip install jinja2`."
"Please be cautious when using jinja2 templates. " "Please be cautious when using jinja2 templates. "
"Do not expand jinja2 templates using unverified or user-controlled " "Do not expand jinja2 templates using unverified or user-controlled "
"inputs as that can result in arbitrary Python code execution." "inputs as that can result in arbitrary Python code execution."
) ) from e
# This uses a sandboxed environment to prevent arbitrary code execution. # This uses a sandboxed environment to prevent arbitrary code execution.
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing. # Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
@ -81,17 +81,17 @@ def validate_jinja2(template: str, input_variables: List[str]) -> None:
warning_message += f"Extra variables: {extra_variables}" warning_message += f"Extra variables: {extra_variables}"
if warning_message: if warning_message:
warnings.warn(warning_message.strip()) warnings.warn(warning_message.strip(), stacklevel=7)
def _get_jinja2_variables_from_template(template: str) -> Set[str]: def _get_jinja2_variables_from_template(template: str) -> Set[str]:
try: try:
from jinja2 import Environment, meta from jinja2 import Environment, meta
except ImportError: except ImportError as e:
raise ImportError( raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. " "jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`." "Please install it with `pip install jinja2`."
) ) from e
env = Environment() env = Environment()
ast = env.parse(template) ast = env.parse(template)
variables = meta.find_undeclared_variables(ast) variables = meta.find_undeclared_variables(ast)

View File

@ -155,6 +155,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
"Retrievers must implement abstract `_get_relevant_documents` method" "Retrievers must implement abstract `_get_relevant_documents` method"
" instead of `get_relevant_documents`", " instead of `get_relevant_documents`",
DeprecationWarning, DeprecationWarning,
stacklevel=4,
) )
swap = cls.get_relevant_documents swap = cls.get_relevant_documents
cls.get_relevant_documents = ( # type: ignore[assignment] cls.get_relevant_documents = ( # type: ignore[assignment]
@ -169,6 +170,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
"Retrievers must implement abstract `_aget_relevant_documents` method" "Retrievers must implement abstract `_aget_relevant_documents` method"
" instead of `aget_relevant_documents`", " instead of `aget_relevant_documents`",
DeprecationWarning, DeprecationWarning,
stacklevel=4,
) )
aswap = cls.aget_relevant_documents aswap = cls.aget_relevant_documents
cls.aget_relevant_documents = ( # type: ignore[assignment] cls.aget_relevant_documents = ( # type: ignore[assignment]

View File

@ -3915,7 +3915,7 @@ class RunnableGenerator(Runnable[Input, Output]):
@property @property
def InputType(self) -> Any: def InputType(self) -> Any:
func = getattr(self, "_transform", None) or getattr(self, "_atransform") func = getattr(self, "_transform", None) or self._atransform
try: try:
params = inspect.signature(func).parameters params = inspect.signature(func).parameters
first_param = next(iter(params.values()), None) first_param = next(iter(params.values()), None)
@ -3928,7 +3928,7 @@ class RunnableGenerator(Runnable[Input, Output]):
@property @property
def OutputType(self) -> Any: def OutputType(self) -> Any:
func = getattr(self, "_transform", None) or getattr(self, "_atransform") func = getattr(self, "_transform", None) or self._atransform
try: try:
sig = inspect.signature(func) sig = inspect.signature(func)
return ( return (
@ -4152,7 +4152,7 @@ class RunnableLambda(Runnable[Input, Output]):
@property @property
def InputType(self) -> Any: def InputType(self) -> Any:
"""The type of the input to this Runnable.""" """The type of the input to this Runnable."""
func = getattr(self, "func", None) or getattr(self, "afunc") func = getattr(self, "func", None) or self.afunc
try: try:
params = inspect.signature(func).parameters params = inspect.signature(func).parameters
first_param = next(iter(params.values()), None) first_param = next(iter(params.values()), None)
@ -4174,7 +4174,7 @@ class RunnableLambda(Runnable[Input, Output]):
Returns: Returns:
The input schema for this Runnable. The input schema for this Runnable.
""" """
func = getattr(self, "func", None) or getattr(self, "afunc") func = getattr(self, "func", None) or self.afunc
if isinstance(func, itemgetter): if isinstance(func, itemgetter):
# This is terrible, but afaict it's not possible to access _items # This is terrible, but afaict it's not possible to access _items
@ -4212,7 +4212,7 @@ class RunnableLambda(Runnable[Input, Output]):
Returns: Returns:
The type of the output of this Runnable. The type of the output of this Runnable.
""" """
func = getattr(self, "func", None) or getattr(self, "afunc") func = getattr(self, "func", None) or self.afunc
try: try:
sig = inspect.signature(func) sig = inspect.signature(func)
if sig.return_annotation != inspect.Signature.empty: if sig.return_annotation != inspect.Signature.empty:

View File

@ -236,6 +236,7 @@ def get_config_list(
warnings.warn( warnings.warn(
"Provided run_id be used only for the first element of the batch.", "Provided run_id be used only for the first element of the batch.",
category=RuntimeWarning, category=RuntimeWarning,
stacklevel=3,
) )
subsequent = cast( subsequent = cast(
RunnableConfig, {k: v for k, v in config.items() if k != "run_id"} RunnableConfig, {k: v for k, v in config.items() if k != "run_id"}

View File

@ -537,7 +537,7 @@ class Graph:
*, *,
with_styles: bool = True, with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR, curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeStyles = NodeStyles(), node_colors: Optional[NodeStyles] = None,
wrap_label_n_words: int = 9, wrap_label_n_words: int = 9,
) -> str: ) -> str:
"""Draw the graph as a Mermaid syntax string. """Draw the graph as a Mermaid syntax string.
@ -573,7 +573,7 @@ class Graph:
self, self,
*, *,
curve_style: CurveStyle = CurveStyle.LINEAR, curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeStyles = NodeStyles(), node_colors: Optional[NodeStyles] = None,
wrap_label_n_words: int = 9, wrap_label_n_words: int = 9,
output_file_path: Optional[str] = None, output_file_path: Optional[str] = None,
draw_method: MermaidDrawMethod = MermaidDrawMethod.API, draw_method: MermaidDrawMethod = MermaidDrawMethod.API,

View File

@ -20,7 +20,7 @@ def draw_mermaid(
last_node: Optional[str] = None, last_node: Optional[str] = None,
with_styles: bool = True, with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR, curve_style: CurveStyle = CurveStyle.LINEAR,
node_styles: NodeStyles = NodeStyles(), node_styles: Optional[NodeStyles] = None,
wrap_label_n_words: int = 9, wrap_label_n_words: int = 9,
) -> str: ) -> str:
"""Draws a Mermaid graph using the provided graph data. """Draws a Mermaid graph using the provided graph data.
@ -153,7 +153,7 @@ def draw_mermaid(
# Add custom styles for nodes # Add custom styles for nodes
if with_styles: if with_styles:
mermaid_graph += _generate_mermaid_graph_styles(node_styles) mermaid_graph += _generate_mermaid_graph_styles(node_styles or NodeStyles())
return mermaid_graph return mermaid_graph

View File

@ -218,6 +218,8 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
def pending(iterable: List[U]) -> List[U]: def pending(iterable: List[U]) -> List[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map] return [item for idx, item in enumerate(iterable) if idx not in results_map]
not_set: List[Output] = []
result = not_set
try: try:
for attempt in self._sync_retrying(): for attempt in self._sync_retrying():
with attempt: with attempt:
@ -247,9 +249,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
): ):
attempt.retry_state.set_result(result) attempt.retry_state.set_result(result)
except RetryError as e: except RetryError as e:
try: if result is not_set:
result
except UnboundLocalError:
result = cast(List[Output], [e] * len(inputs)) result = cast(List[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = [] outputs: List[Union[Output, Exception]] = []
@ -284,6 +284,8 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
def pending(iterable: List[U]) -> List[U]: def pending(iterable: List[U]) -> List[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map] return [item for idx, item in enumerate(iterable) if idx not in results_map]
not_set: List[Output] = []
result = not_set
try: try:
async for attempt in self._async_retrying(): async for attempt in self._async_retrying():
with attempt: with attempt:
@ -313,9 +315,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
): ):
attempt.retry_state.set_result(result) attempt.retry_state.set_result(result)
except RetryError as e: except RetryError as e:
try: if result is not_set:
result
except UnboundLocalError:
result = cast(List[Output], [e] * len(inputs)) result = cast(List[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = [] outputs: List[Union[Output, Exception]] = []

View File

@ -748,7 +748,7 @@ def is_async_generator(
""" """
return ( return (
inspect.isasyncgenfunction(func) inspect.isasyncgenfunction(func)
or hasattr(func, "__call__") or hasattr(func, "__call__") # noqa: B004
and inspect.isasyncgenfunction(func.__call__) and inspect.isasyncgenfunction(func.__call__)
) )
@ -767,6 +767,6 @@ def is_async_callable(
""" """
return ( return (
asyncio.iscoroutinefunction(func) asyncio.iscoroutinefunction(func)
or hasattr(func, "__call__") or hasattr(func, "__call__") # noqa: B004
and asyncio.iscoroutinefunction(func.__call__) and asyncio.iscoroutinefunction(func.__call__)
) )

View File

@ -443,6 +443,7 @@ class ChildTool(BaseTool):
warnings.warn( warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.", "callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning, DeprecationWarning,
stacklevel=6,
) )
values["callbacks"] = values.pop("callback_manager", None) values["callbacks"] = values.pop("callback_manager", None)
return values return values

View File

@ -244,7 +244,7 @@ def _get_schema_from_runnable_and_arg_types(
"Tool input must be str or dict. If dict, dict arguments must be " "Tool input must be str or dict. If dict, dict arguments must be "
"typed. Either annotate types (e.g., with TypedDict) or pass " "typed. Either annotate types (e.g., with TypedDict) or pass "
f"arg_types into `.as_tool` to specify. {str(e)}" f"arg_types into `.as_tool` to specify. {str(e)}"
) ) from e
fields = {key: (key_type, Field(...)) for key, key_type in arg_types.items()} fields = {key: (key_type, Field(...)) for key, key_type in arg_types.items()}
return create_model(name, **fields) # type: ignore return create_model(name, **fields) # type: ignore

View File

@ -516,15 +516,19 @@ class _TracerCore(ABC):
def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""End a trace for a run.""" """End a trace for a run."""
return None
def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process a run upon creation.""" """Process a run upon creation."""
return None
def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process a run upon update.""" """Process a run upon update."""
return None
def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the LLM Run upon start.""" """Process the LLM Run upon start."""
return None
def _on_llm_new_token( def _on_llm_new_token(
self, self,
@ -533,39 +537,52 @@ class _TracerCore(ABC):
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
) -> Union[None, Coroutine[Any, Any, None]]: ) -> Union[None, Coroutine[Any, Any, None]]:
"""Process new LLM token.""" """Process new LLM token."""
return None
def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the LLM Run.""" """Process the LLM Run."""
return None
def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the LLM Run upon error.""" """Process the LLM Run upon error."""
return None
def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Chain Run upon start.""" """Process the Chain Run upon start."""
return None
def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Chain Run.""" """Process the Chain Run."""
return None
def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Chain Run upon error.""" """Process the Chain Run upon error."""
return None
def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Tool Run upon start.""" """Process the Tool Run upon start."""
return None
def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Tool Run.""" """Process the Tool Run."""
return None
def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Tool Run upon error.""" """Process the Tool Run upon error."""
return None
def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Chat Model Run upon start.""" """Process the Chat Model Run upon start."""
return None
def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Retriever Run upon start.""" """Process the Retriever Run upon start."""
return None
def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Retriever Run.""" """Process the Retriever Run."""
return None
def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Retriever Run upon error.""" """Process the Retriever Run upon error."""
return None

View File

@ -144,11 +144,7 @@ class EvaluatorCallbackHandler(BaseTracer):
example_id = str(run.reference_example_id) example_id = str(run.reference_example_id)
with self.lock: with self.lock:
for res in eval_results: for res in eval_results:
run_id = ( run_id = str(getattr(res, "target_run_id", run.id))
str(getattr(res, "target_run_id"))
if hasattr(res, "target_run_id")
else str(run.id)
)
self.logged_eval_results.setdefault((run_id, example_id), []).append( self.logged_eval_results.setdefault((run_id, example_id), []).append(
res res
) )
@ -179,11 +175,9 @@ class EvaluatorCallbackHandler(BaseTracer):
source_info_: Dict[str, Any] = {} source_info_: Dict[str, Any] = {}
if res.evaluator_info: if res.evaluator_info:
source_info_ = {**res.evaluator_info, **source_info_} source_info_ = {**res.evaluator_info, **source_info_}
run_id_ = ( run_id_ = getattr(res, "target_run_id", None)
getattr(res, "target_run_id") if run_id_ is None:
if hasattr(res, "target_run_id") and res.target_run_id is not None run_id_ = run.id
else run.id
)
self.client.create_feedback( self.client.create_feedback(
run_id_, run_id_,
res.key, res.key,

View File

@ -22,6 +22,7 @@ def RunTypeEnum() -> Type[RunTypeEnumDep]:
"RunTypeEnum is deprecated. Please directly use a string instead" "RunTypeEnum is deprecated. Please directly use a string instead"
" (e.g. 'llm', 'chain', 'tool').", " (e.g. 'llm', 'chain', 'tool').",
DeprecationWarning, DeprecationWarning,
stacklevel=2,
) )
return RunTypeEnumDep return RunTypeEnumDep

View File

@ -62,8 +62,8 @@ def py_anext(
__anext__ = cast( __anext__ = cast(
Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__ Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__
) )
except AttributeError: except AttributeError as e:
raise TypeError(f"{iterator!r} is not an async iterator") raise TypeError(f"{iterator!r} is not an async iterator") from e
if default is _no_default: if default is _no_default:
return __anext__(iterator) return __anext__(iterator)

View File

@ -182,7 +182,7 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
try: try:
json_obj = parse_json_markdown(text) json_obj = parse_json_markdown(text)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise OutputParserException(f"Got invalid JSON object. Error: {e}") raise OutputParserException(f"Got invalid JSON object. Error: {e}") from e
for key in expected_keys: for key in expected_keys:
if key not in json_obj: if key not in json_obj:
raise OutputParserException( raise OutputParserException(

View File

@ -21,7 +21,9 @@ def try_load_from_hub(
) -> Any: ) -> Any:
warnings.warn( warnings.warn(
"Loading from the deprecated github-based Hub is no longer supported. " "Loading from the deprecated github-based Hub is no longer supported. "
"Please use the new LangChain Hub at https://smith.langchain.com/hub instead." "Please use the new LangChain Hub at https://smith.langchain.com/hub instead.",
DeprecationWarning,
stacklevel=2,
) )
# return None, which indicates that we shouldn't load from old hub # return None, which indicates that we shouldn't load from old hub
# and might just be a filepath for e.g. load_chain # and might just be a filepath for e.g. load_chain

View File

@ -3,13 +3,17 @@ Adapted from https://github.com/noahmorrison/chevron
MIT License MIT License
""" """
from __future__ import annotations
import logging import logging
from types import MappingProxyType
from typing import ( from typing import (
Any, Any,
Dict, Dict,
Iterator, Iterator,
List, List,
Literal, Literal,
Mapping,
Optional, Optional,
Sequence, Sequence,
Tuple, Tuple,
@ -22,7 +26,7 @@ from typing_extensions import TypeAlias
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
Scopes: TypeAlias = List[Union[Literal[False, 0], Dict[str, Any]]] Scopes: TypeAlias = List[Union[Literal[False, 0], Mapping[str, Any]]]
# Globals # Globals
@ -152,8 +156,8 @@ def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], s
# Get the tag # Get the tag
try: try:
tag, template = template.split(r_del, 1) tag, template = template.split(r_del, 1)
except ValueError: except ValueError as e:
raise ChevronError("unclosed tag " f"at line {_CURRENT_LINE}") raise ChevronError("unclosed tag " f"at line {_CURRENT_LINE}") from e
# Find the type meaning of the first character # Find the type meaning of the first character
tag_type = tag_types.get(tag[0], "variable") tag_type = tag_types.get(tag[0], "variable")
@ -279,12 +283,12 @@ def tokenize(
# is the same as us # is the same as us
try: try:
last_section = open_sections.pop() last_section = open_sections.pop()
except IndexError: except IndexError as e:
raise ChevronError( raise ChevronError(
f'Trying to close tag "{tag_key}"\n' f'Trying to close tag "{tag_key}"\n'
"Looks like it was not opened.\n" "Looks like it was not opened.\n"
f"line {_CURRENT_LINE + 1}" f"line {_CURRENT_LINE + 1}"
) ) from e
if tag_key != last_section: if tag_key != last_section:
# Otherwise we need to complain # Otherwise we need to complain
raise ChevronError( raise ChevronError(
@ -411,7 +415,7 @@ def _get_key(
return "" return ""
def _get_partial(name: str, partials_dict: Dict[str, str]) -> str: def _get_partial(name: str, partials_dict: Mapping[str, str]) -> str:
"""Load a partial""" """Load a partial"""
try: try:
# Maybe the partial is in the dictionary # Maybe the partial is in the dictionary
@ -425,11 +429,13 @@ def _get_partial(name: str, partials_dict: Dict[str, str]) -> str:
# #
g_token_cache: Dict[str, List[Tuple[str, str]]] = {} g_token_cache: Dict[str, List[Tuple[str, str]]] = {}
EMPTY_DICT: MappingProxyType[str, str] = MappingProxyType({})
def render( def render(
template: Union[str, List[Tuple[str, str]]] = "", template: Union[str, List[Tuple[str, str]]] = "",
data: Dict[str, Any] = {}, data: Mapping[str, Any] = EMPTY_DICT,
partials_dict: Dict[str, str] = {}, partials_dict: Mapping[str, str] = EMPTY_DICT,
padding: str = "", padding: str = "",
def_ldel: str = "{{", def_ldel: str = "{{",
def_rdel: str = "}}", def_rdel: str = "}}",

View File

@ -131,12 +131,12 @@ def guard_import(
""" """
try: try:
module = importlib.import_module(module_name, package) module = importlib.import_module(module_name, package)
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError) as e:
pip_name = pip_name or module_name.split(".")[0].replace("_", "-") pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
raise ImportError( raise ImportError(
f"Could not import {module_name} python package. " f"Could not import {module_name} python package. "
f"Please install it with `pip install {pip_name}`." f"Please install it with `pip install {pip_name}`."
) ) from e
return module return module
@ -235,7 +235,8 @@ def build_extra_kwargs(
warnings.warn( warnings.warn(
f"""WARNING! {field_name} is not default parameter. f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs. {field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended.""" Please confirm that {field_name} is what you intended.""",
stacklevel=7,
) )
extra_kwargs[field_name] = values.pop(field_name) extra_kwargs[field_name] = values.pop(field_name)

View File

@ -558,7 +558,8 @@ class VectorStore(ABC):
): ):
warnings.warn( warnings.warn(
"Relevance scores must be between" "Relevance scores must be between"
f" 0 and 1, got {docs_and_similarities}" f" 0 and 1, got {docs_and_similarities}",
stacklevel=2,
) )
if score_threshold is not None: if score_threshold is not None:
@ -568,7 +569,7 @@ class VectorStore(ABC):
if similarity >= score_threshold if similarity >= score_threshold
] ]
if len(docs_and_similarities) == 0: if len(docs_and_similarities) == 0:
warnings.warn( logger.warning(
"No relevant docs were retrieved using the relevance score" "No relevant docs were retrieved using the relevance score"
f" threshold {score_threshold}" f" threshold {score_threshold}"
) )
@ -605,7 +606,8 @@ class VectorStore(ABC):
): ):
warnings.warn( warnings.warn(
"Relevance scores must be between" "Relevance scores must be between"
f" 0 and 1, got {docs_and_similarities}" f" 0 and 1, got {docs_and_similarities}",
stacklevel=2,
) )
if score_threshold is not None: if score_threshold is not None:
@ -615,7 +617,7 @@ class VectorStore(ABC):
if similarity >= score_threshold if similarity >= score_threshold
] ]
if len(docs_and_similarities) == 0: if len(docs_and_similarities) == 0:
warnings.warn( logger.warning(
"No relevant docs were retrieved using the relevance score" "No relevant docs were retrieved using the relevance score"
f" threshold {score_threshold}" f" threshold {score_threshold}"
) )

View File

@ -435,11 +435,11 @@ class InMemoryVectorStore(VectorStore):
try: try:
import numpy as np import numpy as np
except ImportError: except ImportError as e:
raise ImportError( raise ImportError(
"numpy must be installed to use max_marginal_relevance_search " "numpy must be installed to use max_marginal_relevance_search "
"pip install numpy" "pip install numpy"
) ) from e
mmr_chosen_indices = maximal_marginal_relevance( mmr_chosen_indices = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32), np.array(embedding, dtype=np.float32),

View File

@ -34,11 +34,11 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
""" """
try: try:
import numpy as np import numpy as np
except ImportError: except ImportError as e:
raise ImportError( raise ImportError(
"cosine_similarity requires numpy to be installed. " "cosine_similarity requires numpy to be installed. "
"Please install numpy with `pip install numpy`." "Please install numpy with `pip install numpy`."
) ) from e
if len(X) == 0 or len(Y) == 0: if len(X) == 0 or len(Y) == 0:
return np.array([]) return np.array([])
@ -93,11 +93,11 @@ def maximal_marginal_relevance(
""" """
try: try:
import numpy as np import numpy as np
except ImportError: except ImportError as e:
raise ImportError( raise ImportError(
"maximal_marginal_relevance requires numpy to be installed. " "maximal_marginal_relevance requires numpy to be installed. "
"Please install numpy with `pip install numpy`." "Please install numpy with `pip install numpy`."
) ) from e
if min(k, len(embedding_list)) <= 0: if min(k, len(embedding_list)) <= 0:
return [] return []

View File

@ -41,7 +41,7 @@ python = ">=3.12.4"
[tool.poetry.extras] [tool.poetry.extras]
[tool.ruff.lint] [tool.ruff.lint]
select = [ "E", "F", "I", "T201", "UP",] select = [ "B", "E", "F", "I", "T201", "UP",]
ignore = [ "UP006", "UP007",] ignore = [ "UP006", "UP007",]
[tool.coverage.run] [tool.coverage.run]

View File

@ -7,6 +7,7 @@ import pytest
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, FakeListChatModel from langchain_core.language_models import BaseChatModel, FakeListChatModel
from langchain_core.language_models.fake_chat_models import FakeListChatModelError
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
AIMessageChunk, AIMessageChunk,
@ -110,7 +111,7 @@ async def test_stream_error_callback() -> None:
responses=[message], responses=[message],
error_on_chunk_number=i, error_on_chunk_number=i,
) )
with pytest.raises(Exception): with pytest.raises(FakeListChatModelError):
cb_async = FakeAsyncCallbackHandler() cb_async = FakeAsyncCallbackHandler()
async for _ in llm.astream("Dummy message", callbacks=[cb_async]): async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
pass pass

View File

@ -7,6 +7,7 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models import BaseLLM, FakeListLLM, FakeStreamingListLLM from langchain_core.language_models import BaseLLM, FakeListLLM, FakeStreamingListLLM
from langchain_core.language_models.fake import FakeListLLMError
from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.tracers.context import collect_runs from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.callbacks import ( from tests.unit_tests.fake.callbacks import (
@ -108,7 +109,7 @@ async def test_stream_error_callback() -> None:
responses=[message], responses=[message],
error_on_chunk_number=i, error_on_chunk_number=i,
) )
with pytest.raises(Exception): with pytest.raises(FakeListLLMError):
cb_async = FakeAsyncCallbackHandler() cb_async = FakeAsyncCallbackHandler()
async for _ in llm.astream("Dummy message", callbacks=[cb_async]): async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
pass pass

View File

@ -3,6 +3,7 @@ from typing import Any, AsyncIterator, Iterator, Tuple
import pytest import pytest
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.json import ( from langchain_core.output_parsers.json import (
SimpleJsonOutputParser, SimpleJsonOutputParser,
) )
@ -531,7 +532,7 @@ async def test_partial_text_json_output_parser_diff_async() -> None:
def test_raises_error() -> None: def test_raises_error() -> None:
parser = SimpleJsonOutputParser() parser = SimpleJsonOutputParser()
with pytest.raises(Exception): with pytest.raises(OutputParserException):
parser.invoke("hi") parser.invoke("hi")

View File

@ -164,13 +164,9 @@ def test_pydantic_output_parser_fail() -> None:
pydantic_object=TestModel pydantic_object=TestModel
) )
try: with pytest.raises(OutputParserException) as e:
pydantic_parser.parse(DEF_RESULT_FAIL) pydantic_parser.parse(DEF_RESULT_FAIL)
except OutputParserException as e:
print("parse_result:", e) # noqa: T201
assert "Failed to parse TestModel from completion" in str(e) assert "Failed to parse TestModel from completion" in str(e)
else:
assert False, "Expected OutputParserException"
def test_pydantic_output_parser_type_inference() -> None: def test_pydantic_output_parser_type_inference() -> None:

View File

@ -28,7 +28,7 @@ def _replace_all_of_with_ref(schema: Any) -> None:
del schema["default"] del schema["default"]
else: else:
# Recursively process nested schemas # Recursively process nested schemas
for key, value in schema.items(): for value in schema.values():
if isinstance(value, (dict, list)): if isinstance(value, (dict, list)):
_replace_all_of_with_ref(value) _replace_all_of_with_ref(value)
elif isinstance(schema, list): elif isinstance(schema, list):
@ -47,7 +47,7 @@ def _remove_bad_none_defaults(schema: Any) -> None:
See difference between Optional and NotRequired types in python. See difference between Optional and NotRequired types in python.
""" """
if isinstance(schema, dict): if isinstance(schema, dict):
for key, value in schema.items(): for value in schema.values():
if isinstance(value, dict): if isinstance(value, dict):
if "default" in value and value["default"] is None: if "default" in value and value["default"] is None:
any_of = value.get("anyOf", []) any_of = value.get("anyOf", [])

View File

@ -307,7 +307,7 @@ async def test_fallbacks_astream() -> None:
runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks( runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks(
[RunnableGenerator(_agenerate)] [RunnableGenerator(_agenerate)]
) )
async for c in runnable.astream({}): async for _ in runnable.astream({}):
pass pass
@ -373,7 +373,7 @@ def test_fallbacks_getattr() -> None:
assert llm_with_fallbacks.foo == 3 assert llm_with_fallbacks.foo == 3
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
llm_with_fallbacks.bar assert llm_with_fallbacks.bar == 4
def test_fallbacks_getattr_runnable_output() -> None: def test_fallbacks_getattr_runnable_output() -> None:

View File

@ -1516,7 +1516,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
) == [5, 7] ) == [5, 7]
assert len(spy.call_args_list) == 2 assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list): for call in spy.call_args_list:
call_arg = call.args[0] call_arg = call.args[0]
if call_arg == "hello": if call_arg == "hello":
@ -1533,7 +1533,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7] assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
assert len(spy.call_args_list) == 2 assert len(spy.call_args_list) == 2
assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"} assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"}
for i, call in enumerate(spy.call_args_list): for call in spy.call_args_list:
assert call.args[1].get("tags") == ["a-tag"] assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {} assert call.args[1].get("metadata") == {}
spy.reset_mock() spy.reset_mock()
@ -5205,7 +5205,7 @@ def test_invoke_stream_passthrough_assign_trace() -> None:
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>" assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
tracer = FakeTracer() tracer = FakeTracer()
for item in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])): for _ in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
pass pass
assert tracer.runs[0].name == "RunnableAssign<urls>" assert tracer.runs[0].name == "RunnableAssign<urls>"
@ -5225,7 +5225,7 @@ async def test_ainvoke_astream_passthrough_assign_trace() -> None:
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>" assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
tracer = FakeTracer() tracer = FakeTracer()
async for item in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])): async for _ in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
pass pass
assert tracer.runs[0].name == "RunnableAssign<urls>" assert tracer.runs[0].name == "RunnableAssign<urls>"

View File

@ -140,7 +140,7 @@ class LangChainProjectNameTest(unittest.TestCase):
projects = [] projects = []
def mock_create_run(**kwargs: Any) -> Any: def mock_create_run(**kwargs: Any) -> Any:
projects.append(kwargs.get("project_name")) projects.append(kwargs.get("project_name")) # noqa: B023
return unittest.mock.MagicMock() return unittest.mock.MagicMock()
client.create_run = mock_create_run client.create_run = mock_create_run
@ -151,6 +151,4 @@ class LangChainProjectNameTest(unittest.TestCase):
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"), run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
) )
tracer.wait_for_futures() tracer.wait_for_futures()
assert ( assert projects == [case.expected_project_name]
len(projects) == 1 and projects[0] == case.expected_project_name
)