docs: langchain docstrings updates (#21032)

Added missed docstings. Formatted docstrings into a consistent format.
This commit is contained in:
Leonid Ganeline 2024-04-29 14:40:44 -07:00 committed by GitHub
parent 85094cbb3a
commit 08d08d7c83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 78 additions and 16 deletions

View File

@ -16,6 +16,15 @@ from langchain.chains.llm import LLMChain
def remove_prefix(text: str, prefix: str) -> str: def remove_prefix(text: str, prefix: str) -> str:
"""Remove a prefix from a text.
Args:
text: Text to remove the prefix from.
prefix: Prefix to remove from the text.
Returns:
Text with the prefix removed.
"""
if text.startswith(prefix): if text.startswith(prefix):
return text[len(prefix) :] return text[len(prefix) :]
return text return text

View File

@ -54,6 +54,14 @@ SPARQL_GENERATION_PROMPT = PromptTemplate(
def extract_sparql(query: str) -> str: def extract_sparql(query: str) -> str:
"""Extract SPARQL code from a text.
Args:
query: Text to extract SPARQL code from.
Returns:
SPARQL code extracted from the text.
"""
query = query.strip() query = query.strip()
querytoks = query.split("```") querytoks = query.split("```")
if len(querytoks) == 3: if len(querytoks) == 3:

View File

@ -35,7 +35,7 @@ def create_tagging_chain(
prompt: Optional[ChatPromptTemplate] = None, prompt: Optional[ChatPromptTemplate] = None,
**kwargs: Any, **kwargs: Any,
) -> Chain: ) -> Chain:
"""Creates a chain that extracts information from a passage """Create a chain that extracts information from a passage
based on a schema. based on a schema.
Args: Args:
@ -65,7 +65,7 @@ def create_tagging_chain_pydantic(
prompt: Optional[ChatPromptTemplate] = None, prompt: Optional[ChatPromptTemplate] = None,
**kwargs: Any, **kwargs: Any,
) -> Chain: ) -> Chain:
"""Creates a chain that extracts information from a passage """Create a chain that extracts information from a passage
based on a pydantic schema. based on a pydantic schema.
Args: Args:

View File

@ -3,7 +3,7 @@ from typing import Any, Dict
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any: def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any:
""" """
Resolves the $ref keys in a JSON schema object using the provided definitions. Resolve the $ref keys in a JSON schema object using the provided definitions.
""" """
if isinstance(schema, list): if isinstance(schema, list):
for i, item in enumerate(schema): for i, item in enumerate(schema):
@ -29,7 +29,7 @@ def _convert_schema(schema: dict) -> dict:
def get_llm_kwargs(function: dict) -> dict: def get_llm_kwargs(function: dict) -> dict:
"""Returns the kwargs for the LLMChain constructor. """Return the kwargs for the LLMChain constructor.
Args: Args:
function: The function to use. function: The function to use.

View File

@ -63,7 +63,7 @@ class ISO8601Date(TypedDict):
@v_args(inline=True) @v_args(inline=True)
class QueryTransformer(Transformer): class QueryTransformer(Transformer):
"""Transforms a query string into an intermediate representation.""" """Transform a query string into an intermediate representation."""
def __init__( def __init__(
self, self,
@ -159,8 +159,7 @@ def get_parser(
allowed_operators: Optional[Sequence[Operator]] = None, allowed_operators: Optional[Sequence[Operator]] = None,
allowed_attributes: Optional[Sequence[str]] = None, allowed_attributes: Optional[Sequence[str]] = None,
) -> Lark: ) -> Lark:
""" """Return a parser for the query language.
Returns a parser for the query language.
Args: Args:
allowed_comparators: Optional[Sequence[Comparator]] allowed_comparators: Optional[Sequence[Comparator]]

View File

@ -9,7 +9,7 @@ from langchain.evaluation.schema import StringEvaluator
class JsonValidityEvaluator(StringEvaluator): class JsonValidityEvaluator(StringEvaluator):
"""Evaluates whether the prediction is valid JSON. """Evaluate whether the prediction is valid JSON.
This evaluator checks if the prediction is a valid JSON string. It does not This evaluator checks if the prediction is a valid JSON string. It does not
require any input or reference. require any input or reference.
@ -77,7 +77,7 @@ class JsonValidityEvaluator(StringEvaluator):
class JsonEqualityEvaluator(StringEvaluator): class JsonEqualityEvaluator(StringEvaluator):
"""Evaluates whether the prediction is equal to the reference after """Evaluate whether the prediction is equal to the reference after
parsing both as JSON. parsing both as JSON.
This evaluator checks if the prediction, after parsing as JSON, is equal This evaluator checks if the prediction, after parsing as JSON, is equal

View File

@ -37,7 +37,7 @@ def push(
new_repo_description: str = "", new_repo_description: str = "",
) -> str: ) -> str:
""" """
Pushes an object to the hub and returns the URL it can be viewed at in a browser. Push an object to the hub and returns the URL it can be viewed at in a browser.
:param repo_full_name: The full name of the repo to push to in the format of :param repo_full_name: The full name of the repo to push to in the format of
`owner/repo`. `owner/repo`.
@ -71,7 +71,7 @@ def pull(
api_key: Optional[str] = None, api_key: Optional[str] = None,
) -> Any: ) -> Any:
""" """
Pulls an object from the hub and returns it as a LangChain object. Pull an object from the hub and returns it as a LangChain object.
:param owner_repo_commit: The full name of the repo to pull from in the format of :param owner_repo_commit: The full name of the repo to pull from in the format of
`owner/repo:commit_hash`. `owner/repo:commit_hash`.

View File

@ -4,7 +4,7 @@ from langchain_core.memory import BaseMemory
class ReadOnlySharedMemory(BaseMemory): class ReadOnlySharedMemory(BaseMemory):
"""A memory wrapper that is read-only and cannot be changed.""" """Memory wrapper that is read-only and cannot be changed."""
memory: BaseMemory memory: BaseMemory

View File

@ -13,7 +13,7 @@ T = TypeVar("T")
class OutputFixingParser(BaseOutputParser[T]): class OutputFixingParser(BaseOutputParser[T]):
"""Wraps a parser and tries to fix parsing errors.""" """Wrap a parser and try to fix parsing errors."""
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:

View File

@ -34,7 +34,7 @@ T = TypeVar("T")
class RetryOutputParser(BaseOutputParser[T]): class RetryOutputParser(BaseOutputParser[T]):
"""Wraps a parser and tries to fix parsing errors. """Wrap a parser and try to fix parsing errors.
Does this by passing the original prompt and the completion to another Does this by passing the original prompt and the completion to another
LLM, and telling it the completion did not satisfy criteria in the prompt. LLM, and telling it the completion did not satisfy criteria in the prompt.
@ -138,7 +138,7 @@ class RetryOutputParser(BaseOutputParser[T]):
class RetryWithErrorOutputParser(BaseOutputParser[T]): class RetryWithErrorOutputParser(BaseOutputParser[T]):
"""Wraps a parser and tries to fix parsing errors. """Wrap a parser and try to fix parsing errors.
Does this by passing the original prompt, the completion, AND the error Does this by passing the original prompt, the completion, AND the error
that was raised to another language model and telling it that the completion that was raised to another language model and telling it that the completion

View File

@ -15,7 +15,7 @@ line_template = '\t"{name}": {type} // {description}'
class ResponseSchema(BaseModel): class ResponseSchema(BaseModel):
"""A schema for a response from a structured output parser.""" """Schema for a response from a structured output parser."""
name: str name: str
"""The name of the schema.""" """The name of the schema."""

View File

@ -38,6 +38,15 @@ H = TypeVar("H", bound=Hashable)
def unique_by_key(iterable: Iterable[T], key: Callable[[T], H]) -> Iterator[T]: def unique_by_key(iterable: Iterable[T], key: Callable[[T], H]) -> Iterator[T]:
"""Yield unique elements of an iterable based on a key function.
Args:
iterable: The iterable to filter.
key: A function that returns a hashable key for each element.
Yields:
Unique elements of the iterable based on the key function.
"""
seen = set() seen = set()
for e in iterable: for e in iterable:
if (k := key(e)) not in seen: if (k := key(e)) not in seen:

View File

@ -13,6 +13,8 @@ from langchain_core.structured_query import (
class TencentVectorDBTranslator(Visitor): class TencentVectorDBTranslator(Visitor):
"""Translate StructuredQuery to Tencent VectorDB query."""
COMPARATOR_MAP = { COMPARATOR_MAP = {
Comparator.EQ: "=", Comparator.EQ: "=",
Comparator.NE: "!=", Comparator.NE: "!=",
@ -32,9 +34,22 @@ class TencentVectorDBTranslator(Visitor):
] ]
def __init__(self, meta_keys: Optional[Sequence[str]] = None): def __init__(self, meta_keys: Optional[Sequence[str]] = None):
"""Initialize the translator.
Args:
meta_keys: List of meta keys to be used in the query. Default: [].
"""
self.meta_keys = meta_keys or [] self.meta_keys = meta_keys or []
def visit_operation(self, operation: Operation) -> str: def visit_operation(self, operation: Operation) -> str:
"""Visit an operation node and return the translated query.
Args:
operation: Operation node to be visited.
Returns:
Translated query.
"""
if operation.operator in (Operator.AND, Operator.OR): if operation.operator in (Operator.AND, Operator.OR):
ret = f" {operation.operator.value} ".join( ret = f" {operation.operator.value} ".join(
[arg.accept(self) for arg in operation.arguments] [arg.accept(self) for arg in operation.arguments]
@ -46,6 +61,14 @@ class TencentVectorDBTranslator(Visitor):
return f"not ({operation.arguments[0].accept(self)})" return f"not ({operation.arguments[0].accept(self)})"
def visit_comparison(self, comparison: Comparison) -> str: def visit_comparison(self, comparison: Comparison) -> str:
"""Visit a comparison node and return the translated query.
Args:
comparison: Comparison node to be visited.
Returns:
Translated query.
"""
if self.meta_keys and comparison.attribute not in self.meta_keys: if self.meta_keys and comparison.attribute not in self.meta_keys:
raise ValueError( raise ValueError(
f"Expr Filtering found Unsupported attribute: {comparison.attribute}" f"Expr Filtering found Unsupported attribute: {comparison.attribute}"
@ -78,6 +101,14 @@ class TencentVectorDBTranslator(Visitor):
def visit_structured_query( def visit_structured_query(
self, structured_query: StructuredQuery self, structured_query: StructuredQuery
) -> Tuple[str, dict]: ) -> Tuple[str, dict]:
"""Visit a structured query node and return the translated query.
Args:
structured_query: StructuredQuery node to be visited.
Returns:
Translated query and query kwargs.
"""
if structured_query.filter is None: if structured_query.filter is None:
kwargs = {} kwargs = {}
else: else:

View File

@ -281,6 +281,12 @@ def _get_prompt(inputs: Dict[str, Any]) -> str:
class ChatModelInput(TypedDict): class ChatModelInput(TypedDict):
"""Input for a chat model.
Parameters:
messages: List of chat messages.
"""
messages: List[BaseMessage] messages: List[BaseMessage]