From 55f6f91f1738287f89238f7f9eea96bba8b1b5ef Mon Sep 17 00:00:00 2001 From: Leonid Ganeline Date: Wed, 3 Jul 2024 11:27:40 -0700 Subject: [PATCH] core[patch]: docstrings `output_parsers` (#23825) Added missed docstrings. Formatted docstrings to the consistent form. --- .../langchain_core/output_parsers/base.py | 26 ++++- .../langchain_core/output_parsers/json.py | 31 ++++++ .../langchain_core/output_parsers/list.py | 88 +++++++++++++++-- .../output_parsers/openai_functions.py | 71 ++++++++++++++ .../output_parsers/openai_tools.py | 97 ++++++++++++++++++- .../langchain_core/output_parsers/pydantic.py | 25 +++++ .../output_parsers/transform.py | 30 +++++- .../core/langchain_core/output_parsers/xml.py | 39 +++++++- 8 files changed, 387 insertions(+), 20 deletions(-) diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index cad5da3b6a9..caac385b6bd 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -38,6 +38,8 @@ class BaseLLMOutputParser(Generic[T], ABC): Args: result: A list of Generations to be parsed. The Generations are assumed to be different candidate outputs for a single model input. + partial: Whether to parse the output as a partial result. This is useful + for parsers that can parse partial results. Default is False. Returns: Structured output. @@ -46,11 +48,13 @@ class BaseLLMOutputParser(Generic[T], ABC): async def aparse_result( self, result: List[Generation], *, partial: bool = False ) -> T: - """Parse a list of candidate model Generations into a specific format. + """Async parse a list of candidate model Generations into a specific format. Args: result: A list of Generations to be parsed. The Generations are assumed to be different candidate outputs for a single model input. + partial: Whether to parse the output as a partial result. This is useful + for parsers that can parse partial results. Default is False. Returns: Structured output. @@ -65,10 +69,12 @@ class BaseGenerationOutputParser( @property def InputType(self) -> Any: + """Return the input type for the parser.""" return Union[str, AnyMessage] @property def OutputType(self) -> Type[T]: + """Return the output type for the parser.""" # even though mypy complains this isn't valid, # it is good enough for pydantic to build the schema from return T # type: ignore[misc] @@ -148,10 +154,18 @@ class BaseOutputParser( @property def InputType(self) -> Any: + """Return the input type for the parser.""" return Union[str, AnyMessage] @property def OutputType(self) -> Type[T]: + """Return the output type for the parser. + + This property is inferred from the first type argument of the class. + + Raises: + TypeError: If the class doesn't have an inferable OutputType. + """ for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined] type_args = get_args(cls) if type_args and len(type_args) == 1: @@ -214,6 +228,8 @@ class BaseOutputParser( Args: result: A list of Generations to be parsed. The Generations are assumed to be different candidate outputs for a single model input. + partial: Whether to parse the output as a partial result. This is useful + for parsers that can parse partial results. Default is False. Returns: Structured output. @@ -234,7 +250,7 @@ class BaseOutputParser( async def aparse_result( self, result: List[Generation], *, partial: bool = False ) -> T: - """Parse a list of candidate model Generations into a specific format. + """Async parse a list of candidate model Generations into a specific format. The return value is parsed from only the first Generation in the result, which is assumed to be the highest-likelihood Generation. @@ -242,6 +258,8 @@ class BaseOutputParser( Args: result: A list of Generations to be parsed. The Generations are assumed to be different candidate outputs for a single model input. + partial: Whether to parse the output as a partial result. This is useful + for parsers that can parse partial results. Default is False. Returns: Structured output. @@ -249,7 +267,7 @@ class BaseOutputParser( return await run_in_executor(None, self.parse_result, result, partial=partial) async def aparse(self, text: str) -> T: - """Parse a single string model output into some structure. + """Async parse a single string model output into some structure. Args: text: String output of a language model. @@ -272,7 +290,7 @@ class BaseOutputParser( prompt: Input PromptValue. Returns: - Structured output + Structured output. """ return self.parse(completion) diff --git a/libs/core/langchain_core/output_parsers/json.py b/libs/core/langchain_core/output_parsers/json.py index 9652fde424e..58a7090be70 100644 --- a/libs/core/langchain_core/output_parsers/json.py +++ b/libs/core/langchain_core/output_parsers/json.py @@ -41,6 +41,8 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]): """ pydantic_object: Optional[Type[TBaseModel]] = None # type: ignore + """The Pydantic object to use for validation. + If None, no validation is performed.""" def _diff(self, prev: Optional[Any], next: Any) -> Any: return jsonpatch.make_patch(prev, next).patch @@ -54,6 +56,22 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]): return pydantic_object.schema() def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + """Parse the result of an LLM call to a JSON object. + + Args: + result: The result of the LLM call. + partial: Whether to parse partial JSON objects. + If True, the output will be a JSON object containing + all the keys that have been returned so far. + If False, the output will be the full JSON object. + Default is False. + + Returns: + The parsed JSON object. + + Raises: + OutputParserException: If the output is not valid JSON. + """ text = result[0].text text = text.strip() if partial: @@ -69,9 +87,22 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]): raise OutputParserException(msg, llm_output=text) from e def parse(self, text: str) -> Any: + """Parse the output of an LLM call to a JSON object. + + Args: + text: The output of the LLM call. + + Returns: + The parsed JSON object. + """ return self.parse_result([Generation(text=text)]) def get_format_instructions(self) -> str: + """Return the format instructions for the JSON output. + + Returns: + The format instructions for the JSON output. + """ if self.pydantic_object is None: return "Return a JSON object." else: diff --git a/libs/core/langchain_core/output_parsers/list.py b/libs/core/langchain_core/output_parsers/list.py index 2fdb3e0f10a..34371946b6e 100644 --- a/libs/core/langchain_core/output_parsers/list.py +++ b/libs/core/langchain_core/output_parsers/list.py @@ -12,7 +12,15 @@ T = TypeVar("T") def droplastn(iter: Iterator[T], n: int) -> Iterator[T]: - """Drop the last n elements of an iterator.""" + """Drop the last n elements of an iterator. + + Args: + iter: The iterator to drop elements from. + n: The number of elements to drop. + + Yields: + The elements of the iterator, except the last n elements. + """ buffer: Deque[T] = deque() for item in iter: buffer.append(item) @@ -29,10 +37,24 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]): @abstractmethod def parse(self, text: str) -> List[str]: - """Parse the output of an LLM call.""" + """Parse the output of an LLM call. + + Args: + text: The output of an LLM call. + + Returns: + A list of strings. + """ def parse_iter(self, text: str) -> Iterator[re.Match]: - """Parse the output of an LLM call.""" + """Parse the output of an LLM call. + + Args: + text: The output of an LLM call. + + Yields: + A match object for each part of the output. + """ raise NotImplementedError def _transform( @@ -105,21 +127,36 @@ class CommaSeparatedListOutputParser(ListOutputParser): @classmethod def is_lc_serializable(cls) -> bool: + """Check if the langchain object is serializable. + Returns True.""" return True @classmethod def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" + """Get the namespace of the langchain object. + + Returns: + A list of strings. + Default is ["langchain", "output_parsers", "list"]. + """ return ["langchain", "output_parsers", "list"] def get_format_instructions(self) -> str: + """Return the format instructions for the comma-separated list output.""" return ( "Your response should be a list of comma separated values, " "eg: `foo, bar, baz` or `foo,bar,baz`" ) def parse(self, text: str) -> List[str]: - """Parse the output of an LLM call.""" + """Parse the output of an LLM call. + + Args: + text: The output of an LLM call. + + Returns: + A list of strings. + """ return [part.strip() for part in text.split(",")] @property @@ -131,6 +168,7 @@ class NumberedListOutputParser(ListOutputParser): """Parse a numbered list.""" pattern: str = r"\d+\.\s([^\n]+)" + """The pattern to match a numbered list item.""" def get_format_instructions(self) -> str: return ( @@ -139,11 +177,25 @@ class NumberedListOutputParser(ListOutputParser): ) def parse(self, text: str) -> List[str]: - """Parse the output of an LLM call.""" + """Parse the output of an LLM call. + + Args: + text: The output of an LLM call. + + Returns: + A list of strings. + """ return re.findall(self.pattern, text) def parse_iter(self, text: str) -> Iterator[re.Match]: - """Parse the output of an LLM call.""" + """Parse the output of an LLM call. + + Args: + text: The output of an LLM call. + + Yields: + A match object for each part of the output. + """ return re.finditer(self.pattern, text) @property @@ -152,19 +204,35 @@ class NumberedListOutputParser(ListOutputParser): class MarkdownListOutputParser(ListOutputParser): - """Parse a markdown list.""" + """Parse a Markdown list.""" pattern: str = r"^\s*[-*]\s([^\n]+)$" + """The pattern to match a Markdown list item.""" def get_format_instructions(self) -> str: + """Return the format instructions for the Markdown list output.""" return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`" def parse(self, text: str) -> List[str]: - """Parse the output of an LLM call.""" + """Parse the output of an LLM call. + + Args: + text: The output of an LLM call. + + Returns: + A list of strings. + """ return re.findall(self.pattern, text, re.MULTILINE) def parse_iter(self, text: str) -> Iterator[re.Match]: - """Parse the output of an LLM call.""" + """Parse the output of an LLM call. + + Args: + text: The output of an LLM call. + + Yields: + A match object for each part of the output. + """ return re.finditer(self.pattern, text, re.MULTILINE) @property diff --git a/libs/core/langchain_core/output_parsers/openai_functions.py b/libs/core/langchain_core/output_parsers/openai_functions.py index 08391f5243f..bbb50eaade9 100644 --- a/libs/core/langchain_core/output_parsers/openai_functions.py +++ b/libs/core/langchain_core/output_parsers/openai_functions.py @@ -21,6 +21,18 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]): """Whether to only return the arguments to the function call.""" def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + """Parse the result of an LLM call to a JSON object. + + Args: + result: The result of the LLM call. + partial: Whether to parse partial JSON objects. Default is False. + + Returns: + The parsed JSON object. + + Raises: + OutputParserException: If the output is not valid JSON. + """ generation = result[0] if not isinstance(generation, ChatGeneration): raise OutputParserException( @@ -59,6 +71,19 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]): return jsonpatch.make_patch(prev, next).patch def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + """Parse the result of an LLM call to a JSON object. + + Args: + result: The result of the LLM call. + partial: Whether to parse partial JSON objects. Default is False. + + Returns: + The parsed JSON object. + + Raises: + OutputParserException: If the output is not valid JSON. + """ + if len(result) != 1: raise OutputParserException( f"Expected exactly one result, but got {len(result)}" @@ -120,6 +145,14 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]): # This method would be called by the default implementation of `parse_result` # but we're overriding that method so it's not needed. def parse(self, text: str) -> Any: + """Parse the output of an LLM call to a JSON object. + + Args: + text: The output of the LLM call. + + Returns: + The parsed JSON object. + """ raise NotImplementedError() @@ -130,6 +163,15 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser): """The name of the key to return.""" def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + """Parse the result of an LLM call to a JSON object. + + Args: + result: The result of the LLM call. + partial: Whether to parse partial JSON objects. Default is False. + + Returns: + The parsed JSON object. + """ res = super().parse_result(result, partial=partial) if partial and res is None: return None @@ -186,6 +228,17 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser): @root_validator(pre=True) def validate_schema(cls, values: Dict) -> Dict: + """Validate the pydantic schema. + + Args: + values: The values to validate. + + Returns: + The validated values. + + Raises: + ValueError: If the schema is not a pydantic schema. + """ schema = values["pydantic_schema"] if "args_only" not in values: values["args_only"] = isinstance(schema, type) and issubclass( @@ -199,6 +252,15 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser): return values def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + """Parse the result of an LLM call to a JSON object. + + Args: + result: The result of the LLM call. + partial: Whether to parse partial JSON objects. Default is False. + + Returns: + The parsed JSON object. + """ _result = super().parse_result(result) if self.args_only: pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore @@ -216,5 +278,14 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser): """The name of the attribute to return.""" def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + """Parse the result of an LLM call to a JSON object. + + Args: + result: The result of the LLM call. + partial: Whether to parse partial JSON objects. Default is False. + + Returns: + The parsed JSON object. + """ result = super().parse_result(result) return getattr(result, self.attr_name) diff --git a/libs/core/langchain_core/output_parsers/openai_tools.py b/libs/core/langchain_core/output_parsers/openai_tools.py index 370afe2195a..acc7fdb94fe 100644 --- a/libs/core/langchain_core/output_parsers/openai_tools.py +++ b/libs/core/langchain_core/output_parsers/openai_tools.py @@ -18,7 +18,21 @@ def parse_tool_call( strict: bool = False, return_id: bool = True, ) -> Optional[Dict[str, Any]]: - """Parse a single tool call.""" + """Parse a single tool call. + + Args: + raw_tool_call: The raw tool call to parse. + partial: Whether to parse partial JSON. Default is False. + strict: Whether to allow non-JSON-compliant strings. + Default is False. + return_id: Whether to return the tool call id. Default is True. + + Returns: + The parsed tool call. + + Raises: + OutputParserException: If the tool call is not valid JSON. + """ if "function" not in raw_tool_call: return None if partial: @@ -52,7 +66,15 @@ def make_invalid_tool_call( raw_tool_call: Dict[str, Any], error_msg: Optional[str], ) -> InvalidToolCall: - """Create an InvalidToolCall from a raw tool call.""" + """Create an InvalidToolCall from a raw tool call. + + Args: + raw_tool_call: The raw tool call. + error_msg: The error message. + + Returns: + An InvalidToolCall instance with the error message. + """ return InvalidToolCall( name=raw_tool_call["function"]["name"], args=raw_tool_call["function"]["arguments"], @@ -68,7 +90,21 @@ def parse_tool_calls( strict: bool = False, return_id: bool = True, ) -> List[Dict[str, Any]]: - """Parse a list of tool calls.""" + """Parse a list of tool calls. + + Args: + raw_tool_calls: The raw tool calls to parse. + partial: Whether to parse partial JSON. Default is False. + strict: Whether to allow non-JSON-compliant strings. + Default is False. + return_id: Whether to return the tool call id. Default is True. + + Returns: + The parsed tool calls. + + Raises: + OutputParserException: If any of the tool calls are not valid JSON. + """ final_tools: List[Dict[str, Any]] = [] exceptions = [] for tool_call in raw_tool_calls: @@ -110,6 +146,23 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]): """ def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + """Parse the result of an LLM call to a list of tool calls. + + Args: + result: The result of the LLM call. + partial: Whether to parse partial JSON. + If True, the output will be a JSON object containing + all the keys that have been returned so far. + If False, the output will be the full JSON object. + Default is False. + + Returns: + The parsed tool calls. + + Raises: + OutputParserException: If the output is not valid JSON. + """ + generation = result[0] if not isinstance(generation, ChatGeneration): raise OutputParserException( @@ -141,6 +194,14 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]): return tool_calls def parse(self, text: str) -> Any: + """Parse the output of an LLM call to a list of tool calls. + + Args: + text: The output of the LLM call. + + Returns: + The parsed tool calls. + """ raise NotImplementedError() @@ -151,6 +212,19 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser): """The type of tools to return.""" def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + """Parse the result of an LLM call to a list of tool calls. + + Args: + result: The result of the LLM call. + partial: Whether to parse partial JSON. + If True, the output will be a JSON object containing + all the keys that have been returned so far. + If False, the output will be the full JSON object. + Default is False. + + Returns: + The parsed tool calls. + """ parsed_result = super().parse_result(result, partial=partial) if self.first_tool_only: @@ -175,10 +249,27 @@ class PydanticToolsParser(JsonOutputToolsParser): """Parse tools from OpenAI response.""" tools: List[Type[BaseModel]] + """The tools to parse.""" # TODO: Support more granular streaming of objects. Currently only streams once all # Pydantic object fields are present. def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + """Parse the result of an LLM call to a list of Pydantic objects. + + Args: + result: The result of the LLM call. + partial: Whether to parse partial JSON. + If True, the output will be a JSON object containing + all the keys that have been returned so far. + If False, the output will be the full JSON object. + Default is False. + + Returns: + The parsed Pydantic objects. + + Raises: + OutputParserException: If the output is not valid JSON. + """ json_results = super().parse_result(result, partial=partial) if not json_results: return None if self.first_tool_only else [] diff --git a/libs/core/langchain_core/output_parsers/pydantic.py b/libs/core/langchain_core/output_parsers/pydantic.py index 73444d45af2..1c2debcb6b1 100644 --- a/libs/core/langchain_core/output_parsers/pydantic.py +++ b/libs/core/langchain_core/output_parsers/pydantic.py @@ -57,13 +57,38 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]): def parse_result( self, result: List[Generation], *, partial: bool = False ) -> TBaseModel: + """Parse the result of an LLM call to a pydantic object. + + Args: + result: The result of the LLM call. + partial: Whether to parse partial JSON objects. + If True, the output will be a JSON object containing + all the keys that have been returned so far. + Defaults to False. + + Returns: + The parsed pydantic object. + """ json_object = super().parse_result(result) return self._parse_obj(json_object) def parse(self, text: str) -> TBaseModel: + """Parse the output of an LLM call to a pydantic object. + + Args: + text: The output of the LLM call. + + Returns: + The parsed pydantic object. + """ return super().parse(text) def get_format_instructions(self) -> str: + """Return the format instructions for the JSON output. + + Returns: + The format instructions for the JSON output. + """ # Copy schema to avoid altering original Pydantic schema. schema = {k: v for k, v in self.pydantic_object.schema().items()} diff --git a/libs/core/langchain_core/output_parsers/transform.py b/libs/core/langchain_core/output_parsers/transform.py index 96688174d4a..c0b93629378 100644 --- a/libs/core/langchain_core/output_parsers/transform.py +++ b/libs/core/langchain_core/output_parsers/transform.py @@ -47,6 +47,16 @@ class BaseTransformOutputParser(BaseOutputParser[T]): config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Iterator[T]: + """Transform the input into the output format. + + Args: + input: The input to transform. + config: The configuration to use for the transformation. + kwargs: Additional keyword arguments. + + Yields: + The transformed output. + """ yield from self._transform_stream_with_config( input, self._transform, config, run_type="parser" ) @@ -57,6 +67,16 @@ class BaseTransformOutputParser(BaseOutputParser[T]): config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> AsyncIterator[T]: + """Async transform the input into the output format. + + Args: + input: The input to transform. + config: The configuration to use for the transformation. + kwargs: Additional keyword arguments. + + Yields: + The transformed output. + """ async for chunk in self._atransform_stream_with_config( input, self._atransform, config, run_type="parser" ): @@ -73,7 +93,15 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): def _diff(self, prev: Optional[T], next: T) -> T: """Convert parsed outputs into a diff format. The semantics of this are - up to the output parser.""" + up to the output parser. + + Args: + prev: The previous parsed output. + next: The current parsed output. + + Returns: + The diff between the previous and current parsed output. + """ raise NotImplementedError() def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index 890a4d7c717..238c7f6d14b 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -38,6 +38,10 @@ class _StreamingParser: Args: parser: Parser to use for XML parsing. Can be either 'defusedxml' or 'xml'. See documentation in XMLOutputParser for more information. + + Raises: + ImportError: If defusedxml is not installed and the defusedxml + parser is requested. """ if parser == "defusedxml": try: @@ -66,6 +70,9 @@ class _StreamingParser: Yields: AddableDict: A dictionary representing the parsed XML element. + + Raises: + xml.etree.ElementTree.ParseError: If the XML is not well-formed. """ if isinstance(chunk, BaseMessage): # extract text @@ -116,7 +123,13 @@ class _StreamingParser: raise def close(self) -> None: - """Close the parser.""" + """Close the parser. + + This should be called after all chunks have been parsed. + + Raises: + xml.etree.ElementTree.ParseError: If the XML is not well-formed. + """ try: self.pull_parser.close() except xml.etree.ElementTree.ParseError: @@ -153,9 +166,23 @@ class XMLOutputParser(BaseTransformOutputParser): """ def get_format_instructions(self) -> str: + """Return the format instructions for the XML output.""" return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags) def parse(self, text: str) -> Dict[str, Union[str, List[Any]]]: + """Parse the output of an LLM call. + + Args: + text: The output of an LLM call. + + Returns: + A dictionary representing the parsed XML. + + Raises: + OutputParserException: If the XML is not well-formed. + ImportError: If defusedxml is not installed and the defusedxml + parser is requested. + """ # Try to find XML string within triple backticks # Imports are temporarily placed here to avoid issue with caching on CI # likely if you're reading this you can move them to the top of the file @@ -227,7 +254,15 @@ class XMLOutputParser(BaseTransformOutputParser): def nested_element(path: List[str], elem: ET.Element) -> Any: - """Get nested element from path.""" + """Get nested element from path. + + Args: + path: The path to the element. + elem: The element to extract. + + Returns: + The nested element. + """ if len(path) == 0: return AddableDict({elem.tag: elem.text}) else: