mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 05:13:46 +00:00
[core, langchain] modelio code improvements (#15277)
This commit is contained in:
parent
694bbb14cd
commit
b86803153e
@ -8,7 +8,7 @@ class BaseExampleSelector(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_example(self, example: Dict[str, str]) -> Any:
|
def add_example(self, example: Dict[str, str]) -> Any:
|
||||||
"""Add new example to store for a key."""
|
"""Add new example to store."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||||
|
@ -3,7 +3,7 @@ from langchain_core.output_parsers.base import (
|
|||||||
BaseLLMOutputParser,
|
BaseLLMOutputParser,
|
||||||
BaseOutputParser,
|
BaseOutputParser,
|
||||||
)
|
)
|
||||||
from langchain_core.output_parsers.json import SimpleJsonOutputParser
|
from langchain_core.output_parsers.json import JsonOutputParser, SimpleJsonOutputParser
|
||||||
from langchain_core.output_parsers.list import (
|
from langchain_core.output_parsers.list import (
|
||||||
CommaSeparatedListOutputParser,
|
CommaSeparatedListOutputParser,
|
||||||
ListOutputParser,
|
ListOutputParser,
|
||||||
@ -30,4 +30,5 @@ __all__ = [
|
|||||||
"BaseCumulativeTransformOutputParser",
|
"BaseCumulativeTransformOutputParser",
|
||||||
"SimpleJsonOutputParser",
|
"SimpleJsonOutputParser",
|
||||||
"XMLOutputParser",
|
"XMLOutputParser",
|
||||||
|
"JsonOutputParser",
|
||||||
]
|
]
|
||||||
|
@ -0,0 +1,11 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
JSON_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
|
||||||
|
|
||||||
|
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}
|
||||||
|
the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted.
|
||||||
|
|
||||||
|
Here is the output schema:
|
||||||
|
```
|
||||||
|
{schema}
|
||||||
|
```"""
|
@ -3,12 +3,14 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, Callable, List, Optional, Type
|
||||||
|
|
||||||
import jsonpatch # type: ignore[import]
|
import jsonpatch # type: ignore[import]
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
|
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
||||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
|
||||||
|
|
||||||
def _replace_new_line(match: re.Match[str]) -> str:
|
def _replace_new_line(match: re.Match[str]) -> str:
|
||||||
@ -170,7 +172,7 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
|||||||
return json_obj
|
return json_obj
|
||||||
|
|
||||||
|
|
||||||
class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||||
"""Parse the output of an LLM call to a JSON object.
|
"""Parse the output of an LLM call to a JSON object.
|
||||||
|
|
||||||
When used in streaming mode, it will yield partial JSON objects containing
|
When used in streaming mode, it will yield partial JSON objects containing
|
||||||
@ -180,6 +182,8 @@ class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
describing the difference between the previous and the current object.
|
describing the difference between the previous and the current object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
pydantic_object: Optional[Type[BaseModel]] = None
|
||||||
|
|
||||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||||
return jsonpatch.make_patch(prev, next).patch
|
return jsonpatch.make_patch(prev, next).patch
|
||||||
|
|
||||||
@ -190,6 +194,26 @@ class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
except JSONDecodeError as e:
|
except JSONDecodeError as e:
|
||||||
raise OutputParserException(f"Invalid json output: {text}") from e
|
raise OutputParserException(f"Invalid json output: {text}") from e
|
||||||
|
|
||||||
|
def get_format_instructions(self) -> str:
|
||||||
|
if self.pydantic_object is None:
|
||||||
|
return "Return a JSON object."
|
||||||
|
else:
|
||||||
|
schema = self.pydantic_object.schema()
|
||||||
|
|
||||||
|
# Remove extraneous fields.
|
||||||
|
reduced_schema = schema
|
||||||
|
if "title" in reduced_schema:
|
||||||
|
del reduced_schema["title"]
|
||||||
|
if "type" in reduced_schema:
|
||||||
|
del reduced_schema["type"]
|
||||||
|
# Ensure json in context is well-formed with double quotes.
|
||||||
|
schema_str = json.dumps(reduced_schema)
|
||||||
|
return JSON_FORMAT_INSTRUCTIONS.format(schema=schema_str)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _type(self) -> str:
|
def _type(self) -> str:
|
||||||
return "simple_json_output_parser"
|
return "simple_json_output_parser"
|
||||||
|
|
||||||
|
|
||||||
|
# For backwards compatibility
|
||||||
|
SimpleJsonOutputParser = JsonOutputParser
|
||||||
|
@ -34,7 +34,11 @@ class XMLOutputParser(BaseTransformOutputParser):
|
|||||||
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
|
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
|
||||||
|
|
||||||
def parse(self, text: str) -> Dict[str, List[Any]]:
|
def parse(self, text: str) -> Dict[str, List[Any]]:
|
||||||
text = text.strip("`").strip("xml")
|
# Try to find XML string within triple backticks
|
||||||
|
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
|
||||||
|
if match is not None:
|
||||||
|
# If match found, use the content within the backticks
|
||||||
|
text = match.group(2)
|
||||||
encoding_match = self.encoding_matcher.search(text)
|
encoding_match = self.encoding_matcher.search(text)
|
||||||
if encoding_match:
|
if encoding_match:
|
||||||
text = encoding_match.group(2)
|
text = encoding_match.group(2)
|
||||||
|
@ -13,6 +13,7 @@ EXPECTED_ALL = [
|
|||||||
"BaseCumulativeTransformOutputParser",
|
"BaseCumulativeTransformOutputParser",
|
||||||
"SimpleJsonOutputParser",
|
"SimpleJsonOutputParser",
|
||||||
"XMLOutputParser",
|
"XMLOutputParser",
|
||||||
|
"JsonOutputParser",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,8 +39,12 @@ class DatetimeOutputParser(BaseOutputParser[datetime]):
|
|||||||
|
|
||||||
def get_format_instructions(self) -> str:
|
def get_format_instructions(self) -> str:
|
||||||
examples = comma_list(_generate_random_datetime_strings(self.format))
|
examples = comma_list(_generate_random_datetime_strings(self.format))
|
||||||
return f"""Write a datetime string that matches the
|
return (
|
||||||
following pattern: "{self.format}". Examples: {examples}"""
|
f"Write a datetime string that matches the "
|
||||||
|
f"following pattern: '{self.format}'.\n\n"
|
||||||
|
f"Examples: {examples}\n\n"
|
||||||
|
f"Return ONLY this string, no other words!"
|
||||||
|
)
|
||||||
|
|
||||||
def parse(self, response: str) -> datetime:
|
def parse(self, response: str) -> datetime:
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user