core: docstrings utils update (#24213)

Added missed docstrings. Formatted docstrings to the consistent form.
This commit is contained in:
Leonid Ganeline 2024-07-15 08:36:00 -07:00 committed by GitHub
parent e8a21146d3
commit 36ee083753
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 419 additions and 58 deletions

View File

@ -8,6 +8,17 @@ def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]
dictionaries but has a value of None in 'left'. In such cases, the method uses the
value from 'right' for that key in the merged dictionary.
Args:
left: The first dictionary to merge.
others: The other dictionaries to merge.
Returns:
The merged dictionary.
Raises:
TypeError: If the key exists in both dictionaries but has a different type.
TypeError: If the value has an unsupported type.
Example:
If left = {"function_call": {"arguments": None}} and
right = {"function_call": {"arguments": "{\n"}}
@ -46,7 +57,15 @@ def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]
def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]:
"""Add many lists, handling None."""
"""Add many lists, handling None.
Args:
left: The first list to merge.
others: The other lists to merge.
Returns:
The merged list.
"""
merged = left.copy() if left is not None else None
for other in others:
if other is None:
@ -75,6 +94,23 @@ def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]
def merge_obj(left: Any, right: Any) -> Any:
"""Merge two objects.
It handles specific scenarios where a key exists in both
dictionaries but has a value of None in 'left'. In such cases, the method uses the
value from 'right' for that key in the merged dictionary.
Args:
left: The first object to merge.
right: The other object to merge.
Returns:
The merged object.
Raises:
TypeError: If the key exists in both dictionaries but has a different type.
ValueError: If the two objects cannot be merged.
"""
if left is None or right is None:
return left if left is not None else right
elif type(left) is not type(right):

View File

