Compare commits

...

2 Commits

Author SHA1 Message Date
Harrison Chase
d22dd95dfa cr 2023-12-04 11:39:22 -08:00
Harrison Chase
310e946124 integrations start 2023-12-03 08:43:58 -08:00
1527 changed files with 121890 additions and 117061 deletions

View File

@@ -22,6 +22,7 @@ from langchain_core.utils.utils import (
raise_for_status_with_text,
xor_args,
)
from langchain_core.utils.env import get_from_dict_or_env, get_from_env
__all__ = [
"StrictFormatter",
@@ -39,4 +40,6 @@ __all__ = [
"xor_args",
"try_load_from_hub",
"build_extra_kwargs",
"get_from_dict_or_env",
"get_from_env",
]

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import os
from typing import Any, Dict, Optional
def env_var_is_set(env_var: str) -> bool:
@@ -18,3 +19,26 @@ def env_var_is_set(env_var: str) -> bool:
"false",
"False",
)
def get_from_dict_or_env(
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
) -> str:
"""Get a value from a dictionary or an environment variable."""
if key in data and data[key]:
return data[key]
else:
return get_from_env(key, env_key, default=default)
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
"""Get a value from a dictionary or an environment variable."""
if env_key in os.environ and os.environ[env_key]:
return os.environ[env_key]
elif default is not None:
return default
else:
raise ValueError(
f"Did not find {key}, please add an environment variable"
f" `{env_key}` which contains it, or pass"
f" `{key}` as a named parameter."
)

View File

@@ -0,0 +1,264 @@
from __future__ import annotations
import importlib
from typing import (
Any,
AsyncIterator,
Dict,
Iterable,
List,
Mapping,
Sequence,
Union,
overload,
)
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from typing_extensions import Literal
async def aenumerate(
iterable: AsyncIterator[Any], start: int = 0
) -> AsyncIterator[tuple[int, Any]]:
"""Async version of enumerate function."""
i = start
async for x in iterable:
yield i, x
i += 1
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
Args:
_dict: The dictionary.
Returns:
The LangChain message.
"""
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
# Fix for azure
# Also OpenAI returns None for tool invocations
content = _dict.get("content", "") or ""
additional_kwargs: Dict = {}
if _dict.get("function_call"):
additional_kwargs["function_call"] = dict(_dict["function_call"])
if _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = _dict["tool_calls"]
return AIMessage(content=content, additional_kwargs=additional_kwargs)
elif role == "system":
return SystemMessage(content=_dict["content"])
elif role == "function":
return FunctionMessage(content=_dict["content"], name=_dict["name"])
elif role == "tool":
return ToolMessage(content=_dict["content"], tool_call_id=_dict["tool_call_id"])
else:
return ChatMessage(content=_dict["content"], role=role)
def convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a dictionary.
Args:
message: The LangChain message.
Returns:
The dictionary.
"""
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
if "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
# If tool calls only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"content": message.content,
"tool_call_id": message.tool_call_id,
}
else:
raise TypeError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]:
"""Convert dictionaries representing OpenAI messages to LangChain format.
Args:
messages: List of dictionaries representing OpenAI messages
Returns:
List of LangChain BaseMessage objects.
"""
return [convert_dict_to_message(m) for m in messages]
def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]:
_dict: Dict[str, Any] = {}
if isinstance(chunk, AIMessageChunk):
if i == 0:
# Only shows up in the first chunk
_dict["role"] = "assistant"
if "function_call" in chunk.additional_kwargs:
_dict["function_call"] = chunk.additional_kwargs["function_call"]
# If the first chunk is a function call, the content is not empty string,
# not missing, but None.
if i == 0:
_dict["content"] = None
else:
_dict["content"] = chunk.content
else:
raise ValueError(f"Got unexpected streaming chunk type: {type(chunk)}")
# This only happens at the end of streams, and OpenAI returns as empty dict
if _dict == {"content": ""}:
_dict = {}
return {"choices": [{"delta": _dict}]}
class ChatCompletion:
"""Chat completion."""
@overload
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[False] = False,
**kwargs: Any,
) -> dict:
...
@overload
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[True],
**kwargs: Any,
) -> Iterable:
...
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: bool = False,
**kwargs: Any,
) -> Union[dict, Iterable]:
models = importlib.import_module("langchain.chat_models")
model_cls = getattr(models, provider)
model_config = model_cls(**kwargs)
converted_messages = convert_openai_messages(messages)
if not stream:
result = model_config.invoke(converted_messages)
return {"choices": [{"message": convert_message_to_dict(result)}]}
else:
return (
_convert_message_chunk_to_delta(c, i)
for i, c in enumerate(model_config.stream(converted_messages))
)
@overload
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[False] = False,
**kwargs: Any,
) -> dict:
...
@overload
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[True],
**kwargs: Any,
) -> AsyncIterator:
...
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: bool = False,
**kwargs: Any,
) -> Union[dict, AsyncIterator]:
models = importlib.import_module("langchain.chat_models")
model_cls = getattr(models, provider)
model_config = model_cls(**kwargs)
converted_messages = convert_openai_messages(messages)
if not stream:
result = await model_config.ainvoke(converted_messages)
return {"choices": [{"message": convert_message_to_dict(result)}]}
else:
return (
_convert_message_chunk_to_delta(c, i)
async for i, c in aenumerate(model_config.astream(converted_messages))
)
def _has_assistant_message(session: ChatSession) -> bool:
"""Check if chat session has an assistant message."""
return any([isinstance(m, AIMessage) for m in session["messages"]])
def convert_messages_for_finetuning(
sessions: Iterable[ChatSession],
) -> List[List[dict]]:
"""Convert messages to a list of lists of dictionaries for fine-tuning.
Args:
sessions: The chat sessions.
Returns:
The list of lists of dictionaries.
"""
return [
[convert_message_to_dict(s) for s in session["messages"]]
for session in sessions
if _has_assistant_message(session)
]

View File

@@ -0,0 +1,115 @@
"""Agent toolkits contain integrations with various resources and services.
LangChain has a large ecosystem of integrations with various external resources
like local and remote file systems, APIs and databases.
These integrations allow developers to create versatile applications that combine the
power of LLMs with the ability to access, interact with and manipulate external
resources.
When developing an application, developers should inspect the capabilities and
permissions of the tools that underlie the given agent toolkit, and determine
whether permissions of the given toolkit are appropriate for the application.
See [Security](https://python.langchain.com/docs/security) for more information.
"""
from pathlib import Path
from typing import Any
from langchain_core._api.path import as_import_path
from langchain_integrations.agent_toolkits.ainetwork.toolkit import AINetworkToolkit
#from langchain_integrations.agent_toolkits.amadeus.toolkit import AmadeusToolkit
from langchain_integrations.agent_toolkits.azure_cognitive_services import (
AzureCognitiveServicesToolkit,
)
#from langchain_integrations.agent_toolkits.conversational_retrieval.openai_functions import (
# create_conversational_retrieval_agent,
#)
from langchain_integrations.agent_toolkits.file_management.toolkit import (
FileManagementToolkit,
)
from langchain_integrations.agent_toolkits.gmail.toolkit import GmailToolkit
from langchain_integrations.agent_toolkits.jira.toolkit import JiraToolkit
#from langchain_integrations.agent_toolkits.json.base import create_json_agent
#from langchain_integrations.agent_toolkits.json.toolkit import JsonToolkit
from langchain_integrations.agent_toolkits.multion.toolkit import MultionToolkit
#from langchain_integrations.agent_toolkits.nla.toolkit import NLAToolkit
from langchain_integrations.agent_toolkits.office365.toolkit import O365Toolkit
#from langchain_integrations.agent_toolkits.openapi.base import create_openapi_agent
#from langchain_integrations.agent_toolkits.openapi.toolkit import OpenAPIToolkit
from langchain_integrations.agent_toolkits.playwright.toolkit import PlayWrightBrowserToolkit
#from langchain_integrations.agent_toolkits.powerbi.base import create_pbi_agent
#from langchain_integrations.agent_toolkits.powerbi.chat_base import create_pbi_chat_agent
#from langchain_integrations.agent_toolkits.powerbi.toolkit import PowerBIToolkit
#from langchain_integrations.agent_toolkits.spark_sql.base import create_spark_sql_agent
#from langchain_integrations.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
from langchain_integrations.agent_toolkits.sql.base import create_sql_agent
from langchain_integrations.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_integrations.agent_toolkits.vectorstore.base import (
create_vectorstore_agent,
create_vectorstore_router_agent,
)
from langchain_integrations.agent_toolkits.vectorstore.toolkit import (
VectorStoreInfo,
VectorStoreRouterToolkit,
VectorStoreToolkit,
)
from langchain_integrations.agent_toolkits.zapier.toolkit import ZapierToolkit
from langchain_integrations.tools.retriever import create_retriever_tool
DEPRECATED_AGENTS = [
"create_csv_agent",
"create_pandas_dataframe_agent",
"create_xorbits_agent",
"create_python_agent",
"create_spark_dataframe_agent",
]
def __getattr__(name: str) -> Any:
"""Get attr name."""
if name in DEPRECATED_AGENTS:
relative_path = as_import_path(Path(__file__).parent, suffix=name)
old_path = "langchain." + relative_path
new_path = "langchain_experimental." + relative_path
raise ImportError(
f"{name} has been moved to langchain experimental. "
"See https://github.com/langchain-ai/langchain/discussions/11680"
"for more information.\n"
f"Please update your import statement from: `{old_path}` to `{new_path}`."
)
raise AttributeError(f"{name} does not exist")
__all__ = [
"AINetworkToolkit",
"AmadeusToolkit",
"AzureCognitiveServicesToolkit",
"FileManagementToolkit",
"GmailToolkit",
"JiraToolkit",
"JsonToolkit",
"MultionToolkit",
"NLAToolkit",
"O365Toolkit",
"OpenAPIToolkit",
"PlayWrightBrowserToolkit",
"PowerBIToolkit",
"SQLDatabaseToolkit",
"SparkSQLToolkit",
"VectorStoreInfo",
"VectorStoreRouterToolkit",
"VectorStoreToolkit",
"ZapierToolkit",
"create_json_agent",
"create_openapi_agent",
"create_pbi_agent",
"create_pbi_chat_agent",
"create_spark_sql_agent",
"create_sql_agent",
"create_vectorstore_agent",
"create_vectorstore_router_agent",
"create_conversational_retrieval_agent",
"create_retriever_tool",
]

View File

@@ -0,0 +1 @@
"""AINetwork toolkit."""

View File

@@ -0,0 +1,53 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List, Literal, Optional
from langchain_core.pydantic_v1 import root_validator
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_integrations.tools import BaseTool
from langchain_integrations.tools.ainetwork.app import AINAppOps
from langchain_integrations.tools.ainetwork.owner import AINOwnerOps
from langchain_integrations.tools.ainetwork.rule import AINRuleOps
from langchain_integrations.tools.ainetwork.transfer import AINTransfer
from langchain_integrations.tools.ainetwork.utils import authenticate
from langchain_integrations.tools.ainetwork.value import AINValueOps
if TYPE_CHECKING:
from ain.ain import Ain
class AINetworkToolkit(BaseToolkit):
"""Toolkit for interacting with AINetwork Blockchain.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by reading, creating, updating, deleting
data associated with this service.
See https://python.langchain.com/docs/security for more information.
"""
network: Optional[Literal["mainnet", "testnet"]] = "testnet"
interface: Optional[Ain] = None
@root_validator(pre=True)
def set_interface(cls, values: dict) -> dict:
if not values.get("interface"):
values["interface"] = authenticate(network=values.get("network", "testnet"))
return values
class Config:
"""Pydantic config."""
validate_all = True
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
AINAppOps(),
AINOwnerOps(),
AINRuleOps(),
AINTransfer(),
AINValueOps(),
]

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List
from langchain_core.pydantic_v1 import Field
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_integrations.tools import BaseTool
from langchain_integrations.tools.amadeus.closest_airport import AmadeusClosestAirport
from langchain_integrations.tools.amadeus.flight_search import AmadeusFlightSearch
from langchain_integrations.tools.amadeus.utils import authenticate
if TYPE_CHECKING:
from amadeus import Client
class AmadeusToolkit(BaseToolkit):
"""Toolkit for interacting with Amadeus which offers APIs for travel search."""
client: Client = Field(default_factory=authenticate)
class Config:
"""Pydantic config."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
AmadeusClosestAirport(),
AmadeusFlightSearch(),
]

View File

@@ -0,0 +1,33 @@
from __future__ import annotations
import sys
from typing import List
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_integrations.tools.azure_cognitive_services import (
AzureCogsFormRecognizerTool,
AzureCogsImageAnalysisTool,
AzureCogsSpeech2TextTool,
AzureCogsText2SpeechTool,
AzureCogsTextAnalyticsHealthTool,
)
from langchain_core.tools import BaseTool
class AzureCognitiveServicesToolkit(BaseToolkit):
"""Toolkit for Azure Cognitive Services."""
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
tools: List[BaseTool] = [
AzureCogsFormRecognizerTool(),
AzureCogsSpeech2TextTool(),
AzureCogsText2SpeechTool(),
AzureCogsTextAnalyticsHealthTool(),
]
# TODO: Remove check once azure-ai-vision supports MacOS.
if sys.platform.startswith("linux") or sys.platform.startswith("win"):
tools.append(AzureCogsImageAnalysisTool())
return tools

View File

@@ -0,0 +1,15 @@
"""Toolkits for agents."""
from abc import ABC, abstractmethod
from typing import List
from langchain_core.pydantic_v1 import BaseModel
from langchain_integrations.tools import BaseTool
class BaseToolkit(BaseModel, ABC):
"""Base Toolkit representing a collection of related tools."""
@abstractmethod
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""

View File

@@ -0,0 +1,108 @@
from typing import Dict, List
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_integrations.tools import BaseTool
from langchain_integrations.tools.clickup.prompt import (
CLICKUP_FOLDER_CREATE_PROMPT,
CLICKUP_GET_ALL_TEAMS_PROMPT,
CLICKUP_GET_FOLDERS_PROMPT,
CLICKUP_GET_LIST_PROMPT,
CLICKUP_GET_SPACES_PROMPT,
CLICKUP_GET_TASK_ATTRIBUTE_PROMPT,
CLICKUP_GET_TASK_PROMPT,
CLICKUP_LIST_CREATE_PROMPT,
CLICKUP_TASK_CREATE_PROMPT,
CLICKUP_UPDATE_TASK_ASSIGNEE_PROMPT,
CLICKUP_UPDATE_TASK_PROMPT,
)
from langchain_integrations.tools.clickup.tool import ClickupAction
from langchain_integrations.utilities.clickup import ClickupAPIWrapper
class ClickupToolkit(BaseToolkit):
"""Clickup Toolkit.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by reading, creating, updating, deleting
data associated with this service.
See https://python.langchain.com/docs/security for more information.
"""
tools: List[BaseTool] = []
@classmethod
def from_clickup_api_wrapper(
cls, clickup_api_wrapper: ClickupAPIWrapper
) -> "ClickupToolkit":
operations: List[Dict] = [
{
"mode": "get_task",
"name": "Get task",
"description": CLICKUP_GET_TASK_PROMPT,
},
{
"mode": "get_task_attribute",
"name": "Get task attribute",
"description": CLICKUP_GET_TASK_ATTRIBUTE_PROMPT,
},
{
"mode": "get_teams",
"name": "Get Teams",
"description": CLICKUP_GET_ALL_TEAMS_PROMPT,
},
{
"mode": "create_task",
"name": "Create Task",
"description": CLICKUP_TASK_CREATE_PROMPT,
},
{
"mode": "create_list",
"name": "Create List",
"description": CLICKUP_LIST_CREATE_PROMPT,
},
{
"mode": "create_folder",
"name": "Create Folder",
"description": CLICKUP_FOLDER_CREATE_PROMPT,
},
{
"mode": "get_list",
"name": "Get all lists in the space",
"description": CLICKUP_GET_LIST_PROMPT,
},
{
"mode": "get_folders",
"name": "Get all folders in the workspace",
"description": CLICKUP_GET_FOLDERS_PROMPT,
},
{
"mode": "get_spaces",
"name": "Get all spaces in the workspace",
"description": CLICKUP_GET_SPACES_PROMPT,
},
{
"mode": "update_task",
"name": "Update task",
"description": CLICKUP_UPDATE_TASK_PROMPT,
},
{
"mode": "update_task_assignees",
"name": "Update task assignees",
"description": CLICKUP_UPDATE_TASK_ASSIGNEE_PROMPT,
},
]
tools = [
ClickupAction(
name=action["name"],
description=action["description"],
mode=action["mode"],
api_wrapper=clickup_api_wrapper,
)
for action in operations
]
return cls(tools=tools)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return self.tools

View File

@@ -0,0 +1,88 @@
from typing import Any, List, Optional
from langchain_core.language_models import BaseLanguageModel
from langchain_core.memory import BaseMemory
from langchain_core.messages import SystemMessage
from langchain_core.prompts.chat import MessagesPlaceholder
from langchain_integrations.agents.agent import AgentExecutor
from langchain_integrations.agents.openai_functions_agent.agent_token_buffer_memory import (
AgentTokenBufferMemory,
)
from langchain_integrations.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain_openai.chat_model import ChatOpenAI
from langchain_integrations.memory.token_buffer import ConversationTokenBufferMemory
from langchain_core.tools import BaseTool
def _get_default_system_message() -> SystemMessage:
return SystemMessage(
content=(
"Do your best to answer the questions. "
"Feel free to use any tools available to look up "
"relevant information, only if necessary"
)
)
def create_conversational_retrieval_agent(
llm: BaseLanguageModel,
tools: List[BaseTool],
remember_intermediate_steps: bool = True,
memory_key: str = "chat_history",
system_message: Optional[SystemMessage] = None,
verbose: bool = False,
max_token_limit: int = 2000,
**kwargs: Any,
) -> AgentExecutor:
"""A convenience method for creating a conversational retrieval agent.
Args:
llm: The language model to use, should be ChatOpenAI
tools: A list of tools the agent has access to
remember_intermediate_steps: Whether the agent should remember intermediate
steps or not. Intermediate steps refer to prior action/observation
pairs from previous questions. The benefit of remembering these is if
there is relevant information in there, the agent can use it to answer
follow up questions. The downside is it will take up more tokens.
memory_key: The name of the memory key in the prompt.
system_message: The system message to use. By default, a basic one will
be used.
verbose: Whether or not the final AgentExecutor should be verbose or not,
defaults to False.
max_token_limit: The max number of tokens to keep around in memory.
Defaults to 2000.
Returns:
An agent executor initialized appropriately
"""
if not isinstance(llm, ChatOpenAI):
raise ValueError("Only supported with ChatOpenAI models.")
if remember_intermediate_steps:
memory: BaseMemory = AgentTokenBufferMemory(
memory_key=memory_key, llm=llm, max_token_limit=max_token_limit
)
else:
memory = ConversationTokenBufferMemory(
memory_key=memory_key,
return_messages=True,
output_key="output",
llm=llm,
max_token_limit=max_token_limit,
)
_system_message = system_message or _get_default_system_message()
prompt = OpenAIFunctionsAgent.create_prompt(
system_message=_system_message,
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)],
)
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
return AgentExecutor(
agent=agent,
tools=tools,
memory=memory,
verbose=verbose,
return_intermediate_steps=remember_intermediate_steps,
**kwargs,
)

View File

@@ -0,0 +1,3 @@
from langchain_integrations.tools.retriever import create_retriever_tool
__all__ = ["create_retriever_tool"]

View File

@@ -0,0 +1,22 @@
from pathlib import Path
from typing import Any
from langchain_core._api.path import as_import_path
def __getattr__(name: str) -> Any:
"""Get attr name."""
here = as_import_path(Path(__file__).parent)
old_path = "langchain." + here + "." + name
new_path = "langchain_experimental." + here + "." + name
raise AttributeError(
"This agent has been moved to langchain experiment. "
"This agent relies on python REPL tool under the hood, so to use it "
"safely please sandbox the python REPL. "
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
"and https://github.com/langchain-ai/langchain/discussions/11680"
"To keep using this code as is, install langchain experimental and "
f"update your import statement from:\n `{old_path}` to `{new_path}`."
)

View File

@@ -0,0 +1,7 @@
"""Local file management toolkit."""
from langchain_integrations.agent_toolkits.file_management.toolkit import (
FileManagementToolkit,
)
__all__ = ["FileManagementToolkit"]

View File

@@ -0,0 +1,81 @@
from __future__ import annotations
from typing import List, Optional
from langchain_core.pydantic_v1 import root_validator
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_integrations.tools import BaseTool
from langchain_integrations.tools.file_management.copy import CopyFileTool
from langchain_integrations.tools.file_management.delete import DeleteFileTool
from langchain_integrations.tools.file_management.file_search import FileSearchTool
from langchain_integrations.tools.file_management.list_dir import ListDirectoryTool
from langchain_integrations.tools.file_management.move import MoveFileTool
from langchain_integrations.tools.file_management.read import ReadFileTool
from langchain_integrations.tools.file_management.write import WriteFileTool
_FILE_TOOLS = {
# "Type[Runnable[Any, Any]]" has no attribute "__fields__" [attr-defined]
tool_cls.__fields__["name"].default: tool_cls # type: ignore[attr-defined]
for tool_cls in [
CopyFileTool,
DeleteFileTool,
FileSearchTool,
MoveFileTool,
ReadFileTool,
WriteFileTool,
ListDirectoryTool,
]
}
class FileManagementToolkit(BaseToolkit):
"""Toolkit for interacting with local files.
*Security Notice*: This toolkit provides methods to interact with local files.
If providing this toolkit to an agent on an LLM, ensure you scope
the agent's permissions to only include the necessary permissions
to perform the desired operations.
By **default** the agent will have access to all files within
the root dir and will be able to Copy, Delete, Move, Read, Write
and List files in that directory.
Consider the following:
- Limit access to particular directories using `root_dir`.
- Use filesystem permissions to restrict access and permissions to only
the files and directories required by the agent.
- Limit the tools available to the agent to only the file operations
necessary for the agent's intended use.
- Sandbox the agent by running it in a container.
See https://python.langchain.com/docs/security for more information.
"""
root_dir: Optional[str] = None
"""If specified, all file operations are made relative to root_dir."""
selected_tools: Optional[List[str]] = None
"""If provided, only provide the selected tools. Defaults to all."""
@root_validator
def validate_tools(cls, values: dict) -> dict:
selected_tools = values.get("selected_tools") or []
for tool_name in selected_tools:
if tool_name not in _FILE_TOOLS:
raise ValueError(
f"File Tool of name {tool_name} not supported."
f" Permitted tools: {list(_FILE_TOOLS)}"
)
return values
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
allowed_tools = self.selected_tools or _FILE_TOOLS.keys()
tools: List[BaseTool] = []
for tool in allowed_tools:
tool_cls = _FILE_TOOLS[tool]
tools.append(tool_cls(root_dir=self.root_dir)) # type: ignore
return tools
__all__ = ["FileManagementToolkit"]

View File

@@ -0,0 +1 @@
"""GitHub Toolkit."""

View File

@@ -0,0 +1,94 @@
"""GitHub Toolkit."""
from typing import Dict, List
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_integrations.tools import BaseTool
from langchain_integrations.tools.github.prompt import (
COMMENT_ON_ISSUE_PROMPT,
CREATE_FILE_PROMPT,
CREATE_PULL_REQUEST_PROMPT,
DELETE_FILE_PROMPT,
GET_ISSUE_PROMPT,
GET_ISSUES_PROMPT,
READ_FILE_PROMPT,
UPDATE_FILE_PROMPT,
)
from langchain_integrations.tools.github.tool import GitHubAction
from langchain_integrations.utilities.github import GitHubAPIWrapper
class GitHubToolkit(BaseToolkit):
"""GitHub Toolkit.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by creating, deleting, or updating,
reading underlying data.
For example, this toolkit can be used to create issues, pull requests,
and comments on GitHub.
See [Security](https://python.langchain.com/docs/security) for more information.
"""
tools: List[BaseTool] = []
@classmethod
def from_github_api_wrapper(
cls, github_api_wrapper: GitHubAPIWrapper
) -> "GitHubToolkit":
operations: List[Dict] = [
{
"mode": "get_issues",
"name": "Get Issues",
"description": GET_ISSUES_PROMPT,
},
{
"mode": "get_issue",
"name": "Get Issue",
"description": GET_ISSUE_PROMPT,
},
{
"mode": "comment_on_issue",
"name": "Comment on Issue",
"description": COMMENT_ON_ISSUE_PROMPT,
},
{
"mode": "create_pull_request",
"name": "Create Pull Request",
"description": CREATE_PULL_REQUEST_PROMPT,
},
{
"mode": "create_file",
"name": "Create File",
"description": CREATE_FILE_PROMPT,
},
{
"mode": "read_file",
"name": "Read File",
"description": READ_FILE_PROMPT,
},
{
"mode": "update_file",
"name": "Update File",
"description": UPDATE_FILE_PROMPT,
},
{
"mode": "delete_file",
"name": "Delete File",
"description": DELETE_FILE_PROMPT,
},
]
tools = [
GitHubAction(
name=action["name"],
description=action["description"],
mode=action["mode"],
api_wrapper=github_api_wrapper,
)
for action in operations
]
return cls(tools=tools)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return self.tools

View File

@@ -0,0 +1 @@
"""GitLab Toolkit."""

View File

@@ -0,0 +1,94 @@
"""GitHub Toolkit."""
from typing import Dict, List
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_integrations.tools import BaseTool
from langchain_integrations.tools.gitlab.prompt import (
COMMENT_ON_ISSUE_PROMPT,
CREATE_FILE_PROMPT,
CREATE_PULL_REQUEST_PROMPT,
DELETE_FILE_PROMPT,
GET_ISSUE_PROMPT,
GET_ISSUES_PROMPT,
READ_FILE_PROMPT,
UPDATE_FILE_PROMPT,
)
from langchain_integrations.tools.gitlab.tool import GitLabAction
from langchain_integrations.utilities.gitlab import GitLabAPIWrapper
class GitLabToolkit(BaseToolkit):
"""GitLab Toolkit.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by creating, deleting, or updating,
reading underlying data.
For example, this toolkit can be used to create issues, pull requests,
and comments on GitLab.
See https://python.langchain.com/docs/security for more information.
"""
tools: List[BaseTool] = []
@classmethod
def from_gitlab_api_wrapper(
cls, gitlab_api_wrapper: GitLabAPIWrapper
) -> "GitLabToolkit":
operations: List[Dict] = [
{
"mode": "get_issues",
"name": "Get Issues",
"description": GET_ISSUES_PROMPT,
},
{
"mode": "get_issue",
"name": "Get Issue",
"description": GET_ISSUE_PROMPT,
},
{
"mode": "comment_on_issue",
"name": "Comment on Issue",
"description": COMMENT_ON_ISSUE_PROMPT,
},
{
"mode": "create_pull_request",
"name": "Create Pull Request",
"description": CREATE_PULL_REQUEST_PROMPT,
},
{
"mode": "create_file",
"name": "Create File",
"description": CREATE_FILE_PROMPT,
},
{
"mode": "read_file",
"name": "Read File",
"description": READ_FILE_PROMPT,
},
{
"mode": "update_file",
"name": "Update File",
"description": UPDATE_FILE_PROMPT,
},
{
"mode": "delete_file",
"name": "Delete File",
"description": DELETE_FILE_PROMPT,
},
]
tools = [
GitLabAction(
name=action["name"],
description=action["description"],
mode=action["mode"],
api_wrapper=gitlab_api_wrapper,
)
for action in operations
]
return cls(tools=tools)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return self.tools

View File

@@ -0,0 +1 @@
"""Gmail toolkit."""

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List
from langchain_core.pydantic_v1 import Field
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_integrations.tools import BaseTool
from langchain_integrations.tools.gmail.create_draft import GmailCreateDraft
from langchain_integrations.tools.gmail.get_message import GmailGetMessage
from langchain_integrations.tools.gmail.get_thread import GmailGetThread
from langchain_integrations.tools.gmail.search import GmailSearch
from langchain_integrations.tools.gmail.send_message import GmailSendMessage
from langchain_integrations.tools.gmail.utils import build_resource_service
if TYPE_CHECKING:
# This is for linting and IDE typehints
from googleapiclient.discovery import Resource
else:
try:
# We do this so pydantic can resolve the types when instantiating
from googleapiclient.discovery import Resource
except ImportError:
pass
SCOPES = ["https://mail.google.com/"]
class GmailToolkit(BaseToolkit):
"""Toolkit for interacting with Gmail.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by reading, creating, updating, deleting
data associated with this service.
For example, this toolkit can be used to send emails on behalf of the
associated account.
See https://python.langchain.com/docs/security for more information.
"""
api_resource: Resource = Field(default_factory=build_resource_service)
class Config:
"""Pydantic config."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
GmailCreateDraft(api_resource=self.api_resource),
GmailSendMessage(api_resource=self.api_resource),
GmailSearch(api_resource=self.api_resource),
GmailGetMessage(api_resource=self.api_resource),
GmailGetThread(api_resource=self.api_resource),
]

View File

@@ -0,0 +1 @@
"""Jira Toolkit."""

View File

@@ -0,0 +1,70 @@
from typing import Dict, List
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_integrations.tools import BaseTool
from langchain_integrations.tools.jira.prompt import (
JIRA_CATCH_ALL_PROMPT,
JIRA_CONFLUENCE_PAGE_CREATE_PROMPT,
JIRA_GET_ALL_PROJECTS_PROMPT,
JIRA_ISSUE_CREATE_PROMPT,
JIRA_JQL_PROMPT,
)
from langchain_integrations.tools.jira.tool import JiraAction
from langchain_integrations.utilities.jira import JiraAPIWrapper
class JiraToolkit(BaseToolkit):
"""Jira Toolkit.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by creating, deleting, or updating,
reading underlying data.
See https://python.langchain.com/docs/security for more information.
"""
tools: List[BaseTool] = []
@classmethod
def from_jira_api_wrapper(cls, jira_api_wrapper: JiraAPIWrapper) -> "JiraToolkit":
operations: List[Dict] = [
{
"mode": "jql",
"name": "JQL Query",
"description": JIRA_JQL_PROMPT,
},
{
"mode": "get_projects",
"name": "Get Projects",
"description": JIRA_GET_ALL_PROJECTS_PROMPT,
},
{
"mode": "create_issue",
"name": "Create Issue",
"description": JIRA_ISSUE_CREATE_PROMPT,
},
{
"mode": "other",
"name": "Catch all Jira API call",
"description": JIRA_CATCH_ALL_PROMPT,
},
{
"mode": "create_page",
"name": "Create confluence page",
"description": JIRA_CONFLUENCE_PAGE_CREATE_PROMPT,
},
]
tools = [
JiraAction(
name=action["name"],
description=action["description"],
mode=action["mode"],
api_wrapper=jira_api_wrapper,
)
for action in operations
]
return cls(tools=tools)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return self.tools

View File

@@ -0,0 +1 @@
"""Json agent."""

View File

@@ -0,0 +1,49 @@
"""Json agent."""
from typing import Any, Dict, List, Optional
from langchain_core.language_models import BaseLanguageModel
from langchain_integrations.agents.agent import AgentExecutor
from langchain_integrations.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
from langchain_integrations.agent_toolkits.json.toolkit import JsonToolkit
from langchain_integrations.agents.mrkl.base import ZeroShotAgent
from langchain_integrations.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_integrations.chains.llm import LLMChain
def create_json_agent(
llm: BaseLanguageModel,
toolkit: JsonToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = JSON_PREFIX,
suffix: str = JSON_SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a json agent from an LLM and tools."""
tools = toolkit.get_tools()
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
**(agent_executor_kwargs or {}),
)

View File

@@ -0,0 +1,25 @@
# flake8: noqa
JSON_PREFIX = """You are an agent designed to interact with JSON.
Your goal is to return a final answer by interacting with the JSON.
You have access to the following tools which help you learn more about the JSON you are interacting with.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
Do not make up any information that is not contained in the JSON.
Your input to the tools should be in the form of `data["key"][0]` where `data` is the JSON blob you are interacting with, and the syntax used is Python.
You should only use keys that you know for a fact exist. You must validate that a key exists by seeing it previously when calling `json_spec_list_keys`.
If you have not seen a key in one of those responses, you cannot use it.
You should only add one key at a time to the path. You cannot add multiple keys at once.
If you encounter a "KeyError", go back to the previous key, look at the available keys, and try again.
If the question does not seem to be related to the JSON, just return "I don't know" as the answer.
Always begin your interaction with the `json_spec_list_keys` tool with input "data" to see what keys exist in the JSON.
Note that sometimes the value at a given path is large. In this case, you will get an error "Value is a large dictionary, should explore its keys directly".
In this case, you should ALWAYS follow up by using the `json_spec_list_keys` tool to see what keys exist at that path.
Do not simply refer the user to the JSON or a section of the JSON, as this is not a valid answer. Keep digging until you find the answer and explicitly return it.
"""
JSON_SUFFIX = """Begin!"
Question: {input}
Thought: I should look at the keys that exist in data to see what I have access to
{agent_scratchpad}"""

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
from typing import List
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_integrations.tools import BaseTool
from langchain_integrations.tools.json.tool import JsonGetValueTool, JsonListKeysTool, JsonSpec
class JsonToolkit(BaseToolkit):
"""Toolkit for interacting with a JSON spec."""
spec: JsonSpec
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
JsonListKeysTool(spec=self.spec),
JsonGetValueTool(spec=self.spec),
]

View File

@@ -0,0 +1 @@
"""MultiOn Toolkit."""

View File

@@ -0,0 +1,33 @@
"""MultiOn agent."""
from __future__ import annotations
from typing import List
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_integrations.tools import BaseTool
from langchain_integrations.tools.multion.close_session import MultionCloseSession
from langchain_integrations.tools.multion.create_session import MultionCreateSession
from langchain_integrations.tools.multion.update_session import MultionUpdateSession
class MultionToolkit(BaseToolkit):
"""Toolkit for interacting with the Browser Agent.
**Security Note**: This toolkit contains tools that interact with the
user's browser via the multion API which grants an agent
access to the user's browser.
Please review the documentation for the multion API to understand
the security implications of using this toolkit.
See https://python.langchain.com/docs/security for more information.
"""
class Config:
"""Pydantic config."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [MultionCreateSession(), MultionUpdateSession(), MultionCloseSession()]

View File

@@ -0,0 +1,55 @@
"""Tool for interacting with a single API with natural language definition."""
from typing import Any, Optional
from langchain_core.language_models import BaseLanguageModel
from langchain_integrations.agents.tools import Tool
from langchain_integrations.chains.api.openapi.chain import OpenAPIEndpointChain
from langchain_integrations.tools.openapi.utils.api_models import APIOperation
from langchain_integrations.tools.openapi.utils.openapi_utils import OpenAPISpec
from langchain_integrations.utilities.requests import Requests
class NLATool(Tool):
"""Natural Language API Tool."""
@classmethod
def from_open_api_endpoint_chain(
cls, chain: OpenAPIEndpointChain, api_title: str
) -> "NLATool":
"""Convert an endpoint chain to an API endpoint tool."""
expanded_name = (
f'{api_title.replace(" ", "_")}.{chain.api_operation.operation_id}'
)
description = (
f"I'm an AI from {api_title}. Instruct what you want,"
" and I'll assist via an API with description:"
f" {chain.api_operation.description}"
)
return cls(name=expanded_name, func=chain.run, description=description)
@classmethod
def from_llm_and_method(
cls,
llm: BaseLanguageModel,
path: str,
method: str,
spec: OpenAPISpec,
requests: Optional[Requests] = None,
verbose: bool = False,
return_intermediate_steps: bool = False,
**kwargs: Any,
) -> "NLATool":
"""Instantiate the tool from the specified path and method."""
api_operation = APIOperation.from_openapi_spec(spec, path, method)
chain = OpenAPIEndpointChain.from_api_operation(
api_operation,
llm,
requests=requests,
verbose=verbose,
return_intermediate_steps=return_intermediate_steps,
**kwargs,
)
return cls.from_open_api_endpoint_chain(chain, spec.info.title)

View File

@@ -0,0 +1,127 @@
from __future__ import annotations
from typing import Any, List, Optional, Sequence
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Field
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_integrations.agent_toolkits.nla.tool import NLATool
from langchain_core.tools import BaseTool
from langchain_integrations.tools.openapi.utils.openapi_utils import OpenAPISpec
from langchain_integrations.tools.plugin import AIPlugin
from langchain_integrations.utilities.requests import Requests
class NLAToolkit(BaseToolkit):
"""Natural Language API Toolkit.
*Security Note*: This toolkit creates tools that enable making calls
to an Open API compliant API.
The tools created by this toolkit may be able to make GET, POST,
PATCH, PUT, DELETE requests to any of the exposed endpoints on
the API.
Control access to who can use this toolkit.
See https://python.langchain.com/docs/security for more information.
"""
nla_tools: Sequence[NLATool] = Field(...)
"""List of API Endpoint Tools."""
def get_tools(self) -> List[BaseTool]:
"""Get the tools for all the API operations."""
return list(self.nla_tools)
@staticmethod
def _get_http_operation_tools(
llm: BaseLanguageModel,
spec: OpenAPISpec,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> List[NLATool]:
"""Get the tools for all the API operations."""
if not spec.paths:
return []
http_operation_tools = []
for path in spec.paths:
for method in spec.get_methods_for_path(path):
endpoint_tool = NLATool.from_llm_and_method(
llm=llm,
path=path,
method=method,
spec=spec,
requests=requests,
verbose=verbose,
**kwargs,
)
http_operation_tools.append(endpoint_tool)
return http_operation_tools
@classmethod
def from_llm_and_spec(
cls,
llm: BaseLanguageModel,
spec: OpenAPISpec,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> NLAToolkit:
"""Instantiate the toolkit by creating tools for each operation."""
http_operation_tools = cls._get_http_operation_tools(
llm=llm, spec=spec, requests=requests, verbose=verbose, **kwargs
)
return cls(nla_tools=http_operation_tools)
@classmethod
def from_llm_and_url(
cls,
llm: BaseLanguageModel,
open_api_url: str,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> NLAToolkit:
"""Instantiate the toolkit from an OpenAPI Spec URL"""
spec = OpenAPISpec.from_url(open_api_url)
return cls.from_llm_and_spec(
llm=llm, spec=spec, requests=requests, verbose=verbose, **kwargs
)
@classmethod
def from_llm_and_ai_plugin(
cls,
llm: BaseLanguageModel,
ai_plugin: AIPlugin,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> NLAToolkit:
"""Instantiate the toolkit from an OpenAPI Spec URL"""
spec = OpenAPISpec.from_url(ai_plugin.api.url)
# TODO: Merge optional Auth information with the `requests` argument
return cls.from_llm_and_spec(
llm=llm,
spec=spec,
requests=requests,
verbose=verbose,
**kwargs,
)
@classmethod
def from_llm_and_ai_plugin_url(
cls,
llm: BaseLanguageModel,
ai_plugin_url: str,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> NLAToolkit:
"""Instantiate the toolkit from an OpenAPI Spec URL"""
plugin = AIPlugin.from_url(ai_plugin_url)
return cls.from_llm_and_ai_plugin(
llm=llm, ai_plugin=plugin, requests=requests, verbose=verbose, **kwargs
)

View File

@@ -0,0 +1 @@
"""Office365 toolkit."""

View File

@@ -0,0 +1,51 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List
from langchain_core.pydantic_v1 import Field
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_integrations.tools import BaseTool
from langchain_integrations.tools.office365.create_draft_message import O365CreateDraftMessage
from langchain_integrations.tools.office365.events_search import O365SearchEvents
from langchain_integrations.tools.office365.messages_search import O365SearchEmails
from langchain_integrations.tools.office365.send_event import O365SendEvent
from langchain_integrations.tools.office365.send_message import O365SendMessage
from langchain_integrations.tools.office365.utils import authenticate
if TYPE_CHECKING:
from O365 import Account
class O365Toolkit(BaseToolkit):
"""Toolkit for interacting with Office 365.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by reading, creating, updating, deleting
data associated with this service.
For example, this toolkit can be used search through emails and events,
send messages and event invites, and create draft messages.
Please make sure that the permissions given by this toolkit
are appropriate for your use case.
See https://python.langchain.com/docs/security for more information.
"""
account: Account = Field(default_factory=authenticate)
class Config:
"""Pydantic config."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
O365SearchEvents(),
O365CreateDraftMessage(),
O365SearchEmails(),
O365SendEvent(),
O365SendMessage(),
]

View File

@@ -0,0 +1 @@
"""OpenAPI spec agent."""

View File

@@ -0,0 +1,73 @@
"""OpenAPI spec agent."""
from typing import Any, Dict, List, Optional
from langchain_core.language_models import BaseLanguageModel
from langchain_integrations.agents.agent import AgentExecutor
from langchain_integrations.agent_toolkits.openapi.prompt import (
OPENAPI_PREFIX,
OPENAPI_SUFFIX,
)
from langchain_integrations.agent_toolkits.openapi.toolkit import OpenAPIToolkit
from langchain_integrations.agents.mrkl.base import ZeroShotAgent
from langchain_integrations.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_integrations.chains.llm import LLMChain
def create_openapi_agent(
llm: BaseLanguageModel,
toolkit: OpenAPIToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = OPENAPI_PREFIX,
suffix: str = OPENAPI_SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
verbose: bool = False,
return_intermediate_steps: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct an OpenAPI agent from an LLM and tools.
*Security Note*: When creating an OpenAPI agent, check the permissions
and capabilities of the underlying toolkit.
For example, if the default implementation of OpenAPIToolkit
uses the RequestsToolkit which contains tools to make arbitrary
network requests against any URL (e.g., GET, POST, PATCH, PUT, DELETE),
Control access to who can submit issue requests using this toolkit and
what network access it has.
See https://python.langchain.com/docs/security for more information.
"""
tools = toolkit.get_tools()
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
return_intermediate_steps=return_intermediate_steps,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)

View File

@@ -0,0 +1,357 @@
"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach."""
import json
import re
from functools import partial
from typing import Any, Callable, Dict, List, Optional
import yaml
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_integrations.agents.agent import AgentExecutor
from langchain_integrations.agent_toolkits.openapi.planner_prompt import (
API_CONTROLLER_PROMPT,
API_CONTROLLER_TOOL_DESCRIPTION,
API_CONTROLLER_TOOL_NAME,
API_ORCHESTRATOR_PROMPT,
API_PLANNER_PROMPT,
API_PLANNER_TOOL_DESCRIPTION,
API_PLANNER_TOOL_NAME,
PARSING_DELETE_PROMPT,
PARSING_GET_PROMPT,
PARSING_PATCH_PROMPT,
PARSING_POST_PROMPT,
PARSING_PUT_PROMPT,
REQUESTS_DELETE_TOOL_DESCRIPTION,
REQUESTS_GET_TOOL_DESCRIPTION,
REQUESTS_PATCH_TOOL_DESCRIPTION,
REQUESTS_POST_TOOL_DESCRIPTION,
REQUESTS_PUT_TOOL_DESCRIPTION,
)
from langchain_integrations.agent_toolkits.openapi.spec import ReducedOpenAPISpec
from langchain_integrations.agents.mrkl.base import ZeroShotAgent
from langchain_integrations.agents.tools import Tool
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_integrations.chains.llm import LLMChain
from langchain_openai.llm import OpenAI
from langchain_integrations.memory import ReadOnlySharedMemory
from langchain_core.tools import BaseTool
from langchain_integrations.tools.requests.tool import BaseRequestsTool
from langchain_integrations.utilities.requests import RequestsWrapper
#
# Requests tools with LLM-instructed extraction of truncated responses.
#
# Of course, truncating so bluntly may lose a lot of valuable
# information in the response.
# However, the goal for now is to have only a single inference step.
MAX_RESPONSE_LENGTH = 5000
"""Maximum length of the response to be returned."""
def _get_default_llm_chain(prompt: BasePromptTemplate) -> LLMChain:
return LLMChain(
llm=OpenAI(),
prompt=prompt,
)
def _get_default_llm_chain_factory(
prompt: BasePromptTemplate,
) -> Callable[[], LLMChain]:
"""Returns a default LLMChain factory."""
return partial(_get_default_llm_chain, prompt)
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests GET tool with LLM-instructed extraction of truncated responses."""
name: str = "requests_get"
"""Tool name."""
description = REQUESTS_GET_TOOL_DESCRIPTION
"""Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
try:
data = json.loads(text)
except json.JSONDecodeError as e:
raise e
data_params = data.get("params")
response = self.requests_wrapper.get(data["url"], params=data_params)
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests POST tool with LLM-instructed extraction of truncated responses."""
name: str = "requests_post"
"""Tool name."""
description = REQUESTS_POST_TOOL_DESCRIPTION
"""Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
try:
data = json.loads(text)
except json.JSONDecodeError as e:
raise e
response = self.requests_wrapper.post(data["url"], data["data"])
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests PATCH tool with LLM-instructed extraction of truncated responses."""
name: str = "requests_patch"
"""Tool name."""
description = REQUESTS_PATCH_TOOL_DESCRIPTION
"""Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
try:
data = json.loads(text)
except json.JSONDecodeError as e:
raise e
response = self.requests_wrapper.patch(data["url"], data["data"])
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests PUT tool with LLM-instructed extraction of truncated responses."""
name: str = "requests_put"
"""Tool name."""
description = REQUESTS_PUT_TOOL_DESCRIPTION
"""Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_PUT_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
try:
data = json.loads(text)
except json.JSONDecodeError as e:
raise e
response = self.requests_wrapper.put(data["url"], data["data"])
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
"""A tool that sends a DELETE request and parses the response."""
name: str = "requests_delete"
"""The name of the tool."""
description = REQUESTS_DELETE_TOOL_DESCRIPTION
"""The description of the tool."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
"""The maximum length of the response."""
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_DELETE_PROMPT)
)
"""The LLM chain used to parse the response."""
def _run(self, text: str) -> str:
try:
data = json.loads(text)
except json.JSONDecodeError as e:
raise e
response = self.requests_wrapper.delete(data["url"])
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
#
# Orchestrator, planner, controller.
#
def _create_api_planner_tool(
api_spec: ReducedOpenAPISpec, llm: BaseLanguageModel
) -> Tool:
endpoint_descriptions = [
f"{name} {description}" for name, description, _ in api_spec.endpoints
]
prompt = PromptTemplate(
template=API_PLANNER_PROMPT,
input_variables=["query"],
partial_variables={"endpoints": "- " + "- ".join(endpoint_descriptions)},
)
chain = LLMChain(llm=llm, prompt=prompt)
tool = Tool(
name=API_PLANNER_TOOL_NAME,
description=API_PLANNER_TOOL_DESCRIPTION,
func=chain.run,
)
return tool
def _create_api_controller_agent(
api_url: str,
api_docs: str,
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
) -> AgentExecutor:
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
tools: List[BaseTool] = [
RequestsGetToolWithParsing(
requests_wrapper=requests_wrapper, llm_chain=get_llm_chain
),
RequestsPostToolWithParsing(
requests_wrapper=requests_wrapper, llm_chain=post_llm_chain
),
]
prompt = PromptTemplate(
template=API_CONTROLLER_PROMPT,
input_variables=["input", "agent_scratchpad"],
partial_variables={
"api_url": api_url,
"api_docs": api_docs,
"tool_names": ", ".join([tool.name for tool in tools]),
"tool_descriptions": "\n".join(
[f"{tool.name}: {tool.description}" for tool in tools]
),
},
)
agent = ZeroShotAgent(
llm_chain=LLMChain(llm=llm, prompt=prompt),
allowed_tools=[tool.name for tool in tools],
)
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
def _create_api_controller_tool(
api_spec: ReducedOpenAPISpec,
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
) -> Tool:
"""Expose controller as a tool.
The tool is invoked with a plan from the planner, and dynamically
creates a controller agent with relevant documentation only to
constrain the context.
"""
base_url = api_spec.servers[0]["url"] # TODO: do better.
def _create_and_run_api_controller_agent(plan_str: str) -> str:
pattern = r"\b(GET|POST|PATCH|DELETE)\s+(/\S+)*"
matches = re.findall(pattern, plan_str)
endpoint_names = [
"{method} {route}".format(method=method, route=route.split("?")[0])
for method, route in matches
]
docs_str = ""
for endpoint_name in endpoint_names:
found_match = False
for name, _, docs in api_spec.endpoints:
regex_name = re.compile(re.sub("\{.*?\}", ".*", name))
if regex_name.match(endpoint_name):
found_match = True
docs_str += f"== Docs for {endpoint_name} == \n{yaml.dump(docs)}\n"
if not found_match:
raise ValueError(f"{endpoint_name} endpoint does not exist.")
agent = _create_api_controller_agent(base_url, docs_str, requests_wrapper, llm)
return agent.run(plan_str)
return Tool(
name=API_CONTROLLER_TOOL_NAME,
func=_create_and_run_api_controller_agent,
description=API_CONTROLLER_TOOL_DESCRIPTION,
)
def create_openapi_agent(
api_spec: ReducedOpenAPISpec,
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
shared_memory: Optional[ReadOnlySharedMemory] = None,
callback_manager: Optional[BaseCallbackManager] = None,
verbose: bool = True,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Instantiate OpenAI API planner and controller for a given spec.
Inject credentials via requests_wrapper.
We use a top-level "orchestrator" agent to invoke the planner and controller,
rather than a top-level planner
that invokes a controller with its plan. This is to keep the planner simple.
"""
tools = [
_create_api_planner_tool(api_spec, llm),
_create_api_controller_tool(api_spec, requests_wrapper, llm),
]
prompt = PromptTemplate(
template=API_ORCHESTRATOR_PROMPT,
input_variables=["input", "agent_scratchpad"],
partial_variables={
"tool_names": ", ".join([tool.name for tool in tools]),
"tool_descriptions": "\n".join(
[f"{tool.name}: {tool.description}" for tool in tools]
),
},
)
agent = ZeroShotAgent(
llm_chain=LLMChain(llm=llm, prompt=prompt, memory=shared_memory),
allowed_tools=[tool.name for tool in tools],
**kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
**(agent_executor_kwargs or {}),
)

View File

@@ -0,0 +1,235 @@
# flake8: noqa
from langchain_core.prompts.prompt import PromptTemplate
API_PLANNER_PROMPT = """You are a planner that plans a sequence of API calls to assist with user queries against an API.
You should:
1) evaluate whether the user query can be solved by the API documentated below. If no, say why.
2) if yes, generate a plan of API calls and say what they are doing step by step.
3) If the plan includes a DELETE call, you should always return an ask from the User for authorization first unless the User has specifically asked to delete something.
You should only use API endpoints documented below ("Endpoints you can use:").
You can only use the DELETE tool if the User has specifically asked to delete something. Otherwise, you should return a request authorization from the User first.
Some user queries can be resolved in a single API call, but some will require several API calls.
The plan will be passed to an API controller that can format it into web requests and return the responses.
----
Here are some examples:
Fake endpoints for examples:
GET /user to get information about the current user
GET /products/search search across products
POST /users/{{id}}/cart to add products to a user's cart
PATCH /users/{{id}}/cart to update a user's cart
PUT /users/{{id}}/coupon to apply idempotent coupon to a user's cart
DELETE /users/{{id}}/cart to delete a user's cart
User query: tell me a joke
Plan: Sorry, this API's domain is shopping, not comedy.
User query: I want to buy a couch
Plan: 1. GET /products with a query param to search for couches
2. GET /user to find the user's id
3. POST /users/{{id}}/cart to add a couch to the user's cart
User query: I want to add a lamp to my cart
Plan: 1. GET /products with a query param to search for lamps
2. GET /user to find the user's id
3. PATCH /users/{{id}}/cart to add a lamp to the user's cart
User query: I want to add a coupon to my cart
Plan: 1. GET /user to find the user's id
2. PUT /users/{{id}}/coupon to apply the coupon
User query: I want to delete my cart
Plan: 1. GET /user to find the user's id
2. DELETE required. Did user specify DELETE or previously authorize? Yes, proceed.
3. DELETE /users/{{id}}/cart to delete the user's cart
User query: I want to start a new cart
Plan: 1. GET /user to find the user's id
2. DELETE required. Did user specify DELETE or previously authorize? No, ask for authorization.
3. Are you sure you want to delete your cart?
----
Here are endpoints you can use. Do not reference any of the endpoints above.
{endpoints}
----
User query: {query}
Plan:"""
API_PLANNER_TOOL_NAME = "api_planner"
API_PLANNER_TOOL_DESCRIPTION = f"Can be used to generate the right API calls to assist with a user query, like {API_PLANNER_TOOL_NAME}(query). Should always be called before trying to call the API controller."
# Execution.
API_CONTROLLER_PROMPT = """You are an agent that gets a sequence of API calls and given their documentation, should execute them and return the final response.
If you cannot complete them and run into issues, you should explain the issue. If you're unable to resolve an API call, you can retry the API call. When interacting with API objects, you should extract ids for inputs to other API calls but ids and names for outputs returned to the User.
Here is documentation on the API:
Base url: {api_url}
Endpoints:
{api_docs}
Here are tools to execute requests against the API: {tool_descriptions}
Starting below, you should follow this format:
Plan: the plan of API calls to execute
Thought: you should always think about what to do
Action: the action to take, should be one of the tools [{tool_names}]
Action Input: the input to the action
Observation: the output of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I am finished executing the plan (or, I cannot finish executing the plan without knowing some other information.)
Final Answer: the final output from executing the plan or missing information I'd need to re-plan correctly.
Begin!
Plan: {input}
Thought:
{agent_scratchpad}
"""
API_CONTROLLER_TOOL_NAME = "api_controller"
API_CONTROLLER_TOOL_DESCRIPTION = f"Can be used to execute a plan of API calls, like {API_CONTROLLER_TOOL_NAME}(plan)."
# Orchestrate planning + execution.
# The goal is to have an agent at the top-level (e.g. so it can recover from errors and re-plan) while
# keeping planning (and specifically the planning prompt) simple.
API_ORCHESTRATOR_PROMPT = """You are an agent that assists with user queries against API, things like querying information or creating resources.
Some user queries can be resolved in a single API call, particularly if you can find appropriate params from the OpenAPI spec; though some require several API calls.
You should always plan your API calls first, and then execute the plan second.
If the plan includes a DELETE call, be sure to ask the User for authorization first unless the User has specifically asked to delete something.
You should never return information without executing the api_controller tool.
Here are the tools to plan and execute API requests: {tool_descriptions}
Starting below, you should follow this format:
User query: the query a User wants help with related to the API
Thought: you should always think about what to do
Action: the action to take, should be one of the tools [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I am finished executing a plan and have the information the user asked for or the data the user asked to create
Final Answer: the final output from executing the plan
Example:
User query: can you add some trendy stuff to my shopping cart.
Thought: I should plan API calls first.
Action: api_planner
Action Input: I need to find the right API calls to add trendy items to the users shopping cart
Observation: 1) GET /items with params 'trending' is 'True' to get trending item ids
2) GET /user to get user
3) POST /cart to post the trending items to the user's cart
Thought: I'm ready to execute the API calls.
Action: api_controller
Action Input: 1) GET /items params 'trending' is 'True' to get trending item ids
2) GET /user to get user
3) POST /cart to post the trending items to the user's cart
...
Begin!
User query: {input}
Thought: I should generate a plan to help with this query and then copy that plan exactly to the controller.
{agent_scratchpad}"""
REQUESTS_GET_TOOL_DESCRIPTION = """Use this to GET content from a website.
Input to the tool should be a json string with 3 keys: "url", "params" and "output_instructions".
The value of "url" should be a string.
The value of "params" should be a dict of the needed and available parameters from the OpenAPI spec related to the endpoint.
If parameters are not needed, or not available, leave it empty.
The value of "output_instructions" should be instructions on what information to extract from the response,
for example the id(s) for a resource(s) that the GET request fetches.
"""
PARSING_GET_PROMPT = PromptTemplate(
template="""Here is an API response:\n\n{response}\n\n====
Your task is to extract some information according to these instructions: {instructions}
When working with API objects, you should usually use ids over names.
If the response indicates an error, you should instead output a summary of the error.
Output:""",
input_variables=["response", "instructions"],
)
REQUESTS_POST_TOOL_DESCRIPTION = """Use this when you want to POST to a website.
Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions".
The value of "url" should be a string.
The value of "data" should be a dictionary of key-value pairs you want to POST to the url.
The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the POST request creates.
Always use double quotes for strings in the json string."""
PARSING_POST_PROMPT = PromptTemplate(
template="""Here is an API response:\n\n{response}\n\n====
Your task is to extract some information according to these instructions: {instructions}
When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
If the response indicates an error, you should instead output a summary of the error.
Output:""",
input_variables=["response", "instructions"],
)
REQUESTS_PATCH_TOOL_DESCRIPTION = """Use this when you want to PATCH content on a website.
Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions".
The value of "url" should be a string.
The value of "data" should be a dictionary of key-value pairs of the body params available in the OpenAPI spec you want to PATCH the content with at the url.
The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the PATCH request creates.
Always use double quotes for strings in the json string."""
PARSING_PATCH_PROMPT = PromptTemplate(
template="""Here is an API response:\n\n{response}\n\n====
Your task is to extract some information according to these instructions: {instructions}
When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
If the response indicates an error, you should instead output a summary of the error.
Output:""",
input_variables=["response", "instructions"],
)
REQUESTS_PUT_TOOL_DESCRIPTION = """Use this when you want to PUT to a website.
Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions".
The value of "url" should be a string.
The value of "data" should be a dictionary of key-value pairs you want to PUT to the url.
The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the PUT request creates.
Always use double quotes for strings in the json string."""
PARSING_PUT_PROMPT = PromptTemplate(
template="""Here is an API response:\n\n{response}\n\n====
Your task is to extract some information according to these instructions: {instructions}
When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
If the response indicates an error, you should instead output a summary of the error.
Output:""",
input_variables=["response", "instructions"],
)
REQUESTS_DELETE_TOOL_DESCRIPTION = """ONLY USE THIS TOOL WHEN THE USER HAS SPECIFICALLY REQUESTED TO DELETE CONTENT FROM A WEBSITE.
Input to the tool should be a json string with 2 keys: "url", and "output_instructions".
The value of "url" should be a string.
The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the DELETE request creates.
Always use double quotes for strings in the json string.
ONLY USE THIS TOOL IF THE USER HAS SPECIFICALLY REQUESTED TO DELETE SOMETHING."""
PARSING_DELETE_PROMPT = PromptTemplate(
template="""Here is an API response:\n\n{response}\n\n====
Your task is to extract some information according to these instructions: {instructions}
When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
If the response indicates an error, you should instead output a summary of the error.
Output:""",
input_variables=["response", "instructions"],
)

View File

@@ -0,0 +1,29 @@
# flake8: noqa
OPENAPI_PREFIX = """You are an agent designed to answer questions by making web requests to an API given the openapi spec.
If the question does not seem related to the API, return I don't know. Do not make up an answer.
Only use information provided by the tools to construct your response.
First, find the base URL needed to make the request.
Second, find the relevant paths needed to answer the question. Take note that, sometimes, you might need to make more than one request to more than one path to answer the question.
Third, find the required parameters needed to make the request. For GET requests, these are usually URL parameters and for POST requests, these are request body parameters.
Fourth, make the requests needed to answer the question. Ensure that you are sending the correct parameters to the request by checking which parameters are required. For parameters with a fixed set of values, please use the spec to look at which values are allowed.
Use the exact parameter names as listed in the spec, do not make up any names or abbreviate the names of parameters.
If you get a not found error, ensure that you are using a path that actually exists in the spec.
"""
OPENAPI_SUFFIX = """Begin!
Question: {input}
Thought: I should explore the spec to find the base url for the API.
{agent_scratchpad}"""
DESCRIPTION = """Can be used to answer questions about the openapi spec for the API. Always use this tool before trying to make a request.
Example inputs to this tool:
'What are the required query parameters for a GET request to the /bar endpoint?`
'What are the required parameters in the request body for a POST request to the /foo endpoint?'
Always give this tool a specific question."""

View File

@@ -0,0 +1,75 @@
"""Quick and dirty representation for OpenAPI specs."""
from dataclasses import dataclass
from typing import List, Tuple
from langchain_core.utils.json_schema import dereference_refs
@dataclass(frozen=True)
class ReducedOpenAPISpec:
"""A reduced OpenAPI spec.
This is a quick and dirty representation for OpenAPI specs.
Attributes:
servers: The servers in the spec.
description: The description of the spec.
endpoints: The endpoints in the spec.
"""
servers: List[dict]
description: str
endpoints: List[Tuple[str, str, dict]]
def reduce_openapi_spec(spec: dict, dereference: bool = True) -> ReducedOpenAPISpec:
"""Simplify/distill/minify a spec somehow.
I want a smaller target for retrieval and (more importantly)
I want smaller results from retrieval.
I was hoping https://openapi.tools/ would have some useful bits
to this end, but doesn't seem so.
"""
# 1. Consider only get, post, patch, put, delete endpoints.
endpoints = [
(f"{operation_name.upper()} {route}", docs.get("description"), docs)
for route, operation in spec["paths"].items()
for operation_name, docs in operation.items()
if operation_name in ["get", "post", "patch", "put", "delete"]
]
# 2. Replace any refs so that complete docs are retrieved.
# Note: probably want to do this post-retrieval, it blows up the size of the spec.
if dereference:
endpoints = [
(name, description, dereference_refs(docs, full_schema=spec))
for name, description, docs in endpoints
]
# 3. Strip docs down to required request args + happy path response.
def reduce_endpoint_docs(docs: dict) -> dict:
out = {}
if docs.get("description"):
out["description"] = docs.get("description")
if docs.get("parameters"):
out["parameters"] = [
parameter
for parameter in docs.get("parameters", [])
if parameter.get("required")
]
if "200" in docs["responses"]:
out["responses"] = docs["responses"]["200"]
if docs.get("requestBody"):
out["requestBody"] = docs.get("requestBody")
return out
endpoints = [
(name, description, reduce_endpoint_docs(docs))
for name, description, docs in endpoints
]
return ReducedOpenAPISpec(
servers=spec["servers"],
description=spec["info"].get("description", ""),
endpoints=endpoints,
)

View File

@@ -0,0 +1,91 @@
"""Requests toolkit."""
from __future__ import annotations
from typing import Any, List
from langchain_core.language_models import BaseLanguageModel
from langchain_integrations.agents.agent import AgentExecutor
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_integrations.agent_toolkits.json.base import create_json_agent
from langchain_integrations.agent_toolkits.json.toolkit import JsonToolkit
from langchain_integrations.agent_toolkits.openapi.prompt import DESCRIPTION
from langchain_integrations.agents.tools import Tool
from langchain_integrations.tools import BaseTool
from langchain_integrations.tools.json.tool import JsonSpec
from langchain_integrations.tools.requests.tool import (
RequestsDeleteTool,
RequestsGetTool,
RequestsPatchTool,
RequestsPostTool,
RequestsPutTool,
)
from langchain_integrations.utilities.requests import TextRequestsWrapper
class RequestsToolkit(BaseToolkit):
"""Toolkit for making REST requests.
*Security Note*: This toolkit contains tools to make GET, POST, PATCH, PUT,
and DELETE requests to an API.
Exercise care in who is allowed to use this toolkit. If exposing
to end users, consider that users will be able to make arbitrary
requests on behalf of the server hosting the code. For example,
users could ask the server to make a request to a private API
that is only accessible from the server.
Control access to who can submit issue requests using this toolkit and
what network access it has.
See https://python.langchain.com/docs/security for more information.
"""
requests_wrapper: TextRequestsWrapper
def get_tools(self) -> List[BaseTool]:
"""Return a list of tools."""
return [
RequestsGetTool(requests_wrapper=self.requests_wrapper),
RequestsPostTool(requests_wrapper=self.requests_wrapper),
RequestsPatchTool(requests_wrapper=self.requests_wrapper),
RequestsPutTool(requests_wrapper=self.requests_wrapper),
RequestsDeleteTool(requests_wrapper=self.requests_wrapper),
]
class OpenAPIToolkit(BaseToolkit):
"""Toolkit for interacting with an OpenAPI API.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by creating, deleting, or updating,
reading underlying data.
For example, this toolkit can be used to delete data exposed via
an OpenAPI compliant API.
"""
json_agent: AgentExecutor
requests_wrapper: TextRequestsWrapper
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
json_agent_tool = Tool(
name="json_explorer",
func=self.json_agent.run,
description=DESCRIPTION,
)
request_toolkit = RequestsToolkit(requests_wrapper=self.requests_wrapper)
return [*request_toolkit.get_tools(), json_agent_tool]
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
json_spec: JsonSpec,
requests_wrapper: TextRequestsWrapper,
**kwargs: Any,
) -> OpenAPIToolkit:
"""Create json agent from llm, then initialize."""
json_agent = create_json_agent(llm, JsonToolkit(spec=json_spec), **kwargs)
return cls(json_agent=json_agent, requests_wrapper=requests_wrapper)

View File

@@ -0,0 +1,22 @@
from pathlib import Path
from typing import Any
from langchain_core._api.path import as_import_path
def __getattr__(name: str) -> Any:
"""Get attr name."""
here = as_import_path(Path(__file__).parent)
old_path = "langchain." + here + "." + name
new_path = "langchain_experimental." + here + "." + name
raise AttributeError(
"This agent has been moved to langchain experiment. "
"This agent relies on python REPL tool under the hood, so to use it "
"safely please sandbox the python REPL. "
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
"and https://github.com/langchain-ai/langchain/discussions/11680"
"To keep using this code as is, install langchain experimental and "
f"update your import statement from:\n `{old_path}` to `{new_path}`."
)

View File

@@ -0,0 +1,4 @@
"""Playwright browser toolkit."""
from langchain_integrations.agent_toolkits.playwright.toolkit import PlayWrightBrowserToolkit
__all__ = ["PlayWrightBrowserToolkit"]

View File

@@ -0,0 +1,108 @@
"""Playwright web browser toolkit."""
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional, Type, cast
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_core.tools import BaseTool
from langchain_integrations.tools.playwright.base import (
BaseBrowserTool,
lazy_import_playwright_browsers,
)
from langchain_integrations.tools.playwright.click import ClickTool
from langchain_integrations.tools.playwright.current_page import CurrentWebPageTool
from langchain_integrations.tools.playwright.extract_hyperlinks import ExtractHyperlinksTool
from langchain_integrations.tools.playwright.extract_text import ExtractTextTool
from langchain_integrations.tools.playwright.get_elements import GetElementsTool
from langchain_integrations.tools.playwright.navigate import NavigateTool
from langchain_integrations.tools.playwright.navigate_back import NavigateBackTool
if TYPE_CHECKING:
from playwright.async_api import Browser as AsyncBrowser
from playwright.sync_api import Browser as SyncBrowser
else:
try:
# We do this so pydantic can resolve the types when instantiating
from playwright.async_api import Browser as AsyncBrowser
from playwright.sync_api import Browser as SyncBrowser
except ImportError:
pass
class PlayWrightBrowserToolkit(BaseToolkit):
"""Toolkit for PlayWright browser tools.
**Security Note**: This toolkit provides code to control a web-browser.
Careful if exposing this toolkit to end-users. The tools in the toolkit
are capable of navigating to arbitrary webpages, clicking on arbitrary
elements, and extracting arbitrary text and hyperlinks from webpages.
Specifically, by default this toolkit allows navigating to:
- Any URL (including any internal network URLs)
- And local files
If exposing to end-users, consider limiting network access to the
server that hosts the agent; in addition, consider it is advised
to create a custom NavigationTool wht an args_schema that limits the URLs
that can be navigated to (e.g., only allow navigating to URLs that
start with a particular prefix).
Remember to scope permissions to the minimal permissions necessary for
the application. If the default tool selection is not appropriate for
the application, consider creating a custom toolkit with the appropriate
tools.
See https://python.langchain.com/docs/security for more information.
"""
sync_browser: Optional["SyncBrowser"] = None
async_browser: Optional["AsyncBrowser"] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator
def validate_imports_and_browser_provided(cls, values: dict) -> dict:
"""Check that the arguments are valid."""
lazy_import_playwright_browsers()
if values.get("async_browser") is None and values.get("sync_browser") is None:
raise ValueError("Either async_browser or sync_browser must be specified.")
return values
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
tool_classes: List[Type[BaseBrowserTool]] = [
ClickTool,
NavigateTool,
NavigateBackTool,
ExtractTextTool,
ExtractHyperlinksTool,
GetElementsTool,
CurrentWebPageTool,
]
tools = [
tool_cls.from_browser(
sync_browser=self.sync_browser, async_browser=self.async_browser
)
for tool_cls in tool_classes
]
return cast(List[BaseTool], tools)
@classmethod
def from_browser(
cls,
sync_browser: Optional[SyncBrowser] = None,
async_browser: Optional[AsyncBrowser] = None,
) -> PlayWrightBrowserToolkit:
"""Instantiate the toolkit."""
# This is to raise a better error than the forward ref ones Pydantic would have
lazy_import_playwright_browsers()
return cls(sync_browser=sync_browser, async_browser=async_browser)

View File

@@ -0,0 +1 @@
"""Power BI agent."""

View File

@@ -0,0 +1,63 @@
"""Power BI agent."""
from typing import Any, Dict, List, Optional
from langchain_core.language_models import BaseLanguageModel
from langchain_integrations.agents import AgentExecutor
from langchain_integrations.agent_toolkits.powerbi.prompt import (
POWERBI_PREFIX,
POWERBI_SUFFIX,
)
from langchain_integrations.agent_toolkits.powerbi.toolkit import PowerBIToolkit
from langchain_integrations.agents.mrkl.base import ZeroShotAgent
from langchain_integrations.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_integrations.chains.llm import LLMChain
from langchain_integrations.utilities.powerbi import PowerBIDataset
def create_pbi_agent(
llm: BaseLanguageModel,
toolkit: Optional[PowerBIToolkit] = None,
powerbi: Optional[PowerBIDataset] = None,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = POWERBI_PREFIX,
suffix: str = POWERBI_SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
examples: Optional[str] = None,
input_variables: Optional[List[str]] = None,
top_k: int = 10,
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a Power BI agent from an LLM and tools."""
if toolkit is None:
if powerbi is None:
raise ValueError("Must provide either a toolkit or powerbi dataset")
toolkit = PowerBIToolkit(powerbi=powerbi, llm=llm, examples=examples)
tools = toolkit.get_tools()
tables = powerbi.table_names if powerbi else toolkit.powerbi.table_names
agent = ZeroShotAgent(
llm_chain=LLMChain(
llm=llm,
prompt=ZeroShotAgent.create_prompt(
tools,
prefix=prefix.format(top_k=top_k).format(tables=tables),
suffix=suffix,
format_instructions=format_instructions,
input_variables=input_variables,
),
callback_manager=callback_manager, # type: ignore
verbose=verbose,
),
allowed_tools=[tool.name for tool in tools],
**kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
**(agent_executor_kwargs or {}),
)

View File

@@ -0,0 +1,64 @@
"""Power BI agent."""
from typing import Any, Dict, List, Optional
from langchain_integrations.agents import AgentExecutor
from langchain_integrations.agents.agent import AgentOutputParser
from langchain_integrations.agent_toolkits.powerbi.prompt import (
POWERBI_CHAT_PREFIX,
POWERBI_CHAT_SUFFIX,
)
from langchain_integrations.agent_toolkits.powerbi.toolkit import PowerBIToolkit
from langchain_integrations.agents.conversational_chat.base import ConversationalChatAgent
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_integrations.chat_models.base import BaseChatModel
from langchain_integrations.memory import ConversationBufferMemory
from langchain_integrations.memory.chat_memory import BaseChatMemory
from langchain_integrations.utilities.powerbi import PowerBIDataset
def create_pbi_chat_agent(
llm: BaseChatModel,
toolkit: Optional[PowerBIToolkit] = None,
powerbi: Optional[PowerBIDataset] = None,
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = POWERBI_CHAT_PREFIX,
suffix: str = POWERBI_CHAT_SUFFIX,
examples: Optional[str] = None,
input_variables: Optional[List[str]] = None,
memory: Optional[BaseChatMemory] = None,
top_k: int = 10,
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a Power BI agent from a Chat LLM and tools.
If you supply only a toolkit and no Power BI dataset, the same LLM is used for both.
"""
if toolkit is None:
if powerbi is None:
raise ValueError("Must provide either a toolkit or powerbi dataset")
toolkit = PowerBIToolkit(powerbi=powerbi, llm=llm, examples=examples)
tools = toolkit.get_tools()
tables = powerbi.table_names if powerbi else toolkit.powerbi.table_names
agent = ConversationalChatAgent.from_llm_and_tools(
llm=llm,
tools=tools,
system_message=prefix.format(top_k=top_k).format(tables=tables),
human_message=suffix,
input_variables=input_variables,
callback_manager=callback_manager,
output_parser=output_parser,
verbose=verbose,
**kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
memory=memory
or ConversationBufferMemory(memory_key="chat_history", return_messages=True),
verbose=verbose,
**(agent_executor_kwargs or {}),
)

View File

@@ -0,0 +1,38 @@
# flake8: noqa
"""Prompts for PowerBI agent."""
POWERBI_PREFIX = """You are an agent designed to help users interact with a PowerBI Dataset.
Agent has access to a tool that can write a query based on the question and then run those against PowerBI, Microsofts business intelligence tool. The questions from the users should be interpreted as related to the dataset that is available and not general questions about the world. If the question does not seem related to the dataset, return "This does not appear to be part of this dataset." as the answer.
Given an input question, ask to run the questions against the dataset, then look at the results and return the answer, the answer should be a complete sentence that answers the question, if multiple rows are asked find a way to write that in a easily readable format for a human, also make sure to represent numbers in readable ways, like 1M instead of 1000000. Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
"""
POWERBI_SUFFIX = """Begin!
Question: {input}
Thought: I can first ask which tables I have, then how each table is defined and then ask the query tool the question I need, and finally create a nice sentence that answers the question.
{agent_scratchpad}"""
POWERBI_CHAT_PREFIX = """Assistant is a large language model built to help users interact with a PowerBI Dataset.
Assistant should try to create a correct and complete answer to the question from the user. If the user asks a question not related to the dataset it should return "This does not appear to be part of this dataset." as the answer. The user might make a mistake with the spelling of certain values, if you think that is the case, ask the user to confirm the spelling of the value and then run the query again. Unless the user specifies a specific number of examples they wish to obtain, and the results are too large, limit your query to at most {top_k} results, but make it clear when answering which field was used for the filtering. The user has access to these tables: {{tables}}.
The answer should be a complete sentence that answers the question, if multiple rows are asked find a way to write that in a easily readable format for a human, also make sure to represent numbers in readable ways, like 1M instead of 1000000.
"""
POWERBI_CHAT_SUFFIX = """TOOLS
------
Assistant can ask the user to use tools to look up information that may be helpful in answering the users original question. The tools the human can use are:
{{tools}}
{format_instructions}
USER'S INPUT
--------------------
Here is the user's input (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):
{{{{input}}}}
"""

View File

@@ -0,0 +1,102 @@
"""Toolkit for interacting with a Power BI dataset."""
from typing import List, Optional, Union
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_core.pydantic_v1 import Field
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_integrations.chains.llm import LLMChain
from langchain_integrations.chat_models.base import BaseChatModel
from langchain_integrations.tools import BaseTool
from langchain_integrations.tools.powerbi.prompt import (
QUESTION_TO_QUERY_BASE,
SINGLE_QUESTION_TO_QUERY,
USER_INPUT,
)
from langchain_integrations.tools.powerbi.tool import (
InfoPowerBITool,
ListPowerBITool,
QueryPowerBITool,
)
from langchain_integrations.utilities.powerbi import PowerBIDataset
class PowerBIToolkit(BaseToolkit):
"""Toolkit for interacting with Power BI dataset.
*Security Note*: This toolkit interacts with an external service.
Control access to who can use this toolkit.
Make sure that the capabilities given by this toolkit to the calling
code are appropriately scoped to the application.
See https://python.langchain.com/docs/security for more information.
"""
powerbi: PowerBIDataset = Field(exclude=True)
llm: Union[BaseLanguageModel, BaseChatModel] = Field(exclude=True)
examples: Optional[str] = None
max_iterations: int = 5
callback_manager: Optional[BaseCallbackManager] = None
output_token_limit: Optional[int] = None
tiktoken_model_name: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
QueryPowerBITool(
llm_chain=self._get_chain(),
powerbi=self.powerbi,
examples=self.examples,
max_iterations=self.max_iterations,
output_token_limit=self.output_token_limit,
tiktoken_model_name=self.tiktoken_model_name,
),
InfoPowerBITool(powerbi=self.powerbi),
ListPowerBITool(powerbi=self.powerbi),
]
def _get_chain(self) -> LLMChain:
"""Construct the chain based on the callback manager and model type."""
if isinstance(self.llm, BaseLanguageModel):
return LLMChain(
llm=self.llm,
callback_manager=self.callback_manager
if self.callback_manager
else None,
prompt=PromptTemplate(
template=SINGLE_QUESTION_TO_QUERY,
input_variables=["tool_input", "tables", "schemas", "examples"],
),
)
system_prompt = SystemMessagePromptTemplate(
prompt=PromptTemplate(
template=QUESTION_TO_QUERY_BASE,
input_variables=["tables", "schemas", "examples"],
)
)
human_prompt = HumanMessagePromptTemplate(
prompt=PromptTemplate(
template=USER_INPUT,
input_variables=["tool_input"],
)
)
return LLMChain(
llm=self.llm,
callback_manager=self.callback_manager if self.callback_manager else None,
prompt=ChatPromptTemplate.from_messages([system_prompt, human_prompt]),
)

View File

@@ -0,0 +1,22 @@
from pathlib import Path
from typing import Any
from langchain_core._api.path import as_import_path
def __getattr__(name: str) -> Any:
"""Get attr name."""
here = as_import_path(Path(__file__).parent)
old_path = "langchain." + here + "." + name
new_path = "langchain_experimental." + here + "." + name
raise AttributeError(
"This agent has been moved to langchain experiment. "
"This agent relies on python REPL tool under the hood, so to use it "
"safely please sandbox the python REPL. "
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
"and https://github.com/langchain-ai/langchain/discussions/11680"
"To keep using this code as is, install langchain experimental and "
f"update your import statement from:\n `{old_path}` to `{new_path}`."
)

View File

@@ -0,0 +1,22 @@
from pathlib import Path
from typing import Any
from langchain_core._api.path import as_import_path
def __getattr__(name: str) -> Any:
"""Get attr name."""
here = as_import_path(Path(__file__).parent)
old_path = "langchain." + here + "." + name
new_path = "langchain_experimental." + here + "." + name
raise AttributeError(
"This agent has been moved to langchain experiment. "
"This agent relies on python REPL tool under the hood, so to use it "
"safely please sandbox the python REPL. "
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
"and https://github.com/langchain-ai/langchain/discussions/11680"
"To keep using this code as is, install langchain experimental and "
f"update your import statement from:\n `{old_path}` to `{new_path}`."
)

View File

@@ -0,0 +1 @@
"""Spark SQL agent."""

View File

@@ -0,0 +1,60 @@
"""Spark SQL agent."""
from typing import Any, Dict, List, Optional
from langchain_core.language_models import BaseLanguageModel
from langchain_integrations.agents.agent import AgentExecutor
from langchain_integrations.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUFFIX
from langchain_integrations.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
from langchain_integrations.agents.mrkl.base import ZeroShotAgent
from langchain_integrations.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain_core.callbacks.base import BaseCallbackManager, Callbacks
from langchain_integrations.chains.llm import LLMChain
def create_spark_sql_agent(
llm: BaseLanguageModel,
toolkit: SparkSQLToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
prefix: str = SQL_PREFIX,
suffix: str = SQL_SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
top_k: int = 10,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a Spark SQL agent from an LLM and tools."""
tools = toolkit.get_tools()
prefix = prefix.format(top_k=top_k)
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
callbacks=callbacks,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
callbacks=callbacks,
verbose=verbose,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)

View File

@@ -0,0 +1,21 @@
# flake8: noqa
SQL_PREFIX = """You are an agent designed to interact with Spark SQL.
Given an input question, create a syntactically correct Spark SQL query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
"""
SQL_SUFFIX = """Begin!
Question: {input}
Thought: I should look at the tables in the database to see what I can query.
{agent_scratchpad}"""

View File

@@ -0,0 +1,36 @@
"""Toolkit for interacting with Spark SQL."""
from typing import List
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Field
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_integrations.tools import BaseTool
from langchain_integrations.tools.spark_sql.tool import (
InfoSparkSQLTool,
ListSparkSQLTool,
QueryCheckerTool,
QuerySparkSQLTool,
)
from langchain_integrations.utilities.spark_sql import SparkSQL
class SparkSQLToolkit(BaseToolkit):
"""Toolkit for interacting with Spark SQL."""
db: SparkSQL = Field(exclude=True)
llm: BaseLanguageModel = Field(exclude=True)
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
QuerySparkSQLTool(db=self.db),
InfoSparkSQLTool(db=self.db),
ListSparkSQLTool(db=self.db),
QueryCheckerTool(db=self.db, llm=self.llm),
]

View File

@@ -0,0 +1 @@
"""SQL agent."""

View File

@@ -0,0 +1,96 @@
"""SQL agent."""
from typing import Any, Dict, List, Optional, Sequence
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain_integrations.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain_integrations.agent_toolkits.sql.prompt import (
SQL_FUNCTIONS_SUFFIX,
SQL_PREFIX,
SQL_SUFFIX,
)
from langchain_integrations.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_integrations.agents.agent_types import AgentType
from langchain_integrations.agents.mrkl.base import ZeroShotAgent
from langchain_integrations.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain_integrations.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_integrations.chains.llm import LLMChain
from langchain_integrations.tools import BaseTool
def create_sql_agent(
llm: BaseLanguageModel,
toolkit: SQLDatabaseToolkit,
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = SQL_PREFIX,
suffix: Optional[str] = None,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
top_k: int = 10,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
extra_tools: Sequence[BaseTool] = (),
**kwargs: Any,
) -> AgentExecutor:
"""Construct an SQL agent from an LLM and tools."""
tools = toolkit.get_tools() + list(extra_tools)
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
agent: BaseSingleActionAgent
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix or SQL_SUFFIX,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
elif agent_type == AgentType.OPENAI_FUNCTIONS:
messages = [
SystemMessage(content=prefix),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
input_variables = ["input", "agent_scratchpad"]
_prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)
agent = OpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)
else:
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)

View File

@@ -0,0 +1,23 @@
# flake8: noqa
SQL_PREFIX = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
"""
SQL_SUFFIX = """Begin!
Question: {input}
Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
{agent_scratchpad}"""
SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."""

View File

@@ -0,0 +1,71 @@
"""Toolkit for interacting with an SQL database."""
from typing import List
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Field
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_integrations.tools import BaseTool
from langchain_integrations.tools.sql_database.tool import (
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
QuerySQLCheckerTool,
QuerySQLDataBaseTool,
)
from langchain_integrations.utilities.sql_database import SQLDatabase
class SQLDatabaseToolkit(BaseToolkit):
"""Toolkit for interacting with SQL databases."""
db: SQLDatabase = Field(exclude=True)
llm: BaseLanguageModel = Field(exclude=True)
@property
def dialect(self) -> str:
"""Return string representation of SQL dialect to use."""
return self.db.dialect
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
list_sql_database_tool = ListSQLDatabaseTool(db=self.db)
info_sql_database_tool_description = (
"Input to this tool is a comma-separated list of tables, output is the "
"schema and sample rows for those tables. "
"Be sure that the tables actually exist by calling "
f"{list_sql_database_tool.name} first! "
"Example Input: table1, table2, table3"
)
info_sql_database_tool = InfoSQLDatabaseTool(
db=self.db, description=info_sql_database_tool_description
)
query_sql_database_tool_description = (
"Input to this tool is a detailed and correct SQL query, output is a "
"result from the database. If the query is not correct, an error message "
"will be returned. If an error is returned, rewrite the query, check the "
"query, and try again. If you encounter an issue with Unknown column "
f"'xxxx' in 'field list', use {info_sql_database_tool.name} "
"to query the correct table fields."
)
query_sql_database_tool = QuerySQLDataBaseTool(
db=self.db, description=query_sql_database_tool_description
)
query_sql_checker_tool_description = (
"Use this tool to double check if your query is correct before executing "
"it. Always use this tool before executing a query with "
f"{query_sql_database_tool.name}!"
)
query_sql_checker_tool = QuerySQLCheckerTool(
db=self.db, llm=self.llm, description=query_sql_checker_tool_description
)
return [
query_sql_database_tool,
info_sql_database_tool,
list_sql_database_tool,
query_sql_checker_tool,
]

View File

@@ -0,0 +1 @@
"""Agent toolkit for interacting with vector stores."""

View File

@@ -0,0 +1,96 @@
"""VectorStore agent."""
from typing import Any, Dict, Optional
from langchain_core.language_models import BaseLanguageModel
from langchain_integrations.agents.agent import AgentExecutor
from langchain_integrations.agent_toolkits.vectorstore.prompt import PREFIX, ROUTER_PREFIX
from langchain_integrations.agent_toolkits.vectorstore.toolkit import (
VectorStoreRouterToolkit,
VectorStoreToolkit,
)
from langchain_integrations.agents.mrkl.base import ZeroShotAgent
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_integrations.chains.llm import LLMChain
def create_vectorstore_agent(
llm: BaseLanguageModel,
toolkit: VectorStoreToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = PREFIX,
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a VectorStore agent from an LLM and tools.
Args:
llm (BaseLanguageModel): LLM that will be used by the agent
toolkit (VectorStoreToolkit): Set of tools for the agent
callback_manager (Optional[BaseCallbackManager], optional): Object to handle the callback [ Defaults to None. ]
prefix (str, optional): The prefix prompt for the agent. If not provided uses default PREFIX.
verbose (bool, optional): If you want to see the content of the scratchpad. [ Defaults to False ]
agent_executor_kwargs (Optional[Dict[str, Any]], optional): If there is any other parameter you want to send to the agent. [ Defaults to None ]
**kwargs: Additional named parameters to pass to the ZeroShotAgent.
Returns:
AgentExecutor: Returns a callable AgentExecutor object. Either you can call it or use run method with the query to get the response
""" # noqa: E501
tools = toolkit.get_tools()
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
**(agent_executor_kwargs or {}),
)
def create_vectorstore_router_agent(
llm: BaseLanguageModel,
toolkit: VectorStoreRouterToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = ROUTER_PREFIX,
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a VectorStore router agent from an LLM and tools.
Args:
llm (BaseLanguageModel): LLM that will be used by the agent
toolkit (VectorStoreRouterToolkit): Set of tools for the agent which have routing capability with multiple vector stores
callback_manager (Optional[BaseCallbackManager], optional): Object to handle the callback [ Defaults to None. ]
prefix (str, optional): The prefix prompt for the router agent. If not provided uses default ROUTER_PREFIX.
verbose (bool, optional): If you want to see the content of the scratchpad. [ Defaults to False ]
agent_executor_kwargs (Optional[Dict[str, Any]], optional): If there is any other parameter you want to send to the agent. [ Defaults to None ]
**kwargs: Additional named parameters to pass to the ZeroShotAgent.
Returns:
AgentExecutor: Returns a callable AgentExecutor object. Either you can call it or use run method with the query to get the response.
""" # noqa: E501
tools = toolkit.get_tools()
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
**(agent_executor_kwargs or {}),
)

View File

@@ -0,0 +1,13 @@
# flake8: noqa
PREFIX = """You are an agent designed to answer questions about sets of documents.
You have access to tools for interacting with the documents, and the inputs to the tools are questions.
Sometimes, you will be asked to provide sources for your questions, in which case you should use the appropriate tool to do so.
If the question does not seem relevant to any of the tools provided, just return "I don't know" as the answer.
"""
ROUTER_PREFIX = """You are an agent designed to answer questions.
You have access to tools for interacting with different sources, and the inputs to the tools are questions.
Your main task is to decide which of the tools is relevant for answering question at hand.
For complex questions, you can break the question down into sub questions and use tools to answers the sub questions.
"""

View File

@@ -0,0 +1,89 @@
"""Toolkit for interacting with a vector store."""
from typing import List
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.vectorstores import VectorStore
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_openai.llm import OpenAI
from langchain_integrations.tools import BaseTool
from langchain_integrations.tools.vectorstore.tool import (
VectorStoreQATool,
VectorStoreQAWithSourcesTool,
)
class VectorStoreInfo(BaseModel):
"""Information about a VectorStore."""
vectorstore: VectorStore = Field(exclude=True)
name: str
description: str
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
class VectorStoreToolkit(BaseToolkit):
"""Toolkit for interacting with a Vector Store."""
vectorstore_info: VectorStoreInfo = Field(exclude=True)
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
description = VectorStoreQATool.get_description(
self.vectorstore_info.name, self.vectorstore_info.description
)
qa_tool = VectorStoreQATool(
name=self.vectorstore_info.name,
description=description,
vectorstore=self.vectorstore_info.vectorstore,
llm=self.llm,
)
description = VectorStoreQAWithSourcesTool.get_description(
self.vectorstore_info.name, self.vectorstore_info.description
)
qa_with_sources_tool = VectorStoreQAWithSourcesTool(
name=f"{self.vectorstore_info.name}_with_sources",
description=description,
vectorstore=self.vectorstore_info.vectorstore,
llm=self.llm,
)
return [qa_tool, qa_with_sources_tool]
class VectorStoreRouterToolkit(BaseToolkit):
"""Toolkit for routing between Vector Stores."""
vectorstores: List[VectorStoreInfo] = Field(exclude=True)
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
tools: List[BaseTool] = []
for vectorstore_info in self.vectorstores:
description = VectorStoreQATool.get_description(
vectorstore_info.name, vectorstore_info.description
)
qa_tool = VectorStoreQATool(
name=vectorstore_info.name,
description=description,
vectorstore=vectorstore_info.vectorstore,
llm=self.llm,
)
tools.append(qa_tool)
return tools

View File

@@ -0,0 +1,22 @@
from pathlib import Path
from typing import Any
from langchain_core._api.path import as_import_path
def __getattr__(name: str) -> Any:
"""Get attr name."""
here = as_import_path(Path(__file__).parent)
old_path = "langchain." + here + "." + name
new_path = "langchain_experimental." + here + "." + name
raise AttributeError(
"This agent has been moved to langchain experiment. "
"This agent relies on python REPL tool under the hood, so to use it "
"safely please sandbox the python REPL. "
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
"and https://github.com/langchain-ai/langchain/discussions/11680"
"To keep using this code as is, install langchain experimental and "
f"update your import statement from:\n `{old_path}` to `{new_path}`."
)

View File

@@ -0,0 +1 @@
"""Zapier Toolkit."""

View File

@@ -0,0 +1,60 @@
"""[DEPRECATED] Zapier Toolkit."""
from typing import List
from langchain_core._api import warn_deprecated
from langchain_integrations.agent_toolkits.base import BaseToolkit
from langchain_integrations.tools import BaseTool
from langchain_integrations.tools.zapier.tool import ZapierNLARunAction
from langchain_integrations.utilities.zapier import ZapierNLAWrapper
class ZapierToolkit(BaseToolkit):
"""Zapier Toolkit."""
tools: List[BaseTool] = []
@classmethod
def from_zapier_nla_wrapper(
cls, zapier_nla_wrapper: ZapierNLAWrapper
) -> "ZapierToolkit":
"""Create a toolkit from a ZapierNLAWrapper."""
actions = zapier_nla_wrapper.list()
tools = [
ZapierNLARunAction(
action_id=action["id"],
zapier_description=action["description"],
params_schema=action["params"],
api_wrapper=zapier_nla_wrapper,
)
for action in actions
]
return cls(tools=tools)
@classmethod
async def async_from_zapier_nla_wrapper(
cls, zapier_nla_wrapper: ZapierNLAWrapper
) -> "ZapierToolkit":
"""Create a toolkit from a ZapierNLAWrapper."""
actions = await zapier_nla_wrapper.alist()
tools = [
ZapierNLARunAction(
action_id=action["id"],
zapier_description=action["description"],
params_schema=action["params"],
api_wrapper=zapier_nla_wrapper,
)
for action in actions
]
return cls(tools=tools)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
warn_deprecated(
since="0.0.319",
message=(
"This tool will be deprecated on 2023-11-17. See "
"https://nla.zapier.com/sunset/ for details"
),
)
return self.tools

View File

@@ -0,0 +1,19 @@
"""**Chat Loaders** load chat messages from common communications platforms.
Load chat messages from various
communications platforms such as Facebook Messenger, Telegram, and
WhatsApp. The loaded chat messages can be used for fine-tuning models.
**Class hierarchy:**
.. code-block::
BaseChatLoader --> <name>ChatLoader # Examples: WhatsAppChatLoader, IMessageChatLoader
**Main helpers:**
.. code-block::
ChatSession
""" # noqa: E501

View File

@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
from typing import Iterator, List
from langchain_core.chat_sessions import ChatSession
class BaseChatLoader(ABC):
"""Base class for chat loaders."""
@abstractmethod
def lazy_load(self) -> Iterator[ChatSession]:
"""Lazy load the chat sessions."""
def load(self) -> List[ChatSession]:
"""Eagerly load the chat sessions into memory."""
return list(self.lazy_load())

View File

@@ -0,0 +1,79 @@
import json
import logging
from pathlib import Path
from typing import Iterator, Union
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import HumanMessage
from langchain_integrations.chat_loaders.base import BaseChatLoader
logger = logging.getLogger(__file__)
class SingleFileFacebookMessengerChatLoader(BaseChatLoader):
"""Load `Facebook Messenger` chat data from a single file.
Args:
path (Union[Path, str]): The path to the chat file.
Attributes:
path (Path): The path to the chat file.
"""
def __init__(self, path: Union[Path, str]) -> None:
super().__init__()
self.file_path = path if isinstance(path, Path) else Path(path)
def lazy_load(self) -> Iterator[ChatSession]:
"""Lazy loads the chat data from the file.
Yields:
ChatSession: A chat session containing the loaded messages.
"""
with open(self.file_path) as f:
data = json.load(f)
sorted_data = sorted(data["messages"], key=lambda x: x["timestamp_ms"])
messages = []
for m in sorted_data:
messages.append(
HumanMessage(
content=m["content"], additional_kwargs={"sender": m["sender_name"]}
)
)
yield ChatSession(messages=messages)
class FolderFacebookMessengerChatLoader(BaseChatLoader):
"""Load `Facebook Messenger` chat data from a folder.
Args:
path (Union[str, Path]): The path to the directory
containing the chat files.
Attributes:
path (Path): The path to the directory containing the chat files.
"""
def __init__(self, path: Union[str, Path]) -> None:
super().__init__()
self.directory_path = Path(path) if isinstance(path, str) else path
def lazy_load(self) -> Iterator[ChatSession]:
"""Lazy loads the chat data from the folder.
Yields:
ChatSession: A chat session containing the loaded messages.
"""
inbox_path = self.directory_path / "inbox"
for _dir in inbox_path.iterdir():
if _dir.is_dir():
for _file in _dir.iterdir():
if _file.suffix.lower() == ".json":
file_loader = SingleFileFacebookMessengerChatLoader(path=_file)
for result in file_loader.lazy_load():
yield result

View File

@@ -0,0 +1,112 @@
import base64
import re
from typing import Any, Iterator
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import HumanMessage
from langchain_integrations.chat_loaders.base import BaseChatLoader
def _extract_email_content(msg: Any) -> HumanMessage:
from_email = None
for values in msg["payload"]["headers"]:
name = values["name"]
if name == "From":
from_email = values["value"]
if from_email is None:
raise ValueError
for part in msg["payload"]["parts"]:
if part["mimeType"] == "text/plain":
data = part["body"]["data"]
data = base64.urlsafe_b64decode(data).decode("utf-8")
# Regular expression to split the email body at the first
# occurrence of a line that starts with "On ... wrote:"
pattern = re.compile(r"\r\nOn .+(\r\n)*wrote:\r\n")
# Split the email body and extract the first part
newest_response = re.split(pattern, data)[0]
message = HumanMessage(
content=newest_response, additional_kwargs={"sender": from_email}
)
return message
raise ValueError
def _get_message_data(service: Any, message: Any) -> ChatSession:
msg = service.users().messages().get(userId="me", id=message["id"]).execute()
message_content = _extract_email_content(msg)
in_reply_to = None
email_data = msg["payload"]["headers"]
for values in email_data:
name = values["name"]
if name == "In-Reply-To":
in_reply_to = values["value"]
if in_reply_to is None:
raise ValueError
thread_id = msg["threadId"]
thread = service.users().threads().get(userId="me", id=thread_id).execute()
messages = thread["messages"]
response_email = None
for message in messages:
email_data = message["payload"]["headers"]
for values in email_data:
if values["name"] == "Message-ID":
message_id = values["value"]
if message_id == in_reply_to:
response_email = message
if response_email is None:
raise ValueError
starter_content = _extract_email_content(response_email)
return ChatSession(messages=[starter_content, message_content])
class GMailLoader(BaseChatLoader):
"""Load data from `GMail`.
There are many ways you could want to load data from GMail.
This loader is currently fairly opinionated in how to do so.
The way it does it is it first looks for all messages that you have sent.
It then looks for messages where you are responding to a previous email.
It then fetches that previous email, and creates a training example
of that email, followed by your email.
Note that there are clear limitations here. For example,
all examples created are only looking at the previous email for context.
To use:
- Set up a Google Developer Account:
Go to the Google Developer Console, create a project,
and enable the Gmail API for that project.
This will give you a credentials.json file that you'll need later.
"""
def __init__(self, creds: Any, n: int = 100, raise_error: bool = False) -> None:
super().__init__()
self.creds = creds
self.n = n
self.raise_error = raise_error
def lazy_load(self) -> Iterator[ChatSession]:
from googleapiclient.discovery import build
service = build("gmail", "v1", credentials=self.creds)
results = (
service.users()
.messages()
.list(userId="me", labelIds=["SENT"], maxResults=self.n)
.execute()
)
messages = results.get("messages", [])
for message in messages:
try:
yield _get_message_data(service, message)
except Exception as e:
# TODO: handle errors better
if self.raise_error:
raise e
else:
pass

View File

@@ -0,0 +1,161 @@
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import HumanMessage
from langchain_integrations.chat_loaders.base import BaseChatLoader
if TYPE_CHECKING:
import sqlite3
class IMessageChatLoader(BaseChatLoader):
"""Load chat sessions from the `iMessage` chat.db SQLite file.
It only works on macOS when you have iMessage enabled and have the chat.db file.
The chat.db file is likely located at ~/Library/Messages/chat.db. However, your
terminal may not have permission to access this file. To resolve this, you can
copy the file to a different location, change the permissions of the file, or
grant full disk access for your terminal emulator
in System Settings > Security and Privacy > Full Disk Access.
"""
def __init__(self, path: Optional[Union[str, Path]] = None):
"""
Initialize the IMessageChatLoader.
Args:
path (str or Path, optional): Path to the chat.db SQLite file.
Defaults to None, in which case the default path
~/Library/Messages/chat.db will be used.
"""
if path is None:
path = Path.home() / "Library" / "Messages" / "chat.db"
self.db_path = path if isinstance(path, Path) else Path(path)
if not self.db_path.exists():
raise FileNotFoundError(f"File {self.db_path} not found")
try:
import sqlite3 # noqa: F401
except ImportError as e:
raise ImportError(
"The sqlite3 module is required to load iMessage chats.\n"
"Please install it with `pip install pysqlite3`"
) from e
def _parse_attributedBody(self, attributedBody: bytes) -> str:
"""
Parse the attributedBody field of the message table
for the text content of the message.
The attributedBody field is a binary blob that contains
the message content after the byte string b"NSString":
5 bytes 1-3 bytes `len` bytes
... | b"NSString" | preamble | `len` | contents | ...
The 5 preamble bytes are always b"\x01\x94\x84\x01+"
The size of `len` is either 1 byte or 3 bytes:
- If the first byte in `len` is b"\x81" then `len` is 3 bytes long.
So the message length is the 2 bytes after, in little Endian.
- Otherwise, the size of `len` is 1 byte, and the message length is
that byte.
Args:
attributedBody (bytes): attributedBody field of the message table.
Return:
str: Text content of the message.
"""
content = attributedBody.split(b"NSString")[1][5:]
length, start = content[0], 1
if content[0] == 129:
length, start = int.from_bytes(content[1:3], "little"), 3
return content[start : start + length].decode("utf-8", errors="ignore")
def _load_single_chat_session(
self, cursor: "sqlite3.Cursor", chat_id: int
) -> ChatSession:
"""
Load a single chat session from the iMessage chat.db.
Args:
cursor: SQLite cursor object.
chat_id (int): ID of the chat session to load.
Returns:
ChatSession: Loaded chat session.
"""
results: List[HumanMessage] = []
query = """
SELECT message.date, handle.id, message.text, message.attributedBody
FROM message
JOIN chat_message_join ON message.ROWID = chat_message_join.message_id
JOIN handle ON message.handle_id = handle.ROWID
WHERE chat_message_join.chat_id = ?
ORDER BY message.date ASC;
"""
cursor.execute(query, (chat_id,))
messages = cursor.fetchall()
for date, sender, text, attributedBody in messages:
if text:
content = text
elif attributedBody:
content = self._parse_attributedBody(attributedBody)
else: # Skip messages with no content
continue
results.append(
HumanMessage(
role=sender,
content=content,
additional_kwargs={
"message_time": date,
"sender": sender,
},
)
)
return ChatSession(messages=results)
def lazy_load(self) -> Iterator[ChatSession]:
"""
Lazy load the chat sessions from the iMessage chat.db
and yield them in the required format.
Yields:
ChatSession: Loaded chat session.
"""
import sqlite3
try:
conn = sqlite3.connect(self.db_path)
except sqlite3.OperationalError as e:
raise ValueError(
f"Could not open iMessage DB file {self.db_path}.\n"
"Make sure your terminal emulator has disk access to this file.\n"
" You can either copy the DB file to an accessible location"
" or grant full disk access for your terminal emulator."
" You can grant full disk access for your terminal emulator"
" in System Settings > Security and Privacy > Full Disk Access."
) from e
cursor = conn.cursor()
# Fetch the list of chat IDs sorted by time (most recent first)
query = """SELECT chat_id
FROM message
JOIN chat_message_join ON message.ROWID = chat_message_join.message_id
GROUP BY chat_id
ORDER BY MAX(date) DESC;"""
cursor.execute(query)
chat_ids = [row[0] for row in cursor.fetchall()]
for chat_id in chat_ids:
yield self._load_single_chat_session(cursor, chat_id)
conn.close()

View File

@@ -0,0 +1,159 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union, cast
from langchain_core.chat_sessions import ChatSession
from langchain_core.load import load
from langchain_integrations.chat_loaders.base import BaseChatLoader
if TYPE_CHECKING:
from langsmith.client import Client
from langsmith.schemas import Run
logger = logging.getLogger(__name__)
class LangSmithRunChatLoader(BaseChatLoader):
"""
Load chat sessions from a list of LangSmith "llm" runs.
Attributes:
runs (Iterable[Union[str, Run]]): The list of LLM run IDs or run objects.
client (Client): Instance of LangSmith client for fetching data.
"""
def __init__(
self, runs: Iterable[Union[str, Run]], client: Optional["Client"] = None
):
"""
Initialize a new LangSmithRunChatLoader instance.
:param runs: List of LLM run IDs or run objects.
:param client: An instance of LangSmith client, if not provided,
a new client instance will be created.
"""
from langsmith.client import Client
self.runs = runs
self.client = client or Client()
def _load_single_chat_session(self, llm_run: "Run") -> ChatSession:
"""
Convert an individual LangSmith LLM run to a ChatSession.
:param llm_run: The LLM run object.
:return: A chat session representing the run's data.
"""
chat_session = LangSmithRunChatLoader._get_messages_from_llm_run(llm_run)
functions = LangSmithRunChatLoader._get_functions_from_llm_run(llm_run)
if functions:
chat_session["functions"] = functions
return chat_session
@staticmethod
def _get_messages_from_llm_run(llm_run: "Run") -> ChatSession:
"""
Extract messages from a LangSmith LLM run.
:param llm_run: The LLM run object.
:return: ChatSession with the extracted messages.
"""
if llm_run.run_type != "llm":
raise ValueError(f"Expected run of type llm. Got: {llm_run.run_type}")
if "messages" not in llm_run.inputs:
raise ValueError(f"Run has no 'messages' inputs. Got {llm_run.inputs}")
if not llm_run.outputs:
raise ValueError("Cannot convert pending run")
messages = load(llm_run.inputs)["messages"]
message_chunk = load(llm_run.outputs)["generations"][0]["message"]
return ChatSession(messages=messages + [message_chunk])
@staticmethod
def _get_functions_from_llm_run(llm_run: "Run") -> Optional[List[Dict]]:
"""
Extract functions from a LangSmith LLM run if they exist.
:param llm_run: The LLM run object.
:return: Functions from the run or None.
"""
if llm_run.run_type != "llm":
raise ValueError(f"Expected run of type llm. Got: {llm_run.run_type}")
return (llm_run.extra or {}).get("invocation_params", {}).get("functions")
def lazy_load(self) -> Iterator[ChatSession]:
"""
Lazy load the chat sessions from the iterable of run IDs.
This method fetches the runs and converts them to chat sessions on-the-fly,
yielding one session at a time.
:return: Iterator of chat sessions containing messages.
"""
from langsmith.schemas import Run
for run_obj in self.runs:
try:
if hasattr(run_obj, "id"):
run = run_obj
else:
run = self.client.read_run(run_obj)
session = self._load_single_chat_session(cast(Run, run))
yield session
except ValueError as e:
logger.warning(f"Could not load run {run_obj}: {repr(e)}")
continue
class LangSmithDatasetChatLoader(BaseChatLoader):
"""
Load chat sessions from a LangSmith dataset with the "chat" data type.
Attributes:
dataset_name (str): The name of the LangSmith dataset.
client (Client): Instance of LangSmith client for fetching data.
"""
def __init__(self, *, dataset_name: str, client: Optional["Client"] = None):
"""
Initialize a new LangSmithChatDatasetLoader instance.
:param dataset_name: The name of the LangSmith dataset.
:param client: An instance of LangSmith client; if not provided,
a new client instance will be created.
"""
try:
from langsmith.client import Client
except ImportError as e:
raise ImportError(
"The LangSmith client is required to load LangSmith datasets.\n"
"Please install it with `pip install langsmith`"
) from e
self.dataset_name = dataset_name
self.client = client or Client()
def lazy_load(self) -> Iterator[ChatSession]:
"""
Lazy load the chat sessions from the specified LangSmith dataset.
This method fetches the chat data from the dataset and
converts each data point to chat sessions on-the-fly,
yielding one session at a time.
:return: Iterator of chat sessions containing messages.
"""
from langchain_integrations.adapters import openai as oai_adapter # noqa: E402
data = self.client.read_dataset_openai_finetuning(
dataset_name=self.dataset_name
)
for data_point in data:
yield ChatSession(
messages=[
oai_adapter.convert_dict_to_message(m)
for m in data_point.get("messages", [])
],
functions=data_point.get("functions"),
)

View File

@@ -0,0 +1,86 @@
import json
import logging
import re
import zipfile
from pathlib import Path
from typing import Dict, Iterator, List, Union
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import AIMessage, HumanMessage
from langchain_integrations.chat_loaders.base import BaseChatLoader
logger = logging.getLogger(__name__)
class SlackChatLoader(BaseChatLoader):
"""Load `Slack` conversations from a dump zip file."""
def __init__(
self,
path: Union[str, Path],
):
"""
Initialize the chat loader with the path to the exported Slack dump zip file.
:param path: Path to the exported Slack dump zip file.
"""
self.zip_path = path if isinstance(path, Path) else Path(path)
if not self.zip_path.exists():
raise FileNotFoundError(f"File {self.zip_path} not found")
def _load_single_chat_session(self, messages: List[Dict]) -> ChatSession:
results: List[Union[AIMessage, HumanMessage]] = []
previous_sender = None
for message in messages:
if not isinstance(message, dict):
continue
text = message.get("text", "")
timestamp = message.get("ts", "")
sender = message.get("user", "")
if not sender:
continue
skip_pattern = re.compile(
r"<@U\d+> has joined the channel", flags=re.IGNORECASE
)
if skip_pattern.match(text):
continue
if sender == previous_sender:
results[-1].content += "\n\n" + text
results[-1].additional_kwargs["events"].append(
{"message_time": timestamp}
)
else:
results.append(
HumanMessage(
role=sender,
content=text,
additional_kwargs={
"sender": sender,
"events": [{"message_time": timestamp}],
},
)
)
previous_sender = sender
return ChatSession(messages=results)
def _read_json(self, zip_file: zipfile.ZipFile, file_path: str) -> List[dict]:
"""Read JSON data from a zip subfile."""
with zip_file.open(file_path, "r") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError(f"Expected list of dictionaries, got {type(data)}")
return data
def lazy_load(self) -> Iterator[ChatSession]:
"""
Lazy load the chat sessions from the Slack dump file and yield them
in the required format.
:return: Iterator of chat sessions containing messages.
"""
with zipfile.ZipFile(str(self.zip_path), "r") as zip_file:
for file_path in zip_file.namelist():
if file_path.endswith(".json"):
messages = self._read_json(zip_file, file_path)
yield self._load_single_chat_session(messages)

View File

@@ -0,0 +1,151 @@
import json
import logging
import os
import tempfile
import zipfile
from pathlib import Path
from typing import Iterator, List, Union
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_integrations.chat_loaders.base import BaseChatLoader
logger = logging.getLogger(__name__)
class TelegramChatLoader(BaseChatLoader):
"""Load `telegram` conversations to LangChain chat messages.
To export, use the Telegram Desktop app from
https://desktop.telegram.org/, select a conversation, click the three dots
in the top right corner, and select "Export chat history". Then select
"Machine-readable JSON" (preferred) to export. Note: the 'lite' versions of
the desktop app (like "Telegram for MacOS") do not support exporting chat
history.
"""
def __init__(
self,
path: Union[str, Path],
):
"""Initialize the TelegramChatLoader.
Args:
path (Union[str, Path]): Path to the exported Telegram chat zip,
directory, json, or HTML file.
"""
self.path = path if isinstance(path, str) else str(path)
def _load_single_chat_session_html(self, file_path: str) -> ChatSession:
"""Load a single chat session from an HTML file.
Args:
file_path (str): Path to the HTML file.
Returns:
ChatSession: The loaded chat session.
"""
try:
from bs4 import BeautifulSoup
except ImportError:
raise ImportError(
"Please install the 'beautifulsoup4' package to load"
" Telegram HTML files. You can do this by running"
"'pip install beautifulsoup4' in your terminal."
)
with open(file_path, "r", encoding="utf-8") as file:
soup = BeautifulSoup(file, "html.parser")
results: List[Union[HumanMessage, AIMessage]] = []
previous_sender = None
for message in soup.select(".message.default"):
timestamp = message.select_one(".pull_right.date.details")["title"]
from_name_element = message.select_one(".from_name")
if from_name_element is None and previous_sender is None:
logger.debug("from_name not found in message")
continue
elif from_name_element is None:
from_name = previous_sender
else:
from_name = from_name_element.text.strip()
text = message.select_one(".text").text.strip()
results.append(
HumanMessage(
content=text,
additional_kwargs={
"sender": from_name,
"events": [{"message_time": timestamp}],
},
)
)
previous_sender = from_name
return ChatSession(messages=results)
def _load_single_chat_session_json(self, file_path: str) -> ChatSession:
"""Load a single chat session from a JSON file.
Args:
file_path (str): Path to the JSON file.
Returns:
ChatSession: The loaded chat session.
"""
with open(file_path, "r", encoding="utf-8") as file:
data = json.load(file)
messages = data.get("messages", [])
results: List[BaseMessage] = []
for message in messages:
text = message.get("text", "")
timestamp = message.get("date", "")
from_name = message.get("from", "")
results.append(
HumanMessage(
content=text,
additional_kwargs={
"sender": from_name,
"events": [{"message_time": timestamp}],
},
)
)
return ChatSession(messages=results)
def _iterate_files(self, path: str) -> Iterator[str]:
"""Iterate over files in a directory or zip file.
Args:
path (str): Path to the directory or zip file.
Yields:
str: Path to each file.
"""
if os.path.isfile(path) and path.endswith((".html", ".json")):
yield path
elif os.path.isdir(path):
for root, _, files in os.walk(path):
for file in files:
if file.endswith((".html", ".json")):
yield os.path.join(root, file)
elif zipfile.is_zipfile(path):
with zipfile.ZipFile(path) as zip_file:
for file in zip_file.namelist():
if file.endswith((".html", ".json")):
with tempfile.TemporaryDirectory() as temp_dir:
yield zip_file.extract(file, path=temp_dir)
def lazy_load(self) -> Iterator[ChatSession]:
"""Lazy load the messages from the chat file and yield them
in as chat sessions.
Yields:
ChatSession: The loaded chat session.
"""
for file_path in self._iterate_files(self.path):
if file_path.endswith(".html"):
yield self._load_single_chat_session_html(file_path)
elif file_path.endswith(".json"):
yield self._load_single_chat_session_json(file_path)

View File

@@ -0,0 +1,95 @@
"""Utilities for chat loaders."""
from copy import deepcopy
from typing import Iterable, Iterator, List
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import AIMessage, BaseMessage
def merge_chat_runs_in_session(
chat_session: ChatSession, delimiter: str = "\n\n"
) -> ChatSession:
"""Merge chat runs together in a chat session.
A chat run is a sequence of messages from the same sender.
Args:
chat_session: A chat session.
Returns:
A chat session with merged chat runs.
"""
messages: List[BaseMessage] = []
for message in chat_session["messages"]:
if not isinstance(message.content, str):
raise ValueError(
"Chat Loaders only support messages with content type string, "
f"got {message.content}"
)
if not messages:
messages.append(deepcopy(message))
elif (
isinstance(message, type(messages[-1]))
and messages[-1].additional_kwargs.get("sender") is not None
and messages[-1].additional_kwargs["sender"]
== message.additional_kwargs.get("sender")
):
if not isinstance(messages[-1].content, str):
raise ValueError(
"Chat Loaders only support messages with content type string, "
f"got {messages[-1].content}"
)
messages[-1].content = (
messages[-1].content + delimiter + message.content
).strip()
messages[-1].additional_kwargs.get("events", []).extend(
message.additional_kwargs.get("events") or []
)
else:
messages.append(deepcopy(message))
return ChatSession(messages=messages)
def merge_chat_runs(chat_sessions: Iterable[ChatSession]) -> Iterator[ChatSession]:
"""Merge chat runs together.
A chat run is a sequence of messages from the same sender.
Args:
chat_sessions: A list of chat sessions.
Returns:
A list of chat sessions with merged chat runs.
"""
for chat_session in chat_sessions:
yield merge_chat_runs_in_session(chat_session)
def map_ai_messages_in_session(chat_sessions: ChatSession, sender: str) -> ChatSession:
"""Convert messages from the specified 'sender' to AI messages.
This is useful for fine-tuning the AI to adapt to your voice.
"""
messages = []
num_converted = 0
for message in chat_sessions["messages"]:
if message.additional_kwargs.get("sender") == sender:
message = AIMessage(
content=message.content,
additional_kwargs=message.additional_kwargs.copy(),
example=getattr(message, "example", None),
)
num_converted += 1
messages.append(message)
return ChatSession(messages=messages)
def map_ai_messages(
chat_sessions: Iterable[ChatSession], sender: str
) -> Iterator[ChatSession]:
"""Convert messages from the specified 'sender' to AI messages.
This is useful for fine-tuning the AI to adapt to your voice.
"""
for chat_session in chat_sessions:
yield map_ai_messages_in_session(chat_session, sender)

View File

@@ -0,0 +1,119 @@
import logging
import os
import re
import zipfile
from typing import Iterator, List, Union
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import AIMessage, HumanMessage
from langchain_integrations.chat_loaders.base import BaseChatLoader
logger = logging.getLogger(__name__)
class WhatsAppChatLoader(BaseChatLoader):
"""Load `WhatsApp` conversations from a dump zip file or directory."""
def __init__(self, path: str):
"""Initialize the WhatsAppChatLoader.
Args:
path (str): Path to the exported WhatsApp chat
zip directory, folder, or file.
To generate the dump, open the chat, click the three dots in the top
right corner, and select "More". Then select "Export chat" and
choose "Without media".
"""
self.path = path
ignore_lines = [
"This message was deleted",
"<Media omitted>",
"image omitted",
"Messages and calls are end-to-end encrypted. No one outside of this chat,"
" not even WhatsApp, can read or listen to them.",
]
self._ignore_lines = re.compile(
r"(" + "|".join([r"\u200E*" + line for line in ignore_lines]) + r")",
flags=re.IGNORECASE,
)
self._message_line_regex = re.compile(
r"\u200E*\[?(\d{1,2}/\d{1,2}/\d{2,4}, \d{1,2}:\d{2}:\d{2} (?:AM|PM))\]?[ \u200E]*([^:]+): (.+)", # noqa
flags=re.IGNORECASE,
)
def _load_single_chat_session(self, file_path: str) -> ChatSession:
"""Load a single chat session from a file.
Args:
file_path (str): Path to the chat file.
Returns:
ChatSession: The loaded chat session.
"""
with open(file_path, "r", encoding="utf-8") as file:
txt = file.read()
# Split messages by newlines, but keep multi-line messages grouped
chat_lines: List[str] = []
current_message = ""
for line in txt.split("\n"):
if self._message_line_regex.match(line):
if current_message:
chat_lines.append(current_message)
current_message = line
else:
current_message += " " + line.strip()
if current_message:
chat_lines.append(current_message)
results: List[Union[HumanMessage, AIMessage]] = []
for line in chat_lines:
result = self._message_line_regex.match(line.strip())
if result:
timestamp, sender, text = result.groups()
if not self._ignore_lines.match(text.strip()):
results.append(
HumanMessage(
role=sender,
content=text,
additional_kwargs={
"sender": sender,
"events": [{"message_time": timestamp}],
},
)
)
else:
logger.debug(f"Could not parse line: {line}")
return ChatSession(messages=results)
def _iterate_files(self, path: str) -> Iterator[str]:
"""Iterate over the files in a directory or zip file.
Args:
path (str): Path to the directory or zip file.
Yields:
str: The path to each file.
"""
if os.path.isfile(path):
yield path
elif os.path.isdir(path):
for root, _, files in os.walk(path):
for file in files:
if file.endswith(".txt"):
yield os.path.join(root, file)
elif zipfile.is_zipfile(path):
with zipfile.ZipFile(path) as zip_file:
for file in zip_file.namelist():
if file.endswith(".txt"):
yield zip_file.extract(file)
def lazy_load(self) -> Iterator[ChatSession]:
"""Lazy load the messages from the chat file and yield
them as chat sessions.
Yields:
Iterator[ChatSession]: The loaded chat sessions.
"""
yield self._load_single_chat_session(self.path)

View File

@@ -0,0 +1,57 @@
from langchain_integrations.chat_message_histories.astradb import (
AstraDBChatMessageHistory,
)
from langchain_integrations.chat_message_histories.cassandra import (
CassandraChatMessageHistory,
)
from langchain_integrations.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory
from langchain_integrations.chat_message_histories.dynamodb import DynamoDBChatMessageHistory
from langchain_integrations.chat_message_histories.elasticsearch import (
ElasticsearchChatMessageHistory,
)
from langchain_integrations.chat_message_histories.file import FileChatMessageHistory
from langchain_integrations.chat_message_histories.firestore import (
FirestoreChatMessageHistory,
)
from langchain_integrations.chat_message_histories.in_memory import ChatMessageHistory
from langchain_integrations.chat_message_histories.momento import MomentoChatMessageHistory
from langchain_integrations.chat_message_histories.mongodb import MongoDBChatMessageHistory
from langchain_integrations.chat_message_histories.neo4j import Neo4jChatMessageHistory
from langchain_integrations.chat_message_histories.postgres import PostgresChatMessageHistory
from langchain_integrations.chat_message_histories.redis import RedisChatMessageHistory
from langchain_integrations.chat_message_histories.rocksetdb import RocksetChatMessageHistory
from langchain_integrations.chat_message_histories.singlestoredb import (
SingleStoreDBChatMessageHistory,
)
from langchain_integrations.chat_message_histories.sql import SQLChatMessageHistory
from langchain_integrations.chat_message_histories.streamlit import (
StreamlitChatMessageHistory,
)
from langchain_integrations.chat_message_histories.upstash_redis import (
UpstashRedisChatMessageHistory,
)
from langchain_integrations.chat_message_histories.xata import XataChatMessageHistory
from langchain_integrations.chat_message_histories.zep import ZepChatMessageHistory
__all__ = [
"AstraDBChatMessageHistory",
"ChatMessageHistory",
"CassandraChatMessageHistory",
"CosmosDBChatMessageHistory",
"DynamoDBChatMessageHistory",
"ElasticsearchChatMessageHistory",
"FileChatMessageHistory",
"FirestoreChatMessageHistory",
"MomentoChatMessageHistory",
"MongoDBChatMessageHistory",
"PostgresChatMessageHistory",
"RedisChatMessageHistory",
"RocksetChatMessageHistory",
"SQLChatMessageHistory",
"StreamlitChatMessageHistory",
"SingleStoreDBChatMessageHistory",
"XataChatMessageHistory",
"ZepChatMessageHistory",
"UpstashRedisChatMessageHistory",
"Neo4jChatMessageHistory",
]

View File

@@ -0,0 +1,114 @@
"""Astra DB - based chat message history, based on astrapy."""
from __future__ import annotations
import json
import time
import typing
from typing import List, Optional
if typing.TYPE_CHECKING:
from astrapy.db import AstraDB as LibAstraDB
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
DEFAULT_COLLECTION_NAME = "langchain_message_store"
class AstraDBChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in Astra DB.
Args (only keyword-arguments accepted):
session_id: arbitrary key that is used to store the messages
of a single chat session.
collection_name (str): name of the Astra DB collection to create/use.
token (Optional[str]): API token for Astra DB usage.
api_endpoint (Optional[str]): full URL to the API endpoint,
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
namespace (Optional[str]): namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
"""
def __init__(
self,
*,
session_id: str,
collection_name: str = DEFAULT_COLLECTION_NAME,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[LibAstraDB] = None, # type 'astrapy.db.AstraDB'
namespace: Optional[str] = None,
) -> None:
"""Create an Astra DB chat message history."""
try:
from astrapy.db import AstraDB as LibAstraDB
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import a recent astrapy python package. "
"Please install it with `pip install --upgrade astrapy`."
)
# Conflicting-arg checks:
if astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' to AstraDB if passing "
"'token' and 'api_endpoint'."
)
self.session_id = session_id
self.collection_name = collection_name
self.token = token
self.api_endpoint = api_endpoint
self.namespace = namespace
if astra_db_client is not None:
self.astra_db = astra_db_client
else:
self.astra_db = LibAstraDB(
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace,
)
self.collection = self.astra_db.create_collection(self.collection_name)
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve all session messages from DB"""
message_blobs = [
doc["body_blob"]
for doc in sorted(
self.collection.paginated_find(
filter={
"session_id": self.session_id,
},
projection={
"timestamp": 1,
"body_blob": 1,
},
),
key=lambda _doc: _doc["timestamp"],
)
]
items = [json.loads(message_blob) for message_blob in message_blobs]
messages = messages_from_dict(items)
return messages
def add_message(self, message: BaseMessage) -> None:
"""Write a message to the table"""
self.collection.insert_one(
{
"timestamp": time.time(),
"session_id": self.session_id,
"body_blob": json.dumps(message_to_dict(message)),
}
)
def clear(self) -> None:
"""Clear session memory from DB"""
self.collection.delete_many(filter={"session_id": self.session_id})

View File

@@ -0,0 +1,72 @@
"""Cassandra-based chat message history, based on cassIO."""
from __future__ import annotations
import json
import typing
from typing import List
if typing.TYPE_CHECKING:
from cassandra.cluster import Session
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
DEFAULT_TABLE_NAME = "message_store"
DEFAULT_TTL_SECONDS = None
class CassandraChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in Cassandra.
Args:
session_id: arbitrary key that is used to store the messages
of a single chat session.
session: a Cassandra `Session` object (an open DB connection)
keyspace: name of the keyspace to use.
table_name: name of the table to use.
ttl_seconds: time-to-live (seconds) for automatic expiration
of stored entries. None (default) for no expiration.
"""
def __init__(
self,
session_id: str,
session: Session,
keyspace: str,
table_name: str = DEFAULT_TABLE_NAME,
ttl_seconds: typing.Optional[int] = DEFAULT_TTL_SECONDS,
) -> None:
try:
from cassio.history import StoredBlobHistory
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import cassio python package. "
"Please install it with `pip install cassio`."
)
self.session_id = session_id
self.ttl_seconds = ttl_seconds
self.blob_history = StoredBlobHistory(session, keyspace, table_name)
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve all session messages from DB"""
message_blobs = self.blob_history.retrieve(
self.session_id,
)
items = [json.loads(message_blob) for message_blob in message_blobs]
messages = messages_from_dict(items)
return messages
def add_message(self, message: BaseMessage) -> None:
"""Write a message to the table"""
self.blob_history.store(
self.session_id, json.dumps(message_to_dict(message)), self.ttl_seconds
)
def clear(self) -> None:
"""Clear session memory from DB"""
self.blob_history.clear_session_id(self.session_id)

View File

@@ -0,0 +1,172 @@
"""Azure CosmosDB Memory History."""
from __future__ import annotations
import logging
from types import TracebackType
from typing import TYPE_CHECKING, Any, List, Optional, Type
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
messages_from_dict,
messages_to_dict,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from azure.cosmos import ContainerProxy
class CosmosDBChatMessageHistory(BaseChatMessageHistory):
"""Chat message history backed by Azure CosmosDB."""
def __init__(
self,
cosmos_endpoint: str,
cosmos_database: str,
cosmos_container: str,
session_id: str,
user_id: str,
credential: Any = None,
connection_string: Optional[str] = None,
ttl: Optional[int] = None,
cosmos_client_kwargs: Optional[dict] = None,
):
"""
Initializes a new instance of the CosmosDBChatMessageHistory class.
Make sure to call prepare_cosmos or use the context manager to make
sure your database is ready.
Either a credential or a connection string must be provided.
:param cosmos_endpoint: The connection endpoint for the Azure Cosmos DB account.
:param cosmos_database: The name of the database to use.
:param cosmos_container: The name of the container to use.
:param session_id: The session ID to use, can be overwritten while loading.
:param user_id: The user ID to use, can be overwritten while loading.
:param credential: The credential to use to authenticate to Azure Cosmos DB.
:param connection_string: The connection string to use to authenticate.
:param ttl: The time to live (in seconds) to use for documents in the container.
:param cosmos_client_kwargs: Additional kwargs to pass to the CosmosClient.
"""
self.cosmos_endpoint = cosmos_endpoint
self.cosmos_database = cosmos_database
self.cosmos_container = cosmos_container
self.credential = credential
self.conn_string = connection_string
self.session_id = session_id
self.user_id = user_id
self.ttl = ttl
self.messages: List[BaseMessage] = []
try:
from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
CosmosClient,
)
except ImportError as exc:
raise ImportError(
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
"Please install it with `pip install azure-cosmos`."
) from exc
if self.credential:
self._client = CosmosClient(
url=self.cosmos_endpoint,
credential=self.credential,
**cosmos_client_kwargs or {},
)
elif self.conn_string:
self._client = CosmosClient.from_connection_string(
conn_str=self.conn_string,
**cosmos_client_kwargs or {},
)
else:
raise ValueError("Either a connection string or a credential must be set.")
self._container: Optional[ContainerProxy] = None
def prepare_cosmos(self) -> None:
"""Prepare the CosmosDB client.
Use this function or the context manager to make sure your database is ready.
"""
try:
from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
PartitionKey,
)
except ImportError as exc:
raise ImportError(
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
"Please install it with `pip install azure-cosmos`."
) from exc
database = self._client.create_database_if_not_exists(self.cosmos_database)
self._container = database.create_container_if_not_exists(
self.cosmos_container,
partition_key=PartitionKey("/user_id"),
default_ttl=self.ttl,
)
self.load_messages()
def __enter__(self) -> "CosmosDBChatMessageHistory":
"""Context manager entry point."""
self._client.__enter__()
self.prepare_cosmos()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Context manager exit"""
self.upsert_messages()
self._client.__exit__(exc_type, exc_val, traceback)
def load_messages(self) -> None:
"""Retrieve the messages from Cosmos"""
if not self._container:
raise ValueError("Container not initialized")
try:
from azure.cosmos.exceptions import ( # pylint: disable=import-outside-toplevel # noqa: E501
CosmosHttpResponseError,
)
except ImportError as exc:
raise ImportError(
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
"Please install it with `pip install azure-cosmos`."
) from exc
try:
item = self._container.read_item(
item=self.session_id, partition_key=self.user_id
)
except CosmosHttpResponseError:
logger.info("no session found")
return
if "messages" in item and len(item["messages"]) > 0:
self.messages = messages_from_dict(item["messages"])
def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store"""
self.messages.append(message)
self.upsert_messages()
def upsert_messages(self) -> None:
"""Update the cosmosdb item."""
if not self._container:
raise ValueError("Container not initialized")
self._container.upsert_item(
body={
"id": self.session_id,
"user_id": self.user_id,
"messages": messages_to_dict(self.messages),
}
)
def clear(self) -> None:
"""Clear session memory from this memory and cosmos."""
self.messages = []
if self._container:
self._container.delete_item(
item=self.session_id, partition_key=self.user_id
)

View File

@@ -0,0 +1,153 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Dict, List, Optional
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
messages_to_dict,
)
if TYPE_CHECKING:
from boto3.session import Session
logger = logging.getLogger(__name__)
class DynamoDBChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in AWS DynamoDB.
This class expects that a DynamoDB table exists with name `table_name`
Args:
table_name: name of the DynamoDB table
session_id: arbitrary key that is used to store the messages
of a single chat session.
endpoint_url: URL of the AWS endpoint to connect to. This argument
is optional and useful for test purposes, like using Localstack.
If you plan to use AWS cloud service, you normally don't have to
worry about setting the endpoint_url.
primary_key_name: name of the primary key of the DynamoDB table. This argument
is optional, defaulting to "SessionId".
key: an optional dictionary with a custom primary and secondary key.
This argument is optional, but useful when using composite dynamodb keys, or
isolating records based off of application details such as a user id.
This may also contain global and local secondary index keys.
kms_key_id: an optional AWS KMS Key ID, AWS KMS Key ARN, or AWS KMS Alias for
client-side encryption
"""
def __init__(
self,
table_name: str,
session_id: str,
endpoint_url: Optional[str] = None,
primary_key_name: str = "SessionId",
key: Optional[Dict[str, str]] = None,
boto3_session: Optional[Session] = None,
kms_key_id: Optional[str] = None,
):
if boto3_session:
client = boto3_session.resource("dynamodb", endpoint_url=endpoint_url)
else:
try:
import boto3
except ImportError as e:
raise ImportError(
"Unable to import boto3, please install with `pip install boto3`."
) from e
if endpoint_url:
client = boto3.resource("dynamodb", endpoint_url=endpoint_url)
else:
client = boto3.resource("dynamodb")
self.table = client.Table(table_name)
self.session_id = session_id
self.key: Dict = key or {primary_key_name: session_id}
if kms_key_id:
try:
from dynamodb_encryption_sdk.encrypted.table import EncryptedTable
from dynamodb_encryption_sdk.identifiers import CryptoAction
from dynamodb_encryption_sdk.material_providers.aws_kms import (
AwsKmsCryptographicMaterialsProvider,
)
from dynamodb_encryption_sdk.structures import AttributeActions
except ImportError as e:
raise ImportError(
"Unable to import dynamodb_encryption_sdk, please install with "
"`pip install dynamodb-encryption-sdk`."
) from e
actions = AttributeActions(
default_action=CryptoAction.DO_NOTHING,
attribute_actions={"History": CryptoAction.ENCRYPT_AND_SIGN},
)
aws_kms_cmp = AwsKmsCryptographicMaterialsProvider(key_id=kms_key_id)
self.table = EncryptedTable(
table=self.table,
materials_provider=aws_kms_cmp,
attribute_actions=actions,
auto_refresh_table_indexes=False,
)
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from DynamoDB"""
try:
from botocore.exceptions import ClientError
except ImportError as e:
raise ImportError(
"Unable to import botocore, please install with `pip install botocore`."
) from e
response = None
try:
response = self.table.get_item(Key=self.key)
except ClientError as error:
if error.response["Error"]["Code"] == "ResourceNotFoundException":
logger.warning("No record found with session id: %s", self.session_id)
else:
logger.error(error)
if response and "Item" in response:
items = response["Item"]["History"]
else:
items = []
messages = messages_from_dict(items)
return messages
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in DynamoDB"""
try:
from botocore.exceptions import ClientError
except ImportError as e:
raise ImportError(
"Unable to import botocore, please install with `pip install botocore`."
) from e
messages = messages_to_dict(self.messages)
_message = message_to_dict(message)
messages.append(_message)
try:
self.table.put_item(Item={**self.key, "History": messages})
except ClientError as err:
logger.error(err)
def clear(self) -> None:
"""Clear session memory from DynamoDB"""
try:
from botocore.exceptions import ClientError
except ImportError as e:
raise ImportError(
"Unable to import botocore, please install with `pip install botocore`."
) from e
try:
self.table.delete_item(Key=self.key)
except ClientError as err:
logger.error(err)

View File

@@ -0,0 +1,195 @@
import json
import logging
from time import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
if TYPE_CHECKING:
from elasticsearch import Elasticsearch
logger = logging.getLogger(__name__)
class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in Elasticsearch.
Args:
es_url: URL of the Elasticsearch instance to connect to.
es_cloud_id: Cloud ID of the Elasticsearch instance to connect to.
es_user: Username to use when connecting to Elasticsearch.
es_password: Password to use when connecting to Elasticsearch.
es_api_key: API key to use when connecting to Elasticsearch.
es_connection: Optional pre-existing Elasticsearch connection.
index: Name of the index to use.
session_id: Arbitrary key that is used to store the messages
of a single chat session.
"""
def __init__(
self,
index: str,
session_id: str,
*,
es_connection: Optional["Elasticsearch"] = None,
es_url: Optional[str] = None,
es_cloud_id: Optional[str] = None,
es_user: Optional[str] = None,
es_api_key: Optional[str] = None,
es_password: Optional[str] = None,
):
self.index: str = index
self.session_id: str = session_id
# Initialize Elasticsearch client from passed client arg or connection info
if es_connection is not None:
self.client = es_connection.options(
headers={"user-agent": self.get_user_agent()}
)
elif es_url is not None or es_cloud_id is not None:
self.client = ElasticsearchChatMessageHistory.connect_to_elasticsearch(
es_url=es_url,
username=es_user,
password=es_password,
cloud_id=es_cloud_id,
api_key=es_api_key,
)
else:
raise ValueError(
"""Either provide a pre-existing Elasticsearch connection, \
or valid credentials for creating a new connection."""
)
if self.client.indices.exists(index=index):
logger.debug(
f"Chat history index {index} already exists, skipping creation."
)
else:
logger.debug(f"Creating index {index} for storing chat history.")
self.client.indices.create(
index=index,
mappings={
"properties": {
"session_id": {"type": "keyword"},
"created_at": {"type": "date"},
"history": {"type": "text"},
}
},
)
@staticmethod
def get_user_agent() -> str:
from langchain import __version__
return f"langchain-py-ms/{__version__}"
@staticmethod
def connect_to_elasticsearch(
*,
es_url: Optional[str] = None,
cloud_id: Optional[str] = None,
api_key: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
) -> "Elasticsearch":
try:
import elasticsearch
except ImportError:
raise ImportError(
"Could not import elasticsearch python package. "
"Please install it with `pip install elasticsearch`."
)
if es_url and cloud_id:
raise ValueError(
"Both es_url and cloud_id are defined. Please provide only one."
)
connection_params: Dict[str, Any] = {}
if es_url:
connection_params["hosts"] = [es_url]
elif cloud_id:
connection_params["cloud_id"] = cloud_id
else:
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
if api_key:
connection_params["api_key"] = api_key
elif username and password:
connection_params["basic_auth"] = (username, password)
es_client = elasticsearch.Elasticsearch(
**connection_params,
headers={"user-agent": ElasticsearchChatMessageHistory.get_user_agent()},
)
try:
es_client.info()
except Exception as err:
logger.error(f"Error connecting to Elasticsearch: {err}")
raise err
return es_client
@property
def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from Elasticsearch"""
try:
from elasticsearch import ApiError
result = self.client.search(
index=self.index,
query={"term": {"session_id": self.session_id}},
sort="created_at:asc",
)
except ApiError as err:
logger.error(f"Could not retrieve messages from Elasticsearch: {err}")
raise err
if result and len(result["hits"]["hits"]) > 0:
items = [
json.loads(document["_source"]["history"])
for document in result["hits"]["hits"]
]
else:
items = []
return messages_from_dict(items)
def add_message(self, message: BaseMessage) -> None:
"""Add a message to the chat session in Elasticsearch"""
try:
from elasticsearch import ApiError
self.client.index(
index=self.index,
document={
"session_id": self.session_id,
"created_at": round(time() * 1000),
"history": json.dumps(message_to_dict(message)),
},
refresh=True,
)
except ApiError as err:
logger.error(f"Could not add message to Elasticsearch: {err}")
raise err
def clear(self) -> None:
"""Clear session memory in Elasticsearch"""
try:
from elasticsearch import ApiError
self.client.delete_by_query(
index=self.index,
query={"term": {"session_id": self.session_id}},
refresh=True,
)
except ApiError as err:
logger.error(f"Could not clear session memory in Elasticsearch: {err}")
raise err

View File

@@ -0,0 +1,45 @@
import json
import logging
from pathlib import Path
from typing import List
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
messages_from_dict,
messages_to_dict,
)
logger = logging.getLogger(__name__)
class FileChatMessageHistory(BaseChatMessageHistory):
"""
Chat message history that stores history in a local file.
Args:
file_path: path of the local file to store the messages.
"""
def __init__(self, file_path: str):
self.file_path = Path(file_path)
if not self.file_path.exists():
self.file_path.touch()
self.file_path.write_text(json.dumps([]))
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from the local file"""
items = json.loads(self.file_path.read_text())
messages = messages_from_dict(items)
return messages
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in the local file"""
messages = messages_to_dict(self.messages)
messages.append(messages_to_dict([message])[0])
self.file_path.write_text(json.dumps(messages))
def clear(self) -> None:
"""Clear session memory from the local file"""
self.file_path.write_text(json.dumps([]))

View File

@@ -0,0 +1,105 @@
"""Firestore Chat Message History."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, List, Optional
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
messages_from_dict,
messages_to_dict,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from google.cloud.firestore import Client, DocumentReference
def _get_firestore_client() -> Client:
try:
import firebase_admin
from firebase_admin import firestore
except ImportError:
raise ImportError(
"Could not import firebase-admin python package. "
"Please install it with `pip install firebase-admin`."
)
# For multiple instances, only initialize the app once.
try:
firebase_admin.get_app()
except ValueError as e:
logger.debug("Initializing Firebase app: %s", e)
firebase_admin.initialize_app()
return firestore.client()
class FirestoreChatMessageHistory(BaseChatMessageHistory):
"""Chat message history backed by Google Firestore."""
def __init__(
self,
collection_name: str,
session_id: str,
user_id: str,
firestore_client: Optional[Client] = None,
):
"""
Initialize a new instance of the FirestoreChatMessageHistory class.
:param collection_name: The name of the collection to use.
:param session_id: The session ID for the chat..
:param user_id: The user ID for the chat.
"""
self.collection_name = collection_name
self.session_id = session_id
self.user_id = user_id
self._document: Optional[DocumentReference] = None
self.messages: List[BaseMessage] = []
self.firestore_client = firestore_client or _get_firestore_client()
self.prepare_firestore()
def prepare_firestore(self) -> None:
"""Prepare the Firestore client.
Use this function to make sure your database is ready.
"""
self._document = self.firestore_client.collection(
self.collection_name
).document(self.session_id)
self.load_messages()
def load_messages(self) -> None:
"""Retrieve the messages from Firestore"""
if not self._document:
raise ValueError("Document not initialized")
doc = self._document.get()
if doc.exists:
data = doc.to_dict()
if "messages" in data and len(data["messages"]) > 0:
self.messages = messages_from_dict(data["messages"])
def add_message(self, message: BaseMessage) -> None:
self.messages.append(message)
self.upsert_messages()
def upsert_messages(self, new_message: Optional[BaseMessage] = None) -> None:
"""Update the Firestore document."""
if not self._document:
raise ValueError("Document not initialized")
self._document.set(
{
"id": self.session_id,
"user_id": self.user_id,
"messages": messages_to_dict(self.messages),
}
)
def clear(self) -> None:
"""Clear session memory from this memory and Firestore."""
self.messages = []
if self._document:
self._document.delete()

View File

@@ -0,0 +1,21 @@
from typing import List
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel, Field
class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
"""In memory implementation of chat message history.
Stores messages in an in memory list.
"""
messages: List[BaseMessage] = Field(default_factory=list)
def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store"""
self.messages.append(message)
def clear(self) -> None:
self.messages = []

View File

@@ -0,0 +1,190 @@
from __future__ import annotations
import json
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
from langchain_core.utils import get_from_env
if TYPE_CHECKING:
import momento
def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None:
"""Create cache if it doesn't exist.
Raises:
SdkException: Momento service or network error
Exception: Unexpected response
"""
from momento.responses import CreateCache
create_cache_response = cache_client.create_cache(cache_name)
if isinstance(create_cache_response, CreateCache.Success) or isinstance(
create_cache_response, CreateCache.CacheAlreadyExists
):
return None
elif isinstance(create_cache_response, CreateCache.Error):
raise create_cache_response.inner_exception
else:
raise Exception(f"Unexpected response cache creation: {create_cache_response}")
class MomentoChatMessageHistory(BaseChatMessageHistory):
"""Chat message history cache that uses Momento as a backend.
See https://gomomento.com/"""
def __init__(
self,
session_id: str,
cache_client: momento.CacheClient,
cache_name: str,
*,
key_prefix: str = "message_store:",
ttl: Optional[timedelta] = None,
ensure_cache_exists: bool = True,
):
"""Instantiate a chat message history cache that uses Momento as a backend.
Note: to instantiate the cache client passed to MomentoChatMessageHistory,
you must have a Momento account at https://gomomento.com/.
Args:
session_id (str): The session ID to use for this chat session.
cache_client (CacheClient): The Momento cache client.
cache_name (str): The name of the cache to use to store the messages.
key_prefix (str, optional): The prefix to apply to the cache key.
Defaults to "message_store:".
ttl (Optional[timedelta], optional): The TTL to use for the messages.
Defaults to None, ie the default TTL of the cache will be used.
ensure_cache_exists (bool, optional): Create the cache if it doesn't exist.
Defaults to True.
Raises:
ImportError: Momento python package is not installed.
TypeError: cache_client is not of type momento.CacheClientObject
"""
try:
from momento import CacheClient
from momento.requests import CollectionTtl
except ImportError:
raise ImportError(
"Could not import momento python package. "
"Please install it with `pip install momento`."
)
if not isinstance(cache_client, CacheClient):
raise TypeError("cache_client must be a momento.CacheClient object.")
if ensure_cache_exists:
_ensure_cache_exists(cache_client, cache_name)
self.key = key_prefix + session_id
self.cache_client = cache_client
self.cache_name = cache_name
if ttl is not None:
self.ttl = CollectionTtl.of(ttl)
else:
self.ttl = CollectionTtl.from_cache_ttl()
@classmethod
def from_client_params(
cls,
session_id: str,
cache_name: str,
ttl: timedelta,
*,
configuration: Optional[momento.config.Configuration] = None,
api_key: Optional[str] = None,
auth_token: Optional[str] = None, # for backwards compatibility
**kwargs: Any,
) -> MomentoChatMessageHistory:
"""Construct cache from CacheClient parameters."""
try:
from momento import CacheClient, Configurations, CredentialProvider
except ImportError:
raise ImportError(
"Could not import momento python package. "
"Please install it with `pip install momento`."
)
if configuration is None:
configuration = Configurations.Laptop.v1()
# Try checking `MOMENTO_AUTH_TOKEN` first for backwards compatibility
try:
api_key = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN")
except ValueError:
api_key = api_key or get_from_env("api_key", "MOMENTO_API_KEY")
credentials = CredentialProvider.from_string(api_key)
cache_client = CacheClient(configuration, credentials, default_ttl=ttl)
return cls(session_id, cache_client, cache_name, ttl=ttl, **kwargs)
@property
def messages(self) -> list[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from Momento.
Raises:
SdkException: Momento service or network error
Exception: Unexpected response
Returns:
list[BaseMessage]: List of cached messages
"""
from momento.responses import CacheListFetch
fetch_response = self.cache_client.list_fetch(self.cache_name, self.key)
if isinstance(fetch_response, CacheListFetch.Hit):
items = [json.loads(m) for m in fetch_response.value_list_string]
return messages_from_dict(items)
elif isinstance(fetch_response, CacheListFetch.Miss):
return []
elif isinstance(fetch_response, CacheListFetch.Error):
raise fetch_response.inner_exception
else:
raise Exception(f"Unexpected response: {fetch_response}")
def add_message(self, message: BaseMessage) -> None:
"""Store a message in the cache.
Args:
message (BaseMessage): The message object to store.
Raises:
SdkException: Momento service or network error.
Exception: Unexpected response.
"""
from momento.responses import CacheListPushBack
item = json.dumps(message_to_dict(message))
push_response = self.cache_client.list_push_back(
self.cache_name, self.key, item, ttl=self.ttl
)
if isinstance(push_response, CacheListPushBack.Success):
return None
elif isinstance(push_response, CacheListPushBack.Error):
raise push_response.inner_exception
else:
raise Exception(f"Unexpected response: {push_response}")
def clear(self) -> None:
"""Remove the session's messages from the cache.
Raises:
SdkException: Momento service or network error.
Exception: Unexpected response.
"""
from momento.responses import CacheDelete
delete_response = self.cache_client.delete(self.cache_name, self.key)
if isinstance(delete_response, CacheDelete.Success):
return None
elif isinstance(delete_response, CacheDelete.Error):
raise delete_response.inner_exception
else:
raise Exception(f"Unexpected response: {delete_response}")

View File

@@ -0,0 +1,91 @@
import json
import logging
from typing import List
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
logger = logging.getLogger(__name__)
DEFAULT_DBNAME = "chat_history"
DEFAULT_COLLECTION_NAME = "message_store"
class MongoDBChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in MongoDB.
Args:
connection_string: connection string to connect to MongoDB
session_id: arbitrary key that is used to store the messages
of a single chat session.
database_name: name of the database to use
collection_name: name of the collection to use
"""
def __init__(
self,
connection_string: str,
session_id: str,
database_name: str = DEFAULT_DBNAME,
collection_name: str = DEFAULT_COLLECTION_NAME,
):
from pymongo import MongoClient, errors
self.connection_string = connection_string
self.session_id = session_id
self.database_name = database_name
self.collection_name = collection_name
try:
self.client: MongoClient = MongoClient(connection_string)
except errors.ConnectionFailure as error:
logger.error(error)
self.db = self.client[database_name]
self.collection = self.db[collection_name]
self.collection.create_index("SessionId")
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from MongoDB"""
from pymongo import errors
try:
cursor = self.collection.find({"SessionId": self.session_id})
except errors.OperationFailure as error:
logger.error(error)
if cursor:
items = [json.loads(document["History"]) for document in cursor]
else:
items = []
messages = messages_from_dict(items)
return messages
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in MongoDB"""
from pymongo import errors
try:
self.collection.insert_one(
{
"SessionId": self.session_id,
"History": json.dumps(message_to_dict(message)),
}
)
except errors.WriteError as err:
logger.error(err)
def clear(self) -> None:
"""Clear session memory from MongoDB"""
from pymongo import errors
try:
self.collection.delete_many({"SessionId": self.session_id})
except errors.WriteError as err:
logger.error(err)

View File

@@ -0,0 +1,113 @@
from typing import List, Optional, Union
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, messages_from_dict
from langchain_core.utils import get_from_env
class Neo4jChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in a Neo4j database."""
def __init__(
self,
session_id: Union[str, int],
url: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
database: str = "neo4j",
node_label: str = "Session",
window: int = 3,
):
try:
import neo4j
except ImportError:
raise ValueError(
"Could not import neo4j python package. "
"Please install it with `pip install neo4j`."
)
# Make sure session id is not null
if not session_id:
raise ValueError("Please ensure that the session_id parameter is provided")
url = get_from_env("url", "NEO4J_URI", url)
username = get_from_env("username", "NEO4J_USERNAME", username)
password = get_from_env("password", "NEO4J_PASSWORD", password)
database = get_from_env("database", "NEO4J_DATABASE", database)
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
self._database = database
self._session_id = session_id
self._node_label = node_label
self._window = window
# Verify connection
try:
self._driver.verify_connectivity()
except neo4j.exceptions.ServiceUnavailable:
raise ValueError(
"Could not connect to Neo4j database. "
"Please ensure that the url is correct"
)
except neo4j.exceptions.AuthError:
raise ValueError(
"Could not connect to Neo4j database. "
"Please ensure that the username and password are correct"
)
# Create session node
self._driver.execute_query(
f"MERGE (s:`{self._node_label}` {{id:$session_id}})",
{"session_id": self._session_id},
).summary
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from Neo4j"""
query = (
f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) "
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT*0.."
f"{self._window*2}]-() WITH p, length(p) AS length "
"ORDER BY length DESC LIMIT 1 UNWIND reverse(nodes(p)) AS node "
"RETURN {data:{content: node.content}, type:node.type} AS result"
)
records, _, _ = self._driver.execute_query(
query, {"session_id": self._session_id}
)
messages = messages_from_dict([el["result"] for el in records])
return messages
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in Neo4j"""
query = (
f"MATCH (s:`{self._node_label}`) WHERE s.id = $session_id "
"OPTIONAL MATCH (s)-[lm:LAST_MESSAGE]->(last_message) "
"CREATE (s)-[:LAST_MESSAGE]->(new:Message) "
"SET new += {type:$type, content:$content} "
"WITH new, lm, last_message WHERE last_message IS NOT NULL "
"CREATE (last_message)-[:NEXT]->(new) "
"DELETE lm"
)
self._driver.execute_query(
query,
{
"type": message.type,
"content": message.content,
"session_id": self._session_id,
},
).summary
def clear(self) -> None:
"""Clear session memory from Neo4j"""
query = (
f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) "
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT]-() "
"WITH p, length(p) AS length ORDER BY length DESC LIMIT 1 "
"UNWIND nodes(p) as node DETACH DELETE node;"
)
self._driver.execute_query(query, {"session_id": self._session_id}).summary
def __del__(self) -> None:
if self._driver:
self._driver.close()

View File

@@ -0,0 +1,82 @@
import json
import logging
from typing import List
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
logger = logging.getLogger(__name__)
DEFAULT_CONNECTION_STRING = "postgresql://postgres:mypassword@localhost/chat_history"
class PostgresChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in a Postgres database."""
def __init__(
self,
session_id: str,
connection_string: str = DEFAULT_CONNECTION_STRING,
table_name: str = "message_store",
):
import psycopg
from psycopg.rows import dict_row
try:
self.connection = psycopg.connect(connection_string)
self.cursor = self.connection.cursor(row_factory=dict_row)
except psycopg.OperationalError as error:
logger.error(error)
self.session_id = session_id
self.table_name = table_name
self._create_table_if_not_exists()
def _create_table_if_not_exists(self) -> None:
create_table_query = f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
id SERIAL PRIMARY KEY,
session_id TEXT NOT NULL,
message JSONB NOT NULL
);"""
self.cursor.execute(create_table_query)
self.connection.commit()
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from PostgreSQL"""
query = (
f"SELECT message FROM {self.table_name} WHERE session_id = %s ORDER BY id;"
)
self.cursor.execute(query, (self.session_id,))
items = [record["message"] for record in self.cursor.fetchall()]
messages = messages_from_dict(items)
return messages
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in PostgreSQL"""
from psycopg import sql
query = sql.SQL("INSERT INTO {} (session_id, message) VALUES (%s, %s);").format(
sql.Identifier(self.table_name)
)
self.cursor.execute(
query, (self.session_id, json.dumps(message_to_dict(message)))
)
self.connection.commit()
def clear(self) -> None:
"""Clear session memory from PostgreSQL"""
query = f"DELETE FROM {self.table_name} WHERE session_id = %s;"
self.cursor.execute(query, (self.session_id,))
self.connection.commit()
def __del__(self) -> None:
if self.cursor:
self.cursor.close()
if self.connection:
self.connection.close()

View File

@@ -0,0 +1,65 @@
import json
import logging
from typing import List, Optional
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
from langchain_integrations.utilities.redis import get_client
logger = logging.getLogger(__name__)
class RedisChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in a Redis database."""
def __init__(
self,
session_id: str,
url: str = "redis://localhost:6379/0",
key_prefix: str = "message_store:",
ttl: Optional[int] = None,
):
try:
import redis
except ImportError:
raise ImportError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
try:
self.redis_client = get_client(redis_url=url)
except redis.exceptions.ConnectionError as error:
logger.error(error)
self.session_id = session_id
self.key_prefix = key_prefix
self.ttl = ttl
@property
def key(self) -> str:
"""Construct the record key to use"""
return self.key_prefix + self.session_id
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from Redis"""
_items = self.redis_client.lrange(self.key, 0, -1)
items = [json.loads(m.decode("utf-8")) for m in _items[::-1]]
messages = messages_from_dict(items)
return messages
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in Redis"""
self.redis_client.lpush(self.key, json.dumps(message_to_dict(message)))
if self.ttl:
self.redis_client.expire(self.key, self.ttl)
def clear(self) -> None:
"""Clear session memory from Redis"""
self.redis_client.delete(self.key)

