diff --git a/libs/core/langchain_core/example_selectors/base.py b/libs/core/langchain_core/example_selectors/base.py index ff2e099c810..5061e140a75 100644 --- a/libs/core/langchain_core/example_selectors/base.py +++ b/libs/core/langchain_core/example_selectors/base.py @@ -8,7 +8,7 @@ class BaseExampleSelector(ABC): @abstractmethod def add_example(self, example: Dict[str, str]) -> Any: - """Add new example to store for a key.""" + """Add new example to store.""" @abstractmethod def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: diff --git a/libs/core/langchain_core/output_parsers/__init__.py b/libs/core/langchain_core/output_parsers/__init__.py index 2acaab77b92..51caa200f3a 100644 --- a/libs/core/langchain_core/output_parsers/__init__.py +++ b/libs/core/langchain_core/output_parsers/__init__.py @@ -3,7 +3,7 @@ from langchain_core.output_parsers.base import ( BaseLLMOutputParser, BaseOutputParser, ) -from langchain_core.output_parsers.json import SimpleJsonOutputParser +from langchain_core.output_parsers.json import JsonOutputParser, SimpleJsonOutputParser from langchain_core.output_parsers.list import ( CommaSeparatedListOutputParser, ListOutputParser, @@ -30,4 +30,5 @@ __all__ = [ "BaseCumulativeTransformOutputParser", "SimpleJsonOutputParser", "XMLOutputParser", + "JsonOutputParser", ] diff --git a/libs/core/langchain_core/output_parsers/format_instructions.py b/libs/core/langchain_core/output_parsers/format_instructions.py new file mode 100644 index 00000000000..400756e8663 --- /dev/null +++ b/libs/core/langchain_core/output_parsers/format_instructions.py @@ -0,0 +1,11 @@ +# flake8: noqa + +JSON_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below. + +As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}} +the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted. + +Here is the output schema: +``` +{schema} +```""" diff --git a/libs/core/langchain_core/output_parsers/json.py b/libs/core/langchain_core/output_parsers/json.py index d2d7254a590..14d70a5c7b4 100644 --- a/libs/core/langchain_core/output_parsers/json.py +++ b/libs/core/langchain_core/output_parsers/json.py @@ -3,12 +3,14 @@ from __future__ import annotations import json import re from json import JSONDecodeError -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Type import jsonpatch # type: ignore[import] from langchain_core.exceptions import OutputParserException +from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser +from langchain_core.pydantic_v1 import BaseModel def _replace_new_line(match: re.Match[str]) -> str: @@ -170,7 +172,7 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict: return json_obj -class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): +class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]): """Parse the output of an LLM call to a JSON object. When used in streaming mode, it will yield partial JSON objects containing @@ -180,6 +182,8 @@ class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): describing the difference between the previous and the current object. """ + pydantic_object: Optional[Type[BaseModel]] = None + def _diff(self, prev: Optional[Any], next: Any) -> Any: return jsonpatch.make_patch(prev, next).patch @@ -190,6 +194,26 @@ class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): except JSONDecodeError as e: raise OutputParserException(f"Invalid json output: {text}") from e + def get_format_instructions(self) -> str: + if self.pydantic_object is None: + return "Return a JSON object." + else: + schema = self.pydantic_object.schema() + + # Remove extraneous fields. + reduced_schema = schema + if "title" in reduced_schema: + del reduced_schema["title"] + if "type" in reduced_schema: + del reduced_schema["type"] + # Ensure json in context is well-formed with double quotes. + schema_str = json.dumps(reduced_schema) + return JSON_FORMAT_INSTRUCTIONS.format(schema=schema_str) + @property def _type(self) -> str: return "simple_json_output_parser" + + +# For backwards compatibility +SimpleJsonOutputParser = JsonOutputParser diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index 43de770d0b6..9a93023ce12 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -34,7 +34,11 @@ class XMLOutputParser(BaseTransformOutputParser): return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags) def parse(self, text: str) -> Dict[str, List[Any]]: - text = text.strip("`").strip("xml") + # Try to find XML string within triple backticks + match = re.search(r"```(xml)?(.*)```", text, re.DOTALL) + if match is not None: + # If match found, use the content within the backticks + text = match.group(2) encoding_match = self.encoding_matcher.search(text) if encoding_match: text = encoding_match.group(2) diff --git a/libs/core/tests/unit_tests/output_parsers/test_imports.py b/libs/core/tests/unit_tests/output_parsers/test_imports.py index d1ee5c1b4a6..bf4b19120ab 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_imports.py +++ b/libs/core/tests/unit_tests/output_parsers/test_imports.py @@ -13,6 +13,7 @@ EXPECTED_ALL = [ "BaseCumulativeTransformOutputParser", "SimpleJsonOutputParser", "XMLOutputParser", + "JsonOutputParser", ] diff --git a/libs/langchain/langchain/output_parsers/datetime.py b/libs/langchain/langchain/output_parsers/datetime.py index 41636fa6987..837f0b0b8be 100644 --- a/libs/langchain/langchain/output_parsers/datetime.py +++ b/libs/langchain/langchain/output_parsers/datetime.py @@ -39,8 +39,12 @@ class DatetimeOutputParser(BaseOutputParser[datetime]): def get_format_instructions(self) -> str: examples = comma_list(_generate_random_datetime_strings(self.format)) - return f"""Write a datetime string that matches the - following pattern: "{self.format}". Examples: {examples}""" + return ( + f"Write a datetime string that matches the " + f"following pattern: '{self.format}'.\n\n" + f"Examples: {examples}\n\n" + f"Return ONLY this string, no other words!" + ) def parse(self, response: str) -> datetime: try: