This commit is contained in:
Eugene Yurtsev
2024-08-06 11:46:39 -04:00
parent d97f70def4
commit d98c1f115f
12 changed files with 36 additions and 42 deletions

View File

@@ -4,6 +4,7 @@ import re
from abc import abstractmethod
from collections import deque
from typing import AsyncIterator, Deque, Iterator, List, TypeVar, Union
from typing import Optional as Optional
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
@@ -122,6 +123,9 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
yield [part]
ListOutputParser.model_rebuild()
class CommaSeparatedListOutputParser(ListOutputParser):
"""Parse the output of an LLM call to a comma-separated list."""

View File

@@ -3,6 +3,7 @@ import json
from typing import Any, Dict, List, Optional, Type, Union
import jsonpatch # type: ignore[import]
from pydantic import BaseModel, root_validator
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import (
@@ -11,7 +12,6 @@ from langchain_core.output_parsers import (
)
from langchain_core.output_parsers.json import parse_partial_json
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, root_validator
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):

View File

@@ -1,7 +1,9 @@
import copy
import json
from json import JSONDecodeError
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type
from pydantic import BaseModel, ValidationError
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall
@@ -13,9 +15,7 @@ from langchain_core.messages.tool import (
)
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import ValidationError
from langchain_core.utils.json import parse_partial_json
from langchain_core.utils.pydantic import TypeBaseModel
def parse_tool_call(

View File

@@ -106,6 +106,9 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
return self.pydantic_object
PydanticOutputParser.model_rebuild()
_PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}
@@ -115,10 +118,3 @@ Here is the output schema:
```
{schema}
```""" # noqa: E501
# Re-exporting types for backwards compatibility
__all__ = [
"PydanticBaseModel",
"PydanticOutputParser",
"TBaseModel",
]

View File

@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional
from langchain_core.output_parsers.transform import BaseTransformOutputParser
@@ -24,3 +24,6 @@ class StrOutputParser(BaseTransformOutputParser[str]):
def parse(self, text: str) -> str:
"""Returns the input text with no changes."""
return text
StrOutputParser.model_rebuild()

View File

@@ -2,9 +2,10 @@ from __future__ import annotations
from typing import Any, Dict, List, Literal, Union
from pydantic import root_validator
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils._merge import merge_dicts

View File

@@ -1,7 +1,8 @@
from typing import List, Optional
from pydantic import BaseModel
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.pydantic_v1 import BaseModel
class ChatResult(BaseModel):

View File

@@ -1,11 +1,13 @@
from __future__ import annotations
from copy import deepcopy
from typing import List, Optional
from typing import List, Optional, Union
from pydantic import BaseModel
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.outputs.generation import Generation
from langchain_core.outputs.run_info import RunInfo
from langchain_core.pydantic_v1 import BaseModel
class LLMResult(BaseModel):
@@ -16,7 +18,7 @@ class LLMResult(BaseModel):
wants to return.
"""
generations: List[List[Generation]]
generations: Union[List[List[Generation]], List[List[ChatGeneration]]]
"""Generated outputs.
The first dimension of the list represents completions for different input

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from uuid import UUID
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class RunInfo(BaseModel):

View File

@@ -18,6 +18,7 @@ from typing import (
)
import yaml
from pydantic import BaseModel, ConfigDict, Field, root_validator
from langchain_core.output_parsers.base import BaseOutputParser
from langchain_core.prompt_values import (
@@ -25,7 +26,6 @@ from langchain_core.prompt_values import (
PromptValue,
StringPromptValue,
)
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import ensure_config
from langchain_core.runnables.utils import create_model
@@ -99,10 +99,7 @@ class BasePromptTemplate(
Returns True."""
return True
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
@property
def OutputType(self) -> Any:

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
Annotated,
Any,
Dict,
List,
@@ -21,6 +22,8 @@ from typing import (
overload,
)
from pydantic import Field, PositiveInt, SkipValidation, root_validator
from langchain_core._api import deprecated
from langchain_core.load import Serializable
from langchain_core.messages import (
@@ -38,7 +41,6 @@ from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
from langchain_core.pydantic_v1 import Field, PositiveInt, root_validator
from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env
@@ -922,7 +924,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
""" # noqa: E501
messages: List[MessageLike]
messages: Annotated[List[MessageLike], SkipValidation]
"""List of messages consisting of either message prompt templates or messages."""
validate_template: bool = False
"""Whether or not to try validating the template."""

View File

@@ -5,6 +5,8 @@ from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Extra, root_validator
from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.prompts.chat import (
@@ -18,7 +20,6 @@ from langchain_core.prompts.string import (
check_valid_template,
get_template_variables,
)
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
class _FewShotPromptTemplateMixin(BaseModel):
@@ -31,12 +32,7 @@ class _FewShotPromptTemplateMixin(BaseModel):
example_selector: Optional[BaseExampleSelector] = None
"""ExampleSelector to choose the examples to format into the prompt.
Either this or examples should be provided."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
@root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict:
@@ -160,11 +156,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
]
return values
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
def format(self, **kwargs: Any) -> str:
"""Format the prompt with inputs generating a string.
@@ -369,11 +361,7 @@ class FewShotChatMessagePromptTemplate(
"""Return whether or not the class is serializable."""
return False
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages.