REFACTOR: Refactor langchain_core (#13627)

Changes:
- remove langchain_core/schema since no clear distinction b/n schema and
non-schema modules
- make every module that doesn't end in -y plural
- where easy have 1-2 classes per file
- no more than one level of nesting in directories
- only import from top level core modules in langchain
This commit is contained in:
Bagatur 2023-11-21 08:35:29 -08:00 committed by GitHub
parent 17c6551c18
commit d32e511826
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
783 changed files with 2992 additions and 2899 deletions

View File

@ -16,9 +16,12 @@ from .deprecation import (
surface_langchain_deprecation_warnings,
warn_deprecated,
)
from .path import as_import_path, get_relative_path
__all__ = [
"as_import_path",
"deprecated",
"get_relative_path",
"LangChainDeprecationWarning",
"suppress_langchain_deprecation_warning",
"surface_langchain_deprecation_warnings",

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any, Literal, Sequence, Union
from langchain_core.load.serializable import Serializable
from langchain_core.schema.messages import BaseMessage
from langchain_core.messages import BaseMessage
class AgentAction(Serializable):

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Optional, Sequence
from langchain_core.schema.output import Generation
from langchain_core.outputs import Generation
RETURN_VAL_TYPE = Sequence[Generation]

View File

@ -0,0 +1,69 @@
from langchain_core.callbacks.base import (
AsyncCallbackHandler,
BaseCallbackHandler,
BaseCallbackManager,
CallbackManagerMixin,
Callbacks,
ChainManagerMixin,
LLMManagerMixin,
RetrieverManagerMixin,
RunManagerMixin,
ToolManagerMixin,
)
from langchain_core.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainGroup,
AsyncCallbackManagerForChainRun,
AsyncCallbackManagerForLLMRun,
AsyncCallbackManagerForRetrieverRun,
AsyncCallbackManagerForToolRun,
AsyncParentRunManager,
AsyncRunManager,
BaseRunManager,
CallbackManager,
CallbackManagerForChainGroup,
CallbackManagerForChainRun,
CallbackManagerForLLMRun,
CallbackManagerForRetrieverRun,
CallbackManagerForToolRun,
ParentRunManager,
RunManager,
env_var_is_set,
register_configure_hook,
)
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
__all__ = [
"RetrieverManagerMixin",
"LLMManagerMixin",
"ChainManagerMixin",
"ToolManagerMixin",
"Callbacks",
"CallbackManagerMixin",
"RunManagerMixin",
"BaseCallbackHandler",
"AsyncCallbackHandler",
"BaseCallbackManager",
"BaseRunManager",
"RunManager",
"ParentRunManager",
"AsyncRunManager",
"AsyncParentRunManager",
"CallbackManagerForLLMRun",
"AsyncCallbackManagerForLLMRun",
"CallbackManagerForChainRun",
"AsyncCallbackManagerForChainRun",
"CallbackManagerForToolRun",
"AsyncCallbackManagerForToolRun",
"CallbackManagerForRetrieverRun",
"AsyncCallbackManagerForRetrieverRun",
"CallbackManager",
"CallbackManagerForChainGroup",
"AsyncCallbackManager",
"AsyncCallbackManagerForChainGroup",
"StdOutCallbackHandler",
"StreamingStdOutCallbackHandler",
"env_var_is_set",
"register_configure_hook",
]

View File

@ -6,10 +6,10 @@ from uuid import UUID
from tenacity import RetryCallState
from langchain_core.schema.agent import AgentAction, AgentFinish
from langchain_core.schema.document import Document
from langchain_core.schema.messages import BaseMessage
from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
class RetrieverManagerMixin:

View File

@ -30,6 +30,10 @@ from langsmith import utils as ls_utils
from langsmith.run_helpers import get_run_tree_context
from tenacity import RetryCallState
from langchain_core.agents import (
AgentAction,
AgentFinish,
)
from langchain_core.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
@ -41,23 +45,16 @@ from langchain_core.callbacks.base import (
ToolManagerMixin,
)
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.callbacks.tracers import run_collector
from langchain_core.callbacks.tracers.langchain import (
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
from langchain_core.tracers import run_collector
from langchain_core.tracers.langchain import (
LangChainTracer,
)
from langchain_core.callbacks.tracers.langchain_v1 import (
LangChainTracerV1,
TracerSessionV1,
)
from langchain_core.callbacks.tracers.stdout import ConsoleCallbackHandler
from langchain_core.schema import (
AgentAction,
AgentFinish,
Document,
LLMResult,
)
from langchain_core.schema.messages import BaseMessage, get_buffer_string
from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
from langchain_core.tracers.schemas import TracerSessionV1
from langchain_core.tracers.stdout import ConsoleCallbackHandler
if TYPE_CHECKING:
from langsmith import Client as LangSmithClient

View File

@ -1,9 +1,10 @@
"""Callback Handler that prints to std out."""
from typing import Any, Dict, List, Optional
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.schema import AgentAction, AgentFinish, LLMResult
from langchain_core.utils.input import print_text
from langchain_core.outputs import LLMResult
from langchain_core.utils import print_text
class StdOutCallbackHandler(BaseCallbackHandler):

View File

@ -2,9 +2,10 @@
import sys
from typing import Any, Dict, List
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.schema import AgentAction, AgentFinish, LLMResult
from langchain_core.schema.messages import BaseMessage
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult
class StreamingStdOutCallbackHandler(BaseCallbackHandler):

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List
from langchain_core.schema.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
class BaseChatMessageHistory(ABC):

View File

@ -1,6 +1,6 @@
from typing import Sequence, TypedDict
from langchain_core.schema import BaseMessage
from langchain_core.messages import BaseMessage
class ChatSession(TypedDict, total=False):

View File

@ -3,27 +3,9 @@ from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Literal, Sequence
from typing import Any, Sequence
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Field
class Document(Serializable):
"""Class for storing a piece of text and associated metadata."""
page_content: str
"""String text."""
metadata: dict = Field(default_factory=dict)
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
type: Literal["Document"] = "Document"
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
from langchain_core.documents import Document
class BaseDocumentTransformer(ABC):

View File

@ -0,0 +1,23 @@
from __future__ import annotations
from typing import Literal
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Field
class Document(Serializable):
"""Class for storing a piece of text and associated metadata."""
page_content: str
"""String text."""
metadata: dict = Field(default_factory=dict)
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
type: Literal["Document"] = "Document"
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True

View File

@ -1,14 +1,18 @@
"""Logic for selecting examples to include in prompts."""
from langchain_core.prompts.example_selector.length_based import (
from langchain_core.example_selectors.base import BaseExampleSelector
from langchain_core.example_selectors.length_based import (
LengthBasedExampleSelector,
)
from langchain_core.prompts.example_selector.semantic_similarity import (
from langchain_core.example_selectors.semantic_similarity import (
MaxMarginalRelevanceExampleSelector,
SemanticSimilarityExampleSelector,
sorted_values,
)
__all__ = [
"BaseExampleSelector",
"LengthBasedExampleSelector",
"MaxMarginalRelevanceExampleSelector",
"SemanticSimilarityExampleSelector",
"sorted_values",
]

View File

@ -2,7 +2,7 @@
import re
from typing import Callable, Dict, List
from langchain_core.prompts.example_selector.base import BaseExampleSelector
from langchain_core.example_selectors.base import BaseExampleSelector
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, validator

View File

@ -3,10 +3,10 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, Type
from langchain_core.prompts.example_selector.base import BaseExampleSelector
from langchain_core.embeddings import Embeddings
from langchain_core.example_selectors.base import BaseExampleSelector
from langchain_core.pydantic_v1 import BaseModel, Extra
from langchain_core.schema.embeddings import Embeddings
from langchain_core.schema.vectorstore import VectorStore
from langchain_core.vectorstores import VectorStore
def sorted_values(values: Dict[str, str]) -> List[Any]:

View File

@ -0,0 +1,48 @@
from typing import Any, Optional
class LangChainException(Exception):
"""General LangChain exception."""
class TracerException(LangChainException):
"""Base class for exceptions in tracers module."""
class OutputParserException(ValueError, LangChainException):
"""Exception that output parsers should raise to signify a parsing error.
This exists to differentiate parsing errors from other code or execution errors
that also may arise inside the output parser. OutputParserExceptions will be
available to catch and handle in ways to fix the parsing error, while other
errors will be raised.
Args:
error: The error that's being re-raised or an error message.
observation: String explanation of error which can be passed to a
model to try and remediate the issue.
llm_output: String model output which is error-ing.
send_to_llm: Whether to send the observation and llm_output back to an Agent
after an OutputParserException has been raised. This gives the underlying
model driving the agent the context that the previous output was improperly
structured, in the hopes that it will update the output to the correct
format.
"""
def __init__(
self,
error: Any,
observation: Optional[str] = None,
llm_output: Optional[str] = None,
send_to_llm: bool = False,
):
super(OutputParserException, self).__init__(error)
if send_to_llm:
if observation is None or llm_output is None:
raise ValueError(
"Arguments 'observation' & 'llm_output'"
" are required if 'send_to_llm' is True"
)
self.observation = observation
self.llm_output = llm_output
self.send_to_llm = send_to_llm

View File

@ -4,7 +4,7 @@ import warnings
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from langchain_core.schema import BaseCache
from langchain_core.caches import BaseCache
# DO NOT USE THESE VALUES DIRECTLY!

View File

@ -0,0 +1,12 @@
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.language_models.llms import LLM, BaseLLM
__all__ = [
"BaseLanguageModel",
"BaseChatModel",
"SimpleChatModel",
"BaseLLM",
"LLM",
"LanguageModelInput",
]

View File

@ -15,14 +15,14 @@ from typing import (
from typing_extensions import TypeAlias
from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string
from langchain_core.outputs import LLMResult
from langchain_core.prompts import PromptValue
from langchain_core.runnables import RunnableSerializable
from langchain_core.schema.messages import AnyMessage, BaseMessage, get_buffer_string
from langchain_core.schema.output import LLMResult
from langchain_core.schema.prompt import PromptValue
from langchain_core.utils import get_pydantic_field_names
if TYPE_CHECKING:
from langchain_core.callbacks.manager import Callbacks
from langchain_core.callbacks import Callbacks
@lru_cache(maxsize=None) # Cache the tokenizer
@ -74,8 +74,8 @@ class BaseLanguageModel(
@property
def InputType(self) -> TypeAlias:
"""Get the input type for this runnable."""
from langchain_core.prompts.base import StringPromptValue
from langchain_core.prompts.chat import ChatPromptValueConcrete
from langchain_core.prompts.string import StringPromptValue
# This is a version of LanguageModelInput which replaces the abstract
# base class BaseMessage with a union of its subclasses, which makes

View File

@ -14,36 +14,34 @@ from typing import (
cast,
)
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_core.callbacks.manager import (
from langchain_core.callbacks import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
BaseCallbackManager,
CallbackManager,
CallbackManagerForLLMRun,
Callbacks,
)
from langchain_core.globals import get_llm_cache
from langchain_core.load.dump import dumpd, dumps
from langchain_core.prompts.base import StringPromptValue
from langchain_core.prompts.chat import ChatPromptValue
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables import RunnableConfig
from langchain_core.schema import (
ChatGeneration,
ChatResult,
LLMResult,
PromptValue,
RunInfo,
)
from langchain_core.schema.language_model import BaseLanguageModel, LanguageModelInput
from langchain_core.schema.messages import (
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
from langchain_core.load import dumpd, dumps
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
BaseMessageChunk,
HumanMessage,
)
from langchain_core.schema.output import ChatGenerationChunk
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
LLMResult,
RunInfo,
)
from langchain_core.prompts import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables import RunnableConfig
def _get_verbosity() -> bool:

View File

@ -46,16 +46,13 @@ from langchain_core.callbacks.manager import (
Callbacks,
)
from langchain_core.globals import get_llm_cache
from langchain_core.load.dump import dumpd
from langchain_core.prompts.base import StringPromptValue
from langchain_core.prompts.chat import ChatPromptValue
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
from langchain_core.load import dumpd
from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
from langchain_core.prompts import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator, validator
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import get_config_list
from langchain_core.schema import Generation, LLMResult, PromptValue, RunInfo
from langchain_core.schema.language_model import BaseLanguageModel, LanguageModelInput
from langchain_core.schema.messages import AIMessage, BaseMessage, get_buffer_string
from langchain_core.schema.output import GenerationChunk
from langchain_core.runnables import RunnableConfig, get_config_list
logger = logging.getLogger(__name__)

View File

@ -0,0 +1,120 @@
from typing import List, Sequence, Union
from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
merge_content,
message_to_dict,
messages_to_dict,
)
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk
from langchain_core.messages.human import HumanMessage, HumanMessageChunk
from langchain_core.messages.system import SystemMessage, SystemMessageChunk
from langchain_core.messages.tool import ToolMessage, ToolMessageChunk
AnyMessage = Union[
AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage, ToolMessage
]
def get_buffer_string(
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Convert sequence of Messages to strings and concatenate them into one string.
Args:
messages: Messages to be converted to strings.
human_prefix: The prefix to prepend to contents of HumanMessages.
ai_prefix: THe prefix to prepend to contents of AIMessages.
Returns:
A single string concatenation of all input messages.
Example:
.. code-block:: python
from langchain_core import AIMessage, HumanMessage
messages = [
HumanMessage(content="Hi, how are you?"),
AIMessage(content="Good, how are you?"),
]
get_buffer_string(messages)
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
"""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
elif isinstance(m, SystemMessage):
role = "System"
elif isinstance(m, FunctionMessage):
role = "Function"
elif isinstance(m, ChatMessage):
role = m.role
else:
raise ValueError(f"Got unsupported message type: {m}")
message = f"{role}: {m.content}"
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
message += f"{m.additional_kwargs['function_call']}"
string_messages.append(message)
return "\n".join(string_messages)
def _message_from_dict(message: dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
elif _type == "ai":
return AIMessage(**message["data"])
elif _type == "system":
return SystemMessage(**message["data"])
elif _type == "chat":
return ChatMessage(**message["data"])
elif _type == "function":
return FunctionMessage(**message["data"])
elif _type == "tool":
return ToolMessage(**message["data"])
else:
raise ValueError(f"Got unexpected message type: {_type}")
def messages_from_dict(messages: Sequence[dict]) -> List[BaseMessage]:
"""Convert a sequence of messages from dicts to Message objects.
Args:
messages: Sequence of messages (as dicts) to convert.
Returns:
List of messages (BaseMessages).
"""
return [_message_from_dict(m) for m in messages]
__all__ = [
"AIMessage",
"AIMessageChunk",
"AnyMessage",
"BaseMessage",
"BaseMessageChunk",
"ChatMessage",
"ChatMessageChunk",
"FunctionMessage",
"FunctionMessageChunk",
"HumanMessage",
"HumanMessageChunk",
"SystemMessage",
"SystemMessageChunk",
"ToolMessage",
"ToolMessageChunk",
"get_buffer_string",
"messages_from_dict",
"messages_to_dict",
"message_to_dict",
"merge_content",
]

View File

@ -0,0 +1,47 @@
from typing import Any, Literal
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
merge_content,
)
class AIMessage(BaseMessage):
"""A Message from an AI."""
example: bool = False
"""Whether this Message is being passed in to the model as part of an example
conversation.
"""
type: Literal["ai"] = "ai"
AIMessage.update_forward_refs()
class AIMessageChunk(AIMessage, BaseMessageChunk):
"""A Message chunk from an AI."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):
if self.example != other.example:
raise ValueError(
"Cannot concatenate AIMessageChunks with different example values."
)
return self.__class__(
example=self.example,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)

View File

@ -0,0 +1,126 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Extra, Field
if TYPE_CHECKING:
from langchain_core.prompts.chat import ChatPromptTemplate
class BaseMessage(Serializable):
"""The base abstract Message class.
Messages are the inputs and outputs of ChatModels.
"""
content: Union[str, List[Union[str, Dict]]]
"""The string contents of the message."""
additional_kwargs: dict = Field(default_factory=dict)
"""Any additional information."""
type: str
class Config:
extra = Extra.allow
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
def __add__(self, other: Any) -> ChatPromptTemplate:
from langchain_core.prompts.chat import ChatPromptTemplate
prompt = ChatPromptTemplate(messages=[self])
return prompt + other
def merge_content(
first_content: Union[str, List[Union[str, Dict]]],
second_content: Union[str, List[Union[str, Dict]]],
) -> Union[str, List[Union[str, Dict]]]:
# If first chunk is a string
if isinstance(first_content, str):
# If the second chunk is also a string, then merge them naively
if isinstance(second_content, str):
return first_content + second_content
# If the second chunk is a list, add the first chunk to the start of the list
else:
return_list: List[Union[str, Dict]] = [first_content]
return return_list + second_content
# If both are lists, merge them naively
elif isinstance(second_content, List):
return first_content + second_content
# If the first content is a list, and the second content is a string
else:
# If the last element of the first content is a string
# Add the second content to the last element
if isinstance(first_content[-1], str):
return first_content[:-1] + [first_content[-1] + second_content]
else:
# Otherwise, add the second content as a new element of the list
return first_content + [second_content]
class BaseMessageChunk(BaseMessage):
"""A Message chunk, which can be concatenated with other Message chunks."""
def _merge_kwargs_dict(
self, left: Dict[str, Any], right: Dict[str, Any]
) -> Dict[str, Any]:
"""Merge additional_kwargs from another BaseMessageChunk into this one."""
merged = left.copy()
for k, v in right.items():
if k not in merged:
merged[k] = v
elif type(merged[k]) != type(v):
raise ValueError(
f'additional_kwargs["{k}"] already exists in this message,'
" but with a different type."
)
elif isinstance(merged[k], str):
merged[k] += v
elif isinstance(merged[k], dict):
merged[k] = self._merge_kwargs_dict(merged[k], v)
else:
raise ValueError(
f"Additional kwargs key {k} already exists in this message."
)
return merged
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, BaseMessageChunk):
# If both are (subclasses of) BaseMessageChunk,
# concat into a single BaseMessageChunk
return self.__class__(
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
else:
raise TypeError(
'unsupported operand type(s) for +: "'
f"{self.__class__.__name__}"
f'" and "{other.__class__.__name__}"'
)
def message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
"""Convert a sequence of Messages to a list of dictionaries.
Args:
messages: Sequence of messages (as BaseMessages) to convert.
Returns:
List of messages as dicts.
"""
return [message_to_dict(m) for m in messages]

View File

@ -0,0 +1,53 @@
from typing import Any, Literal
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
merge_content,
)
class ChatMessage(BaseMessage):
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
role: str
"""The speaker / role of the Message."""
type: Literal["chat"] = "chat"
ChatMessage.update_forward_refs()
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
"""A Chat Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["ChatMessageChunk"] = "ChatMessageChunk" # type: ignore
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, ChatMessageChunk):
if self.role != other.role:
raise ValueError(
"Cannot concatenate ChatMessageChunks with different roles."
)
return self.__class__(
role=self.role,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
elif isinstance(other, BaseMessageChunk):
return self.__class__(
role=self.role,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
else:
return super().__add__(other)

View File

@ -0,0 +1,45 @@
from typing import Any, Literal
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
merge_content,
)
class FunctionMessage(BaseMessage):
"""A Message for passing the result of executing a function back to a model."""
name: str
"""The name of the function that was executed."""
type: Literal["function"] = "function"
FunctionMessage.update_forward_refs()
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
"""A Function Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["FunctionMessageChunk"] = "FunctionMessageChunk" # type: ignore[assignment]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, FunctionMessageChunk):
if self.name != other.name:
raise ValueError(
"Cannot concatenate FunctionMessageChunks with different names."
)
return self.__class__(
name=self.name,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)

View File

@ -0,0 +1,26 @@
from typing import Literal
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
class HumanMessage(BaseMessage):
"""A Message from a human."""
example: bool = False
"""Whether this Message is being passed in to the model as part of an example
conversation.
"""
type: Literal["human"] = "human"
HumanMessage.update_forward_refs()
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
"""A Human Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["HumanMessageChunk"] = "HumanMessageChunk" # type: ignore[assignment] # noqa: E501

View File

@ -0,0 +1,23 @@
from typing import Literal
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
class SystemMessage(BaseMessage):
"""A Message for priming AI behavior, usually passed in as the first of a sequence
of input messages.
"""
type: Literal["system"] = "system"
SystemMessage.update_forward_refs()
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
"""A System Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["SystemMessageChunk"] = "SystemMessageChunk" # type: ignore[assignment] # noqa: E501

View File

@ -0,0 +1,45 @@
from typing import Any, Literal
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
merge_content,
)
class ToolMessage(BaseMessage):
"""A Message for passing the result of executing a tool back to a model."""
tool_call_id: str
"""Tool call that this message is responding to."""
type: Literal["tool"] = "tool"
ToolMessage.update_forward_refs()
class ToolMessageChunk(ToolMessage, BaseMessageChunk):
"""A Tool Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, ToolMessageChunk):
if self.tool_call_id != other.tool_call_id:
raise ValueError(
"Cannot concatenate ToolMessageChunks with different names."
)
return self.__class__(
tool_call_id=self.tool_call_id,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)

View File

@ -0,0 +1,29 @@
from langchain_core.output_parsers.base import (
BaseGenerationOutputParser,
BaseLLMOutputParser,
BaseOutputParser,
)
from langchain_core.output_parsers.list import (
CommaSeparatedListOutputParser,
ListOutputParser,
MarkdownListOutputParser,
NumberedListOutputParser,
)
from langchain_core.output_parsers.str import StrOutputParser
from langchain_core.output_parsers.transform import (
BaseCumulativeTransformOutputParser,
BaseTransformOutputParser,
)
__all__ = [
"BaseLLMOutputParser",
"BaseGenerationOutputParser",
"BaseOutputParser",
"ListOutputParser",
"CommaSeparatedListOutputParser",
"NumberedListOutputParser",
"MarkdownListOutputParser",
"StrOutputParser",
"BaseTransformOutputParser",
"BaseCumulativeTransformOutputParser",
]

View File

@ -5,10 +5,8 @@ import functools
from abc import ABC, abstractmethod
from typing import (
Any,
AsyncIterator,
Dict,
Generic,
Iterator,
List,
Optional,
Type,
@ -18,15 +16,13 @@ from typing import (
from typing_extensions import get_args
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
from langchain_core.schema.output import (
from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
Generation,
GenerationChunk,
)
from langchain_core.schema.prompt import PromptValue
from langchain_core.prompts.value import PromptValue
from langchain_core.runnables import RunnableConfig, RunnableSerializable
T = TypeVar("T")
@ -303,173 +299,3 @@ class BaseOutputParser(
except NotImplementedError:
pass
return output_parser_dict
class BaseTransformOutputParser(BaseOutputParser[T]):
"""Base class for an output parser that can handle streaming input."""
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]:
for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
else:
yield self.parse_result([Generation(text=chunk)])
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]:
async for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
else:
yield self.parse_result([Generation(text=chunk)])
def transform(
self,
input: Iterator[Union[str, BaseMessage]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[T]:
yield from self._transform_stream_with_config(
input, self._transform, config, run_type="parser"
)
async def atransform(
self,
input: AsyncIterator[Union[str, BaseMessage]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[T]:
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, run_type="parser"
):
yield chunk
class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
"""Base class for an output parser that can handle streaming input."""
diff: bool = False
"""In streaming mode, whether to yield diffs between the previous and current
parsed output, or just the current parsed output.
"""
def _diff(self, prev: Optional[T], next: T) -> T:
"""Convert parsed outputs into a diff format. The semantics of this are
up to the output parser."""
raise NotImplementedError()
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
prev_parsed = None
acc_gen = None
for chunk in input:
if isinstance(chunk, BaseMessageChunk):
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
)
else:
chunk_gen = GenerationChunk(text=chunk)
if acc_gen is None:
acc_gen = chunk_gen
else:
acc_gen += chunk_gen
parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
else:
yield parsed
prev_parsed = parsed
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]:
prev_parsed = None
acc_gen = None
async for chunk in input:
if isinstance(chunk, BaseMessageChunk):
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
)
else:
chunk_gen = GenerationChunk(text=chunk)
if acc_gen is None:
acc_gen = chunk_gen
else:
acc_gen += chunk_gen
parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
else:
yield parsed
prev_parsed = parsed
class StrOutputParser(BaseTransformOutputParser[str]):
"""OutputParser that parses LLMResult into the top likely string."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@property
def _type(self) -> str:
"""Return the output parser type for serialization."""
return "default"
def parse(self, text: str) -> str:
"""Returns the input text with no changes."""
return text
# TODO: Deprecate
NoOpOutputParser = StrOutputParser
class OutputParserException(ValueError):
"""Exception that output parsers should raise to signify a parsing error.
This exists to differentiate parsing errors from other code or execution errors
that also may arise inside the output parser. OutputParserExceptions will be
available to catch and handle in ways to fix the parsing error, while other
errors will be raised.
Args:
error: The error that's being re-raised or an error message.
observation: String explanation of error which can be passed to a
model to try and remediate the issue.
llm_output: String model output which is error-ing.
send_to_llm: Whether to send the observation and llm_output back to an Agent
after an OutputParserException has been raised. This gives the underlying
model driving the agent the context that the previous output was improperly
structured, in the hopes that it will update the output to the correct
format.
"""
def __init__(
self,
error: Any,
observation: Optional[str] = None,
llm_output: Optional[str] = None,
send_to_llm: bool = False,
):
super(OutputParserException, self).__init__(error)
if send_to_llm:
if observation is None or llm_output is None:
raise ValueError(
"Arguments 'observation' & 'llm_output'"
" are required if 'send_to_llm' is True"
)
self.observation = observation
self.llm_output = llm_output
self.send_to_llm = send_to_llm

View File

@ -4,7 +4,7 @@ import re
from abc import abstractmethod
from typing import List
from langchain_core.schema import BaseOutputParser
from langchain_core.output_parsers.base import BaseOutputParser
class ListOutputParser(BaseOutputParser[List[str]]):

View File

@ -0,0 +1,19 @@
from langchain_core.output_parsers.transform import BaseTransformOutputParser
class StrOutputParser(BaseTransformOutputParser[str]):
"""OutputParser that parses LLMResult into the top likely string."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@property
def _type(self) -> str:
"""Return the output parser type for serialization."""
return "default"
def parse(self, text: str) -> str:
"""Returns the input text with no changes."""
return text

View File

@ -0,0 +1,128 @@
from __future__ import annotations
from typing import (
Any,
AsyncIterator,
Iterator,
Optional,
Union,
)
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.output_parsers.base import BaseOutputParser, T
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
Generation,
GenerationChunk,
)
from langchain_core.runnables import RunnableConfig
class BaseTransformOutputParser(BaseOutputParser[T]):
"""Base class for an output parser that can handle streaming input."""
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]:
for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
else:
yield self.parse_result([Generation(text=chunk)])
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]:
async for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
else:
yield self.parse_result([Generation(text=chunk)])
def transform(
self,
input: Iterator[Union[str, BaseMessage]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[T]:
yield from self._transform_stream_with_config(
input, self._transform, config, run_type="parser"
)
async def atransform(
self,
input: AsyncIterator[Union[str, BaseMessage]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[T]:
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, run_type="parser"
):
yield chunk
class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
"""Base class for an output parser that can handle streaming input."""
diff: bool = False
"""In streaming mode, whether to yield diffs between the previous and current
parsed output, or just the current parsed output.
"""
def _diff(self, prev: Optional[T], next: T) -> T:
"""Convert parsed outputs into a diff format. The semantics of this are
up to the output parser."""
raise NotImplementedError()
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
prev_parsed = None
acc_gen = None
for chunk in input:
if isinstance(chunk, BaseMessageChunk):
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
)
else:
chunk_gen = GenerationChunk(text=chunk)
if acc_gen is None:
acc_gen = chunk_gen
else:
acc_gen += chunk_gen
parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
else:
yield parsed
prev_parsed = parsed
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]:
prev_parsed = None
acc_gen = None
async for chunk in input:
if isinstance(chunk, BaseMessageChunk):
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
)
else:
chunk_gen = GenerationChunk(text=chunk)
if acc_gen is None:
acc_gen = chunk_gen
else:
acc_gen += chunk_gen
parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
else:
yield parsed
prev_parsed = parsed

View File

@ -0,0 +1,15 @@
from langchain_core.outputs.chat_generation import ChatGeneration, ChatGenerationChunk
from langchain_core.outputs.chat_result import ChatResult
from langchain_core.outputs.generation import Generation, GenerationChunk
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.outputs.run_info import RunInfo
__all__ = [
"ChatGeneration",
"ChatGenerationChunk",
"ChatResult",
"Generation",
"GenerationChunk",
"LLMResult",
"RunInfo",
]

View File

@ -0,0 +1,58 @@
from __future__ import annotations
from typing import Any, Dict, Literal
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation
from langchain_core.pydantic_v1 import root_validator
class ChatGeneration(Generation):
"""A single chat generation output."""
text: str = ""
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
message: BaseMessage
"""The message output by the chat model."""
# Override type to be ChatGeneration, ignore mypy error as this is intentional
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
"""Type is used exclusively for serialization purposes."""
@root_validator
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Set the text attribute to be the contents of the message."""
try:
values["text"] = values["message"].content
except (KeyError, AttributeError) as e:
raise ValueError("Error while initializing ChatGeneration") from e
return values
class ChatGenerationChunk(ChatGeneration):
"""A ChatGeneration chunk, which can be concatenated with other
ChatGeneration chunks.
Attributes:
message: The message chunk output by the chat model.
"""
message: BaseMessageChunk
# Override type to be ChatGeneration, ignore mypy error as this is intentional
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501
"""Type is used exclusively for serialization purposes."""
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk):
generation_info = (
{**(self.generation_info or {}), **(other.generation_info or {})}
if self.generation_info is not None or other.generation_info is not None
else None
)
return ChatGenerationChunk(
message=self.message + other.message,
generation_info=generation_info,
)
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)

View File

@ -0,0 +1,15 @@
from typing import List, Optional
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.pydantic_v1 import BaseModel
class ChatResult(BaseModel):
"""Class that contains all results for a single chat model call."""
generations: List[ChatGeneration]
"""List of the chat generations. This is a List because an input can have multiple
candidate generations.
"""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""

View File

@ -0,0 +1,45 @@
from __future__ import annotations
from typing import Any, Dict, Literal, Optional
from langchain_core.load import Serializable
class Generation(Serializable):
"""A single text generation output."""
text: str
"""Generated text output."""
generation_info: Optional[Dict[str, Any]] = None
"""Raw response from the provider. May include things like the
reason for finishing or token log probabilities.
"""
type: Literal["Generation"] = "Generation"
"""Type is used exclusively for serialization purposes."""
# TODO: add log probs as separate attribute
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
class GenerationChunk(Generation):
"""A Generation chunk, which can be concatenated with other Generation chunks."""
def __add__(self, other: GenerationChunk) -> GenerationChunk:
if isinstance(other, GenerationChunk):
generation_info = (
{**(self.generation_info or {}), **(other.generation_info or {})}
if self.generation_info is not None or other.generation_info is not None
else None
)
return GenerationChunk(
text=self.text + other.text,
generation_info=generation_info,
)
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)

View File

@ -0,0 +1,65 @@
from __future__ import annotations
from copy import deepcopy
from typing import List, Optional
from langchain_core.outputs.generation import Generation
from langchain_core.outputs.run_info import RunInfo
from langchain_core.pydantic_v1 import BaseModel
class LLMResult(BaseModel):
"""Class that contains all results for a batched LLM call."""
generations: List[List[Generation]]
"""List of generated outputs. This is a List[List[]] because
each input could have multiple candidate generations."""
llm_output: Optional[dict] = None
"""Arbitrary LLM provider-specific output."""
run: Optional[List[RunInfo]] = None
"""List of metadata info for model call for each input."""
def flatten(self) -> List[LLMResult]:
"""Flatten generations into a single list.
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult
contains only a single Generation. If token usage information is available,
it is kept only for the LLMResult corresponding to the top-choice
Generation, to avoid over-counting of token usage downstream.
Returns:
List of LLMResults where each returned LLMResult contains a single
Generation.
"""
llm_results = []
for i, gen_list in enumerate(self.generations):
# Avoid double counting tokens in OpenAICallback
if i == 0:
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=self.llm_output,
)
)
else:
if self.llm_output is not None:
llm_output = deepcopy(self.llm_output)
llm_output["token_usage"] = dict()
else:
llm_output = None
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=llm_output,
)
)
return llm_results
def __eq__(self, other: object) -> bool:
"""Check for LLMResult equality by ignoring any metadata related to runs."""
if not isinstance(other, LLMResult):
return NotImplemented
return (
self.generations == other.generations
and self.llm_output == other.llm_output
)

