mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
[core, langchain] modelio code improvements (#15277)
This commit is contained in:
parent
694bbb14cd
commit
b86803153e
@ -8,7 +8,7 @@ class BaseExampleSelector(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def add_example(self, example: Dict[str, str]) -> Any:
|
||||
"""Add new example to store for a key."""
|
||||
"""Add new example to store."""
|
||||
|
||||
@abstractmethod
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
|
@ -3,7 +3,7 @@ from langchain_core.output_parsers.base import (
|
||||
BaseLLMOutputParser,
|
||||
BaseOutputParser,
|
||||
)
|
||||
from langchain_core.output_parsers.json import SimpleJsonOutputParser
|
||||
from langchain_core.output_parsers.json import JsonOutputParser, SimpleJsonOutputParser
|
||||
from langchain_core.output_parsers.list import (
|
||||
CommaSeparatedListOutputParser,
|
||||
ListOutputParser,
|
||||
@ -30,4 +30,5 @@ __all__ = [
|
||||
"BaseCumulativeTransformOutputParser",
|
||||
"SimpleJsonOutputParser",
|
||||
"XMLOutputParser",
|
||||
"JsonOutputParser",
|
||||
]
|
||||
|
@ -0,0 +1,11 @@
|
||||
# flake8: noqa
|
||||
|
||||
JSON_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
|
||||
|
||||
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}
|
||||
the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted.
|
||||
|
||||
Here is the output schema:
|
||||
```
|
||||
{schema}
|
||||
```"""
|
@ -3,12 +3,14 @@ from __future__ import annotations
|
||||
import json
|
||||
import re
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Callable, List, Optional
|
||||
from typing import Any, Callable, List, Optional, Type
|
||||
|
||||
import jsonpatch # type: ignore[import]
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
def _replace_new_line(match: re.Match[str]) -> str:
|
||||
@ -170,7 +172,7 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
||||
return json_obj
|
||||
|
||||
|
||||
class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
"""Parse the output of an LLM call to a JSON object.
|
||||
|
||||
When used in streaming mode, it will yield partial JSON objects containing
|
||||
@ -180,6 +182,8 @@ class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
describing the difference between the previous and the current object.
|
||||
"""
|
||||
|
||||
pydantic_object: Optional[Type[BaseModel]] = None
|
||||
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
@ -190,6 +194,26 @@ class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
except JSONDecodeError as e:
|
||||
raise OutputParserException(f"Invalid json output: {text}") from e
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
if self.pydantic_object is None:
|
||||
return "Return a JSON object."
|
||||
else:
|
||||
schema = self.pydantic_object.schema()
|
||||
|
||||
# Remove extraneous fields.
|
||||
reduced_schema = schema
|
||||
if "title" in reduced_schema:
|
||||
del reduced_schema["title"]
|
||||
if "type" in reduced_schema:
|
||||
del reduced_schema["type"]
|
||||
# Ensure json in context is well-formed with double quotes.
|
||||
schema_str = json.dumps(reduced_schema)
|
||||
return JSON_FORMAT_INSTRUCTIONS.format(schema=schema_str)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "simple_json_output_parser"
|
||||
|
||||
|
||||
# For backwards compatibility
|
||||
SimpleJsonOutputParser = JsonOutputParser
|
||||
|
@ -34,7 +34,11 @@ class XMLOutputParser(BaseTransformOutputParser):
|
||||
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
|
||||
|
||||
def parse(self, text: str) -> Dict[str, List[Any]]:
|
||||
text = text.strip("`").strip("xml")
|
||||
# Try to find XML string within triple backticks
|
||||
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
|
||||
if match is not None:
|
||||
# If match found, use the content within the backticks
|
||||
text = match.group(2)
|
||||
encoding_match = self.encoding_matcher.search(text)
|
||||
if encoding_match:
|
||||
text = encoding_match.group(2)
|
||||
|
@ -13,6 +13,7 @@ EXPECTED_ALL = [
|
||||
"BaseCumulativeTransformOutputParser",
|
||||
"SimpleJsonOutputParser",
|
||||
"XMLOutputParser",
|
||||
"JsonOutputParser",
|
||||
]
|
||||
|
||||
|
||||
|
@ -39,8 +39,12 @@ class DatetimeOutputParser(BaseOutputParser[datetime]):
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
examples = comma_list(_generate_random_datetime_strings(self.format))
|
||||
return f"""Write a datetime string that matches the
|
||||
following pattern: "{self.format}". Examples: {examples}"""
|
||||
return (
|
||||
f"Write a datetime string that matches the "
|
||||
f"following pattern: '{self.format}'.\n\n"
|
||||
f"Examples: {examples}\n\n"
|
||||
f"Return ONLY this string, no other words!"
|
||||
)
|
||||
|
||||
def parse(self, response: str) -> datetime:
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user