mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-24 12:01:54 +00:00
chore(core): fix some ruff preview rules (#32785)
Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
committed by
GitHub
parent
f4e83e0ad8
commit
01fdeede50
@@ -361,7 +361,8 @@ def deprecated(
|
||||
# Modify the docstring to include a deprecation notice.
|
||||
if (
|
||||
_alternative
|
||||
and _alternative.split(".")[-1].lower() == _alternative.split(".")[-1]
|
||||
and _alternative.rsplit(".", maxsplit=1)[-1].lower()
|
||||
== _alternative.rsplit(".", maxsplit=1)[-1]
|
||||
):
|
||||
_alternative = f":meth:`~{_alternative}`"
|
||||
elif _alternative:
|
||||
@@ -369,8 +370,8 @@ def deprecated(
|
||||
|
||||
if (
|
||||
_alternative_import
|
||||
and _alternative_import.split(".")[-1].lower()
|
||||
== _alternative_import.split(".")[-1]
|
||||
and _alternative_import.rsplit(".", maxsplit=1)[-1].lower()
|
||||
== _alternative_import.rsplit(".", maxsplit=1)[-1]
|
||||
):
|
||||
_alternative_import = f":meth:`~{_alternative_import}`"
|
||||
elif _alternative_import:
|
||||
@@ -474,7 +475,7 @@ def warn_deprecated(
|
||||
if not message:
|
||||
message = ""
|
||||
package_ = (
|
||||
package or name.split(".")[0].replace("_", "-")
|
||||
package or name.split(".", maxsplit=1)[0].replace("_", "-")
|
||||
if "." in name
|
||||
else "LangChain"
|
||||
)
|
||||
@@ -493,7 +494,7 @@ def warn_deprecated(
|
||||
message += f" and will be removed {removal}"
|
||||
|
||||
if alternative_import:
|
||||
alt_package = alternative_import.split(".")[0].replace("_", "-")
|
||||
alt_package = alternative_import.split(".", maxsplit=1)[0].replace("_", "-")
|
||||
if alt_package == package_:
|
||||
message += f". Use {alternative_import} instead."
|
||||
else:
|
||||
|
@@ -119,7 +119,8 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
|
||||
rng = np.random.default_rng(seed)
|
||||
return list(rng.normal(size=self.size))
|
||||
|
||||
def _get_seed(self, text: str) -> int:
|
||||
@staticmethod
|
||||
def _get_seed(text: str) -> int:
|
||||
"""Get a seed for the random generator, using the hash of the text."""
|
||||
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8
|
||||
|
||||
|
@@ -148,8 +148,6 @@ def _format_for_tracing(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
"type": key,
|
||||
key: block[key],
|
||||
}
|
||||
else:
|
||||
pass
|
||||
messages_to_trace.append(message_to_trace)
|
||||
|
||||
return messages_to_trace
|
||||
|
@@ -352,10 +352,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
|
||||
for chunk in self.tool_call_chunks:
|
||||
try:
|
||||
if chunk["args"] is not None and chunk["args"] != "":
|
||||
args_ = parse_partial_json(chunk["args"])
|
||||
else:
|
||||
args_ = {}
|
||||
args_ = parse_partial_json(chunk["args"]) if chunk["args"] else {}
|
||||
if isinstance(args_, dict):
|
||||
tool_calls.append(
|
||||
create_tool_call(
|
||||
|
@@ -179,9 +179,7 @@ def merge_content(
|
||||
elif merged and isinstance(merged[-1], str):
|
||||
merged[-1] += content
|
||||
# If second content is an empty string, treat as a no-op
|
||||
elif content == "":
|
||||
pass
|
||||
else:
|
||||
elif content:
|
||||
# Otherwise, add the second content as a new element of the list
|
||||
merged.append(content)
|
||||
return merged
|
||||
|
@@ -46,11 +46,13 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]:
|
||||
@staticmethod
|
||||
def _get_schema(pydantic_object: type[TBaseModel]) -> dict[str, Any]:
|
||||
if issubclass(pydantic_object, pydantic.BaseModel):
|
||||
return pydantic_object.model_json_schema()
|
||||
return pydantic_object.schema()
|
||||
|
||||
@override
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
|
@@ -155,6 +155,7 @@ class CommaSeparatedListOutputParser(ListOutputParser):
|
||||
"""
|
||||
return ["langchain", "output_parsers", "list"]
|
||||
|
||||
@override
|
||||
def get_format_instructions(self) -> str:
|
||||
"""Return the format instructions for the comma-separated list output."""
|
||||
return (
|
||||
@@ -162,6 +163,7 @@ class CommaSeparatedListOutputParser(ListOutputParser):
|
||||
"eg: `foo, bar, baz` or `foo,bar,baz`"
|
||||
)
|
||||
|
||||
@override
|
||||
def parse(self, text: str) -> list[str]:
|
||||
"""Parse the output of an LLM call.
|
||||
|
||||
@@ -224,6 +226,7 @@ class MarkdownListOutputParser(ListOutputParser):
|
||||
pattern: str = r"^\s*[-*]\s([^\n]+)$"
|
||||
"""The pattern to match a Markdown list item."""
|
||||
|
||||
@override
|
||||
def get_format_instructions(self) -> str:
|
||||
"""Return the format instructions for the Markdown list output."""
|
||||
return "Your response should be a markdown list, eg: `- foo\n- bar\n- baz`"
|
||||
|
@@ -1,5 +1,7 @@
|
||||
"""String output parser."""
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
|
||||
|
||||
@@ -29,6 +31,7 @@ class StrOutputParser(BaseTransformOutputParser[str]):
|
||||
"""Return the output parser type for serialization."""
|
||||
return "default"
|
||||
|
||||
@override
|
||||
def parse(self, text: str) -> str:
|
||||
"""Returns the input text with no changes."""
|
||||
return text
|
||||
|
@@ -210,7 +210,7 @@ class BasePromptTemplate(
|
||||
if self.metadata:
|
||||
config["metadata"] = {**config["metadata"], **self.metadata}
|
||||
if self.tags:
|
||||
config["tags"] = config["tags"] + self.tags
|
||||
config["tags"] += self.tags
|
||||
return self._call_with_config(
|
||||
self._format_prompt_with_error_handling,
|
||||
input,
|
||||
|
@@ -166,7 +166,7 @@ def mustache_schema(
|
||||
prefix = section_stack.pop()
|
||||
elif type_ in {"section", "inverted section"}:
|
||||
section_stack.append(prefix)
|
||||
prefix = prefix + tuple(key.split("."))
|
||||
prefix += tuple(key.split("."))
|
||||
fields[prefix] = False
|
||||
elif type_ in {"variable", "no escape"}:
|
||||
fields[prefix + tuple(key.split("."))] = True
|
||||
|
@@ -155,7 +155,7 @@ def draw_mermaid(
|
||||
nonlocal mermaid_graph
|
||||
self_loop = len(edges) == 1 and edges[0].source == edges[0].target
|
||||
if prefix and not self_loop:
|
||||
subgraph = prefix.split(":")[-1]
|
||||
subgraph = prefix.rsplit(":", maxsplit=1)[-1]
|
||||
if subgraph in seen_subgraphs:
|
||||
msg = (
|
||||
f"Found duplicate subgraph '{subgraph}' -- this likely means that "
|
||||
@@ -214,7 +214,7 @@ def draw_mermaid(
|
||||
|
||||
# Add remaining subgraphs with edges
|
||||
for prefix, edges_ in edge_groups.items():
|
||||
if ":" in prefix or prefix == "":
|
||||
if not prefix or ":" in prefix:
|
||||
continue
|
||||
add_subgraph(edges_, prefix)
|
||||
seen_subgraphs.add(prefix)
|
||||
|
@@ -82,7 +82,7 @@ class _TracerCore(ABC):
|
||||
"""Map of run ID to (trace_id, dotted_order). Cleared when tracer GCed."""
|
||||
|
||||
@abstractmethod
|
||||
def _persist_run(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
def _persist_run(self, run: Run) -> Union[Coroutine[Any, Any, None], None]:
|
||||
"""Persist a run."""
|
||||
|
||||
@staticmethod
|
||||
@@ -108,7 +108,7 @@ class _TracerCore(ABC):
|
||||
except: # noqa: E722
|
||||
return msg
|
||||
|
||||
def _start_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # type: ignore[return]
|
||||
def _start_trace(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # type: ignore[return]
|
||||
current_dotted_order = run.start_time.strftime("%Y%m%dT%H%M%S%fZ") + str(run.id)
|
||||
if run.parent_run_id:
|
||||
if parent := self.order_map.get(run.parent_run_id):
|
||||
@@ -538,7 +538,7 @@ class _TracerCore(ABC):
|
||||
"""Return self copied."""
|
||||
return self
|
||||
|
||||
def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||
def _end_trace(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002
|
||||
"""End a trace for a run.
|
||||
|
||||
Args:
|
||||
@@ -546,7 +546,7 @@ class _TracerCore(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||
def _on_run_create(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002
|
||||
"""Process a run upon creation.
|
||||
|
||||
Args:
|
||||
@@ -554,7 +554,7 @@ class _TracerCore(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||
def _on_run_update(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002
|
||||
"""Process a run upon update.
|
||||
|
||||
Args:
|
||||
@@ -562,7 +562,7 @@ class _TracerCore(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||
def _on_llm_start(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002
|
||||
"""Process the LLM Run upon start.
|
||||
|
||||
Args:
|
||||
@@ -575,7 +575,7 @@ class _TracerCore(ABC):
|
||||
run: Run, # noqa: ARG002
|
||||
token: str, # noqa: ARG002
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], # noqa: ARG002
|
||||
) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
) -> Union[Coroutine[Any, Any, None], None]:
|
||||
"""Process new LLM token.
|
||||
|
||||
Args:
|
||||
@@ -585,7 +585,7 @@ class _TracerCore(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||
def _on_llm_end(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002
|
||||
"""Process the LLM Run.
|
||||
|
||||
Args:
|
||||
@@ -593,7 +593,7 @@ class _TracerCore(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||
def _on_llm_error(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002
|
||||
"""Process the LLM Run upon error.
|
||||
|
||||
Args:
|
||||
@@ -601,7 +601,7 @@ class _TracerCore(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||
def _on_chain_start(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002
|
||||
"""Process the Chain Run upon start.
|
||||
|
||||
Args:
|
||||
@@ -609,7 +609,7 @@ class _TracerCore(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||
def _on_chain_end(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002
|
||||
"""Process the Chain Run.
|
||||
|
||||
Args:
|
||||
@@ -617,7 +617,7 @@ class _TracerCore(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||
def _on_chain_error(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002
|
||||
"""Process the Chain Run upon error.
|
||||
|
||||
Args:
|
||||
@@ -625,7 +625,7 @@ class _TracerCore(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||
def _on_tool_start(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002
|
||||
"""Process the Tool Run upon start.
|
||||
|
||||
Args:
|
||||
@@ -633,7 +633,7 @@ class _TracerCore(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||
def _on_tool_end(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002
|
||||
"""Process the Tool Run.
|
||||
|
||||
Args:
|
||||
@@ -641,7 +641,7 @@ class _TracerCore(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||
def _on_tool_error(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002
|
||||
"""Process the Tool Run upon error.
|
||||
|
||||
Args:
|
||||
@@ -649,7 +649,7 @@ class _TracerCore(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||
def _on_chat_model_start(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002
|
||||
"""Process the Chat Model Run upon start.
|
||||
|
||||
Args:
|
||||
@@ -657,7 +657,7 @@ class _TracerCore(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||
def _on_retriever_start(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002
|
||||
"""Process the Retriever Run upon start.
|
||||
|
||||
Args:
|
||||
@@ -665,7 +665,7 @@ class _TracerCore(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||
def _on_retriever_end(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002
|
||||
"""Process the Retriever Run.
|
||||
|
||||
Args:
|
||||
@@ -673,7 +673,7 @@ class _TracerCore(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
|
||||
def _on_retriever_error(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002
|
||||
"""Process the Retriever Run upon error.
|
||||
|
||||
Args:
|
||||
|
@@ -37,7 +37,7 @@ _no_default = object()
|
||||
# before 3.10, the builtin anext() was not available
|
||||
def py_anext(
|
||||
iterator: AsyncIterator[T], default: Union[T, Any] = _no_default
|
||||
) -> Awaitable[Union[T, None, Any]]:
|
||||
) -> Awaitable[Union[T, Any, None]]:
|
||||
"""Pure-Python implementation of anext() for testing purposes.
|
||||
|
||||
Closely matches the builtin anext() C implementation.
|
||||
|
@@ -82,7 +82,7 @@ def l_sa_check(
|
||||
"""
|
||||
# If there is a newline, or the previous tag was a standalone
|
||||
if literal.find("\n") != -1 or is_standalone:
|
||||
padding = literal.split("\n")[-1]
|
||||
padding = literal.rsplit("\n", maxsplit=1)[-1]
|
||||
|
||||
# If all the characters since the last newline are spaces
|
||||
# Then the next tag could be a standalone
|
||||
|
@@ -134,7 +134,7 @@ def guard_import(
|
||||
try:
|
||||
module = importlib.import_module(module_name, package)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
|
||||
pip_name = pip_name or module_name.split(".", maxsplit=1)[0].replace("_", "-")
|
||||
msg = (
|
||||
f"Could not import {module_name} python package. "
|
||||
f"Please install it with `pip install {pip_name}`."
|
||||
|
Reference in New Issue
Block a user