View File

@ -0,0 +1,12 @@
from __future__ import annotations
from uuid import UUID
from langchain_core.pydantic_v1 import BaseModel
class RunInfo(BaseModel):
"""Class that contains metadata for a single execution of a Chain or model."""
run_id: UUID
"""A unique identifier for the model or chain run."""

View File

@ -27,21 +27,18 @@ from multiple components. Prompt classes and functions make constructing
ChatPromptValue
""" # noqa: E501
from langchain_core.prompts.base import StringPromptTemplate
from langchain_core.prompts.base import BasePromptTemplate, format_document
from langchain_core.prompts.chat import (
AIMessagePromptTemplate,
BaseChatPromptTemplate,
ChatMessagePromptTemplate,
ChatPromptTemplate,
ChatPromptValue,
ChatPromptValueConcrete,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from langchain_core.prompts.example_selector import (
LengthBasedExampleSelector,
MaxMarginalRelevanceExampleSelector,
SemanticSimilarityExampleSelector,
)
from langchain_core.prompts.few_shot import (
FewShotChatMessagePromptTemplate,
FewShotPromptTemplate,
@ -50,7 +47,7 @@ from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemp
from langchain_core.prompts.loading import load_prompt
from langchain_core.prompts.pipeline import PipelinePromptTemplate
from langchain_core.prompts.prompt import Prompt, PromptTemplate
from langchain_core.schema.prompt_template import BasePromptTemplate
from langchain_core.prompts.string import StringPromptTemplate, StringPromptValue
__all__ = [
"AIMessagePromptTemplate",
@ -58,18 +55,22 @@ __all__ = [
"BasePromptTemplate",
"ChatMessagePromptTemplate",
"ChatPromptTemplate",
"ChatPromptValue",
"ChatPromptValueConcrete",
"FewShotPromptTemplate",
"FewShotPromptWithTemplates",
"FewShotChatMessagePromptTemplate",
"HumanMessagePromptTemplate",
"LengthBasedExampleSelector",
"MaxMarginalRelevanceExampleSelector",
"MessagesPlaceholder",
"PipelinePromptTemplate",
"Prompt",
"PromptTemplate",
"SemanticSimilarityExampleSelector",
"PromptValue",
"StringPromptValue",
"StringPromptTemplate",
"SystemMessagePromptTemplate",
"load_prompt",
"FewShotChatMessagePromptTemplate",
"format_document",
]
from langchain_core.prompts.value import PromptValue

View File

@ -1,173 +1,228 @@
"""BasePrompt schema definition."""
from __future__ import annotations
import warnings
from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict, List, Literal, Set
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union
from langchain_core.schema.messages import BaseMessage, HumanMessage
from langchain_core.schema.prompt import PromptValue
from langchain_core.schema.prompt_template import BasePromptTemplate
from langchain_core.utils.formatting import formatter
import yaml
from langchain_core.documents import Document
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts.value import PromptValue
from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator
from langchain_core.runnables import RunnableConfig, RunnableSerializable
def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2.
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
"""Base class for all prompt templates, returning a prompt."""
*Security warning*: As of LangChain 0.0.329, this method uses Jinja2's
SandboxedEnvironment by default. However, this sand-boxing should
be treated as a best-effort approach rather than a guarantee of security.
Do not accept jinja2 templates from untrusted sources as they may lead
to arbitrary Python code execution.
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
input_types: Dict[str, Any] = Field(default_factory=dict)
"""A dictionary of the types of the variables the prompt template expects.
If not provided, all variables are assumed to be strings."""
output_parser: Optional[BaseOutputParser] = None
"""How to parse the output of calling an LLM on this formatted prompt."""
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
default_factory=dict
)
https://jinja.palletsprojects.com/en/3.1.x/sandbox/
"""
try:
from jinja2.sandbox import SandboxedEnvironment
except ImportError:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
"Please be cautious when using jinja2 templates. "
"Do not expand jinja2 templates using unverified or user-controlled "
"inputs as that can result in arbitrary Python code execution."
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def OutputType(self) -> Any:
from langchain_core.prompts.chat import ChatPromptValueConcrete
from langchain_core.prompts.string import StringPromptValue
return Union[StringPromptValue, ChatPromptValueConcrete]
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"PromptInput",
**{k: (self.input_types.get(k, str), None) for k in self.input_variables},
)
# This uses a sandboxed environment to prevent arbitrary code execution.
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
# Please treat this sand-boxing as a best-effort approach rather than
# a guarantee of security.
# We recommend to never use jinja2 templates with untrusted inputs.
# https://jinja.palletsprojects.com/en/3.1.x/sandbox/
# approach not a guarantee of security.
return SandboxedEnvironment().from_string(template).render(**kwargs)
def validate_jinja2(template: str, input_variables: List[str]) -> None:
"""
Validate that the input variables are valid for the template.
Issues a warning if missing or extra variables are found.
Args:
template: The template string.
input_variables: The input variables.
"""
input_variables_set = set(input_variables)
valid_variables = _get_jinja2_variables_from_template(template)
missing_variables = valid_variables - input_variables_set
extra_variables = input_variables_set - valid_variables
warning_message = ""
if missing_variables:
warning_message += f"Missing variables: {missing_variables} "
if extra_variables:
warning_message += f"Extra variables: {extra_variables}"
if warning_message:
warnings.warn(warning_message.strip())
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
try:
from jinja2 import Environment, meta
except ImportError:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
)
env = Environment()
ast = env.parse(template)
variables = meta.find_undeclared_variables(ast)
return variables
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": formatter.format,
"jinja2": jinja2_formatter,
}
DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
"f-string": formatter.validate_input_variables,
"jinja2": validate_jinja2,
}
def check_valid_template(
template: str, template_format: str, input_variables: List[str]
) -> None:
"""Check that template string is valid.
Args:
template: The template string.
template_format: The template format. Should be one of "f-string" or "jinja2".
input_variables: The input variables.
Raises:
ValueError: If the template format is not supported.
"""
if template_format not in DEFAULT_FORMATTER_MAPPING:
valid_formats = list(DEFAULT_FORMATTER_MAPPING)
raise ValueError(
f"Invalid template format. Got `{template_format}`;"
f" should be one of {valid_formats}"
)
try:
validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
validator_func(template, input_variables)
except KeyError as e:
raise ValueError(
"Invalid prompt schema; check for mismatched or missing input parameters. "
+ str(e)
def invoke(
self, input: Dict, config: Optional[RunnableConfig] = None
) -> PromptValue:
return self._call_with_config(
lambda inner_input: self.format_prompt(
**{key: inner_input[key] for key in self.input_variables}
),
input,
config,
run_type="prompt",
)
def get_template_variables(template: str, template_format: str) -> List[str]:
"""Get the variables from the template.
Args:
template: The template string.
template_format: The template format. Should be one of "f-string" or "jinja2".
Returns:
The variables from the template.
Raises:
ValueError: If the template format is not supported.
"""
if template_format == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
elif template_format == "f-string":
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
else:
raise ValueError(f"Unsupported template format: {template_format}")
return sorted(input_variables)
class StringPromptValue(PromptValue):
"""String prompt value."""
text: str
"""Prompt text."""
type: Literal["StringPromptValue"] = "StringPromptValue"
def to_string(self) -> str:
"""Return prompt as string."""
return self.text
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
return [HumanMessage(content=self.text)]
class StringPromptTemplate(BasePromptTemplate, ABC):
"""String prompt that exposes the format method, returning a prompt."""
@abstractmethod
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
return StringPromptValue(text=self.format(**kwargs))
@root_validator()
def validate_variable_names(cls, values: Dict) -> Dict:
"""Validate variable names do not include restricted names."""
if "stop" in values["input_variables"]:
raise ValueError(
"Cannot have an input variable named 'stop', as it is used internally,"
" please rename."
)
if "stop" in values["partial_variables"]:
raise ValueError(
"Cannot have an partial variable named 'stop', as it is used "
"internally, please rename."
)
overall = set(values["input_variables"]).intersection(
values["partial_variables"]
)
if overall:
raise ValueError(
f"Found overlapping input and partial variables: {overall}"
)
return values
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
"""Return a partial of the prompt template."""
prompt_dict = self.__dict__.copy()
prompt_dict["input_variables"] = list(
set(self.input_variables).difference(kwargs)
)
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
# Get partial params:
partial_kwargs = {
k: v if isinstance(v, str) else v()
for k, v in self.partial_variables.items()
}
return {**partial_kwargs, **kwargs}
@abstractmethod
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of prompt."""
prompt_dict = super().dict(**kwargs)
try:
prompt_dict["_type"] = self._prompt_type
except NotImplementedError:
pass
return prompt_dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the prompt.
Args:
file_path: Path to directory to save prompt to.
Example:
.. code-block:: python
prompt.save(file_path="path/prompt.yaml")
"""
if self.partial_variables:
raise ValueError("Cannot save prompt with partial variables.")
# Fetch dictionary to save
prompt_dict = self.dict()
if "_type" not in prompt_dict:
raise NotImplementedError(f"Prompt {self} does not support saving.")
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
"""Format a document into a string based on a prompt template.
First, this pulls information from the document from two sources:
1. `page_content`:
This takes the information from the `document.page_content`
and assigns it to a variable named `page_content`.
2. metadata:
This takes information from `document.metadata` and assigns
it to variables of the same name.
Those variables are then passed into the `prompt` to produce a formatted string.
Args:
doc: Document, the page_content and metadata will be used to create
the final string.
prompt: BasePromptTemplate, will be used to format the page_content
and metadata into the final string.
Returns:
string of the document formatted.
Example:
.. code-block:: python
from langchain_core import Document
from langchain_core.prompts import PromptTemplate
doc = Document(page_content="This is a joke", metadata={"page": "1"})
prompt = PromptTemplate.from_template("Page {page}: {page_content}")
format_document(doc, prompt)
>>> "Page 1: This is a joke"
"""
base_info = {"page_content": doc.page_content, **doc.metadata}
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
]
raise ValueError(
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
document_info = {k: base_info[k] for k in prompt.input_variables}
return prompt.format(**document_info)

View File

@ -19,15 +19,8 @@ from typing import (
)
from langchain_core._api import deprecated
from langchain_core.load.serializable import Serializable
from langchain_core.prompts.base import StringPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.schema import (
BasePromptTemplate,
PromptValue,
)
from langchain_core.schema.messages import (
from langchain_core.load import Serializable
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
@ -36,6 +29,11 @@ from langchain_core.schema.messages import (
SystemMessage,
get_buffer_string,
)
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate
from langchain_core.prompts.value import PromptValue
from langchain_core.pydantic_v1 import Field, root_validator
class BaseMessagePromptTemplate(Serializable, ABC):

View File

@ -4,20 +4,19 @@ from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from langchain_core.prompts.base import (
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.prompts.chat import (
BaseChatPromptTemplate,
BaseMessagePromptTemplate,
)
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
check_valid_template,
get_template_variables,
)
from langchain_core.prompts.chat import (
BaseChatPromptTemplate,
BaseMessagePromptTemplate,
)
from langchain_core.prompts.example_selector.base import BaseExampleSelector
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.schema.messages import BaseMessage, get_buffer_string
class _FewShotPromptTemplateMixin(BaseModel):
@ -27,7 +26,7 @@ class _FewShotPromptTemplateMixin(BaseModel):
"""Examples to format into the prompt.
Either this or example_selector should be provided."""
example_selector: Optional[BaseExampleSelector] = None
example_selector: Any = None
"""ExampleSelector to choose the examples to format into the prompt.
Either this or examples should be provided."""
@ -253,7 +252,7 @@ class FewShotChatMessagePromptTemplate(
vectorstore=vectorstore
)
from langchain_core.schema import SystemMessage
from langchain_core import SystemMessage
from langchain_core.prompts import HumanMessagePromptTemplate
from langchain_core.prompts.few_shot import FewShotChatMessagePromptTemplate

View File

@ -2,9 +2,11 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from langchain_core.prompts.base import DEFAULT_FORMATTER_MAPPING, StringPromptTemplate
from langchain_core.prompts.example_selector.base import BaseExampleSelector
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
)
from langchain_core.pydantic_v1 import Extra, root_validator
@ -15,7 +17,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
"""Examples to format into the prompt.
Either this or example_selector should be provided."""
example_selector: Optional[BaseExampleSelector] = None
example_selector: Any = None
"""ExampleSelector to choose the examples to format into the prompt.
Either this or examples should be provided."""

View File

@ -6,13 +6,11 @@ from typing import Callable, Dict, Union
import yaml
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.schema import (
BasePromptTemplate,
StrOutputParser,
)
from langchain_core.utils.loading import try_load_from_hub
from langchain_core.utils import try_load_from_hub
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
logger = logging.getLogger(__name__)

View File

@ -1,8 +1,9 @@
from typing import Any, Dict, List, Tuple
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.chat import BaseChatPromptTemplate
from langchain_core.prompts.value import PromptValue
from langchain_core.pydantic_v1 import root_validator
from langchain_core.schema import BasePromptTemplate, PromptValue
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from langchain_core.prompts.base import (
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
check_valid_template,

View File

@ -0,0 +1,173 @@
"""BasePrompt schema definition."""
from __future__ import annotations
import warnings
from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict, List, Literal, Set
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.value import PromptValue
from langchain_core.utils.formatting import formatter
def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2.
*Security warning*: As of LangChain 0.0.329, this method uses Jinja2's
SandboxedEnvironment by default. However, this sand-boxing should
be treated as a best-effort approach rather than a guarantee of security.
Do not accept jinja2 templates from untrusted sources as they may lead
to arbitrary Python code execution.
https://jinja.palletsprojects.com/en/3.1.x/sandbox/
"""
try:
from jinja2.sandbox import SandboxedEnvironment
except ImportError:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
"Please be cautious when using jinja2 templates. "
"Do not expand jinja2 templates using unverified or user-controlled "
"inputs as that can result in arbitrary Python code execution."
)
# This uses a sandboxed environment to prevent arbitrary code execution.
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
# Please treat this sand-boxing as a best-effort approach rather than
# a guarantee of security.
# We recommend to never use jinja2 templates with untrusted inputs.
# https://jinja.palletsprojects.com/en/3.1.x/sandbox/
# approach not a guarantee of security.
return SandboxedEnvironment().from_string(template).render(**kwargs)
def validate_jinja2(template: str, input_variables: List[str]) -> None:
"""
Validate that the input variables are valid for the template.
Issues a warning if missing or extra variables are found.
Args:
template: The template string.
input_variables: The input variables.
"""
input_variables_set = set(input_variables)
valid_variables = _get_jinja2_variables_from_template(template)
missing_variables = valid_variables - input_variables_set
extra_variables = input_variables_set - valid_variables
warning_message = ""
if missing_variables:
warning_message += f"Missing variables: {missing_variables} "
if extra_variables:
warning_message += f"Extra variables: {extra_variables}"
if warning_message:
warnings.warn(warning_message.strip())
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
try:
from jinja2 import Environment, meta
except ImportError:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
)
env = Environment()
ast = env.parse(template)
variables = meta.find_undeclared_variables(ast)
return variables
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": formatter.format,
"jinja2": jinja2_formatter,
}
DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
"f-string": formatter.validate_input_variables,
"jinja2": validate_jinja2,
}
def check_valid_template(
template: str, template_format: str, input_variables: List[str]
) -> None:
"""Check that template string is valid.
Args:
template: The template string.
template_format: The template format. Should be one of "f-string" or "jinja2".
input_variables: The input variables.
Raises:
ValueError: If the template format is not supported.
"""
if template_format not in DEFAULT_FORMATTER_MAPPING:
valid_formats = list(DEFAULT_FORMATTER_MAPPING)
raise ValueError(
f"Invalid template format. Got `{template_format}`;"
f" should be one of {valid_formats}"
)
try:
validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
validator_func(template, input_variables)
except KeyError as e:
raise ValueError(
"Invalid prompt schema; check for mismatched or missing input parameters. "
+ str(e)
)
def get_template_variables(template: str, template_format: str) -> List[str]:
"""Get the variables from the template.
Args:
template: The template string.
template_format: The template format. Should be one of "f-string" or "jinja2".
Returns:
The variables from the template.
Raises:
ValueError: If the template format is not supported.
"""
if template_format == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
elif template_format == "f-string":
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
else:
raise ValueError(f"Unsupported template format: {template_format}")
return sorted(input_variables)
class StringPromptValue(PromptValue):
"""String prompt value."""
text: str
"""Prompt text."""
type: Literal["StringPromptValue"] = "StringPromptValue"
def to_string(self) -> str:
"""Return prompt as string."""
return self.text
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
return [HumanMessage(content=self.text)]
class StringPromptTemplate(BasePromptTemplate, ABC):
"""String prompt that exposes the format method, returning a prompt."""
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
return StringPromptValue(text=self.format(**kwargs))

View File

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from typing import List
from langchain_core.load.serializable import Serializable
from langchain_core.schema.messages import BaseMessage
from langchain_core.messages import BaseMessage
class PromptValue(Serializable, ABC):

View File

@ -7,9 +7,9 @@ from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.documents import Document
from langchain_core.load.dump import dumpd
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.schema.document import Document
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (

View File

@ -25,7 +25,11 @@ from langchain_core.runnables.base import (
RunnableSerializable,
)
from langchain_core.runnables.branch import RunnableBranch
from langchain_core.runnables.config import RunnableConfig, patch_config
from langchain_core.runnables.config import (
RunnableConfig,
get_config_list,
patch_config,
)
from langchain_core.runnables.fallbacks import RunnableWithFallbacks
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.router import RouterInput, RouterRunnable
@ -33,6 +37,7 @@ from langchain_core.runnables.utils import (
ConfigurableField,
ConfigurableFieldMultiOption,
ConfigurableFieldSingleOption,
add,
)
__all__ = [
@ -54,4 +59,6 @@ __all__ = [
"RunnablePassthrough",
"RunnableSequence",
"RunnableWithFallbacks",
"get_config_list",
"add",
]

View File

@ -36,11 +36,11 @@ if TYPE_CHECKING:
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.callbacks.tracers.log_stream import RunLog, RunLogPatch
from langchain_core.callbacks.tracers.root_listeners import Listener
from langchain_core.runnables.fallbacks import (
RunnableWithFallbacks as RunnableWithFallbacksT,
)
from langchain_core.tracers.log_stream import RunLog, RunLogPatch
from langchain_core.tracers.root_listeners import Listener
from langchain_core.load.dump import dumpd
from langchain_core.load.serializable import Serializable
@ -198,7 +198,7 @@ class Runnable(Generic[Input, Output], ABC):
... code-block:: python
from langchain_core.callbacks.tracers import ConsoleCallbackHandler
from langchain_core.tracers import ConsoleCallbackHandler
chain.invoke(
...,
@ -559,7 +559,7 @@ class Runnable(Generic[Input, Output], ABC):
"""
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_core.callbacks.tracers.log_stream import (
from langchain_core.tracers.log_stream import (
LogStreamCallbackHandler,
RunLog,
RunLogPatch,
@ -725,7 +725,7 @@ class Runnable(Generic[Input, Output], ABC):
type, input, output, error, start_time, end_time, and any tags or metadata
added to the run.
"""
from langchain_core.callbacks.tracers.root_listeners import RootListenersTracer
from langchain_core.tracers.root_listeners import RootListenersTracer
return RunnableBinding(
bound=self,
@ -2945,7 +2945,7 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
type, input, output, error, start_time, end_time, and any tags or metadata
added to the run.
"""
from langchain_core.callbacks.tracers.root_listeners import RootListenersTracer
from langchain_core.tracers.root_listeners import RootListenersTracer
return self.__class__(
bound=self.bound,

View File

@ -66,7 +66,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
# response.
from langchain_core.prompts import PromptTemplate
from langchain_core.schema.output_parser import StrOutputParser
from langchain_core.output_parser import StrOutputParser
from langchain_core.runnables import RunnableLambda
def when_all_is_lost(inputs):

View File

@ -13,6 +13,7 @@ from typing import (
Union,
)
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.load import load
from langchain_core.pydantic_v1 import BaseModel, create_model
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
@ -21,12 +22,11 @@ from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
get_unique_config_specs,
)
from langchain_core.schema.chat_history import BaseChatMessageHistory
if TYPE_CHECKING:
from langchain_core.callbacks.tracers.schemas import Run
from langchain_core.messages import BaseMessage
from langchain_core.runnables.config import RunnableConfig
from langchain_core.schema.messages import BaseMessage
from langchain_core.tracers.schemas import Run
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
@ -178,7 +178,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
) -> Type[BaseModel]:
super_schema = super().get_input_schema(config)
if super_schema.__custom_root_type__ is not None:
from langchain_core.schema.messages import BaseMessage
from langchain_core.messages import BaseMessage
fields: Dict = {}
if self.input_messages_key and self.history_messages_key:
@ -202,10 +202,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
def _get_input_messages(
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage]]
) -> List[BaseMessage]:
from langchain_core.schema.messages import BaseMessage
from langchain_core.messages import BaseMessage
if isinstance(input_val, str):
from langchain_core.schema.messages import HumanMessage
from langchain_core.messages import HumanMessage
return [HumanMessage(content=input_val)]
elif isinstance(input_val, BaseMessage):
@ -221,13 +221,13 @@ class RunnableWithMessageHistory(RunnableBindingBase):
def _get_output_messages(
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
) -> List[BaseMessage]:
from langchain_core.schema.messages import BaseMessage
from langchain_core.messages import BaseMessage
if isinstance(output_val, dict):
output_val = output_val[self.output_messages_key or "output"]
if isinstance(output_val, str):
from langchain_core.schema.messages import AIMessage
from langchain_core.messages import AIMessage
return [AIMessage(content=output_val)]
elif isinstance(output_val, BaseMessage):

View File

@ -1,78 +0,0 @@
"""**Schemas** are the LangChain Base Classes and Interfaces."""
from langchain_core.schema.agent import AgentAction, AgentFinish
from langchain_core.schema.cache import BaseCache
from langchain_core.schema.chat_history import BaseChatMessageHistory
from langchain_core.schema.document import BaseDocumentTransformer, Document
from langchain_core.schema.exceptions import LangChainException
from langchain_core.schema.memory import BaseMemory
from langchain_core.schema.messages import (
AIMessage,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
_message_from_dict,
_message_to_dict,
get_buffer_string,
messages_from_dict,
messages_to_dict,
)
from langchain_core.schema.output import (
ChatGeneration,
ChatResult,
Generation,
LLMResult,
RunInfo,
)
from langchain_core.schema.output_parser import (
BaseLLMOutputParser,
BaseOutputParser,
OutputParserException,
StrOutputParser,
)
from langchain_core.schema.prompt import PromptValue
from langchain_core.schema.prompt_template import BasePromptTemplate, format_document
from langchain_core.schema.retriever import BaseRetriever
from langchain_core.schema.storage import BaseStore
RUN_KEY = "__run"
Memory = BaseMemory
__all__ = [
"BaseCache",
"BaseMemory",
"BaseStore",
"AgentFinish",
"AgentAction",
"Document",
"BaseChatMessageHistory",
"BaseDocumentTransformer",
"BaseMessage",
"ChatMessage",
"FunctionMessage",
"HumanMessage",
"AIMessage",
"SystemMessage",
"messages_from_dict",
"messages_to_dict",
"_message_to_dict",
"_message_from_dict",
"get_buffer_string",
"RunInfo",
"LLMResult",
"ChatResult",
"ChatGeneration",
"Generation",
"PromptValue",
"LangChainException",
"BaseRetriever",
"RUN_KEY",
"Memory",
"OutputParserException",
"StrOutputParser",
"BaseOutputParser",
"BaseLLMOutputParser",
"BasePromptTemplate",
"format_document",
]

View File

@ -1,2 +0,0 @@
class LangChainException(Exception):
"""General LangChain exception."""

View File

@ -1,415 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
from typing_extensions import Literal
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Extra, Field
if TYPE_CHECKING:
from langchain_core.prompts.chat import ChatPromptTemplate
def get_buffer_string(
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Convert sequence of Messages to strings and concatenate them into one string.
Args:
messages: Messages to be converted to strings.
human_prefix: The prefix to prepend to contents of HumanMessages.
ai_prefix: THe prefix to prepend to contents of AIMessages.
Returns:
A single string concatenation of all input messages.
Example:
.. code-block:: python
from langchain_core.schema import AIMessage, HumanMessage
messages = [
HumanMessage(content="Hi, how are you?"),
AIMessage(content="Good, how are you?"),
]
get_buffer_string(messages)
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
"""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
elif isinstance(m, SystemMessage):
role = "System"
elif isinstance(m, FunctionMessage):
role = "Function"
elif isinstance(m, ChatMessage):
role = m.role
else:
raise ValueError(f"Got unsupported message type: {m}")
message = f"{role}: {m.content}"
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
message += f"{m.additional_kwargs['function_call']}"
string_messages.append(message)
return "\n".join(string_messages)
class BaseMessage(Serializable):
"""The base abstract Message class.
Messages are the inputs and outputs of ChatModels.
"""
content: Union[str, List[Union[str, Dict]]]
"""The string contents of the message."""
additional_kwargs: dict = Field(default_factory=dict)
"""Any additional information."""
type: str
class Config:
extra = Extra.allow
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
def __add__(self, other: Any) -> ChatPromptTemplate:
from langchain_core.prompts.chat import ChatPromptTemplate
prompt = ChatPromptTemplate(messages=[self])
return prompt + other
def merge_content(
first_content: Union[str, List[Union[str, Dict]]],
second_content: Union[str, List[Union[str, Dict]]],
) -> Union[str, List[Union[str, Dict]]]:
# If first chunk is a string
if isinstance(first_content, str):
# If the second chunk is also a string, then merge them naively
if isinstance(second_content, str):
return first_content + second_content
# If the second chunk is a list, add the first chunk to the start of the list
else:
return_list: List[Union[str, Dict]] = [first_content]
return return_list + second_content
# If both are lists, merge them naively
elif isinstance(second_content, List):
return first_content + second_content
# If the first content is a list, and the second content is a string
else:
# If the last element of the first content is a string
# Add the second content to the last element
if isinstance(first_content[-1], str):
return first_content[:-1] + [first_content[-1] + second_content]
else:
# Otherwise, add the second content as a new element of the list
return first_content + [second_content]
class BaseMessageChunk(BaseMessage):
"""A Message chunk, which can be concatenated with other Message chunks."""
def _merge_kwargs_dict(
self, left: Dict[str, Any], right: Dict[str, Any]
) -> Dict[str, Any]:
"""Merge additional_kwargs from another BaseMessageChunk into this one."""
merged = left.copy()
for k, v in right.items():
if k not in merged:
merged[k] = v
elif type(merged[k]) != type(v):
raise ValueError(
f'additional_kwargs["{k}"] already exists in this message,'
" but with a different type."
)
elif isinstance(merged[k], str):
merged[k] += v
elif isinstance(merged[k], dict):
merged[k] = self._merge_kwargs_dict(merged[k], v)
else:
raise ValueError(
f"Additional kwargs key {k} already exists in this message."
)
return merged
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, BaseMessageChunk):
# If both are (subclasses of) BaseMessageChunk,
# concat into a single BaseMessageChunk
if isinstance(self, ChatMessageChunk):
return self.__class__(
role=self.role,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return self.__class__(
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
else:
raise TypeError(
'unsupported operand type(s) for +: "'
f"{self.__class__.__name__}"
f'" and "{other.__class__.__name__}"'
)
class HumanMessage(BaseMessage):
"""A Message from a human."""
example: bool = False
"""Whether this Message is being passed in to the model as part of an example
conversation.
"""
type: Literal["human"] = "human"
HumanMessage.update_forward_refs()
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
"""A Human Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["HumanMessageChunk"] = "HumanMessageChunk" # type: ignore[assignment] # noqa: E501
class AIMessage(BaseMessage):
"""A Message from an AI."""
example: bool = False
"""Whether this Message is being passed in to the model as part of an example
conversation.
"""
type: Literal["ai"] = "ai"
AIMessage.update_forward_refs()
class AIMessageChunk(AIMessage, BaseMessageChunk):
"""A Message chunk from an AI."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):
if self.example != other.example:
raise ValueError(
"Cannot concatenate AIMessageChunks with different example values."
)
return self.__class__(
example=self.example,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)
class SystemMessage(BaseMessage):
"""A Message for priming AI behavior, usually passed in as the first of a sequence
of input messages.
"""
type: Literal["system"] = "system"
SystemMessage.update_forward_refs()
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
"""A System Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["SystemMessageChunk"] = "SystemMessageChunk" # type: ignore[assignment] # noqa: E501
class FunctionMessage(BaseMessage):
"""A Message for passing the result of executing a function back to a model."""
name: str
"""The name of the function that was executed."""
type: Literal["function"] = "function"
FunctionMessage.update_forward_refs()
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
"""A Function Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["FunctionMessageChunk"] = "FunctionMessageChunk" # type: ignore[assignment]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, FunctionMessageChunk):
if self.name != other.name:
raise ValueError(
"Cannot concatenate FunctionMessageChunks with different names."
)
return self.__class__(
name=self.name,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)
class ToolMessage(BaseMessage):
"""A Message for passing the result of executing a tool back to a model."""
tool_call_id: str
"""Tool call that this message is responding to."""
type: Literal["tool"] = "tool"
ToolMessage.update_forward_refs()
class ToolMessageChunk(ToolMessage, BaseMessageChunk):
"""A Tool Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, ToolMessageChunk):
if self.tool_call_id != other.tool_call_id:
raise ValueError(
"Cannot concatenate ToolMessageChunks with different names."
)
return self.__class__(
tool_call_id=self.tool_call_id,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)
class ChatMessage(BaseMessage):
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
role: str
"""The speaker / role of the Message."""
type: Literal["chat"] = "chat"
ChatMessage.update_forward_refs()
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
"""A Chat Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["ChatMessageChunk"] = "ChatMessageChunk" # type: ignore
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, ChatMessageChunk):
if self.role != other.role:
raise ValueError(
"Cannot concatenate ChatMessageChunks with different roles."
)
return self.__class__(
role=self.role,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)
AnyMessage = Union[
AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage, ToolMessage
]
def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
"""Convert a sequence of Messages to a list of dictionaries.
Args:
messages: Sequence of messages (as BaseMessages) to convert.
Returns:
List of messages as dicts.
"""
return [_message_to_dict(m) for m in messages]
def _message_from_dict(message: dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
elif _type == "ai":
return AIMessage(**message["data"])
elif _type == "system":
return SystemMessage(**message["data"])
elif _type == "chat":
return ChatMessage(**message["data"])
elif _type == "function":
return FunctionMessage(**message["data"])
elif _type == "tool":
return ToolMessage(**message["data"])
else:
raise ValueError(f"Got unexpected message type: {_type}")
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
"""Convert a sequence of messages from dicts to Message objects.
Args:
messages: Sequence of messages (as dicts) to convert.
Returns:
List of messages (BaseMessages).
"""
return [_message_from_dict(m) for m in messages]

View File

@ -1,175 +0,0 @@
from __future__ import annotations
from copy import deepcopy
from typing import Any, Dict, List, Literal, Optional
from uuid import UUID
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.schema.messages import BaseMessage, BaseMessageChunk
class Generation(Serializable):
"""A single text generation output."""
text: str
"""Generated text output."""
generation_info: Optional[Dict[str, Any]] = None
"""Raw response from the provider. May include things like the
reason for finishing or token log probabilities.
"""
type: Literal["Generation"] = "Generation"
"""Type is used exclusively for serialization purposes."""
# TODO: add log probs as separate attribute
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
class GenerationChunk(Generation):
"""A Generation chunk, which can be concatenated with other Generation chunks."""
def __add__(self, other: GenerationChunk) -> GenerationChunk:
if isinstance(other, GenerationChunk):
generation_info = (
{**(self.generation_info or {}), **(other.generation_info or {})}
if self.generation_info is not None or other.generation_info is not None
else None
)
return GenerationChunk(
text=self.text + other.text,
generation_info=generation_info,
)
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)
class ChatGeneration(Generation):
"""A single chat generation output."""
text: str = ""
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
message: BaseMessage
"""The message output by the chat model."""
# Override type to be ChatGeneration, ignore mypy error as this is intentional
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
"""Type is used exclusively for serialization purposes."""
@root_validator
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Set the text attribute to be the contents of the message."""
try:
values["text"] = values["message"].content
except (KeyError, AttributeError) as e:
raise ValueError("Error while initializing ChatGeneration") from e
return values
class ChatGenerationChunk(ChatGeneration):
"""A ChatGeneration chunk, which can be concatenated with other
ChatGeneration chunks.
Attributes:
message: The message chunk output by the chat model.
"""
message: BaseMessageChunk
# Override type to be ChatGeneration, ignore mypy error as this is intentional
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501
"""Type is used exclusively for serialization purposes."""
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk):
generation_info = (
{**(self.generation_info or {}), **(other.generation_info or {})}
if self.generation_info is not None or other.generation_info is not None
else None
)
return ChatGenerationChunk(
message=self.message + other.message,
generation_info=generation_info,
)
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)
class RunInfo(BaseModel):
"""Class that contains metadata for a single execution of a Chain or model."""
run_id: UUID
"""A unique identifier for the model or chain run."""
class ChatResult(BaseModel):
"""Class that contains all results for a single chat model call."""
generations: List[ChatGeneration]
"""List of the chat generations. This is a List because an input can have multiple
candidate generations.
"""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
class LLMResult(BaseModel):
"""Class that contains all results for a batched LLM call."""
generations: List[List[Generation]]
"""List of generated outputs. This is a List[List[]] because
each input could have multiple candidate generations."""
llm_output: Optional[dict] = None
"""Arbitrary LLM provider-specific output."""
run: Optional[List[RunInfo]] = None
"""List of metadata info for model call for each input."""
def flatten(self) -> List[LLMResult]:
"""Flatten generations into a single list.
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult
contains only a single Generation. If token usage information is available,
it is kept only for the LLMResult corresponding to the top-choice
Generation, to avoid over-counting of token usage downstream.
Returns:
List of LLMResults where each returned LLMResult contains a single
Generation.
"""
llm_results = []
for i, gen_list in enumerate(self.generations):
# Avoid double counting tokens in OpenAICallback
if i == 0:
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=self.llm_output,
)
)
else:
if self.llm_output is not None:
llm_output = deepcopy(self.llm_output)
llm_output["token_usage"] = dict()
else:
llm_output = None
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=llm_output,
)
)
return llm_results
def __eq__(self, other: object) -> bool:
"""Check for LLMResult equality by ignoring any metadata related to runs."""
if not isinstance(other, LLMResult):
return NotImplemented
return (
self.generations == other.generations
and self.llm_output == other.llm_output
)