View File

@@ -0,0 +1,268 @@
from datetime import datetime
from time import sleep
from typing import Any, Callable, List, Union
from uuid import uuid4
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
class RocksetChatMessageHistory(BaseChatMessageHistory):
"""Uses Rockset to store chat messages.
To use, ensure that the `rockset` python package installed.
Example:
.. code-block:: python
from langchain_integrations.chat_message_histories import (
RocksetChatMessageHistory
)
from rockset import RocksetClient
history = RocksetChatMessageHistory(
session_id="MySession",
client=RocksetClient(),
collection="langchain_demo",
sync=True
)
history.add_user_message("hi!")
history.add_ai_message("whats up?")
print(history.messages)
"""
# You should set these values based on your VI.
# These values are configured for the typical
# free VI. Read more about VIs here:
# https://rockset.com/docs/instances
SLEEP_INTERVAL_MS: int = 5
ADD_TIMEOUT_MS: int = 5000
CREATE_TIMEOUT_MS: int = 20000
def _wait_until(self, method: Callable, timeout: int, **method_params: Any) -> None:
"""Sleeps until meth() evaluates to true. Passes kwargs into
meth.
"""
start = datetime.now()
while not method(**method_params):
curr = datetime.now()
if (curr - start).total_seconds() * 1000 > timeout:
raise TimeoutError(f"{method} timed out at {timeout} ms")
sleep(RocksetChatMessageHistory.SLEEP_INTERVAL_MS / 1000)
def _query(self, query: str, **query_params: Any) -> List[Any]:
"""Executes an SQL statement and returns the result
Args:
- query: The SQL string
- **query_params: Parameters to pass into the query
"""
return self.client.sql(query, params=query_params).results
def _create_collection(self) -> None:
"""Creates a collection for this message history"""
self.client.Collections.create_s3_collection(
name=self.collection, workspace=self.workspace
)
def _collection_exists(self) -> bool:
"""Checks whether a collection exists for this message history"""
try:
self.client.Collections.get(collection=self.collection)
except self.rockset.exceptions.NotFoundException:
return False
return True
def _collection_is_ready(self) -> bool:
"""Checks whether the collection for this message history is ready
to be queried
"""
return (
self.client.Collections.get(collection=self.collection).data.status
== "READY"
)
def _document_exists(self) -> bool:
return (
len(
self._query(
f"""
SELECT 1
FROM {self.location}
WHERE _id=:session_id
LIMIT 1
""",
session_id=self.session_id,
)
)
!= 0
)
def _wait_until_collection_created(self) -> None:
"""Sleeps until the collection for this message history is ready
to be queried
"""
self._wait_until(
lambda: self._collection_is_ready(),
RocksetChatMessageHistory.CREATE_TIMEOUT_MS,
)
def _wait_until_message_added(self, message_id: str) -> None:
"""Sleeps until a message is added to the messages list"""
self._wait_until(
lambda message_id: len(
self._query(
f"""
SELECT *
FROM UNNEST((
SELECT {self.messages_key}
FROM {self.location}
WHERE _id = :session_id
)) AS message
WHERE message.data.additional_kwargs.id = :message_id
LIMIT 1
""",
session_id=self.session_id,
message_id=message_id,
),
)
!= 0,
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
message_id=message_id,
)
def _create_empty_doc(self) -> None:
"""Creates or replaces a document for this message history with no
messages"""
self.client.Documents.add_documents(
collection=self.collection,
workspace=self.workspace,
data=[{"_id": self.session_id, self.messages_key: []}],
)
def __init__(
self,
session_id: str,
client: Any,
collection: str,
workspace: str = "commons",
messages_key: str = "messages",
sync: bool = False,
message_uuid_method: Callable[[], Union[str, int]] = lambda: str(uuid4()),
) -> None:
"""Constructs a new RocksetChatMessageHistory.
Args:
- session_id: The ID of the chat session
- client: The RocksetClient object to use to query
- collection: The name of the collection to use to store chat
messages. If a collection with the given name
does not exist in the workspace, it is created.
- workspace: The workspace containing `collection`. Defaults
to `"commons"`
- messages_key: The DB column containing message history.
Defaults to `"messages"`
- sync: Whether to wait for messages to be added. Defaults
to `False`. NOTE: setting this to `True` will slow
down performance.
- message_uuid_method: The method that generates message IDs.
If set, all messages will have an `id` field within the
`additional_kwargs` property. If this param is not set
and `sync` is `False`, message IDs will not be created.
If this param is not set and `sync` is `True`, the
`uuid.uuid4` method will be used to create message IDs.
"""
try:
import rockset
except ImportError:
raise ImportError(
"Could not import rockset client python package. "
"Please install it with `pip install rockset`."
)
if not isinstance(client, rockset.RocksetClient):
raise ValueError(
f"client should be an instance of rockset.RocksetClient, "
f"got {type(client)}"
)
self.session_id = session_id
self.client = client
self.collection = collection
self.workspace = workspace
self.location = f'"{self.workspace}"."{self.collection}"'
self.rockset = rockset
self.messages_key = messages_key
self.message_uuid_method = message_uuid_method
self.sync = sync
try:
self.client.set_application("langchain")
except AttributeError:
# ignore
pass
if not self._collection_exists():
self._create_collection()
self._wait_until_collection_created()
self._create_empty_doc()
elif not self._document_exists():
self._create_empty_doc()
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Messages in this chat history."""
return messages_from_dict(
self._query(
f"""
SELECT *
FROM UNNEST ((
SELECT "{self.messages_key}"
FROM {self.location}
WHERE _id = :session_id
))
""",
session_id=self.session_id,
)
)
def add_message(self, message: BaseMessage) -> None:
"""Add a Message object to the history.
Args:
message: A BaseMessage object to store.
"""
if self.sync and "id" not in message.additional_kwargs:
message.additional_kwargs["id"] = self.message_uuid_method()
self.client.Documents.patch_documents(
collection=self.collection,
workspace=self.workspace,
data=[
self.rockset.model.patch_document.PatchDocument(
id=self.session_id,
patch=[
self.rockset.model.patch_operation.PatchOperation(
op="ADD",
path=f"/{self.messages_key}/-",
value=message_to_dict(message),
)
],
)
],
)
if self.sync:
self._wait_until_message_added(message.additional_kwargs["id"])
def clear(self) -> None:
"""Removes all messages from the chat history"""
self._create_empty_doc()
if self.sync:
self._wait_until(
lambda: not self.messages,
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
)

View File

@@ -0,0 +1,277 @@
import json
import logging
import re
from typing import (
Any,
List,
)
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
logger = logging.getLogger(__name__)
class SingleStoreDBChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in a SingleStoreDB database."""
def __init__(
self,
session_id: str,
*,
table_name: str = "message_store",
id_field: str = "id",
session_id_field: str = "session_id",
message_field: str = "message",
pool_size: int = 5,
max_overflow: int = 10,
timeout: float = 30,
**kwargs: Any,
):
"""Initialize with necessary components.
Args:
table_name (str, optional): Specifies the name of the table in use.
Defaults to "message_store".
id_field (str, optional): Specifies the name of the id field in the table.
Defaults to "id".
session_id_field (str, optional): Specifies the name of the session_id
field in the table. Defaults to "session_id".
message_field (str, optional): Specifies the name of the message field
in the table. Defaults to "message".
Following arguments pertain to the connection pool:
pool_size (int, optional): Determines the number of active connections in
the pool. Defaults to 5.
max_overflow (int, optional): Determines the maximum number of connections
allowed beyond the pool_size. Defaults to 10.
timeout (float, optional): Specifies the maximum wait time in seconds for
establishing a connection. Defaults to 30.
Following arguments pertain to the database connection:
host (str, optional): Specifies the hostname, IP address, or URL for the
database connection. The default scheme is "mysql".
user (str, optional): Database username.
password (str, optional): Database password.
port (int, optional): Database port. Defaults to 3306 for non-HTTP
connections, 80 for HTTP connections, and 443 for HTTPS connections.
database (str, optional): Database name.
Additional optional arguments provide further customization over the
database connection:
pure_python (bool, optional): Toggles the connector mode. If True,
operates in pure Python mode.
local_infile (bool, optional): Allows local file uploads.
charset (str, optional): Specifies the character set for string values.
ssl_key (str, optional): Specifies the path of the file containing the SSL
key.
ssl_cert (str, optional): Specifies the path of the file containing the SSL
certificate.
ssl_ca (str, optional): Specifies the path of the file containing the SSL
certificate authority.
ssl_cipher (str, optional): Sets the SSL cipher list.
ssl_disabled (bool, optional): Disables SSL usage.
ssl_verify_cert (bool, optional): Verifies the server's certificate.
Automatically enabled if ``ssl_ca`` is specified.
ssl_verify_identity (bool, optional): Verifies the server's identity.
conv (dict[int, Callable], optional): A dictionary of data conversion
functions.
credential_type (str, optional): Specifies the type of authentication to
use: auth.PASSWORD, auth.JWT, or auth.BROWSER_SSO.
autocommit (bool, optional): Enables autocommits.
results_type (str, optional): Determines the structure of the query results:
tuples, namedtuples, dicts.
results_format (str, optional): Deprecated. This option has been renamed to
results_type.
Examples:
Basic Usage:
.. code-block:: python
from langchain_integrations.chat_message_histories import (
SingleStoreDBChatMessageHistory
)
message_history = SingleStoreDBChatMessageHistory(
session_id="my-session",
host="https://user:password@127.0.0.1:3306/database"
)
Advanced Usage:
.. code-block:: python
from langchain_integrations.chat_message_histories import (
SingleStoreDBChatMessageHistory
)
message_history = SingleStoreDBChatMessageHistory(
session_id="my-session",
host="127.0.0.1",
port=3306,
user="user",
password="password",
database="db",
table_name="my_custom_table",
pool_size=10,
timeout=60,
)
Using environment variables:
.. code-block:: python
from langchain_integrations.chat_message_histories import (
SingleStoreDBChatMessageHistory
)
os.environ['SINGLESTOREDB_URL'] = 'me:p455w0rd@s2-host.com/my_db'
message_history = SingleStoreDBChatMessageHistory("my-session")
"""
self.table_name = self._sanitize_input(table_name)
self.session_id = self._sanitize_input(session_id)
self.id_field = self._sanitize_input(id_field)
self.session_id_field = self._sanitize_input(session_id_field)
self.message_field = self._sanitize_input(message_field)
# Pass the rest of the kwargs to the connection.
self.connection_kwargs = kwargs
# Add connection attributes to the connection kwargs.
if "conn_attrs" not in self.connection_kwargs:
self.connection_kwargs["conn_attrs"] = dict()
self.connection_kwargs["conn_attrs"]["_connector_name"] = "langchain python sdk"
self.connection_kwargs["conn_attrs"]["_connector_version"] = "1.0.1"
# Create a connection pool.
try:
from sqlalchemy.pool import QueuePool
except ImportError:
raise ImportError(
"Could not import sqlalchemy.pool python package. "
"Please install it with `pip install singlestoredb`."
)
self.connection_pool = QueuePool(
self._get_connection,
max_overflow=max_overflow,
pool_size=pool_size,
timeout=timeout,
)
self.table_created = False
def _sanitize_input(self, input_str: str) -> str:
# Remove characters that are not alphanumeric or underscores
return re.sub(r"[^a-zA-Z0-9_]", "", input_str)
def _get_connection(self) -> Any:
try:
import singlestoredb as s2
except ImportError:
raise ImportError(
"Could not import singlestoredb python package. "
"Please install it with `pip install singlestoredb`."
)
return s2.connect(**self.connection_kwargs)
def _create_table_if_not_exists(self) -> None:
"""Create table if it doesn't exist."""
if self.table_created:
return
conn = self.connection_pool.connect()
try:
cur = conn.cursor()
try:
cur.execute(
"""CREATE TABLE IF NOT EXISTS {}
({} BIGINT PRIMARY KEY AUTO_INCREMENT,
{} TEXT NOT NULL,
{} JSON NOT NULL);""".format(
self.table_name,
self.id_field,
self.session_id_field,
self.message_field,
),
)
self.table_created = True
finally:
cur.close()
finally:
conn.close()
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from SingleStoreDB"""
self._create_table_if_not_exists()
conn = self.connection_pool.connect()
items = []
try:
cur = conn.cursor()
try:
cur.execute(
"""SELECT {} FROM {} WHERE {} = %s""".format(
self.message_field,
self.table_name,
self.session_id_field,
),
(self.session_id),
)
for row in cur.fetchall():
items.append(row[0])
finally:
cur.close()
finally:
conn.close()
messages = messages_from_dict(items)
return messages
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in SingleStoreDB"""
self._create_table_if_not_exists()
conn = self.connection_pool.connect()
try:
cur = conn.cursor()
try:
cur.execute(
"""INSERT INTO {} ({}, {}) VALUES (%s, %s)""".format(
self.table_name,
self.session_id_field,
self.message_field,
),
(self.session_id, json.dumps(message_to_dict(message))),
)
finally:
cur.close()
finally:
conn.close()
def clear(self) -> None:
"""Clear session memory from SingleStoreDB"""
self._create_table_if_not_exists()
conn = self.connection_pool.connect()
try:
cur = conn.cursor()
try:
cur.execute(
"""DELETE FROM {} WHERE {} = %s""".format(
self.table_name,
self.session_id_field,
),
(self.session_id),
)
finally:
cur.close()
finally:
conn.close()

