mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
manual mapping (#14422)
This commit is contained in:
parent
c24f277b7c
commit
f5befe3b89
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Literal, Sequence, Union
|
||||
from typing import Any, List, Literal, Sequence, Union
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import (
|
||||
@ -40,6 +40,11 @@ class AgentAction(Serializable):
|
||||
"""Return whether or not the class is serializable."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "agent"]
|
||||
|
||||
@property
|
||||
def messages(self) -> Sequence[BaseMessage]:
|
||||
"""Return the messages that correspond to this action."""
|
||||
@ -98,6 +103,11 @@ class AgentFinish(Serializable):
|
||||
"""Return whether or not the class is serializable."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "agent"]
|
||||
|
||||
@property
|
||||
def messages(self) -> Sequence[BaseMessage]:
|
||||
"""Return the messages that correspond to this observation."""
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
from typing import List, Literal
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
@ -21,3 +21,8 @@ class Document(Serializable):
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "document"]
|
||||
|
@ -3,6 +3,7 @@ import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.load.mapping import SERIALIZABLE_MAPPING
|
||||
from langchain_core.load.serializable import Serializable
|
||||
|
||||
DEFAULT_NAMESPACES = ["langchain", "langchain_core"]
|
||||
@ -62,8 +63,21 @@ class Reviver:
|
||||
if len(namespace) == 1 and namespace[0] == "langchain":
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
mod = importlib.import_module(".".join(namespace))
|
||||
cls = getattr(mod, name)
|
||||
# Get the importable path
|
||||
key = tuple(namespace + [name])
|
||||
if key not in SERIALIZABLE_MAPPING:
|
||||
raise ValueError(
|
||||
"Trying to deserialize something that cannot "
|
||||
"be deserialized in current version of langchain-core: "
|
||||
f"{key}"
|
||||
)
|
||||
import_path = SERIALIZABLE_MAPPING[key]
|
||||
# Split into module and name
|
||||
import_dir, import_obj = import_path[:-1], import_path[-1]
|
||||
# Import module
|
||||
mod = importlib.import_module(".".join(import_dir))
|
||||
# Import class
|
||||
cls = getattr(mod, import_obj)
|
||||
|
||||
# The class must be a subclass of Serializable.
|
||||
if not issubclass(cls, Serializable):
|
||||
|
478
libs/core/langchain_core/load/mapping.py
Normal file
478
libs/core/langchain_core/load/mapping.py
Normal file
@ -0,0 +1,478 @@
|
||||
# First value is the value that it is serialized as
|
||||
# Second value is the path to load it from
|
||||
SERIALIZABLE_MAPPING = {
|
||||
("langchain", "schema", "messages", "AIMessage"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
"ai",
|
||||
"AIMessage",
|
||||
),
|
||||
("langchain", "schema", "messages", "AIMessageChunk"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
"ai",
|
||||
"AIMessageChunk",
|
||||
),
|
||||
("langchain", "schema", "messages", "BaseMessage"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
"base",
|
||||
"BaseMessage",
|
||||
),
|
||||
("langchain", "schema", "messages", "BaseMessageChunk"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
"base",
|
||||
"BaseMessageChunk",
|
||||
),
|
||||
("langchain", "schema", "messages", "ChatMessage"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
"chat",
|
||||
"ChatMessage",
|
||||
),
|
||||
("langchain", "schema", "messages", "FunctionMessage"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
"function",
|
||||
"FunctionMessage",
|
||||
),
|
||||
("langchain", "schema", "messages", "HumanMessage"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
"human",
|
||||
"HumanMessage",
|
||||
),
|
||||
("langchain", "schema", "messages", "SystemMessage"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
"system",
|
||||
"SystemMessage",
|
||||
),
|
||||
("langchain", "schema", "messages", "ToolMessage"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
"tool",
|
||||
"ToolMessage",
|
||||
),
|
||||
("langchain", "schema", "agent", "AgentAction"): (
|
||||
"langchain_core",
|
||||
"agents",
|
||||
"AgentAction",
|
||||
),
|
||||
("langchain", "schema", "agent", "AgentFinish"): (
|
||||
"langchain_core",
|
||||
"agents",
|
||||
"AgentFinish",
|
||||
),
|
||||
("langchain", "schema", "prompt_template", "BasePromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"base",
|
||||
"BasePromptTemplate",
|
||||
),
|
||||
("langchain", "chains", "llm", "LLMChain"): (
|
||||
"langchain",
|
||||
"chains",
|
||||
"llm",
|
||||
"LLMChain",
|
||||
),
|
||||
("langchain", "prompts", "prompt", "PromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate",
|
||||
),
|
||||
("langchain", "prompts", "chat", "MessagesPlaceholder"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"chat",
|
||||
"MessagesPlaceholder",
|
||||
),
|
||||
("langchain", "llms", "openai", "OpenAI"): (
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"OpenAI",
|
||||
),
|
||||
("langchain", "prompts", "chat", "ChatPromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"chat",
|
||||
"ChatPromptTemplate",
|
||||
),
|
||||
("langchain", "prompts", "chat", "HumanMessagePromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"chat",
|
||||
"HumanMessagePromptTemplate",
|
||||
),
|
||||
("langchain", "prompts", "chat", "SystemMessagePromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"chat",
|
||||
"SystemMessagePromptTemplate",
|
||||
),
|
||||
("langchain", "schema", "agent", "AgentActionMessageLog"): (
|
||||
"langchain_core",
|
||||
"agents",
|
||||
"AgentActionMessageLog",
|
||||
),
|
||||
("langchain", "schema", "agent", "OpenAIToolAgentAction"): (
|
||||
"langchain",
|
||||
"agents",
|
||||
"output_parsers",
|
||||
"openai_tools",
|
||||
"OpenAIToolAgentAction",
|
||||
),
|
||||
("langchain", "prompts", "chat", "BaseMessagePromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"chat",
|
||||
"BaseMessagePromptTemplate",
|
||||
),
|
||||
("langchain", "schema", "output", "ChatGeneration"): (
|
||||
"langchain_core",
|
||||
"outputs",
|
||||
"chat_generation",
|
||||
"ChatGeneration",
|
||||
),
|
||||
("langchain", "schema", "output", "Generation"): (
|
||||
"langchain_core",
|
||||
"outputs",
|
||||
"generation",
|
||||
"Generation",
|
||||
),
|
||||
("langchain", "schema", "document", "Document"): (
|
||||
"langchain_core",
|
||||
"documents",
|
||||
"base",
|
||||
"Document",
|
||||
),
|
||||
("langchain", "output_parsers", "fix", "OutputFixingParser"): (
|
||||
"langchain",
|
||||
"output_parsers",
|
||||
"fix",
|
||||
"OutputFixingParser",
|
||||
),
|
||||
("langchain", "prompts", "chat", "AIMessagePromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"chat",
|
||||
"AIMessagePromptTemplate",
|
||||
),
|
||||
("langchain", "output_parsers", "regex", "RegexParser"): (
|
||||
"langchain",
|
||||
"output_parsers",
|
||||
"regex",
|
||||
"RegexParser",
|
||||
),
|
||||
("langchain", "schema", "runnable", "DynamicRunnable"): (
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"configurable",
|
||||
"DynamicRunnable",
|
||||
),
|
||||
("langchain", "schema", "prompt", "PromptValue"): (
|
||||
"langchain_core",
|
||||
"prompt_values",
|
||||
"PromptValue",
|
||||
),
|
||||
("langchain", "schema", "runnable", "RunnableBinding"): (
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableBinding",
|
||||
),
|
||||
("langchain", "schema", "runnable", "RunnableBranch"): (
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"branch",
|
||||
"RunnableBranch",
|
||||
),
|
||||
("langchain", "schema", "runnable", "RunnableWithFallbacks"): (
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"fallbacks",
|
||||
"RunnableWithFallbacks",
|
||||
),
|
||||
("langchain", "schema", "output_parser", "StrOutputParser"): (
|
||||
"langchain_core",
|
||||
"output_parsers",
|
||||
"string",
|
||||
"StrOutputParser",
|
||||
),
|
||||
("langchain", "chat_models", "openai", "ChatOpenAI"): (
|
||||
"langchain",
|
||||
"chat_models",
|
||||
"openai",
|
||||
"ChatOpenAI",
|
||||
),
|
||||
("langchain", "output_parsers", "list", "CommaSeparatedListOutputParser"): (
|
||||
"langchain_core",
|
||||
"output_parsers",
|
||||
"list",
|
||||
"CommaSeparatedListOutputParser",
|
||||
),
|
||||
("langchain", "schema", "runnable", "RunnableParallel"): (
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableParallel",
|
||||
),
|
||||
("langchain", "chat_models", "azure_openai", "AzureChatOpenAI"): (
|
||||
"langchain",
|
||||
"chat_models",
|
||||
"azure_openai",
|
||||
"AzureChatOpenAI",
|
||||
),
|
||||
("langchain", "chat_models", "bedrock", "BedrockChat"): (
|
||||
"langchain",
|
||||
"chat_models",
|
||||
"bedrock",
|
||||
"BedrockChat",
|
||||
),
|
||||
("langchain", "chat_models", "anthropic", "ChatAnthropic"): (
|
||||
"langchain",
|
||||
"chat_models",
|
||||
"anthropic",
|
||||
"ChatAnthropic",
|
||||
),
|
||||
("langchain", "chat_models", "fireworks", "ChatFireworks"): (
|
||||
"langchain",
|
||||
"chat_models",
|
||||
"fireworks",
|
||||
"ChatFireworks",
|
||||
),
|
||||
("langchain", "chat_models", "google_palm", "ChatGooglePalm"): (
|
||||
"langchain",
|
||||
"chat_models",
|
||||
"google_palm",
|
||||
"ChatGooglePalm",
|
||||
),
|
||||
("langchain", "chat_models", "vertexai", "ChatVertexAI"): (
|
||||
"langchain",
|
||||
"chat_models",
|
||||
"vertexai",
|
||||
"ChatVertexAI",
|
||||
),
|
||||
("langchain", "schema", "output", "ChatGenerationChunk"): (
|
||||
"langchain_core",
|
||||
"outputs",
|
||||
"chat_generation",
|
||||
"ChatGenerationChunk",
|
||||
),
|
||||
("langchain", "schema", "messages", "ChatMessageChunk"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
"chat",
|
||||
"ChatMessageChunk",
|
||||
),
|
||||
("langchain", "schema", "messages", "HumanMessageChunk"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
"human",
|
||||
"HumanMessageChunk",
|
||||
),
|
||||
("langchain", "schema", "messages", "FunctionMessageChunk"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
"function",
|
||||
"FunctionMessageChunk",
|
||||
),
|
||||
("langchain", "schema", "messages", "SystemMessageChunk"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
"system",
|
||||
"SystemMessageChunk",
|
||||
),
|
||||
("langchain", "schema", "messages", "ToolMessageChunk"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
"tool",
|
||||
"ToolMessageChunk",
|
||||
),
|
||||
("langchain", "schema", "output", "GenerationChunk"): (
|
||||
"langchain_core",
|
||||
"outputs",
|
||||
"generation",
|
||||
"GenerationChunk",
|
||||
),
|
||||
("langchain", "llms", "openai", "BaseOpenAI"): (
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"BaseOpenAI",
|
||||
),
|
||||
("langchain", "llms", "bedrock", "Bedrock"): (
|
||||
"langchain",
|
||||
"llms",
|
||||
"bedrock",
|
||||
"Bedrock",
|
||||
),
|
||||
("langchain", "llms", "fireworks", "Fireworks"): (
|
||||
"langchain",
|
||||
"llms",
|
||||
"fireworks",
|
||||
"Fireworks",
|
||||
),
|
||||
("langchain", "llms", "google_palm", "GooglePalm"): (
|
||||
"langchain",
|
||||
"llms",
|
||||
"google_palm",
|
||||
"GooglePalm",
|
||||
),
|
||||
("langchain", "llms", "openai", "AzureOpenAI"): (
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"AzureOpenAI",
|
||||
),
|
||||
("langchain", "llms", "replicate", "Replicate"): (
|
||||
"langchain",
|
||||
"llms",
|
||||
"replicate",
|
||||
"Replicate",
|
||||
),
|
||||
("langchain", "llms", "vertexai", "VertexAI"): (
|
||||
"langchain",
|
||||
"llms",
|
||||
"vertexai",
|
||||
"VertexAI",
|
||||
),
|
||||
("langchain", "output_parsers", "combining", "CombiningOutputParser"): (
|
||||
"langchain",
|
||||
"output_parsers",
|
||||
"combining",
|
||||
"CombiningOutputParser",
|
||||
),
|
||||
("langchain", "schema", "prompt_template", "BaseChatPromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"chat",
|
||||
"BaseChatPromptTemplate",
|
||||
),
|
||||
("langchain", "prompts", "chat", "ChatMessagePromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"chat",
|
||||
"ChatMessagePromptTemplate",
|
||||
),
|
||||
("langchain", "prompts", "few_shot_with_templates", "FewShotPromptWithTemplates"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"few_shot_with_templates",
|
||||
"FewShotPromptWithTemplates",
|
||||
),
|
||||
("langchain", "prompts", "pipeline", "PipelinePromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"pipeline",
|
||||
"PipelinePromptTemplate",
|
||||
),
|
||||
("langchain", "prompts", "base", "StringPromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"string",
|
||||
"StringPromptTemplate",
|
||||
),
|
||||
("langchain", "prompts", "base", "StringPromptValue"): (
|
||||
"langchain_core",
|
||||
"prompt_values",
|
||||
"StringPromptValue",
|
||||
),
|
||||
("langchain", "prompts", "chat", "BaseStringMessagePromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"chat",
|
||||
"BaseStringMessagePromptTemplate",
|
||||
),
|
||||
("langchain", "prompts", "chat", "ChatPromptValue"): (
|
||||
"langchain_core",
|
||||
"prompt_values",
|
||||
"ChatPromptValue",
|
||||
),
|
||||
("langchain", "prompts", "chat", "ChatPromptValueConcrete"): (
|
||||
"langchain_core",
|
||||
"prompt_values",
|
||||
"ChatPromptValueConcrete",
|
||||
),
|
||||
("langchain", "schema", "runnable", "HubRunnable"): (
|
||||
"langchain",
|
||||
"runnables",
|
||||
"hub",
|
||||
"HubRunnable",
|
||||
),
|
||||
("langchain", "schema", "runnable", "RunnableBindingBase"): (
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableBindingBase",
|
||||
),
|
||||
("langchain", "schema", "runnable", "OpenAIFunctionsRouter"): (
|
||||
"langchain",
|
||||
"runnables",
|
||||
"openai_functions",
|
||||
"OpenAIFunctionsRouter",
|
||||
),
|
||||
("langchain", "schema", "runnable", "RouterRunnable"): (
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"router",
|
||||
"RouterRunnable",
|
||||
),
|
||||
("langchain", "schema", "runnable", "RunnablePassthrough"): (
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"passthrough",
|
||||
"RunnablePassthrough",
|
||||
),
|
||||
("langchain", "schema", "runnable", "RunnableSequence"): (
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableSequence",
|
||||
),
|
||||
("langchain", "schema", "runnable", "RunnableEach"): (
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableEach",
|
||||
),
|
||||
("langchain", "schema", "runnable", "RunnableEachBase"): (
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableEachBase",
|
||||
),
|
||||
("langchain", "schema", "runnable", "RunnableConfigurableAlternatives"): (
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"configurable",
|
||||
"RunnableConfigurableAlternatives",
|
||||
),
|
||||
("langchain", "schema", "runnable", "RunnableConfigurableFields"): (
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"configurable",
|
||||
"RunnableConfigurableFields",
|
||||
),
|
||||
("langchain", "schema", "runnable", "RunnableWithMessageHistory"): (
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"history",
|
||||
"RunnableWithMessageHistory",
|
||||
),
|
||||
("langchain", "schema", "runnable", "RunnableAssign"): (
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"passthrough",
|
||||
"RunnableAssign",
|
||||
),
|
||||
("langchain", "schema", "runnable", "RunnableRetry"): (
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"retry",
|
||||
"RunnableRetry",
|
||||
),
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Literal
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
@ -17,6 +17,11 @@ class AIMessage(BaseMessage):
|
||||
|
||||
type: Literal["ai"] = "ai"
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
|
||||
AIMessage.update_forward_refs()
|
||||
|
||||
@ -29,6 +34,11 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
# non-chunk variant.
|
||||
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, AIMessageChunk):
|
||||
if self.example != other.example:
|
||||
|
@ -31,6 +31,11 @@ class BaseMessage(Serializable):
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
|
||||
@ -68,6 +73,11 @@ def merge_content(
|
||||
class BaseMessageChunk(BaseMessage):
|
||||
"""A Message chunk, which can be concatenated with other Message chunks."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
def _merge_kwargs_dict(
|
||||
self, left: Dict[str, Any], right: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Literal
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
@ -15,6 +15,11 @@ class ChatMessage(BaseMessage):
|
||||
|
||||
type: Literal["chat"] = "chat"
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
|
||||
ChatMessage.update_forward_refs()
|
||||
|
||||
@ -27,6 +32,11 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||
# non-chunk variant.
|
||||
type: Literal["ChatMessageChunk"] = "ChatMessageChunk" # type: ignore
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, ChatMessageChunk):
|
||||
if self.role != other.role:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Literal
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
@ -15,6 +15,11 @@ class FunctionMessage(BaseMessage):
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
|
||||
FunctionMessage.update_forward_refs()
|
||||
|
||||
@ -27,6 +32,11 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||
# non-chunk variant.
|
||||
type: Literal["FunctionMessageChunk"] = "FunctionMessageChunk" # type: ignore[assignment]
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, FunctionMessageChunk):
|
||||
if self.name != other.name:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Literal
|
||||
from typing import List, Literal
|
||||
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||
|
||||
@ -13,6 +13,11 @@ class HumanMessage(BaseMessage):
|
||||
|
||||
type: Literal["human"] = "human"
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
|
||||
HumanMessage.update_forward_refs()
|
||||
|
||||
@ -24,3 +29,8 @@ class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
type: Literal["HumanMessageChunk"] = "HumanMessageChunk" # type: ignore[assignment] # noqa: E501
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Literal
|
||||
from typing import List, Literal
|
||||
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||
|
||||
@ -10,6 +10,11 @@ class SystemMessage(BaseMessage):
|
||||
|
||||
type: Literal["system"] = "system"
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
|
||||
SystemMessage.update_forward_refs()
|
||||
|
||||
@ -21,3 +26,8 @@ class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
type: Literal["SystemMessageChunk"] = "SystemMessageChunk" # type: ignore[assignment] # noqa: E501
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Literal
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
@ -15,6 +15,11 @@ class ToolMessage(BaseMessage):
|
||||
|
||||
type: Literal["tool"] = "tool"
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
|
||||
ToolMessage.update_forward_refs()
|
||||
|
||||
@ -27,6 +32,11 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||
# non-chunk variant.
|
||||
type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment]
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, ToolMessageChunk):
|
||||
if self.tool_call_id != other.tool_call_id:
|
||||
|
@ -26,6 +26,11 @@ class CommaSeparatedListOutputParser(ListOutputParser):
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "output_parsers", "list"]
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return (
|
||||
"Your response should be a list of comma separated values, "
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
|
||||
|
||||
@ -9,6 +11,11 @@ class StrOutputParser(BaseTransformOutputParser[str]):
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "output_parser"]
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Return the output parser type for serialization."""
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Literal
|
||||
from typing import Any, Dict, List, Literal
|
||||
|
||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||
from langchain_core.outputs.generation import Generation
|
||||
@ -27,6 +27,11 @@ class ChatGeneration(Generation):
|
||||
raise ValueError("Error while initializing ChatGeneration") from e
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "output"]
|
||||
|
||||
|
||||
class ChatGenerationChunk(ChatGeneration):
|
||||
"""A ChatGeneration chunk, which can be concatenated with other
|
||||
@ -41,6 +46,11 @@ class ChatGenerationChunk(ChatGeneration):
|
||||
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "output"]
|
||||
|
||||
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
|
||||
if isinstance(other, ChatGenerationChunk):
|
||||
generation_info = (
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from langchain_core.load import Serializable
|
||||
|
||||
@ -24,10 +24,20 @@ class Generation(Serializable):
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "output"]
|
||||
|
||||
|
||||
class GenerationChunk(Generation):
|
||||
"""A Generation chunk, which can be concatenated with other Generation chunks."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "output"]
|
||||
|
||||
def __add__(self, other: GenerationChunk) -> GenerationChunk:
|
||||
if isinstance(other, GenerationChunk):
|
||||
generation_info = (
|
||||
|
@ -24,6 +24,11 @@ class PromptValue(Serializable, ABC):
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "prompt"]
|
||||
|
||||
@abstractmethod
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt value as string."""
|
||||
@ -40,6 +45,11 @@ class StringPromptValue(PromptValue):
|
||||
"""Prompt text."""
|
||||
type: Literal["StringPromptValue"] = "StringPromptValue"
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "base"]
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt as string."""
|
||||
return self.text
|
||||
@ -66,6 +76,11 @@ class ChatPromptValue(PromptValue):
|
||||
"""Return prompt as a list of messages."""
|
||||
return list(self.messages)
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
|
||||
class ChatPromptValueConcrete(ChatPromptValue):
|
||||
"""Chat prompt value which explicitly lists out the message types it accepts.
|
||||
@ -74,3 +89,8 @@ class ChatPromptValueConcrete(ChatPromptValue):
|
||||
messages: Sequence[AnyMessage]
|
||||
|
||||
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
@ -44,6 +44,11 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "prompt_template"]
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
|
@ -43,6 +43,11 @@ class BaseMessagePromptTemplate(Serializable, ABC):
|
||||
"""Return whether or not the class is serializable."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
@abstractmethod
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs. Should return a list of BaseMessages.
|
||||
@ -82,6 +87,11 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
variable_name: str
|
||||
"""Name of variable to use as messages."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
def __init__(self, variable_name: str, **kwargs: Any):
|
||||
return super().__init__(variable_name=variable_name, **kwargs)
|
||||
|
||||
@ -132,6 +142,11 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
"""Additional keyword arguments to pass to the prompt template."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls: Type[MessagePromptTemplateT],
|
||||
@ -221,6 +236,11 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
role: str
|
||||
"""Role of the message."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
@ -239,6 +259,11 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""Human message prompt template. This is a message sent from the user."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
@ -255,6 +280,11 @@ class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""AI message prompt template. This is a message sent from the AI."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
@ -273,6 +303,11 @@ class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
This is a message that is not sent to the user.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
@ -368,6 +403,11 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
"""Combine two prompt templates.
|
||||
|
||||
|
@ -42,6 +42,11 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "few_shot_with_templates"]
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_examples_and_selector(cls, values: Dict) -> Dict:
|
||||
"""Check that one and only one of examples/example_selector are provided."""
|
||||
|
@ -28,6 +28,11 @@ class PipelinePromptTemplate(BasePromptTemplate):
|
||||
pipeline_prompts: List[Tuple[str, BasePromptTemplate]]
|
||||
"""A list of tuples, consisting of a string (`name`) and a Prompt Template."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "pipeline"]
|
||||
|
||||
@root_validator(pre=True)
|
||||
def get_input_variables(cls, values: Dict) -> Dict:
|
||||
"""Get input variables."""
|
||||
|
@ -54,6 +54,11 @@ class PromptTemplate(StringPromptTemplate):
|
||||
"template_format": self.template_format,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "prompt"]
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
|
||||
|
@ -151,6 +151,11 @@ def get_template_variables(template: str, template_format: str) -> List[str]:
|
||||
class StringPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""String prompt that exposes the format method, returning a prompt."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "base"]
|
||||
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""Create Chat Messages."""
|
||||
return StringPromptValue(text=self.format(**kwargs))
|
||||
|
@ -1349,6 +1349,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
last: Runnable[Any, Output]
|
||||
"""The last runnable in the sequence."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@property
|
||||
def steps(self) -> List[Runnable[Any, Any]]:
|
||||
"""All the runnables that make up the sequence in order."""
|
||||
@ -1358,10 +1363,6 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@ -1939,7 +1940,8 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@ -2705,7 +2707,8 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
@ -2746,6 +2749,11 @@ class RunnableEach(RunnableEachBase[Input, Output]):
|
||||
with each element of the input sequence.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:
|
||||
return RunnableEach(bound=self.bound.bind(**kwargs))
|
||||
|
||||
@ -2910,7 +2918,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
config = merge_configs(self.config, *configs)
|
||||
@ -3086,6 +3095,11 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
|
||||
runnable_binding.invoke('Say "Parrot-MAGIC"') # Should return `Parrot`
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||
"""Bind additional kwargs to a Runnable, returning a new Runnable.
|
||||
|
||||
|
@ -132,8 +132,8 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""The namespace of a RunnableBranch is the namespace of its default branch."""
|
||||
return cls.__module__.split(".")[:-1]
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
|
@ -53,7 +53,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
@ -217,6 +218,11 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
|
||||
fields: Dict[str, AnyConfigurableField]
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
@ -318,6 +324,11 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
of the form <which.id>==<alternative_key>, eg. a key named "temperature" used by
|
||||
the alternative named "gpt3" becomes "model==gpt3/temperature"."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
with _enums_for_spec_lock:
|
||||
|
@ -125,7 +125,8 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@property
|
||||
def runnables(self) -> Iterator[Runnable[Input, Output]]:
|
||||
|
@ -86,6 +86,11 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
output_messages_key: Optional[str] = None
|
||||
history_messages_key: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Runnable[
|
||||
|
@ -167,7 +167,8 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
@ -312,7 +313,8 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
|
@ -114,6 +114,11 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
max_attempt_number: int = 3
|
||||
"""The maximum number of attempts to retry the runnable."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@property
|
||||
def _kwargs_retrying(self) -> Dict[str, Any]:
|
||||
kwargs: Dict[str, Any] = dict()
|
||||
|
@ -77,7 +77,8 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
def invoke(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
|
File diff suppressed because one or more lines are too long
@ -2029,7 +2029,7 @@ async def test_prompt_with_llm(
|
||||
):
|
||||
del op["value"]["id"]
|
||||
|
||||
assert stream_log == [
|
||||
expected = [
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "replace",
|
||||
@ -2113,6 +2113,7 @@ async def test_prompt_with_llm(
|
||||
{"op": "replace", "path": "/final_output", "value": "foo"},
|
||||
),
|
||||
]
|
||||
assert stream_log == expected
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
|
@ -105,6 +105,11 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "anthropic"]
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
|
||||
"""Format a list of messages into a full prompt for the Anthropic model
|
||||
Args:
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, Union
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
|
||||
from langchain_core.outputs import ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
@ -94,6 +94,11 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
infer if it is a base_url or azure_endpoint and update accordingly.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "azure_openai"]
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
@ -50,6 +50,11 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "bedrock"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
|
@ -101,6 +101,11 @@ class ChatFireworks(BaseChatModel):
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "fireworks"]
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key in environment."""
|
||||
|
@ -256,6 +256,11 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
def is_lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "google_palm"]
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate api key, python package exists, temperature, top_p, and top_k."""
|
||||
|
@ -160,6 +160,11 @@ class ChatOpenAI(BaseChatModel):
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "openai"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
|
@ -127,6 +127,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
def is_lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "vertexai"]
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in environment."""
|
||||
|
@ -357,6 +357,11 @@ class Bedrock(LLM, BedrockBase):
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "llms", "bedrock"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
|
@ -51,6 +51,11 @@ class Fireworks(BaseLLM):
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "llms", "fireworks"]
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key in environment."""
|
||||
|
@ -75,6 +75,11 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
def is_lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "llms", "google_palm"]
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate api key, python package exists."""
|
||||
|
@ -149,6 +149,11 @@ class BaseOpenAI(BaseLLM):
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "llms", "openai"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
@ -736,6 +741,11 @@ class OpenAI(BaseOpenAI):
|
||||
openai = OpenAI(model_name="text-davinci-003")
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "llms", "openai"]
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
return {**{"model": self.model_name}, **super()._invocation_params}
|
||||
@ -794,6 +804,11 @@ class AzureOpenAI(BaseOpenAI):
|
||||
infer if it is a base_url or azure_endpoint and update accordingly.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "llms", "openai"]
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
@ -70,6 +70,11 @@ class Replicate(LLM):
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "llms", "replicate"]
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
|
@ -224,6 +224,11 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
def is_lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "llms", "vertexai"]
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in environment."""
|
||||
|
@ -9,12 +9,12 @@ from langchain_core.pydantic_v1 import root_validator
|
||||
class CombiningOutputParser(BaseOutputParser):
|
||||
"""Combine multiple output parsers into one."""
|
||||
|
||||
parsers: List[BaseOutputParser]
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
parsers: List[BaseOutputParser]
|
||||
|
||||
@root_validator()
|
||||
def validate_parsers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate the parsers."""
|
||||
|
@ -97,7 +97,7 @@
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
@ -152,7 +152,7 @@
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"langchain",
|
||||
"prompts",
|
||||
"chat",
|
||||
"ChatPromptTemplate"
|
||||
@ -166,7 +166,7 @@
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"langchain",
|
||||
"prompts",
|
||||
"chat",
|
||||
"HumanMessagePromptTemplate"
|
||||
@ -176,7 +176,7 @@
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
@ -236,7 +236,7 @@
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
|
55
libs/langchain/tests/unit_tests/load/test_serializable.py
Normal file
55
libs/langchain/tests/unit_tests/load/test_serializable.py
Normal file
@ -0,0 +1,55 @@
|
||||
import importlib
|
||||
import pkgutil
|
||||
|
||||
from langchain_core.load.mapping import SERIALIZABLE_MAPPING
|
||||
|
||||
|
||||
def import_all_modules(package_name: str) -> dict:
|
||||
package = importlib.import_module(package_name)
|
||||
classes: dict = {}
|
||||
|
||||
for attribute_name in dir(package):
|
||||
attribute = getattr(package, attribute_name)
|
||||
if hasattr(attribute, "is_lc_serializable") and isinstance(attribute, type):
|
||||
if (
|
||||
isinstance(attribute.is_lc_serializable(), bool) # type: ignore
|
||||
and attribute.is_lc_serializable() # type: ignore
|
||||
):
|
||||
key = tuple(attribute.lc_id()) # type: ignore
|
||||
value = tuple(attribute.__module__.split(".") + [attribute.__name__])
|
||||
if key in classes and classes[key] != value:
|
||||
raise ValueError
|
||||
classes[key] = value
|
||||
if hasattr(package, "__path__"):
|
||||
for loader, module_name, is_pkg in pkgutil.walk_packages(
|
||||
package.__path__, package_name + "."
|
||||
):
|
||||
if module_name not in (
|
||||
"langchain.chains.llm_bash",
|
||||
"langchain.chains.llm_symbolic_math",
|
||||
"langchain.tools.python",
|
||||
"langchain.vectorstores._pgvector_data_models",
|
||||
):
|
||||
importlib.import_module(module_name)
|
||||
new_classes = import_all_modules(module_name)
|
||||
for k, v in new_classes.items():
|
||||
if k in classes and classes[k] != v:
|
||||
raise ValueError
|
||||
classes[k] = v
|
||||
return classes
|
||||
|
||||
|
||||
def test_serializable_mapping() -> None:
|
||||
serializable_modules = import_all_modules("langchain")
|
||||
missing = set(SERIALIZABLE_MAPPING).difference(serializable_modules)
|
||||
assert missing == set()
|
||||
extra = set(serializable_modules).difference(SERIALIZABLE_MAPPING)
|
||||
assert extra == set()
|
||||
|
||||
for k, import_path in serializable_modules.items():
|
||||
import_dir, import_obj = import_path[:-1], import_path[-1]
|
||||
# Import module
|
||||
mod = importlib.import_module(".".join(import_dir))
|
||||
# Import class
|
||||
cls = getattr(mod, import_obj)
|
||||
assert list(k) == cls.lc_id()
|
Loading…
Reference in New Issue
Block a user