@ -44,6 +44,18 @@ def py_anext(
Can be used to compare the built-in implementation of the inner
coroutines machinery to C-implementation of __anext__() and send()
or throw() on the returned generator.
Args:
iterator: The async iterator to advance.
default: The value to return if the iterator is exhausted.
If not provided, a StopAsyncIteration exception is raised.
Returns:
The next value from the iterator, or the default value
if the iterator is exhausted.
Raises:
TypeError: If the iterator is not an async iterator.
"""
try:
@ -71,7 +83,7 @@ def py_anext(
class NoLock:
"""Dummy lock that provides the proper interface but no protection"""
"""Dummy lock that provides the proper interface but no protection."""
async def __aenter__(self) -> None:
pass
@ -88,7 +100,21 @@ async def tee_peer(
peers: List[Deque[T]],
lock: AsyncContextManager[Any],
) -> AsyncGenerator[T, None]:
"""An individual iterator of a :py:func:`~.tee`"""
"""An individual iterator of a :py:func:`~.tee`.
This function is a generator that yields items from the shared iterator
``iterator``. It buffers items until the least advanced iterator has
yielded them as well. The buffer is shared with all other peers.
Args:
iterator: The shared iterator.
buffer: The buffer for this peer.
peers: The buffers of all peers.
lock: The lock to synchronise access to the shared buffers.
Yields:
The next item from the shared iterator.
"""
try:
while True:
if not buffer:
@ -204,6 +230,7 @@ class Tee(Generic[T]):
return False
async def aclose(self) -> None:
"""Async close all child iterators."""
for child in self._children:
await child.aclose()
@ -258,7 +285,7 @@ async def abatch_iterate(
iterable: The async iterable to batch.
Returns:
An async iterator over the batches
An async iterator over the batches.
"""
batch: List[T] = []
async for element in iterable:

View File

@ -36,7 +36,7 @@ def get_from_dict_or_env(
env_key: The environment variable to look up if the key is not
in the dictionary.
default: The default value to return if the key is not in the dictionary
or the environment.
or the environment. Defaults to None.
"""
if isinstance(key, (list, tuple)):
for k in key:
@ -56,7 +56,22 @@ def get_from_dict_or_env(
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
"""Get a value from a dictionary or an environment variable."""
"""Get a value from a dictionary or an environment variable.
Args:
key: The key to look up in the dictionary.
env_key: The environment variable to look up if the key is not
in the dictionary.
default: The default value to return if the key is not in the dictionary
or the environment. Defaults to None.
Returns:
str: The value of the key.
Raises:
ValueError: If the key is not in the dictionary and no default value is
provided or if the environment variable is not set.
"""
if env_key in os.environ and os.environ[env_key]:
return os.environ[env_key]
elif default is not None:

View File

@ -10,7 +10,19 @@ class StrictFormatter(Formatter):
def vformat(
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
) -> str:
"""Check that no arguments are provided."""
"""Check that no arguments are provided.
Args:
format_string: The format string.
args: The arguments.
kwargs: The keyword arguments.
Returns:
The formatted string.
Raises:
ValueError: If any arguments are provided.
"""
if len(args) > 0:
raise ValueError(
"No arguments should be provided, "
@ -21,6 +33,15 @@ class StrictFormatter(Formatter):
def validate_input_variables(
self, format_string: str, input_variables: List[str]
) -> None:
"""Check that all input variables are used in the format string.
Args:
format_string: The format string.
input_variables: The input variables.
Raises:
ValueError: If any input variables are not used in the format string.
"""
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
super().format(format_string, **dummy_inputs)

View File

@ -55,7 +55,9 @@ class ToolDescription(TypedDict):
"""Representation of a callable function to the OpenAI API."""
type: Literal["function"]
"""The type of the tool."""
function: FunctionDescription
"""The function description."""
def _rm_titles(kv: dict, prev_key: str = "") -> dict:
@ -85,7 +87,19 @@ def convert_pydantic_to_openai_function(
description: Optional[str] = None,
rm_titles: bool = True,
) -> FunctionDescription:
"""Converts a Pydantic model to a function description for the OpenAI API."""
"""Converts a Pydantic model to a function description for the OpenAI API.
Args:
model: The Pydantic model to convert.
name: The name of the function. If not provided, the title of the schema will be
used.
description: The description of the function. If not provided, the description
of the schema will be used.
rm_titles: Whether to remove titles from the schema. Defaults to True.
Returns:
The function description.
"""
schema = dereference_refs(model.schema())
schema.pop("definitions", None)
title = schema.pop("title", "")
@ -108,7 +122,18 @@ def convert_pydantic_to_openai_tool(
name: Optional[str] = None,
description: Optional[str] = None,
) -> ToolDescription:
"""Converts a Pydantic model to a function description for the OpenAI API."""
"""Converts a Pydantic model to a function description for the OpenAI API.
Args:
model: The Pydantic model to convert.
name: The name of the function. If not provided, the title of the schema will be
used.
description: The description of the function. If not provided, the description
of the schema will be used.
Returns:
The tool description.
"""
function = convert_pydantic_to_openai_function(
model, name=name, description=description
)
@ -133,6 +158,12 @@ def convert_python_function_to_openai_function(
Assumes the Python function has type hints and a docstring with a description. If
the docstring has Google Python style argument descriptions, these will be
included as well.
Args:
function: The Python function to convert.
Returns:
The OpenAI function description.
"""
from langchain_core import tools
@ -157,7 +188,14 @@ def convert_python_function_to_openai_function(
removal="0.3.0",
)
def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
"""Format tool into the OpenAI function API."""
"""Format tool into the OpenAI function API.
Args:
tool: The tool to format.
Returns:
The function description.
"""
if tool.args_schema:
return convert_pydantic_to_openai_function(
tool.args_schema, name=tool.name, description=tool.description
@ -187,7 +225,14 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
removal="0.3.0",
)
def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
"""Format tool into the OpenAI function API."""
"""Format tool into the OpenAI function API.
Args:
tool: The tool to format.
Returns:
The tool description.
"""
function = format_tool_to_openai_function(tool)
return {"type": "function", "function": function}
@ -206,6 +251,9 @@ def convert_to_openai_function(
Returns:
A dict version of the passed in function which is compatible with the
OpenAI function-calling API.
Raises:
ValueError: If the function is not in a supported format.
"""
from langchain_core.tools import BaseTool
@ -284,7 +332,7 @@ def tool_example_to_messages(
BaseModels
tool_outputs: Optional[List[str]], a list of tool call outputs.
Does not need to be provided. If not provided, a placeholder value
will be inserted.
will be inserted. Defaults to None.
Returns:
A list of messages

View File

@ -34,11 +34,11 @@ DEFAULT_LINK_REGEX = (
def find_all_links(
raw_html: str, *, pattern: Union[str, re.Pattern, None] = None
) -> List[str]:
"""Extract all links from a raw html string.
"""Extract all links from a raw HTML string.
Args:
raw_html: original html.
pattern: Regex to use for extracting links from raw html.
raw_html: original HTML.
pattern: Regex to use for extracting links from raw HTML.
Returns:
List[str]: all links
@ -57,20 +57,20 @@ def extract_sub_links(
exclude_prefixes: Sequence[str] = (),
continue_on_failure: bool = False,
) -> List[str]:
"""Extract all links from a raw html string and convert into absolute paths.
"""Extract all links from a raw HTML string and convert into absolute paths.
Args:
raw_html: original html.
url: the url of the html.
base_url: the base url to check for outside links against.
pattern: Regex to use for extracting links from raw html.
raw_html: original HTML.
url: the url of the HTML.
base_url: the base URL to check for outside links against.
pattern: Regex to use for extracting links from raw HTML.
prevent_outside: If True, ignore external links which are not children
of the base url.
of the base URL.
exclude_prefixes: Exclude any URLs that start with one of these prefixes.
continue_on_failure: If True, continue if parsing a specific link raises an
exception. Otherwise, raise the exception.
Returns:
List[str]: sub links
List[str]: sub links.
"""
base_url_to_use = base_url if base_url is not None else url
parsed_base_url = urlparse(base_url_to_use)

View File

@ -3,12 +3,27 @@ import mimetypes
def encode_image(image_path: str) -> str:
"""Get base64 string from image URI."""
"""Get base64 string from image URI.
Args:
image_path: The path to the image.
Returns:
The base64 string of the image.
"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def image_to_data_url(image_path: str) -> str:
"""Get data URL from image URI.
Args:
image_path: The path to the image.
Returns:
The data URL of the image.
"""
encoding = encode_image(image_path)
mime_type = mimetypes.guess_type(image_path)[0]
return f"data:{mime_type};base64,{encoding}"

View File

@ -14,7 +14,15 @@ _TEXT_COLOR_MAPPING = {
def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None
) -> Dict[str, str]:
"""Get mapping for items to a support color."""
"""Get mapping for items to a support color.
Args:
items: The items to map to colors.
excluded_colors: The colors to exclude.
Returns:
The mapping of items to colors.
"""
colors = list(_TEXT_COLOR_MAPPING.keys())
if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors]
@ -23,20 +31,45 @@ def get_color_mapping(
def get_colored_text(text: str, color: str) -> str:
"""Get colored text."""
"""Get colored text.
Args:
text: The text to color.
color: The color to use.
Returns:
The colored text.
"""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
def get_bolded_text(text: str) -> str:
"""Get bolded text."""
"""Get bolded text.
Args:
text: The text to bold.
Returns:
The bolded text.
"""
return f"\033[1m{text}\033[0m"
def print_text(
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
) -> None:
"""Print text with highlighting and no end characters."""
"""Print text with highlighting and no end characters.
If a color is provided, the text will be printed in that color.
If a file is provided, the text will be written to that file.
Args:
text: The text to print.
color: The color to use. Defaults to None.
end: The end character to use. Defaults to "".
file: The file to write to. Defaults to None.
"""
text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file)
if file:

View File

@ -22,7 +22,7 @@ T = TypeVar("T")
class NoLock:
"""Dummy lock that provides the proper interface but no protection"""
"""Dummy lock that provides the proper interface but no protection."""
def __enter__(self) -> None:
pass
@ -39,7 +39,21 @@ def tee_peer(
peers: List[Deque[T]],
lock: ContextManager[Any],
) -> Generator[T, None, None]:
"""An individual iterator of a :py:func:`~.tee`"""
"""An individual iterator of a :py:func:`~.tee`.
This function is a generator that yields items from the shared iterator
``iterator``. It buffers items until the least advanced iterator has
yielded them as well. The buffer is shared with all other peers.
Args:
iterator: The shared iterator.
buffer: The buffer for this peer.
peers: The buffers of all peers.
lock: The lock to synchronise access to the shared buffers.
Yields:
The next item from the shared iterator.
"""
try:
while True:
if not buffer:
@ -118,6 +132,14 @@ class Tee(Generic[T]):
*,
lock: Optional[ContextManager[Any]] = None,
):
"""Create a new ``tee``.
Args:
iterable: The iterable to split.
n: The number of iterators to create. Defaults to 2.
lock: The lock to synchronise access to the shared buffers.
Defaults to None.
"""
self._iterator = iter(iterable)
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
self._children = tuple(
@ -170,8 +192,8 @@ def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[List[T
size: The size of the batch. If None, returns a single batch.
iterable: The iterable to batch.
Returns:
An iterator over the batches.
Yields:
The batches of the iterable.
"""
it = iter(iterable)
while True:

View File

@ -124,8 +124,7 @@ _json_markdown_re = re.compile(r"```(json)?(.*)", re.DOTALL)
def parse_json_markdown(
json_string: str, *, parser: Callable[[str], Any] = parse_partial_json
) -> dict:
"""
Parse a JSON string from a Markdown string.
"""Parse a JSON string from a Markdown string.
Args:
json_string: The Markdown string.
@ -175,6 +174,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
Returns:
The parsed JSON object as a Python dictionary.
Raises:
OutputParserException: If the JSON string is invalid or does not contain
the expected keys.
"""
try:
json_obj = parse_json_markdown(text)

View File

@ -90,7 +90,16 @@ def dereference_refs(
full_schema: Optional[dict] = None,
skip_keys: Optional[Sequence[str]] = None,
) -> dict:
"""Try to substitute $refs in JSON Schema."""
"""Try to substitute $refs in JSON Schema.
Args:
schema_obj: The schema object to dereference.
full_schema: The full schema object. Defaults to None.
skip_keys: The keys to skip. Defaults to None.
Returns:
The dereferenced schema object.
"""
full_schema = full_schema or schema_obj
skip_keys = (

View File

@ -42,7 +42,15 @@ class ChevronError(SyntaxError):
def grab_literal(template: str, l_del: str) -> Tuple[str, str]:
"""Parse a literal from the template."""
"""Parse a literal from the template.
Args:
template: The template to parse.
l_del: The left delimiter.
Returns:
Tuple[str, str]: The literal and the template.
"""
global _CURRENT_LINE
@ -59,7 +67,16 @@ def grab_literal(template: str, l_del: str) -> Tuple[str, str]:
def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
"""Do a preliminary check to see if a tag could be a standalone."""
"""Do a preliminary check to see if a tag could be a standalone.
Args:
template: The template. (Not used.)
literal: The literal.
is_standalone: Whether the tag is standalone.
Returns:
bool: Whether the tag could be a standalone.
"""
# If there is a newline, or the previous tag was a standalone
if literal.find("\n") != -1 or is_standalone:
@ -77,7 +94,16 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
"""Do a final check to see if a tag could be a standalone."""
"""Do a final check to see if a tag could be a standalone.
Args:
template: The template.
tag_type: The type of the tag.
is_standalone: Whether the tag is standalone.
Returns:
bool: Whether the tag could be a standalone.
"""
# Check right side if we might be a standalone
if is_standalone and tag_type not in ["variable", "no escape"]:
@ -95,7 +121,20 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], str]:
"""Parse a tag from a template."""
"""Parse a tag from a template.
Args:
template: The template.
l_del: The left delimiter.
r_del: The right delimiter.
Returns:
Tuple[Tuple[str, str], str]: The tag and the template.
Raises:
ChevronError: If the tag is unclosed.
ChevronError: If the set delimiter tag is unclosed.
"""
global _CURRENT_LINE
global _LAST_TAG_LINE
@ -404,36 +443,36 @@ def render(
Arguments:
template -- A file-like object or a string containing the template
template -- A file-like object or a string containing the template.
data -- A python dictionary with your data scope
data -- A python dictionary with your data scope.
partials_path -- The path to where your partials are stored
partials_path -- The path to where your partials are stored.
If set to None, then partials won't be loaded from the file system
(defaults to '.')
(defaults to '.').
partials_ext -- The extension that you want the parser to look for
(defaults to 'mustache')
(defaults to 'mustache').
partials_dict -- A python dictionary which will be search for partials
before the filesystem is. {'include': 'foo'} is the same
as a file called include.mustache
(defaults to {})
(defaults to {}).
padding -- This is for padding partials, and shouldn't be used
(but can be if you really want to)
(but can be if you really want to).
def_ldel -- The default left delimiter
("{{" by default, as in spec compliant mustache)
("{{" by default, as in spec compliant mustache).
def_rdel -- The default right delimiter
("}}" by default, as in spec compliant mustache)
("}}" by default, as in spec compliant mustache).
scopes -- The list of scopes that get_key will look through
scopes -- The list of scopes that get_key will look through.
warn -- Log a warning when a template substitution isn't found in the data
keep -- Keep unreplaced tags when a substitution isn't found in the data
keep -- Keep unreplaced tags when a substitution isn't found in the data.
Returns:

View File

@ -21,12 +21,27 @@ PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
# How to type hint this?
def pre_init(func: Callable) -> Any:
"""Decorator to run a function before model initialization."""
"""Decorator to run a function before model initialization.
Args:
func (Callable): The function to run before model initialization.
Returns:
Any: The decorated function.
"""
@root_validator(pre=True)
@wraps(func)
def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]:
"""Decorator to run a function before model initialization."""
"""Decorator to run a function before model initialization.
Args:
cls (Type[BaseModel]): The model class.
values (Dict[str, Any]): The values to initialize the model with.
Returns:
Dict[str, Any]: The values to initialize the model with.
"""
# Insert default values
fields = cls.__fields__
for name, field_info in fields.items():

View File

@ -36,5 +36,12 @@ def stringify_dict(data: dict) -> str:
def comma_list(items: List[Any]) -> str:
"""Convert a list to a comma-separated string."""
"""Convert a list to a comma-separated string.
Args:
items: The list to convert.
Returns:
str: The comma-separated string.
"""
return ", ".join(str(item) for item in items)

View File

@ -15,7 +15,18 @@ from langchain_core.pydantic_v1 import SecretStr
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
"""Validate specified keyword args are mutually exclusive."""
"""Validate specified keyword args are mutually exclusive."
Args:
*arg_groups (Tuple[str, ...]): Groups of mutually exclusive keyword args.
Returns:
Callable: Decorator that validates the specified keyword args
are mutually exclusive
Raises:
ValueError: If more than one arg in a group is defined.
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
@ -41,7 +52,14 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
def raise_for_status_with_text(response: Response) -> None:
"""Raise an error with the response text."""
"""Raise an error with the response text.
Args:
response (Response): The response to check for errors.
Raises:
ValueError: If the response has an error status code.
"""
try:
response.raise_for_status()
except HTTPError as e:
@ -52,6 +70,12 @@ def raise_for_status_with_text(response: Response) -> None:
def mock_now(dt_value): # type: ignore
"""Context manager for mocking out datetime.now() in unit tests.
Args:
dt_value: The datetime value to use for datetime.now().
Yields:
datetime.datetime: The mocked datetime class.
Example:
with mock_now(datetime.datetime(2011, 2, 3, 10, 11)):
assert datetime.datetime.now() == datetime.datetime(2011, 2, 3, 10, 11)
@ -86,7 +110,21 @@ def guard_import(
module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None
) -> Any:
"""Dynamically import a module and raise an exception if the module is not
installed."""
installed.
Args:
module_name (str): The name of the module to import.
pip_name (str, optional): The name of the module to install with pip.
Defaults to None.
package (str, optional): The package to import the module from.
Defaults to None.
Returns:
Any: The imported module.
Raises:
ImportError: If the module is not installed.
"""
try:
module = importlib.import_module(module_name, package)
except (ImportError, ModuleNotFoundError):
@ -105,7 +143,22 @@ def check_package_version(
gt_version: Optional[str] = None,
gte_version: Optional[str] = None,
) -> None:
"""Check the version of a package."""
"""Check the version of a package.
Args:
package (str): The name of the package.
lt_version (str, optional): The version must be less than this.
Defaults to None.
lte_version (str, optional): The version must be less than or equal to this.
Defaults to None.
gt_version (str, optional): The version must be greater than this.
Defaults to None.
gte_version (str, optional): The version must be greater than or equal to this.
Defaults to None.
Raises:
ValueError: If the package version does not meet the requirements.
"""
imported_version = parse(version(package))
if lt_version is not None and imported_version >= parse(lt_version):
raise ValueError(
@ -133,7 +186,11 @@ def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]:
"""Get field names, including aliases, for a pydantic class.
Args:
pydantic_cls: Pydantic class."""
pydantic_cls: Pydantic class.
Returns:
Set[str]: Field names.
"""
all_required_field_names = set()
for field in pydantic_cls.__fields__.values():
all_required_field_names.add(field.name)
@ -153,6 +210,13 @@ def build_extra_kwargs(
extra_kwargs: Extra kwargs passed in by user.
values: Values passed in by user.
all_required_field_names: All required field names for the pydantic class.
Returns:
Dict[str, Any]: Extra kwargs.
Raises:
ValueError: If a field is specified in both values and extra_kwargs.
ValueError: If a field is specified in model_kwargs.
"""
for field_name in list(values):
if field_name in extra_kwargs:
@ -176,7 +240,14 @@ def build_extra_kwargs(
def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr:
"""Convert a string to a SecretStr if needed."""
"""Convert a string to a SecretStr if needed.
Args:
value (Union[SecretStr, str]): The value to convert.
Returns:
SecretStr: The SecretStr value.
"""
if isinstance(value, SecretStr):
return value
return SecretStr(value)