View File

@@ -0,0 +1,140 @@
import json
import logging
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from sqlalchemy import Column, Integer, Text, create_engine
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
from sqlalchemy.orm import sessionmaker
logger = logging.getLogger(__name__)
class BaseMessageConverter(ABC):
"""The class responsible for converting BaseMessage to your SQLAlchemy model."""
@abstractmethod
def from_sql_model(self, sql_message: Any) -> BaseMessage:
"""Convert a SQLAlchemy model to a BaseMessage instance."""
raise NotImplementedError
@abstractmethod
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
"""Convert a BaseMessage instance to a SQLAlchemy model."""
raise NotImplementedError
@abstractmethod
def get_sql_model_class(self) -> Any:
"""Get the SQLAlchemy model class."""
raise NotImplementedError
def create_message_model(table_name, DynamicBase): # type: ignore
"""
Create a message model for a given table name.
Args:
table_name: The name of the table to use.
DynamicBase: The base class to use for the model.
Returns:
The model class.
"""
# Model decleared inside a function to have a dynamic table name
class Message(DynamicBase):
__tablename__ = table_name
id = Column(Integer, primary_key=True)
session_id = Column(Text)
message = Column(Text)
return Message
class DefaultMessageConverter(BaseMessageConverter):
"""The default message converter for SQLChatMessageHistory."""
def __init__(self, table_name: str):
self.model_class = create_message_model(table_name, declarative_base())
def from_sql_model(self, sql_message: Any) -> BaseMessage:
return messages_from_dict([json.loads(sql_message.message)])[0]
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
return self.model_class(
session_id=session_id, message=json.dumps(message_to_dict(message))
)
def get_sql_model_class(self) -> Any:
return self.model_class
class SQLChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in an SQL database."""
def __init__(
self,
session_id: str,
connection_string: str,
table_name: str = "message_store",
session_id_field_name: str = "session_id",
custom_message_converter: Optional[BaseMessageConverter] = None,
):
self.connection_string = connection_string
self.engine = create_engine(connection_string, echo=False)
self.session_id_field_name = session_id_field_name
self.converter = custom_message_converter or DefaultMessageConverter(table_name)
self.sql_model_class = self.converter.get_sql_model_class()
if not hasattr(self.sql_model_class, session_id_field_name):
raise ValueError("SQL model class must have session_id column")
self._create_table_if_not_exists()
self.session_id = session_id
self.Session = sessionmaker(self.engine)
def _create_table_if_not_exists(self) -> None:
self.sql_model_class.metadata.create_all(self.engine)
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve all messages from db"""
with self.Session() as session:
result = (
session.query(self.sql_model_class)
.where(
getattr(self.sql_model_class, self.session_id_field_name)
== self.session_id
)
.order_by(self.sql_model_class.id.asc())
)
messages = []
for record in result:
messages.append(self.converter.from_sql_model(record))
return messages
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in db"""
with self.Session() as session:
session.add(self.converter.to_sql_model(message, self.session_id))
session.commit()
def clear(self) -> None:
"""Clear session memory from db"""
with self.Session() as session:
session.query(self.sql_model_class).filter(
getattr(self.sql_model_class, self.session_id_field_name)
== self.session_id
).delete()
session.commit()

View File

@@ -0,0 +1,38 @@
from typing import List
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage
class StreamlitChatMessageHistory(BaseChatMessageHistory):
"""
Chat message history that stores messages in Streamlit session state.
Args:
key: The key to use in Streamlit session state for storing messages.
"""
def __init__(self, key: str = "langchain_messages"):
try:
import streamlit as st
except ImportError as e:
raise ImportError(
"Unable to import streamlit, please run `pip install streamlit`."
) from e
if key not in st.session_state:
st.session_state[key] = []
self._messages = st.session_state[key]
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the current list of messages"""
return self._messages
def add_message(self, message: BaseMessage) -> None:
"""Add a message to the session memory"""
self._messages.append(message)
def clear(self) -> None:
"""Clear session memory"""
self._messages.clear()

View File

@@ -0,0 +1,69 @@
import json
import logging
from typing import List, Optional
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
logger = logging.getLogger(__name__)
class UpstashRedisChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in an Upstash Redis database."""
def __init__(
self,
session_id: str,
url: str = "",
token: str = "",
key_prefix: str = "message_store:",
ttl: Optional[int] = None,
):
try:
from upstash_redis import Redis
except ImportError:
raise ImportError(
"Could not import upstash redis python package. "
"Please install it with `pip install upstash_redis`."
)
if url == "" or token == "":
raise ValueError(
"UPSTASH_REDIS_REST_URL and UPSTASH_REDIS_REST_TOKEN are needed."
)
try:
self.redis_client = Redis(url=url, token=token)
except Exception:
logger.error("Upstash Redis instance could not be initiated.")
self.session_id = session_id
self.key_prefix = key_prefix
self.ttl = ttl
@property
def key(self) -> str:
"""Construct the record key to use"""
return self.key_prefix + self.session_id
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from Upstash Redis"""
_items = self.redis_client.lrange(self.key, 0, -1)
items = [json.loads(m) for m in _items[::-1]]
messages = messages_from_dict(items)
return messages
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in Upstash Redis"""
self.redis_client.lpush(self.key, json.dumps(message_to_dict(message)))
if self.ttl:
self.redis_client.expire(self.key, self.ttl)
def clear(self) -> None:
"""Clear session memory from Upstash Redis"""
self.redis_client.delete(self.key)

View File

@@ -0,0 +1,134 @@
import json
from typing import List
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
class XataChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in a Xata database."""
def __init__(
self,
session_id: str,
db_url: str,
api_key: str,
branch_name: str = "main",
table_name: str = "messages",
create_table: bool = True,
) -> None:
"""Initialize with Xata client."""
try:
from xata.client import XataClient # noqa: F401
except ImportError:
raise ValueError(
"Could not import xata python package. "
"Please install it with `pip install xata`."
)
self._client = XataClient(
api_key=api_key, db_url=db_url, branch_name=branch_name
)
self._table_name = table_name
self._session_id = session_id
if create_table:
self._create_table_if_not_exists()
def _create_table_if_not_exists(self) -> None:
r = self._client.table().get_schema(self._table_name)
if r.status_code <= 299:
return
if r.status_code != 404:
raise Exception(
f"Error checking if table exists in Xata: {r.status_code} {r}"
)
r = self._client.table().create(self._table_name)
if r.status_code > 299:
raise Exception(f"Error creating table in Xata: {r.status_code} {r}")
r = self._client.table().set_schema(
self._table_name,
payload={
"columns": [
{"name": "sessionId", "type": "string"},
{"name": "type", "type": "string"},
{"name": "role", "type": "string"},
{"name": "content", "type": "text"},
{"name": "name", "type": "string"},
{"name": "additionalKwargs", "type": "json"},
]
},
)
if r.status_code > 299:
raise Exception(f"Error setting table schema in Xata: {r.status_code} {r}")
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the Xata table"""
msg = message_to_dict(message)
r = self._client.records().insert(
self._table_name,
{
"sessionId": self._session_id,
"type": msg["type"],
"content": message.content,
"additionalKwargs": json.dumps(message.additional_kwargs),
"role": msg["data"].get("role"),
"name": msg["data"].get("name"),
},
)
if r.status_code > 299:
raise Exception(f"Error adding message to Xata: {r.status_code} {r}")
@property
def messages(self) -> List[BaseMessage]: # type: ignore
r = self._client.data().query(
self._table_name,
payload={
"filter": {
"sessionId": self._session_id,
},
"sort": {"xata.createdAt": "asc"},
},
)
if r.status_code != 200:
raise Exception(f"Error running query: {r.status_code} {r}")
msgs = messages_from_dict(
[
{
"type": m["type"],
"data": {
"content": m["content"],
"role": m.get("role"),
"name": m.get("name"),
"additional_kwargs": json.loads(m["additionalKwargs"]),
},
}
for m in r["records"]
]
)
return msgs
def clear(self) -> None:
"""Delete session from Xata table."""
while True:
r = self._client.data().query(
self._table_name,
payload={
"columns": ["id"],
"filter": {
"sessionId": self._session_id,
},
},
)
if r.status_code != 200:
raise Exception(f"Error running query: {r.status_code} {r}")
ids = [rec["id"] for rec in r["records"]]
if len(ids) == 0:
break
operations = [
{"delete": {"table": self._table_name, "id": id}} for id in ids
]
self._client.records().transaction(payload={"operations": operations})

Some files were not shown because too many files have changed in this diff Show More