mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
REFACTOR: Refactor langchain_core (#13627)
Changes: - remove langchain_core/schema since no clear distinction b/n schema and non-schema modules - make every module that doesn't end in -y plural - where easy have 1-2 classes per file - no more than one level of nesting in directories - only import from top level core modules in langchain
This commit is contained in:
parent
17c6551c18
commit
d32e511826
@ -16,9 +16,12 @@ from .deprecation import (
|
|||||||
surface_langchain_deprecation_warnings,
|
surface_langchain_deprecation_warnings,
|
||||||
warn_deprecated,
|
warn_deprecated,
|
||||||
)
|
)
|
||||||
|
from .path import as_import_path, get_relative_path
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"as_import_path",
|
||||||
"deprecated",
|
"deprecated",
|
||||||
|
"get_relative_path",
|
||||||
"LangChainDeprecationWarning",
|
"LangChainDeprecationWarning",
|
||||||
"suppress_langchain_deprecation_warning",
|
"suppress_langchain_deprecation_warning",
|
||||||
"surface_langchain_deprecation_warnings",
|
"surface_langchain_deprecation_warnings",
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from typing import Any, Literal, Sequence, Union
|
from typing import Any, Literal, Sequence, Union
|
||||||
|
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
from langchain_core.schema.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
|
|
||||||
class AgentAction(Serializable):
|
class AgentAction(Serializable):
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Optional, Sequence
|
from typing import Any, Optional, Sequence
|
||||||
|
|
||||||
from langchain_core.schema.output import Generation
|
from langchain_core.outputs import Generation
|
||||||
|
|
||||||
RETURN_VAL_TYPE = Sequence[Generation]
|
RETURN_VAL_TYPE = Sequence[Generation]
|
||||||
|
|
@ -0,0 +1,69 @@
|
|||||||
|
from langchain_core.callbacks.base import (
|
||||||
|
AsyncCallbackHandler,
|
||||||
|
BaseCallbackHandler,
|
||||||
|
BaseCallbackManager,
|
||||||
|
CallbackManagerMixin,
|
||||||
|
Callbacks,
|
||||||
|
ChainManagerMixin,
|
||||||
|
LLMManagerMixin,
|
||||||
|
RetrieverManagerMixin,
|
||||||
|
RunManagerMixin,
|
||||||
|
ToolManagerMixin,
|
||||||
|
)
|
||||||
|
from langchain_core.callbacks.manager import (
|
||||||
|
AsyncCallbackManager,
|
||||||
|
AsyncCallbackManagerForChainGroup,
|
||||||
|
AsyncCallbackManagerForChainRun,
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
|
AsyncCallbackManagerForToolRun,
|
||||||
|
AsyncParentRunManager,
|
||||||
|
AsyncRunManager,
|
||||||
|
BaseRunManager,
|
||||||
|
CallbackManager,
|
||||||
|
CallbackManagerForChainGroup,
|
||||||
|
CallbackManagerForChainRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForRetrieverRun,
|
||||||
|
CallbackManagerForToolRun,
|
||||||
|
ParentRunManager,
|
||||||
|
RunManager,
|
||||||
|
env_var_is_set,
|
||||||
|
register_configure_hook,
|
||||||
|
)
|
||||||
|
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||||
|
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RetrieverManagerMixin",
|
||||||
|
"LLMManagerMixin",
|
||||||
|
"ChainManagerMixin",
|
||||||
|
"ToolManagerMixin",
|
||||||
|
"Callbacks",
|
||||||
|
"CallbackManagerMixin",
|
||||||
|
"RunManagerMixin",
|
||||||
|
"BaseCallbackHandler",
|
||||||
|
"AsyncCallbackHandler",
|
||||||
|
"BaseCallbackManager",
|
||||||
|
"BaseRunManager",
|
||||||
|
"RunManager",
|
||||||
|
"ParentRunManager",
|
||||||
|
"AsyncRunManager",
|
||||||
|
"AsyncParentRunManager",
|
||||||
|
"CallbackManagerForLLMRun",
|
||||||
|
"AsyncCallbackManagerForLLMRun",
|
||||||
|
"CallbackManagerForChainRun",
|
||||||
|
"AsyncCallbackManagerForChainRun",
|
||||||
|
"CallbackManagerForToolRun",
|
||||||
|
"AsyncCallbackManagerForToolRun",
|
||||||
|
"CallbackManagerForRetrieverRun",
|
||||||
|
"AsyncCallbackManagerForRetrieverRun",
|
||||||
|
"CallbackManager",
|
||||||
|
"CallbackManagerForChainGroup",
|
||||||
|
"AsyncCallbackManager",
|
||||||
|
"AsyncCallbackManagerForChainGroup",
|
||||||
|
"StdOutCallbackHandler",
|
||||||
|
"StreamingStdOutCallbackHandler",
|
||||||
|
"env_var_is_set",
|
||||||
|
"register_configure_hook",
|
||||||
|
]
|
@ -6,10 +6,10 @@ from uuid import UUID
|
|||||||
|
|
||||||
from tenacity import RetryCallState
|
from tenacity import RetryCallState
|
||||||
|
|
||||||
from langchain_core.schema.agent import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
from langchain_core.schema.document import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.schema.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult
|
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||||
|
|
||||||
|
|
||||||
class RetrieverManagerMixin:
|
class RetrieverManagerMixin:
|
||||||
|
@ -30,6 +30,10 @@ from langsmith import utils as ls_utils
|
|||||||
from langsmith.run_helpers import get_run_tree_context
|
from langsmith.run_helpers import get_run_tree_context
|
||||||
from tenacity import RetryCallState
|
from tenacity import RetryCallState
|
||||||
|
|
||||||
|
from langchain_core.agents import (
|
||||||
|
AgentAction,
|
||||||
|
AgentFinish,
|
||||||
|
)
|
||||||
from langchain_core.callbacks.base import (
|
from langchain_core.callbacks.base import (
|
||||||
BaseCallbackHandler,
|
BaseCallbackHandler,
|
||||||
BaseCallbackManager,
|
BaseCallbackManager,
|
||||||
@ -41,23 +45,16 @@ from langchain_core.callbacks.base import (
|
|||||||
ToolManagerMixin,
|
ToolManagerMixin,
|
||||||
)
|
)
|
||||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||||
from langchain_core.callbacks.tracers import run_collector
|
from langchain_core.documents import Document
|
||||||
from langchain_core.callbacks.tracers.langchain import (
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||||
|
from langchain_core.tracers import run_collector
|
||||||
|
from langchain_core.tracers.langchain import (
|
||||||
LangChainTracer,
|
LangChainTracer,
|
||||||
)
|
)
|
||||||
from langchain_core.callbacks.tracers.langchain_v1 import (
|
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
|
||||||
LangChainTracerV1,
|
from langchain_core.tracers.schemas import TracerSessionV1
|
||||||
TracerSessionV1,
|
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
||||||
)
|
|
||||||
from langchain_core.callbacks.tracers.stdout import ConsoleCallbackHandler
|
|
||||||
from langchain_core.schema import (
|
|
||||||
AgentAction,
|
|
||||||
AgentFinish,
|
|
||||||
Document,
|
|
||||||
LLMResult,
|
|
||||||
)
|
|
||||||
from langchain_core.schema.messages import BaseMessage, get_buffer_string
|
|
||||||
from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langsmith import Client as LangSmithClient
|
from langsmith import Client as LangSmithClient
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
"""Callback Handler that prints to std out."""
|
"""Callback Handler that prints to std out."""
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||||
from langchain_core.schema import AgentAction, AgentFinish, LLMResult
|
from langchain_core.outputs import LLMResult
|
||||||
from langchain_core.utils.input import print_text
|
from langchain_core.utils import print_text
|
||||||
|
|
||||||
|
|
||||||
class StdOutCallbackHandler(BaseCallbackHandler):
|
class StdOutCallbackHandler(BaseCallbackHandler):
|
||||||
|
@ -2,9 +2,10 @@
|
|||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||||
from langchain_core.schema import AgentAction, AgentFinish, LLMResult
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.schema.messages import BaseMessage
|
from langchain_core.outputs import LLMResult
|
||||||
|
|
||||||
|
|
||||||
class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain_core.schema.messages import AIMessage, BaseMessage, HumanMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
|
|
||||||
|
|
||||||
class BaseChatMessageHistory(ABC):
|
class BaseChatMessageHistory(ABC):
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Sequence, TypedDict
|
from typing import Sequence, TypedDict
|
||||||
|
|
||||||
from langchain_core.schema import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
|
|
||||||
class ChatSession(TypedDict, total=False):
|
class ChatSession(TypedDict, total=False):
|
@ -3,27 +3,9 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Literal, Sequence
|
from typing import Any, Sequence
|
||||||
|
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.documents import Document
|
||||||
from langchain_core.pydantic_v1 import Field
|
|
||||||
|
|
||||||
|
|
||||||
class Document(Serializable):
|
|
||||||
"""Class for storing a piece of text and associated metadata."""
|
|
||||||
|
|
||||||
page_content: str
|
|
||||||
"""String text."""
|
|
||||||
metadata: dict = Field(default_factory=dict)
|
|
||||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
|
||||||
documents, etc.).
|
|
||||||
"""
|
|
||||||
type: Literal["Document"] = "Document"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_lc_serializable(cls) -> bool:
|
|
||||||
"""Return whether this class is serializable."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDocumentTransformer(ABC):
|
class BaseDocumentTransformer(ABC):
|
23
libs/core/langchain_core/documents.py
Normal file
23
libs/core/langchain_core/documents.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from langchain_core.load.serializable import Serializable
|
||||||
|
from langchain_core.pydantic_v1 import Field
|
||||||
|
|
||||||
|
|
||||||
|
class Document(Serializable):
|
||||||
|
"""Class for storing a piece of text and associated metadata."""
|
||||||
|
|
||||||
|
page_content: str
|
||||||
|
"""String text."""
|
||||||
|
metadata: dict = Field(default_factory=dict)
|
||||||
|
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||||
|
documents, etc.).
|
||||||
|
"""
|
||||||
|
type: Literal["Document"] = "Document"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
"""Return whether this class is serializable."""
|
||||||
|
return True
|
@ -1,14 +1,18 @@
|
|||||||
"""Logic for selecting examples to include in prompts."""
|
"""Logic for selecting examples to include in prompts."""
|
||||||
from langchain_core.prompts.example_selector.length_based import (
|
from langchain_core.example_selectors.base import BaseExampleSelector
|
||||||
|
from langchain_core.example_selectors.length_based import (
|
||||||
LengthBasedExampleSelector,
|
LengthBasedExampleSelector,
|
||||||
)
|
)
|
||||||
from langchain_core.prompts.example_selector.semantic_similarity import (
|
from langchain_core.example_selectors.semantic_similarity import (
|
||||||
MaxMarginalRelevanceExampleSelector,
|
MaxMarginalRelevanceExampleSelector,
|
||||||
SemanticSimilarityExampleSelector,
|
SemanticSimilarityExampleSelector,
|
||||||
|
sorted_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"BaseExampleSelector",
|
||||||
"LengthBasedExampleSelector",
|
"LengthBasedExampleSelector",
|
||||||
"MaxMarginalRelevanceExampleSelector",
|
"MaxMarginalRelevanceExampleSelector",
|
||||||
"SemanticSimilarityExampleSelector",
|
"SemanticSimilarityExampleSelector",
|
||||||
|
"sorted_values",
|
||||||
]
|
]
|
@ -2,7 +2,7 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Callable, Dict, List
|
from typing import Callable, Dict, List
|
||||||
|
|
||||||
from langchain_core.prompts.example_selector.base import BaseExampleSelector
|
from langchain_core.example_selectors.base import BaseExampleSelector
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.pydantic_v1 import BaseModel, validator
|
from langchain_core.pydantic_v1 import BaseModel, validator
|
||||||
|
|
@ -3,10 +3,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Type
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
from langchain_core.prompts.example_selector.base import BaseExampleSelector
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_core.example_selectors.base import BaseExampleSelector
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra
|
from langchain_core.pydantic_v1 import BaseModel, Extra
|
||||||
from langchain_core.schema.embeddings import Embeddings
|
from langchain_core.vectorstores import VectorStore
|
||||||
from langchain_core.schema.vectorstore import VectorStore
|
|
||||||
|
|
||||||
|
|
||||||
def sorted_values(values: Dict[str, str]) -> List[Any]:
|
def sorted_values(values: Dict[str, str]) -> List[Any]:
|
48
libs/core/langchain_core/exceptions.py
Normal file
48
libs/core/langchain_core/exceptions.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class LangChainException(Exception):
|
||||||
|
"""General LangChain exception."""
|
||||||
|
|
||||||
|
|
||||||
|
class TracerException(LangChainException):
|
||||||
|
"""Base class for exceptions in tracers module."""
|
||||||
|
|
||||||
|
|
||||||
|
class OutputParserException(ValueError, LangChainException):
|
||||||
|
"""Exception that output parsers should raise to signify a parsing error.
|
||||||
|
|
||||||
|
This exists to differentiate parsing errors from other code or execution errors
|
||||||
|
that also may arise inside the output parser. OutputParserExceptions will be
|
||||||
|
available to catch and handle in ways to fix the parsing error, while other
|
||||||
|
errors will be raised.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: The error that's being re-raised or an error message.
|
||||||
|
observation: String explanation of error which can be passed to a
|
||||||
|
model to try and remediate the issue.
|
||||||
|
llm_output: String model output which is error-ing.
|
||||||
|
send_to_llm: Whether to send the observation and llm_output back to an Agent
|
||||||
|
after an OutputParserException has been raised. This gives the underlying
|
||||||
|
model driving the agent the context that the previous output was improperly
|
||||||
|
structured, in the hopes that it will update the output to the correct
|
||||||
|
format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
error: Any,
|
||||||
|
observation: Optional[str] = None,
|
||||||
|
llm_output: Optional[str] = None,
|
||||||
|
send_to_llm: bool = False,
|
||||||
|
):
|
||||||
|
super(OutputParserException, self).__init__(error)
|
||||||
|
if send_to_llm:
|
||||||
|
if observation is None or llm_output is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Arguments 'observation' & 'llm_output'"
|
||||||
|
" are required if 'send_to_llm' is True"
|
||||||
|
)
|
||||||
|
self.observation = observation
|
||||||
|
self.llm_output = llm_output
|
||||||
|
self.send_to_llm = send_to_llm
|
@ -4,7 +4,7 @@ import warnings
|
|||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.schema import BaseCache
|
from langchain_core.caches import BaseCache
|
||||||
|
|
||||||
|
|
||||||
# DO NOT USE THESE VALUES DIRECTLY!
|
# DO NOT USE THESE VALUES DIRECTLY!
|
||||||
|
12
libs/core/langchain_core/language_models/__init__.py
Normal file
12
libs/core/langchain_core/language_models/__init__.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
|
||||||
|
from langchain_core.language_models.llms import LLM, BaseLLM
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseLanguageModel",
|
||||||
|
"BaseChatModel",
|
||||||
|
"SimpleChatModel",
|
||||||
|
"BaseLLM",
|
||||||
|
"LLM",
|
||||||
|
"LanguageModelInput",
|
||||||
|
]
|
@ -15,14 +15,14 @@ from typing import (
|
|||||||
|
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
|
from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string
|
||||||
|
from langchain_core.outputs import LLMResult
|
||||||
|
from langchain_core.prompts import PromptValue
|
||||||
from langchain_core.runnables import RunnableSerializable
|
from langchain_core.runnables import RunnableSerializable
|
||||||
from langchain_core.schema.messages import AnyMessage, BaseMessage, get_buffer_string
|
|
||||||
from langchain_core.schema.output import LLMResult
|
|
||||||
from langchain_core.schema.prompt import PromptValue
|
|
||||||
from langchain_core.utils import get_pydantic_field_names
|
from langchain_core.utils import get_pydantic_field_names
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.callbacks.manager import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None) # Cache the tokenizer
|
@lru_cache(maxsize=None) # Cache the tokenizer
|
||||||
@ -74,8 +74,8 @@ class BaseLanguageModel(
|
|||||||
@property
|
@property
|
||||||
def InputType(self) -> TypeAlias:
|
def InputType(self) -> TypeAlias:
|
||||||
"""Get the input type for this runnable."""
|
"""Get the input type for this runnable."""
|
||||||
from langchain_core.prompts.base import StringPromptValue
|
|
||||||
from langchain_core.prompts.chat import ChatPromptValueConcrete
|
from langchain_core.prompts.chat import ChatPromptValueConcrete
|
||||||
|
from langchain_core.prompts.string import StringPromptValue
|
||||||
|
|
||||||
# This is a version of LanguageModelInput which replaces the abstract
|
# This is a version of LanguageModelInput which replaces the abstract
|
||||||
# base class BaseMessage with a union of its subclasses, which makes
|
# base class BaseMessage with a union of its subclasses, which makes
|
@ -14,36 +14,34 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_core.callbacks.base import BaseCallbackManager
|
from langchain_core.callbacks import (
|
||||||
from langchain_core.callbacks.manager import (
|
|
||||||
AsyncCallbackManager,
|
AsyncCallbackManager,
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
BaseCallbackManager,
|
||||||
CallbackManager,
|
CallbackManager,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain_core.globals import get_llm_cache
|
from langchain_core.globals import get_llm_cache
|
||||||
from langchain_core.load.dump import dumpd, dumps
|
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
|
||||||
from langchain_core.prompts.base import StringPromptValue
|
from langchain_core.load import dumpd, dumps
|
||||||
from langchain_core.prompts.chat import ChatPromptValue
|
from langchain_core.messages import (
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
|
||||||
from langchain_core.runnables import RunnableConfig
|
|
||||||
from langchain_core.schema import (
|
|
||||||
ChatGeneration,
|
|
||||||
ChatResult,
|
|
||||||
LLMResult,
|
|
||||||
PromptValue,
|
|
||||||
RunInfo,
|
|
||||||
)
|
|
||||||
from langchain_core.schema.language_model import BaseLanguageModel, LanguageModelInput
|
|
||||||
from langchain_core.schema.messages import (
|
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AnyMessage,
|
AnyMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
BaseMessageChunk,
|
BaseMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.schema.output import ChatGenerationChunk
|
from langchain_core.outputs import (
|
||||||
|
ChatGeneration,
|
||||||
|
ChatGenerationChunk,
|
||||||
|
ChatResult,
|
||||||
|
LLMResult,
|
||||||
|
RunInfo,
|
||||||
|
)
|
||||||
|
from langchain_core.prompts import ChatPromptValue, PromptValue, StringPromptValue
|
||||||
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
|
||||||
def _get_verbosity() -> bool:
|
def _get_verbosity() -> bool:
|
@ -46,16 +46,13 @@ from langchain_core.callbacks.manager import (
|
|||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain_core.globals import get_llm_cache
|
from langchain_core.globals import get_llm_cache
|
||||||
from langchain_core.load.dump import dumpd
|
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
|
||||||
from langchain_core.prompts.base import StringPromptValue
|
from langchain_core.load import dumpd
|
||||||
from langchain_core.prompts.chat import ChatPromptValue
|
from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string
|
||||||
|
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
||||||
|
from langchain_core.prompts import ChatPromptValue, PromptValue, StringPromptValue
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator, validator
|
from langchain_core.pydantic_v1 import Field, root_validator, validator
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig, get_config_list
|
||||||
from langchain_core.runnables.config import get_config_list
|
|
||||||
from langchain_core.schema import Generation, LLMResult, PromptValue, RunInfo
|
|
||||||
from langchain_core.schema.language_model import BaseLanguageModel, LanguageModelInput
|
|
||||||
from langchain_core.schema.messages import AIMessage, BaseMessage, get_buffer_string
|
|
||||||
from langchain_core.schema.output import GenerationChunk
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
120
libs/core/langchain_core/messages/__init__.py
Normal file
120
libs/core/langchain_core/messages/__init__.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from typing import List, Sequence, Union
|
||||||
|
|
||||||
|
from langchain_core.messages.ai import AIMessage, AIMessageChunk
|
||||||
|
from langchain_core.messages.base import (
|
||||||
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
|
merge_content,
|
||||||
|
message_to_dict,
|
||||||
|
messages_to_dict,
|
||||||
|
)
|
||||||
|
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
|
||||||
|
from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk
|
||||||
|
from langchain_core.messages.human import HumanMessage, HumanMessageChunk
|
||||||
|
from langchain_core.messages.system import SystemMessage, SystemMessageChunk
|
||||||
|
from langchain_core.messages.tool import ToolMessage, ToolMessageChunk
|
||||||
|
|
||||||
|
AnyMessage = Union[
|
||||||
|
AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage, ToolMessage
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_buffer_string(
|
||||||
|
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
||||||
|
) -> str:
|
||||||
|
"""Convert sequence of Messages to strings and concatenate them into one string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Messages to be converted to strings.
|
||||||
|
human_prefix: The prefix to prepend to contents of HumanMessages.
|
||||||
|
ai_prefix: THe prefix to prepend to contents of AIMessages.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A single string concatenation of all input messages.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_core import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content="Hi, how are you?"),
|
||||||
|
AIMessage(content="Good, how are you?"),
|
||||||
|
]
|
||||||
|
get_buffer_string(messages)
|
||||||
|
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
|
||||||
|
"""
|
||||||
|
string_messages = []
|
||||||
|
for m in messages:
|
||||||
|
if isinstance(m, HumanMessage):
|
||||||
|
role = human_prefix
|
||||||
|
elif isinstance(m, AIMessage):
|
||||||
|
role = ai_prefix
|
||||||
|
elif isinstance(m, SystemMessage):
|
||||||
|
role = "System"
|
||||||
|
elif isinstance(m, FunctionMessage):
|
||||||
|
role = "Function"
|
||||||
|
elif isinstance(m, ChatMessage):
|
||||||
|
role = m.role
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unsupported message type: {m}")
|
||||||
|
message = f"{role}: {m.content}"
|
||||||
|
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
|
||||||
|
message += f"{m.additional_kwargs['function_call']}"
|
||||||
|
string_messages.append(message)
|
||||||
|
|
||||||
|
return "\n".join(string_messages)
|
||||||
|
|
||||||
|
|
||||||
|
def _message_from_dict(message: dict) -> BaseMessage:
|
||||||
|
_type = message["type"]
|
||||||
|
if _type == "human":
|
||||||
|
return HumanMessage(**message["data"])
|
||||||
|
elif _type == "ai":
|
||||||
|
return AIMessage(**message["data"])
|
||||||
|
elif _type == "system":
|
||||||
|
return SystemMessage(**message["data"])
|
||||||
|
elif _type == "chat":
|
||||||
|
return ChatMessage(**message["data"])
|
||||||
|
elif _type == "function":
|
||||||
|
return FunctionMessage(**message["data"])
|
||||||
|
elif _type == "tool":
|
||||||
|
return ToolMessage(**message["data"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unexpected message type: {_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def messages_from_dict(messages: Sequence[dict]) -> List[BaseMessage]:
|
||||||
|
"""Convert a sequence of messages from dicts to Message objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Sequence of messages (as dicts) to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of messages (BaseMessages).
|
||||||
|
"""
|
||||||
|
return [_message_from_dict(m) for m in messages]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AIMessage",
|
||||||
|
"AIMessageChunk",
|
||||||
|
"AnyMessage",
|
||||||
|
"BaseMessage",
|
||||||
|
"BaseMessageChunk",
|
||||||
|
"ChatMessage",
|
||||||
|
"ChatMessageChunk",
|
||||||
|
"FunctionMessage",
|
||||||
|
"FunctionMessageChunk",
|
||||||
|
"HumanMessage",
|
||||||
|
"HumanMessageChunk",
|
||||||
|
"SystemMessage",
|
||||||
|
"SystemMessageChunk",
|
||||||
|
"ToolMessage",
|
||||||
|
"ToolMessageChunk",
|
||||||
|
"get_buffer_string",
|
||||||
|
"messages_from_dict",
|
||||||
|
"messages_to_dict",
|
||||||
|
"message_to_dict",
|
||||||
|
"merge_content",
|
||||||
|
]
|
47
libs/core/langchain_core/messages/ai.py
Normal file
47
libs/core/langchain_core/messages/ai.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from langchain_core.messages.base import (
|
||||||
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
|
merge_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AIMessage(BaseMessage):
|
||||||
|
"""A Message from an AI."""
|
||||||
|
|
||||||
|
example: bool = False
|
||||||
|
"""Whether this Message is being passed in to the model as part of an example
|
||||||
|
conversation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["ai"] = "ai"
|
||||||
|
|
||||||
|
|
||||||
|
AIMessage.update_forward_refs()
|
||||||
|
|
||||||
|
|
||||||
|
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||||
|
"""A Message chunk from an AI."""
|
||||||
|
|
||||||
|
# Ignoring mypy re-assignment here since we're overriding the value
|
||||||
|
# to make sure that the chunk variant can be discriminated from the
|
||||||
|
# non-chunk variant.
|
||||||
|
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501
|
||||||
|
|
||||||
|
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||||
|
if isinstance(other, AIMessageChunk):
|
||||||
|
if self.example != other.example:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot concatenate AIMessageChunks with different example values."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.__class__(
|
||||||
|
example=self.example,
|
||||||
|
content=merge_content(self.content, other.content),
|
||||||
|
additional_kwargs=self._merge_kwargs_dict(
|
||||||
|
self.additional_kwargs, other.additional_kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().__add__(other)
|
126
libs/core/langchain_core/messages/base.py
Normal file
126
libs/core/langchain_core/messages/base.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
|
||||||
|
|
||||||
|
from langchain_core.load.serializable import Serializable
|
||||||
|
from langchain_core.pydantic_v1 import Extra, Field
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMessage(Serializable):
|
||||||
|
"""The base abstract Message class.
|
||||||
|
|
||||||
|
Messages are the inputs and outputs of ChatModels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
content: Union[str, List[Union[str, Dict]]]
|
||||||
|
"""The string contents of the message."""
|
||||||
|
|
||||||
|
additional_kwargs: dict = Field(default_factory=dict)
|
||||||
|
"""Any additional information."""
|
||||||
|
|
||||||
|
type: str
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = Extra.allow
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
"""Return whether this class is serializable."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||||
|
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate(messages=[self])
|
||||||
|
return prompt + other
|
||||||
|
|
||||||
|
|
||||||
|
def merge_content(
|
||||||
|
first_content: Union[str, List[Union[str, Dict]]],
|
||||||
|
second_content: Union[str, List[Union[str, Dict]]],
|
||||||
|
) -> Union[str, List[Union[str, Dict]]]:
|
||||||
|
# If first chunk is a string
|
||||||
|
if isinstance(first_content, str):
|
||||||
|
# If the second chunk is also a string, then merge them naively
|
||||||
|
if isinstance(second_content, str):
|
||||||
|
return first_content + second_content
|
||||||
|
# If the second chunk is a list, add the first chunk to the start of the list
|
||||||
|
else:
|
||||||
|
return_list: List[Union[str, Dict]] = [first_content]
|
||||||
|
return return_list + second_content
|
||||||
|
# If both are lists, merge them naively
|
||||||
|
elif isinstance(second_content, List):
|
||||||
|
return first_content + second_content
|
||||||
|
# If the first content is a list, and the second content is a string
|
||||||
|
else:
|
||||||
|
# If the last element of the first content is a string
|
||||||
|
# Add the second content to the last element
|
||||||
|
if isinstance(first_content[-1], str):
|
||||||
|
return first_content[:-1] + [first_content[-1] + second_content]
|
||||||
|
else:
|
||||||
|
# Otherwise, add the second content as a new element of the list
|
||||||
|
return first_content + [second_content]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMessageChunk(BaseMessage):
|
||||||
|
"""A Message chunk, which can be concatenated with other Message chunks."""
|
||||||
|
|
||||||
|
def _merge_kwargs_dict(
|
||||||
|
self, left: Dict[str, Any], right: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Merge additional_kwargs from another BaseMessageChunk into this one."""
|
||||||
|
merged = left.copy()
|
||||||
|
for k, v in right.items():
|
||||||
|
if k not in merged:
|
||||||
|
merged[k] = v
|
||||||
|
elif type(merged[k]) != type(v):
|
||||||
|
raise ValueError(
|
||||||
|
f'additional_kwargs["{k}"] already exists in this message,'
|
||||||
|
" but with a different type."
|
||||||
|
)
|
||||||
|
elif isinstance(merged[k], str):
|
||||||
|
merged[k] += v
|
||||||
|
elif isinstance(merged[k], dict):
|
||||||
|
merged[k] = self._merge_kwargs_dict(merged[k], v)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Additional kwargs key {k} already exists in this message."
|
||||||
|
)
|
||||||
|
return merged
|
||||||
|
|
||||||
|
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||||
|
if isinstance(other, BaseMessageChunk):
|
||||||
|
# If both are (subclasses of) BaseMessageChunk,
|
||||||
|
# concat into a single BaseMessageChunk
|
||||||
|
|
||||||
|
return self.__class__(
|
||||||
|
content=merge_content(self.content, other.content),
|
||||||
|
additional_kwargs=self._merge_kwargs_dict(
|
||||||
|
self.additional_kwargs, other.additional_kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'unsupported operand type(s) for +: "'
|
||||||
|
f"{self.__class__.__name__}"
|
||||||
|
f'" and "{other.__class__.__name__}"'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def message_to_dict(message: BaseMessage) -> dict:
|
||||||
|
return {"type": message.type, "data": message.dict()}
|
||||||
|
|
||||||
|
|
||||||
|
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
|
||||||
|
"""Convert a sequence of Messages to a list of dictionaries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Sequence of messages (as BaseMessages) to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of messages as dicts.
|
||||||
|
"""
|
||||||
|
return [message_to_dict(m) for m in messages]
|
53
libs/core/langchain_core/messages/chat.py
Normal file
53
libs/core/langchain_core/messages/chat.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from langchain_core.messages.base import (
|
||||||
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
|
merge_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseMessage):
|
||||||
|
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
|
||||||
|
|
||||||
|
role: str
|
||||||
|
"""The speaker / role of the Message."""
|
||||||
|
|
||||||
|
type: Literal["chat"] = "chat"
|
||||||
|
|
||||||
|
|
||||||
|
ChatMessage.update_forward_refs()
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||||
|
"""A Chat Message chunk."""
|
||||||
|
|
||||||
|
# Ignoring mypy re-assignment here since we're overriding the value
|
||||||
|
# to make sure that the chunk variant can be discriminated from the
|
||||||
|
# non-chunk variant.
|
||||||
|
type: Literal["ChatMessageChunk"] = "ChatMessageChunk" # type: ignore
|
||||||
|
|
||||||
|
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||||
|
if isinstance(other, ChatMessageChunk):
|
||||||
|
if self.role != other.role:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot concatenate ChatMessageChunks with different roles."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.__class__(
|
||||||
|
role=self.role,
|
||||||
|
content=merge_content(self.content, other.content),
|
||||||
|
additional_kwargs=self._merge_kwargs_dict(
|
||||||
|
self.additional_kwargs, other.additional_kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif isinstance(other, BaseMessageChunk):
|
||||||
|
return self.__class__(
|
||||||
|
role=self.role,
|
||||||
|
content=merge_content(self.content, other.content),
|
||||||
|
additional_kwargs=self._merge_kwargs_dict(
|
||||||
|
self.additional_kwargs, other.additional_kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return super().__add__(other)
|
45
libs/core/langchain_core/messages/function.py
Normal file
45
libs/core/langchain_core/messages/function.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from langchain_core.messages.base import (
|
||||||
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
|
merge_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionMessage(BaseMessage):
|
||||||
|
"""A Message for passing the result of executing a function back to a model."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""The name of the function that was executed."""
|
||||||
|
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
|
||||||
|
|
||||||
|
FunctionMessage.update_forward_refs()
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||||
|
"""A Function Message chunk."""
|
||||||
|
|
||||||
|
# Ignoring mypy re-assignment here since we're overriding the value
|
||||||
|
# to make sure that the chunk variant can be discriminated from the
|
||||||
|
# non-chunk variant.
|
||||||
|
type: Literal["FunctionMessageChunk"] = "FunctionMessageChunk" # type: ignore[assignment]
|
||||||
|
|
||||||
|
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||||
|
if isinstance(other, FunctionMessageChunk):
|
||||||
|
if self.name != other.name:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot concatenate FunctionMessageChunks with different names."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.__class__(
|
||||||
|
name=self.name,
|
||||||
|
content=merge_content(self.content, other.content),
|
||||||
|
additional_kwargs=self._merge_kwargs_dict(
|
||||||
|
self.additional_kwargs, other.additional_kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().__add__(other)
|
26
libs/core/langchain_core/messages/human.py
Normal file
26
libs/core/langchain_core/messages/human.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||||
|
|
||||||
|
|
||||||
|
class HumanMessage(BaseMessage):
|
||||||
|
"""A Message from a human."""
|
||||||
|
|
||||||
|
example: bool = False
|
||||||
|
"""Whether this Message is being passed in to the model as part of an example
|
||||||
|
conversation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["human"] = "human"
|
||||||
|
|
||||||
|
|
||||||
|
HumanMessage.update_forward_refs()
|
||||||
|
|
||||||
|
|
||||||
|
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
||||||
|
"""A Human Message chunk."""
|
||||||
|
|
||||||
|
# Ignoring mypy re-assignment here since we're overriding the value
|
||||||
|
# to make sure that the chunk variant can be discriminated from the
|
||||||
|
# non-chunk variant.
|
||||||
|
type: Literal["HumanMessageChunk"] = "HumanMessageChunk" # type: ignore[assignment] # noqa: E501
|
23
libs/core/langchain_core/messages/system.py
Normal file
23
libs/core/langchain_core/messages/system.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||||
|
|
||||||
|
|
||||||
|
class SystemMessage(BaseMessage):
|
||||||
|
"""A Message for priming AI behavior, usually passed in as the first of a sequence
|
||||||
|
of input messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["system"] = "system"
|
||||||
|
|
||||||
|
|
||||||
|
SystemMessage.update_forward_refs()
|
||||||
|
|
||||||
|
|
||||||
|
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
||||||
|
"""A System Message chunk."""
|
||||||
|
|
||||||
|
# Ignoring mypy re-assignment here since we're overriding the value
|
||||||
|
# to make sure that the chunk variant can be discriminated from the
|
||||||
|
# non-chunk variant.
|
||||||
|
type: Literal["SystemMessageChunk"] = "SystemMessageChunk" # type: ignore[assignment] # noqa: E501
|
45
libs/core/langchain_core/messages/tool.py
Normal file
45
libs/core/langchain_core/messages/tool.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from langchain_core.messages.base import (
|
||||||
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
|
merge_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolMessage(BaseMessage):
|
||||||
|
"""A Message for passing the result of executing a tool back to a model."""
|
||||||
|
|
||||||
|
tool_call_id: str
|
||||||
|
"""Tool call that this message is responding to."""
|
||||||
|
|
||||||
|
type: Literal["tool"] = "tool"
|
||||||
|
|
||||||
|
|
||||||
|
ToolMessage.update_forward_refs()
|
||||||
|
|
||||||
|
|
||||||
|
class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||||
|
"""A Tool Message chunk."""
|
||||||
|
|
||||||
|
# Ignoring mypy re-assignment here since we're overriding the value
|
||||||
|
# to make sure that the chunk variant can be discriminated from the
|
||||||
|
# non-chunk variant.
|
||||||
|
type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment]
|
||||||
|
|
||||||
|
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||||
|
if isinstance(other, ToolMessageChunk):
|
||||||
|
if self.tool_call_id != other.tool_call_id:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot concatenate ToolMessageChunks with different names."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.__class__(
|
||||||
|
tool_call_id=self.tool_call_id,
|
||||||
|
content=merge_content(self.content, other.content),
|
||||||
|
additional_kwargs=self._merge_kwargs_dict(
|
||||||
|
self.additional_kwargs, other.additional_kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().__add__(other)
|
@ -0,0 +1,29 @@
|
|||||||
|
from langchain_core.output_parsers.base import (
|
||||||
|
BaseGenerationOutputParser,
|
||||||
|
BaseLLMOutputParser,
|
||||||
|
BaseOutputParser,
|
||||||
|
)
|
||||||
|
from langchain_core.output_parsers.list import (
|
||||||
|
CommaSeparatedListOutputParser,
|
||||||
|
ListOutputParser,
|
||||||
|
MarkdownListOutputParser,
|
||||||
|
NumberedListOutputParser,
|
||||||
|
)
|
||||||
|
from langchain_core.output_parsers.str import StrOutputParser
|
||||||
|
from langchain_core.output_parsers.transform import (
|
||||||
|
BaseCumulativeTransformOutputParser,
|
||||||
|
BaseTransformOutputParser,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseLLMOutputParser",
|
||||||
|
"BaseGenerationOutputParser",
|
||||||
|
"BaseOutputParser",
|
||||||
|
"ListOutputParser",
|
||||||
|
"CommaSeparatedListOutputParser",
|
||||||
|
"NumberedListOutputParser",
|
||||||
|
"MarkdownListOutputParser",
|
||||||
|
"StrOutputParser",
|
||||||
|
"BaseTransformOutputParser",
|
||||||
|
"BaseCumulativeTransformOutputParser",
|
||||||
|
]
|
@ -5,10 +5,8 @@ import functools
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Dict,
|
Dict,
|
||||||
Generic,
|
Generic,
|
||||||
Iterator,
|
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Type,
|
Type,
|
||||||
@ -18,15 +16,13 @@ from typing import (
|
|||||||
|
|
||||||
from typing_extensions import get_args
|
from typing_extensions import get_args
|
||||||
|
|
||||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
from langchain_core.messages import AnyMessage, BaseMessage
|
||||||
from langchain_core.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
|
from langchain_core.outputs import (
|
||||||
from langchain_core.schema.output import (
|
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
ChatGenerationChunk,
|
|
||||||
Generation,
|
Generation,
|
||||||
GenerationChunk,
|
|
||||||
)
|
)
|
||||||
from langchain_core.schema.prompt import PromptValue
|
from langchain_core.prompts.value import PromptValue
|
||||||
|
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@ -303,173 +299,3 @@ class BaseOutputParser(
|
|||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
pass
|
pass
|
||||||
return output_parser_dict
|
return output_parser_dict
|
||||||
|
|
||||||
|
|
||||||
class BaseTransformOutputParser(BaseOutputParser[T]):
|
|
||||||
"""Base class for an output parser that can handle streaming input."""
|
|
||||||
|
|
||||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]:
|
|
||||||
for chunk in input:
|
|
||||||
if isinstance(chunk, BaseMessage):
|
|
||||||
yield self.parse_result([ChatGeneration(message=chunk)])
|
|
||||||
else:
|
|
||||||
yield self.parse_result([Generation(text=chunk)])
|
|
||||||
|
|
||||||
async def _atransform(
|
|
||||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
|
||||||
) -> AsyncIterator[T]:
|
|
||||||
async for chunk in input:
|
|
||||||
if isinstance(chunk, BaseMessage):
|
|
||||||
yield self.parse_result([ChatGeneration(message=chunk)])
|
|
||||||
else:
|
|
||||||
yield self.parse_result([Generation(text=chunk)])
|
|
||||||
|
|
||||||
def transform(
|
|
||||||
self,
|
|
||||||
input: Iterator[Union[str, BaseMessage]],
|
|
||||||
config: Optional[RunnableConfig] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Iterator[T]:
|
|
||||||
yield from self._transform_stream_with_config(
|
|
||||||
input, self._transform, config, run_type="parser"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def atransform(
|
|
||||||
self,
|
|
||||||
input: AsyncIterator[Union[str, BaseMessage]],
|
|
||||||
config: Optional[RunnableConfig] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> AsyncIterator[T]:
|
|
||||||
async for chunk in self._atransform_stream_with_config(
|
|
||||||
input, self._atransform, config, run_type="parser"
|
|
||||||
):
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
|
|
||||||
class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|
||||||
"""Base class for an output parser that can handle streaming input."""
|
|
||||||
|
|
||||||
diff: bool = False
|
|
||||||
"""In streaming mode, whether to yield diffs between the previous and current
|
|
||||||
parsed output, or just the current parsed output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _diff(self, prev: Optional[T], next: T) -> T:
|
|
||||||
"""Convert parsed outputs into a diff format. The semantics of this are
|
|
||||||
up to the output parser."""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
|
||||||
prev_parsed = None
|
|
||||||
acc_gen = None
|
|
||||||
for chunk in input:
|
|
||||||
if isinstance(chunk, BaseMessageChunk):
|
|
||||||
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
|
||||||
elif isinstance(chunk, BaseMessage):
|
|
||||||
chunk_gen = ChatGenerationChunk(
|
|
||||||
message=BaseMessageChunk(**chunk.dict())
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
chunk_gen = GenerationChunk(text=chunk)
|
|
||||||
|
|
||||||
if acc_gen is None:
|
|
||||||
acc_gen = chunk_gen
|
|
||||||
else:
|
|
||||||
acc_gen += chunk_gen
|
|
||||||
|
|
||||||
parsed = self.parse_result([acc_gen], partial=True)
|
|
||||||
if parsed is not None and parsed != prev_parsed:
|
|
||||||
if self.diff:
|
|
||||||
yield self._diff(prev_parsed, parsed)
|
|
||||||
else:
|
|
||||||
yield parsed
|
|
||||||
prev_parsed = parsed
|
|
||||||
|
|
||||||
async def _atransform(
|
|
||||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
|
||||||
) -> AsyncIterator[T]:
|
|
||||||
prev_parsed = None
|
|
||||||
acc_gen = None
|
|
||||||
async for chunk in input:
|
|
||||||
if isinstance(chunk, BaseMessageChunk):
|
|
||||||
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
|
||||||
elif isinstance(chunk, BaseMessage):
|
|
||||||
chunk_gen = ChatGenerationChunk(
|
|
||||||
message=BaseMessageChunk(**chunk.dict())
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
chunk_gen = GenerationChunk(text=chunk)
|
|
||||||
|
|
||||||
if acc_gen is None:
|
|
||||||
acc_gen = chunk_gen
|
|
||||||
else:
|
|
||||||
acc_gen += chunk_gen
|
|
||||||
|
|
||||||
parsed = self.parse_result([acc_gen], partial=True)
|
|
||||||
if parsed is not None and parsed != prev_parsed:
|
|
||||||
if self.diff:
|
|
||||||
yield self._diff(prev_parsed, parsed)
|
|
||||||
else:
|
|
||||||
yield parsed
|
|
||||||
prev_parsed = parsed
|
|
||||||
|
|
||||||
|
|
||||||
class StrOutputParser(BaseTransformOutputParser[str]):
|
|
||||||
"""OutputParser that parses LLMResult into the top likely string."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_lc_serializable(cls) -> bool:
|
|
||||||
"""Return whether this class is serializable."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _type(self) -> str:
|
|
||||||
"""Return the output parser type for serialization."""
|
|
||||||
return "default"
|
|
||||||
|
|
||||||
def parse(self, text: str) -> str:
|
|
||||||
"""Returns the input text with no changes."""
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Deprecate
|
|
||||||
NoOpOutputParser = StrOutputParser
|
|
||||||
|
|
||||||
|
|
||||||
class OutputParserException(ValueError):
|
|
||||||
"""Exception that output parsers should raise to signify a parsing error.
|
|
||||||
|
|
||||||
This exists to differentiate parsing errors from other code or execution errors
|
|
||||||
that also may arise inside the output parser. OutputParserExceptions will be
|
|
||||||
available to catch and handle in ways to fix the parsing error, while other
|
|
||||||
errors will be raised.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
error: The error that's being re-raised or an error message.
|
|
||||||
observation: String explanation of error which can be passed to a
|
|
||||||
model to try and remediate the issue.
|
|
||||||
llm_output: String model output which is error-ing.
|
|
||||||
send_to_llm: Whether to send the observation and llm_output back to an Agent
|
|
||||||
after an OutputParserException has been raised. This gives the underlying
|
|
||||||
model driving the agent the context that the previous output was improperly
|
|
||||||
structured, in the hopes that it will update the output to the correct
|
|
||||||
format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
error: Any,
|
|
||||||
observation: Optional[str] = None,
|
|
||||||
llm_output: Optional[str] = None,
|
|
||||||
send_to_llm: bool = False,
|
|
||||||
):
|
|
||||||
super(OutputParserException, self).__init__(error)
|
|
||||||
if send_to_llm:
|
|
||||||
if observation is None or llm_output is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Arguments 'observation' & 'llm_output'"
|
|
||||||
" are required if 'send_to_llm' is True"
|
|
||||||
)
|
|
||||||
self.observation = observation
|
|
||||||
self.llm_output = llm_output
|
|
||||||
self.send_to_llm = send_to_llm
|
|
@ -4,7 +4,7 @@ import re
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain_core.schema import BaseOutputParser
|
from langchain_core.output_parsers.base import BaseOutputParser
|
||||||
|
|
||||||
|
|
||||||
class ListOutputParser(BaseOutputParser[List[str]]):
|
class ListOutputParser(BaseOutputParser[List[str]]):
|
||||||
|
19
libs/core/langchain_core/output_parsers/str.py
Normal file
19
libs/core/langchain_core/output_parsers/str.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||||
|
|
||||||
|
|
||||||
|
class StrOutputParser(BaseTransformOutputParser[str]):
|
||||||
|
"""OutputParser that parses LLMResult into the top likely string."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
"""Return whether this class is serializable."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
"""Return the output parser type for serialization."""
|
||||||
|
return "default"
|
||||||
|
|
||||||
|
def parse(self, text: str) -> str:
|
||||||
|
"""Returns the input text with no changes."""
|
||||||
|
return text
|
128
libs/core/langchain_core/output_parsers/transform.py
Normal file
128
libs/core/langchain_core/output_parsers/transform.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Iterator,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||||
|
from langchain_core.output_parsers.base import BaseOutputParser, T
|
||||||
|
from langchain_core.outputs import (
|
||||||
|
ChatGeneration,
|
||||||
|
ChatGenerationChunk,
|
||||||
|
Generation,
|
||||||
|
GenerationChunk,
|
||||||
|
)
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||||
|
"""Base class for an output parser that can handle streaming input."""
|
||||||
|
|
||||||
|
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]:
|
||||||
|
for chunk in input:
|
||||||
|
if isinstance(chunk, BaseMessage):
|
||||||
|
yield self.parse_result([ChatGeneration(message=chunk)])
|
||||||
|
else:
|
||||||
|
yield self.parse_result([Generation(text=chunk)])
|
||||||
|
|
||||||
|
async def _atransform(
|
||||||
|
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||||
|
) -> AsyncIterator[T]:
|
||||||
|
async for chunk in input:
|
||||||
|
if isinstance(chunk, BaseMessage):
|
||||||
|
yield self.parse_result([ChatGeneration(message=chunk)])
|
||||||
|
else:
|
||||||
|
yield self.parse_result([Generation(text=chunk)])
|
||||||
|
|
||||||
|
def transform(
|
||||||
|
self,
|
||||||
|
input: Iterator[Union[str, BaseMessage]],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[T]:
|
||||||
|
yield from self._transform_stream_with_config(
|
||||||
|
input, self._transform, config, run_type="parser"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def atransform(
|
||||||
|
self,
|
||||||
|
input: AsyncIterator[Union[str, BaseMessage]],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[T]:
|
||||||
|
async for chunk in self._atransform_stream_with_config(
|
||||||
|
input, self._atransform, config, run_type="parser"
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||||
|
"""Base class for an output parser that can handle streaming input."""
|
||||||
|
|
||||||
|
diff: bool = False
|
||||||
|
"""In streaming mode, whether to yield diffs between the previous and current
|
||||||
|
parsed output, or just the current parsed output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _diff(self, prev: Optional[T], next: T) -> T:
|
||||||
|
"""Convert parsed outputs into a diff format. The semantics of this are
|
||||||
|
up to the output parser."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
||||||
|
prev_parsed = None
|
||||||
|
acc_gen = None
|
||||||
|
for chunk in input:
|
||||||
|
if isinstance(chunk, BaseMessageChunk):
|
||||||
|
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
||||||
|
elif isinstance(chunk, BaseMessage):
|
||||||
|
chunk_gen = ChatGenerationChunk(
|
||||||
|
message=BaseMessageChunk(**chunk.dict())
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk_gen = GenerationChunk(text=chunk)
|
||||||
|
|
||||||
|
if acc_gen is None:
|
||||||
|
acc_gen = chunk_gen
|
||||||
|
else:
|
||||||
|
acc_gen += chunk_gen
|
||||||
|
|
||||||
|
parsed = self.parse_result([acc_gen], partial=True)
|
||||||
|
if parsed is not None and parsed != prev_parsed:
|
||||||
|
if self.diff:
|
||||||
|
yield self._diff(prev_parsed, parsed)
|
||||||
|
else:
|
||||||
|
yield parsed
|
||||||
|
prev_parsed = parsed
|
||||||
|
|
||||||
|
async def _atransform(
|
||||||
|
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||||
|
) -> AsyncIterator[T]:
|
||||||
|
prev_parsed = None
|
||||||
|
acc_gen = None
|
||||||
|
async for chunk in input:
|
||||||
|
if isinstance(chunk, BaseMessageChunk):
|
||||||
|
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
||||||
|
elif isinstance(chunk, BaseMessage):
|
||||||
|
chunk_gen = ChatGenerationChunk(
|
||||||
|
message=BaseMessageChunk(**chunk.dict())
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk_gen = GenerationChunk(text=chunk)
|
||||||
|
|
||||||
|
if acc_gen is None:
|
||||||
|
acc_gen = chunk_gen
|
||||||
|
else:
|
||||||
|
acc_gen += chunk_gen
|
||||||
|
|
||||||
|
parsed = self.parse_result([acc_gen], partial=True)
|
||||||
|
if parsed is not None and parsed != prev_parsed:
|
||||||
|
if self.diff:
|
||||||
|
yield self._diff(prev_parsed, parsed)
|
||||||
|
else:
|
||||||
|
yield parsed
|
||||||
|
prev_parsed = parsed
|
15
libs/core/langchain_core/outputs/__init__.py
Normal file
15
libs/core/langchain_core/outputs/__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from langchain_core.outputs.chat_generation import ChatGeneration, ChatGenerationChunk
|
||||||
|
from langchain_core.outputs.chat_result import ChatResult
|
||||||
|
from langchain_core.outputs.generation import Generation, GenerationChunk
|
||||||
|
from langchain_core.outputs.llm_result import LLMResult
|
||||||
|
from langchain_core.outputs.run_info import RunInfo
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ChatGeneration",
|
||||||
|
"ChatGenerationChunk",
|
||||||
|
"ChatResult",
|
||||||
|
"Generation",
|
||||||
|
"GenerationChunk",
|
||||||
|
"LLMResult",
|
||||||
|
"RunInfo",
|
||||||
|
]
|
58
libs/core/langchain_core/outputs/chat_generation.py
Normal file
58
libs/core/langchain_core/outputs/chat_generation.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, Literal
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||||
|
from langchain_core.outputs.generation import Generation
|
||||||
|
from langchain_core.pydantic_v1 import root_validator
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGeneration(Generation):
|
||||||
|
"""A single chat generation output."""
|
||||||
|
|
||||||
|
text: str = ""
|
||||||
|
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
|
||||||
|
message: BaseMessage
|
||||||
|
"""The message output by the chat model."""
|
||||||
|
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||||
|
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
|
||||||
|
"""Type is used exclusively for serialization purposes."""
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Set the text attribute to be the contents of the message."""
|
||||||
|
try:
|
||||||
|
values["text"] = values["message"].content
|
||||||
|
except (KeyError, AttributeError) as e:
|
||||||
|
raise ValueError("Error while initializing ChatGeneration") from e
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGenerationChunk(ChatGeneration):
|
||||||
|
"""A ChatGeneration chunk, which can be concatenated with other
|
||||||
|
ChatGeneration chunks.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
message: The message chunk output by the chat model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
message: BaseMessageChunk
|
||||||
|
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||||
|
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501
|
||||||
|
"""Type is used exclusively for serialization purposes."""
|
||||||
|
|
||||||
|
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
|
||||||
|
if isinstance(other, ChatGenerationChunk):
|
||||||
|
generation_info = (
|
||||||
|
{**(self.generation_info or {}), **(other.generation_info or {})}
|
||||||
|
if self.generation_info is not None or other.generation_info is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return ChatGenerationChunk(
|
||||||
|
message=self.message + other.message,
|
||||||
|
generation_info=generation_info,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||||
|
)
|
15
libs/core/langchain_core/outputs/chat_result.py
Normal file
15
libs/core/langchain_core/outputs/chat_result.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from langchain_core.outputs.chat_generation import ChatGeneration
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ChatResult(BaseModel):
|
||||||
|
"""Class that contains all results for a single chat model call."""
|
||||||
|
|
||||||
|
generations: List[ChatGeneration]
|
||||||
|
"""List of the chat generations. This is a List because an input can have multiple
|
||||||
|
candidate generations.
|
||||||
|
"""
|
||||||
|
llm_output: Optional[dict] = None
|
||||||
|
"""For arbitrary LLM provider specific output."""
|
45
libs/core/langchain_core/outputs/generation.py
Normal file
45
libs/core/langchain_core/outputs/generation.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, Literal, Optional
|
||||||
|
|
||||||
|
from langchain_core.load import Serializable
|
||||||
|
|
||||||
|
|
||||||
|
class Generation(Serializable):
|
||||||
|
"""A single text generation output."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
"""Generated text output."""
|
||||||
|
|
||||||
|
generation_info: Optional[Dict[str, Any]] = None
|
||||||
|
"""Raw response from the provider. May include things like the
|
||||||
|
reason for finishing or token log probabilities.
|
||||||
|
"""
|
||||||
|
type: Literal["Generation"] = "Generation"
|
||||||
|
"""Type is used exclusively for serialization purposes."""
|
||||||
|
# TODO: add log probs as separate attribute
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
"""Return whether this class is serializable."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationChunk(Generation):
|
||||||
|
"""A Generation chunk, which can be concatenated with other Generation chunks."""
|
||||||
|
|
||||||
|
def __add__(self, other: GenerationChunk) -> GenerationChunk:
|
||||||
|
if isinstance(other, GenerationChunk):
|
||||||
|
generation_info = (
|
||||||
|
{**(self.generation_info or {}), **(other.generation_info or {})}
|
||||||
|
if self.generation_info is not None or other.generation_info is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return GenerationChunk(
|
||||||
|
text=self.text + other.text,
|
||||||
|
generation_info=generation_info,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||||
|
)
|
65
libs/core/langchain_core/outputs/llm_result.py
Normal file
65
libs/core/langchain_core/outputs/llm_result.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from langchain_core.outputs.generation import Generation
|
||||||
|
from langchain_core.outputs.run_info import RunInfo
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class LLMResult(BaseModel):
|
||||||
|
"""Class that contains all results for a batched LLM call."""
|
||||||
|
|
||||||
|
generations: List[List[Generation]]
|
||||||
|
"""List of generated outputs. This is a List[List[]] because
|
||||||
|
each input could have multiple candidate generations."""
|
||||||
|
llm_output: Optional[dict] = None
|
||||||
|
"""Arbitrary LLM provider-specific output."""
|
||||||
|
run: Optional[List[RunInfo]] = None
|
||||||
|
"""List of metadata info for model call for each input."""
|
||||||
|
|
||||||
|
def flatten(self) -> List[LLMResult]:
|
||||||
|
"""Flatten generations into a single list.
|
||||||
|
|
||||||
|
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult
|
||||||
|
contains only a single Generation. If token usage information is available,
|
||||||
|
it is kept only for the LLMResult corresponding to the top-choice
|
||||||
|
Generation, to avoid over-counting of token usage downstream.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LLMResults where each returned LLMResult contains a single
|
||||||
|
Generation.
|
||||||
|
"""
|
||||||
|
llm_results = []
|
||||||
|
for i, gen_list in enumerate(self.generations):
|
||||||
|
# Avoid double counting tokens in OpenAICallback
|
||||||
|
if i == 0:
|
||||||
|
llm_results.append(
|
||||||
|
LLMResult(
|
||||||
|
generations=[gen_list],
|
||||||
|
llm_output=self.llm_output,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.llm_output is not None:
|
||||||
|
llm_output = deepcopy(self.llm_output)
|
||||||
|
llm_output["token_usage"] = dict()
|
||||||
|
else:
|
||||||
|
llm_output = None
|
||||||
|
llm_results.append(
|
||||||
|
LLMResult(
|
||||||
|
generations=[gen_list],
|
||||||
|
llm_output=llm_output,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return llm_results
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
"""Check for LLMResult equality by ignoring any metadata related to runs."""
|
||||||
|
if not isinstance(other, LLMResult):
|
||||||
|
return NotImplemented
|
||||||
|
return (
|
||||||
|
self.generations == other.generations
|
||||||
|
and self.llm_output == other.llm_output
|
||||||
|
)
|
12
libs/core/langchain_core/outputs/run_info.py
Normal file
12
libs/core/langchain_core/outputs/run_info.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class RunInfo(BaseModel):
|
||||||
|
"""Class that contains metadata for a single execution of a Chain or model."""
|
||||||
|
|
||||||
|
run_id: UUID
|
||||||
|
"""A unique identifier for the model or chain run."""
|
@ -27,21 +27,18 @@ from multiple components. Prompt classes and functions make constructing
|
|||||||
ChatPromptValue
|
ChatPromptValue
|
||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
from langchain_core.prompts.base import StringPromptTemplate
|
from langchain_core.prompts.base import BasePromptTemplate, format_document
|
||||||
from langchain_core.prompts.chat import (
|
from langchain_core.prompts.chat import (
|
||||||
AIMessagePromptTemplate,
|
AIMessagePromptTemplate,
|
||||||
BaseChatPromptTemplate,
|
BaseChatPromptTemplate,
|
||||||
ChatMessagePromptTemplate,
|
ChatMessagePromptTemplate,
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
|
ChatPromptValue,
|
||||||
|
ChatPromptValueConcrete,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
MessagesPlaceholder,
|
MessagesPlaceholder,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain_core.prompts.example_selector import (
|
|
||||||
LengthBasedExampleSelector,
|
|
||||||
MaxMarginalRelevanceExampleSelector,
|
|
||||||
SemanticSimilarityExampleSelector,
|
|
||||||
)
|
|
||||||
from langchain_core.prompts.few_shot import (
|
from langchain_core.prompts.few_shot import (
|
||||||
FewShotChatMessagePromptTemplate,
|
FewShotChatMessagePromptTemplate,
|
||||||
FewShotPromptTemplate,
|
FewShotPromptTemplate,
|
||||||
@ -50,7 +47,7 @@ from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemp
|
|||||||
from langchain_core.prompts.loading import load_prompt
|
from langchain_core.prompts.loading import load_prompt
|
||||||
from langchain_core.prompts.pipeline import PipelinePromptTemplate
|
from langchain_core.prompts.pipeline import PipelinePromptTemplate
|
||||||
from langchain_core.prompts.prompt import Prompt, PromptTemplate
|
from langchain_core.prompts.prompt import Prompt, PromptTemplate
|
||||||
from langchain_core.schema.prompt_template import BasePromptTemplate
|
from langchain_core.prompts.string import StringPromptTemplate, StringPromptValue
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AIMessagePromptTemplate",
|
"AIMessagePromptTemplate",
|
||||||
@ -58,18 +55,22 @@ __all__ = [
|
|||||||
"BasePromptTemplate",
|
"BasePromptTemplate",
|
||||||
"ChatMessagePromptTemplate",
|
"ChatMessagePromptTemplate",
|
||||||
"ChatPromptTemplate",
|
"ChatPromptTemplate",
|
||||||
|
"ChatPromptValue",
|
||||||
|
"ChatPromptValueConcrete",
|
||||||
"FewShotPromptTemplate",
|
"FewShotPromptTemplate",
|
||||||
"FewShotPromptWithTemplates",
|
"FewShotPromptWithTemplates",
|
||||||
|
"FewShotChatMessagePromptTemplate",
|
||||||
"HumanMessagePromptTemplate",
|
"HumanMessagePromptTemplate",
|
||||||
"LengthBasedExampleSelector",
|
|
||||||
"MaxMarginalRelevanceExampleSelector",
|
|
||||||
"MessagesPlaceholder",
|
"MessagesPlaceholder",
|
||||||
"PipelinePromptTemplate",
|
"PipelinePromptTemplate",
|
||||||
"Prompt",
|
"Prompt",
|
||||||
"PromptTemplate",
|
"PromptTemplate",
|
||||||
"SemanticSimilarityExampleSelector",
|
"PromptValue",
|
||||||
|
"StringPromptValue",
|
||||||
"StringPromptTemplate",
|
"StringPromptTemplate",
|
||||||
"SystemMessagePromptTemplate",
|
"SystemMessagePromptTemplate",
|
||||||
"load_prompt",
|
"load_prompt",
|
||||||
"FewShotChatMessagePromptTemplate",
|
"format_document",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
from langchain_core.prompts.value import PromptValue
|
||||||
|
@ -1,173 +1,228 @@
|
|||||||
"""BasePrompt schema definition."""
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import warnings
|
import json
|
||||||
from abc import ABC
|
from abc import ABC, abstractmethod
|
||||||
from string import Formatter
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Literal, Set
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union
|
||||||
|
|
||||||
from langchain_core.schema.messages import BaseMessage, HumanMessage
|
import yaml
|
||||||
from langchain_core.schema.prompt import PromptValue
|
|
||||||
from langchain_core.schema.prompt_template import BasePromptTemplate
|
from langchain_core.documents import Document
|
||||||
from langchain_core.utils.formatting import formatter
|
from langchain_core.output_parsers import BaseOutputParser
|
||||||
|
from langchain_core.prompts.value import PromptValue
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator
|
||||||
|
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||||
|
|
||||||
|
|
||||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||||
"""Format a template using jinja2.
|
"""Base class for all prompt templates, returning a prompt."""
|
||||||
|
|
||||||
*Security warning*: As of LangChain 0.0.329, this method uses Jinja2's
|
input_variables: List[str]
|
||||||
SandboxedEnvironment by default. However, this sand-boxing should
|
"""A list of the names of the variables the prompt template expects."""
|
||||||
be treated as a best-effort approach rather than a guarantee of security.
|
input_types: Dict[str, Any] = Field(default_factory=dict)
|
||||||
Do not accept jinja2 templates from untrusted sources as they may lead
|
"""A dictionary of the types of the variables the prompt template expects.
|
||||||
to arbitrary Python code execution.
|
If not provided, all variables are assumed to be strings."""
|
||||||
|
output_parser: Optional[BaseOutputParser] = None
|
||||||
https://jinja.palletsprojects.com/en/3.1.x/sandbox/
|
"""How to parse the output of calling an LLM on this formatted prompt."""
|
||||||
"""
|
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
|
||||||
try:
|
default_factory=dict
|
||||||
from jinja2.sandbox import SandboxedEnvironment
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
|
||||||
"Please install it with `pip install jinja2`."
|
|
||||||
"Please be cautious when using jinja2 templates. "
|
|
||||||
"Do not expand jinja2 templates using unverified or user-controlled "
|
|
||||||
"inputs as that can result in arbitrary Python code execution."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# This uses a sandboxed environment to prevent arbitrary code execution.
|
@classmethod
|
||||||
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
|
def is_lc_serializable(cls) -> bool:
|
||||||
# Please treat this sand-boxing as a best-effort approach rather than
|
"""Return whether this class is serializable."""
|
||||||
# a guarantee of security.
|
return True
|
||||||
# We recommend to never use jinja2 templates with untrusted inputs.
|
|
||||||
# https://jinja.palletsprojects.com/en/3.1.x/sandbox/
|
|
||||||
# approach not a guarantee of security.
|
|
||||||
return SandboxedEnvironment().from_string(template).render(**kwargs)
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
def validate_jinja2(template: str, input_variables: List[str]) -> None:
|
arbitrary_types_allowed = True
|
||||||
"""
|
|
||||||
Validate that the input variables are valid for the template.
|
|
||||||
Issues a warning if missing or extra variables are found.
|
|
||||||
|
|
||||||
Args:
|
@property
|
||||||
template: The template string.
|
def OutputType(self) -> Any:
|
||||||
input_variables: The input variables.
|
from langchain_core.prompts.chat import ChatPromptValueConcrete
|
||||||
"""
|
from langchain_core.prompts.string import StringPromptValue
|
||||||
input_variables_set = set(input_variables)
|
|
||||||
valid_variables = _get_jinja2_variables_from_template(template)
|
|
||||||
missing_variables = valid_variables - input_variables_set
|
|
||||||
extra_variables = input_variables_set - valid_variables
|
|
||||||
|
|
||||||
warning_message = ""
|
return Union[StringPromptValue, ChatPromptValueConcrete]
|
||||||
if missing_variables:
|
|
||||||
warning_message += f"Missing variables: {missing_variables} "
|
|
||||||
|
|
||||||
if extra_variables:
|
def get_input_schema(
|
||||||
warning_message += f"Extra variables: {extra_variables}"
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
if warning_message:
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
warnings.warn(warning_message.strip())
|
return create_model( # type: ignore[call-overload]
|
||||||
|
"PromptInput",
|
||||||
|
**{k: (self.input_types.get(k, str), None) for k in self.input_variables},
|
||||||
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
|
||||||
try:
|
|
||||||
from jinja2 import Environment, meta
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
|
||||||
"Please install it with `pip install jinja2`."
|
|
||||||
)
|
|
||||||
env = Environment()
|
|
||||||
ast = env.parse(template)
|
|
||||||
variables = meta.find_undeclared_variables(ast)
|
|
||||||
return variables
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
|
||||||
"f-string": formatter.format,
|
|
||||||
"jinja2": jinja2_formatter,
|
|
||||||
}
|
|
||||||
|
|
||||||
DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
|
|
||||||
"f-string": formatter.validate_input_variables,
|
|
||||||
"jinja2": validate_jinja2,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def check_valid_template(
|
|
||||||
template: str, template_format: str, input_variables: List[str]
|
|
||||||
) -> None:
|
|
||||||
"""Check that template string is valid.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template: The template string.
|
|
||||||
template_format: The template format. Should be one of "f-string" or "jinja2".
|
|
||||||
input_variables: The input variables.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the template format is not supported.
|
|
||||||
"""
|
|
||||||
if template_format not in DEFAULT_FORMATTER_MAPPING:
|
|
||||||
valid_formats = list(DEFAULT_FORMATTER_MAPPING)
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid template format. Got `{template_format}`;"
|
|
||||||
f" should be one of {valid_formats}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
|
|
||||||
validator_func(template, input_variables)
|
|
||||||
except KeyError as e:
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid prompt schema; check for mismatched or missing input parameters. "
|
|
||||||
+ str(e)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: Dict, config: Optional[RunnableConfig] = None
|
||||||
|
) -> PromptValue:
|
||||||
|
return self._call_with_config(
|
||||||
|
lambda inner_input: self.format_prompt(
|
||||||
|
**{key: inner_input[key] for key in self.input_variables}
|
||||||
|
),
|
||||||
|
input,
|
||||||
|
config,
|
||||||
|
run_type="prompt",
|
||||||
|
)
|
||||||
|
|
||||||
def get_template_variables(template: str, template_format: str) -> List[str]:
|
@abstractmethod
|
||||||
"""Get the variables from the template.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template: The template string.
|
|
||||||
template_format: The template format. Should be one of "f-string" or "jinja2".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The variables from the template.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the template format is not supported.
|
|
||||||
"""
|
|
||||||
if template_format == "jinja2":
|
|
||||||
# Get the variables for the template
|
|
||||||
input_variables = _get_jinja2_variables_from_template(template)
|
|
||||||
elif template_format == "f-string":
|
|
||||||
input_variables = {
|
|
||||||
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported template format: {template_format}")
|
|
||||||
|
|
||||||
return sorted(input_variables)
|
|
||||||
|
|
||||||
|
|
||||||
class StringPromptValue(PromptValue):
|
|
||||||
"""String prompt value."""
|
|
||||||
|
|
||||||
text: str
|
|
||||||
"""Prompt text."""
|
|
||||||
type: Literal["StringPromptValue"] = "StringPromptValue"
|
|
||||||
|
|
||||||
def to_string(self) -> str:
|
|
||||||
"""Return prompt as string."""
|
|
||||||
return self.text
|
|
||||||
|
|
||||||
def to_messages(self) -> List[BaseMessage]:
|
|
||||||
"""Return prompt as messages."""
|
|
||||||
return [HumanMessage(content=self.text)]
|
|
||||||
|
|
||||||
|
|
||||||
class StringPromptTemplate(BasePromptTemplate, ABC):
|
|
||||||
"""String prompt that exposes the format method, returning a prompt."""
|
|
||||||
|
|
||||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
"""Create Chat Messages."""
|
"""Create Chat Messages."""
|
||||||
return StringPromptValue(text=self.format(**kwargs))
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate variable names do not include restricted names."""
|
||||||
|
if "stop" in values["input_variables"]:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have an input variable named 'stop', as it is used internally,"
|
||||||
|
" please rename."
|
||||||
|
)
|
||||||
|
if "stop" in values["partial_variables"]:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have an partial variable named 'stop', as it is used "
|
||||||
|
"internally, please rename."
|
||||||
|
)
|
||||||
|
|
||||||
|
overall = set(values["input_variables"]).intersection(
|
||||||
|
values["partial_variables"]
|
||||||
|
)
|
||||||
|
if overall:
|
||||||
|
raise ValueError(
|
||||||
|
f"Found overlapping input and partial variables: {overall}"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
|
||||||
|
"""Return a partial of the prompt template."""
|
||||||
|
prompt_dict = self.__dict__.copy()
|
||||||
|
prompt_dict["input_variables"] = list(
|
||||||
|
set(self.input_variables).difference(kwargs)
|
||||||
|
)
|
||||||
|
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
|
||||||
|
return type(self)(**prompt_dict)
|
||||||
|
|
||||||
|
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
# Get partial params:
|
||||||
|
partial_kwargs = {
|
||||||
|
k: v if isinstance(v, str) else v()
|
||||||
|
for k, v in self.partial_variables.items()
|
||||||
|
}
|
||||||
|
return {**partial_kwargs, **kwargs}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format(self, **kwargs: Any) -> str:
|
||||||
|
"""Format the prompt with the inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kwargs: Any arguments to be passed to the prompt template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A formatted string.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
prompt.format(variable1="foo")
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _prompt_type(self) -> str:
|
||||||
|
"""Return the prompt type key."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
|
"""Return dictionary representation of prompt."""
|
||||||
|
prompt_dict = super().dict(**kwargs)
|
||||||
|
try:
|
||||||
|
prompt_dict["_type"] = self._prompt_type
|
||||||
|
except NotImplementedError:
|
||||||
|
pass
|
||||||
|
return prompt_dict
|
||||||
|
|
||||||
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
|
"""Save the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to directory to save prompt to.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
prompt.save(file_path="path/prompt.yaml")
|
||||||
|
"""
|
||||||
|
if self.partial_variables:
|
||||||
|
raise ValueError("Cannot save prompt with partial variables.")
|
||||||
|
|
||||||
|
# Fetch dictionary to save
|
||||||
|
prompt_dict = self.dict()
|
||||||
|
if "_type" not in prompt_dict:
|
||||||
|
raise NotImplementedError(f"Prompt {self} does not support saving.")
|
||||||
|
|
||||||
|
# Convert file to Path object.
|
||||||
|
if isinstance(file_path, str):
|
||||||
|
save_path = Path(file_path)
|
||||||
|
else:
|
||||||
|
save_path = file_path
|
||||||
|
|
||||||
|
directory_path = save_path.parent
|
||||||
|
directory_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if save_path.suffix == ".json":
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
json.dump(prompt_dict, f, indent=4)
|
||||||
|
elif save_path.suffix == ".yaml":
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{save_path} must be json or yaml")
|
||||||
|
|
||||||
|
|
||||||
|
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||||
|
"""Format a document into a string based on a prompt template.
|
||||||
|
|
||||||
|
First, this pulls information from the document from two sources:
|
||||||
|
|
||||||
|
1. `page_content`:
|
||||||
|
This takes the information from the `document.page_content`
|
||||||
|
and assigns it to a variable named `page_content`.
|
||||||
|
2. metadata:
|
||||||
|
This takes information from `document.metadata` and assigns
|
||||||
|
it to variables of the same name.
|
||||||
|
|
||||||
|
Those variables are then passed into the `prompt` to produce a formatted string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc: Document, the page_content and metadata will be used to create
|
||||||
|
the final string.
|
||||||
|
prompt: BasePromptTemplate, will be used to format the page_content
|
||||||
|
and metadata into the final string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
string of the document formatted.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_core import Document
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
|
doc = Document(page_content="This is a joke", metadata={"page": "1"})
|
||||||
|
prompt = PromptTemplate.from_template("Page {page}: {page_content}")
|
||||||
|
format_document(doc, prompt)
|
||||||
|
>>> "Page 1: This is a joke"
|
||||||
|
"""
|
||||||
|
base_info = {"page_content": doc.page_content, **doc.metadata}
|
||||||
|
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||||
|
if len(missing_metadata) > 0:
|
||||||
|
required_metadata = [
|
||||||
|
iv for iv in prompt.input_variables if iv != "page_content"
|
||||||
|
]
|
||||||
|
raise ValueError(
|
||||||
|
f"Document prompt requires documents to have metadata variables: "
|
||||||
|
f"{required_metadata}. Received document with missing metadata: "
|
||||||
|
f"{list(missing_metadata)}."
|
||||||
|
)
|
||||||
|
document_info = {k: base_info[k] for k in prompt.input_variables}
|
||||||
|
return prompt.format(**document_info)
|
||||||
|
@ -19,15 +19,8 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load import Serializable
|
||||||
from langchain_core.prompts.base import StringPromptTemplate
|
from langchain_core.messages import (
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
|
||||||
from langchain_core.schema import (
|
|
||||||
BasePromptTemplate,
|
|
||||||
PromptValue,
|
|
||||||
)
|
|
||||||
from langchain_core.schema.messages import (
|
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AnyMessage,
|
AnyMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -36,6 +29,11 @@ from langchain_core.schema.messages import (
|
|||||||
SystemMessage,
|
SystemMessage,
|
||||||
get_buffer_string,
|
get_buffer_string,
|
||||||
)
|
)
|
||||||
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
|
from langchain_core.prompts.string import StringPromptTemplate
|
||||||
|
from langchain_core.prompts.value import PromptValue
|
||||||
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
|
|
||||||
|
|
||||||
class BaseMessagePromptTemplate(Serializable, ABC):
|
class BaseMessagePromptTemplate(Serializable, ABC):
|
||||||
|
@ -4,20 +4,19 @@ from __future__ import annotations
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from langchain_core.prompts.base import (
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
from langchain_core.prompts.chat import (
|
||||||
|
BaseChatPromptTemplate,
|
||||||
|
BaseMessagePromptTemplate,
|
||||||
|
)
|
||||||
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
|
from langchain_core.prompts.string import (
|
||||||
DEFAULT_FORMATTER_MAPPING,
|
DEFAULT_FORMATTER_MAPPING,
|
||||||
StringPromptTemplate,
|
StringPromptTemplate,
|
||||||
check_valid_template,
|
check_valid_template,
|
||||||
get_template_variables,
|
get_template_variables,
|
||||||
)
|
)
|
||||||
from langchain_core.prompts.chat import (
|
|
||||||
BaseChatPromptTemplate,
|
|
||||||
BaseMessagePromptTemplate,
|
|
||||||
)
|
|
||||||
from langchain_core.prompts.example_selector.base import BaseExampleSelector
|
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||||
from langchain_core.schema.messages import BaseMessage, get_buffer_string
|
|
||||||
|
|
||||||
|
|
||||||
class _FewShotPromptTemplateMixin(BaseModel):
|
class _FewShotPromptTemplateMixin(BaseModel):
|
||||||
@ -27,7 +26,7 @@ class _FewShotPromptTemplateMixin(BaseModel):
|
|||||||
"""Examples to format into the prompt.
|
"""Examples to format into the prompt.
|
||||||
Either this or example_selector should be provided."""
|
Either this or example_selector should be provided."""
|
||||||
|
|
||||||
example_selector: Optional[BaseExampleSelector] = None
|
example_selector: Any = None
|
||||||
"""ExampleSelector to choose the examples to format into the prompt.
|
"""ExampleSelector to choose the examples to format into the prompt.
|
||||||
Either this or examples should be provided."""
|
Either this or examples should be provided."""
|
||||||
|
|
||||||
@ -253,7 +252,7 @@ class FewShotChatMessagePromptTemplate(
|
|||||||
vectorstore=vectorstore
|
vectorstore=vectorstore
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_core.schema import SystemMessage
|
from langchain_core import SystemMessage
|
||||||
from langchain_core.prompts import HumanMessagePromptTemplate
|
from langchain_core.prompts import HumanMessagePromptTemplate
|
||||||
from langchain_core.prompts.few_shot import FewShotChatMessagePromptTemplate
|
from langchain_core.prompts.few_shot import FewShotChatMessagePromptTemplate
|
||||||
|
|
||||||
|
@ -2,9 +2,11 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from langchain_core.prompts.base import DEFAULT_FORMATTER_MAPPING, StringPromptTemplate
|
|
||||||
from langchain_core.prompts.example_selector.base import BaseExampleSelector
|
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
|
from langchain_core.prompts.string import (
|
||||||
|
DEFAULT_FORMATTER_MAPPING,
|
||||||
|
StringPromptTemplate,
|
||||||
|
)
|
||||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||||
|
|
||||||
|
|
||||||
@ -15,7 +17,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
|||||||
"""Examples to format into the prompt.
|
"""Examples to format into the prompt.
|
||||||
Either this or example_selector should be provided."""
|
Either this or example_selector should be provided."""
|
||||||
|
|
||||||
example_selector: Optional[BaseExampleSelector] = None
|
example_selector: Any = None
|
||||||
"""ExampleSelector to choose the examples to format into the prompt.
|
"""ExampleSelector to choose the examples to format into the prompt.
|
||||||
Either this or examples should be provided."""
|
Either this or examples should be provided."""
|
||||||
|
|
||||||
|
@ -6,13 +6,11 @@ from typing import Callable, Dict, Union
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.schema import (
|
from langchain_core.utils import try_load_from_hub
|
||||||
BasePromptTemplate,
|
|
||||||
StrOutputParser,
|
|
||||||
)
|
|
||||||
from langchain_core.utils.loading import try_load_from_hub
|
|
||||||
|
|
||||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
from langchain_core.prompts.chat import BaseChatPromptTemplate
|
from langchain_core.prompts.chat import BaseChatPromptTemplate
|
||||||
|
from langchain_core.prompts.value import PromptValue
|
||||||
from langchain_core.pydantic_v1 import root_validator
|
from langchain_core.pydantic_v1 import root_validator
|
||||||
from langchain_core.schema import BasePromptTemplate, PromptValue
|
|
||||||
|
|
||||||
|
|
||||||
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
|
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
|
||||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from langchain_core.prompts.base import (
|
from langchain_core.prompts.string import (
|
||||||
DEFAULT_FORMATTER_MAPPING,
|
DEFAULT_FORMATTER_MAPPING,
|
||||||
StringPromptTemplate,
|
StringPromptTemplate,
|
||||||
check_valid_template,
|
check_valid_template,
|
||||||
|
173
libs/core/langchain_core/prompts/string.py
Normal file
173
libs/core/langchain_core/prompts/string.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
"""BasePrompt schema definition."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from abc import ABC
|
||||||
|
from string import Formatter
|
||||||
|
from typing import Any, Callable, Dict, List, Literal, Set
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage, HumanMessage
|
||||||
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
|
from langchain_core.prompts.value import PromptValue
|
||||||
|
from langchain_core.utils.formatting import formatter
|
||||||
|
|
||||||
|
|
||||||
|
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||||
|
"""Format a template using jinja2.
|
||||||
|
|
||||||
|
*Security warning*: As of LangChain 0.0.329, this method uses Jinja2's
|
||||||
|
SandboxedEnvironment by default. However, this sand-boxing should
|
||||||
|
be treated as a best-effort approach rather than a guarantee of security.
|
||||||
|
Do not accept jinja2 templates from untrusted sources as they may lead
|
||||||
|
to arbitrary Python code execution.
|
||||||
|
|
||||||
|
https://jinja.palletsprojects.com/en/3.1.x/sandbox/
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from jinja2.sandbox import SandboxedEnvironment
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
||||||
|
"Please install it with `pip install jinja2`."
|
||||||
|
"Please be cautious when using jinja2 templates. "
|
||||||
|
"Do not expand jinja2 templates using unverified or user-controlled "
|
||||||
|
"inputs as that can result in arbitrary Python code execution."
|
||||||
|
)
|
||||||
|
|
||||||
|
# This uses a sandboxed environment to prevent arbitrary code execution.
|
||||||
|
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
|
||||||
|
# Please treat this sand-boxing as a best-effort approach rather than
|
||||||
|
# a guarantee of security.
|
||||||
|
# We recommend to never use jinja2 templates with untrusted inputs.
|
||||||
|
# https://jinja.palletsprojects.com/en/3.1.x/sandbox/
|
||||||
|
# approach not a guarantee of security.
|
||||||
|
return SandboxedEnvironment().from_string(template).render(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_jinja2(template: str, input_variables: List[str]) -> None:
|
||||||
|
"""
|
||||||
|
Validate that the input variables are valid for the template.
|
||||||
|
Issues a warning if missing or extra variables are found.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: The template string.
|
||||||
|
input_variables: The input variables.
|
||||||
|
"""
|
||||||
|
input_variables_set = set(input_variables)
|
||||||
|
valid_variables = _get_jinja2_variables_from_template(template)
|
||||||
|
missing_variables = valid_variables - input_variables_set
|
||||||
|
extra_variables = input_variables_set - valid_variables
|
||||||
|
|
||||||
|
warning_message = ""
|
||||||
|
if missing_variables:
|
||||||
|
warning_message += f"Missing variables: {missing_variables} "
|
||||||
|
|
||||||
|
if extra_variables:
|
||||||
|
warning_message += f"Extra variables: {extra_variables}"
|
||||||
|
|
||||||
|
if warning_message:
|
||||||
|
warnings.warn(warning_message.strip())
|
||||||
|
|
||||||
|
|
||||||
|
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
||||||
|
try:
|
||||||
|
from jinja2 import Environment, meta
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
||||||
|
"Please install it with `pip install jinja2`."
|
||||||
|
)
|
||||||
|
env = Environment()
|
||||||
|
ast = env.parse(template)
|
||||||
|
variables = meta.find_undeclared_variables(ast)
|
||||||
|
return variables
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
||||||
|
"f-string": formatter.format,
|
||||||
|
"jinja2": jinja2_formatter,
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
|
||||||
|
"f-string": formatter.validate_input_variables,
|
||||||
|
"jinja2": validate_jinja2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def check_valid_template(
|
||||||
|
template: str, template_format: str, input_variables: List[str]
|
||||||
|
) -> None:
|
||||||
|
"""Check that template string is valid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: The template string.
|
||||||
|
template_format: The template format. Should be one of "f-string" or "jinja2".
|
||||||
|
input_variables: The input variables.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the template format is not supported.
|
||||||
|
"""
|
||||||
|
if template_format not in DEFAULT_FORMATTER_MAPPING:
|
||||||
|
valid_formats = list(DEFAULT_FORMATTER_MAPPING)
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid template format. Got `{template_format}`;"
|
||||||
|
f" should be one of {valid_formats}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
|
||||||
|
validator_func(template, input_variables)
|
||||||
|
except KeyError as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid prompt schema; check for mismatched or missing input parameters. "
|
||||||
|
+ str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_template_variables(template: str, template_format: str) -> List[str]:
|
||||||
|
"""Get the variables from the template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: The template string.
|
||||||
|
template_format: The template format. Should be one of "f-string" or "jinja2".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The variables from the template.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the template format is not supported.
|
||||||
|
"""
|
||||||
|
if template_format == "jinja2":
|
||||||
|
# Get the variables for the template
|
||||||
|
input_variables = _get_jinja2_variables_from_template(template)
|
||||||
|
elif template_format == "f-string":
|
||||||
|
input_variables = {
|
||||||
|
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported template format: {template_format}")
|
||||||
|
|
||||||
|
return sorted(input_variables)
|
||||||
|
|
||||||
|
|
||||||
|
class StringPromptValue(PromptValue):
|
||||||
|
"""String prompt value."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
"""Prompt text."""
|
||||||
|
type: Literal["StringPromptValue"] = "StringPromptValue"
|
||||||
|
|
||||||
|
def to_string(self) -> str:
|
||||||
|
"""Return prompt as string."""
|
||||||
|
return self.text
|
||||||
|
|
||||||
|
def to_messages(self) -> List[BaseMessage]:
|
||||||
|
"""Return prompt as messages."""
|
||||||
|
return [HumanMessage(content=self.text)]
|
||||||
|
|
||||||
|
|
||||||
|
class StringPromptTemplate(BasePromptTemplate, ABC):
|
||||||
|
"""String prompt that exposes the format method, returning a prompt."""
|
||||||
|
|
||||||
|
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
|
"""Create Chat Messages."""
|
||||||
|
return StringPromptValue(text=self.format(**kwargs))
|
@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
from langchain_core.schema.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
|
|
||||||
class PromptValue(Serializable, ABC):
|
class PromptValue(Serializable, ABC):
|
@ -7,9 +7,9 @@ from functools import partial
|
|||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain_core.documents import Document
|
||||||
from langchain_core.load.dump import dumpd
|
from langchain_core.load.dump import dumpd
|
||||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||||
from langchain_core.schema.document import Document
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
@ -25,7 +25,11 @@ from langchain_core.runnables.base import (
|
|||||||
RunnableSerializable,
|
RunnableSerializable,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.branch import RunnableBranch
|
from langchain_core.runnables.branch import RunnableBranch
|
||||||
from langchain_core.runnables.config import RunnableConfig, patch_config
|
from langchain_core.runnables.config import (
|
||||||
|
RunnableConfig,
|
||||||
|
get_config_list,
|
||||||
|
patch_config,
|
||||||
|
)
|
||||||
from langchain_core.runnables.fallbacks import RunnableWithFallbacks
|
from langchain_core.runnables.fallbacks import RunnableWithFallbacks
|
||||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||||
from langchain_core.runnables.router import RouterInput, RouterRunnable
|
from langchain_core.runnables.router import RouterInput, RouterRunnable
|
||||||
@ -33,6 +37,7 @@ from langchain_core.runnables.utils import (
|
|||||||
ConfigurableField,
|
ConfigurableField,
|
||||||
ConfigurableFieldMultiOption,
|
ConfigurableFieldMultiOption,
|
||||||
ConfigurableFieldSingleOption,
|
ConfigurableFieldSingleOption,
|
||||||
|
add,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -54,4 +59,6 @@ __all__ = [
|
|||||||
"RunnablePassthrough",
|
"RunnablePassthrough",
|
||||||
"RunnableSequence",
|
"RunnableSequence",
|
||||||
"RunnableWithFallbacks",
|
"RunnableWithFallbacks",
|
||||||
|
"get_config_list",
|
||||||
|
"add",
|
||||||
]
|
]
|
||||||
|
@ -36,11 +36,11 @@ if TYPE_CHECKING:
|
|||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
)
|
)
|
||||||
from langchain_core.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
|
||||||
from langchain_core.callbacks.tracers.root_listeners import Listener
|
|
||||||
from langchain_core.runnables.fallbacks import (
|
from langchain_core.runnables.fallbacks import (
|
||||||
RunnableWithFallbacks as RunnableWithFallbacksT,
|
RunnableWithFallbacks as RunnableWithFallbacksT,
|
||||||
)
|
)
|
||||||
|
from langchain_core.tracers.log_stream import RunLog, RunLogPatch
|
||||||
|
from langchain_core.tracers.root_listeners import Listener
|
||||||
|
|
||||||
from langchain_core.load.dump import dumpd
|
from langchain_core.load.dump import dumpd
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
@ -198,7 +198,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
... code-block:: python
|
... code-block:: python
|
||||||
|
|
||||||
from langchain_core.callbacks.tracers import ConsoleCallbackHandler
|
from langchain_core.tracers import ConsoleCallbackHandler
|
||||||
|
|
||||||
chain.invoke(
|
chain.invoke(
|
||||||
...,
|
...,
|
||||||
@ -559,7 +559,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from langchain_core.callbacks.base import BaseCallbackManager
|
from langchain_core.callbacks.base import BaseCallbackManager
|
||||||
from langchain_core.callbacks.tracers.log_stream import (
|
from langchain_core.tracers.log_stream import (
|
||||||
LogStreamCallbackHandler,
|
LogStreamCallbackHandler,
|
||||||
RunLog,
|
RunLog,
|
||||||
RunLogPatch,
|
RunLogPatch,
|
||||||
@ -725,7 +725,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
type, input, output, error, start_time, end_time, and any tags or metadata
|
type, input, output, error, start_time, end_time, and any tags or metadata
|
||||||
added to the run.
|
added to the run.
|
||||||
"""
|
"""
|
||||||
from langchain_core.callbacks.tracers.root_listeners import RootListenersTracer
|
from langchain_core.tracers.root_listeners import RootListenersTracer
|
||||||
|
|
||||||
return RunnableBinding(
|
return RunnableBinding(
|
||||||
bound=self,
|
bound=self,
|
||||||
@ -2945,7 +2945,7 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
|
|||||||
type, input, output, error, start_time, end_time, and any tags or metadata
|
type, input, output, error, start_time, end_time, and any tags or metadata
|
||||||
added to the run.
|
added to the run.
|
||||||
"""
|
"""
|
||||||
from langchain_core.callbacks.tracers.root_listeners import RootListenersTracer
|
from langchain_core.tracers.root_listeners import RootListenersTracer
|
||||||
|
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
bound=self.bound,
|
bound=self.bound,
|
||||||
|
@ -66,7 +66,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
# response.
|
# response.
|
||||||
|
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_core.schema.output_parser import StrOutputParser
|
from langchain_core.output_parser import StrOutputParser
|
||||||
from langchain_core.runnables import RunnableLambda
|
from langchain_core.runnables import RunnableLambda
|
||||||
|
|
||||||
def when_all_is_lost(inputs):
|
def when_all_is_lost(inputs):
|
||||||
|
@ -13,6 +13,7 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from langchain_core.chat_history import BaseChatMessageHistory
|
||||||
from langchain_core.load import load
|
from langchain_core.load import load
|
||||||
from langchain_core.pydantic_v1 import BaseModel, create_model
|
from langchain_core.pydantic_v1 import BaseModel, create_model
|
||||||
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
|
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
|
||||||
@ -21,12 +22,11 @@ from langchain_core.runnables.utils import (
|
|||||||
ConfigurableFieldSpec,
|
ConfigurableFieldSpec,
|
||||||
get_unique_config_specs,
|
get_unique_config_specs,
|
||||||
)
|
)
|
||||||
from langchain_core.schema.chat_history import BaseChatMessageHistory
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.callbacks.tracers.schemas import Run
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
from langchain_core.schema.messages import BaseMessage
|
from langchain_core.tracers.schemas import Run
|
||||||
|
|
||||||
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
|
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
|
||||||
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
|
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
|
||||||
@ -178,7 +178,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
super_schema = super().get_input_schema(config)
|
super_schema = super().get_input_schema(config)
|
||||||
if super_schema.__custom_root_type__ is not None:
|
if super_schema.__custom_root_type__ is not None:
|
||||||
from langchain_core.schema.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
fields: Dict = {}
|
fields: Dict = {}
|
||||||
if self.input_messages_key and self.history_messages_key:
|
if self.input_messages_key and self.history_messages_key:
|
||||||
@ -202,10 +202,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
def _get_input_messages(
|
def _get_input_messages(
|
||||||
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage]]
|
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage]]
|
||||||
) -> List[BaseMessage]:
|
) -> List[BaseMessage]:
|
||||||
from langchain_core.schema.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
if isinstance(input_val, str):
|
if isinstance(input_val, str):
|
||||||
from langchain_core.schema.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
return [HumanMessage(content=input_val)]
|
return [HumanMessage(content=input_val)]
|
||||||
elif isinstance(input_val, BaseMessage):
|
elif isinstance(input_val, BaseMessage):
|
||||||
@ -221,13 +221,13 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
def _get_output_messages(
|
def _get_output_messages(
|
||||||
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
|
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
|
||||||
) -> List[BaseMessage]:
|
) -> List[BaseMessage]:
|
||||||
from langchain_core.schema.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
if isinstance(output_val, dict):
|
if isinstance(output_val, dict):
|
||||||
output_val = output_val[self.output_messages_key or "output"]
|
output_val = output_val[self.output_messages_key or "output"]
|
||||||
|
|
||||||
if isinstance(output_val, str):
|
if isinstance(output_val, str):
|
||||||
from langchain_core.schema.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
return [AIMessage(content=output_val)]
|
return [AIMessage(content=output_val)]
|
||||||
elif isinstance(output_val, BaseMessage):
|
elif isinstance(output_val, BaseMessage):
|
||||||
|
@ -1,78 +0,0 @@
|
|||||||
"""**Schemas** are the LangChain Base Classes and Interfaces."""
|
|
||||||
from langchain_core.schema.agent import AgentAction, AgentFinish
|
|
||||||
from langchain_core.schema.cache import BaseCache
|
|
||||||
from langchain_core.schema.chat_history import BaseChatMessageHistory
|
|
||||||
from langchain_core.schema.document import BaseDocumentTransformer, Document
|
|
||||||
from langchain_core.schema.exceptions import LangChainException
|
|
||||||
from langchain_core.schema.memory import BaseMemory
|
|
||||||
from langchain_core.schema.messages import (
|
|
||||||
AIMessage,
|
|
||||||
BaseMessage,
|
|
||||||
ChatMessage,
|
|
||||||
FunctionMessage,
|
|
||||||
HumanMessage,
|
|
||||||
SystemMessage,
|
|
||||||
_message_from_dict,
|
|
||||||
_message_to_dict,
|
|
||||||
get_buffer_string,
|
|
||||||
messages_from_dict,
|
|
||||||
messages_to_dict,
|
|
||||||
)
|
|
||||||
from langchain_core.schema.output import (
|
|
||||||
ChatGeneration,
|
|
||||||
ChatResult,
|
|
||||||
Generation,
|
|
||||||
LLMResult,
|
|
||||||
RunInfo,
|
|
||||||
)
|
|
||||||
from langchain_core.schema.output_parser import (
|
|
||||||
BaseLLMOutputParser,
|
|
||||||
BaseOutputParser,
|
|
||||||
OutputParserException,
|
|
||||||
StrOutputParser,
|
|
||||||
)
|
|
||||||
from langchain_core.schema.prompt import PromptValue
|
|
||||||
from langchain_core.schema.prompt_template import BasePromptTemplate, format_document
|
|
||||||
from langchain_core.schema.retriever import BaseRetriever
|
|
||||||
from langchain_core.schema.storage import BaseStore
|
|
||||||
|
|
||||||
RUN_KEY = "__run"
|
|
||||||
Memory = BaseMemory
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseCache",
|
|
||||||
"BaseMemory",
|
|
||||||
"BaseStore",
|
|
||||||
"AgentFinish",
|
|
||||||
"AgentAction",
|
|
||||||
"Document",
|
|
||||||
"BaseChatMessageHistory",
|
|
||||||
"BaseDocumentTransformer",
|
|
||||||
"BaseMessage",
|
|
||||||
"ChatMessage",
|
|
||||||
"FunctionMessage",
|
|
||||||
"HumanMessage",
|
|
||||||
"AIMessage",
|
|
||||||
"SystemMessage",
|
|
||||||
"messages_from_dict",
|
|
||||||
"messages_to_dict",
|
|
||||||
"_message_to_dict",
|
|
||||||
"_message_from_dict",
|
|
||||||
"get_buffer_string",
|
|
||||||
"RunInfo",
|
|
||||||
"LLMResult",
|
|
||||||
"ChatResult",
|
|
||||||
"ChatGeneration",
|
|
||||||
"Generation",
|
|
||||||
"PromptValue",
|
|
||||||
"LangChainException",
|
|
||||||
"BaseRetriever",
|
|
||||||
"RUN_KEY",
|
|
||||||
"Memory",
|
|
||||||
"OutputParserException",
|
|
||||||
"StrOutputParser",
|
|
||||||
"BaseOutputParser",
|
|
||||||
"BaseLLMOutputParser",
|
|
||||||
"BasePromptTemplate",
|
|
||||||
"format_document",
|
|
||||||
]
|
|
@ -1,2 +0,0 @@
|
|||||||
class LangChainException(Exception):
|
|
||||||
"""General LangChain exception."""
|
|
@ -1,415 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
|
|
||||||
|
|
||||||
from typing_extensions import Literal
|
|
||||||
|
|
||||||
from langchain_core.load.serializable import Serializable
|
|
||||||
from langchain_core.pydantic_v1 import Extra, Field
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
|
||||||
|
|
||||||
|
|
||||||
def get_buffer_string(
|
|
||||||
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
|
||||||
) -> str:
|
|
||||||
"""Convert sequence of Messages to strings and concatenate them into one string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Messages to be converted to strings.
|
|
||||||
human_prefix: The prefix to prepend to contents of HumanMessages.
|
|
||||||
ai_prefix: THe prefix to prepend to contents of AIMessages.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A single string concatenation of all input messages.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain_core.schema import AIMessage, HumanMessage
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
HumanMessage(content="Hi, how are you?"),
|
|
||||||
AIMessage(content="Good, how are you?"),
|
|
||||||
]
|
|
||||||
get_buffer_string(messages)
|
|
||||||
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
|
|
||||||
"""
|
|
||||||
string_messages = []
|
|
||||||
for m in messages:
|
|
||||||
if isinstance(m, HumanMessage):
|
|
||||||
role = human_prefix
|
|
||||||
elif isinstance(m, AIMessage):
|
|
||||||
role = ai_prefix
|
|
||||||
elif isinstance(m, SystemMessage):
|
|
||||||
role = "System"
|
|
||||||
elif isinstance(m, FunctionMessage):
|
|
||||||
role = "Function"
|
|
||||||
elif isinstance(m, ChatMessage):
|
|
||||||
role = m.role
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Got unsupported message type: {m}")
|
|
||||||
message = f"{role}: {m.content}"
|
|
||||||
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
|
|
||||||
message += f"{m.additional_kwargs['function_call']}"
|
|
||||||
string_messages.append(message)
|
|
||||||
|
|
||||||
return "\n".join(string_messages)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseMessage(Serializable):
|
|
||||||
"""The base abstract Message class.
|
|
||||||
|
|
||||||
Messages are the inputs and outputs of ChatModels.
|
|
||||||
"""
|
|
||||||
|
|
||||||
content: Union[str, List[Union[str, Dict]]]
|
|
||||||
"""The string contents of the message."""
|
|
||||||
|
|
||||||
additional_kwargs: dict = Field(default_factory=dict)
|
|
||||||
"""Any additional information."""
|
|
||||||
|
|
||||||
type: str
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
extra = Extra.allow
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_lc_serializable(cls) -> bool:
|
|
||||||
"""Return whether this class is serializable."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
|
||||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
|
||||||
|
|
||||||
prompt = ChatPromptTemplate(messages=[self])
|
|
||||||
return prompt + other
|
|
||||||
|
|
||||||
|
|
||||||
def merge_content(
|
|
||||||
first_content: Union[str, List[Union[str, Dict]]],
|
|
||||||
second_content: Union[str, List[Union[str, Dict]]],
|
|
||||||
) -> Union[str, List[Union[str, Dict]]]:
|
|
||||||
# If first chunk is a string
|
|
||||||
if isinstance(first_content, str):
|
|
||||||
# If the second chunk is also a string, then merge them naively
|
|
||||||
if isinstance(second_content, str):
|
|
||||||
return first_content + second_content
|
|
||||||
# If the second chunk is a list, add the first chunk to the start of the list
|
|
||||||
else:
|
|
||||||
return_list: List[Union[str, Dict]] = [first_content]
|
|
||||||
return return_list + second_content
|
|
||||||
# If both are lists, merge them naively
|
|
||||||
elif isinstance(second_content, List):
|
|
||||||
return first_content + second_content
|
|
||||||
# If the first content is a list, and the second content is a string
|
|
||||||
else:
|
|
||||||
# If the last element of the first content is a string
|
|
||||||
# Add the second content to the last element
|
|
||||||
if isinstance(first_content[-1], str):
|
|
||||||
return first_content[:-1] + [first_content[-1] + second_content]
|
|
||||||
else:
|
|
||||||
# Otherwise, add the second content as a new element of the list
|
|
||||||
return first_content + [second_content]
|
|
||||||
|
|
||||||
|
|
||||||
class BaseMessageChunk(BaseMessage):
|
|
||||||
"""A Message chunk, which can be concatenated with other Message chunks."""
|
|
||||||
|
|
||||||
def _merge_kwargs_dict(
|
|
||||||
self, left: Dict[str, Any], right: Dict[str, Any]
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Merge additional_kwargs from another BaseMessageChunk into this one."""
|
|
||||||
merged = left.copy()
|
|
||||||
for k, v in right.items():
|
|
||||||
if k not in merged:
|
|
||||||
merged[k] = v
|
|
||||||
elif type(merged[k]) != type(v):
|
|
||||||
raise ValueError(
|
|
||||||
f'additional_kwargs["{k}"] already exists in this message,'
|
|
||||||
" but with a different type."
|
|
||||||
)
|
|
||||||
elif isinstance(merged[k], str):
|
|
||||||
merged[k] += v
|
|
||||||
elif isinstance(merged[k], dict):
|
|
||||||
merged[k] = self._merge_kwargs_dict(merged[k], v)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Additional kwargs key {k} already exists in this message."
|
|
||||||
)
|
|
||||||
return merged
|
|
||||||
|
|
||||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
|
||||||
if isinstance(other, BaseMessageChunk):
|
|
||||||
# If both are (subclasses of) BaseMessageChunk,
|
|
||||||
# concat into a single BaseMessageChunk
|
|
||||||
|
|
||||||
if isinstance(self, ChatMessageChunk):
|
|
||||||
return self.__class__(
|
|
||||||
role=self.role,
|
|
||||||
content=merge_content(self.content, other.content),
|
|
||||||
additional_kwargs=self._merge_kwargs_dict(
|
|
||||||
self.additional_kwargs, other.additional_kwargs
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return self.__class__(
|
|
||||||
content=merge_content(self.content, other.content),
|
|
||||||
additional_kwargs=self._merge_kwargs_dict(
|
|
||||||
self.additional_kwargs, other.additional_kwargs
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
'unsupported operand type(s) for +: "'
|
|
||||||
f"{self.__class__.__name__}"
|
|
||||||
f'" and "{other.__class__.__name__}"'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HumanMessage(BaseMessage):
|
|
||||||
"""A Message from a human."""
|
|
||||||
|
|
||||||
example: bool = False
|
|
||||||
"""Whether this Message is being passed in to the model as part of an example
|
|
||||||
conversation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal["human"] = "human"
|
|
||||||
|
|
||||||
|
|
||||||
HumanMessage.update_forward_refs()
|
|
||||||
|
|
||||||
|
|
||||||
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
|
||||||
"""A Human Message chunk."""
|
|
||||||
|
|
||||||
# Ignoring mypy re-assignment here since we're overriding the value
|
|
||||||
# to make sure that the chunk variant can be discriminated from the
|
|
||||||
# non-chunk variant.
|
|
||||||
type: Literal["HumanMessageChunk"] = "HumanMessageChunk" # type: ignore[assignment] # noqa: E501
|
|
||||||
|
|
||||||
|
|
||||||
class AIMessage(BaseMessage):
|
|
||||||
"""A Message from an AI."""
|
|
||||||
|
|
||||||
example: bool = False
|
|
||||||
"""Whether this Message is being passed in to the model as part of an example
|
|
||||||
conversation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal["ai"] = "ai"
|
|
||||||
|
|
||||||
|
|
||||||
AIMessage.update_forward_refs()
|
|
||||||
|
|
||||||
|
|
||||||
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
|
||||||
"""A Message chunk from an AI."""
|
|
||||||
|
|
||||||
# Ignoring mypy re-assignment here since we're overriding the value
|
|
||||||
# to make sure that the chunk variant can be discriminated from the
|
|
||||||
# non-chunk variant.
|
|
||||||
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501
|
|
||||||
|
|
||||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
|
||||||
if isinstance(other, AIMessageChunk):
|
|
||||||
if self.example != other.example:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot concatenate AIMessageChunks with different example values."
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.__class__(
|
|
||||||
example=self.example,
|
|
||||||
content=merge_content(self.content, other.content),
|
|
||||||
additional_kwargs=self._merge_kwargs_dict(
|
|
||||||
self.additional_kwargs, other.additional_kwargs
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return super().__add__(other)
|
|
||||||
|
|
||||||
|
|
||||||
class SystemMessage(BaseMessage):
|
|
||||||
"""A Message for priming AI behavior, usually passed in as the first of a sequence
|
|
||||||
of input messages.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal["system"] = "system"
|
|
||||||
|
|
||||||
|
|
||||||
SystemMessage.update_forward_refs()
|
|
||||||
|
|
||||||
|
|
||||||
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
|
||||||
"""A System Message chunk."""
|
|
||||||
|
|
||||||
# Ignoring mypy re-assignment here since we're overriding the value
|
|
||||||
# to make sure that the chunk variant can be discriminated from the
|
|
||||||
# non-chunk variant.
|
|
||||||
type: Literal["SystemMessageChunk"] = "SystemMessageChunk" # type: ignore[assignment] # noqa: E501
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionMessage(BaseMessage):
|
|
||||||
"""A Message for passing the result of executing a function back to a model."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
"""The name of the function that was executed."""
|
|
||||||
|
|
||||||
type: Literal["function"] = "function"
|
|
||||||
|
|
||||||
|
|
||||||
FunctionMessage.update_forward_refs()
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
|
||||||
"""A Function Message chunk."""
|
|
||||||
|
|
||||||
# Ignoring mypy re-assignment here since we're overriding the value
|
|
||||||
# to make sure that the chunk variant can be discriminated from the
|
|
||||||
# non-chunk variant.
|
|
||||||
type: Literal["FunctionMessageChunk"] = "FunctionMessageChunk" # type: ignore[assignment]
|
|
||||||
|
|
||||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
|
||||||
if isinstance(other, FunctionMessageChunk):
|
|
||||||
if self.name != other.name:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot concatenate FunctionMessageChunks with different names."
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.__class__(
|
|
||||||
name=self.name,
|
|
||||||
content=merge_content(self.content, other.content),
|
|
||||||
additional_kwargs=self._merge_kwargs_dict(
|
|
||||||
self.additional_kwargs, other.additional_kwargs
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return super().__add__(other)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolMessage(BaseMessage):
|
|
||||||
"""A Message for passing the result of executing a tool back to a model."""
|
|
||||||
|
|
||||||
tool_call_id: str
|
|
||||||
"""Tool call that this message is responding to."""
|
|
||||||
|
|
||||||
type: Literal["tool"] = "tool"
|
|
||||||
|
|
||||||
|
|
||||||
ToolMessage.update_forward_refs()
|
|
||||||
|
|
||||||
|
|
||||||
class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
|
||||||
"""A Tool Message chunk."""
|
|
||||||
|
|
||||||
# Ignoring mypy re-assignment here since we're overriding the value
|
|
||||||
# to make sure that the chunk variant can be discriminated from the
|
|
||||||
# non-chunk variant.
|
|
||||||
type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment]
|
|
||||||
|
|
||||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
|
||||||
if isinstance(other, ToolMessageChunk):
|
|
||||||
if self.tool_call_id != other.tool_call_id:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot concatenate ToolMessageChunks with different names."
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.__class__(
|
|
||||||
tool_call_id=self.tool_call_id,
|
|
||||||
content=merge_content(self.content, other.content),
|
|
||||||
additional_kwargs=self._merge_kwargs_dict(
|
|
||||||
self.additional_kwargs, other.additional_kwargs
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return super().__add__(other)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseMessage):
|
|
||||||
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
|
|
||||||
|
|
||||||
role: str
|
|
||||||
"""The speaker / role of the Message."""
|
|
||||||
|
|
||||||
type: Literal["chat"] = "chat"
|
|
||||||
|
|
||||||
|
|
||||||
ChatMessage.update_forward_refs()
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
|
||||||
"""A Chat Message chunk."""
|
|
||||||
|
|
||||||
# Ignoring mypy re-assignment here since we're overriding the value
|
|
||||||
# to make sure that the chunk variant can be discriminated from the
|
|
||||||
# non-chunk variant.
|
|
||||||
type: Literal["ChatMessageChunk"] = "ChatMessageChunk" # type: ignore
|
|
||||||
|
|
||||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
|
||||||
if isinstance(other, ChatMessageChunk):
|
|
||||||
if self.role != other.role:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot concatenate ChatMessageChunks with different roles."
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.__class__(
|
|
||||||
role=self.role,
|
|
||||||
content=merge_content(self.content, other.content),
|
|
||||||
additional_kwargs=self._merge_kwargs_dict(
|
|
||||||
self.additional_kwargs, other.additional_kwargs
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return super().__add__(other)
|
|
||||||
|
|
||||||
|
|
||||||
AnyMessage = Union[
|
|
||||||
AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage, ToolMessage
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _message_to_dict(message: BaseMessage) -> dict:
|
|
||||||
return {"type": message.type, "data": message.dict()}
|
|
||||||
|
|
||||||
|
|
||||||
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
|
|
||||||
"""Convert a sequence of Messages to a list of dictionaries.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Sequence of messages (as BaseMessages) to convert.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of messages as dicts.
|
|
||||||
"""
|
|
||||||
return [_message_to_dict(m) for m in messages]
|
|
||||||
|
|
||||||
|
|
||||||
def _message_from_dict(message: dict) -> BaseMessage:
|
|
||||||
_type = message["type"]
|
|
||||||
if _type == "human":
|
|
||||||
return HumanMessage(**message["data"])
|
|
||||||
elif _type == "ai":
|
|
||||||
return AIMessage(**message["data"])
|
|
||||||
elif _type == "system":
|
|
||||||
return SystemMessage(**message["data"])
|
|
||||||
elif _type == "chat":
|
|
||||||
return ChatMessage(**message["data"])
|
|
||||||
elif _type == "function":
|
|
||||||
return FunctionMessage(**message["data"])
|
|
||||||
elif _type == "tool":
|
|
||||||
return ToolMessage(**message["data"])
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Got unexpected message type: {_type}")
|
|
||||||
|
|
||||||
|
|
||||||
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
|
||||||
"""Convert a sequence of messages from dicts to Message objects.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Sequence of messages (as dicts) to convert.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of messages (BaseMessages).
|
|
||||||
"""
|
|
||||||
return [_message_from_dict(m) for m in messages]
|
|
@ -1,175 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from copy import deepcopy
|
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from langchain_core.load.serializable import Serializable
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
|
||||||
from langchain_core.schema.messages import BaseMessage, BaseMessageChunk
|
|
||||||
|
|
||||||
|
|
||||||
class Generation(Serializable):
|
|
||||||
"""A single text generation output."""
|
|
||||||
|
|
||||||
text: str
|
|
||||||
"""Generated text output."""
|
|
||||||
|
|
||||||
generation_info: Optional[Dict[str, Any]] = None
|
|
||||||
"""Raw response from the provider. May include things like the
|
|
||||||
reason for finishing or token log probabilities.
|
|
||||||
"""
|
|
||||||
type: Literal["Generation"] = "Generation"
|
|
||||||
"""Type is used exclusively for serialization purposes."""
|
|
||||||
# TODO: add log probs as separate attribute
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_lc_serializable(cls) -> bool:
|
|
||||||
"""Return whether this class is serializable."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class GenerationChunk(Generation):
|
|
||||||
"""A Generation chunk, which can be concatenated with other Generation chunks."""
|
|
||||||
|
|
||||||
def __add__(self, other: GenerationChunk) -> GenerationChunk:
|
|
||||||
if isinstance(other, GenerationChunk):
|
|
||||||
generation_info = (
|
|
||||||
{**(self.generation_info or {}), **(other.generation_info or {})}
|
|
||||||
if self.generation_info is not None or other.generation_info is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
return GenerationChunk(
|
|
||||||
text=self.text + other.text,
|
|
||||||
generation_info=generation_info,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatGeneration(Generation):
|
|
||||||
"""A single chat generation output."""
|
|
||||||
|
|
||||||
text: str = ""
|
|
||||||
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
|
|
||||||
message: BaseMessage
|
|
||||||
"""The message output by the chat model."""
|
|
||||||
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
|
||||||
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
|
|
||||||
"""Type is used exclusively for serialization purposes."""
|
|
||||||
|
|
||||||
@root_validator
|
|
||||||
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Set the text attribute to be the contents of the message."""
|
|
||||||
try:
|
|
||||||
values["text"] = values["message"].content
|
|
||||||
except (KeyError, AttributeError) as e:
|
|
||||||
raise ValueError("Error while initializing ChatGeneration") from e
|
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
class ChatGenerationChunk(ChatGeneration):
|
|
||||||
"""A ChatGeneration chunk, which can be concatenated with other
|
|
||||||
ChatGeneration chunks.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
message: The message chunk output by the chat model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
message: BaseMessageChunk
|
|
||||||
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
|
||||||
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501
|
|
||||||
"""Type is used exclusively for serialization purposes."""
|
|
||||||
|
|
||||||
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
|
|
||||||
if isinstance(other, ChatGenerationChunk):
|
|
||||||
generation_info = (
|
|
||||||
{**(self.generation_info or {}), **(other.generation_info or {})}
|
|
||||||
if self.generation_info is not None or other.generation_info is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
return ChatGenerationChunk(
|
|
||||||
message=self.message + other.message,
|
|
||||||
generation_info=generation_info,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RunInfo(BaseModel):
|
|
||||||
"""Class that contains metadata for a single execution of a Chain or model."""
|
|
||||||
|
|
||||||
run_id: UUID
|
|
||||||
"""A unique identifier for the model or chain run."""
|
|
||||||
|
|
||||||
|
|
||||||
class ChatResult(BaseModel):
|
|
||||||
"""Class that contains all results for a single chat model call."""
|
|
||||||
|
|
||||||
generations: List[ChatGeneration]
|
|
||||||
"""List of the chat generations. This is a List because an input can have multiple
|
|
||||||
candidate generations.
|
|
||||||
"""
|
|
||||||
llm_output: Optional[dict] = None
|
|
||||||
"""For arbitrary LLM provider specific output."""
|
|
||||||
|
|
||||||
|
|
||||||
class LLMResult(BaseModel):
|
|
||||||
"""Class that contains all results for a batched LLM call."""
|
|
||||||
|
|
||||||
generations: List[List[Generation]]
|
|
||||||
"""List of generated outputs. This is a List[List[]] because
|
|
||||||
each input could have multiple candidate generations."""
|
|
||||||
llm_output: Optional[dict] = None
|
|
||||||
"""Arbitrary LLM provider-specific output."""
|
|
||||||
run: Optional[List[RunInfo]] = None
|
|
||||||
"""List of metadata info for model call for each input."""
|
|
||||||
|
|
||||||
def flatten(self) -> List[LLMResult]:
|
|
||||||
"""Flatten generations into a single list.
|
|
||||||
|
|
||||||
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult
|
|
||||||
contains only a single Generation. If token usage information is available,
|
|
||||||
it is kept only for the LLMResult corresponding to the top-choice
|
|
||||||
Generation, to avoid over-counting of token usage downstream.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of LLMResults where each returned LLMResult contains a single
|
|
||||||
Generation.
|
|
||||||
"""
|
|
||||||
llm_results = []
|
|
||||||
for i, gen_list in enumerate(self.generations):
|
|
||||||
# Avoid double counting tokens in OpenAICallback
|
|
||||||
if i == 0:
|
|
||||||
llm_results.append(
|
|
||||||
LLMResult(
|
|
||||||
generations=[gen_list],
|
|
||||||
llm_output=self.llm_output,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if self.llm_output is not None:
|
|
||||||
llm_output = deepcopy(self.llm_output)
|
|
||||||
llm_output["token_usage"] = dict()
|
|
||||||
else:
|
|
||||||
llm_output = None
|
|
||||||
llm_results.append(
|
|
||||||
LLMResult(
|
|
||||||
generations=[gen_list],
|
|
||||||
llm_output=llm_output,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return llm_results
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
"""Check for LLMResult equality by ignoring any metadata related to runs."""
|
|
||||||
if not isinstance(other, LLMResult):
|
|
||||||
return NotImplemented
|
|
||||||
return (
|
|
||||||
self.generations == other.generations
|
|
||||||
and self.llm_output == other.llm_output
|
|
||||||
)
|
|
@ -1,228 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator
|
|
||||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
|
||||||
from langchain_core.schema.document import Document
|
|
||||||
from langchain_core.schema.output_parser import BaseOutputParser
|
|
||||||
from langchain_core.schema.prompt import PromptValue
|
|
||||||
|
|
||||||
|
|
||||||
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
|
||||||
"""Base class for all prompt templates, returning a prompt."""
|
|
||||||
|
|
||||||
input_variables: List[str]
|
|
||||||
"""A list of the names of the variables the prompt template expects."""
|
|
||||||
input_types: Dict[str, Any] = Field(default_factory=dict)
|
|
||||||
"""A dictionary of the types of the variables the prompt template expects.
|
|
||||||
If not provided, all variables are assumed to be strings."""
|
|
||||||
output_parser: Optional[BaseOutputParser] = None
|
|
||||||
"""How to parse the output of calling an LLM on this formatted prompt."""
|
|
||||||
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
|
|
||||||
default_factory=dict
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_lc_serializable(cls) -> bool:
|
|
||||||
"""Return whether this class is serializable."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def OutputType(self) -> Any:
|
|
||||||
from langchain_core.prompts.base import StringPromptValue
|
|
||||||
from langchain_core.prompts.chat import ChatPromptValueConcrete
|
|
||||||
|
|
||||||
return Union[StringPromptValue, ChatPromptValueConcrete]
|
|
||||||
|
|
||||||
def get_input_schema(
|
|
||||||
self, config: Optional[RunnableConfig] = None
|
|
||||||
) -> Type[BaseModel]:
|
|
||||||
# This is correct, but pydantic typings/mypy don't think so.
|
|
||||||
return create_model( # type: ignore[call-overload]
|
|
||||||
"PromptInput",
|
|
||||||
**{k: (self.input_types.get(k, str), None) for k in self.input_variables},
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(
|
|
||||||
self, input: Dict, config: Optional[RunnableConfig] = None
|
|
||||||
) -> PromptValue:
|
|
||||||
return self._call_with_config(
|
|
||||||
lambda inner_input: self.format_prompt(
|
|
||||||
**{key: inner_input[key] for key in self.input_variables}
|
|
||||||
),
|
|
||||||
input,
|
|
||||||
config,
|
|
||||||
run_type="prompt",
|
|
||||||
)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
|
||||||
"""Create Chat Messages."""
|
|
||||||
|
|
||||||
@root_validator()
|
|
||||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
|
||||||
"""Validate variable names do not include restricted names."""
|
|
||||||
if "stop" in values["input_variables"]:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot have an input variable named 'stop', as it is used internally,"
|
|
||||||
" please rename."
|
|
||||||
)
|
|
||||||
if "stop" in values["partial_variables"]:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot have an partial variable named 'stop', as it is used "
|
|
||||||
"internally, please rename."
|
|
||||||
)
|
|
||||||
|
|
||||||
overall = set(values["input_variables"]).intersection(
|
|
||||||
values["partial_variables"]
|
|
||||||
)
|
|
||||||
if overall:
|
|
||||||
raise ValueError(
|
|
||||||
f"Found overlapping input and partial variables: {overall}"
|
|
||||||
)
|
|
||||||
return values
|
|
||||||
|
|
||||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
|
|
||||||
"""Return a partial of the prompt template."""
|
|
||||||
prompt_dict = self.__dict__.copy()
|
|
||||||
prompt_dict["input_variables"] = list(
|
|
||||||
set(self.input_variables).difference(kwargs)
|
|
||||||
)
|
|
||||||
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
|
|
||||||
return type(self)(**prompt_dict)
|
|
||||||
|
|
||||||
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
|
|
||||||
# Get partial params:
|
|
||||||
partial_kwargs = {
|
|
||||||
k: v if isinstance(v, str) else v()
|
|
||||||
for k, v in self.partial_variables.items()
|
|
||||||
}
|
|
||||||
return {**partial_kwargs, **kwargs}
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format(self, **kwargs: Any) -> str:
|
|
||||||
"""Format the prompt with the inputs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
kwargs: Any arguments to be passed to the prompt template.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A formatted string.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
prompt.format(variable1="foo")
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _prompt_type(self) -> str:
|
|
||||||
"""Return the prompt type key."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def dict(self, **kwargs: Any) -> Dict:
|
|
||||||
"""Return dictionary representation of prompt."""
|
|
||||||
prompt_dict = super().dict(**kwargs)
|
|
||||||
try:
|
|
||||||
prompt_dict["_type"] = self._prompt_type
|
|
||||||
except NotImplementedError:
|
|
||||||
pass
|
|
||||||
return prompt_dict
|
|
||||||
|
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
|
||||||
"""Save the prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to directory to save prompt to.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
prompt.save(file_path="path/prompt.yaml")
|
|
||||||
"""
|
|
||||||
if self.partial_variables:
|
|
||||||
raise ValueError("Cannot save prompt with partial variables.")
|
|
||||||
|
|
||||||
# Fetch dictionary to save
|
|
||||||
prompt_dict = self.dict()
|
|
||||||
if "_type" not in prompt_dict:
|
|
||||||
raise NotImplementedError(f"Prompt {self} does not support saving.")
|
|
||||||
|
|
||||||
# Convert file to Path object.
|
|
||||||
if isinstance(file_path, str):
|
|
||||||
save_path = Path(file_path)
|
|
||||||
else:
|
|
||||||
save_path = file_path
|
|
||||||
|
|
||||||
directory_path = save_path.parent
|
|
||||||
directory_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
if save_path.suffix == ".json":
|
|
||||||
with open(file_path, "w") as f:
|
|
||||||
json.dump(prompt_dict, f, indent=4)
|
|
||||||
elif save_path.suffix == ".yaml":
|
|
||||||
with open(file_path, "w") as f:
|
|
||||||
yaml.dump(prompt_dict, f, default_flow_style=False)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"{save_path} must be json or yaml")
|
|
||||||
|
|
||||||
|
|
||||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
|
||||||
"""Format a document into a string based on a prompt template.
|
|
||||||
|
|
||||||
First, this pulls information from the document from two sources:
|
|
||||||
|
|
||||||
1. `page_content`:
|
|
||||||
This takes the information from the `document.page_content`
|
|
||||||
and assigns it to a variable named `page_content`.
|
|
||||||
2. metadata:
|
|
||||||
This takes information from `document.metadata` and assigns
|
|
||||||
it to variables of the same name.
|
|
||||||
|
|
||||||
Those variables are then passed into the `prompt` to produce a formatted string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
doc: Document, the page_content and metadata will be used to create
|
|
||||||
the final string.
|
|
||||||
prompt: BasePromptTemplate, will be used to format the page_content
|
|
||||||
and metadata into the final string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
string of the document formatted.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain_core.schema import Document
|
|
||||||
from langchain_core.prompts import PromptTemplate
|
|
||||||
|
|
||||||
doc = Document(page_content="This is a joke", metadata={"page": "1"})
|
|
||||||
prompt = PromptTemplate.from_template("Page {page}: {page_content}")
|
|
||||||
format_document(doc, prompt)
|
|
||||||
>>> "Page 1: This is a joke"
|
|
||||||
"""
|
|
||||||
base_info = {"page_content": doc.page_content, **doc.metadata}
|
|
||||||
missing_metadata = set(prompt.input_variables).difference(base_info)
|
|
||||||
if len(missing_metadata) > 0:
|
|
||||||
required_metadata = [
|
|
||||||
iv for iv in prompt.input_variables if iv != "page_content"
|
|
||||||
]
|
|
||||||
raise ValueError(
|
|
||||||
f"Document prompt requires documents to have metadata variables: "
|
|
||||||
f"{required_metadata}. Received document with missing metadata: "
|
|
||||||
f"{list(missing_metadata)}."
|
|
||||||
)
|
|
||||||
document_info = {k: base_info[k] for k in prompt.input_variables}
|
|
||||||
return prompt.format(**document_info)
|
|
16
libs/core/langchain_core/tracers/__init__.py
Normal file
16
libs/core/langchain_core/tracers/__init__.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
__all__ = [
|
||||||
|
"BaseTracer",
|
||||||
|
"EvaluatorCallbackHandler",
|
||||||
|
"LangChainTracer",
|
||||||
|
"ConsoleCallbackHandler",
|
||||||
|
"Run",
|
||||||
|
"RunLog",
|
||||||
|
"RunLogPatch",
|
||||||
|
]
|
||||||
|
|
||||||
|
from langchain_core.tracers.base import BaseTracer
|
||||||
|
from langchain_core.tracers.evaluation import EvaluatorCallbackHandler
|
||||||
|
from langchain_core.tracers.langchain import LangChainTracer
|
||||||
|
from langchain_core.tracers.log_stream import RunLog, RunLogPatch
|
||||||
|
from langchain_core.tracers.schemas import Run
|
||||||
|
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
@ -9,24 +9,21 @@ from uuid import UUID
|
|||||||
|
|
||||||
from tenacity import RetryCallState
|
from tenacity import RetryCallState
|
||||||
|
|
||||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
from langchain_core.callbacks.tracers.schemas import Run
|
from langchain_core.documents import Document
|
||||||
from langchain_core.load.dump import dumpd
|
from langchain_core.exceptions import TracerException
|
||||||
from langchain_core.schema.document import Document
|
from langchain_core.load import dumpd
|
||||||
from langchain_core.schema.output import (
|
from langchain_core.outputs import (
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
ChatGenerationChunk,
|
ChatGenerationChunk,
|
||||||
GenerationChunk,
|
GenerationChunk,
|
||||||
LLMResult,
|
LLMResult,
|
||||||
)
|
)
|
||||||
|
from langchain_core.tracers.schemas import Run
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TracerException(Exception):
|
|
||||||
"""Base class for exceptions in tracers module."""
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTracer(BaseCallbackHandler, ABC):
|
class BaseTracer(BaseCallbackHandler, ABC):
|
||||||
"""Base interface for tracers."""
|
"""Base interface for tracers."""
|
||||||
|
|
@ -12,10 +12,10 @@ import langsmith
|
|||||||
from langsmith.evaluation.evaluator import EvaluationResult, EvaluationResults
|
from langsmith.evaluation.evaluator import EvaluationResult, EvaluationResults
|
||||||
|
|
||||||
from langchain_core.callbacks import manager
|
from langchain_core.callbacks import manager
|
||||||
from langchain_core.callbacks.tracers import langchain as langchain_tracer
|
from langchain_core.tracers import langchain as langchain_tracer
|
||||||
from langchain_core.callbacks.tracers.base import BaseTracer
|
from langchain_core.tracers.base import BaseTracer
|
||||||
from langchain_core.callbacks.tracers.langchain import _get_executor
|
from langchain_core.tracers.langchain import _get_executor
|
||||||
from langchain_core.callbacks.tracers.schemas import Run
|
from langchain_core.tracers.schemas import Run
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -17,11 +17,11 @@ from tenacity import (
|
|||||||
wait_exponential_jitter,
|
wait_exponential_jitter,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_core.callbacks.tracers.base import BaseTracer
|
|
||||||
from langchain_core.callbacks.tracers.schemas import Run
|
|
||||||
from langchain_core.env import get_runtime_environment
|
from langchain_core.env import get_runtime_environment
|
||||||
from langchain_core.load.dump import dumpd
|
from langchain_core.load import dumpd
|
||||||
from langchain_core.schema.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
from langchain_core.tracers.base import BaseTracer
|
||||||
|
from langchain_core.tracers.schemas import Run
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
_LOGGED = set()
|
_LOGGED = set()
|
@ -6,8 +6,9 @@ from typing import Any, Dict, Optional, Union
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from langchain_core.callbacks.tracers.base import BaseTracer
|
from langchain_core.messages import get_buffer_string
|
||||||
from langchain_core.callbacks.tracers.schemas import (
|
from langchain_core.tracers.base import BaseTracer
|
||||||
|
from langchain_core.tracers.schemas import (
|
||||||
ChainRun,
|
ChainRun,
|
||||||
LLMRun,
|
LLMRun,
|
||||||
Run,
|
Run,
|
||||||
@ -16,7 +17,6 @@ from langchain_core.callbacks.tracers.schemas import (
|
|||||||
TracerSessionV1,
|
TracerSessionV1,
|
||||||
TracerSessionV1Base,
|
TracerSessionV1Base,
|
||||||
)
|
)
|
||||||
from langchain_core.schema.messages import get_buffer_string
|
|
||||||
from langchain_core.utils import raise_for_status_with_text
|
from langchain_core.utils import raise_for_status_with_text
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
@ -18,10 +18,10 @@ from uuid import UUID
|
|||||||
import jsonpatch
|
import jsonpatch
|
||||||
from anyio import create_memory_object_stream
|
from anyio import create_memory_object_stream
|
||||||
|
|
||||||
from langchain_core.callbacks.tracers.base import BaseTracer
|
from langchain_core.load import load
|
||||||
from langchain_core.callbacks.tracers.schemas import Run
|
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||||
from langchain_core.load.load import load
|
from langchain_core.tracers.base import BaseTracer
|
||||||
from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk
|
from langchain_core.tracers.schemas import Run
|
||||||
|
|
||||||
|
|
||||||
class LogEntry(TypedDict):
|
class LogEntry(TypedDict):
|
@ -1,12 +1,12 @@
|
|||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain_core.callbacks.tracers.base import BaseTracer
|
|
||||||
from langchain_core.callbacks.tracers.schemas import Run
|
|
||||||
from langchain_core.runnables.config import (
|
from langchain_core.runnables.config import (
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
call_func_with_variable_args,
|
call_func_with_variable_args,
|
||||||
)
|
)
|
||||||
|
from langchain_core.tracers.base import BaseTracer
|
||||||
|
from langchain_core.tracers.schemas import Run
|
||||||
|
|
||||||
Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
|
Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
|
||||||
|
|
@ -3,8 +3,8 @@
|
|||||||
from typing import Any, List, Optional, Union
|
from typing import Any, List, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain_core.callbacks.tracers.base import BaseTracer
|
from langchain_core.tracers.base import BaseTracer
|
||||||
from langchain_core.callbacks.tracers.schemas import Run
|
from langchain_core.tracers.schemas import Run
|
||||||
|
|
||||||
|
|
||||||
class RunCollectorCallbackHandler(BaseTracer):
|
class RunCollectorCallbackHandler(BaseTracer):
|
@ -9,8 +9,8 @@ from uuid import UUID
|
|||||||
from langsmith.schemas import RunBase as BaseRunV2
|
from langsmith.schemas import RunBase as BaseRunV2
|
||||||
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
|
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
|
||||||
|
|
||||||
|
from langchain_core.outputs import LLMResult
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||||
from langchain_core.schema import LLMResult
|
|
||||||
|
|
||||||
|
|
||||||
def RunTypeEnum() -> Type[RunTypeEnumDep]:
|
def RunTypeEnum() -> Type[RunTypeEnumDep]:
|
@ -1,8 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Callable, List
|
from typing import Any, Callable, List
|
||||||
|
|
||||||
from langchain_core.callbacks.tracers.base import BaseTracer
|
from langchain_core.tracers.base import BaseTracer
|
||||||
from langchain_core.callbacks.tracers.schemas import Run
|
from langchain_core.tracers.schemas import Run
|
||||||
from langchain_core.utils.input import get_bolded_text, get_colored_text
|
from langchain_core.utils.input import get_bolded_text, get_colored_text
|
||||||
|
|
||||||
|
|
@ -11,6 +11,7 @@ from langchain_core.utils.input import (
|
|||||||
get_colored_text,
|
get_colored_text,
|
||||||
print_text,
|
print_text,
|
||||||
)
|
)
|
||||||
|
from langchain_core.utils.loading import try_load_from_hub
|
||||||
from langchain_core.utils.utils import (
|
from langchain_core.utils.utils import (
|
||||||
check_package_version,
|
check_package_version,
|
||||||
convert_to_secret_str,
|
convert_to_secret_str,
|
||||||
@ -35,4 +36,5 @@ __all__ = [
|
|||||||
"print_text",
|
"print_text",
|
||||||
"raise_for_status_with_text",
|
"raise_for_status_with_text",
|
||||||
"xor_args",
|
"xor_args",
|
||||||
|
"try_load_from_hub",
|
||||||
]
|
]
|
||||||
|
@ -21,10 +21,10 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
from langchain_core.schema import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
from langchain_core.schema.document import Document
|
|
||||||
from langchain_core.schema.embeddings import Embeddings
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.0.1"
|
version = "0.0.2"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
authors = []
|
authors = []
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -6,6 +6,8 @@ EXPECTED_ALL = [
|
|||||||
"suppress_langchain_deprecation_warning",
|
"suppress_langchain_deprecation_warning",
|
||||||
"surface_langchain_deprecation_warnings",
|
"surface_langchain_deprecation_warnings",
|
||||||
"warn_deprecated",
|
"warn_deprecated",
|
||||||
|
"as_import_path",
|
||||||
|
"get_relative_path",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
"""Test functionality related to length based selector."""
|
"""Test functionality related to length based selector."""
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_core.prompts.example_selector.length_based import (
|
from langchain_core.example_selectors import (
|
||||||
LengthBasedExampleSelector,
|
LengthBasedExampleSelector,
|
||||||
)
|
)
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
EXAMPLES = [
|
EXAMPLES = [
|
||||||
{"question": "Question: who are you?\nAnswer: foo"},
|
{"question": "Question: who are you?\nAnswer: foo"},
|
@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Union
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.schema.messages import BaseMessage
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFakeCallbackHandler(BaseModel):
|
class BaseFakeCallbackHandler(BaseModel):
|
||||||
|
@ -7,10 +7,9 @@ from langchain_core.callbacks.manager import (
|
|||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain_core.chat_model import BaseChatModel, SimpleChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
|
||||||
from langchain_core.schema import ChatResult
|
from langchain_core.messages import AIMessageChunk, BaseMessage
|
||||||
from langchain_core.schema.messages import AIMessageChunk, BaseMessage
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.schema.output import ChatGeneration, ChatGenerationChunk
|
|
||||||
|
|
||||||
|
|
||||||
class FakeMessagesListChatModel(BaseChatModel):
|
class FakeMessagesListChatModel(BaseChatModel):
|
||||||
|
@ -6,9 +6,8 @@ from langchain_core.callbacks.manager import (
|
|||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain_core.llm import LLM
|
from langchain_core.language_models import LLM, LanguageModelInput
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
from langchain_core.schema.language_model import LanguageModelInput
|
|
||||||
|
|
||||||
|
|
||||||
class FakeListLLM(LLM):
|
class FakeListLLM(LLM):
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.chat_history import (
|
||||||
from langchain_core.schema import (
|
|
||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
)
|
)
|
||||||
from langchain_core.schema.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
|
class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
|
||||||
|
@ -3,6 +3,13 @@ from typing import Any, List, Union
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
get_buffer_string,
|
||||||
|
)
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_core.prompts.chat import (
|
from langchain_core.prompts.chat import (
|
||||||
AIMessagePromptTemplate,
|
AIMessagePromptTemplate,
|
||||||
@ -15,13 +22,6 @@ from langchain_core.prompts.chat import (
|
|||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
_convert_to_message,
|
_convert_to_message,
|
||||||
)
|
)
|
||||||
from langchain_core.schema.messages import (
|
|
||||||
AIMessage,
|
|
||||||
BaseMessage,
|
|
||||||
HumanMessage,
|
|
||||||
SystemMessage,
|
|
||||||
get_buffer_string,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_messages() -> List[BaseMessagePromptTemplate]:
|
def create_messages() -> List[BaseMessagePromptTemplate]:
|
||||||
|
@ -3,19 +3,19 @@ from typing import Any, Dict, List, Sequence, Tuple
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from langchain_core.example_selectors import BaseExampleSelector
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
from langchain_core.prompts import (
|
from langchain_core.prompts import (
|
||||||
AIMessagePromptTemplate,
|
AIMessagePromptTemplate,
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain_core.prompts.chat import SystemMessagePromptTemplate
|
from langchain_core.prompts.chat import SystemMessagePromptTemplate
|
||||||
from langchain_core.prompts.example_selector.base import BaseExampleSelector
|
|
||||||
from langchain_core.prompts.few_shot import (
|
from langchain_core.prompts.few_shot import (
|
||||||
FewShotChatMessagePromptTemplate,
|
FewShotChatMessagePromptTemplate,
|
||||||
FewShotPromptTemplate,
|
FewShotPromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.schema import AIMessage, HumanMessage, SystemMessage
|
|
||||||
|
|
||||||
EXAMPLE_PROMPT = PromptTemplate(
|
EXAMPLE_PROMPT = PromptTemplate(
|
||||||
input_variables=["question", "answer"], template="{question}: {answer}"
|
input_variables=["question", "answer"], template="{question}: {answer}"
|
||||||
|
@ -6,20 +6,22 @@ EXPECTED_ALL = [
|
|||||||
"BasePromptTemplate",
|
"BasePromptTemplate",
|
||||||
"ChatMessagePromptTemplate",
|
"ChatMessagePromptTemplate",
|
||||||
"ChatPromptTemplate",
|
"ChatPromptTemplate",
|
||||||
|
"ChatPromptValueConcrete",
|
||||||
"FewShotPromptTemplate",
|
"FewShotPromptTemplate",
|
||||||
"FewShotPromptWithTemplates",
|
"FewShotPromptWithTemplates",
|
||||||
|
"FewShotChatMessagePromptTemplate",
|
||||||
|
"format_document",
|
||||||
|
"ChatPromptValue",
|
||||||
|
"PromptValue",
|
||||||
|
"StringPromptValue",
|
||||||
"HumanMessagePromptTemplate",
|
"HumanMessagePromptTemplate",
|
||||||
"LengthBasedExampleSelector",
|
|
||||||
"MaxMarginalRelevanceExampleSelector",
|
|
||||||
"MessagesPlaceholder",
|
"MessagesPlaceholder",
|
||||||
"PipelinePromptTemplate",
|
"PipelinePromptTemplate",
|
||||||
"Prompt",
|
"Prompt",
|
||||||
"PromptTemplate",
|
"PromptTemplate",
|
||||||
"SemanticSimilarityExampleSelector",
|
|
||||||
"StringPromptTemplate",
|
"StringPromptTemplate",
|
||||||
"SystemMessagePromptTemplate",
|
"SystemMessagePromptTemplate",
|
||||||
"load_prompt",
|
"load_prompt",
|
||||||
"FewShotChatMessagePromptTemplate",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Test functionality related to prompt utils."""
|
"""Test functionality related to prompt utils."""
|
||||||
from langchain_core.prompts.example_selector.semantic_similarity import sorted_values
|
from langchain_core.example_selectors import sorted_values
|
||||||
|
|
||||||
|
|
||||||
def test_sorted_vals() -> None:
|
def test_sorted_vals() -> None:
|
||||||
|
File diff suppressed because one or more lines are too long
@ -1,8 +1,8 @@
|
|||||||
from langchain_core.callbacks.manager import CallbackManager
|
from langchain_core.callbacks.manager import CallbackManager
|
||||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||||
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
from langchain_core.callbacks.tracers.stdout import ConsoleCallbackHandler
|
|
||||||
from langchain_core.runnables.config import RunnableConfig, merge_configs
|
from langchain_core.runnables.config import RunnableConfig, merge_configs
|
||||||
|
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
def test_merge_config_callbacks() -> None:
|
def test_merge_config_callbacks() -> None:
|
@ -1,9 +1,10 @@
|
|||||||
from typing import Any, Callable, Sequence, Union
|
from typing import Any, Callable, Sequence, Union
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.runnables import RunnableConfig, RunnableLambda
|
from langchain_core.runnables.base import RunnableLambda
|
||||||
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||||
from langchain_core.schema import AIMessage, BaseMessage, HumanMessage
|
|
||||||
from tests.unit_tests.fake.memory import ChatMessageHistory
|
from tests.unit_tests.fake.memory import ChatMessageHistory
|
||||||
|
|
||||||
|
|
@ -26,53 +26,55 @@ from langchain_core.callbacks.manager import (
|
|||||||
collect_runs,
|
collect_runs,
|
||||||
trace_as_chain_group,
|
trace_as_chain_group,
|
||||||
)
|
)
|
||||||
from langchain_core.callbacks.tracers.base import BaseTracer
|
from langchain_core.documents import Document
|
||||||
from langchain_core.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
from langchain_core.load import dumpd, dumps
|
||||||
from langchain_core.callbacks.tracers.schemas import Run
|
from langchain_core.messages import (
|
||||||
from langchain_core.callbacks.tracers.stdout import ConsoleCallbackHandler
|
|
||||||
from langchain_core.load.dump import dumpd, dumps
|
|
||||||
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
|
|
||||||
from langchain_core.prompts import PromptTemplate
|
|
||||||
from langchain_core.prompts.base import StringPromptValue
|
|
||||||
from langchain_core.prompts.chat import (
|
|
||||||
ChatPromptTemplate,
|
|
||||||
ChatPromptValue,
|
|
||||||
HumanMessagePromptTemplate,
|
|
||||||
MessagesPlaceholder,
|
|
||||||
SystemMessagePromptTemplate,
|
|
||||||
)
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
|
||||||
from langchain_core.runnables import (
|
|
||||||
RouterRunnable,
|
|
||||||
Runnable,
|
|
||||||
RunnableBranch,
|
|
||||||
RunnableConfig,
|
|
||||||
RunnableLambda,
|
|
||||||
RunnableParallel,
|
|
||||||
RunnablePassthrough,
|
|
||||||
RunnableSequence,
|
|
||||||
RunnableWithFallbacks,
|
|
||||||
)
|
|
||||||
from langchain_core.runnables.base import (
|
|
||||||
ConfigurableField,
|
|
||||||
RunnableBinding,
|
|
||||||
RunnableGenerator,
|
|
||||||
)
|
|
||||||
from langchain_core.runnables.utils import (
|
|
||||||
ConfigurableFieldMultiOption,
|
|
||||||
ConfigurableFieldSingleOption,
|
|
||||||
add,
|
|
||||||
)
|
|
||||||
from langchain_core.schema.document import Document
|
|
||||||
from langchain_core.schema.messages import (
|
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.schema.output_parser import BaseOutputParser, StrOutputParser
|
from langchain_core.output_parsers import (
|
||||||
from langchain_core.schema.retriever import BaseRetriever
|
BaseOutputParser,
|
||||||
from langchain_core.tool import BaseTool, tool
|
CommaSeparatedListOutputParser,
|
||||||
|
StrOutputParser,
|
||||||
|
)
|
||||||
|
from langchain_core.prompts import (
|
||||||
|
ChatPromptTemplate,
|
||||||
|
ChatPromptValue,
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
MessagesPlaceholder,
|
||||||
|
PromptTemplate,
|
||||||
|
StringPromptValue,
|
||||||
|
SystemMessagePromptTemplate,
|
||||||
|
)
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
from langchain_core.runnables import (
|
||||||
|
ConfigurableField,
|
||||||
|
ConfigurableFieldMultiOption,
|
||||||
|
ConfigurableFieldSingleOption,
|
||||||
|
RouterRunnable,
|
||||||
|
Runnable,
|
||||||
|
RunnableBinding,
|
||||||
|
RunnableBranch,
|
||||||
|
RunnableConfig,
|
||||||
|
RunnableGenerator,
|
||||||
|
RunnableLambda,
|
||||||
|
RunnableParallel,
|
||||||
|
RunnablePassthrough,
|
||||||
|
RunnableSequence,
|
||||||
|
RunnableWithFallbacks,
|
||||||
|
add,
|
||||||
|
)
|
||||||
|
from langchain_core.tools import BaseTool, tool
|
||||||
|
from langchain_core.tracers import (
|
||||||
|
BaseTracer,
|
||||||
|
ConsoleCallbackHandler,
|
||||||
|
Run,
|
||||||
|
RunLog,
|
||||||
|
RunLogPatch,
|
||||||
|
)
|
||||||
from tests.unit_tests.fake.chat_model import FakeListChatModel
|
from tests.unit_tests.fake.chat_model import FakeListChatModel
|
||||||
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
|
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
|
||||||
|
|
||||||
@ -1539,7 +1541,7 @@ def test_with_listeners(mocker: MockerFixture) -> None:
|
|||||||
)
|
)
|
||||||
chat = FakeListChatModel(responses=["foo"])
|
chat = FakeListChatModel(responses=["foo"])
|
||||||
|
|
||||||
chain = prompt | chat
|
chain: Runnable = prompt | chat
|
||||||
|
|
||||||
mock_start = mocker.Mock()
|
mock_start = mocker.Mock()
|
||||||
mock_end = mocker.Mock()
|
mock_end = mocker.Mock()
|
||||||
@ -1572,7 +1574,7 @@ async def test_with_listeners_async(mocker: MockerFixture) -> None:
|
|||||||
)
|
)
|
||||||
chat = FakeListChatModel(responses=["foo"])
|
chat = FakeListChatModel(responses=["foo"])
|
||||||
|
|
||||||
chain = prompt | chat
|
chain: Runnable = prompt | chat
|
||||||
|
|
||||||
mock_start = mocker.Mock()
|
mock_start = mocker.Mock()
|
||||||
mock_end = mocker.Mock()
|
mock_end = mocker.Mock()
|
||||||
@ -1608,7 +1610,7 @@ def test_prompt_with_chat_model(
|
|||||||
)
|
)
|
||||||
chat = FakeListChatModel(responses=["foo"])
|
chat = FakeListChatModel(responses=["foo"])
|
||||||
|
|
||||||
chain = prompt | chat
|
chain: Runnable = prompt | chat
|
||||||
|
|
||||||
assert repr(chain) == snapshot
|
assert repr(chain) == snapshot
|
||||||
assert isinstance(chain, RunnableSequence)
|
assert isinstance(chain, RunnableSequence)
|
||||||
@ -1712,7 +1714,7 @@ async def test_prompt_with_chat_model_async(
|
|||||||
)
|
)
|
||||||
chat = FakeListChatModel(responses=["foo"])
|
chat = FakeListChatModel(responses=["foo"])
|
||||||
|
|
||||||
chain = prompt | chat
|
chain: Runnable = prompt | chat
|
||||||
|
|
||||||
assert repr(chain) == snapshot
|
assert repr(chain) == snapshot
|
||||||
assert isinstance(chain, RunnableSequence)
|
assert isinstance(chain, RunnableSequence)
|
||||||
@ -1819,7 +1821,7 @@ async def test_prompt_with_llm(
|
|||||||
)
|
)
|
||||||
llm = FakeListLLM(responses=["foo", "bar"])
|
llm = FakeListLLM(responses=["foo", "bar"])
|
||||||
|
|
||||||
chain = prompt | llm
|
chain: Runnable = prompt | llm
|
||||||
|
|
||||||
assert isinstance(chain, RunnableSequence)
|
assert isinstance(chain, RunnableSequence)
|
||||||
assert chain.first == prompt
|
assert chain.first == prompt
|
||||||
@ -2325,13 +2327,13 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
|||||||
async def test_router_runnable(
|
async def test_router_runnable(
|
||||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||||
) -> None:
|
) -> None:
|
||||||
chain1 = ChatPromptTemplate.from_template(
|
chain1: Runnable = ChatPromptTemplate.from_template(
|
||||||
"You are a math genius. Answer the question: {question}"
|
"You are a math genius. Answer the question: {question}"
|
||||||
) | FakeListLLM(responses=["4"])
|
) | FakeListLLM(responses=["4"])
|
||||||
chain2 = ChatPromptTemplate.from_template(
|
chain2: Runnable = ChatPromptTemplate.from_template(
|
||||||
"You are an english major. Answer the question: {question}"
|
"You are an english major. Answer the question: {question}"
|
||||||
) | FakeListLLM(responses=["2"])
|
) | FakeListLLM(responses=["2"])
|
||||||
router = RouterRunnable({"math": chain1, "english": chain2})
|
router: Runnable = RouterRunnable({"math": chain1, "english": chain2})
|
||||||
chain: Runnable = {
|
chain: Runnable = {
|
||||||
"key": lambda x: x["key"],
|
"key": lambda x: x["key"],
|
||||||
"input": {"question": lambda x: x["question"]},
|
"input": {"question": lambda x: x["question"]},
|
||||||
@ -2377,10 +2379,10 @@ async def test_router_runnable(
|
|||||||
async def test_higher_order_lambda_runnable(
|
async def test_higher_order_lambda_runnable(
|
||||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||||
) -> None:
|
) -> None:
|
||||||
math_chain = ChatPromptTemplate.from_template(
|
math_chain: Runnable = ChatPromptTemplate.from_template(
|
||||||
"You are a math genius. Answer the question: {question}"
|
"You are a math genius. Answer the question: {question}"
|
||||||
) | FakeListLLM(responses=["4"])
|
) | FakeListLLM(responses=["4"])
|
||||||
english_chain = ChatPromptTemplate.from_template(
|
english_chain: Runnable = ChatPromptTemplate.from_template(
|
||||||
"You are an english major. Answer the question: {question}"
|
"You are an english major. Answer the question: {question}"
|
||||||
) | FakeListLLM(responses=["2"])
|
) | FakeListLLM(responses=["2"])
|
||||||
input_map: Runnable = RunnableParallel(
|
input_map: Runnable = RunnableParallel(
|
||||||
@ -3096,7 +3098,7 @@ async def test_deep_astream_assign() -> None:
|
|||||||
def test_runnable_sequence_transform() -> None:
|
def test_runnable_sequence_transform() -> None:
|
||||||
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||||
|
|
||||||
chain = llm | StrOutputParser()
|
chain: Runnable = llm | StrOutputParser()
|
||||||
|
|
||||||
stream = chain.transform(llm.stream("Hi there!"))
|
stream = chain.transform(llm.stream("Hi there!"))
|
||||||
|
|
||||||
@ -3111,7 +3113,7 @@ def test_runnable_sequence_transform() -> None:
|
|||||||
async def test_runnable_sequence_atransform() -> None:
|
async def test_runnable_sequence_atransform() -> None:
|
||||||
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||||
|
|
||||||
chain = llm | StrOutputParser()
|
chain: Runnable = llm | StrOutputParser()
|
||||||
|
|
||||||
stream = chain.atransform(llm.astream("Hi there!"))
|
stream = chain.atransform(llm.astream("Hi there!"))
|
||||||
|
|
@ -1,43 +0,0 @@
|
|||||||
from langchain_core.schema import __all__
|
|
||||||
|
|
||||||
EXPECTED_ALL = [
|
|
||||||
"BaseCache",
|
|
||||||
"BaseMemory",
|
|
||||||
"BaseStore",
|
|
||||||
"AgentFinish",
|
|
||||||
"AgentAction",
|
|
||||||
"Document",
|
|
||||||
"BaseChatMessageHistory",
|
|
||||||
"BaseDocumentTransformer",
|
|
||||||
"BaseMessage",
|
|
||||||
"ChatMessage",
|
|
||||||
"FunctionMessage",
|
|
||||||
"HumanMessage",
|
|
||||||
"AIMessage",
|
|
||||||
"SystemMessage",
|
|
||||||
"messages_from_dict",
|
|
||||||
"messages_to_dict",
|
|
||||||
"_message_to_dict",
|
|
||||||
"_message_from_dict",
|
|
||||||
"get_buffer_string",
|
|
||||||
"RunInfo",
|
|
||||||
"LLMResult",
|
|
||||||
"ChatResult",
|
|
||||||
"ChatGeneration",
|
|
||||||
"Generation",
|
|
||||||
"PromptValue",
|
|
||||||
"LangChainException",
|
|
||||||
"BaseRetriever",
|
|
||||||
"RUN_KEY",
|
|
||||||
"Memory",
|
|
||||||
"OutputParserException",
|
|
||||||
"StrOutputParser",
|
|
||||||
"BaseOutputParser",
|
|
||||||
"BaseLLMOutputParser",
|
|
||||||
"BasePromptTemplate",
|
|
||||||
"format_document",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_all_imports() -> None:
|
|
||||||
assert set(__all__) == set(EXPECTED_ALL)
|
|
@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_core.schema.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
ChatMessageChunk,
|
ChatMessageChunk,
|
||||||
FunctionMessageChunk,
|
FunctionMessageChunk,
|
@ -1,5 +1,5 @@
|
|||||||
from langchain_core.schema.messages import HumanMessageChunk
|
from langchain_core.messages import HumanMessageChunk
|
||||||
from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk
|
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||||
|
|
||||||
|
|
||||||
def test_generation_chunk() -> None:
|
def test_generation_chunk() -> None:
|
@ -7,12 +7,12 @@ from typing import Any, List, Optional, Type, Union
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForToolRun,
|
AsyncCallbackManagerForToolRun,
|
||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.tool import (
|
from langchain_core.tools import (
|
||||||
BaseTool,
|
BaseTool,
|
||||||
SchemaAnnotationError,
|
SchemaAnnotationError,
|
||||||
StructuredTool,
|
StructuredTool,
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user