From 36ee0837536d1164b80ae54f21c468fda86a4716 Mon Sep 17 00:00:00 2001 From: Leonid Ganeline Date: Mon, 15 Jul 2024 08:36:00 -0700 Subject: [PATCH] core: docstrings `utils` update (#24213) Added missed docstrings. Formatted docstrings to the consistent form. --- libs/core/langchain_core/utils/_merge.py | 38 ++++++++- libs/core/langchain_core/utils/aiter.py | 33 +++++++- libs/core/langchain_core/utils/env.py | 19 ++++- libs/core/langchain_core/utils/formatting.py | 23 ++++- .../langchain_core/utils/function_calling.py | 58 +++++++++++-- libs/core/langchain_core/utils/html.py | 20 ++--- libs/core/langchain_core/utils/image.py | 17 +++- libs/core/langchain_core/utils/input.py | 41 ++++++++- libs/core/langchain_core/utils/iter.py | 30 ++++++- libs/core/langchain_core/utils/json.py | 7 +- libs/core/langchain_core/utils/json_schema.py | 11 ++- libs/core/langchain_core/utils/mustache.py | 69 +++++++++++---- libs/core/langchain_core/utils/pydantic.py | 19 ++++- libs/core/langchain_core/utils/strings.py | 9 +- libs/core/langchain_core/utils/utils.py | 83 +++++++++++++++++-- 15 files changed, 419 insertions(+), 58 deletions(-) diff --git a/libs/core/langchain_core/utils/_merge.py b/libs/core/langchain_core/utils/_merge.py index da0acb42b5d..42cdde85493 100644 --- a/libs/core/langchain_core/utils/_merge.py +++ b/libs/core/langchain_core/utils/_merge.py @@ -8,6 +8,17 @@ def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any] dictionaries but has a value of None in 'left'. In such cases, the method uses the value from 'right' for that key in the merged dictionary. + Args: + left: The first dictionary to merge. + others: The other dictionaries to merge. + + Returns: + The merged dictionary. + + Raises: + TypeError: If the key exists in both dictionaries but has a different type. + TypeError: If the value has an unsupported type. + Example: If left = {"function_call": {"arguments": None}} and right = {"function_call": {"arguments": "{\n"}} @@ -46,7 +57,15 @@ def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any] def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]: - """Add many lists, handling None.""" + """Add many lists, handling None. + + Args: + left: The first list to merge. + others: The other lists to merge. + + Returns: + The merged list. + """ merged = left.copy() if left is not None else None for other in others: if other is None: @@ -75,6 +94,23 @@ def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List] def merge_obj(left: Any, right: Any) -> Any: + """Merge two objects. + + It handles specific scenarios where a key exists in both + dictionaries but has a value of None in 'left'. In such cases, the method uses the + value from 'right' for that key in the merged dictionary. + + Args: + left: The first object to merge. + right: The other object to merge. + + Returns: + The merged object. + + Raises: + TypeError: If the key exists in both dictionaries but has a different type. + ValueError: If the two objects cannot be merged. + """ if left is None or right is None: return left if left is not None else right elif type(left) is not type(right): diff --git a/libs/core/langchain_core/utils/aiter.py b/libs/core/langchain_core/utils/aiter.py index eb55079e601..0748e272da8 100644 --- a/libs/core/langchain_core/utils/aiter.py +++ b/libs/core/langchain_core/utils/aiter.py @@ -44,6 +44,18 @@ def py_anext( Can be used to compare the built-in implementation of the inner coroutines machinery to C-implementation of __anext__() and send() or throw() on the returned generator. + + Args: + iterator: The async iterator to advance. + default: The value to return if the iterator is exhausted. + If not provided, a StopAsyncIteration exception is raised. + + Returns: + The next value from the iterator, or the default value + if the iterator is exhausted. + + Raises: + TypeError: If the iterator is not an async iterator. """ try: @@ -71,7 +83,7 @@ def py_anext( class NoLock: - """Dummy lock that provides the proper interface but no protection""" + """Dummy lock that provides the proper interface but no protection.""" async def __aenter__(self) -> None: pass @@ -88,7 +100,21 @@ async def tee_peer( peers: List[Deque[T]], lock: AsyncContextManager[Any], ) -> AsyncGenerator[T, None]: - """An individual iterator of a :py:func:`~.tee`""" + """An individual iterator of a :py:func:`~.tee`. + + This function is a generator that yields items from the shared iterator + ``iterator``. It buffers items until the least advanced iterator has + yielded them as well. The buffer is shared with all other peers. + + Args: + iterator: The shared iterator. + buffer: The buffer for this peer. + peers: The buffers of all peers. + lock: The lock to synchronise access to the shared buffers. + + Yields: + The next item from the shared iterator. + """ try: while True: if not buffer: @@ -204,6 +230,7 @@ class Tee(Generic[T]): return False async def aclose(self) -> None: + """Async close all child iterators.""" for child in self._children: await child.aclose() @@ -258,7 +285,7 @@ async def abatch_iterate( iterable: The async iterable to batch. Returns: - An async iterator over the batches + An async iterator over the batches. """ batch: List[T] = [] async for element in iterable: diff --git a/libs/core/langchain_core/utils/env.py b/libs/core/langchain_core/utils/env.py index d509fca2501..6c5cff88819 100644 --- a/libs/core/langchain_core/utils/env.py +++ b/libs/core/langchain_core/utils/env.py @@ -36,7 +36,7 @@ def get_from_dict_or_env( env_key: The environment variable to look up if the key is not in the dictionary. default: The default value to return if the key is not in the dictionary - or the environment. + or the environment. Defaults to None. """ if isinstance(key, (list, tuple)): for k in key: @@ -56,7 +56,22 @@ def get_from_dict_or_env( def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str: - """Get a value from a dictionary or an environment variable.""" + """Get a value from a dictionary or an environment variable. + + Args: + key: The key to look up in the dictionary. + env_key: The environment variable to look up if the key is not + in the dictionary. + default: The default value to return if the key is not in the dictionary + or the environment. Defaults to None. + + Returns: + str: The value of the key. + + Raises: + ValueError: If the key is not in the dictionary and no default value is + provided or if the environment variable is not set. + """ if env_key in os.environ and os.environ[env_key]: return os.environ[env_key] elif default is not None: diff --git a/libs/core/langchain_core/utils/formatting.py b/libs/core/langchain_core/utils/formatting.py index f27051c12b8..b5b880f66b1 100644 --- a/libs/core/langchain_core/utils/formatting.py +++ b/libs/core/langchain_core/utils/formatting.py @@ -10,7 +10,19 @@ class StrictFormatter(Formatter): def vformat( self, format_string: str, args: Sequence, kwargs: Mapping[str, Any] ) -> str: - """Check that no arguments are provided.""" + """Check that no arguments are provided. + + Args: + format_string: The format string. + args: The arguments. + kwargs: The keyword arguments. + + Returns: + The formatted string. + + Raises: + ValueError: If any arguments are provided. + """ if len(args) > 0: raise ValueError( "No arguments should be provided, " @@ -21,6 +33,15 @@ class StrictFormatter(Formatter): def validate_input_variables( self, format_string: str, input_variables: List[str] ) -> None: + """Check that all input variables are used in the format string. + + Args: + format_string: The format string. + input_variables: The input variables. + + Raises: + ValueError: If any input variables are not used in the format string. + """ dummy_inputs = {input_variable: "foo" for input_variable in input_variables} super().format(format_string, **dummy_inputs) diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index 280db0d3ab6..5cc34060c85 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -55,7 +55,9 @@ class ToolDescription(TypedDict): """Representation of a callable function to the OpenAI API.""" type: Literal["function"] + """The type of the tool.""" function: FunctionDescription + """The function description.""" def _rm_titles(kv: dict, prev_key: str = "") -> dict: @@ -85,7 +87,19 @@ def convert_pydantic_to_openai_function( description: Optional[str] = None, rm_titles: bool = True, ) -> FunctionDescription: - """Converts a Pydantic model to a function description for the OpenAI API.""" + """Converts a Pydantic model to a function description for the OpenAI API. + + Args: + model: The Pydantic model to convert. + name: The name of the function. If not provided, the title of the schema will be + used. + description: The description of the function. If not provided, the description + of the schema will be used. + rm_titles: Whether to remove titles from the schema. Defaults to True. + + Returns: + The function description. + """ schema = dereference_refs(model.schema()) schema.pop("definitions", None) title = schema.pop("title", "") @@ -108,7 +122,18 @@ def convert_pydantic_to_openai_tool( name: Optional[str] = None, description: Optional[str] = None, ) -> ToolDescription: - """Converts a Pydantic model to a function description for the OpenAI API.""" + """Converts a Pydantic model to a function description for the OpenAI API. + + Args: + model: The Pydantic model to convert. + name: The name of the function. If not provided, the title of the schema will be + used. + description: The description of the function. If not provided, the description + of the schema will be used. + + Returns: + The tool description. + """ function = convert_pydantic_to_openai_function( model, name=name, description=description ) @@ -133,6 +158,12 @@ def convert_python_function_to_openai_function( Assumes the Python function has type hints and a docstring with a description. If the docstring has Google Python style argument descriptions, these will be included as well. + + Args: + function: The Python function to convert. + + Returns: + The OpenAI function description. """ from langchain_core import tools @@ -157,7 +188,14 @@ def convert_python_function_to_openai_function( removal="0.3.0", ) def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: - """Format tool into the OpenAI function API.""" + """Format tool into the OpenAI function API. + + Args: + tool: The tool to format. + + Returns: + The function description. + """ if tool.args_schema: return convert_pydantic_to_openai_function( tool.args_schema, name=tool.name, description=tool.description @@ -187,7 +225,14 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: removal="0.3.0", ) def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription: - """Format tool into the OpenAI function API.""" + """Format tool into the OpenAI function API. + + Args: + tool: The tool to format. + + Returns: + The tool description. + """ function = format_tool_to_openai_function(tool) return {"type": "function", "function": function} @@ -206,6 +251,9 @@ def convert_to_openai_function( Returns: A dict version of the passed in function which is compatible with the OpenAI function-calling API. + + Raises: + ValueError: If the function is not in a supported format. """ from langchain_core.tools import BaseTool @@ -284,7 +332,7 @@ def tool_example_to_messages( BaseModels tool_outputs: Optional[List[str]], a list of tool call outputs. Does not need to be provided. If not provided, a placeholder value - will be inserted. + will be inserted. Defaults to None. Returns: A list of messages diff --git a/libs/core/langchain_core/utils/html.py b/libs/core/langchain_core/utils/html.py index 3e41c187b4a..9750515bd1f 100644 --- a/libs/core/langchain_core/utils/html.py +++ b/libs/core/langchain_core/utils/html.py @@ -34,11 +34,11 @@ DEFAULT_LINK_REGEX = ( def find_all_links( raw_html: str, *, pattern: Union[str, re.Pattern, None] = None ) -> List[str]: - """Extract all links from a raw html string. + """Extract all links from a raw HTML string. Args: - raw_html: original html. - pattern: Regex to use for extracting links from raw html. + raw_html: original HTML. + pattern: Regex to use for extracting links from raw HTML. Returns: List[str]: all links @@ -57,20 +57,20 @@ def extract_sub_links( exclude_prefixes: Sequence[str] = (), continue_on_failure: bool = False, ) -> List[str]: - """Extract all links from a raw html string and convert into absolute paths. + """Extract all links from a raw HTML string and convert into absolute paths. Args: - raw_html: original html. - url: the url of the html. - base_url: the base url to check for outside links against. - pattern: Regex to use for extracting links from raw html. + raw_html: original HTML. + url: the url of the HTML. + base_url: the base URL to check for outside links against. + pattern: Regex to use for extracting links from raw HTML. prevent_outside: If True, ignore external links which are not children - of the base url. + of the base URL. exclude_prefixes: Exclude any URLs that start with one of these prefixes. continue_on_failure: If True, continue if parsing a specific link raises an exception. Otherwise, raise the exception. Returns: - List[str]: sub links + List[str]: sub links. """ base_url_to_use = base_url if base_url is not None else url parsed_base_url = urlparse(base_url_to_use) diff --git a/libs/core/langchain_core/utils/image.py b/libs/core/langchain_core/utils/image.py index b59682bd37f..708f0c7fdea 100644 --- a/libs/core/langchain_core/utils/image.py +++ b/libs/core/langchain_core/utils/image.py @@ -3,12 +3,27 @@ import mimetypes def encode_image(image_path: str) -> str: - """Get base64 string from image URI.""" + """Get base64 string from image URI. + + Args: + image_path: The path to the image. + + Returns: + The base64 string of the image. + """ with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") def image_to_data_url(image_path: str) -> str: + """Get data URL from image URI. + + Args: + image_path: The path to the image. + + Returns: + The data URL of the image. + """ encoding = encode_image(image_path) mime_type = mimetypes.guess_type(image_path)[0] return f"data:{mime_type};base64,{encoding}" diff --git a/libs/core/langchain_core/utils/input.py b/libs/core/langchain_core/utils/input.py index 2e52a0e282f..7edfcd7b6cb 100644 --- a/libs/core/langchain_core/utils/input.py +++ b/libs/core/langchain_core/utils/input.py @@ -14,7 +14,15 @@ _TEXT_COLOR_MAPPING = { def get_color_mapping( items: List[str], excluded_colors: Optional[List] = None ) -> Dict[str, str]: - """Get mapping for items to a support color.""" + """Get mapping for items to a support color. + + Args: + items: The items to map to colors. + excluded_colors: The colors to exclude. + + Returns: + The mapping of items to colors. + """ colors = list(_TEXT_COLOR_MAPPING.keys()) if excluded_colors is not None: colors = [c for c in colors if c not in excluded_colors] @@ -23,20 +31,45 @@ def get_color_mapping( def get_colored_text(text: str, color: str) -> str: - """Get colored text.""" + """Get colored text. + + Args: + text: The text to color. + color: The color to use. + + Returns: + The colored text. + """ color_str = _TEXT_COLOR_MAPPING[color] return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" def get_bolded_text(text: str) -> str: - """Get bolded text.""" + """Get bolded text. + + Args: + text: The text to bold. + + Returns: + The bolded text. + """ return f"\033[1m{text}\033[0m" def print_text( text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None ) -> None: - """Print text with highlighting and no end characters.""" + """Print text with highlighting and no end characters. + + If a color is provided, the text will be printed in that color. + If a file is provided, the text will be written to that file. + + Args: + text: The text to print. + color: The color to use. Defaults to None. + end: The end character to use. Defaults to "". + file: The file to write to. Defaults to None. + """ text_to_print = get_colored_text(text, color) if color else text print(text_to_print, end=end, file=file) if file: diff --git a/libs/core/langchain_core/utils/iter.py b/libs/core/langchain_core/utils/iter.py index 4ebc9381c9b..9645b75597b 100644 --- a/libs/core/langchain_core/utils/iter.py +++ b/libs/core/langchain_core/utils/iter.py @@ -22,7 +22,7 @@ T = TypeVar("T") class NoLock: - """Dummy lock that provides the proper interface but no protection""" + """Dummy lock that provides the proper interface but no protection.""" def __enter__(self) -> None: pass @@ -39,7 +39,21 @@ def tee_peer( peers: List[Deque[T]], lock: ContextManager[Any], ) -> Generator[T, None, None]: - """An individual iterator of a :py:func:`~.tee`""" + """An individual iterator of a :py:func:`~.tee`. + + This function is a generator that yields items from the shared iterator + ``iterator``. It buffers items until the least advanced iterator has + yielded them as well. The buffer is shared with all other peers. + + Args: + iterator: The shared iterator. + buffer: The buffer for this peer. + peers: The buffers of all peers. + lock: The lock to synchronise access to the shared buffers. + + Yields: + The next item from the shared iterator. + """ try: while True: if not buffer: @@ -118,6 +132,14 @@ class Tee(Generic[T]): *, lock: Optional[ContextManager[Any]] = None, ): + """Create a new ``tee``. + + Args: + iterable: The iterable to split. + n: The number of iterators to create. Defaults to 2. + lock: The lock to synchronise access to the shared buffers. + Defaults to None. + """ self._iterator = iter(iterable) self._buffers: List[Deque[T]] = [deque() for _ in range(n)] self._children = tuple( @@ -170,8 +192,8 @@ def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[List[T size: The size of the batch. If None, returns a single batch. iterable: The iterable to batch. - Returns: - An iterator over the batches. + Yields: + The batches of the iterable. """ it = iter(iterable) while True: diff --git a/libs/core/langchain_core/utils/json.py b/libs/core/langchain_core/utils/json.py index 7e1d42a555d..61b21d867a4 100644 --- a/libs/core/langchain_core/utils/json.py +++ b/libs/core/langchain_core/utils/json.py @@ -124,8 +124,7 @@ _json_markdown_re = re.compile(r"```(json)?(.*)", re.DOTALL) def parse_json_markdown( json_string: str, *, parser: Callable[[str], Any] = parse_partial_json ) -> dict: - """ - Parse a JSON string from a Markdown string. + """Parse a JSON string from a Markdown string. Args: json_string: The Markdown string. @@ -175,6 +174,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict: Returns: The parsed JSON object as a Python dictionary. + + Raises: + OutputParserException: If the JSON string is invalid or does not contain + the expected keys. """ try: json_obj = parse_json_markdown(text) diff --git a/libs/core/langchain_core/utils/json_schema.py b/libs/core/langchain_core/utils/json_schema.py index aecdae470d8..3368f838133 100644 --- a/libs/core/langchain_core/utils/json_schema.py +++ b/libs/core/langchain_core/utils/json_schema.py @@ -90,7 +90,16 @@ def dereference_refs( full_schema: Optional[dict] = None, skip_keys: Optional[Sequence[str]] = None, ) -> dict: - """Try to substitute $refs in JSON Schema.""" + """Try to substitute $refs in JSON Schema. + + Args: + schema_obj: The schema object to dereference. + full_schema: The full schema object. Defaults to None. + skip_keys: The keys to skip. Defaults to None. + + Returns: + The dereferenced schema object. + """ full_schema = full_schema or schema_obj skip_keys = ( diff --git a/libs/core/langchain_core/utils/mustache.py b/libs/core/langchain_core/utils/mustache.py index 09d0284eced..4d135354f7d 100644 --- a/libs/core/langchain_core/utils/mustache.py +++ b/libs/core/langchain_core/utils/mustache.py @@ -42,7 +42,15 @@ class ChevronError(SyntaxError): def grab_literal(template: str, l_del: str) -> Tuple[str, str]: - """Parse a literal from the template.""" + """Parse a literal from the template. + + Args: + template: The template to parse. + l_del: The left delimiter. + + Returns: + Tuple[str, str]: The literal and the template. + """ global _CURRENT_LINE @@ -59,7 +67,16 @@ def grab_literal(template: str, l_del: str) -> Tuple[str, str]: def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool: - """Do a preliminary check to see if a tag could be a standalone.""" + """Do a preliminary check to see if a tag could be a standalone. + + Args: + template: The template. (Not used.) + literal: The literal. + is_standalone: Whether the tag is standalone. + + Returns: + bool: Whether the tag could be a standalone. + """ # If there is a newline, or the previous tag was a standalone if literal.find("\n") != -1 or is_standalone: @@ -77,7 +94,16 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool: def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool: - """Do a final check to see if a tag could be a standalone.""" + """Do a final check to see if a tag could be a standalone. + + Args: + template: The template. + tag_type: The type of the tag. + is_standalone: Whether the tag is standalone. + + Returns: + bool: Whether the tag could be a standalone. + """ # Check right side if we might be a standalone if is_standalone and tag_type not in ["variable", "no escape"]: @@ -95,7 +121,20 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool: def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], str]: - """Parse a tag from a template.""" + """Parse a tag from a template. + + Args: + template: The template. + l_del: The left delimiter. + r_del: The right delimiter. + + Returns: + Tuple[Tuple[str, str], str]: The tag and the template. + + Raises: + ChevronError: If the tag is unclosed. + ChevronError: If the set delimiter tag is unclosed. + """ global _CURRENT_LINE global _LAST_TAG_LINE @@ -404,36 +443,36 @@ def render( Arguments: - template -- A file-like object or a string containing the template + template -- A file-like object or a string containing the template. - data -- A python dictionary with your data scope + data -- A python dictionary with your data scope. - partials_path -- The path to where your partials are stored + partials_path -- The path to where your partials are stored. If set to None, then partials won't be loaded from the file system - (defaults to '.') + (defaults to '.'). partials_ext -- The extension that you want the parser to look for - (defaults to 'mustache') + (defaults to 'mustache'). partials_dict -- A python dictionary which will be search for partials before the filesystem is. {'include': 'foo'} is the same as a file called include.mustache - (defaults to {}) + (defaults to {}). padding -- This is for padding partials, and shouldn't be used - (but can be if you really want to) + (but can be if you really want to). def_ldel -- The default left delimiter - ("{{" by default, as in spec compliant mustache) + ("{{" by default, as in spec compliant mustache). def_rdel -- The default right delimiter - ("}}" by default, as in spec compliant mustache) + ("}}" by default, as in spec compliant mustache). - scopes -- The list of scopes that get_key will look through + scopes -- The list of scopes that get_key will look through. warn -- Log a warning when a template substitution isn't found in the data - keep -- Keep unreplaced tags when a substitution isn't found in the data + keep -- Keep unreplaced tags when a substitution isn't found in the data. Returns: diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index a79eb70fefe..e60a8d9e21e 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -21,12 +21,27 @@ PYDANTIC_MAJOR_VERSION = get_pydantic_major_version() # How to type hint this? def pre_init(func: Callable) -> Any: - """Decorator to run a function before model initialization.""" + """Decorator to run a function before model initialization. + + Args: + func (Callable): The function to run before model initialization. + + Returns: + Any: The decorated function. + """ @root_validator(pre=True) @wraps(func) def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]: - """Decorator to run a function before model initialization.""" + """Decorator to run a function before model initialization. + + Args: + cls (Type[BaseModel]): The model class. + values (Dict[str, Any]): The values to initialize the model with. + + Returns: + Dict[str, Any]: The values to initialize the model with. + """ # Insert default values fields = cls.__fields__ for name, field_info in fields.items(): diff --git a/libs/core/langchain_core/utils/strings.py b/libs/core/langchain_core/utils/strings.py index 3e866f059ff..f54ddb4fbaa 100644 --- a/libs/core/langchain_core/utils/strings.py +++ b/libs/core/langchain_core/utils/strings.py @@ -36,5 +36,12 @@ def stringify_dict(data: dict) -> str: def comma_list(items: List[Any]) -> str: - """Convert a list to a comma-separated string.""" + """Convert a list to a comma-separated string. + + Args: + items: The list to convert. + + Returns: + str: The comma-separated string. + """ return ", ".join(str(item) for item in items) diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index d989d94cfa0..4ee4ee6de67 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -15,7 +15,18 @@ from langchain_core.pydantic_v1 import SecretStr def xor_args(*arg_groups: Tuple[str, ...]) -> Callable: - """Validate specified keyword args are mutually exclusive.""" + """Validate specified keyword args are mutually exclusive." + + Args: + *arg_groups (Tuple[str, ...]): Groups of mutually exclusive keyword args. + + Returns: + Callable: Decorator that validates the specified keyword args + are mutually exclusive + + Raises: + ValueError: If more than one arg in a group is defined. + """ def decorator(func: Callable) -> Callable: @functools.wraps(func) @@ -41,7 +52,14 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable: def raise_for_status_with_text(response: Response) -> None: - """Raise an error with the response text.""" + """Raise an error with the response text. + + Args: + response (Response): The response to check for errors. + + Raises: + ValueError: If the response has an error status code. + """ try: response.raise_for_status() except HTTPError as e: @@ -52,6 +70,12 @@ def raise_for_status_with_text(response: Response) -> None: def mock_now(dt_value): # type: ignore """Context manager for mocking out datetime.now() in unit tests. + Args: + dt_value: The datetime value to use for datetime.now(). + + Yields: + datetime.datetime: The mocked datetime class. + Example: with mock_now(datetime.datetime(2011, 2, 3, 10, 11)): assert datetime.datetime.now() == datetime.datetime(2011, 2, 3, 10, 11) @@ -86,7 +110,21 @@ def guard_import( module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None ) -> Any: """Dynamically import a module and raise an exception if the module is not - installed.""" + installed. + + Args: + module_name (str): The name of the module to import. + pip_name (str, optional): The name of the module to install with pip. + Defaults to None. + package (str, optional): The package to import the module from. + Defaults to None. + + Returns: + Any: The imported module. + + Raises: + ImportError: If the module is not installed. + """ try: module = importlib.import_module(module_name, package) except (ImportError, ModuleNotFoundError): @@ -105,7 +143,22 @@ def check_package_version( gt_version: Optional[str] = None, gte_version: Optional[str] = None, ) -> None: - """Check the version of a package.""" + """Check the version of a package. + + Args: + package (str): The name of the package. + lt_version (str, optional): The version must be less than this. + Defaults to None. + lte_version (str, optional): The version must be less than or equal to this. + Defaults to None. + gt_version (str, optional): The version must be greater than this. + Defaults to None. + gte_version (str, optional): The version must be greater than or equal to this. + Defaults to None. + + Raises: + ValueError: If the package version does not meet the requirements. + """ imported_version = parse(version(package)) if lt_version is not None and imported_version >= parse(lt_version): raise ValueError( @@ -133,7 +186,11 @@ def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]: """Get field names, including aliases, for a pydantic class. Args: - pydantic_cls: Pydantic class.""" + pydantic_cls: Pydantic class. + + Returns: + Set[str]: Field names. + """ all_required_field_names = set() for field in pydantic_cls.__fields__.values(): all_required_field_names.add(field.name) @@ -153,6 +210,13 @@ def build_extra_kwargs( extra_kwargs: Extra kwargs passed in by user. values: Values passed in by user. all_required_field_names: All required field names for the pydantic class. + + Returns: + Dict[str, Any]: Extra kwargs. + + Raises: + ValueError: If a field is specified in both values and extra_kwargs. + ValueError: If a field is specified in model_kwargs. """ for field_name in list(values): if field_name in extra_kwargs: @@ -176,7 +240,14 @@ def build_extra_kwargs( def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr: - """Convert a string to a SecretStr if needed.""" + """Convert a string to a SecretStr if needed. + + Args: + value (Union[SecretStr, str]): The value to convert. + + Returns: + SecretStr: The SecretStr value. + """ if isinstance(value, SecretStr): return value return SecretStr(value)