core: docstrings utils update (#24213)

Added missed docstrings. Formatted docstrings to the consistent form.
This commit is contained in:
Leonid Ganeline 2024-07-15 08:36:00 -07:00 committed by GitHub
parent e8a21146d3
commit 36ee083753
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 419 additions and 58 deletions

View File

@ -8,6 +8,17 @@ def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]
dictionaries but has a value of None in 'left'. In such cases, the method uses the dictionaries but has a value of None in 'left'. In such cases, the method uses the
value from 'right' for that key in the merged dictionary. value from 'right' for that key in the merged dictionary.
Args:
left: The first dictionary to merge.
others: The other dictionaries to merge.
Returns:
The merged dictionary.
Raises:
TypeError: If the key exists in both dictionaries but has a different type.
TypeError: If the value has an unsupported type.
Example: Example:
If left = {"function_call": {"arguments": None}} and If left = {"function_call": {"arguments": None}} and
right = {"function_call": {"arguments": "{\n"}} right = {"function_call": {"arguments": "{\n"}}
@ -46,7 +57,15 @@ def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]
def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]: def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]:
"""Add many lists, handling None.""" """Add many lists, handling None.
Args:
left: The first list to merge.
others: The other lists to merge.
Returns:
The merged list.
"""
merged = left.copy() if left is not None else None merged = left.copy() if left is not None else None
for other in others: for other in others:
if other is None: if other is None:
@ -75,6 +94,23 @@ def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]
def merge_obj(left: Any, right: Any) -> Any: def merge_obj(left: Any, right: Any) -> Any:
"""Merge two objects.
It handles specific scenarios where a key exists in both
dictionaries but has a value of None in 'left'. In such cases, the method uses the
value from 'right' for that key in the merged dictionary.
Args:
left: The first object to merge.
right: The other object to merge.
Returns:
The merged object.
Raises:
TypeError: If the key exists in both dictionaries but has a different type.
ValueError: If the two objects cannot be merged.
"""
if left is None or right is None: if left is None or right is None:
return left if left is not None else right return left if left is not None else right
elif type(left) is not type(right): elif type(left) is not type(right):

View File

@ -44,6 +44,18 @@ def py_anext(
Can be used to compare the built-in implementation of the inner Can be used to compare the built-in implementation of the inner
coroutines machinery to C-implementation of __anext__() and send() coroutines machinery to C-implementation of __anext__() and send()
or throw() on the returned generator. or throw() on the returned generator.
Args:
iterator: The async iterator to advance.
default: The value to return if the iterator is exhausted.
If not provided, a StopAsyncIteration exception is raised.
Returns:
The next value from the iterator, or the default value
if the iterator is exhausted.
Raises:
TypeError: If the iterator is not an async iterator.
""" """
try: try:
@ -71,7 +83,7 @@ def py_anext(
class NoLock: class NoLock:
"""Dummy lock that provides the proper interface but no protection""" """Dummy lock that provides the proper interface but no protection."""
async def __aenter__(self) -> None: async def __aenter__(self) -> None:
pass pass
@ -88,7 +100,21 @@ async def tee_peer(
peers: List[Deque[T]], peers: List[Deque[T]],
lock: AsyncContextManager[Any], lock: AsyncContextManager[Any],
) -> AsyncGenerator[T, None]: ) -> AsyncGenerator[T, None]:
"""An individual iterator of a :py:func:`~.tee`""" """An individual iterator of a :py:func:`~.tee`.
This function is a generator that yields items from the shared iterator
``iterator``. It buffers items until the least advanced iterator has
yielded them as well. The buffer is shared with all other peers.
Args:
iterator: The shared iterator.
buffer: The buffer for this peer.
peers: The buffers of all peers.
lock: The lock to synchronise access to the shared buffers.
Yields:
The next item from the shared iterator.
"""
try: try:
while True: while True:
if not buffer: if not buffer:
@ -204,6 +230,7 @@ class Tee(Generic[T]):
return False return False
async def aclose(self) -> None: async def aclose(self) -> None:
"""Async close all child iterators."""
for child in self._children: for child in self._children:
await child.aclose() await child.aclose()
@ -258,7 +285,7 @@ async def abatch_iterate(
iterable: The async iterable to batch. iterable: The async iterable to batch.
Returns: Returns:
An async iterator over the batches An async iterator over the batches.
""" """
batch: List[T] = [] batch: List[T] = []
async for element in iterable: async for element in iterable:

View File

@ -36,7 +36,7 @@ def get_from_dict_or_env(
env_key: The environment variable to look up if the key is not env_key: The environment variable to look up if the key is not
in the dictionary. in the dictionary.
default: The default value to return if the key is not in the dictionary default: The default value to return if the key is not in the dictionary
or the environment. or the environment. Defaults to None.
""" """
if isinstance(key, (list, tuple)): if isinstance(key, (list, tuple)):
for k in key: for k in key:
@ -56,7 +56,22 @@ def get_from_dict_or_env(
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str: def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
"""Get a value from a dictionary or an environment variable.""" """Get a value from a dictionary or an environment variable.
Args:
key: The key to look up in the dictionary.
env_key: The environment variable to look up if the key is not
in the dictionary.
default: The default value to return if the key is not in the dictionary
or the environment. Defaults to None.
Returns:
str: The value of the key.
Raises:
ValueError: If the key is not in the dictionary and no default value is
provided or if the environment variable is not set.
"""
if env_key in os.environ and os.environ[env_key]: if env_key in os.environ and os.environ[env_key]:
return os.environ[env_key] return os.environ[env_key]
elif default is not None: elif default is not None:

View File

@ -10,7 +10,19 @@ class StrictFormatter(Formatter):
def vformat( def vformat(
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any] self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
) -> str: ) -> str:
"""Check that no arguments are provided.""" """Check that no arguments are provided.
Args:
format_string: The format string.
args: The arguments.
kwargs: The keyword arguments.
Returns:
The formatted string.
Raises:
ValueError: If any arguments are provided.
"""
if len(args) > 0: if len(args) > 0:
raise ValueError( raise ValueError(
"No arguments should be provided, " "No arguments should be provided, "
@ -21,6 +33,15 @@ class StrictFormatter(Formatter):
def validate_input_variables( def validate_input_variables(
self, format_string: str, input_variables: List[str] self, format_string: str, input_variables: List[str]
) -> None: ) -> None:
"""Check that all input variables are used in the format string.
Args:
format_string: The format string.
input_variables: The input variables.
Raises:
ValueError: If any input variables are not used in the format string.
"""
dummy_inputs = {input_variable: "foo" for input_variable in input_variables} dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
super().format(format_string, **dummy_inputs) super().format(format_string, **dummy_inputs)

View File

@ -55,7 +55,9 @@ class ToolDescription(TypedDict):
"""Representation of a callable function to the OpenAI API.""" """Representation of a callable function to the OpenAI API."""
type: Literal["function"] type: Literal["function"]
"""The type of the tool."""
function: FunctionDescription function: FunctionDescription
"""The function description."""
def _rm_titles(kv: dict, prev_key: str = "") -> dict: def _rm_titles(kv: dict, prev_key: str = "") -> dict:
@ -85,7 +87,19 @@ def convert_pydantic_to_openai_function(
description: Optional[str] = None, description: Optional[str] = None,
rm_titles: bool = True, rm_titles: bool = True,
) -> FunctionDescription: ) -> FunctionDescription:
"""Converts a Pydantic model to a function description for the OpenAI API.""" """Converts a Pydantic model to a function description for the OpenAI API.
Args:
model: The Pydantic model to convert.
name: The name of the function. If not provided, the title of the schema will be
used.
description: The description of the function. If not provided, the description
of the schema will be used.
rm_titles: Whether to remove titles from the schema. Defaults to True.
Returns:
The function description.
"""
schema = dereference_refs(model.schema()) schema = dereference_refs(model.schema())
schema.pop("definitions", None) schema.pop("definitions", None)
title = schema.pop("title", "") title = schema.pop("title", "")
@ -108,7 +122,18 @@ def convert_pydantic_to_openai_tool(
name: Optional[str] = None, name: Optional[str] = None,
description: Optional[str] = None, description: Optional[str] = None,
) -> ToolDescription: ) -> ToolDescription:
"""Converts a Pydantic model to a function description for the OpenAI API.""" """Converts a Pydantic model to a function description for the OpenAI API.
Args:
model: The Pydantic model to convert.
name: The name of the function. If not provided, the title of the schema will be
used.
description: The description of the function. If not provided, the description
of the schema will be used.
Returns:
The tool description.
"""
function = convert_pydantic_to_openai_function( function = convert_pydantic_to_openai_function(
model, name=name, description=description model, name=name, description=description
) )
@ -133,6 +158,12 @@ def convert_python_function_to_openai_function(
Assumes the Python function has type hints and a docstring with a description. If Assumes the Python function has type hints and a docstring with a description. If
the docstring has Google Python style argument descriptions, these will be the docstring has Google Python style argument descriptions, these will be
included as well. included as well.
Args:
function: The Python function to convert.
Returns:
The OpenAI function description.
""" """
from langchain_core import tools from langchain_core import tools
@ -157,7 +188,14 @@ def convert_python_function_to_openai_function(
removal="0.3.0", removal="0.3.0",
) )
def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
"""Format tool into the OpenAI function API.""" """Format tool into the OpenAI function API.
Args:
tool: The tool to format.
Returns:
The function description.
"""
if tool.args_schema: if tool.args_schema:
return convert_pydantic_to_openai_function( return convert_pydantic_to_openai_function(
tool.args_schema, name=tool.name, description=tool.description tool.args_schema, name=tool.name, description=tool.description
@ -187,7 +225,14 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
removal="0.3.0", removal="0.3.0",
) )
def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription: def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
"""Format tool into the OpenAI function API.""" """Format tool into the OpenAI function API.
Args:
tool: The tool to format.
Returns:
The tool description.
"""
function = format_tool_to_openai_function(tool) function = format_tool_to_openai_function(tool)
return {"type": "function", "function": function} return {"type": "function", "function": function}
@ -206,6 +251,9 @@ def convert_to_openai_function(
Returns: Returns:
A dict version of the passed in function which is compatible with the A dict version of the passed in function which is compatible with the
OpenAI function-calling API. OpenAI function-calling API.
Raises:
ValueError: If the function is not in a supported format.
""" """
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
@ -284,7 +332,7 @@ def tool_example_to_messages(
BaseModels BaseModels
tool_outputs: Optional[List[str]], a list of tool call outputs. tool_outputs: Optional[List[str]], a list of tool call outputs.
Does not need to be provided. If not provided, a placeholder value Does not need to be provided. If not provided, a placeholder value
will be inserted. will be inserted. Defaults to None.
Returns: Returns:
A list of messages A list of messages

View File

@ -34,11 +34,11 @@ DEFAULT_LINK_REGEX = (
def find_all_links( def find_all_links(
raw_html: str, *, pattern: Union[str, re.Pattern, None] = None raw_html: str, *, pattern: Union[str, re.Pattern, None] = None
) -> List[str]: ) -> List[str]:
"""Extract all links from a raw html string. """Extract all links from a raw HTML string.
Args: Args:
raw_html: original html. raw_html: original HTML.
pattern: Regex to use for extracting links from raw html. pattern: Regex to use for extracting links from raw HTML.
Returns: Returns:
List[str]: all links List[str]: all links
@ -57,20 +57,20 @@ def extract_sub_links(
exclude_prefixes: Sequence[str] = (), exclude_prefixes: Sequence[str] = (),
continue_on_failure: bool = False, continue_on_failure: bool = False,
) -> List[str]: ) -> List[str]:
"""Extract all links from a raw html string and convert into absolute paths. """Extract all links from a raw HTML string and convert into absolute paths.
Args: Args:
raw_html: original html. raw_html: original HTML.
url: the url of the html. url: the url of the HTML.
base_url: the base url to check for outside links against. base_url: the base URL to check for outside links against.
pattern: Regex to use for extracting links from raw html. pattern: Regex to use for extracting links from raw HTML.
prevent_outside: If True, ignore external links which are not children prevent_outside: If True, ignore external links which are not children
of the base url. of the base URL.
exclude_prefixes: Exclude any URLs that start with one of these prefixes. exclude_prefixes: Exclude any URLs that start with one of these prefixes.
continue_on_failure: If True, continue if parsing a specific link raises an continue_on_failure: If True, continue if parsing a specific link raises an
exception. Otherwise, raise the exception. exception. Otherwise, raise the exception.
Returns: Returns:
List[str]: sub links List[str]: sub links.
""" """
base_url_to_use = base_url if base_url is not None else url base_url_to_use = base_url if base_url is not None else url
parsed_base_url = urlparse(base_url_to_use) parsed_base_url = urlparse(base_url_to_use)

View File

@ -3,12 +3,27 @@ import mimetypes
def encode_image(image_path: str) -> str: def encode_image(image_path: str) -> str:
"""Get base64 string from image URI.""" """Get base64 string from image URI.
Args:
image_path: The path to the image.
Returns:
The base64 string of the image.
"""
with open(image_path, "rb") as image_file: with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8") return base64.b64encode(image_file.read()).decode("utf-8")
def image_to_data_url(image_path: str) -> str: def image_to_data_url(image_path: str) -> str:
"""Get data URL from image URI.
Args:
image_path: The path to the image.
Returns:
The data URL of the image.
"""
encoding = encode_image(image_path) encoding = encode_image(image_path)
mime_type = mimetypes.guess_type(image_path)[0] mime_type = mimetypes.guess_type(image_path)[0]
return f"data:{mime_type};base64,{encoding}" return f"data:{mime_type};base64,{encoding}"

View File

@ -14,7 +14,15 @@ _TEXT_COLOR_MAPPING = {
def get_color_mapping( def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None items: List[str], excluded_colors: Optional[List] = None
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Get mapping for items to a support color.""" """Get mapping for items to a support color.
Args:
items: The items to map to colors.
excluded_colors: The colors to exclude.
Returns:
The mapping of items to colors.
"""
colors = list(_TEXT_COLOR_MAPPING.keys()) colors = list(_TEXT_COLOR_MAPPING.keys())
if excluded_colors is not None: if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors] colors = [c for c in colors if c not in excluded_colors]
@ -23,20 +31,45 @@ def get_color_mapping(
def get_colored_text(text: str, color: str) -> str: def get_colored_text(text: str, color: str) -> str:
"""Get colored text.""" """Get colored text.
Args:
text: The text to color.
color: The color to use.
Returns:
The colored text.
"""
color_str = _TEXT_COLOR_MAPPING[color] color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
def get_bolded_text(text: str) -> str: def get_bolded_text(text: str) -> str:
"""Get bolded text.""" """Get bolded text.
Args:
text: The text to bold.
Returns:
The bolded text.
"""
return f"\033[1m{text}\033[0m" return f"\033[1m{text}\033[0m"
def print_text( def print_text(
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
) -> None: ) -> None:
"""Print text with highlighting and no end characters.""" """Print text with highlighting and no end characters.
If a color is provided, the text will be printed in that color.
If a file is provided, the text will be written to that file.
Args:
text: The text to print.
color: The color to use. Defaults to None.
end: The end character to use. Defaults to "".
file: The file to write to. Defaults to None.
"""
text_to_print = get_colored_text(text, color) if color else text text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file) print(text_to_print, end=end, file=file)
if file: if file:

View File

@ -22,7 +22,7 @@ T = TypeVar("T")
class NoLock: class NoLock:
"""Dummy lock that provides the proper interface but no protection""" """Dummy lock that provides the proper interface but no protection."""
def __enter__(self) -> None: def __enter__(self) -> None:
pass pass
@ -39,7 +39,21 @@ def tee_peer(
peers: List[Deque[T]], peers: List[Deque[T]],
lock: ContextManager[Any], lock: ContextManager[Any],
) -> Generator[T, None, None]: ) -> Generator[T, None, None]:
"""An individual iterator of a :py:func:`~.tee`""" """An individual iterator of a :py:func:`~.tee`.
This function is a generator that yields items from the shared iterator
``iterator``. It buffers items until the least advanced iterator has
yielded them as well. The buffer is shared with all other peers.
Args:
iterator: The shared iterator.
buffer: The buffer for this peer.
peers: The buffers of all peers.
lock: The lock to synchronise access to the shared buffers.
Yields:
The next item from the shared iterator.
"""
try: try:
while True: while True:
if not buffer: if not buffer:
@ -118,6 +132,14 @@ class Tee(Generic[T]):
*, *,
lock: Optional[ContextManager[Any]] = None, lock: Optional[ContextManager[Any]] = None,
): ):
"""Create a new ``tee``.
Args:
iterable: The iterable to split.
n: The number of iterators to create. Defaults to 2.
lock: The lock to synchronise access to the shared buffers.
Defaults to None.
"""
self._iterator = iter(iterable) self._iterator = iter(iterable)
self._buffers: List[Deque[T]] = [deque() for _ in range(n)] self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
self._children = tuple( self._children = tuple(
@ -170,8 +192,8 @@ def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[List[T
size: The size of the batch. If None, returns a single batch. size: The size of the batch. If None, returns a single batch.
iterable: The iterable to batch. iterable: The iterable to batch.
Returns: Yields:
An iterator over the batches. The batches of the iterable.
""" """
it = iter(iterable) it = iter(iterable)
while True: while True:

View File

@ -124,8 +124,7 @@ _json_markdown_re = re.compile(r"```(json)?(.*)", re.DOTALL)
def parse_json_markdown( def parse_json_markdown(
json_string: str, *, parser: Callable[[str], Any] = parse_partial_json json_string: str, *, parser: Callable[[str], Any] = parse_partial_json
) -> dict: ) -> dict:
""" """Parse a JSON string from a Markdown string.
Parse a JSON string from a Markdown string.
Args: Args:
json_string: The Markdown string. json_string: The Markdown string.
@ -175,6 +174,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
Returns: Returns:
The parsed JSON object as a Python dictionary. The parsed JSON object as a Python dictionary.
Raises:
OutputParserException: If the JSON string is invalid or does not contain
the expected keys.
""" """
try: try:
json_obj = parse_json_markdown(text) json_obj = parse_json_markdown(text)

View File

@ -90,7 +90,16 @@ def dereference_refs(
full_schema: Optional[dict] = None, full_schema: Optional[dict] = None,
skip_keys: Optional[Sequence[str]] = None, skip_keys: Optional[Sequence[str]] = None,
) -> dict: ) -> dict:
"""Try to substitute $refs in JSON Schema.""" """Try to substitute $refs in JSON Schema.
Args:
schema_obj: The schema object to dereference.
full_schema: The full schema object. Defaults to None.
skip_keys: The keys to skip. Defaults to None.
Returns:
The dereferenced schema object.
"""
full_schema = full_schema or schema_obj full_schema = full_schema or schema_obj
skip_keys = ( skip_keys = (

View File

@ -42,7 +42,15 @@ class ChevronError(SyntaxError):
def grab_literal(template: str, l_del: str) -> Tuple[str, str]: def grab_literal(template: str, l_del: str) -> Tuple[str, str]:
"""Parse a literal from the template.""" """Parse a literal from the template.
Args:
template: The template to parse.
l_del: The left delimiter.
Returns:
Tuple[str, str]: The literal and the template.
"""
global _CURRENT_LINE global _CURRENT_LINE
@ -59,7 +67,16 @@ def grab_literal(template: str, l_del: str) -> Tuple[str, str]:
def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool: def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
"""Do a preliminary check to see if a tag could be a standalone.""" """Do a preliminary check to see if a tag could be a standalone.
Args:
template: The template. (Not used.)
literal: The literal.
is_standalone: Whether the tag is standalone.
Returns:
bool: Whether the tag could be a standalone.
"""
# If there is a newline, or the previous tag was a standalone # If there is a newline, or the previous tag was a standalone
if literal.find("\n") != -1 or is_standalone: if literal.find("\n") != -1 or is_standalone:
@ -77,7 +94,16 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool: def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
"""Do a final check to see if a tag could be a standalone.""" """Do a final check to see if a tag could be a standalone.
Args:
template: The template.
tag_type: The type of the tag.
is_standalone: Whether the tag is standalone.
Returns:
bool: Whether the tag could be a standalone.
"""
# Check right side if we might be a standalone # Check right side if we might be a standalone
if is_standalone and tag_type not in ["variable", "no escape"]: if is_standalone and tag_type not in ["variable", "no escape"]:
@ -95,7 +121,20 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], str]: def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], str]:
"""Parse a tag from a template.""" """Parse a tag from a template.
Args:
template: The template.
l_del: The left delimiter.
r_del: The right delimiter.
Returns:
Tuple[Tuple[str, str], str]: The tag and the template.
Raises:
ChevronError: If the tag is unclosed.
ChevronError: If the set delimiter tag is unclosed.
"""
global _CURRENT_LINE global _CURRENT_LINE
global _LAST_TAG_LINE global _LAST_TAG_LINE
@ -404,36 +443,36 @@ def render(
Arguments: Arguments:
template -- A file-like object or a string containing the template template -- A file-like object or a string containing the template.
data -- A python dictionary with your data scope data -- A python dictionary with your data scope.
partials_path -- The path to where your partials are stored partials_path -- The path to where your partials are stored.
If set to None, then partials won't be loaded from the file system If set to None, then partials won't be loaded from the file system
(defaults to '.') (defaults to '.').
partials_ext -- The extension that you want the parser to look for partials_ext -- The extension that you want the parser to look for
(defaults to 'mustache') (defaults to 'mustache').
partials_dict -- A python dictionary which will be search for partials partials_dict -- A python dictionary which will be search for partials
before the filesystem is. {'include': 'foo'} is the same before the filesystem is. {'include': 'foo'} is the same
as a file called include.mustache as a file called include.mustache
(defaults to {}) (defaults to {}).
padding -- This is for padding partials, and shouldn't be used padding -- This is for padding partials, and shouldn't be used
(but can be if you really want to) (but can be if you really want to).
def_ldel -- The default left delimiter def_ldel -- The default left delimiter
("{{" by default, as in spec compliant mustache) ("{{" by default, as in spec compliant mustache).
def_rdel -- The default right delimiter def_rdel -- The default right delimiter
("}}" by default, as in spec compliant mustache) ("}}" by default, as in spec compliant mustache).
scopes -- The list of scopes that get_key will look through scopes -- The list of scopes that get_key will look through.
warn -- Log a warning when a template substitution isn't found in the data warn -- Log a warning when a template substitution isn't found in the data
keep -- Keep unreplaced tags when a substitution isn't found in the data keep -- Keep unreplaced tags when a substitution isn't found in the data.
Returns: Returns:

View File

@ -21,12 +21,27 @@ PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
# How to type hint this? # How to type hint this?
def pre_init(func: Callable) -> Any: def pre_init(func: Callable) -> Any:
"""Decorator to run a function before model initialization.""" """Decorator to run a function before model initialization.
Args:
func (Callable): The function to run before model initialization.
Returns:
Any: The decorated function.
"""
@root_validator(pre=True) @root_validator(pre=True)
@wraps(func) @wraps(func)
def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]: def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]:
"""Decorator to run a function before model initialization.""" """Decorator to run a function before model initialization.
Args:
cls (Type[BaseModel]): The model class.
values (Dict[str, Any]): The values to initialize the model with.
Returns:
Dict[str, Any]: The values to initialize the model with.
"""
# Insert default values # Insert default values
fields = cls.__fields__ fields = cls.__fields__
for name, field_info in fields.items(): for name, field_info in fields.items():

View File

@ -36,5 +36,12 @@ def stringify_dict(data: dict) -> str:
def comma_list(items: List[Any]) -> str: def comma_list(items: List[Any]) -> str:
"""Convert a list to a comma-separated string.""" """Convert a list to a comma-separated string.
Args:
items: The list to convert.
Returns:
str: The comma-separated string.
"""
return ", ".join(str(item) for item in items) return ", ".join(str(item) for item in items)

View File

@ -15,7 +15,18 @@ from langchain_core.pydantic_v1 import SecretStr
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable: def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
"""Validate specified keyword args are mutually exclusive.""" """Validate specified keyword args are mutually exclusive."
Args:
*arg_groups (Tuple[str, ...]): Groups of mutually exclusive keyword args.
Returns:
Callable: Decorator that validates the specified keyword args
are mutually exclusive
Raises:
ValueError: If more than one arg in a group is defined.
"""
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
@functools.wraps(func) @functools.wraps(func)
@ -41,7 +52,14 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
def raise_for_status_with_text(response: Response) -> None: def raise_for_status_with_text(response: Response) -> None:
"""Raise an error with the response text.""" """Raise an error with the response text.
Args:
response (Response): The response to check for errors.
Raises:
ValueError: If the response has an error status code.
"""
try: try:
response.raise_for_status() response.raise_for_status()
except HTTPError as e: except HTTPError as e:
@ -52,6 +70,12 @@ def raise_for_status_with_text(response: Response) -> None:
def mock_now(dt_value): # type: ignore def mock_now(dt_value): # type: ignore
"""Context manager for mocking out datetime.now() in unit tests. """Context manager for mocking out datetime.now() in unit tests.
Args:
dt_value: The datetime value to use for datetime.now().
Yields:
datetime.datetime: The mocked datetime class.
Example: Example:
with mock_now(datetime.datetime(2011, 2, 3, 10, 11)): with mock_now(datetime.datetime(2011, 2, 3, 10, 11)):
assert datetime.datetime.now() == datetime.datetime(2011, 2, 3, 10, 11) assert datetime.datetime.now() == datetime.datetime(2011, 2, 3, 10, 11)
@ -86,7 +110,21 @@ def guard_import(
module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None
) -> Any: ) -> Any:
"""Dynamically import a module and raise an exception if the module is not """Dynamically import a module and raise an exception if the module is not
installed.""" installed.
Args:
module_name (str): The name of the module to import.
pip_name (str, optional): The name of the module to install with pip.
Defaults to None.
package (str, optional): The package to import the module from.
Defaults to None.
Returns:
Any: The imported module.
Raises:
ImportError: If the module is not installed.
"""
try: try:
module = importlib.import_module(module_name, package) module = importlib.import_module(module_name, package)
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
@ -105,7 +143,22 @@ def check_package_version(
gt_version: Optional[str] = None, gt_version: Optional[str] = None,
gte_version: Optional[str] = None, gte_version: Optional[str] = None,
) -> None: ) -> None:
"""Check the version of a package.""" """Check the version of a package.
Args:
package (str): The name of the package.
lt_version (str, optional): The version must be less than this.
Defaults to None.
lte_version (str, optional): The version must be less than or equal to this.
Defaults to None.
gt_version (str, optional): The version must be greater than this.
Defaults to None.
gte_version (str, optional): The version must be greater than or equal to this.
Defaults to None.
Raises:
ValueError: If the package version does not meet the requirements.
"""
imported_version = parse(version(package)) imported_version = parse(version(package))
if lt_version is not None and imported_version >= parse(lt_version): if lt_version is not None and imported_version >= parse(lt_version):
raise ValueError( raise ValueError(
@ -133,7 +186,11 @@ def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]:
"""Get field names, including aliases, for a pydantic class. """Get field names, including aliases, for a pydantic class.
Args: Args:
pydantic_cls: Pydantic class.""" pydantic_cls: Pydantic class.
Returns:
Set[str]: Field names.
"""
all_required_field_names = set() all_required_field_names = set()
for field in pydantic_cls.__fields__.values(): for field in pydantic_cls.__fields__.values():
all_required_field_names.add(field.name) all_required_field_names.add(field.name)
@ -153,6 +210,13 @@ def build_extra_kwargs(
extra_kwargs: Extra kwargs passed in by user. extra_kwargs: Extra kwargs passed in by user.
values: Values passed in by user. values: Values passed in by user.
all_required_field_names: All required field names for the pydantic class. all_required_field_names: All required field names for the pydantic class.
Returns:
Dict[str, Any]: Extra kwargs.
Raises:
ValueError: If a field is specified in both values and extra_kwargs.
ValueError: If a field is specified in model_kwargs.
""" """
for field_name in list(values): for field_name in list(values):
if field_name in extra_kwargs: if field_name in extra_kwargs:
@ -176,7 +240,14 @@ def build_extra_kwargs(
def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr: def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr:
"""Convert a string to a SecretStr if needed.""" """Convert a string to a SecretStr if needed.
Args:
value (Union[SecretStr, str]): The value to convert.
Returns:
SecretStr: The SecretStr value.
"""
if isinstance(value, SecretStr): if isinstance(value, SecretStr):
return value return value
return SecretStr(value) return SecretStr(value)