View File

@ -1,228 +0,0 @@
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union
import yaml
from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.schema.document import Document
from langchain_core.schema.output_parser import BaseOutputParser
from langchain_core.schema.prompt import PromptValue
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
"""Base class for all prompt templates, returning a prompt."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
input_types: Dict[str, Any] = Field(default_factory=dict)
"""A dictionary of the types of the variables the prompt template expects.
If not provided, all variables are assumed to be strings."""
output_parser: Optional[BaseOutputParser] = None
"""How to parse the output of calling an LLM on this formatted prompt."""
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
default_factory=dict
)
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def OutputType(self) -> Any:
from langchain_core.prompts.base import StringPromptValue
from langchain_core.prompts.chat import ChatPromptValueConcrete
return Union[StringPromptValue, ChatPromptValueConcrete]
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"PromptInput",
**{k: (self.input_types.get(k, str), None) for k in self.input_variables},
)
def invoke(
self, input: Dict, config: Optional[RunnableConfig] = None
) -> PromptValue:
return self._call_with_config(
lambda inner_input: self.format_prompt(
**{key: inner_input[key] for key in self.input_variables}
),
input,
config,
run_type="prompt",
)
@abstractmethod
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
@root_validator()
def validate_variable_names(cls, values: Dict) -> Dict:
"""Validate variable names do not include restricted names."""
if "stop" in values["input_variables"]:
raise ValueError(
"Cannot have an input variable named 'stop', as it is used internally,"
" please rename."
)
if "stop" in values["partial_variables"]:
raise ValueError(
"Cannot have an partial variable named 'stop', as it is used "
"internally, please rename."
)
overall = set(values["input_variables"]).intersection(
values["partial_variables"]
)
if overall:
raise ValueError(
f"Found overlapping input and partial variables: {overall}"
)
return values
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
"""Return a partial of the prompt template."""
prompt_dict = self.__dict__.copy()
prompt_dict["input_variables"] = list(
set(self.input_variables).difference(kwargs)
)
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
# Get partial params:
partial_kwargs = {
k: v if isinstance(v, str) else v()
for k, v in self.partial_variables.items()
}
return {**partial_kwargs, **kwargs}
@abstractmethod
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of prompt."""
prompt_dict = super().dict(**kwargs)
try:
prompt_dict["_type"] = self._prompt_type
except NotImplementedError:
pass
return prompt_dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the prompt.
Args:
file_path: Path to directory to save prompt to.
Example:
.. code-block:: python
prompt.save(file_path="path/prompt.yaml")
"""
if self.partial_variables:
raise ValueError("Cannot save prompt with partial variables.")
# Fetch dictionary to save
prompt_dict = self.dict()
if "_type" not in prompt_dict:
raise NotImplementedError(f"Prompt {self} does not support saving.")
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
"""Format a document into a string based on a prompt template.
First, this pulls information from the document from two sources:
1. `page_content`:
This takes the information from the `document.page_content`
and assigns it to a variable named `page_content`.
2. metadata:
This takes information from `document.metadata` and assigns
it to variables of the same name.
Those variables are then passed into the `prompt` to produce a formatted string.
Args:
doc: Document, the page_content and metadata will be used to create
the final string.
prompt: BasePromptTemplate, will be used to format the page_content
and metadata into the final string.
Returns:
string of the document formatted.
Example:
.. code-block:: python
from langchain_core.schema import Document
from langchain_core.prompts import PromptTemplate
doc = Document(page_content="This is a joke", metadata={"page": "1"})
prompt = PromptTemplate.from_template("Page {page}: {page_content}")
format_document(doc, prompt)
>>> "Page 1: This is a joke"
"""
base_info = {"page_content": doc.page_content, **doc.metadata}
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
]
raise ValueError(
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
document_info = {k: base_info[k] for k in prompt.input_variables}
return prompt.format(**document_info)

View File

@ -0,0 +1,16 @@
__all__ = [
"BaseTracer",
"EvaluatorCallbackHandler",
"LangChainTracer",
"ConsoleCallbackHandler",
"Run",
"RunLog",
"RunLogPatch",
]
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.evaluation import EvaluatorCallbackHandler
from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.log_stream import RunLog, RunLogPatch
from langchain_core.tracers.schemas import Run
from langchain_core.tracers.stdout import ConsoleCallbackHandler

View File

@ -9,24 +9,21 @@ from uuid import UUID
from tenacity import RetryCallState
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.callbacks.tracers.schemas import Run
from langchain_core.load.dump import dumpd
from langchain_core.schema.document import Document
from langchain_core.schema.output import (
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.documents import Document
from langchain_core.exceptions import TracerException
from langchain_core.load import dumpd
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
GenerationChunk,
LLMResult,
)
from langchain_core.tracers.schemas import Run
logger = logging.getLogger(__name__)
class TracerException(Exception):
"""Base class for exceptions in tracers module."""
class BaseTracer(BaseCallbackHandler, ABC):
"""Base interface for tracers."""

View File

@ -12,10 +12,10 @@ import langsmith
from langsmith.evaluation.evaluator import EvaluationResult, EvaluationResults
from langchain_core.callbacks import manager
from langchain_core.callbacks.tracers import langchain as langchain_tracer
from langchain_core.callbacks.tracers.base import BaseTracer
from langchain_core.callbacks.tracers.langchain import _get_executor
from langchain_core.callbacks.tracers.schemas import Run
from langchain_core.tracers import langchain as langchain_tracer
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.langchain import _get_executor
from langchain_core.tracers.schemas import Run
logger = logging.getLogger(__name__)

View File

@ -17,11 +17,11 @@ from tenacity import (
wait_exponential_jitter,
)
from langchain_core.callbacks.tracers.base import BaseTracer
from langchain_core.callbacks.tracers.schemas import Run
from langchain_core.env import get_runtime_environment
from langchain_core.load.dump import dumpd
from langchain_core.schema.messages import BaseMessage
from langchain_core.load import dumpd
from langchain_core.messages import BaseMessage
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
logger = logging.getLogger(__name__)
_LOGGED = set()

View File

@ -6,8 +6,9 @@ from typing import Any, Dict, Optional, Union
import requests
from langchain_core.callbacks.tracers.base import BaseTracer
from langchain_core.callbacks.tracers.schemas import (
from langchain_core.messages import get_buffer_string
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import (
ChainRun,
LLMRun,
Run,
@ -16,7 +17,6 @@ from langchain_core.callbacks.tracers.schemas import (
TracerSessionV1,
TracerSessionV1Base,
)
from langchain_core.schema.messages import get_buffer_string
from langchain_core.utils import raise_for_status_with_text
logger = logging.getLogger(__name__)

View File

@ -18,10 +18,10 @@ from uuid import UUID
import jsonpatch
from anyio import create_memory_object_stream
from langchain_core.callbacks.tracers.base import BaseTracer
from langchain_core.callbacks.tracers.schemas import Run
from langchain_core.load.load import load
from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk
from langchain_core.load import load
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
class LogEntry(TypedDict):

View File

@ -1,12 +1,12 @@
from typing import Callable, Optional, Union
from uuid import UUID
from langchain_core.callbacks.tracers.base import BaseTracer
from langchain_core.callbacks.tracers.schemas import Run
from langchain_core.runnables.config import (
RunnableConfig,
call_func_with_variable_args,
)
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]

View File

@ -3,8 +3,8 @@
from typing import Any, List, Optional, Union
from uuid import UUID
from langchain_core.callbacks.tracers.base import BaseTracer
from langchain_core.callbacks.tracers.schemas import Run
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
class RunCollectorCallbackHandler(BaseTracer):

View File

@ -9,8 +9,8 @@ from uuid import UUID
from langsmith.schemas import RunBase as BaseRunV2
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
from langchain_core.outputs import LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.schema import LLMResult
def RunTypeEnum() -> Type[RunTypeEnumDep]:

View File

@ -1,8 +1,8 @@
import json
from typing import Any, Callable, List
from langchain_core.callbacks.tracers.base import BaseTracer
from langchain_core.callbacks.tracers.schemas import Run
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
from langchain_core.utils.input import get_bolded_text, get_colored_text

View File

@ -11,6 +11,7 @@ from langchain_core.utils.input import (
get_colored_text,
print_text,
)
from langchain_core.utils.loading import try_load_from_hub
from langchain_core.utils.utils import (
check_package_version,
convert_to_secret_str,
@ -35,4 +36,5 @@ __all__ = [
"print_text",
"raise_for_status_with_text",
"xor_args",
"try_load_from_hub",
]

View File

@ -21,10 +21,10 @@ from typing import (
TypeVar,
)
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.schema import BaseRetriever
from langchain_core.schema.document import Document
from langchain_core.schema.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-core"
version = "0.0.1"
version = "0.0.2"
description = "Building applications with LLMs through composability"
authors = []
license = "MIT"

View File

@ -6,6 +6,8 @@ EXPECTED_ALL = [
"suppress_langchain_deprecation_warning",
"surface_langchain_deprecation_warnings",
"warn_deprecated",
"as_import_path",
"get_relative_path",
]

View File

@ -1,10 +1,10 @@
"""Test functionality related to length based selector."""
import pytest
from langchain_core.prompts.example_selector.length_based import (
from langchain_core.example_selectors import (
LengthBasedExampleSelector,
)
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts import PromptTemplate
EXAMPLES = [
{"question": "Question: who are you?\nAnswer: foo"},

View File

@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.schema.messages import BaseMessage
class BaseFakeCallbackHandler(BaseModel):

View File

@ -7,10 +7,9 @@ from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.chat_model import BaseChatModel, SimpleChatModel
from langchain_core.schema import ChatResult
from langchain_core.schema.messages import AIMessageChunk, BaseMessage
from langchain_core.schema.output import ChatGeneration, ChatGenerationChunk
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
class FakeMessagesListChatModel(BaseChatModel):

View File

@ -6,9 +6,8 @@ from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.llm import LLM
from langchain_core.language_models import LLM, LanguageModelInput
from langchain_core.runnables import RunnableConfig
from langchain_core.schema.language_model import LanguageModelInput
class FakeListLLM(LLM):

View File

@ -1,10 +1,10 @@
from typing import List
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.schema import (
from langchain_core.chat_history import (
BaseChatMessageHistory,
)
from langchain_core.schema.messages import BaseMessage
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel, Field
class ChatMessageHistory(BaseChatMessageHistory, BaseModel):

View File

@ -3,6 +3,13 @@ from typing import Any, List, Union
import pytest
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
get_buffer_string,
)
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.chat import (
AIMessagePromptTemplate,
@ -15,13 +22,6 @@ from langchain_core.prompts.chat import (
SystemMessagePromptTemplate,
_convert_to_message,
)
from langchain_core.schema.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
get_buffer_string,
)
def create_messages() -> List[BaseMessagePromptTemplate]:

View File

@ -3,19 +3,19 @@ from typing import Any, Dict, List, Sequence, Tuple
import pytest
from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.prompts import (
AIMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from langchain_core.prompts.chat import SystemMessagePromptTemplate
from langchain_core.prompts.example_selector.base import BaseExampleSelector
from langchain_core.prompts.few_shot import (
FewShotChatMessagePromptTemplate,
FewShotPromptTemplate,
)
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.schema import AIMessage, HumanMessage, SystemMessage
EXAMPLE_PROMPT = PromptTemplate(
input_variables=["question", "answer"], template="{question}: {answer}"

View File

@ -6,20 +6,22 @@ EXPECTED_ALL = [
"BasePromptTemplate",
"ChatMessagePromptTemplate",
"ChatPromptTemplate",
"ChatPromptValueConcrete",
"FewShotPromptTemplate",
"FewShotPromptWithTemplates",
"FewShotChatMessagePromptTemplate",
"format_document",
"ChatPromptValue",
"PromptValue",
"StringPromptValue",
"HumanMessagePromptTemplate",
"LengthBasedExampleSelector",
"MaxMarginalRelevanceExampleSelector",
"MessagesPlaceholder",
"PipelinePromptTemplate",
"Prompt",
"PromptTemplate",
"SemanticSimilarityExampleSelector",
"StringPromptTemplate",
"SystemMessagePromptTemplate",
"load_prompt",
"FewShotChatMessagePromptTemplate",
]

View File

@ -1,5 +1,5 @@
"""Test functionality related to prompt utils."""
from langchain_core.prompts.example_selector.semantic_similarity import sorted_values
from langchain_core.example_selectors import sorted_values
def test_sorted_vals() -> None:

View File

@ -1,8 +1,8 @@
from langchain_core.callbacks.manager import CallbackManager
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_core.callbacks.tracers.stdout import ConsoleCallbackHandler
from langchain_core.runnables.config import RunnableConfig, merge_configs
from langchain_core.tracers.stdout import ConsoleCallbackHandler
def test_merge_config_callbacks() -> None:

View File

@ -1,9 +1,10 @@
from typing import Any, Callable, Sequence, Union
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.runnables.base import RunnableLambda
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.schema import AIMessage, BaseMessage, HumanMessage
from tests.unit_tests.fake.memory import ChatMessageHistory

View File

@ -26,53 +26,55 @@ from langchain_core.callbacks.manager import (
collect_runs,
trace_as_chain_group,
)
from langchain_core.callbacks.tracers.base import BaseTracer
from langchain_core.callbacks.tracers.log_stream import RunLog, RunLogPatch
from langchain_core.callbacks.tracers.schemas import Run
from langchain_core.callbacks.tracers.stdout import ConsoleCallbackHandler
from langchain_core.load.dump import dumpd, dumps
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.base import StringPromptValue
from langchain_core.prompts.chat import (
ChatPromptTemplate,
ChatPromptValue,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import (
RouterRunnable,
Runnable,
RunnableBranch,
RunnableConfig,
RunnableLambda,
RunnableParallel,
RunnablePassthrough,
RunnableSequence,
RunnableWithFallbacks,
)
from langchain_core.runnables.base import (
ConfigurableField,
RunnableBinding,
RunnableGenerator,
)
from langchain_core.runnables.utils import (
ConfigurableFieldMultiOption,
ConfigurableFieldSingleOption,
add,
)
from langchain_core.schema.document import Document
from langchain_core.schema.messages import (
from langchain_core.documents import Document
from langchain_core.load import dumpd, dumps
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
HumanMessage,
SystemMessage,
)
from langchain_core.schema.output_parser import BaseOutputParser, StrOutputParser
from langchain_core.schema.retriever import BaseRetriever
from langchain_core.tool import BaseTool, tool
from langchain_core.output_parsers import (
BaseOutputParser,
CommaSeparatedListOutputParser,
StrOutputParser,
)
from langchain_core.prompts import (
ChatPromptTemplate,
ChatPromptValue,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
StringPromptValue,
SystemMessagePromptTemplate,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import (
ConfigurableField,
ConfigurableFieldMultiOption,
ConfigurableFieldSingleOption,
RouterRunnable,
Runnable,
RunnableBinding,
RunnableBranch,
RunnableConfig,
RunnableGenerator,
RunnableLambda,
RunnableParallel,
RunnablePassthrough,
RunnableSequence,
RunnableWithFallbacks,
add,
)
from langchain_core.tools import BaseTool, tool
from langchain_core.tracers import (
BaseTracer,
ConsoleCallbackHandler,
Run,
RunLog,
RunLogPatch,
)
from tests.unit_tests.fake.chat_model import FakeListChatModel
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
@ -1539,7 +1541,7 @@ def test_with_listeners(mocker: MockerFixture) -> None:
)
chat = FakeListChatModel(responses=["foo"])
chain = prompt | chat
chain: Runnable = prompt | chat
mock_start = mocker.Mock()
mock_end = mocker.Mock()
@ -1572,7 +1574,7 @@ async def test_with_listeners_async(mocker: MockerFixture) -> None:
)
chat = FakeListChatModel(responses=["foo"])
chain = prompt | chat
chain: Runnable = prompt | chat
mock_start = mocker.Mock()
mock_end = mocker.Mock()
@ -1608,7 +1610,7 @@ def test_prompt_with_chat_model(
)
chat = FakeListChatModel(responses=["foo"])
chain = prompt | chat
chain: Runnable = prompt | chat
assert repr(chain) == snapshot
assert isinstance(chain, RunnableSequence)
@ -1712,7 +1714,7 @@ async def test_prompt_with_chat_model_async(
)
chat = FakeListChatModel(responses=["foo"])
chain = prompt | chat
chain: Runnable = prompt | chat
assert repr(chain) == snapshot
assert isinstance(chain, RunnableSequence)
@ -1819,7 +1821,7 @@ async def test_prompt_with_llm(
)
llm = FakeListLLM(responses=["foo", "bar"])
chain = prompt | llm
chain: Runnable = prompt | llm
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
@ -2325,13 +2327,13 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
async def test_router_runnable(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
chain1 = ChatPromptTemplate.from_template(
chain1: Runnable = ChatPromptTemplate.from_template(
"You are a math genius. Answer the question: {question}"
) | FakeListLLM(responses=["4"])
chain2 = ChatPromptTemplate.from_template(
chain2: Runnable = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"])
router = RouterRunnable({"math": chain1, "english": chain2})
router: Runnable = RouterRunnable({"math": chain1, "english": chain2})
chain: Runnable = {
"key": lambda x: x["key"],
"input": {"question": lambda x: x["question"]},
@ -2377,10 +2379,10 @@ async def test_router_runnable(
async def test_higher_order_lambda_runnable(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
math_chain = ChatPromptTemplate.from_template(
math_chain: Runnable = ChatPromptTemplate.from_template(
"You are a math genius. Answer the question: {question}"
) | FakeListLLM(responses=["4"])
english_chain = ChatPromptTemplate.from_template(
english_chain: Runnable = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"])
input_map: Runnable = RunnableParallel(
@ -3096,7 +3098,7 @@ async def test_deep_astream_assign() -> None:
def test_runnable_sequence_transform() -> None:
llm = FakeStreamingListLLM(responses=["foo-lish"])
chain = llm | StrOutputParser()
chain: Runnable = llm | StrOutputParser()
stream = chain.transform(llm.stream("Hi there!"))
@ -3111,7 +3113,7 @@ def test_runnable_sequence_transform() -> None:
async def test_runnable_sequence_atransform() -> None:
llm = FakeStreamingListLLM(responses=["foo-lish"])
chain = llm | StrOutputParser()
chain: Runnable = llm | StrOutputParser()
stream = chain.atransform(llm.astream("Hi there!"))

View File

@ -1,43 +0,0 @@
from langchain_core.schema import __all__
EXPECTED_ALL = [
"BaseCache",
"BaseMemory",
"BaseStore",
"AgentFinish",
"AgentAction",
"Document",
"BaseChatMessageHistory",
"BaseDocumentTransformer",
"BaseMessage",
"ChatMessage",
"FunctionMessage",
"HumanMessage",
"AIMessage",
"SystemMessage",
"messages_from_dict",
"messages_to_dict",
"_message_to_dict",
"_message_from_dict",
"get_buffer_string",
"RunInfo",
"LLMResult",
"ChatResult",
"ChatGeneration",
"Generation",
"PromptValue",
"LangChainException",
"BaseRetriever",
"RUN_KEY",
"Memory",
"OutputParserException",
"StrOutputParser",
"BaseOutputParser",
"BaseLLMOutputParser",
"BasePromptTemplate",
"format_document",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -1,6 +1,6 @@
import pytest
from langchain_core.schema.messages import (
from langchain_core.messages import (
AIMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,

View File

@ -1,5 +1,5 @@
from langchain_core.schema.messages import HumanMessageChunk
from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk
from langchain_core.messages import HumanMessageChunk
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
def test_generation_chunk() -> None:

View File

@ -7,12 +7,12 @@ from typing import Any, List, Optional, Type, Union
import pytest
from langchain_core.callbacks.manager import (
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tool import (
from langchain_core.tools import (
BaseTool,
SchemaAnnotationError,
StructuredTool,

Some files were not shown because too many files have changed in this diff Show More