core[patch]: docstrings output_parsers (#23825)

Added missed docstrings. Formatted docstrings to the consistent form.
This commit is contained in:
Leonid Ganeline 2024-07-03 11:27:40 -07:00 committed by GitHub
parent 26cee2e878
commit 55f6f91f17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 387 additions and 20 deletions

View File

@ -38,6 +38,8 @@ class BaseLLMOutputParser(Generic[T], ABC):
Args: Args:
result: A list of Generations to be parsed. The Generations are assumed result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input. to be different candidate outputs for a single model input.
partial: Whether to parse the output as a partial result. This is useful
for parsers that can parse partial results. Default is False.
Returns: Returns:
Structured output. Structured output.
@ -46,11 +48,13 @@ class BaseLLMOutputParser(Generic[T], ABC):
async def aparse_result( async def aparse_result(
self, result: List[Generation], *, partial: bool = False self, result: List[Generation], *, partial: bool = False
) -> T: ) -> T:
"""Parse a list of candidate model Generations into a specific format. """Async parse a list of candidate model Generations into a specific format.
Args: Args:
result: A list of Generations to be parsed. The Generations are assumed result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input. to be different candidate outputs for a single model input.
partial: Whether to parse the output as a partial result. This is useful
for parsers that can parse partial results. Default is False.
Returns: Returns:
Structured output. Structured output.
@ -65,10 +69,12 @@ class BaseGenerationOutputParser(
@property @property
def InputType(self) -> Any: def InputType(self) -> Any:
"""Return the input type for the parser."""
return Union[str, AnyMessage] return Union[str, AnyMessage]
@property @property
def OutputType(self) -> Type[T]: def OutputType(self) -> Type[T]:
"""Return the output type for the parser."""
# even though mypy complains this isn't valid, # even though mypy complains this isn't valid,
# it is good enough for pydantic to build the schema from # it is good enough for pydantic to build the schema from
return T # type: ignore[misc] return T # type: ignore[misc]
@ -148,10 +154,18 @@ class BaseOutputParser(
@property @property
def InputType(self) -> Any: def InputType(self) -> Any:
"""Return the input type for the parser."""
return Union[str, AnyMessage] return Union[str, AnyMessage]
@property @property
def OutputType(self) -> Type[T]: def OutputType(self) -> Type[T]:
"""Return the output type for the parser.
This property is inferred from the first type argument of the class.
Raises:
TypeError: If the class doesn't have an inferable OutputType.
"""
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined] for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls) type_args = get_args(cls)
if type_args and len(type_args) == 1: if type_args and len(type_args) == 1:
@ -214,6 +228,8 @@ class BaseOutputParser(
Args: Args:
result: A list of Generations to be parsed. The Generations are assumed result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input. to be different candidate outputs for a single model input.
partial: Whether to parse the output as a partial result. This is useful
for parsers that can parse partial results. Default is False.
Returns: Returns:
Structured output. Structured output.
@ -234,7 +250,7 @@ class BaseOutputParser(
async def aparse_result( async def aparse_result(
self, result: List[Generation], *, partial: bool = False self, result: List[Generation], *, partial: bool = False
) -> T: ) -> T:
"""Parse a list of candidate model Generations into a specific format. """Async parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which The return value is parsed from only the first Generation in the result, which
is assumed to be the highest-likelihood Generation. is assumed to be the highest-likelihood Generation.
@ -242,6 +258,8 @@ class BaseOutputParser(
Args: Args:
result: A list of Generations to be parsed. The Generations are assumed result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input. to be different candidate outputs for a single model input.
partial: Whether to parse the output as a partial result. This is useful
for parsers that can parse partial results. Default is False.
Returns: Returns:
Structured output. Structured output.
@ -249,7 +267,7 @@ class BaseOutputParser(
return await run_in_executor(None, self.parse_result, result, partial=partial) return await run_in_executor(None, self.parse_result, result, partial=partial)
async def aparse(self, text: str) -> T: async def aparse(self, text: str) -> T:
"""Parse a single string model output into some structure. """Async parse a single string model output into some structure.
Args: Args:
text: String output of a language model. text: String output of a language model.
@ -272,7 +290,7 @@ class BaseOutputParser(
prompt: Input PromptValue. prompt: Input PromptValue.
Returns: Returns:
Structured output Structured output.
""" """
return self.parse(completion) return self.parse(completion)

View File

@ -41,6 +41,8 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
""" """
pydantic_object: Optional[Type[TBaseModel]] = None # type: ignore pydantic_object: Optional[Type[TBaseModel]] = None # type: ignore
"""The Pydantic object to use for validation.
If None, no validation is performed."""
def _diff(self, prev: Optional[Any], next: Any) -> Any: def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch return jsonpatch.make_patch(prev, next).patch
@ -54,6 +56,22 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
return pydantic_object.schema() return pydantic_object.schema()
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects.
If True, the output will be a JSON object containing
all the keys that have been returned so far.
If False, the output will be the full JSON object.
Default is False.
Returns:
The parsed JSON object.
Raises:
OutputParserException: If the output is not valid JSON.
"""
text = result[0].text text = result[0].text
text = text.strip() text = text.strip()
if partial: if partial:
@ -69,9 +87,22 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
raise OutputParserException(msg, llm_output=text) from e raise OutputParserException(msg, llm_output=text) from e
def parse(self, text: str) -> Any: def parse(self, text: str) -> Any:
"""Parse the output of an LLM call to a JSON object.
Args:
text: The output of the LLM call.
Returns:
The parsed JSON object.
"""
return self.parse_result([Generation(text=text)]) return self.parse_result([Generation(text=text)])
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
"""Return the format instructions for the JSON output.
Returns:
The format instructions for the JSON output.
"""
if self.pydantic_object is None: if self.pydantic_object is None:
return "Return a JSON object." return "Return a JSON object."
else: else:

View File

@ -12,7 +12,15 @@ T = TypeVar("T")
def droplastn(iter: Iterator[T], n: int) -> Iterator[T]: def droplastn(iter: Iterator[T], n: int) -> Iterator[T]:
"""Drop the last n elements of an iterator.""" """Drop the last n elements of an iterator.
Args:
iter: The iterator to drop elements from.
n: The number of elements to drop.
Yields:
The elements of the iterator, except the last n elements.
"""
buffer: Deque[T] = deque() buffer: Deque[T] = deque()
for item in iter: for item in iter:
buffer.append(item) buffer.append(item)
@ -29,10 +37,24 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
@abstractmethod @abstractmethod
def parse(self, text: str) -> List[str]: def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call.""" """Parse the output of an LLM call.
Args:
text: The output of an LLM call.
Returns:
A list of strings.
"""
def parse_iter(self, text: str) -> Iterator[re.Match]: def parse_iter(self, text: str) -> Iterator[re.Match]:
"""Parse the output of an LLM call.""" """Parse the output of an LLM call.
Args:
text: The output of an LLM call.
Yields:
A match object for each part of the output.
"""
raise NotImplementedError raise NotImplementedError
def _transform( def _transform(
@ -105,21 +127,36 @@ class CommaSeparatedListOutputParser(ListOutputParser):
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
"""Check if the langchain object is serializable.
Returns True."""
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object.
Returns:
A list of strings.
Default is ["langchain", "output_parsers", "list"].
"""
return ["langchain", "output_parsers", "list"] return ["langchain", "output_parsers", "list"]
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
"""Return the format instructions for the comma-separated list output."""
return ( return (
"Your response should be a list of comma separated values, " "Your response should be a list of comma separated values, "
"eg: `foo, bar, baz` or `foo,bar,baz`" "eg: `foo, bar, baz` or `foo,bar,baz`"
) )
def parse(self, text: str) -> List[str]: def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call.""" """Parse the output of an LLM call.
Args:
text: The output of an LLM call.
Returns:
A list of strings.
"""
return [part.strip() for part in text.split(",")] return [part.strip() for part in text.split(",")]
@property @property
@ -131,6 +168,7 @@ class NumberedListOutputParser(ListOutputParser):
"""Parse a numbered list.""" """Parse a numbered list."""
pattern: str = r"\d+\.\s([^\n]+)" pattern: str = r"\d+\.\s([^\n]+)"
"""The pattern to match a numbered list item."""
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
return ( return (
@ -139,11 +177,25 @@ class NumberedListOutputParser(ListOutputParser):
) )
def parse(self, text: str) -> List[str]: def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call.""" """Parse the output of an LLM call.
Args:
text: The output of an LLM call.
Returns:
A list of strings.
"""
return re.findall(self.pattern, text) return re.findall(self.pattern, text)
def parse_iter(self, text: str) -> Iterator[re.Match]: def parse_iter(self, text: str) -> Iterator[re.Match]:
"""Parse the output of an LLM call.""" """Parse the output of an LLM call.
Args:
text: The output of an LLM call.
Yields:
A match object for each part of the output.
"""
return re.finditer(self.pattern, text) return re.finditer(self.pattern, text)
@property @property
@ -152,19 +204,35 @@ class NumberedListOutputParser(ListOutputParser):
class MarkdownListOutputParser(ListOutputParser): class MarkdownListOutputParser(ListOutputParser):
"""Parse a markdown list.""" """Parse a Markdown list."""
pattern: str = r"^\s*[-*]\s([^\n]+)$" pattern: str = r"^\s*[-*]\s([^\n]+)$"
"""The pattern to match a Markdown list item."""
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
"""Return the format instructions for the Markdown list output."""
return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`" return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`"
def parse(self, text: str) -> List[str]: def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call.""" """Parse the output of an LLM call.
Args:
text: The output of an LLM call.
Returns:
A list of strings.
"""
return re.findall(self.pattern, text, re.MULTILINE) return re.findall(self.pattern, text, re.MULTILINE)
def parse_iter(self, text: str) -> Iterator[re.Match]: def parse_iter(self, text: str) -> Iterator[re.Match]:
"""Parse the output of an LLM call.""" """Parse the output of an LLM call.
Args:
text: The output of an LLM call.
Yields:
A match object for each part of the output.
"""
return re.finditer(self.pattern, text, re.MULTILINE) return re.finditer(self.pattern, text, re.MULTILINE)
@property @property

View File

@ -21,6 +21,18 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
"""Whether to only return the arguments to the function call.""" """Whether to only return the arguments to the function call."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects. Default is False.
Returns:
The parsed JSON object.
Raises:
OutputParserException: If the output is not valid JSON.
"""
generation = result[0] generation = result[0]
if not isinstance(generation, ChatGeneration): if not isinstance(generation, ChatGeneration):
raise OutputParserException( raise OutputParserException(
@ -59,6 +71,19 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
return jsonpatch.make_patch(prev, next).patch return jsonpatch.make_patch(prev, next).patch
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects. Default is False.
Returns:
The parsed JSON object.
Raises:
OutputParserException: If the output is not valid JSON.
"""
if len(result) != 1: if len(result) != 1:
raise OutputParserException( raise OutputParserException(
f"Expected exactly one result, but got {len(result)}" f"Expected exactly one result, but got {len(result)}"
@ -120,6 +145,14 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
# This method would be called by the default implementation of `parse_result` # This method would be called by the default implementation of `parse_result`
# but we're overriding that method so it's not needed. # but we're overriding that method so it's not needed.
def parse(self, text: str) -> Any: def parse(self, text: str) -> Any:
"""Parse the output of an LLM call to a JSON object.
Args:
text: The output of the LLM call.
Returns:
The parsed JSON object.
"""
raise NotImplementedError() raise NotImplementedError()
@ -130,6 +163,15 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
"""The name of the key to return.""" """The name of the key to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects. Default is False.
Returns:
The parsed JSON object.
"""
res = super().parse_result(result, partial=partial) res = super().parse_result(result, partial=partial)
if partial and res is None: if partial and res is None:
return None return None
@ -186,6 +228,17 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
@root_validator(pre=True) @root_validator(pre=True)
def validate_schema(cls, values: Dict) -> Dict: def validate_schema(cls, values: Dict) -> Dict:
"""Validate the pydantic schema.
Args:
values: The values to validate.
Returns:
The validated values.
Raises:
ValueError: If the schema is not a pydantic schema.
"""
schema = values["pydantic_schema"] schema = values["pydantic_schema"]
if "args_only" not in values: if "args_only" not in values:
values["args_only"] = isinstance(schema, type) and issubclass( values["args_only"] = isinstance(schema, type) and issubclass(
@ -199,6 +252,15 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
return values return values
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects. Default is False.
Returns:
The parsed JSON object.
"""
_result = super().parse_result(result) _result = super().parse_result(result)
if self.args_only: if self.args_only:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
@ -216,5 +278,14 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
"""The name of the attribute to return.""" """The name of the attribute to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects. Default is False.
Returns:
The parsed JSON object.
"""
result = super().parse_result(result) result = super().parse_result(result)
return getattr(result, self.attr_name) return getattr(result, self.attr_name)

View File

@ -18,7 +18,21 @@ def parse_tool_call(
strict: bool = False, strict: bool = False,
return_id: bool = True, return_id: bool = True,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
"""Parse a single tool call.""" """Parse a single tool call.
Args:
raw_tool_call: The raw tool call to parse.
partial: Whether to parse partial JSON. Default is False.
strict: Whether to allow non-JSON-compliant strings.
Default is False.
return_id: Whether to return the tool call id. Default is True.
Returns:
The parsed tool call.
Raises:
OutputParserException: If the tool call is not valid JSON.
"""
if "function" not in raw_tool_call: if "function" not in raw_tool_call:
return None return None
if partial: if partial:
@ -52,7 +66,15 @@ def make_invalid_tool_call(
raw_tool_call: Dict[str, Any], raw_tool_call: Dict[str, Any],
error_msg: Optional[str], error_msg: Optional[str],
) -> InvalidToolCall: ) -> InvalidToolCall:
"""Create an InvalidToolCall from a raw tool call.""" """Create an InvalidToolCall from a raw tool call.
Args:
raw_tool_call: The raw tool call.
error_msg: The error message.
Returns:
An InvalidToolCall instance with the error message.
"""
return InvalidToolCall( return InvalidToolCall(
name=raw_tool_call["function"]["name"], name=raw_tool_call["function"]["name"],
args=raw_tool_call["function"]["arguments"], args=raw_tool_call["function"]["arguments"],
@ -68,7 +90,21 @@ def parse_tool_calls(
strict: bool = False, strict: bool = False,
return_id: bool = True, return_id: bool = True,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Parse a list of tool calls.""" """Parse a list of tool calls.
Args:
raw_tool_calls: The raw tool calls to parse.
partial: Whether to parse partial JSON. Default is False.
strict: Whether to allow non-JSON-compliant strings.
Default is False.
return_id: Whether to return the tool call id. Default is True.
Returns:
The parsed tool calls.
Raises:
OutputParserException: If any of the tool calls are not valid JSON.
"""
final_tools: List[Dict[str, Any]] = [] final_tools: List[Dict[str, Any]] = []
exceptions = [] exceptions = []
for tool_call in raw_tool_calls: for tool_call in raw_tool_calls:
@ -110,6 +146,23 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
""" """
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a list of tool calls.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON.
If True, the output will be a JSON object containing
all the keys that have been returned so far.
If False, the output will be the full JSON object.
Default is False.
Returns:
The parsed tool calls.
Raises:
OutputParserException: If the output is not valid JSON.
"""
generation = result[0] generation = result[0]
if not isinstance(generation, ChatGeneration): if not isinstance(generation, ChatGeneration):
raise OutputParserException( raise OutputParserException(
@ -141,6 +194,14 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
return tool_calls return tool_calls
def parse(self, text: str) -> Any: def parse(self, text: str) -> Any:
"""Parse the output of an LLM call to a list of tool calls.
Args:
text: The output of the LLM call.
Returns:
The parsed tool calls.
"""
raise NotImplementedError() raise NotImplementedError()
@ -151,6 +212,19 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
"""The type of tools to return.""" """The type of tools to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a list of tool calls.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON.
If True, the output will be a JSON object containing
all the keys that have been returned so far.
If False, the output will be the full JSON object.
Default is False.
Returns:
The parsed tool calls.
"""
parsed_result = super().parse_result(result, partial=partial) parsed_result = super().parse_result(result, partial=partial)
if self.first_tool_only: if self.first_tool_only:
@ -175,10 +249,27 @@ class PydanticToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response.""" """Parse tools from OpenAI response."""
tools: List[Type[BaseModel]] tools: List[Type[BaseModel]]
"""The tools to parse."""
# TODO: Support more granular streaming of objects. Currently only streams once all # TODO: Support more granular streaming of objects. Currently only streams once all
# Pydantic object fields are present. # Pydantic object fields are present.
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a list of Pydantic objects.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON.
If True, the output will be a JSON object containing
all the keys that have been returned so far.
If False, the output will be the full JSON object.
Default is False.
Returns:
The parsed Pydantic objects.
Raises:
OutputParserException: If the output is not valid JSON.
"""
json_results = super().parse_result(result, partial=partial) json_results = super().parse_result(result, partial=partial)
if not json_results: if not json_results:
return None if self.first_tool_only else [] return None if self.first_tool_only else []

View File

@ -57,13 +57,38 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
def parse_result( def parse_result(
self, result: List[Generation], *, partial: bool = False self, result: List[Generation], *, partial: bool = False
) -> TBaseModel: ) -> TBaseModel:
"""Parse the result of an LLM call to a pydantic object.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects.
If True, the output will be a JSON object containing
all the keys that have been returned so far.
Defaults to False.
Returns:
The parsed pydantic object.
"""
json_object = super().parse_result(result) json_object = super().parse_result(result)
return self._parse_obj(json_object) return self._parse_obj(json_object)
def parse(self, text: str) -> TBaseModel: def parse(self, text: str) -> TBaseModel:
"""Parse the output of an LLM call to a pydantic object.
Args:
text: The output of the LLM call.
Returns:
The parsed pydantic object.
"""
return super().parse(text) return super().parse(text)
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
"""Return the format instructions for the JSON output.
Returns:
The format instructions for the JSON output.
"""
# Copy schema to avoid altering original Pydantic schema. # Copy schema to avoid altering original Pydantic schema.
schema = {k: v for k, v in self.pydantic_object.schema().items()} schema = {k: v for k, v in self.pydantic_object.schema().items()}

View File

@ -47,6 +47,16 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[T]: ) -> Iterator[T]:
"""Transform the input into the output format.
Args:
input: The input to transform.
config: The configuration to use for the transformation.
kwargs: Additional keyword arguments.
Yields:
The transformed output.
"""
yield from self._transform_stream_with_config( yield from self._transform_stream_with_config(
input, self._transform, config, run_type="parser" input, self._transform, config, run_type="parser"
) )
@ -57,6 +67,16 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[T]: ) -> AsyncIterator[T]:
"""Async transform the input into the output format.
Args:
input: The input to transform.
config: The configuration to use for the transformation.
kwargs: Additional keyword arguments.
Yields:
The transformed output.
"""
async for chunk in self._atransform_stream_with_config( async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, run_type="parser" input, self._atransform, config, run_type="parser"
): ):
@ -73,7 +93,15 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
def _diff(self, prev: Optional[T], next: T) -> T: def _diff(self, prev: Optional[T], next: T) -> T:
"""Convert parsed outputs into a diff format. The semantics of this are """Convert parsed outputs into a diff format. The semantics of this are
up to the output parser.""" up to the output parser.
Args:
prev: The previous parsed output.
next: The current parsed output.
Returns:
The diff between the previous and current parsed output.
"""
raise NotImplementedError() raise NotImplementedError()
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:

View File

@ -38,6 +38,10 @@ class _StreamingParser:
Args: Args:
parser: Parser to use for XML parsing. Can be either 'defusedxml' or 'xml'. parser: Parser to use for XML parsing. Can be either 'defusedxml' or 'xml'.
See documentation in XMLOutputParser for more information. See documentation in XMLOutputParser for more information.
Raises:
ImportError: If defusedxml is not installed and the defusedxml
parser is requested.
""" """
if parser == "defusedxml": if parser == "defusedxml":
try: try:
@ -66,6 +70,9 @@ class _StreamingParser:
Yields: Yields:
AddableDict: A dictionary representing the parsed XML element. AddableDict: A dictionary representing the parsed XML element.
Raises:
xml.etree.ElementTree.ParseError: If the XML is not well-formed.
""" """
if isinstance(chunk, BaseMessage): if isinstance(chunk, BaseMessage):
# extract text # extract text
@ -116,7 +123,13 @@ class _StreamingParser:
raise raise
def close(self) -> None: def close(self) -> None:
"""Close the parser.""" """Close the parser.
This should be called after all chunks have been parsed.
Raises:
xml.etree.ElementTree.ParseError: If the XML is not well-formed.
"""
try: try:
self.pull_parser.close() self.pull_parser.close()
except xml.etree.ElementTree.ParseError: except xml.etree.ElementTree.ParseError:
@ -153,9 +166,23 @@ class XMLOutputParser(BaseTransformOutputParser):
""" """
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
"""Return the format instructions for the XML output."""
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags) return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
def parse(self, text: str) -> Dict[str, Union[str, List[Any]]]: def parse(self, text: str) -> Dict[str, Union[str, List[Any]]]:
"""Parse the output of an LLM call.
Args:
text: The output of an LLM call.
Returns:
A dictionary representing the parsed XML.
Raises:
OutputParserException: If the XML is not well-formed.
ImportError: If defusedxml is not installed and the defusedxml
parser is requested.
"""
# Try to find XML string within triple backticks # Try to find XML string within triple backticks
# Imports are temporarily placed here to avoid issue with caching on CI # Imports are temporarily placed here to avoid issue with caching on CI
# likely if you're reading this you can move them to the top of the file # likely if you're reading this you can move them to the top of the file
@ -227,7 +254,15 @@ class XMLOutputParser(BaseTransformOutputParser):
def nested_element(path: List[str], elem: ET.Element) -> Any: def nested_element(path: List[str], elem: ET.Element) -> Any:
"""Get nested element from path.""" """Get nested element from path.
Args:
path: The path to the element.
elem: The element to extract.
Returns:
The nested element.
"""
if len(path) == 0: if len(path) == 0:
return AddableDict({elem.tag: elem.text}) return AddableDict({elem.tag: elem.text})
else: else: