mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-23 21:31:02 +00:00
Compare commits
2 Commits
v1.0
...
harrison/i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d22dd95dfa | ||
|
|
310e946124 |
@@ -22,6 +22,7 @@ from langchain_core.utils.utils import (
|
||||
raise_for_status_with_text,
|
||||
xor_args,
|
||||
)
|
||||
from langchain_core.utils.env import get_from_dict_or_env, get_from_env
|
||||
|
||||
__all__ = [
|
||||
"StrictFormatter",
|
||||
@@ -39,4 +40,6 @@ __all__ = [
|
||||
"xor_args",
|
||||
"try_load_from_hub",
|
||||
"build_extra_kwargs",
|
||||
"get_from_dict_or_env",
|
||||
"get_from_env",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
def env_var_is_set(env_var: str) -> bool:
|
||||
@@ -18,3 +19,26 @@ def env_var_is_set(env_var: str) -> bool:
|
||||
"false",
|
||||
"False",
|
||||
)
|
||||
|
||||
def get_from_dict_or_env(
|
||||
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
|
||||
) -> str:
|
||||
"""Get a value from a dictionary or an environment variable."""
|
||||
if key in data and data[key]:
|
||||
return data[key]
|
||||
else:
|
||||
return get_from_env(key, env_key, default=default)
|
||||
|
||||
|
||||
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
|
||||
"""Get a value from a dictionary or an environment variable."""
|
||||
if env_key in os.environ and os.environ[env_key]:
|
||||
return os.environ[env_key]
|
||||
elif default is not None:
|
||||
return default
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Did not find {key}, please add an environment variable"
|
||||
f" `{env_key}` which contains it, or pass"
|
||||
f" `{key}` as a named parameter."
|
||||
)
|
||||
|
||||
264
libs/integrations/langchain_integrations/adapters/openai.py
Normal file
264
libs/integrations/langchain_integrations/adapters/openai.py
Normal file
@@ -0,0 +1,264 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Sequence,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
from langchain_core.chat_sessions import ChatSession
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
async def aenumerate(
|
||||
iterable: AsyncIterator[Any], start: int = 0
|
||||
) -> AsyncIterator[tuple[int, Any]]:
|
||||
"""Async version of enumerate function."""
|
||||
i = start
|
||||
async for x in iterable:
|
||||
yield i, x
|
||||
i += 1
|
||||
|
||||
|
||||
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
"""Convert a dictionary to a LangChain message.
|
||||
|
||||
Args:
|
||||
_dict: The dictionary.
|
||||
|
||||
Returns:
|
||||
The LangChain message.
|
||||
"""
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
# Fix for azure
|
||||
# Also OpenAI returns None for tool invocations
|
||||
content = _dict.get("content", "") or ""
|
||||
additional_kwargs: Dict = {}
|
||||
if _dict.get("function_call"):
|
||||
additional_kwargs["function_call"] = dict(_dict["function_call"])
|
||||
if _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = _dict["tool_calls"]
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict["content"])
|
||||
elif role == "function":
|
||||
return FunctionMessage(content=_dict["content"], name=_dict["name"])
|
||||
elif role == "tool":
|
||||
return ToolMessage(content=_dict["content"], tool_call_id=_dict["tool_call_id"])
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"""Convert a LangChain message to a dictionary.
|
||||
|
||||
Args:
|
||||
message: The LangChain message.
|
||||
|
||||
Returns:
|
||||
The dictionary.
|
||||
"""
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
# If function call only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
if "tool_calls" in message.additional_kwargs:
|
||||
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||
# If tool calls only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
elif isinstance(message, ToolMessage):
|
||||
message_dict = {
|
||||
"role": "tool",
|
||||
"content": message.content,
|
||||
"tool_call_id": message.tool_call_id,
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
|
||||
def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]:
|
||||
"""Convert dictionaries representing OpenAI messages to LangChain format.
|
||||
|
||||
Args:
|
||||
messages: List of dictionaries representing OpenAI messages
|
||||
|
||||
Returns:
|
||||
List of LangChain BaseMessage objects.
|
||||
"""
|
||||
return [convert_dict_to_message(m) for m in messages]
|
||||
|
||||
|
||||
def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]:
|
||||
_dict: Dict[str, Any] = {}
|
||||
if isinstance(chunk, AIMessageChunk):
|
||||
if i == 0:
|
||||
# Only shows up in the first chunk
|
||||
_dict["role"] = "assistant"
|
||||
if "function_call" in chunk.additional_kwargs:
|
||||
_dict["function_call"] = chunk.additional_kwargs["function_call"]
|
||||
# If the first chunk is a function call, the content is not empty string,
|
||||
# not missing, but None.
|
||||
if i == 0:
|
||||
_dict["content"] = None
|
||||
else:
|
||||
_dict["content"] = chunk.content
|
||||
else:
|
||||
raise ValueError(f"Got unexpected streaming chunk type: {type(chunk)}")
|
||||
# This only happens at the end of streams, and OpenAI returns as empty dict
|
||||
if _dict == {"content": ""}:
|
||||
_dict = {}
|
||||
return {"choices": [{"delta": _dict}]}
|
||||
|
||||
|
||||
class ChatCompletion:
|
||||
"""Chat completion."""
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def create(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: Literal[False] = False,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def create(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: Literal[True],
|
||||
**kwargs: Any,
|
||||
) -> Iterable:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[dict, Iterable]:
|
||||
models = importlib.import_module("langchain.chat_models")
|
||||
model_cls = getattr(models, provider)
|
||||
model_config = model_cls(**kwargs)
|
||||
converted_messages = convert_openai_messages(messages)
|
||||
if not stream:
|
||||
result = model_config.invoke(converted_messages)
|
||||
return {"choices": [{"message": convert_message_to_dict(result)}]}
|
||||
else:
|
||||
return (
|
||||
_convert_message_chunk_to_delta(c, i)
|
||||
for i, c in enumerate(model_config.stream(converted_messages))
|
||||
)
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
async def acreate(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: Literal[False] = False,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
async def acreate(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: Literal[True],
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
async def acreate(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[dict, AsyncIterator]:
|
||||
models = importlib.import_module("langchain.chat_models")
|
||||
model_cls = getattr(models, provider)
|
||||
model_config = model_cls(**kwargs)
|
||||
converted_messages = convert_openai_messages(messages)
|
||||
if not stream:
|
||||
result = await model_config.ainvoke(converted_messages)
|
||||
return {"choices": [{"message": convert_message_to_dict(result)}]}
|
||||
else:
|
||||
return (
|
||||
_convert_message_chunk_to_delta(c, i)
|
||||
async for i, c in aenumerate(model_config.astream(converted_messages))
|
||||
)
|
||||
|
||||
|
||||
def _has_assistant_message(session: ChatSession) -> bool:
|
||||
"""Check if chat session has an assistant message."""
|
||||
return any([isinstance(m, AIMessage) for m in session["messages"]])
|
||||
|
||||
|
||||
def convert_messages_for_finetuning(
|
||||
sessions: Iterable[ChatSession],
|
||||
) -> List[List[dict]]:
|
||||
"""Convert messages to a list of lists of dictionaries for fine-tuning.
|
||||
|
||||
Args:
|
||||
sessions: The chat sessions.
|
||||
|
||||
Returns:
|
||||
The list of lists of dictionaries.
|
||||
"""
|
||||
return [
|
||||
[convert_message_to_dict(s) for s in session["messages"]]
|
||||
for session in sessions
|
||||
if _has_assistant_message(session)
|
||||
]
|
||||
@@ -0,0 +1,115 @@
|
||||
"""Agent toolkits contain integrations with various resources and services.
|
||||
|
||||
LangChain has a large ecosystem of integrations with various external resources
|
||||
like local and remote file systems, APIs and databases.
|
||||
|
||||
These integrations allow developers to create versatile applications that combine the
|
||||
power of LLMs with the ability to access, interact with and manipulate external
|
||||
resources.
|
||||
|
||||
When developing an application, developers should inspect the capabilities and
|
||||
permissions of the tools that underlie the given agent toolkit, and determine
|
||||
whether permissions of the given toolkit are appropriate for the application.
|
||||
|
||||
See [Security](https://python.langchain.com/docs/security) for more information.
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api.path import as_import_path
|
||||
|
||||
from langchain_integrations.agent_toolkits.ainetwork.toolkit import AINetworkToolkit
|
||||
#from langchain_integrations.agent_toolkits.amadeus.toolkit import AmadeusToolkit
|
||||
from langchain_integrations.agent_toolkits.azure_cognitive_services import (
|
||||
AzureCognitiveServicesToolkit,
|
||||
)
|
||||
#from langchain_integrations.agent_toolkits.conversational_retrieval.openai_functions import (
|
||||
# create_conversational_retrieval_agent,
|
||||
#)
|
||||
from langchain_integrations.agent_toolkits.file_management.toolkit import (
|
||||
FileManagementToolkit,
|
||||
)
|
||||
from langchain_integrations.agent_toolkits.gmail.toolkit import GmailToolkit
|
||||
from langchain_integrations.agent_toolkits.jira.toolkit import JiraToolkit
|
||||
#from langchain_integrations.agent_toolkits.json.base import create_json_agent
|
||||
#from langchain_integrations.agent_toolkits.json.toolkit import JsonToolkit
|
||||
from langchain_integrations.agent_toolkits.multion.toolkit import MultionToolkit
|
||||
#from langchain_integrations.agent_toolkits.nla.toolkit import NLAToolkit
|
||||
from langchain_integrations.agent_toolkits.office365.toolkit import O365Toolkit
|
||||
#from langchain_integrations.agent_toolkits.openapi.base import create_openapi_agent
|
||||
#from langchain_integrations.agent_toolkits.openapi.toolkit import OpenAPIToolkit
|
||||
from langchain_integrations.agent_toolkits.playwright.toolkit import PlayWrightBrowserToolkit
|
||||
#from langchain_integrations.agent_toolkits.powerbi.base import create_pbi_agent
|
||||
#from langchain_integrations.agent_toolkits.powerbi.chat_base import create_pbi_chat_agent
|
||||
#from langchain_integrations.agent_toolkits.powerbi.toolkit import PowerBIToolkit
|
||||
#from langchain_integrations.agent_toolkits.spark_sql.base import create_spark_sql_agent
|
||||
#from langchain_integrations.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
|
||||
from langchain_integrations.agent_toolkits.sql.base import create_sql_agent
|
||||
from langchain_integrations.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
||||
from langchain_integrations.agent_toolkits.vectorstore.base import (
|
||||
create_vectorstore_agent,
|
||||
create_vectorstore_router_agent,
|
||||
)
|
||||
from langchain_integrations.agent_toolkits.vectorstore.toolkit import (
|
||||
VectorStoreInfo,
|
||||
VectorStoreRouterToolkit,
|
||||
VectorStoreToolkit,
|
||||
)
|
||||
from langchain_integrations.agent_toolkits.zapier.toolkit import ZapierToolkit
|
||||
from langchain_integrations.tools.retriever import create_retriever_tool
|
||||
|
||||
DEPRECATED_AGENTS = [
|
||||
"create_csv_agent",
|
||||
"create_pandas_dataframe_agent",
|
||||
"create_xorbits_agent",
|
||||
"create_python_agent",
|
||||
"create_spark_dataframe_agent",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Get attr name."""
|
||||
if name in DEPRECATED_AGENTS:
|
||||
relative_path = as_import_path(Path(__file__).parent, suffix=name)
|
||||
old_path = "langchain." + relative_path
|
||||
new_path = "langchain_experimental." + relative_path
|
||||
raise ImportError(
|
||||
f"{name} has been moved to langchain experimental. "
|
||||
"See https://github.com/langchain-ai/langchain/discussions/11680"
|
||||
"for more information.\n"
|
||||
f"Please update your import statement from: `{old_path}` to `{new_path}`."
|
||||
)
|
||||
raise AttributeError(f"{name} does not exist")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AINetworkToolkit",
|
||||
"AmadeusToolkit",
|
||||
"AzureCognitiveServicesToolkit",
|
||||
"FileManagementToolkit",
|
||||
"GmailToolkit",
|
||||
"JiraToolkit",
|
||||
"JsonToolkit",
|
||||
"MultionToolkit",
|
||||
"NLAToolkit",
|
||||
"O365Toolkit",
|
||||
"OpenAPIToolkit",
|
||||
"PlayWrightBrowserToolkit",
|
||||
"PowerBIToolkit",
|
||||
"SQLDatabaseToolkit",
|
||||
"SparkSQLToolkit",
|
||||
"VectorStoreInfo",
|
||||
"VectorStoreRouterToolkit",
|
||||
"VectorStoreToolkit",
|
||||
"ZapierToolkit",
|
||||
"create_json_agent",
|
||||
"create_openapi_agent",
|
||||
"create_pbi_agent",
|
||||
"create_pbi_chat_agent",
|
||||
"create_spark_sql_agent",
|
||||
"create_sql_agent",
|
||||
"create_vectorstore_agent",
|
||||
"create_vectorstore_router_agent",
|
||||
"create_conversational_retrieval_agent",
|
||||
"create_retriever_tool",
|
||||
]
|
||||
@@ -0,0 +1 @@
|
||||
"""AINetwork toolkit."""
|
||||
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_integrations.tools import BaseTool
|
||||
from langchain_integrations.tools.ainetwork.app import AINAppOps
|
||||
from langchain_integrations.tools.ainetwork.owner import AINOwnerOps
|
||||
from langchain_integrations.tools.ainetwork.rule import AINRuleOps
|
||||
from langchain_integrations.tools.ainetwork.transfer import AINTransfer
|
||||
from langchain_integrations.tools.ainetwork.utils import authenticate
|
||||
from langchain_integrations.tools.ainetwork.value import AINValueOps
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ain.ain import Ain
|
||||
|
||||
|
||||
class AINetworkToolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with AINetwork Blockchain.
|
||||
|
||||
*Security Note*: This toolkit contains tools that can read and modify
|
||||
the state of a service; e.g., by reading, creating, updating, deleting
|
||||
data associated with this service.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
network: Optional[Literal["mainnet", "testnet"]] = "testnet"
|
||||
interface: Optional[Ain] = None
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_interface(cls, values: dict) -> dict:
|
||||
if not values.get("interface"):
|
||||
values["interface"] = authenticate(network=values.get("network", "testnet"))
|
||||
return values
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
validate_all = True
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
return [
|
||||
AINAppOps(),
|
||||
AINOwnerOps(),
|
||||
AINRuleOps(),
|
||||
AINTransfer(),
|
||||
AINValueOps(),
|
||||
]
|
||||
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_integrations.tools import BaseTool
|
||||
from langchain_integrations.tools.amadeus.closest_airport import AmadeusClosestAirport
|
||||
from langchain_integrations.tools.amadeus.flight_search import AmadeusFlightSearch
|
||||
from langchain_integrations.tools.amadeus.utils import authenticate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from amadeus import Client
|
||||
|
||||
|
||||
class AmadeusToolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with Amadeus which offers APIs for travel search."""
|
||||
|
||||
client: Client = Field(default_factory=authenticate)
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
return [
|
||||
AmadeusClosestAirport(),
|
||||
AmadeusFlightSearch(),
|
||||
]
|
||||
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_integrations.tools.azure_cognitive_services import (
|
||||
AzureCogsFormRecognizerTool,
|
||||
AzureCogsImageAnalysisTool,
|
||||
AzureCogsSpeech2TextTool,
|
||||
AzureCogsText2SpeechTool,
|
||||
AzureCogsTextAnalyticsHealthTool,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
|
||||
class AzureCognitiveServicesToolkit(BaseToolkit):
|
||||
"""Toolkit for Azure Cognitive Services."""
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
tools: List[BaseTool] = [
|
||||
AzureCogsFormRecognizerTool(),
|
||||
AzureCogsSpeech2TextTool(),
|
||||
AzureCogsText2SpeechTool(),
|
||||
AzureCogsTextAnalyticsHealthTool(),
|
||||
]
|
||||
|
||||
# TODO: Remove check once azure-ai-vision supports MacOS.
|
||||
if sys.platform.startswith("linux") or sys.platform.startswith("win"):
|
||||
tools.append(AzureCogsImageAnalysisTool())
|
||||
return tools
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Toolkits for agents."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
from langchain_integrations.tools import BaseTool
|
||||
|
||||
|
||||
class BaseToolkit(BaseModel, ABC):
|
||||
"""Base Toolkit representing a collection of related tools."""
|
||||
|
||||
@abstractmethod
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
@@ -0,0 +1,108 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_integrations.tools import BaseTool
|
||||
from langchain_integrations.tools.clickup.prompt import (
|
||||
CLICKUP_FOLDER_CREATE_PROMPT,
|
||||
CLICKUP_GET_ALL_TEAMS_PROMPT,
|
||||
CLICKUP_GET_FOLDERS_PROMPT,
|
||||
CLICKUP_GET_LIST_PROMPT,
|
||||
CLICKUP_GET_SPACES_PROMPT,
|
||||
CLICKUP_GET_TASK_ATTRIBUTE_PROMPT,
|
||||
CLICKUP_GET_TASK_PROMPT,
|
||||
CLICKUP_LIST_CREATE_PROMPT,
|
||||
CLICKUP_TASK_CREATE_PROMPT,
|
||||
CLICKUP_UPDATE_TASK_ASSIGNEE_PROMPT,
|
||||
CLICKUP_UPDATE_TASK_PROMPT,
|
||||
)
|
||||
from langchain_integrations.tools.clickup.tool import ClickupAction
|
||||
from langchain_integrations.utilities.clickup import ClickupAPIWrapper
|
||||
|
||||
|
||||
class ClickupToolkit(BaseToolkit):
|
||||
"""Clickup Toolkit.
|
||||
|
||||
*Security Note*: This toolkit contains tools that can read and modify
|
||||
the state of a service; e.g., by reading, creating, updating, deleting
|
||||
data associated with this service.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
tools: List[BaseTool] = []
|
||||
|
||||
@classmethod
|
||||
def from_clickup_api_wrapper(
|
||||
cls, clickup_api_wrapper: ClickupAPIWrapper
|
||||
) -> "ClickupToolkit":
|
||||
operations: List[Dict] = [
|
||||
{
|
||||
"mode": "get_task",
|
||||
"name": "Get task",
|
||||
"description": CLICKUP_GET_TASK_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "get_task_attribute",
|
||||
"name": "Get task attribute",
|
||||
"description": CLICKUP_GET_TASK_ATTRIBUTE_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "get_teams",
|
||||
"name": "Get Teams",
|
||||
"description": CLICKUP_GET_ALL_TEAMS_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "create_task",
|
||||
"name": "Create Task",
|
||||
"description": CLICKUP_TASK_CREATE_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "create_list",
|
||||
"name": "Create List",
|
||||
"description": CLICKUP_LIST_CREATE_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "create_folder",
|
||||
"name": "Create Folder",
|
||||
"description": CLICKUP_FOLDER_CREATE_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "get_list",
|
||||
"name": "Get all lists in the space",
|
||||
"description": CLICKUP_GET_LIST_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "get_folders",
|
||||
"name": "Get all folders in the workspace",
|
||||
"description": CLICKUP_GET_FOLDERS_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "get_spaces",
|
||||
"name": "Get all spaces in the workspace",
|
||||
"description": CLICKUP_GET_SPACES_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "update_task",
|
||||
"name": "Update task",
|
||||
"description": CLICKUP_UPDATE_TASK_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "update_task_assignees",
|
||||
"name": "Update task assignees",
|
||||
"description": CLICKUP_UPDATE_TASK_ASSIGNEE_PROMPT,
|
||||
},
|
||||
]
|
||||
tools = [
|
||||
ClickupAction(
|
||||
name=action["name"],
|
||||
description=action["description"],
|
||||
mode=action["mode"],
|
||||
api_wrapper=clickup_api_wrapper,
|
||||
)
|
||||
for action in operations
|
||||
]
|
||||
return cls(tools=tools)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
return self.tools
|
||||
@@ -0,0 +1,88 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.memory import BaseMemory
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.prompts.chat import MessagesPlaceholder
|
||||
|
||||
from langchain_integrations.agents.agent import AgentExecutor
|
||||
from langchain_integrations.agents.openai_functions_agent.agent_token_buffer_memory import (
|
||||
AgentTokenBufferMemory,
|
||||
)
|
||||
from langchain_integrations.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain_openai.chat_model import ChatOpenAI
|
||||
from langchain_integrations.memory.token_buffer import ConversationTokenBufferMemory
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
|
||||
def _get_default_system_message() -> SystemMessage:
|
||||
return SystemMessage(
|
||||
content=(
|
||||
"Do your best to answer the questions. "
|
||||
"Feel free to use any tools available to look up "
|
||||
"relevant information, only if necessary"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def create_conversational_retrieval_agent(
|
||||
llm: BaseLanguageModel,
|
||||
tools: List[BaseTool],
|
||||
remember_intermediate_steps: bool = True,
|
||||
memory_key: str = "chat_history",
|
||||
system_message: Optional[SystemMessage] = None,
|
||||
verbose: bool = False,
|
||||
max_token_limit: int = 2000,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""A convenience method for creating a conversational retrieval agent.
|
||||
|
||||
Args:
|
||||
llm: The language model to use, should be ChatOpenAI
|
||||
tools: A list of tools the agent has access to
|
||||
remember_intermediate_steps: Whether the agent should remember intermediate
|
||||
steps or not. Intermediate steps refer to prior action/observation
|
||||
pairs from previous questions. The benefit of remembering these is if
|
||||
there is relevant information in there, the agent can use it to answer
|
||||
follow up questions. The downside is it will take up more tokens.
|
||||
memory_key: The name of the memory key in the prompt.
|
||||
system_message: The system message to use. By default, a basic one will
|
||||
be used.
|
||||
verbose: Whether or not the final AgentExecutor should be verbose or not,
|
||||
defaults to False.
|
||||
max_token_limit: The max number of tokens to keep around in memory.
|
||||
Defaults to 2000.
|
||||
|
||||
Returns:
|
||||
An agent executor initialized appropriately
|
||||
"""
|
||||
|
||||
if not isinstance(llm, ChatOpenAI):
|
||||
raise ValueError("Only supported with ChatOpenAI models.")
|
||||
if remember_intermediate_steps:
|
||||
memory: BaseMemory = AgentTokenBufferMemory(
|
||||
memory_key=memory_key, llm=llm, max_token_limit=max_token_limit
|
||||
)
|
||||
else:
|
||||
memory = ConversationTokenBufferMemory(
|
||||
memory_key=memory_key,
|
||||
return_messages=True,
|
||||
output_key="output",
|
||||
llm=llm,
|
||||
max_token_limit=max_token_limit,
|
||||
)
|
||||
|
||||
_system_message = system_message or _get_default_system_message()
|
||||
prompt = OpenAIFunctionsAgent.create_prompt(
|
||||
system_message=_system_message,
|
||||
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)],
|
||||
)
|
||||
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
|
||||
return AgentExecutor(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
memory=memory,
|
||||
verbose=verbose,
|
||||
return_intermediate_steps=remember_intermediate_steps,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,3 @@
|
||||
from langchain_integrations.tools.retriever import create_retriever_tool
|
||||
|
||||
__all__ = ["create_retriever_tool"]
|
||||
@@ -0,0 +1,22 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api.path import as_import_path
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Get attr name."""
|
||||
|
||||
here = as_import_path(Path(__file__).parent)
|
||||
|
||||
old_path = "langchain." + here + "." + name
|
||||
new_path = "langchain_experimental." + here + "." + name
|
||||
raise AttributeError(
|
||||
"This agent has been moved to langchain experiment. "
|
||||
"This agent relies on python REPL tool under the hood, so to use it "
|
||||
"safely please sandbox the python REPL. "
|
||||
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
|
||||
"and https://github.com/langchain-ai/langchain/discussions/11680"
|
||||
"To keep using this code as is, install langchain experimental and "
|
||||
f"update your import statement from:\n `{old_path}` to `{new_path}`."
|
||||
)
|
||||
@@ -0,0 +1,7 @@
|
||||
"""Local file management toolkit."""
|
||||
|
||||
from langchain_integrations.agent_toolkits.file_management.toolkit import (
|
||||
FileManagementToolkit,
|
||||
)
|
||||
|
||||
__all__ = ["FileManagementToolkit"]
|
||||
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_integrations.tools import BaseTool
|
||||
from langchain_integrations.tools.file_management.copy import CopyFileTool
|
||||
from langchain_integrations.tools.file_management.delete import DeleteFileTool
|
||||
from langchain_integrations.tools.file_management.file_search import FileSearchTool
|
||||
from langchain_integrations.tools.file_management.list_dir import ListDirectoryTool
|
||||
from langchain_integrations.tools.file_management.move import MoveFileTool
|
||||
from langchain_integrations.tools.file_management.read import ReadFileTool
|
||||
from langchain_integrations.tools.file_management.write import WriteFileTool
|
||||
|
||||
_FILE_TOOLS = {
|
||||
# "Type[Runnable[Any, Any]]" has no attribute "__fields__" [attr-defined]
|
||||
tool_cls.__fields__["name"].default: tool_cls # type: ignore[attr-defined]
|
||||
for tool_cls in [
|
||||
CopyFileTool,
|
||||
DeleteFileTool,
|
||||
FileSearchTool,
|
||||
MoveFileTool,
|
||||
ReadFileTool,
|
||||
WriteFileTool,
|
||||
ListDirectoryTool,
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class FileManagementToolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with local files.
|
||||
|
||||
*Security Notice*: This toolkit provides methods to interact with local files.
|
||||
If providing this toolkit to an agent on an LLM, ensure you scope
|
||||
the agent's permissions to only include the necessary permissions
|
||||
to perform the desired operations.
|
||||
|
||||
By **default** the agent will have access to all files within
|
||||
the root dir and will be able to Copy, Delete, Move, Read, Write
|
||||
and List files in that directory.
|
||||
|
||||
Consider the following:
|
||||
- Limit access to particular directories using `root_dir`.
|
||||
- Use filesystem permissions to restrict access and permissions to only
|
||||
the files and directories required by the agent.
|
||||
- Limit the tools available to the agent to only the file operations
|
||||
necessary for the agent's intended use.
|
||||
- Sandbox the agent by running it in a container.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
root_dir: Optional[str] = None
|
||||
"""If specified, all file operations are made relative to root_dir."""
|
||||
selected_tools: Optional[List[str]] = None
|
||||
"""If provided, only provide the selected tools. Defaults to all."""
|
||||
|
||||
@root_validator
|
||||
def validate_tools(cls, values: dict) -> dict:
|
||||
selected_tools = values.get("selected_tools") or []
|
||||
for tool_name in selected_tools:
|
||||
if tool_name not in _FILE_TOOLS:
|
||||
raise ValueError(
|
||||
f"File Tool of name {tool_name} not supported."
|
||||
f" Permitted tools: {list(_FILE_TOOLS)}"
|
||||
)
|
||||
return values
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
allowed_tools = self.selected_tools or _FILE_TOOLS.keys()
|
||||
tools: List[BaseTool] = []
|
||||
for tool in allowed_tools:
|
||||
tool_cls = _FILE_TOOLS[tool]
|
||||
tools.append(tool_cls(root_dir=self.root_dir)) # type: ignore
|
||||
return tools
|
||||
|
||||
|
||||
__all__ = ["FileManagementToolkit"]
|
||||
@@ -0,0 +1 @@
|
||||
"""GitHub Toolkit."""
|
||||
@@ -0,0 +1,94 @@
|
||||
"""GitHub Toolkit."""
|
||||
from typing import Dict, List
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_integrations.tools import BaseTool
|
||||
from langchain_integrations.tools.github.prompt import (
|
||||
COMMENT_ON_ISSUE_PROMPT,
|
||||
CREATE_FILE_PROMPT,
|
||||
CREATE_PULL_REQUEST_PROMPT,
|
||||
DELETE_FILE_PROMPT,
|
||||
GET_ISSUE_PROMPT,
|
||||
GET_ISSUES_PROMPT,
|
||||
READ_FILE_PROMPT,
|
||||
UPDATE_FILE_PROMPT,
|
||||
)
|
||||
from langchain_integrations.tools.github.tool import GitHubAction
|
||||
from langchain_integrations.utilities.github import GitHubAPIWrapper
|
||||
|
||||
|
||||
class GitHubToolkit(BaseToolkit):
|
||||
"""GitHub Toolkit.
|
||||
|
||||
*Security Note*: This toolkit contains tools that can read and modify
|
||||
the state of a service; e.g., by creating, deleting, or updating,
|
||||
reading underlying data.
|
||||
|
||||
For example, this toolkit can be used to create issues, pull requests,
|
||||
and comments on GitHub.
|
||||
|
||||
See [Security](https://python.langchain.com/docs/security) for more information.
|
||||
"""
|
||||
|
||||
tools: List[BaseTool] = []
|
||||
|
||||
@classmethod
|
||||
def from_github_api_wrapper(
|
||||
cls, github_api_wrapper: GitHubAPIWrapper
|
||||
) -> "GitHubToolkit":
|
||||
operations: List[Dict] = [
|
||||
{
|
||||
"mode": "get_issues",
|
||||
"name": "Get Issues",
|
||||
"description": GET_ISSUES_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "get_issue",
|
||||
"name": "Get Issue",
|
||||
"description": GET_ISSUE_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "comment_on_issue",
|
||||
"name": "Comment on Issue",
|
||||
"description": COMMENT_ON_ISSUE_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "create_pull_request",
|
||||
"name": "Create Pull Request",
|
||||
"description": CREATE_PULL_REQUEST_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "create_file",
|
||||
"name": "Create File",
|
||||
"description": CREATE_FILE_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "read_file",
|
||||
"name": "Read File",
|
||||
"description": READ_FILE_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "update_file",
|
||||
"name": "Update File",
|
||||
"description": UPDATE_FILE_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "delete_file",
|
||||
"name": "Delete File",
|
||||
"description": DELETE_FILE_PROMPT,
|
||||
},
|
||||
]
|
||||
tools = [
|
||||
GitHubAction(
|
||||
name=action["name"],
|
||||
description=action["description"],
|
||||
mode=action["mode"],
|
||||
api_wrapper=github_api_wrapper,
|
||||
)
|
||||
for action in operations
|
||||
]
|
||||
return cls(tools=tools)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
return self.tools
|
||||
@@ -0,0 +1 @@
|
||||
"""GitLab Toolkit."""
|
||||
@@ -0,0 +1,94 @@
|
||||
"""GitHub Toolkit."""
|
||||
from typing import Dict, List
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_integrations.tools import BaseTool
|
||||
from langchain_integrations.tools.gitlab.prompt import (
|
||||
COMMENT_ON_ISSUE_PROMPT,
|
||||
CREATE_FILE_PROMPT,
|
||||
CREATE_PULL_REQUEST_PROMPT,
|
||||
DELETE_FILE_PROMPT,
|
||||
GET_ISSUE_PROMPT,
|
||||
GET_ISSUES_PROMPT,
|
||||
READ_FILE_PROMPT,
|
||||
UPDATE_FILE_PROMPT,
|
||||
)
|
||||
from langchain_integrations.tools.gitlab.tool import GitLabAction
|
||||
from langchain_integrations.utilities.gitlab import GitLabAPIWrapper
|
||||
|
||||
|
||||
class GitLabToolkit(BaseToolkit):
|
||||
"""GitLab Toolkit.
|
||||
|
||||
*Security Note*: This toolkit contains tools that can read and modify
|
||||
the state of a service; e.g., by creating, deleting, or updating,
|
||||
reading underlying data.
|
||||
|
||||
For example, this toolkit can be used to create issues, pull requests,
|
||||
and comments on GitLab.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
tools: List[BaseTool] = []
|
||||
|
||||
@classmethod
|
||||
def from_gitlab_api_wrapper(
|
||||
cls, gitlab_api_wrapper: GitLabAPIWrapper
|
||||
) -> "GitLabToolkit":
|
||||
operations: List[Dict] = [
|
||||
{
|
||||
"mode": "get_issues",
|
||||
"name": "Get Issues",
|
||||
"description": GET_ISSUES_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "get_issue",
|
||||
"name": "Get Issue",
|
||||
"description": GET_ISSUE_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "comment_on_issue",
|
||||
"name": "Comment on Issue",
|
||||
"description": COMMENT_ON_ISSUE_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "create_pull_request",
|
||||
"name": "Create Pull Request",
|
||||
"description": CREATE_PULL_REQUEST_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "create_file",
|
||||
"name": "Create File",
|
||||
"description": CREATE_FILE_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "read_file",
|
||||
"name": "Read File",
|
||||
"description": READ_FILE_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "update_file",
|
||||
"name": "Update File",
|
||||
"description": UPDATE_FILE_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "delete_file",
|
||||
"name": "Delete File",
|
||||
"description": DELETE_FILE_PROMPT,
|
||||
},
|
||||
]
|
||||
tools = [
|
||||
GitLabAction(
|
||||
name=action["name"],
|
||||
description=action["description"],
|
||||
mode=action["mode"],
|
||||
api_wrapper=gitlab_api_wrapper,
|
||||
)
|
||||
for action in operations
|
||||
]
|
||||
return cls(tools=tools)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
return self.tools
|
||||
@@ -0,0 +1 @@
|
||||
"""Gmail toolkit."""
|
||||
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_integrations.tools import BaseTool
|
||||
from langchain_integrations.tools.gmail.create_draft import GmailCreateDraft
|
||||
from langchain_integrations.tools.gmail.get_message import GmailGetMessage
|
||||
from langchain_integrations.tools.gmail.get_thread import GmailGetThread
|
||||
from langchain_integrations.tools.gmail.search import GmailSearch
|
||||
from langchain_integrations.tools.gmail.send_message import GmailSendMessage
|
||||
from langchain_integrations.tools.gmail.utils import build_resource_service
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# This is for linting and IDE typehints
|
||||
from googleapiclient.discovery import Resource
|
||||
else:
|
||||
try:
|
||||
# We do this so pydantic can resolve the types when instantiating
|
||||
from googleapiclient.discovery import Resource
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
SCOPES = ["https://mail.google.com/"]
|
||||
|
||||
|
||||
class GmailToolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with Gmail.
|
||||
|
||||
*Security Note*: This toolkit contains tools that can read and modify
|
||||
the state of a service; e.g., by reading, creating, updating, deleting
|
||||
data associated with this service.
|
||||
|
||||
For example, this toolkit can be used to send emails on behalf of the
|
||||
associated account.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
api_resource: Resource = Field(default_factory=build_resource_service)
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
return [
|
||||
GmailCreateDraft(api_resource=self.api_resource),
|
||||
GmailSendMessage(api_resource=self.api_resource),
|
||||
GmailSearch(api_resource=self.api_resource),
|
||||
GmailGetMessage(api_resource=self.api_resource),
|
||||
GmailGetThread(api_resource=self.api_resource),
|
||||
]
|
||||
@@ -0,0 +1 @@
|
||||
"""Jira Toolkit."""
|
||||
@@ -0,0 +1,70 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_integrations.tools import BaseTool
|
||||
from langchain_integrations.tools.jira.prompt import (
|
||||
JIRA_CATCH_ALL_PROMPT,
|
||||
JIRA_CONFLUENCE_PAGE_CREATE_PROMPT,
|
||||
JIRA_GET_ALL_PROJECTS_PROMPT,
|
||||
JIRA_ISSUE_CREATE_PROMPT,
|
||||
JIRA_JQL_PROMPT,
|
||||
)
|
||||
from langchain_integrations.tools.jira.tool import JiraAction
|
||||
from langchain_integrations.utilities.jira import JiraAPIWrapper
|
||||
|
||||
|
||||
class JiraToolkit(BaseToolkit):
|
||||
"""Jira Toolkit.
|
||||
|
||||
*Security Note*: This toolkit contains tools that can read and modify
|
||||
the state of a service; e.g., by creating, deleting, or updating,
|
||||
reading underlying data.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
tools: List[BaseTool] = []
|
||||
|
||||
@classmethod
|
||||
def from_jira_api_wrapper(cls, jira_api_wrapper: JiraAPIWrapper) -> "JiraToolkit":
|
||||
operations: List[Dict] = [
|
||||
{
|
||||
"mode": "jql",
|
||||
"name": "JQL Query",
|
||||
"description": JIRA_JQL_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "get_projects",
|
||||
"name": "Get Projects",
|
||||
"description": JIRA_GET_ALL_PROJECTS_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "create_issue",
|
||||
"name": "Create Issue",
|
||||
"description": JIRA_ISSUE_CREATE_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "other",
|
||||
"name": "Catch all Jira API call",
|
||||
"description": JIRA_CATCH_ALL_PROMPT,
|
||||
},
|
||||
{
|
||||
"mode": "create_page",
|
||||
"name": "Create confluence page",
|
||||
"description": JIRA_CONFLUENCE_PAGE_CREATE_PROMPT,
|
||||
},
|
||||
]
|
||||
tools = [
|
||||
JiraAction(
|
||||
name=action["name"],
|
||||
description=action["description"],
|
||||
mode=action["mode"],
|
||||
api_wrapper=jira_api_wrapper,
|
||||
)
|
||||
for action in operations
|
||||
]
|
||||
return cls(tools=tools)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
return self.tools
|
||||
@@ -0,0 +1 @@
|
||||
"""Json agent."""
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Json agent."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from langchain_integrations.agents.agent import AgentExecutor
|
||||
from langchain_integrations.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
|
||||
from langchain_integrations.agent_toolkits.json.toolkit import JsonToolkit
|
||||
from langchain_integrations.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain_integrations.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain_core.callbacks.base import BaseCallbackManager
|
||||
from langchain_integrations.chains.llm import LLMChain
|
||||
|
||||
|
||||
def create_json_agent(
|
||||
llm: BaseLanguageModel,
|
||||
toolkit: JsonToolkit,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = JSON_PREFIX,
|
||||
suffix: str = JSON_SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
verbose: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct a json agent from an LLM and tools."""
|
||||
tools = toolkit.get_tools()
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
# flake8: noqa
|
||||
|
||||
JSON_PREFIX = """You are an agent designed to interact with JSON.
|
||||
Your goal is to return a final answer by interacting with the JSON.
|
||||
You have access to the following tools which help you learn more about the JSON you are interacting with.
|
||||
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
|
||||
Do not make up any information that is not contained in the JSON.
|
||||
Your input to the tools should be in the form of `data["key"][0]` where `data` is the JSON blob you are interacting with, and the syntax used is Python.
|
||||
You should only use keys that you know for a fact exist. You must validate that a key exists by seeing it previously when calling `json_spec_list_keys`.
|
||||
If you have not seen a key in one of those responses, you cannot use it.
|
||||
You should only add one key at a time to the path. You cannot add multiple keys at once.
|
||||
If you encounter a "KeyError", go back to the previous key, look at the available keys, and try again.
|
||||
|
||||
If the question does not seem to be related to the JSON, just return "I don't know" as the answer.
|
||||
Always begin your interaction with the `json_spec_list_keys` tool with input "data" to see what keys exist in the JSON.
|
||||
|
||||
Note that sometimes the value at a given path is large. In this case, you will get an error "Value is a large dictionary, should explore its keys directly".
|
||||
In this case, you should ALWAYS follow up by using the `json_spec_list_keys` tool to see what keys exist at that path.
|
||||
Do not simply refer the user to the JSON or a section of the JSON, as this is not a valid answer. Keep digging until you find the answer and explicitly return it.
|
||||
"""
|
||||
JSON_SUFFIX = """Begin!"
|
||||
|
||||
Question: {input}
|
||||
Thought: I should look at the keys that exist in data to see what I have access to
|
||||
{agent_scratchpad}"""
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_integrations.tools import BaseTool
|
||||
from langchain_integrations.tools.json.tool import JsonGetValueTool, JsonListKeysTool, JsonSpec
|
||||
|
||||
|
||||
class JsonToolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with a JSON spec."""
|
||||
|
||||
spec: JsonSpec
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
return [
|
||||
JsonListKeysTool(spec=self.spec),
|
||||
JsonGetValueTool(spec=self.spec),
|
||||
]
|
||||
@@ -0,0 +1 @@
|
||||
"""MultiOn Toolkit."""
|
||||
@@ -0,0 +1,33 @@
|
||||
"""MultiOn agent."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_integrations.tools import BaseTool
|
||||
from langchain_integrations.tools.multion.close_session import MultionCloseSession
|
||||
from langchain_integrations.tools.multion.create_session import MultionCreateSession
|
||||
from langchain_integrations.tools.multion.update_session import MultionUpdateSession
|
||||
|
||||
|
||||
class MultionToolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with the Browser Agent.
|
||||
|
||||
**Security Note**: This toolkit contains tools that interact with the
|
||||
user's browser via the multion API which grants an agent
|
||||
access to the user's browser.
|
||||
|
||||
Please review the documentation for the multion API to understand
|
||||
the security implications of using this toolkit.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
return [MultionCreateSession(), MultionUpdateSession(), MultionCloseSession()]
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Tool for interacting with a single API with natural language definition."""
|
||||
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from langchain_integrations.agents.tools import Tool
|
||||
from langchain_integrations.chains.api.openapi.chain import OpenAPIEndpointChain
|
||||
from langchain_integrations.tools.openapi.utils.api_models import APIOperation
|
||||
from langchain_integrations.tools.openapi.utils.openapi_utils import OpenAPISpec
|
||||
from langchain_integrations.utilities.requests import Requests
|
||||
|
||||
|
||||
class NLATool(Tool):
|
||||
"""Natural Language API Tool."""
|
||||
|
||||
@classmethod
|
||||
def from_open_api_endpoint_chain(
|
||||
cls, chain: OpenAPIEndpointChain, api_title: str
|
||||
) -> "NLATool":
|
||||
"""Convert an endpoint chain to an API endpoint tool."""
|
||||
expanded_name = (
|
||||
f'{api_title.replace(" ", "_")}.{chain.api_operation.operation_id}'
|
||||
)
|
||||
description = (
|
||||
f"I'm an AI from {api_title}. Instruct what you want,"
|
||||
" and I'll assist via an API with description:"
|
||||
f" {chain.api_operation.description}"
|
||||
)
|
||||
return cls(name=expanded_name, func=chain.run, description=description)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_method(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
path: str,
|
||||
method: str,
|
||||
spec: OpenAPISpec,
|
||||
requests: Optional[Requests] = None,
|
||||
verbose: bool = False,
|
||||
return_intermediate_steps: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> "NLATool":
|
||||
"""Instantiate the tool from the specified path and method."""
|
||||
api_operation = APIOperation.from_openapi_spec(spec, path, method)
|
||||
chain = OpenAPIEndpointChain.from_api_operation(
|
||||
api_operation,
|
||||
llm,
|
||||
requests=requests,
|
||||
verbose=verbose,
|
||||
return_intermediate_steps=return_intermediate_steps,
|
||||
**kwargs,
|
||||
)
|
||||
return cls.from_open_api_endpoint_chain(chain, spec.info.title)
|
||||
@@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_integrations.agent_toolkits.nla.tool import NLATool
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_integrations.tools.openapi.utils.openapi_utils import OpenAPISpec
|
||||
from langchain_integrations.tools.plugin import AIPlugin
|
||||
from langchain_integrations.utilities.requests import Requests
|
||||
|
||||
|
||||
class NLAToolkit(BaseToolkit):
|
||||
"""Natural Language API Toolkit.
|
||||
|
||||
*Security Note*: This toolkit creates tools that enable making calls
|
||||
to an Open API compliant API.
|
||||
|
||||
The tools created by this toolkit may be able to make GET, POST,
|
||||
PATCH, PUT, DELETE requests to any of the exposed endpoints on
|
||||
the API.
|
||||
|
||||
Control access to who can use this toolkit.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
nla_tools: Sequence[NLATool] = Field(...)
|
||||
"""List of API Endpoint Tools."""
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools for all the API operations."""
|
||||
return list(self.nla_tools)
|
||||
|
||||
@staticmethod
|
||||
def _get_http_operation_tools(
|
||||
llm: BaseLanguageModel,
|
||||
spec: OpenAPISpec,
|
||||
requests: Optional[Requests] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[NLATool]:
|
||||
"""Get the tools for all the API operations."""
|
||||
if not spec.paths:
|
||||
return []
|
||||
http_operation_tools = []
|
||||
for path in spec.paths:
|
||||
for method in spec.get_methods_for_path(path):
|
||||
endpoint_tool = NLATool.from_llm_and_method(
|
||||
llm=llm,
|
||||
path=path,
|
||||
method=method,
|
||||
spec=spec,
|
||||
requests=requests,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
http_operation_tools.append(endpoint_tool)
|
||||
return http_operation_tools
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_spec(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
spec: OpenAPISpec,
|
||||
requests: Optional[Requests] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> NLAToolkit:
|
||||
"""Instantiate the toolkit by creating tools for each operation."""
|
||||
http_operation_tools = cls._get_http_operation_tools(
|
||||
llm=llm, spec=spec, requests=requests, verbose=verbose, **kwargs
|
||||
)
|
||||
return cls(nla_tools=http_operation_tools)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_url(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
open_api_url: str,
|
||||
requests: Optional[Requests] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> NLAToolkit:
|
||||
"""Instantiate the toolkit from an OpenAPI Spec URL"""
|
||||
spec = OpenAPISpec.from_url(open_api_url)
|
||||
return cls.from_llm_and_spec(
|
||||
llm=llm, spec=spec, requests=requests, verbose=verbose, **kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_ai_plugin(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
ai_plugin: AIPlugin,
|
||||
requests: Optional[Requests] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> NLAToolkit:
|
||||
"""Instantiate the toolkit from an OpenAPI Spec URL"""
|
||||
spec = OpenAPISpec.from_url(ai_plugin.api.url)
|
||||
# TODO: Merge optional Auth information with the `requests` argument
|
||||
return cls.from_llm_and_spec(
|
||||
llm=llm,
|
||||
spec=spec,
|
||||
requests=requests,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_ai_plugin_url(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
ai_plugin_url: str,
|
||||
requests: Optional[Requests] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> NLAToolkit:
|
||||
"""Instantiate the toolkit from an OpenAPI Spec URL"""
|
||||
plugin = AIPlugin.from_url(ai_plugin_url)
|
||||
return cls.from_llm_and_ai_plugin(
|
||||
llm=llm, ai_plugin=plugin, requests=requests, verbose=verbose, **kwargs
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
"""Office365 toolkit."""
|
||||
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_integrations.tools import BaseTool
|
||||
from langchain_integrations.tools.office365.create_draft_message import O365CreateDraftMessage
|
||||
from langchain_integrations.tools.office365.events_search import O365SearchEvents
|
||||
from langchain_integrations.tools.office365.messages_search import O365SearchEmails
|
||||
from langchain_integrations.tools.office365.send_event import O365SendEvent
|
||||
from langchain_integrations.tools.office365.send_message import O365SendMessage
|
||||
from langchain_integrations.tools.office365.utils import authenticate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from O365 import Account
|
||||
|
||||
|
||||
class O365Toolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with Office 365.
|
||||
|
||||
*Security Note*: This toolkit contains tools that can read and modify
|
||||
the state of a service; e.g., by reading, creating, updating, deleting
|
||||
data associated with this service.
|
||||
|
||||
For example, this toolkit can be used search through emails and events,
|
||||
send messages and event invites, and create draft messages.
|
||||
|
||||
Please make sure that the permissions given by this toolkit
|
||||
are appropriate for your use case.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
account: Account = Field(default_factory=authenticate)
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
return [
|
||||
O365SearchEvents(),
|
||||
O365CreateDraftMessage(),
|
||||
O365SearchEmails(),
|
||||
O365SendEvent(),
|
||||
O365SendMessage(),
|
||||
]
|
||||
@@ -0,0 +1 @@
|
||||
"""OpenAPI spec agent."""
|
||||
@@ -0,0 +1,73 @@
|
||||
"""OpenAPI spec agent."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from langchain_integrations.agents.agent import AgentExecutor
|
||||
from langchain_integrations.agent_toolkits.openapi.prompt import (
|
||||
OPENAPI_PREFIX,
|
||||
OPENAPI_SUFFIX,
|
||||
)
|
||||
from langchain_integrations.agent_toolkits.openapi.toolkit import OpenAPIToolkit
|
||||
from langchain_integrations.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain_integrations.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain_core.callbacks.base import BaseCallbackManager
|
||||
from langchain_integrations.chains.llm import LLMChain
|
||||
|
||||
|
||||
def create_openapi_agent(
|
||||
llm: BaseLanguageModel,
|
||||
toolkit: OpenAPIToolkit,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = OPENAPI_PREFIX,
|
||||
suffix: str = OPENAPI_SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
max_iterations: Optional[int] = 15,
|
||||
max_execution_time: Optional[float] = None,
|
||||
early_stopping_method: str = "force",
|
||||
verbose: bool = False,
|
||||
return_intermediate_steps: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct an OpenAPI agent from an LLM and tools.
|
||||
|
||||
*Security Note*: When creating an OpenAPI agent, check the permissions
|
||||
and capabilities of the underlying toolkit.
|
||||
|
||||
For example, if the default implementation of OpenAPIToolkit
|
||||
uses the RequestsToolkit which contains tools to make arbitrary
|
||||
network requests against any URL (e.g., GET, POST, PATCH, PUT, DELETE),
|
||||
|
||||
Control access to who can submit issue requests using this toolkit and
|
||||
what network access it has.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
tools = toolkit.get_tools()
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
return_intermediate_steps=return_intermediate_steps,
|
||||
max_iterations=max_iterations,
|
||||
max_execution_time=max_execution_time,
|
||||
early_stopping_method=early_stopping_method,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
||||
@@ -0,0 +1,357 @@
|
||||
"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach."""
|
||||
import json
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_integrations.agents.agent import AgentExecutor
|
||||
from langchain_integrations.agent_toolkits.openapi.planner_prompt import (
|
||||
API_CONTROLLER_PROMPT,
|
||||
API_CONTROLLER_TOOL_DESCRIPTION,
|
||||
API_CONTROLLER_TOOL_NAME,
|
||||
API_ORCHESTRATOR_PROMPT,
|
||||
API_PLANNER_PROMPT,
|
||||
API_PLANNER_TOOL_DESCRIPTION,
|
||||
API_PLANNER_TOOL_NAME,
|
||||
PARSING_DELETE_PROMPT,
|
||||
PARSING_GET_PROMPT,
|
||||
PARSING_PATCH_PROMPT,
|
||||
PARSING_POST_PROMPT,
|
||||
PARSING_PUT_PROMPT,
|
||||
REQUESTS_DELETE_TOOL_DESCRIPTION,
|
||||
REQUESTS_GET_TOOL_DESCRIPTION,
|
||||
REQUESTS_PATCH_TOOL_DESCRIPTION,
|
||||
REQUESTS_POST_TOOL_DESCRIPTION,
|
||||
REQUESTS_PUT_TOOL_DESCRIPTION,
|
||||
)
|
||||
from langchain_integrations.agent_toolkits.openapi.spec import ReducedOpenAPISpec
|
||||
from langchain_integrations.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain_integrations.agents.tools import Tool
|
||||
from langchain_core.callbacks.base import BaseCallbackManager
|
||||
from langchain_integrations.chains.llm import LLMChain
|
||||
from langchain_openai.llm import OpenAI
|
||||
from langchain_integrations.memory import ReadOnlySharedMemory
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_integrations.tools.requests.tool import BaseRequestsTool
|
||||
from langchain_integrations.utilities.requests import RequestsWrapper
|
||||
|
||||
#
|
||||
# Requests tools with LLM-instructed extraction of truncated responses.
|
||||
#
|
||||
# Of course, truncating so bluntly may lose a lot of valuable
|
||||
# information in the response.
|
||||
# However, the goal for now is to have only a single inference step.
|
||||
MAX_RESPONSE_LENGTH = 5000
|
||||
"""Maximum length of the response to be returned."""
|
||||
|
||||
|
||||
def _get_default_llm_chain(prompt: BasePromptTemplate) -> LLMChain:
|
||||
return LLMChain(
|
||||
llm=OpenAI(),
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
def _get_default_llm_chain_factory(
|
||||
prompt: BasePromptTemplate,
|
||||
) -> Callable[[], LLMChain]:
|
||||
"""Returns a default LLMChain factory."""
|
||||
return partial(_get_default_llm_chain, prompt)
|
||||
|
||||
|
||||
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
"""Requests GET tool with LLM-instructed extraction of truncated responses."""
|
||||
|
||||
name: str = "requests_get"
|
||||
"""Tool name."""
|
||||
description = REQUESTS_GET_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
llm_chain: LLMChain = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
|
||||
)
|
||||
"""LLMChain used to extract the response."""
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
data_params = data.get("params")
|
||||
response = self.requests_wrapper.get(data["url"], params=data_params)
|
||||
response = response[: self.response_length]
|
||||
return self.llm_chain.predict(
|
||||
response=response, instructions=data["output_instructions"]
|
||||
).strip()
|
||||
|
||||
async def _arun(self, text: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
"""Requests POST tool with LLM-instructed extraction of truncated responses."""
|
||||
|
||||
name: str = "requests_post"
|
||||
"""Tool name."""
|
||||
description = REQUESTS_POST_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
llm_chain: LLMChain = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
|
||||
)
|
||||
"""LLMChain used to extract the response."""
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
response = self.requests_wrapper.post(data["url"], data["data"])
|
||||
response = response[: self.response_length]
|
||||
return self.llm_chain.predict(
|
||||
response=response, instructions=data["output_instructions"]
|
||||
).strip()
|
||||
|
||||
async def _arun(self, text: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
"""Requests PATCH tool with LLM-instructed extraction of truncated responses."""
|
||||
|
||||
name: str = "requests_patch"
|
||||
"""Tool name."""
|
||||
description = REQUESTS_PATCH_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
llm_chain: LLMChain = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT)
|
||||
)
|
||||
"""LLMChain used to extract the response."""
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
response = self.requests_wrapper.patch(data["url"], data["data"])
|
||||
response = response[: self.response_length]
|
||||
return self.llm_chain.predict(
|
||||
response=response, instructions=data["output_instructions"]
|
||||
).strip()
|
||||
|
||||
async def _arun(self, text: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
"""Requests PUT tool with LLM-instructed extraction of truncated responses."""
|
||||
|
||||
name: str = "requests_put"
|
||||
"""Tool name."""
|
||||
description = REQUESTS_PUT_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
llm_chain: LLMChain = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_PUT_PROMPT)
|
||||
)
|
||||
"""LLMChain used to extract the response."""
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
response = self.requests_wrapper.put(data["url"], data["data"])
|
||||
response = response[: self.response_length]
|
||||
return self.llm_chain.predict(
|
||||
response=response, instructions=data["output_instructions"]
|
||||
).strip()
|
||||
|
||||
async def _arun(self, text: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
"""A tool that sends a DELETE request and parses the response."""
|
||||
|
||||
name: str = "requests_delete"
|
||||
"""The name of the tool."""
|
||||
description = REQUESTS_DELETE_TOOL_DESCRIPTION
|
||||
"""The description of the tool."""
|
||||
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
"""The maximum length of the response."""
|
||||
llm_chain: LLMChain = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_DELETE_PROMPT)
|
||||
)
|
||||
"""The LLM chain used to parse the response."""
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
response = self.requests_wrapper.delete(data["url"])
|
||||
response = response[: self.response_length]
|
||||
return self.llm_chain.predict(
|
||||
response=response, instructions=data["output_instructions"]
|
||||
).strip()
|
||||
|
||||
async def _arun(self, text: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
#
|
||||
# Orchestrator, planner, controller.
|
||||
#
|
||||
def _create_api_planner_tool(
|
||||
api_spec: ReducedOpenAPISpec, llm: BaseLanguageModel
|
||||
) -> Tool:
|
||||
endpoint_descriptions = [
|
||||
f"{name} {description}" for name, description, _ in api_spec.endpoints
|
||||
]
|
||||
prompt = PromptTemplate(
|
||||
template=API_PLANNER_PROMPT,
|
||||
input_variables=["query"],
|
||||
partial_variables={"endpoints": "- " + "- ".join(endpoint_descriptions)},
|
||||
)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
tool = Tool(
|
||||
name=API_PLANNER_TOOL_NAME,
|
||||
description=API_PLANNER_TOOL_DESCRIPTION,
|
||||
func=chain.run,
|
||||
)
|
||||
return tool
|
||||
|
||||
|
||||
def _create_api_controller_agent(
|
||||
api_url: str,
|
||||
api_docs: str,
|
||||
requests_wrapper: RequestsWrapper,
|
||||
llm: BaseLanguageModel,
|
||||
) -> AgentExecutor:
|
||||
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
|
||||
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
|
||||
tools: List[BaseTool] = [
|
||||
RequestsGetToolWithParsing(
|
||||
requests_wrapper=requests_wrapper, llm_chain=get_llm_chain
|
||||
),
|
||||
RequestsPostToolWithParsing(
|
||||
requests_wrapper=requests_wrapper, llm_chain=post_llm_chain
|
||||
),
|
||||
]
|
||||
prompt = PromptTemplate(
|
||||
template=API_CONTROLLER_PROMPT,
|
||||
input_variables=["input", "agent_scratchpad"],
|
||||
partial_variables={
|
||||
"api_url": api_url,
|
||||
"api_docs": api_docs,
|
||||
"tool_names": ", ".join([tool.name for tool in tools]),
|
||||
"tool_descriptions": "\n".join(
|
||||
[f"{tool.name}: {tool.description}" for tool in tools]
|
||||
),
|
||||
},
|
||||
)
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=LLMChain(llm=llm, prompt=prompt),
|
||||
allowed_tools=[tool.name for tool in tools],
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
|
||||
|
||||
|
||||
def _create_api_controller_tool(
|
||||
api_spec: ReducedOpenAPISpec,
|
||||
requests_wrapper: RequestsWrapper,
|
||||
llm: BaseLanguageModel,
|
||||
) -> Tool:
|
||||
"""Expose controller as a tool.
|
||||
|
||||
The tool is invoked with a plan from the planner, and dynamically
|
||||
creates a controller agent with relevant documentation only to
|
||||
constrain the context.
|
||||
"""
|
||||
|
||||
base_url = api_spec.servers[0]["url"] # TODO: do better.
|
||||
|
||||
def _create_and_run_api_controller_agent(plan_str: str) -> str:
|
||||
pattern = r"\b(GET|POST|PATCH|DELETE)\s+(/\S+)*"
|
||||
matches = re.findall(pattern, plan_str)
|
||||
endpoint_names = [
|
||||
"{method} {route}".format(method=method, route=route.split("?")[0])
|
||||
for method, route in matches
|
||||
]
|
||||
docs_str = ""
|
||||
for endpoint_name in endpoint_names:
|
||||
found_match = False
|
||||
for name, _, docs in api_spec.endpoints:
|
||||
regex_name = re.compile(re.sub("\{.*?\}", ".*", name))
|
||||
if regex_name.match(endpoint_name):
|
||||
found_match = True
|
||||
docs_str += f"== Docs for {endpoint_name} == \n{yaml.dump(docs)}\n"
|
||||
if not found_match:
|
||||
raise ValueError(f"{endpoint_name} endpoint does not exist.")
|
||||
|
||||
agent = _create_api_controller_agent(base_url, docs_str, requests_wrapper, llm)
|
||||
return agent.run(plan_str)
|
||||
|
||||
return Tool(
|
||||
name=API_CONTROLLER_TOOL_NAME,
|
||||
func=_create_and_run_api_controller_agent,
|
||||
description=API_CONTROLLER_TOOL_DESCRIPTION,
|
||||
)
|
||||
|
||||
|
||||
def create_openapi_agent(
|
||||
api_spec: ReducedOpenAPISpec,
|
||||
requests_wrapper: RequestsWrapper,
|
||||
llm: BaseLanguageModel,
|
||||
shared_memory: Optional[ReadOnlySharedMemory] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
verbose: bool = True,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Instantiate OpenAI API planner and controller for a given spec.
|
||||
|
||||
Inject credentials via requests_wrapper.
|
||||
|
||||
We use a top-level "orchestrator" agent to invoke the planner and controller,
|
||||
rather than a top-level planner
|
||||
that invokes a controller with its plan. This is to keep the planner simple.
|
||||
"""
|
||||
tools = [
|
||||
_create_api_planner_tool(api_spec, llm),
|
||||
_create_api_controller_tool(api_spec, requests_wrapper, llm),
|
||||
]
|
||||
prompt = PromptTemplate(
|
||||
template=API_ORCHESTRATOR_PROMPT,
|
||||
input_variables=["input", "agent_scratchpad"],
|
||||
partial_variables={
|
||||
"tool_names": ", ".join([tool.name for tool in tools]),
|
||||
"tool_descriptions": "\n".join(
|
||||
[f"{tool.name}: {tool.description}" for tool in tools]
|
||||
),
|
||||
},
|
||||
)
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=LLMChain(llm=llm, prompt=prompt, memory=shared_memory),
|
||||
allowed_tools=[tool.name for tool in tools],
|
||||
**kwargs,
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
||||
@@ -0,0 +1,235 @@
|
||||
# flake8: noqa
|
||||
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
API_PLANNER_PROMPT = """You are a planner that plans a sequence of API calls to assist with user queries against an API.
|
||||
|
||||
You should:
|
||||
1) evaluate whether the user query can be solved by the API documentated below. If no, say why.
|
||||
2) if yes, generate a plan of API calls and say what they are doing step by step.
|
||||
3) If the plan includes a DELETE call, you should always return an ask from the User for authorization first unless the User has specifically asked to delete something.
|
||||
|
||||
You should only use API endpoints documented below ("Endpoints you can use:").
|
||||
You can only use the DELETE tool if the User has specifically asked to delete something. Otherwise, you should return a request authorization from the User first.
|
||||
Some user queries can be resolved in a single API call, but some will require several API calls.
|
||||
The plan will be passed to an API controller that can format it into web requests and return the responses.
|
||||
|
||||
----
|
||||
|
||||
Here are some examples:
|
||||
|
||||
Fake endpoints for examples:
|
||||
GET /user to get information about the current user
|
||||
GET /products/search search across products
|
||||
POST /users/{{id}}/cart to add products to a user's cart
|
||||
PATCH /users/{{id}}/cart to update a user's cart
|
||||
PUT /users/{{id}}/coupon to apply idempotent coupon to a user's cart
|
||||
DELETE /users/{{id}}/cart to delete a user's cart
|
||||
|
||||
User query: tell me a joke
|
||||
Plan: Sorry, this API's domain is shopping, not comedy.
|
||||
|
||||
User query: I want to buy a couch
|
||||
Plan: 1. GET /products with a query param to search for couches
|
||||
2. GET /user to find the user's id
|
||||
3. POST /users/{{id}}/cart to add a couch to the user's cart
|
||||
|
||||
User query: I want to add a lamp to my cart
|
||||
Plan: 1. GET /products with a query param to search for lamps
|
||||
2. GET /user to find the user's id
|
||||
3. PATCH /users/{{id}}/cart to add a lamp to the user's cart
|
||||
|
||||
User query: I want to add a coupon to my cart
|
||||
Plan: 1. GET /user to find the user's id
|
||||
2. PUT /users/{{id}}/coupon to apply the coupon
|
||||
|
||||
User query: I want to delete my cart
|
||||
Plan: 1. GET /user to find the user's id
|
||||
2. DELETE required. Did user specify DELETE or previously authorize? Yes, proceed.
|
||||
3. DELETE /users/{{id}}/cart to delete the user's cart
|
||||
|
||||
User query: I want to start a new cart
|
||||
Plan: 1. GET /user to find the user's id
|
||||
2. DELETE required. Did user specify DELETE or previously authorize? No, ask for authorization.
|
||||
3. Are you sure you want to delete your cart?
|
||||
----
|
||||
|
||||
Here are endpoints you can use. Do not reference any of the endpoints above.
|
||||
|
||||
{endpoints}
|
||||
|
||||
----
|
||||
|
||||
User query: {query}
|
||||
Plan:"""
|
||||
API_PLANNER_TOOL_NAME = "api_planner"
|
||||
API_PLANNER_TOOL_DESCRIPTION = f"Can be used to generate the right API calls to assist with a user query, like {API_PLANNER_TOOL_NAME}(query). Should always be called before trying to call the API controller."
|
||||
|
||||
# Execution.
|
||||
API_CONTROLLER_PROMPT = """You are an agent that gets a sequence of API calls and given their documentation, should execute them and return the final response.
|
||||
If you cannot complete them and run into issues, you should explain the issue. If you're unable to resolve an API call, you can retry the API call. When interacting with API objects, you should extract ids for inputs to other API calls but ids and names for outputs returned to the User.
|
||||
|
||||
|
||||
Here is documentation on the API:
|
||||
Base url: {api_url}
|
||||
Endpoints:
|
||||
{api_docs}
|
||||
|
||||
|
||||
Here are tools to execute requests against the API: {tool_descriptions}
|
||||
|
||||
|
||||
Starting below, you should follow this format:
|
||||
|
||||
Plan: the plan of API calls to execute
|
||||
Thought: you should always think about what to do
|
||||
Action: the action to take, should be one of the tools [{tool_names}]
|
||||
Action Input: the input to the action
|
||||
Observation: the output of the action
|
||||
... (this Thought/Action/Action Input/Observation can repeat N times)
|
||||
Thought: I am finished executing the plan (or, I cannot finish executing the plan without knowing some other information.)
|
||||
Final Answer: the final output from executing the plan or missing information I'd need to re-plan correctly.
|
||||
|
||||
|
||||
Begin!
|
||||
|
||||
Plan: {input}
|
||||
Thought:
|
||||
{agent_scratchpad}
|
||||
"""
|
||||
API_CONTROLLER_TOOL_NAME = "api_controller"
|
||||
API_CONTROLLER_TOOL_DESCRIPTION = f"Can be used to execute a plan of API calls, like {API_CONTROLLER_TOOL_NAME}(plan)."
|
||||
|
||||
# Orchestrate planning + execution.
|
||||
# The goal is to have an agent at the top-level (e.g. so it can recover from errors and re-plan) while
|
||||
# keeping planning (and specifically the planning prompt) simple.
|
||||
API_ORCHESTRATOR_PROMPT = """You are an agent that assists with user queries against API, things like querying information or creating resources.
|
||||
Some user queries can be resolved in a single API call, particularly if you can find appropriate params from the OpenAPI spec; though some require several API calls.
|
||||
You should always plan your API calls first, and then execute the plan second.
|
||||
If the plan includes a DELETE call, be sure to ask the User for authorization first unless the User has specifically asked to delete something.
|
||||
You should never return information without executing the api_controller tool.
|
||||
|
||||
|
||||
Here are the tools to plan and execute API requests: {tool_descriptions}
|
||||
|
||||
|
||||
Starting below, you should follow this format:
|
||||
|
||||
User query: the query a User wants help with related to the API
|
||||
Thought: you should always think about what to do
|
||||
Action: the action to take, should be one of the tools [{tool_names}]
|
||||
Action Input: the input to the action
|
||||
Observation: the result of the action
|
||||
... (this Thought/Action/Action Input/Observation can repeat N times)
|
||||
Thought: I am finished executing a plan and have the information the user asked for or the data the user asked to create
|
||||
Final Answer: the final output from executing the plan
|
||||
|
||||
|
||||
Example:
|
||||
User query: can you add some trendy stuff to my shopping cart.
|
||||
Thought: I should plan API calls first.
|
||||
Action: api_planner
|
||||
Action Input: I need to find the right API calls to add trendy items to the users shopping cart
|
||||
Observation: 1) GET /items with params 'trending' is 'True' to get trending item ids
|
||||
2) GET /user to get user
|
||||
3) POST /cart to post the trending items to the user's cart
|
||||
Thought: I'm ready to execute the API calls.
|
||||
Action: api_controller
|
||||
Action Input: 1) GET /items params 'trending' is 'True' to get trending item ids
|
||||
2) GET /user to get user
|
||||
3) POST /cart to post the trending items to the user's cart
|
||||
...
|
||||
|
||||
Begin!
|
||||
|
||||
User query: {input}
|
||||
Thought: I should generate a plan to help with this query and then copy that plan exactly to the controller.
|
||||
{agent_scratchpad}"""
|
||||
|
||||
REQUESTS_GET_TOOL_DESCRIPTION = """Use this to GET content from a website.
|
||||
Input to the tool should be a json string with 3 keys: "url", "params" and "output_instructions".
|
||||
The value of "url" should be a string.
|
||||
The value of "params" should be a dict of the needed and available parameters from the OpenAPI spec related to the endpoint.
|
||||
If parameters are not needed, or not available, leave it empty.
|
||||
The value of "output_instructions" should be instructions on what information to extract from the response,
|
||||
for example the id(s) for a resource(s) that the GET request fetches.
|
||||
"""
|
||||
|
||||
PARSING_GET_PROMPT = PromptTemplate(
|
||||
template="""Here is an API response:\n\n{response}\n\n====
|
||||
Your task is to extract some information according to these instructions: {instructions}
|
||||
When working with API objects, you should usually use ids over names.
|
||||
If the response indicates an error, you should instead output a summary of the error.
|
||||
|
||||
Output:""",
|
||||
input_variables=["response", "instructions"],
|
||||
)
|
||||
|
||||
REQUESTS_POST_TOOL_DESCRIPTION = """Use this when you want to POST to a website.
|
||||
Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions".
|
||||
The value of "url" should be a string.
|
||||
The value of "data" should be a dictionary of key-value pairs you want to POST to the url.
|
||||
The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the POST request creates.
|
||||
Always use double quotes for strings in the json string."""
|
||||
|
||||
PARSING_POST_PROMPT = PromptTemplate(
|
||||
template="""Here is an API response:\n\n{response}\n\n====
|
||||
Your task is to extract some information according to these instructions: {instructions}
|
||||
When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
|
||||
If the response indicates an error, you should instead output a summary of the error.
|
||||
|
||||
Output:""",
|
||||
input_variables=["response", "instructions"],
|
||||
)
|
||||
|
||||
REQUESTS_PATCH_TOOL_DESCRIPTION = """Use this when you want to PATCH content on a website.
|
||||
Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions".
|
||||
The value of "url" should be a string.
|
||||
The value of "data" should be a dictionary of key-value pairs of the body params available in the OpenAPI spec you want to PATCH the content with at the url.
|
||||
The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the PATCH request creates.
|
||||
Always use double quotes for strings in the json string."""
|
||||
|
||||
PARSING_PATCH_PROMPT = PromptTemplate(
|
||||
template="""Here is an API response:\n\n{response}\n\n====
|
||||
Your task is to extract some information according to these instructions: {instructions}
|
||||
When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
|
||||
If the response indicates an error, you should instead output a summary of the error.
|
||||
|
||||
Output:""",
|
||||
input_variables=["response", "instructions"],
|
||||
)
|
||||
|
||||
REQUESTS_PUT_TOOL_DESCRIPTION = """Use this when you want to PUT to a website.
|
||||
Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions".
|
||||
The value of "url" should be a string.
|
||||
The value of "data" should be a dictionary of key-value pairs you want to PUT to the url.
|
||||
The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the PUT request creates.
|
||||
Always use double quotes for strings in the json string."""
|
||||
|
||||
PARSING_PUT_PROMPT = PromptTemplate(
|
||||
template="""Here is an API response:\n\n{response}\n\n====
|
||||
Your task is to extract some information according to these instructions: {instructions}
|
||||
When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
|
||||
If the response indicates an error, you should instead output a summary of the error.
|
||||
|
||||
Output:""",
|
||||
input_variables=["response", "instructions"],
|
||||
)
|
||||
|
||||
REQUESTS_DELETE_TOOL_DESCRIPTION = """ONLY USE THIS TOOL WHEN THE USER HAS SPECIFICALLY REQUESTED TO DELETE CONTENT FROM A WEBSITE.
|
||||
Input to the tool should be a json string with 2 keys: "url", and "output_instructions".
|
||||
The value of "url" should be a string.
|
||||
The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the DELETE request creates.
|
||||
Always use double quotes for strings in the json string.
|
||||
ONLY USE THIS TOOL IF THE USER HAS SPECIFICALLY REQUESTED TO DELETE SOMETHING."""
|
||||
|
||||
PARSING_DELETE_PROMPT = PromptTemplate(
|
||||
template="""Here is an API response:\n\n{response}\n\n====
|
||||
Your task is to extract some information according to these instructions: {instructions}
|
||||
When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
|
||||
If the response indicates an error, you should instead output a summary of the error.
|
||||
|
||||
Output:""",
|
||||
input_variables=["response", "instructions"],
|
||||
)
|
||||
@@ -0,0 +1,29 @@
|
||||
# flake8: noqa
|
||||
|
||||
OPENAPI_PREFIX = """You are an agent designed to answer questions by making web requests to an API given the openapi spec.
|
||||
|
||||
If the question does not seem related to the API, return I don't know. Do not make up an answer.
|
||||
Only use information provided by the tools to construct your response.
|
||||
|
||||
First, find the base URL needed to make the request.
|
||||
|
||||
Second, find the relevant paths needed to answer the question. Take note that, sometimes, you might need to make more than one request to more than one path to answer the question.
|
||||
|
||||
Third, find the required parameters needed to make the request. For GET requests, these are usually URL parameters and for POST requests, these are request body parameters.
|
||||
|
||||
Fourth, make the requests needed to answer the question. Ensure that you are sending the correct parameters to the request by checking which parameters are required. For parameters with a fixed set of values, please use the spec to look at which values are allowed.
|
||||
|
||||
Use the exact parameter names as listed in the spec, do not make up any names or abbreviate the names of parameters.
|
||||
If you get a not found error, ensure that you are using a path that actually exists in the spec.
|
||||
"""
|
||||
OPENAPI_SUFFIX = """Begin!
|
||||
|
||||
Question: {input}
|
||||
Thought: I should explore the spec to find the base url for the API.
|
||||
{agent_scratchpad}"""
|
||||
|
||||
DESCRIPTION = """Can be used to answer questions about the openapi spec for the API. Always use this tool before trying to make a request.
|
||||
Example inputs to this tool:
|
||||
'What are the required query parameters for a GET request to the /bar endpoint?`
|
||||
'What are the required parameters in the request body for a POST request to the /foo endpoint?'
|
||||
Always give this tool a specific question."""
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Quick and dirty representation for OpenAPI specs."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
from langchain_core.utils.json_schema import dereference_refs
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReducedOpenAPISpec:
|
||||
"""A reduced OpenAPI spec.
|
||||
|
||||
This is a quick and dirty representation for OpenAPI specs.
|
||||
|
||||
Attributes:
|
||||
servers: The servers in the spec.
|
||||
description: The description of the spec.
|
||||
endpoints: The endpoints in the spec.
|
||||
"""
|
||||
|
||||
servers: List[dict]
|
||||
description: str
|
||||
endpoints: List[Tuple[str, str, dict]]
|
||||
|
||||
|
||||
def reduce_openapi_spec(spec: dict, dereference: bool = True) -> ReducedOpenAPISpec:
|
||||
"""Simplify/distill/minify a spec somehow.
|
||||
|
||||
I want a smaller target for retrieval and (more importantly)
|
||||
I want smaller results from retrieval.
|
||||
I was hoping https://openapi.tools/ would have some useful bits
|
||||
to this end, but doesn't seem so.
|
||||
"""
|
||||
# 1. Consider only get, post, patch, put, delete endpoints.
|
||||
endpoints = [
|
||||
(f"{operation_name.upper()} {route}", docs.get("description"), docs)
|
||||
for route, operation in spec["paths"].items()
|
||||
for operation_name, docs in operation.items()
|
||||
if operation_name in ["get", "post", "patch", "put", "delete"]
|
||||
]
|
||||
|
||||
# 2. Replace any refs so that complete docs are retrieved.
|
||||
# Note: probably want to do this post-retrieval, it blows up the size of the spec.
|
||||
if dereference:
|
||||
endpoints = [
|
||||
(name, description, dereference_refs(docs, full_schema=spec))
|
||||
for name, description, docs in endpoints
|
||||
]
|
||||
|
||||
# 3. Strip docs down to required request args + happy path response.
|
||||
def reduce_endpoint_docs(docs: dict) -> dict:
|
||||
out = {}
|
||||
if docs.get("description"):
|
||||
out["description"] = docs.get("description")
|
||||
if docs.get("parameters"):
|
||||
out["parameters"] = [
|
||||
parameter
|
||||
for parameter in docs.get("parameters", [])
|
||||
if parameter.get("required")
|
||||
]
|
||||
if "200" in docs["responses"]:
|
||||
out["responses"] = docs["responses"]["200"]
|
||||
if docs.get("requestBody"):
|
||||
out["requestBody"] = docs.get("requestBody")
|
||||
return out
|
||||
|
||||
endpoints = [
|
||||
(name, description, reduce_endpoint_docs(docs))
|
||||
for name, description, docs in endpoints
|
||||
]
|
||||
return ReducedOpenAPISpec(
|
||||
servers=spec["servers"],
|
||||
description=spec["info"].get("description", ""),
|
||||
endpoints=endpoints,
|
||||
)
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Requests toolkit."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from langchain_integrations.agents.agent import AgentExecutor
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_integrations.agent_toolkits.json.base import create_json_agent
|
||||
from langchain_integrations.agent_toolkits.json.toolkit import JsonToolkit
|
||||
from langchain_integrations.agent_toolkits.openapi.prompt import DESCRIPTION
|
||||
from langchain_integrations.agents.tools import Tool
|
||||
from langchain_integrations.tools import BaseTool
|
||||
from langchain_integrations.tools.json.tool import JsonSpec
|
||||
from langchain_integrations.tools.requests.tool import (
|
||||
RequestsDeleteTool,
|
||||
RequestsGetTool,
|
||||
RequestsPatchTool,
|
||||
RequestsPostTool,
|
||||
RequestsPutTool,
|
||||
)
|
||||
from langchain_integrations.utilities.requests import TextRequestsWrapper
|
||||
|
||||
|
||||
class RequestsToolkit(BaseToolkit):
|
||||
"""Toolkit for making REST requests.
|
||||
|
||||
*Security Note*: This toolkit contains tools to make GET, POST, PATCH, PUT,
|
||||
and DELETE requests to an API.
|
||||
|
||||
Exercise care in who is allowed to use this toolkit. If exposing
|
||||
to end users, consider that users will be able to make arbitrary
|
||||
requests on behalf of the server hosting the code. For example,
|
||||
users could ask the server to make a request to a private API
|
||||
that is only accessible from the server.
|
||||
|
||||
Control access to who can submit issue requests using this toolkit and
|
||||
what network access it has.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
requests_wrapper: TextRequestsWrapper
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Return a list of tools."""
|
||||
return [
|
||||
RequestsGetTool(requests_wrapper=self.requests_wrapper),
|
||||
RequestsPostTool(requests_wrapper=self.requests_wrapper),
|
||||
RequestsPatchTool(requests_wrapper=self.requests_wrapper),
|
||||
RequestsPutTool(requests_wrapper=self.requests_wrapper),
|
||||
RequestsDeleteTool(requests_wrapper=self.requests_wrapper),
|
||||
]
|
||||
|
||||
|
||||
class OpenAPIToolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with an OpenAPI API.
|
||||
|
||||
*Security Note*: This toolkit contains tools that can read and modify
|
||||
the state of a service; e.g., by creating, deleting, or updating,
|
||||
reading underlying data.
|
||||
|
||||
For example, this toolkit can be used to delete data exposed via
|
||||
an OpenAPI compliant API.
|
||||
"""
|
||||
|
||||
json_agent: AgentExecutor
|
||||
requests_wrapper: TextRequestsWrapper
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
json_agent_tool = Tool(
|
||||
name="json_explorer",
|
||||
func=self.json_agent.run,
|
||||
description=DESCRIPTION,
|
||||
)
|
||||
request_toolkit = RequestsToolkit(requests_wrapper=self.requests_wrapper)
|
||||
return [*request_toolkit.get_tools(), json_agent_tool]
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
json_spec: JsonSpec,
|
||||
requests_wrapper: TextRequestsWrapper,
|
||||
**kwargs: Any,
|
||||
) -> OpenAPIToolkit:
|
||||
"""Create json agent from llm, then initialize."""
|
||||
json_agent = create_json_agent(llm, JsonToolkit(spec=json_spec), **kwargs)
|
||||
return cls(json_agent=json_agent, requests_wrapper=requests_wrapper)
|
||||
@@ -0,0 +1,22 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api.path import as_import_path
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Get attr name."""
|
||||
|
||||
here = as_import_path(Path(__file__).parent)
|
||||
|
||||
old_path = "langchain." + here + "." + name
|
||||
new_path = "langchain_experimental." + here + "." + name
|
||||
raise AttributeError(
|
||||
"This agent has been moved to langchain experiment. "
|
||||
"This agent relies on python REPL tool under the hood, so to use it "
|
||||
"safely please sandbox the python REPL. "
|
||||
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
|
||||
"and https://github.com/langchain-ai/langchain/discussions/11680"
|
||||
"To keep using this code as is, install langchain experimental and "
|
||||
f"update your import statement from:\n `{old_path}` to `{new_path}`."
|
||||
)
|
||||
@@ -0,0 +1,4 @@
|
||||
"""Playwright browser toolkit."""
|
||||
from langchain_integrations.agent_toolkits.playwright.toolkit import PlayWrightBrowserToolkit
|
||||
|
||||
__all__ = ["PlayWrightBrowserToolkit"]
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Playwright web browser toolkit."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional, Type, cast
|
||||
|
||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_integrations.tools.playwright.base import (
|
||||
BaseBrowserTool,
|
||||
lazy_import_playwright_browsers,
|
||||
)
|
||||
from langchain_integrations.tools.playwright.click import ClickTool
|
||||
from langchain_integrations.tools.playwright.current_page import CurrentWebPageTool
|
||||
from langchain_integrations.tools.playwright.extract_hyperlinks import ExtractHyperlinksTool
|
||||
from langchain_integrations.tools.playwright.extract_text import ExtractTextTool
|
||||
from langchain_integrations.tools.playwright.get_elements import GetElementsTool
|
||||
from langchain_integrations.tools.playwright.navigate import NavigateTool
|
||||
from langchain_integrations.tools.playwright.navigate_back import NavigateBackTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from playwright.async_api import Browser as AsyncBrowser
|
||||
from playwright.sync_api import Browser as SyncBrowser
|
||||
else:
|
||||
try:
|
||||
# We do this so pydantic can resolve the types when instantiating
|
||||
from playwright.async_api import Browser as AsyncBrowser
|
||||
from playwright.sync_api import Browser as SyncBrowser
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class PlayWrightBrowserToolkit(BaseToolkit):
|
||||
"""Toolkit for PlayWright browser tools.
|
||||
|
||||
**Security Note**: This toolkit provides code to control a web-browser.
|
||||
|
||||
Careful if exposing this toolkit to end-users. The tools in the toolkit
|
||||
are capable of navigating to arbitrary webpages, clicking on arbitrary
|
||||
elements, and extracting arbitrary text and hyperlinks from webpages.
|
||||
|
||||
Specifically, by default this toolkit allows navigating to:
|
||||
|
||||
- Any URL (including any internal network URLs)
|
||||
- And local files
|
||||
|
||||
If exposing to end-users, consider limiting network access to the
|
||||
server that hosts the agent; in addition, consider it is advised
|
||||
to create a custom NavigationTool wht an args_schema that limits the URLs
|
||||
that can be navigated to (e.g., only allow navigating to URLs that
|
||||
start with a particular prefix).
|
||||
|
||||
Remember to scope permissions to the minimal permissions necessary for
|
||||
the application. If the default tool selection is not appropriate for
|
||||
the application, consider creating a custom toolkit with the appropriate
|
||||
tools.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
sync_browser: Optional["SyncBrowser"] = None
|
||||
async_browser: Optional["AsyncBrowser"] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
def validate_imports_and_browser_provided(cls, values: dict) -> dict:
|
||||
"""Check that the arguments are valid."""
|
||||
lazy_import_playwright_browsers()
|
||||
if values.get("async_browser") is None and values.get("sync_browser") is None:
|
||||
raise ValueError("Either async_browser or sync_browser must be specified.")
|
||||
return values
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
tool_classes: List[Type[BaseBrowserTool]] = [
|
||||
ClickTool,
|
||||
NavigateTool,
|
||||
NavigateBackTool,
|
||||
ExtractTextTool,
|
||||
ExtractHyperlinksTool,
|
||||
GetElementsTool,
|
||||
CurrentWebPageTool,
|
||||
]
|
||||
|
||||
tools = [
|
||||
tool_cls.from_browser(
|
||||
sync_browser=self.sync_browser, async_browser=self.async_browser
|
||||
)
|
||||
for tool_cls in tool_classes
|
||||
]
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
@classmethod
|
||||
def from_browser(
|
||||
cls,
|
||||
sync_browser: Optional[SyncBrowser] = None,
|
||||
async_browser: Optional[AsyncBrowser] = None,
|
||||
) -> PlayWrightBrowserToolkit:
|
||||
"""Instantiate the toolkit."""
|
||||
# This is to raise a better error than the forward ref ones Pydantic would have
|
||||
lazy_import_playwright_browsers()
|
||||
return cls(sync_browser=sync_browser, async_browser=async_browser)
|
||||
@@ -0,0 +1 @@
|
||||
"""Power BI agent."""
|
||||
@@ -0,0 +1,63 @@
|
||||
"""Power BI agent."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from langchain_integrations.agents import AgentExecutor
|
||||
from langchain_integrations.agent_toolkits.powerbi.prompt import (
|
||||
POWERBI_PREFIX,
|
||||
POWERBI_SUFFIX,
|
||||
)
|
||||
from langchain_integrations.agent_toolkits.powerbi.toolkit import PowerBIToolkit
|
||||
from langchain_integrations.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain_integrations.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain_core.callbacks.base import BaseCallbackManager
|
||||
from langchain_integrations.chains.llm import LLMChain
|
||||
from langchain_integrations.utilities.powerbi import PowerBIDataset
|
||||
|
||||
|
||||
def create_pbi_agent(
|
||||
llm: BaseLanguageModel,
|
||||
toolkit: Optional[PowerBIToolkit] = None,
|
||||
powerbi: Optional[PowerBIDataset] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = POWERBI_PREFIX,
|
||||
suffix: str = POWERBI_SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
examples: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
top_k: int = 10,
|
||||
verbose: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct a Power BI agent from an LLM and tools."""
|
||||
if toolkit is None:
|
||||
if powerbi is None:
|
||||
raise ValueError("Must provide either a toolkit or powerbi dataset")
|
||||
toolkit = PowerBIToolkit(powerbi=powerbi, llm=llm, examples=examples)
|
||||
tools = toolkit.get_tools()
|
||||
tables = powerbi.table_names if powerbi else toolkit.powerbi.table_names
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=LLMChain(
|
||||
llm=llm,
|
||||
prompt=ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=prefix.format(top_k=top_k).format(tables=tables),
|
||||
suffix=suffix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
),
|
||||
callback_manager=callback_manager, # type: ignore
|
||||
verbose=verbose,
|
||||
),
|
||||
allowed_tools=[tool.name for tool in tools],
|
||||
**kwargs,
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
||||
@@ -0,0 +1,64 @@
|
||||
"""Power BI agent."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_integrations.agents import AgentExecutor
|
||||
from langchain_integrations.agents.agent import AgentOutputParser
|
||||
from langchain_integrations.agent_toolkits.powerbi.prompt import (
|
||||
POWERBI_CHAT_PREFIX,
|
||||
POWERBI_CHAT_SUFFIX,
|
||||
)
|
||||
from langchain_integrations.agent_toolkits.powerbi.toolkit import PowerBIToolkit
|
||||
from langchain_integrations.agents.conversational_chat.base import ConversationalChatAgent
|
||||
from langchain_core.callbacks.base import BaseCallbackManager
|
||||
from langchain_integrations.chat_models.base import BaseChatModel
|
||||
from langchain_integrations.memory import ConversationBufferMemory
|
||||
from langchain_integrations.memory.chat_memory import BaseChatMemory
|
||||
from langchain_integrations.utilities.powerbi import PowerBIDataset
|
||||
|
||||
|
||||
def create_pbi_chat_agent(
|
||||
llm: BaseChatModel,
|
||||
toolkit: Optional[PowerBIToolkit] = None,
|
||||
powerbi: Optional[PowerBIDataset] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
prefix: str = POWERBI_CHAT_PREFIX,
|
||||
suffix: str = POWERBI_CHAT_SUFFIX,
|
||||
examples: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory: Optional[BaseChatMemory] = None,
|
||||
top_k: int = 10,
|
||||
verbose: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct a Power BI agent from a Chat LLM and tools.
|
||||
|
||||
If you supply only a toolkit and no Power BI dataset, the same LLM is used for both.
|
||||
"""
|
||||
if toolkit is None:
|
||||
if powerbi is None:
|
||||
raise ValueError("Must provide either a toolkit or powerbi dataset")
|
||||
toolkit = PowerBIToolkit(powerbi=powerbi, llm=llm, examples=examples)
|
||||
tools = toolkit.get_tools()
|
||||
tables = powerbi.table_names if powerbi else toolkit.powerbi.table_names
|
||||
agent = ConversationalChatAgent.from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
system_message=prefix.format(top_k=top_k).format(tables=tables),
|
||||
human_message=suffix,
|
||||
input_variables=input_variables,
|
||||
callback_manager=callback_manager,
|
||||
output_parser=output_parser,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
memory=memory
|
||||
or ConversationBufferMemory(memory_key="chat_history", return_messages=True),
|
||||
verbose=verbose,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
||||
@@ -0,0 +1,38 @@
|
||||
# flake8: noqa
|
||||
"""Prompts for PowerBI agent."""
|
||||
|
||||
|
||||
POWERBI_PREFIX = """You are an agent designed to help users interact with a PowerBI Dataset.
|
||||
|
||||
Agent has access to a tool that can write a query based on the question and then run those against PowerBI, Microsofts business intelligence tool. The questions from the users should be interpreted as related to the dataset that is available and not general questions about the world. If the question does not seem related to the dataset, return "This does not appear to be part of this dataset." as the answer.
|
||||
|
||||
Given an input question, ask to run the questions against the dataset, then look at the results and return the answer, the answer should be a complete sentence that answers the question, if multiple rows are asked find a way to write that in a easily readable format for a human, also make sure to represent numbers in readable ways, like 1M instead of 1000000. Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
|
||||
"""
|
||||
|
||||
POWERBI_SUFFIX = """Begin!
|
||||
|
||||
Question: {input}
|
||||
Thought: I can first ask which tables I have, then how each table is defined and then ask the query tool the question I need, and finally create a nice sentence that answers the question.
|
||||
{agent_scratchpad}"""
|
||||
|
||||
POWERBI_CHAT_PREFIX = """Assistant is a large language model built to help users interact with a PowerBI Dataset.
|
||||
|
||||
Assistant should try to create a correct and complete answer to the question from the user. If the user asks a question not related to the dataset it should return "This does not appear to be part of this dataset." as the answer. The user might make a mistake with the spelling of certain values, if you think that is the case, ask the user to confirm the spelling of the value and then run the query again. Unless the user specifies a specific number of examples they wish to obtain, and the results are too large, limit your query to at most {top_k} results, but make it clear when answering which field was used for the filtering. The user has access to these tables: {{tables}}.
|
||||
|
||||
The answer should be a complete sentence that answers the question, if multiple rows are asked find a way to write that in a easily readable format for a human, also make sure to represent numbers in readable ways, like 1M instead of 1000000.
|
||||
"""
|
||||
|
||||
POWERBI_CHAT_SUFFIX = """TOOLS
|
||||
------
|
||||
Assistant can ask the user to use tools to look up information that may be helpful in answering the users original question. The tools the human can use are:
|
||||
|
||||
{{tools}}
|
||||
|
||||
{format_instructions}
|
||||
|
||||
USER'S INPUT
|
||||
--------------------
|
||||
Here is the user's input (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):
|
||||
|
||||
{{{{input}}}}
|
||||
"""
|
||||
@@ -0,0 +1,102 @@
|
||||
"""Toolkit for interacting with a Power BI dataset."""
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_core.callbacks.base import BaseCallbackManager
|
||||
from langchain_integrations.chains.llm import LLMChain
|
||||
from langchain_integrations.chat_models.base import BaseChatModel
|
||||
from langchain_integrations.tools import BaseTool
|
||||
from langchain_integrations.tools.powerbi.prompt import (
|
||||
QUESTION_TO_QUERY_BASE,
|
||||
SINGLE_QUESTION_TO_QUERY,
|
||||
USER_INPUT,
|
||||
)
|
||||
from langchain_integrations.tools.powerbi.tool import (
|
||||
InfoPowerBITool,
|
||||
ListPowerBITool,
|
||||
QueryPowerBITool,
|
||||
)
|
||||
from langchain_integrations.utilities.powerbi import PowerBIDataset
|
||||
|
||||
|
||||
class PowerBIToolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with Power BI dataset.
|
||||
|
||||
*Security Note*: This toolkit interacts with an external service.
|
||||
|
||||
Control access to who can use this toolkit.
|
||||
|
||||
Make sure that the capabilities given by this toolkit to the calling
|
||||
code are appropriately scoped to the application.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
powerbi: PowerBIDataset = Field(exclude=True)
|
||||
llm: Union[BaseLanguageModel, BaseChatModel] = Field(exclude=True)
|
||||
examples: Optional[str] = None
|
||||
max_iterations: int = 5
|
||||
callback_manager: Optional[BaseCallbackManager] = None
|
||||
output_token_limit: Optional[int] = None
|
||||
tiktoken_model_name: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
return [
|
||||
QueryPowerBITool(
|
||||
llm_chain=self._get_chain(),
|
||||
powerbi=self.powerbi,
|
||||
examples=self.examples,
|
||||
max_iterations=self.max_iterations,
|
||||
output_token_limit=self.output_token_limit,
|
||||
tiktoken_model_name=self.tiktoken_model_name,
|
||||
),
|
||||
InfoPowerBITool(powerbi=self.powerbi),
|
||||
ListPowerBITool(powerbi=self.powerbi),
|
||||
]
|
||||
|
||||
def _get_chain(self) -> LLMChain:
|
||||
"""Construct the chain based on the callback manager and model type."""
|
||||
if isinstance(self.llm, BaseLanguageModel):
|
||||
return LLMChain(
|
||||
llm=self.llm,
|
||||
callback_manager=self.callback_manager
|
||||
if self.callback_manager
|
||||
else None,
|
||||
prompt=PromptTemplate(
|
||||
template=SINGLE_QUESTION_TO_QUERY,
|
||||
input_variables=["tool_input", "tables", "schemas", "examples"],
|
||||
),
|
||||
)
|
||||
|
||||
system_prompt = SystemMessagePromptTemplate(
|
||||
prompt=PromptTemplate(
|
||||
template=QUESTION_TO_QUERY_BASE,
|
||||
input_variables=["tables", "schemas", "examples"],
|
||||
)
|
||||
)
|
||||
human_prompt = HumanMessagePromptTemplate(
|
||||
prompt=PromptTemplate(
|
||||
template=USER_INPUT,
|
||||
input_variables=["tool_input"],
|
||||
)
|
||||
)
|
||||
return LLMChain(
|
||||
llm=self.llm,
|
||||
callback_manager=self.callback_manager if self.callback_manager else None,
|
||||
prompt=ChatPromptTemplate.from_messages([system_prompt, human_prompt]),
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api.path import as_import_path
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Get attr name."""
|
||||
|
||||
here = as_import_path(Path(__file__).parent)
|
||||
|
||||
old_path = "langchain." + here + "." + name
|
||||
new_path = "langchain_experimental." + here + "." + name
|
||||
raise AttributeError(
|
||||
"This agent has been moved to langchain experiment. "
|
||||
"This agent relies on python REPL tool under the hood, so to use it "
|
||||
"safely please sandbox the python REPL. "
|
||||
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
|
||||
"and https://github.com/langchain-ai/langchain/discussions/11680"
|
||||
"To keep using this code as is, install langchain experimental and "
|
||||
f"update your import statement from:\n `{old_path}` to `{new_path}`."
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api.path import as_import_path
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Get attr name."""
|
||||
|
||||
here = as_import_path(Path(__file__).parent)
|
||||
|
||||
old_path = "langchain." + here + "." + name
|
||||
new_path = "langchain_experimental." + here + "." + name
|
||||
raise AttributeError(
|
||||
"This agent has been moved to langchain experiment. "
|
||||
"This agent relies on python REPL tool under the hood, so to use it "
|
||||
"safely please sandbox the python REPL. "
|
||||
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
|
||||
"and https://github.com/langchain-ai/langchain/discussions/11680"
|
||||
"To keep using this code as is, install langchain experimental and "
|
||||
f"update your import statement from:\n `{old_path}` to `{new_path}`."
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
"""Spark SQL agent."""
|
||||
@@ -0,0 +1,60 @@
|
||||
"""Spark SQL agent."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from langchain_integrations.agents.agent import AgentExecutor
|
||||
from langchain_integrations.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUFFIX
|
||||
from langchain_integrations.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
|
||||
from langchain_integrations.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain_integrations.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain_core.callbacks.base import BaseCallbackManager, Callbacks
|
||||
from langchain_integrations.chains.llm import LLMChain
|
||||
|
||||
|
||||
def create_spark_sql_agent(
|
||||
llm: BaseLanguageModel,
|
||||
toolkit: SparkSQLToolkit,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
callbacks: Callbacks = None,
|
||||
prefix: str = SQL_PREFIX,
|
||||
suffix: str = SQL_SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
top_k: int = 10,
|
||||
max_iterations: Optional[int] = 15,
|
||||
max_execution_time: Optional[float] = None,
|
||||
early_stopping_method: str = "force",
|
||||
verbose: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct a Spark SQL agent from an LLM and tools."""
|
||||
tools = toolkit.get_tools()
|
||||
prefix = prefix.format(top_k=top_k)
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
verbose=verbose,
|
||||
max_iterations=max_iterations,
|
||||
max_execution_time=max_execution_time,
|
||||
early_stopping_method=early_stopping_method,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
||||
@@ -0,0 +1,21 @@
|
||||
# flake8: noqa
|
||||
|
||||
SQL_PREFIX = """You are an agent designed to interact with Spark SQL.
|
||||
Given an input question, create a syntactically correct Spark SQL query to run, then look at the results of the query and return the answer.
|
||||
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
|
||||
You can order the results by a relevant column to return the most interesting examples in the database.
|
||||
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
|
||||
You have access to tools for interacting with the database.
|
||||
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
|
||||
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
|
||||
|
||||
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
|
||||
|
||||
If the question does not seem related to the database, just return "I don't know" as the answer.
|
||||
"""
|
||||
|
||||
SQL_SUFFIX = """Begin!
|
||||
|
||||
Question: {input}
|
||||
Thought: I should look at the tables in the database to see what I can query.
|
||||
{agent_scratchpad}"""
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Toolkit for interacting with Spark SQL."""
|
||||
from typing import List
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_integrations.tools import BaseTool
|
||||
from langchain_integrations.tools.spark_sql.tool import (
|
||||
InfoSparkSQLTool,
|
||||
ListSparkSQLTool,
|
||||
QueryCheckerTool,
|
||||
QuerySparkSQLTool,
|
||||
)
|
||||
from langchain_integrations.utilities.spark_sql import SparkSQL
|
||||
|
||||
|
||||
class SparkSQLToolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with Spark SQL."""
|
||||
|
||||
db: SparkSQL = Field(exclude=True)
|
||||
llm: BaseLanguageModel = Field(exclude=True)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
return [
|
||||
QuerySparkSQLTool(db=self.db),
|
||||
InfoSparkSQLTool(db=self.db),
|
||||
ListSparkSQLTool(db=self.db),
|
||||
QueryCheckerTool(db=self.db, llm=self.llm),
|
||||
]
|
||||
@@ -0,0 +1 @@
|
||||
"""SQL agent."""
|
||||
@@ -0,0 +1,96 @@
|
||||
"""SQL agent."""
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import AIMessage, SystemMessage
|
||||
from langchain_core.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
)
|
||||
|
||||
from langchain_integrations.agents.agent import AgentExecutor, BaseSingleActionAgent
|
||||
from langchain_integrations.agent_toolkits.sql.prompt import (
|
||||
SQL_FUNCTIONS_SUFFIX,
|
||||
SQL_PREFIX,
|
||||
SQL_SUFFIX,
|
||||
)
|
||||
from langchain_integrations.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
||||
from langchain_integrations.agents.agent_types import AgentType
|
||||
from langchain_integrations.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain_integrations.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain_integrations.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain_core.callbacks.base import BaseCallbackManager
|
||||
from langchain_integrations.chains.llm import LLMChain
|
||||
from langchain_integrations.tools import BaseTool
|
||||
|
||||
|
||||
def create_sql_agent(
|
||||
llm: BaseLanguageModel,
|
||||
toolkit: SQLDatabaseToolkit,
|
||||
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = SQL_PREFIX,
|
||||
suffix: Optional[str] = None,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
top_k: int = 10,
|
||||
max_iterations: Optional[int] = 15,
|
||||
max_execution_time: Optional[float] = None,
|
||||
early_stopping_method: str = "force",
|
||||
verbose: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
extra_tools: Sequence[BaseTool] = (),
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct an SQL agent from an LLM and tools."""
|
||||
tools = toolkit.get_tools() + list(extra_tools)
|
||||
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
|
||||
agent: BaseSingleActionAgent
|
||||
|
||||
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix or SQL_SUFFIX,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
|
||||
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||
messages = [
|
||||
SystemMessage(content=prefix),
|
||||
HumanMessagePromptTemplate.from_template("{input}"),
|
||||
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
agent = OpenAIFunctionsAgent(
|
||||
llm=llm,
|
||||
prompt=_prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
|
||||
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
max_iterations=max_iterations,
|
||||
max_execution_time=max_execution_time,
|
||||
early_stopping_method=early_stopping_method,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
||||
@@ -0,0 +1,23 @@
|
||||
# flake8: noqa
|
||||
|
||||
SQL_PREFIX = """You are an agent designed to interact with a SQL database.
|
||||
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
|
||||
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
|
||||
You can order the results by a relevant column to return the most interesting examples in the database.
|
||||
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
|
||||
You have access to tools for interacting with the database.
|
||||
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
|
||||
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
|
||||
|
||||
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
|
||||
|
||||
If the question does not seem related to the database, just return "I don't know" as the answer.
|
||||
"""
|
||||
|
||||
SQL_SUFFIX = """Begin!
|
||||
|
||||
Question: {input}
|
||||
Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
|
||||
{agent_scratchpad}"""
|
||||
|
||||
SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."""
|
||||
@@ -0,0 +1,71 @@
|
||||
"""Toolkit for interacting with an SQL database."""
|
||||
from typing import List
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_integrations.tools import BaseTool
|
||||
from langchain_integrations.tools.sql_database.tool import (
|
||||
InfoSQLDatabaseTool,
|
||||
ListSQLDatabaseTool,
|
||||
QuerySQLCheckerTool,
|
||||
QuerySQLDataBaseTool,
|
||||
)
|
||||
from langchain_integrations.utilities.sql_database import SQLDatabase
|
||||
|
||||
|
||||
class SQLDatabaseToolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with SQL databases."""
|
||||
|
||||
db: SQLDatabase = Field(exclude=True)
|
||||
llm: BaseLanguageModel = Field(exclude=True)
|
||||
|
||||
@property
|
||||
def dialect(self) -> str:
|
||||
"""Return string representation of SQL dialect to use."""
|
||||
return self.db.dialect
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
list_sql_database_tool = ListSQLDatabaseTool(db=self.db)
|
||||
info_sql_database_tool_description = (
|
||||
"Input to this tool is a comma-separated list of tables, output is the "
|
||||
"schema and sample rows for those tables. "
|
||||
"Be sure that the tables actually exist by calling "
|
||||
f"{list_sql_database_tool.name} first! "
|
||||
"Example Input: table1, table2, table3"
|
||||
)
|
||||
info_sql_database_tool = InfoSQLDatabaseTool(
|
||||
db=self.db, description=info_sql_database_tool_description
|
||||
)
|
||||
query_sql_database_tool_description = (
|
||||
"Input to this tool is a detailed and correct SQL query, output is a "
|
||||
"result from the database. If the query is not correct, an error message "
|
||||
"will be returned. If an error is returned, rewrite the query, check the "
|
||||
"query, and try again. If you encounter an issue with Unknown column "
|
||||
f"'xxxx' in 'field list', use {info_sql_database_tool.name} "
|
||||
"to query the correct table fields."
|
||||
)
|
||||
query_sql_database_tool = QuerySQLDataBaseTool(
|
||||
db=self.db, description=query_sql_database_tool_description
|
||||
)
|
||||
query_sql_checker_tool_description = (
|
||||
"Use this tool to double check if your query is correct before executing "
|
||||
"it. Always use this tool before executing a query with "
|
||||
f"{query_sql_database_tool.name}!"
|
||||
)
|
||||
query_sql_checker_tool = QuerySQLCheckerTool(
|
||||
db=self.db, llm=self.llm, description=query_sql_checker_tool_description
|
||||
)
|
||||
return [
|
||||
query_sql_database_tool,
|
||||
info_sql_database_tool,
|
||||
list_sql_database_tool,
|
||||
query_sql_checker_tool,
|
||||
]
|
||||
@@ -0,0 +1 @@
|
||||
"""Agent toolkit for interacting with vector stores."""
|
||||
@@ -0,0 +1,96 @@
|
||||
"""VectorStore agent."""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from langchain_integrations.agents.agent import AgentExecutor
|
||||
from langchain_integrations.agent_toolkits.vectorstore.prompt import PREFIX, ROUTER_PREFIX
|
||||
from langchain_integrations.agent_toolkits.vectorstore.toolkit import (
|
||||
VectorStoreRouterToolkit,
|
||||
VectorStoreToolkit,
|
||||
)
|
||||
from langchain_integrations.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain_core.callbacks.base import BaseCallbackManager
|
||||
from langchain_integrations.chains.llm import LLMChain
|
||||
|
||||
|
||||
def create_vectorstore_agent(
|
||||
llm: BaseLanguageModel,
|
||||
toolkit: VectorStoreToolkit,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = PREFIX,
|
||||
verbose: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct a VectorStore agent from an LLM and tools.
|
||||
|
||||
Args:
|
||||
llm (BaseLanguageModel): LLM that will be used by the agent
|
||||
toolkit (VectorStoreToolkit): Set of tools for the agent
|
||||
callback_manager (Optional[BaseCallbackManager], optional): Object to handle the callback [ Defaults to None. ]
|
||||
prefix (str, optional): The prefix prompt for the agent. If not provided uses default PREFIX.
|
||||
verbose (bool, optional): If you want to see the content of the scratchpad. [ Defaults to False ]
|
||||
agent_executor_kwargs (Optional[Dict[str, Any]], optional): If there is any other parameter you want to send to the agent. [ Defaults to None ]
|
||||
**kwargs: Additional named parameters to pass to the ZeroShotAgent.
|
||||
|
||||
Returns:
|
||||
AgentExecutor: Returns a callable AgentExecutor object. Either you can call it or use run method with the query to get the response
|
||||
""" # noqa: E501
|
||||
tools = toolkit.get_tools()
|
||||
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
||||
|
||||
|
||||
def create_vectorstore_router_agent(
|
||||
llm: BaseLanguageModel,
|
||||
toolkit: VectorStoreRouterToolkit,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = ROUTER_PREFIX,
|
||||
verbose: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct a VectorStore router agent from an LLM and tools.
|
||||
|
||||
Args:
|
||||
llm (BaseLanguageModel): LLM that will be used by the agent
|
||||
toolkit (VectorStoreRouterToolkit): Set of tools for the agent which have routing capability with multiple vector stores
|
||||
callback_manager (Optional[BaseCallbackManager], optional): Object to handle the callback [ Defaults to None. ]
|
||||
prefix (str, optional): The prefix prompt for the router agent. If not provided uses default ROUTER_PREFIX.
|
||||
verbose (bool, optional): If you want to see the content of the scratchpad. [ Defaults to False ]
|
||||
agent_executor_kwargs (Optional[Dict[str, Any]], optional): If there is any other parameter you want to send to the agent. [ Defaults to None ]
|
||||
**kwargs: Additional named parameters to pass to the ZeroShotAgent.
|
||||
|
||||
Returns:
|
||||
AgentExecutor: Returns a callable AgentExecutor object. Either you can call it or use run method with the query to get the response.
|
||||
""" # noqa: E501
|
||||
tools = toolkit.get_tools()
|
||||
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
||||
@@ -0,0 +1,13 @@
|
||||
# flake8: noqa
|
||||
|
||||
PREFIX = """You are an agent designed to answer questions about sets of documents.
|
||||
You have access to tools for interacting with the documents, and the inputs to the tools are questions.
|
||||
Sometimes, you will be asked to provide sources for your questions, in which case you should use the appropriate tool to do so.
|
||||
If the question does not seem relevant to any of the tools provided, just return "I don't know" as the answer.
|
||||
"""
|
||||
|
||||
ROUTER_PREFIX = """You are an agent designed to answer questions.
|
||||
You have access to tools for interacting with different sources, and the inputs to the tools are questions.
|
||||
Your main task is to decide which of the tools is relevant for answering question at hand.
|
||||
For complex questions, you can break the question down into sub questions and use tools to answers the sub questions.
|
||||
"""
|
||||
@@ -0,0 +1,89 @@
|
||||
"""Toolkit for interacting with a vector store."""
|
||||
from typing import List
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_openai.llm import OpenAI
|
||||
from langchain_integrations.tools import BaseTool
|
||||
from langchain_integrations.tools.vectorstore.tool import (
|
||||
VectorStoreQATool,
|
||||
VectorStoreQAWithSourcesTool,
|
||||
)
|
||||
|
||||
|
||||
class VectorStoreInfo(BaseModel):
|
||||
"""Information about a VectorStore."""
|
||||
|
||||
vectorstore: VectorStore = Field(exclude=True)
|
||||
name: str
|
||||
description: str
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class VectorStoreToolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with a Vector Store."""
|
||||
|
||||
vectorstore_info: VectorStoreInfo = Field(exclude=True)
|
||||
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
description = VectorStoreQATool.get_description(
|
||||
self.vectorstore_info.name, self.vectorstore_info.description
|
||||
)
|
||||
qa_tool = VectorStoreQATool(
|
||||
name=self.vectorstore_info.name,
|
||||
description=description,
|
||||
vectorstore=self.vectorstore_info.vectorstore,
|
||||
llm=self.llm,
|
||||
)
|
||||
description = VectorStoreQAWithSourcesTool.get_description(
|
||||
self.vectorstore_info.name, self.vectorstore_info.description
|
||||
)
|
||||
qa_with_sources_tool = VectorStoreQAWithSourcesTool(
|
||||
name=f"{self.vectorstore_info.name}_with_sources",
|
||||
description=description,
|
||||
vectorstore=self.vectorstore_info.vectorstore,
|
||||
llm=self.llm,
|
||||
)
|
||||
return [qa_tool, qa_with_sources_tool]
|
||||
|
||||
|
||||
class VectorStoreRouterToolkit(BaseToolkit):
|
||||
"""Toolkit for routing between Vector Stores."""
|
||||
|
||||
vectorstores: List[VectorStoreInfo] = Field(exclude=True)
|
||||
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
tools: List[BaseTool] = []
|
||||
for vectorstore_info in self.vectorstores:
|
||||
description = VectorStoreQATool.get_description(
|
||||
vectorstore_info.name, vectorstore_info.description
|
||||
)
|
||||
qa_tool = VectorStoreQATool(
|
||||
name=vectorstore_info.name,
|
||||
description=description,
|
||||
vectorstore=vectorstore_info.vectorstore,
|
||||
llm=self.llm,
|
||||
)
|
||||
tools.append(qa_tool)
|
||||
return tools
|
||||
@@ -0,0 +1,22 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api.path import as_import_path
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Get attr name."""
|
||||
|
||||
here = as_import_path(Path(__file__).parent)
|
||||
|
||||
old_path = "langchain." + here + "." + name
|
||||
new_path = "langchain_experimental." + here + "." + name
|
||||
raise AttributeError(
|
||||
"This agent has been moved to langchain experiment. "
|
||||
"This agent relies on python REPL tool under the hood, so to use it "
|
||||
"safely please sandbox the python REPL. "
|
||||
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
|
||||
"and https://github.com/langchain-ai/langchain/discussions/11680"
|
||||
"To keep using this code as is, install langchain experimental and "
|
||||
f"update your import statement from:\n `{old_path}` to `{new_path}`."
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
"""Zapier Toolkit."""
|
||||
@@ -0,0 +1,60 @@
|
||||
"""[DEPRECATED] Zapier Toolkit."""
|
||||
from typing import List
|
||||
|
||||
from langchain_core._api import warn_deprecated
|
||||
|
||||
from langchain_integrations.agent_toolkits.base import BaseToolkit
|
||||
from langchain_integrations.tools import BaseTool
|
||||
from langchain_integrations.tools.zapier.tool import ZapierNLARunAction
|
||||
from langchain_integrations.utilities.zapier import ZapierNLAWrapper
|
||||
|
||||
|
||||
class ZapierToolkit(BaseToolkit):
|
||||
"""Zapier Toolkit."""
|
||||
|
||||
tools: List[BaseTool] = []
|
||||
|
||||
@classmethod
|
||||
def from_zapier_nla_wrapper(
|
||||
cls, zapier_nla_wrapper: ZapierNLAWrapper
|
||||
) -> "ZapierToolkit":
|
||||
"""Create a toolkit from a ZapierNLAWrapper."""
|
||||
actions = zapier_nla_wrapper.list()
|
||||
tools = [
|
||||
ZapierNLARunAction(
|
||||
action_id=action["id"],
|
||||
zapier_description=action["description"],
|
||||
params_schema=action["params"],
|
||||
api_wrapper=zapier_nla_wrapper,
|
||||
)
|
||||
for action in actions
|
||||
]
|
||||
return cls(tools=tools)
|
||||
|
||||
@classmethod
|
||||
async def async_from_zapier_nla_wrapper(
|
||||
cls, zapier_nla_wrapper: ZapierNLAWrapper
|
||||
) -> "ZapierToolkit":
|
||||
"""Create a toolkit from a ZapierNLAWrapper."""
|
||||
actions = await zapier_nla_wrapper.alist()
|
||||
tools = [
|
||||
ZapierNLARunAction(
|
||||
action_id=action["id"],
|
||||
zapier_description=action["description"],
|
||||
params_schema=action["params"],
|
||||
api_wrapper=zapier_nla_wrapper,
|
||||
)
|
||||
for action in actions
|
||||
]
|
||||
return cls(tools=tools)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
warn_deprecated(
|
||||
since="0.0.319",
|
||||
message=(
|
||||
"This tool will be deprecated on 2023-11-17. See "
|
||||
"https://nla.zapier.com/sunset/ for details"
|
||||
),
|
||||
)
|
||||
return self.tools
|
||||
@@ -0,0 +1,19 @@
|
||||
"""**Chat Loaders** load chat messages from common communications platforms.
|
||||
|
||||
Load chat messages from various
|
||||
communications platforms such as Facebook Messenger, Telegram, and
|
||||
WhatsApp. The loaded chat messages can be used for fine-tuning models.
|
||||
|
||||
**Class hierarchy:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
BaseChatLoader --> <name>ChatLoader # Examples: WhatsAppChatLoader, IMessageChatLoader
|
||||
|
||||
**Main helpers:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
ChatSession
|
||||
|
||||
""" # noqa: E501
|
||||
@@ -0,0 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterator, List
|
||||
|
||||
from langchain_core.chat_sessions import ChatSession
|
||||
|
||||
|
||||
class BaseChatLoader(ABC):
|
||||
"""Base class for chat loaders."""
|
||||
|
||||
@abstractmethod
|
||||
def lazy_load(self) -> Iterator[ChatSession]:
|
||||
"""Lazy load the chat sessions."""
|
||||
|
||||
def load(self) -> List[ChatSession]:
|
||||
"""Eagerly load the chat sessions into memory."""
|
||||
return list(self.lazy_load())
|
||||
@@ -0,0 +1,79 @@
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Union
|
||||
|
||||
from langchain_core.chat_sessions import ChatSession
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from langchain_integrations.chat_loaders.base import BaseChatLoader
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class SingleFileFacebookMessengerChatLoader(BaseChatLoader):
|
||||
"""Load `Facebook Messenger` chat data from a single file.
|
||||
|
||||
Args:
|
||||
path (Union[Path, str]): The path to the chat file.
|
||||
|
||||
Attributes:
|
||||
path (Path): The path to the chat file.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, path: Union[Path, str]) -> None:
|
||||
super().__init__()
|
||||
self.file_path = path if isinstance(path, Path) else Path(path)
|
||||
|
||||
def lazy_load(self) -> Iterator[ChatSession]:
|
||||
"""Lazy loads the chat data from the file.
|
||||
|
||||
Yields:
|
||||
ChatSession: A chat session containing the loaded messages.
|
||||
|
||||
"""
|
||||
with open(self.file_path) as f:
|
||||
data = json.load(f)
|
||||
sorted_data = sorted(data["messages"], key=lambda x: x["timestamp_ms"])
|
||||
messages = []
|
||||
for m in sorted_data:
|
||||
messages.append(
|
||||
HumanMessage(
|
||||
content=m["content"], additional_kwargs={"sender": m["sender_name"]}
|
||||
)
|
||||
)
|
||||
yield ChatSession(messages=messages)
|
||||
|
||||
|
||||
class FolderFacebookMessengerChatLoader(BaseChatLoader):
|
||||
"""Load `Facebook Messenger` chat data from a folder.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]): The path to the directory
|
||||
containing the chat files.
|
||||
|
||||
Attributes:
|
||||
path (Path): The path to the directory containing the chat files.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, path: Union[str, Path]) -> None:
|
||||
super().__init__()
|
||||
self.directory_path = Path(path) if isinstance(path, str) else path
|
||||
|
||||
def lazy_load(self) -> Iterator[ChatSession]:
|
||||
"""Lazy loads the chat data from the folder.
|
||||
|
||||
Yields:
|
||||
ChatSession: A chat session containing the loaded messages.
|
||||
|
||||
"""
|
||||
inbox_path = self.directory_path / "inbox"
|
||||
for _dir in inbox_path.iterdir():
|
||||
if _dir.is_dir():
|
||||
for _file in _dir.iterdir():
|
||||
if _file.suffix.lower() == ".json":
|
||||
file_loader = SingleFileFacebookMessengerChatLoader(path=_file)
|
||||
for result in file_loader.lazy_load():
|
||||
yield result
|
||||
112
libs/integrations/langchain_integrations/chat_loaders/gmail.py
Normal file
112
libs/integrations/langchain_integrations/chat_loaders/gmail.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import base64
|
||||
import re
|
||||
from typing import Any, Iterator
|
||||
|
||||
from langchain_core.chat_sessions import ChatSession
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from langchain_integrations.chat_loaders.base import BaseChatLoader
|
||||
|
||||
|
||||
def _extract_email_content(msg: Any) -> HumanMessage:
|
||||
from_email = None
|
||||
for values in msg["payload"]["headers"]:
|
||||
name = values["name"]
|
||||
if name == "From":
|
||||
from_email = values["value"]
|
||||
if from_email is None:
|
||||
raise ValueError
|
||||
for part in msg["payload"]["parts"]:
|
||||
if part["mimeType"] == "text/plain":
|
||||
data = part["body"]["data"]
|
||||
data = base64.urlsafe_b64decode(data).decode("utf-8")
|
||||
# Regular expression to split the email body at the first
|
||||
# occurrence of a line that starts with "On ... wrote:"
|
||||
pattern = re.compile(r"\r\nOn .+(\r\n)*wrote:\r\n")
|
||||
# Split the email body and extract the first part
|
||||
newest_response = re.split(pattern, data)[0]
|
||||
message = HumanMessage(
|
||||
content=newest_response, additional_kwargs={"sender": from_email}
|
||||
)
|
||||
return message
|
||||
raise ValueError
|
||||
|
||||
|
||||
def _get_message_data(service: Any, message: Any) -> ChatSession:
|
||||
msg = service.users().messages().get(userId="me", id=message["id"]).execute()
|
||||
message_content = _extract_email_content(msg)
|
||||
in_reply_to = None
|
||||
email_data = msg["payload"]["headers"]
|
||||
for values in email_data:
|
||||
name = values["name"]
|
||||
if name == "In-Reply-To":
|
||||
in_reply_to = values["value"]
|
||||
if in_reply_to is None:
|
||||
raise ValueError
|
||||
|
||||
thread_id = msg["threadId"]
|
||||
|
||||
thread = service.users().threads().get(userId="me", id=thread_id).execute()
|
||||
messages = thread["messages"]
|
||||
|
||||
response_email = None
|
||||
for message in messages:
|
||||
email_data = message["payload"]["headers"]
|
||||
for values in email_data:
|
||||
if values["name"] == "Message-ID":
|
||||
message_id = values["value"]
|
||||
if message_id == in_reply_to:
|
||||
response_email = message
|
||||
if response_email is None:
|
||||
raise ValueError
|
||||
starter_content = _extract_email_content(response_email)
|
||||
return ChatSession(messages=[starter_content, message_content])
|
||||
|
||||
|
||||
class GMailLoader(BaseChatLoader):
|
||||
"""Load data from `GMail`.
|
||||
|
||||
There are many ways you could want to load data from GMail.
|
||||
This loader is currently fairly opinionated in how to do so.
|
||||
The way it does it is it first looks for all messages that you have sent.
|
||||
It then looks for messages where you are responding to a previous email.
|
||||
It then fetches that previous email, and creates a training example
|
||||
of that email, followed by your email.
|
||||
|
||||
Note that there are clear limitations here. For example,
|
||||
all examples created are only looking at the previous email for context.
|
||||
|
||||
To use:
|
||||
|
||||
- Set up a Google Developer Account:
|
||||
Go to the Google Developer Console, create a project,
|
||||
and enable the Gmail API for that project.
|
||||
This will give you a credentials.json file that you'll need later.
|
||||
"""
|
||||
|
||||
def __init__(self, creds: Any, n: int = 100, raise_error: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.creds = creds
|
||||
self.n = n
|
||||
self.raise_error = raise_error
|
||||
|
||||
def lazy_load(self) -> Iterator[ChatSession]:
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
service = build("gmail", "v1", credentials=self.creds)
|
||||
results = (
|
||||
service.users()
|
||||
.messages()
|
||||
.list(userId="me", labelIds=["SENT"], maxResults=self.n)
|
||||
.execute()
|
||||
)
|
||||
messages = results.get("messages", [])
|
||||
for message in messages:
|
||||
try:
|
||||
yield _get_message_data(service, message)
|
||||
except Exception as e:
|
||||
# TODO: handle errors better
|
||||
if self.raise_error:
|
||||
raise e
|
||||
else:
|
||||
pass
|
||||
@@ -0,0 +1,161 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
|
||||
|
||||
from langchain_core.chat_sessions import ChatSession
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from langchain_integrations.chat_loaders.base import BaseChatLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sqlite3
|
||||
|
||||
|
||||
class IMessageChatLoader(BaseChatLoader):
|
||||
"""Load chat sessions from the `iMessage` chat.db SQLite file.
|
||||
|
||||
It only works on macOS when you have iMessage enabled and have the chat.db file.
|
||||
|
||||
The chat.db file is likely located at ~/Library/Messages/chat.db. However, your
|
||||
terminal may not have permission to access this file. To resolve this, you can
|
||||
copy the file to a different location, change the permissions of the file, or
|
||||
grant full disk access for your terminal emulator
|
||||
in System Settings > Security and Privacy > Full Disk Access.
|
||||
"""
|
||||
|
||||
def __init__(self, path: Optional[Union[str, Path]] = None):
|
||||
"""
|
||||
Initialize the IMessageChatLoader.
|
||||
|
||||
Args:
|
||||
path (str or Path, optional): Path to the chat.db SQLite file.
|
||||
Defaults to None, in which case the default path
|
||||
~/Library/Messages/chat.db will be used.
|
||||
"""
|
||||
if path is None:
|
||||
path = Path.home() / "Library" / "Messages" / "chat.db"
|
||||
self.db_path = path if isinstance(path, Path) else Path(path)
|
||||
if not self.db_path.exists():
|
||||
raise FileNotFoundError(f"File {self.db_path} not found")
|
||||
try:
|
||||
import sqlite3 # noqa: F401
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"The sqlite3 module is required to load iMessage chats.\n"
|
||||
"Please install it with `pip install pysqlite3`"
|
||||
) from e
|
||||
|
||||
def _parse_attributedBody(self, attributedBody: bytes) -> str:
|
||||
"""
|
||||
Parse the attributedBody field of the message table
|
||||
for the text content of the message.
|
||||
|
||||
The attributedBody field is a binary blob that contains
|
||||
the message content after the byte string b"NSString":
|
||||
|
||||
5 bytes 1-3 bytes `len` bytes
|
||||
... | b"NSString" | preamble | `len` | contents | ...
|
||||
|
||||
The 5 preamble bytes are always b"\x01\x94\x84\x01+"
|
||||
|
||||
The size of `len` is either 1 byte or 3 bytes:
|
||||
- If the first byte in `len` is b"\x81" then `len` is 3 bytes long.
|
||||
So the message length is the 2 bytes after, in little Endian.
|
||||
- Otherwise, the size of `len` is 1 byte, and the message length is
|
||||
that byte.
|
||||
|
||||
Args:
|
||||
attributedBody (bytes): attributedBody field of the message table.
|
||||
Return:
|
||||
str: Text content of the message.
|
||||
"""
|
||||
content = attributedBody.split(b"NSString")[1][5:]
|
||||
length, start = content[0], 1
|
||||
if content[0] == 129:
|
||||
length, start = int.from_bytes(content[1:3], "little"), 3
|
||||
return content[start : start + length].decode("utf-8", errors="ignore")
|
||||
|
||||
def _load_single_chat_session(
|
||||
self, cursor: "sqlite3.Cursor", chat_id: int
|
||||
) -> ChatSession:
|
||||
"""
|
||||
Load a single chat session from the iMessage chat.db.
|
||||
|
||||
Args:
|
||||
cursor: SQLite cursor object.
|
||||
chat_id (int): ID of the chat session to load.
|
||||
|
||||
Returns:
|
||||
ChatSession: Loaded chat session.
|
||||
"""
|
||||
results: List[HumanMessage] = []
|
||||
|
||||
query = """
|
||||
SELECT message.date, handle.id, message.text, message.attributedBody
|
||||
FROM message
|
||||
JOIN chat_message_join ON message.ROWID = chat_message_join.message_id
|
||||
JOIN handle ON message.handle_id = handle.ROWID
|
||||
WHERE chat_message_join.chat_id = ?
|
||||
ORDER BY message.date ASC;
|
||||
"""
|
||||
cursor.execute(query, (chat_id,))
|
||||
messages = cursor.fetchall()
|
||||
|
||||
for date, sender, text, attributedBody in messages:
|
||||
if text:
|
||||
content = text
|
||||
elif attributedBody:
|
||||
content = self._parse_attributedBody(attributedBody)
|
||||
else: # Skip messages with no content
|
||||
continue
|
||||
|
||||
results.append(
|
||||
HumanMessage(
|
||||
role=sender,
|
||||
content=content,
|
||||
additional_kwargs={
|
||||
"message_time": date,
|
||||
"sender": sender,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return ChatSession(messages=results)
|
||||
|
||||
def lazy_load(self) -> Iterator[ChatSession]:
|
||||
"""
|
||||
Lazy load the chat sessions from the iMessage chat.db
|
||||
and yield them in the required format.
|
||||
|
||||
Yields:
|
||||
ChatSession: Loaded chat session.
|
||||
"""
|
||||
import sqlite3
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
except sqlite3.OperationalError as e:
|
||||
raise ValueError(
|
||||
f"Could not open iMessage DB file {self.db_path}.\n"
|
||||
"Make sure your terminal emulator has disk access to this file.\n"
|
||||
" You can either copy the DB file to an accessible location"
|
||||
" or grant full disk access for your terminal emulator."
|
||||
" You can grant full disk access for your terminal emulator"
|
||||
" in System Settings > Security and Privacy > Full Disk Access."
|
||||
) from e
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Fetch the list of chat IDs sorted by time (most recent first)
|
||||
query = """SELECT chat_id
|
||||
FROM message
|
||||
JOIN chat_message_join ON message.ROWID = chat_message_join.message_id
|
||||
GROUP BY chat_id
|
||||
ORDER BY MAX(date) DESC;"""
|
||||
cursor.execute(query)
|
||||
chat_ids = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
for chat_id in chat_ids:
|
||||
yield self._load_single_chat_session(cursor, chat_id)
|
||||
|
||||
conn.close()
|
||||
@@ -0,0 +1,159 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union, cast
|
||||
|
||||
from langchain_core.chat_sessions import ChatSession
|
||||
from langchain_core.load import load
|
||||
|
||||
from langchain_integrations.chat_loaders.base import BaseChatLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langsmith.client import Client
|
||||
from langsmith.schemas import Run
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LangSmithRunChatLoader(BaseChatLoader):
|
||||
"""
|
||||
Load chat sessions from a list of LangSmith "llm" runs.
|
||||
|
||||
Attributes:
|
||||
runs (Iterable[Union[str, Run]]): The list of LLM run IDs or run objects.
|
||||
client (Client): Instance of LangSmith client for fetching data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, runs: Iterable[Union[str, Run]], client: Optional["Client"] = None
|
||||
):
|
||||
"""
|
||||
Initialize a new LangSmithRunChatLoader instance.
|
||||
|
||||
:param runs: List of LLM run IDs or run objects.
|
||||
:param client: An instance of LangSmith client, if not provided,
|
||||
a new client instance will be created.
|
||||
"""
|
||||
from langsmith.client import Client
|
||||
|
||||
self.runs = runs
|
||||
self.client = client or Client()
|
||||
|
||||
def _load_single_chat_session(self, llm_run: "Run") -> ChatSession:
|
||||
"""
|
||||
Convert an individual LangSmith LLM run to a ChatSession.
|
||||
|
||||
:param llm_run: The LLM run object.
|
||||
:return: A chat session representing the run's data.
|
||||
"""
|
||||
chat_session = LangSmithRunChatLoader._get_messages_from_llm_run(llm_run)
|
||||
functions = LangSmithRunChatLoader._get_functions_from_llm_run(llm_run)
|
||||
if functions:
|
||||
chat_session["functions"] = functions
|
||||
return chat_session
|
||||
|
||||
@staticmethod
|
||||
def _get_messages_from_llm_run(llm_run: "Run") -> ChatSession:
|
||||
"""
|
||||
Extract messages from a LangSmith LLM run.
|
||||
|
||||
:param llm_run: The LLM run object.
|
||||
:return: ChatSession with the extracted messages.
|
||||
"""
|
||||
if llm_run.run_type != "llm":
|
||||
raise ValueError(f"Expected run of type llm. Got: {llm_run.run_type}")
|
||||
if "messages" not in llm_run.inputs:
|
||||
raise ValueError(f"Run has no 'messages' inputs. Got {llm_run.inputs}")
|
||||
if not llm_run.outputs:
|
||||
raise ValueError("Cannot convert pending run")
|
||||
messages = load(llm_run.inputs)["messages"]
|
||||
message_chunk = load(llm_run.outputs)["generations"][0]["message"]
|
||||
return ChatSession(messages=messages + [message_chunk])
|
||||
|
||||
@staticmethod
|
||||
def _get_functions_from_llm_run(llm_run: "Run") -> Optional[List[Dict]]:
|
||||
"""
|
||||
Extract functions from a LangSmith LLM run if they exist.
|
||||
|
||||
:param llm_run: The LLM run object.
|
||||
:return: Functions from the run or None.
|
||||
"""
|
||||
if llm_run.run_type != "llm":
|
||||
raise ValueError(f"Expected run of type llm. Got: {llm_run.run_type}")
|
||||
return (llm_run.extra or {}).get("invocation_params", {}).get("functions")
|
||||
|
||||
def lazy_load(self) -> Iterator[ChatSession]:
|
||||
"""
|
||||
Lazy load the chat sessions from the iterable of run IDs.
|
||||
|
||||
This method fetches the runs and converts them to chat sessions on-the-fly,
|
||||
yielding one session at a time.
|
||||
|
||||
:return: Iterator of chat sessions containing messages.
|
||||
"""
|
||||
from langsmith.schemas import Run
|
||||
|
||||
for run_obj in self.runs:
|
||||
try:
|
||||
if hasattr(run_obj, "id"):
|
||||
run = run_obj
|
||||
else:
|
||||
run = self.client.read_run(run_obj)
|
||||
session = self._load_single_chat_session(cast(Run, run))
|
||||
yield session
|
||||
except ValueError as e:
|
||||
logger.warning(f"Could not load run {run_obj}: {repr(e)}")
|
||||
continue
|
||||
|
||||
|
||||
class LangSmithDatasetChatLoader(BaseChatLoader):
|
||||
"""
|
||||
Load chat sessions from a LangSmith dataset with the "chat" data type.
|
||||
|
||||
Attributes:
|
||||
dataset_name (str): The name of the LangSmith dataset.
|
||||
client (Client): Instance of LangSmith client for fetching data.
|
||||
"""
|
||||
|
||||
def __init__(self, *, dataset_name: str, client: Optional["Client"] = None):
|
||||
"""
|
||||
Initialize a new LangSmithChatDatasetLoader instance.
|
||||
|
||||
:param dataset_name: The name of the LangSmith dataset.
|
||||
:param client: An instance of LangSmith client; if not provided,
|
||||
a new client instance will be created.
|
||||
"""
|
||||
try:
|
||||
from langsmith.client import Client
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"The LangSmith client is required to load LangSmith datasets.\n"
|
||||
"Please install it with `pip install langsmith`"
|
||||
) from e
|
||||
|
||||
self.dataset_name = dataset_name
|
||||
self.client = client or Client()
|
||||
|
||||
def lazy_load(self) -> Iterator[ChatSession]:
|
||||
"""
|
||||
Lazy load the chat sessions from the specified LangSmith dataset.
|
||||
|
||||
This method fetches the chat data from the dataset and
|
||||
converts each data point to chat sessions on-the-fly,
|
||||
yielding one session at a time.
|
||||
|
||||
:return: Iterator of chat sessions containing messages.
|
||||
"""
|
||||
from langchain_integrations.adapters import openai as oai_adapter # noqa: E402
|
||||
|
||||
data = self.client.read_dataset_openai_finetuning(
|
||||
dataset_name=self.dataset_name
|
||||
)
|
||||
for data_point in data:
|
||||
yield ChatSession(
|
||||
messages=[
|
||||
oai_adapter.convert_dict_to_message(m)
|
||||
for m in data_point.get("messages", [])
|
||||
],
|
||||
functions=data_point.get("functions"),
|
||||
)
|
||||
@@ -0,0 +1,86 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterator, List, Union
|
||||
|
||||
from langchain_core.chat_sessions import ChatSession
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from langchain_integrations.chat_loaders.base import BaseChatLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlackChatLoader(BaseChatLoader):
|
||||
"""Load `Slack` conversations from a dump zip file."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
):
|
||||
"""
|
||||
Initialize the chat loader with the path to the exported Slack dump zip file.
|
||||
|
||||
:param path: Path to the exported Slack dump zip file.
|
||||
"""
|
||||
self.zip_path = path if isinstance(path, Path) else Path(path)
|
||||
if not self.zip_path.exists():
|
||||
raise FileNotFoundError(f"File {self.zip_path} not found")
|
||||
|
||||
def _load_single_chat_session(self, messages: List[Dict]) -> ChatSession:
|
||||
results: List[Union[AIMessage, HumanMessage]] = []
|
||||
previous_sender = None
|
||||
for message in messages:
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
text = message.get("text", "")
|
||||
timestamp = message.get("ts", "")
|
||||
sender = message.get("user", "")
|
||||
if not sender:
|
||||
continue
|
||||
skip_pattern = re.compile(
|
||||
r"<@U\d+> has joined the channel", flags=re.IGNORECASE
|
||||
)
|
||||
if skip_pattern.match(text):
|
||||
continue
|
||||
if sender == previous_sender:
|
||||
results[-1].content += "\n\n" + text
|
||||
results[-1].additional_kwargs["events"].append(
|
||||
{"message_time": timestamp}
|
||||
)
|
||||
else:
|
||||
results.append(
|
||||
HumanMessage(
|
||||
role=sender,
|
||||
content=text,
|
||||
additional_kwargs={
|
||||
"sender": sender,
|
||||
"events": [{"message_time": timestamp}],
|
||||
},
|
||||
)
|
||||
)
|
||||
previous_sender = sender
|
||||
return ChatSession(messages=results)
|
||||
|
||||
def _read_json(self, zip_file: zipfile.ZipFile, file_path: str) -> List[dict]:
|
||||
"""Read JSON data from a zip subfile."""
|
||||
with zip_file.open(file_path, "r") as f:
|
||||
data = json.load(f)
|
||||
if not isinstance(data, list):
|
||||
raise ValueError(f"Expected list of dictionaries, got {type(data)}")
|
||||
return data
|
||||
|
||||
def lazy_load(self) -> Iterator[ChatSession]:
|
||||
"""
|
||||
Lazy load the chat sessions from the Slack dump file and yield them
|
||||
in the required format.
|
||||
|
||||
:return: Iterator of chat sessions containing messages.
|
||||
"""
|
||||
with zipfile.ZipFile(str(self.zip_path), "r") as zip_file:
|
||||
for file_path in zip_file.namelist():
|
||||
if file_path.endswith(".json"):
|
||||
messages = self._read_json(zip_file, file_path)
|
||||
yield self._load_single_chat_session(messages)
|
||||
@@ -0,0 +1,151 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Iterator, List, Union
|
||||
|
||||
from langchain_core.chat_sessions import ChatSession
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
|
||||
from langchain_integrations.chat_loaders.base import BaseChatLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TelegramChatLoader(BaseChatLoader):
|
||||
"""Load `telegram` conversations to LangChain chat messages.
|
||||
|
||||
To export, use the Telegram Desktop app from
|
||||
https://desktop.telegram.org/, select a conversation, click the three dots
|
||||
in the top right corner, and select "Export chat history". Then select
|
||||
"Machine-readable JSON" (preferred) to export. Note: the 'lite' versions of
|
||||
the desktop app (like "Telegram for MacOS") do not support exporting chat
|
||||
history.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
):
|
||||
"""Initialize the TelegramChatLoader.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]): Path to the exported Telegram chat zip,
|
||||
directory, json, or HTML file.
|
||||
"""
|
||||
self.path = path if isinstance(path, str) else str(path)
|
||||
|
||||
def _load_single_chat_session_html(self, file_path: str) -> ChatSession:
|
||||
"""Load a single chat session from an HTML file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the HTML file.
|
||||
|
||||
Returns:
|
||||
ChatSession: The loaded chat session.
|
||||
"""
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install the 'beautifulsoup4' package to load"
|
||||
" Telegram HTML files. You can do this by running"
|
||||
"'pip install beautifulsoup4' in your terminal."
|
||||
)
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
soup = BeautifulSoup(file, "html.parser")
|
||||
|
||||
results: List[Union[HumanMessage, AIMessage]] = []
|
||||
previous_sender = None
|
||||
for message in soup.select(".message.default"):
|
||||
timestamp = message.select_one(".pull_right.date.details")["title"]
|
||||
from_name_element = message.select_one(".from_name")
|
||||
if from_name_element is None and previous_sender is None:
|
||||
logger.debug("from_name not found in message")
|
||||
continue
|
||||
elif from_name_element is None:
|
||||
from_name = previous_sender
|
||||
else:
|
||||
from_name = from_name_element.text.strip()
|
||||
text = message.select_one(".text").text.strip()
|
||||
results.append(
|
||||
HumanMessage(
|
||||
content=text,
|
||||
additional_kwargs={
|
||||
"sender": from_name,
|
||||
"events": [{"message_time": timestamp}],
|
||||
},
|
||||
)
|
||||
)
|
||||
previous_sender = from_name
|
||||
|
||||
return ChatSession(messages=results)
|
||||
|
||||
def _load_single_chat_session_json(self, file_path: str) -> ChatSession:
|
||||
"""Load a single chat session from a JSON file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the JSON file.
|
||||
|
||||
Returns:
|
||||
ChatSession: The loaded chat session.
|
||||
"""
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
data = json.load(file)
|
||||
|
||||
messages = data.get("messages", [])
|
||||
results: List[BaseMessage] = []
|
||||
for message in messages:
|
||||
text = message.get("text", "")
|
||||
timestamp = message.get("date", "")
|
||||
from_name = message.get("from", "")
|
||||
|
||||
results.append(
|
||||
HumanMessage(
|
||||
content=text,
|
||||
additional_kwargs={
|
||||
"sender": from_name,
|
||||
"events": [{"message_time": timestamp}],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return ChatSession(messages=results)
|
||||
|
||||
def _iterate_files(self, path: str) -> Iterator[str]:
|
||||
"""Iterate over files in a directory or zip file.
|
||||
|
||||
Args:
|
||||
path (str): Path to the directory or zip file.
|
||||
|
||||
Yields:
|
||||
str: Path to each file.
|
||||
"""
|
||||
if os.path.isfile(path) and path.endswith((".html", ".json")):
|
||||
yield path
|
||||
elif os.path.isdir(path):
|
||||
for root, _, files in os.walk(path):
|
||||
for file in files:
|
||||
if file.endswith((".html", ".json")):
|
||||
yield os.path.join(root, file)
|
||||
elif zipfile.is_zipfile(path):
|
||||
with zipfile.ZipFile(path) as zip_file:
|
||||
for file in zip_file.namelist():
|
||||
if file.endswith((".html", ".json")):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
yield zip_file.extract(file, path=temp_dir)
|
||||
|
||||
def lazy_load(self) -> Iterator[ChatSession]:
|
||||
"""Lazy load the messages from the chat file and yield them
|
||||
in as chat sessions.
|
||||
|
||||
Yields:
|
||||
ChatSession: The loaded chat session.
|
||||
"""
|
||||
for file_path in self._iterate_files(self.path):
|
||||
if file_path.endswith(".html"):
|
||||
yield self._load_single_chat_session_html(file_path)
|
||||
elif file_path.endswith(".json"):
|
||||
yield self._load_single_chat_session_json(file_path)
|
||||
@@ -0,0 +1,95 @@
|
||||
"""Utilities for chat loaders."""
|
||||
from copy import deepcopy
|
||||
from typing import Iterable, Iterator, List
|
||||
|
||||
from langchain_core.chat_sessions import ChatSession
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
|
||||
|
||||
def merge_chat_runs_in_session(
|
||||
chat_session: ChatSession, delimiter: str = "\n\n"
|
||||
) -> ChatSession:
|
||||
"""Merge chat runs together in a chat session.
|
||||
|
||||
A chat run is a sequence of messages from the same sender.
|
||||
|
||||
Args:
|
||||
chat_session: A chat session.
|
||||
|
||||
Returns:
|
||||
A chat session with merged chat runs.
|
||||
"""
|
||||
messages: List[BaseMessage] = []
|
||||
for message in chat_session["messages"]:
|
||||
if not isinstance(message.content, str):
|
||||
raise ValueError(
|
||||
"Chat Loaders only support messages with content type string, "
|
||||
f"got {message.content}"
|
||||
)
|
||||
if not messages:
|
||||
messages.append(deepcopy(message))
|
||||
elif (
|
||||
isinstance(message, type(messages[-1]))
|
||||
and messages[-1].additional_kwargs.get("sender") is not None
|
||||
and messages[-1].additional_kwargs["sender"]
|
||||
== message.additional_kwargs.get("sender")
|
||||
):
|
||||
if not isinstance(messages[-1].content, str):
|
||||
raise ValueError(
|
||||
"Chat Loaders only support messages with content type string, "
|
||||
f"got {messages[-1].content}"
|
||||
)
|
||||
messages[-1].content = (
|
||||
messages[-1].content + delimiter + message.content
|
||||
).strip()
|
||||
messages[-1].additional_kwargs.get("events", []).extend(
|
||||
message.additional_kwargs.get("events") or []
|
||||
)
|
||||
else:
|
||||
messages.append(deepcopy(message))
|
||||
return ChatSession(messages=messages)
|
||||
|
||||
|
||||
def merge_chat_runs(chat_sessions: Iterable[ChatSession]) -> Iterator[ChatSession]:
|
||||
"""Merge chat runs together.
|
||||
|
||||
A chat run is a sequence of messages from the same sender.
|
||||
|
||||
Args:
|
||||
chat_sessions: A list of chat sessions.
|
||||
|
||||
Returns:
|
||||
A list of chat sessions with merged chat runs.
|
||||
"""
|
||||
for chat_session in chat_sessions:
|
||||
yield merge_chat_runs_in_session(chat_session)
|
||||
|
||||
|
||||
def map_ai_messages_in_session(chat_sessions: ChatSession, sender: str) -> ChatSession:
|
||||
"""Convert messages from the specified 'sender' to AI messages.
|
||||
|
||||
This is useful for fine-tuning the AI to adapt to your voice.
|
||||
"""
|
||||
messages = []
|
||||
num_converted = 0
|
||||
for message in chat_sessions["messages"]:
|
||||
if message.additional_kwargs.get("sender") == sender:
|
||||
message = AIMessage(
|
||||
content=message.content,
|
||||
additional_kwargs=message.additional_kwargs.copy(),
|
||||
example=getattr(message, "example", None),
|
||||
)
|
||||
num_converted += 1
|
||||
messages.append(message)
|
||||
return ChatSession(messages=messages)
|
||||
|
||||
|
||||
def map_ai_messages(
|
||||
chat_sessions: Iterable[ChatSession], sender: str
|
||||
) -> Iterator[ChatSession]:
|
||||
"""Convert messages from the specified 'sender' to AI messages.
|
||||
|
||||
This is useful for fine-tuning the AI to adapt to your voice.
|
||||
"""
|
||||
for chat_session in chat_sessions:
|
||||
yield map_ai_messages_in_session(chat_session, sender)
|
||||
@@ -0,0 +1,119 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
from typing import Iterator, List, Union
|
||||
|
||||
from langchain_core.chat_sessions import ChatSession
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from langchain_integrations.chat_loaders.base import BaseChatLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WhatsAppChatLoader(BaseChatLoader):
|
||||
"""Load `WhatsApp` conversations from a dump zip file or directory."""
|
||||
|
||||
def __init__(self, path: str):
|
||||
"""Initialize the WhatsAppChatLoader.
|
||||
|
||||
Args:
|
||||
path (str): Path to the exported WhatsApp chat
|
||||
zip directory, folder, or file.
|
||||
|
||||
To generate the dump, open the chat, click the three dots in the top
|
||||
right corner, and select "More". Then select "Export chat" and
|
||||
choose "Without media".
|
||||
"""
|
||||
self.path = path
|
||||
ignore_lines = [
|
||||
"This message was deleted",
|
||||
"<Media omitted>",
|
||||
"image omitted",
|
||||
"Messages and calls are end-to-end encrypted. No one outside of this chat,"
|
||||
" not even WhatsApp, can read or listen to them.",
|
||||
]
|
||||
self._ignore_lines = re.compile(
|
||||
r"(" + "|".join([r"\u200E*" + line for line in ignore_lines]) + r")",
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
self._message_line_regex = re.compile(
|
||||
r"\u200E*\[?(\d{1,2}/\d{1,2}/\d{2,4}, \d{1,2}:\d{2}:\d{2} (?:AM|PM))\]?[ \u200E]*([^:]+): (.+)", # noqa
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
def _load_single_chat_session(self, file_path: str) -> ChatSession:
|
||||
"""Load a single chat session from a file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the chat file.
|
||||
|
||||
Returns:
|
||||
ChatSession: The loaded chat session.
|
||||
"""
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
txt = file.read()
|
||||
|
||||
# Split messages by newlines, but keep multi-line messages grouped
|
||||
chat_lines: List[str] = []
|
||||
current_message = ""
|
||||
for line in txt.split("\n"):
|
||||
if self._message_line_regex.match(line):
|
||||
if current_message:
|
||||
chat_lines.append(current_message)
|
||||
current_message = line
|
||||
else:
|
||||
current_message += " " + line.strip()
|
||||
if current_message:
|
||||
chat_lines.append(current_message)
|
||||
results: List[Union[HumanMessage, AIMessage]] = []
|
||||
for line in chat_lines:
|
||||
result = self._message_line_regex.match(line.strip())
|
||||
if result:
|
||||
timestamp, sender, text = result.groups()
|
||||
if not self._ignore_lines.match(text.strip()):
|
||||
results.append(
|
||||
HumanMessage(
|
||||
role=sender,
|
||||
content=text,
|
||||
additional_kwargs={
|
||||
"sender": sender,
|
||||
"events": [{"message_time": timestamp}],
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Could not parse line: {line}")
|
||||
return ChatSession(messages=results)
|
||||
|
||||
def _iterate_files(self, path: str) -> Iterator[str]:
|
||||
"""Iterate over the files in a directory or zip file.
|
||||
|
||||
Args:
|
||||
path (str): Path to the directory or zip file.
|
||||
|
||||
Yields:
|
||||
str: The path to each file.
|
||||
"""
|
||||
if os.path.isfile(path):
|
||||
yield path
|
||||
elif os.path.isdir(path):
|
||||
for root, _, files in os.walk(path):
|
||||
for file in files:
|
||||
if file.endswith(".txt"):
|
||||
yield os.path.join(root, file)
|
||||
elif zipfile.is_zipfile(path):
|
||||
with zipfile.ZipFile(path) as zip_file:
|
||||
for file in zip_file.namelist():
|
||||
if file.endswith(".txt"):
|
||||
yield zip_file.extract(file)
|
||||
|
||||
def lazy_load(self) -> Iterator[ChatSession]:
|
||||
"""Lazy load the messages from the chat file and yield
|
||||
them as chat sessions.
|
||||
|
||||
Yields:
|
||||
Iterator[ChatSession]: The loaded chat sessions.
|
||||
"""
|
||||
yield self._load_single_chat_session(self.path)
|
||||
@@ -0,0 +1,57 @@
|
||||
from langchain_integrations.chat_message_histories.astradb import (
|
||||
AstraDBChatMessageHistory,
|
||||
)
|
||||
from langchain_integrations.chat_message_histories.cassandra import (
|
||||
CassandraChatMessageHistory,
|
||||
)
|
||||
from langchain_integrations.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory
|
||||
from langchain_integrations.chat_message_histories.dynamodb import DynamoDBChatMessageHistory
|
||||
from langchain_integrations.chat_message_histories.elasticsearch import (
|
||||
ElasticsearchChatMessageHistory,
|
||||
)
|
||||
from langchain_integrations.chat_message_histories.file import FileChatMessageHistory
|
||||
from langchain_integrations.chat_message_histories.firestore import (
|
||||
FirestoreChatMessageHistory,
|
||||
)
|
||||
from langchain_integrations.chat_message_histories.in_memory import ChatMessageHistory
|
||||
from langchain_integrations.chat_message_histories.momento import MomentoChatMessageHistory
|
||||
from langchain_integrations.chat_message_histories.mongodb import MongoDBChatMessageHistory
|
||||
from langchain_integrations.chat_message_histories.neo4j import Neo4jChatMessageHistory
|
||||
from langchain_integrations.chat_message_histories.postgres import PostgresChatMessageHistory
|
||||
from langchain_integrations.chat_message_histories.redis import RedisChatMessageHistory
|
||||
from langchain_integrations.chat_message_histories.rocksetdb import RocksetChatMessageHistory
|
||||
from langchain_integrations.chat_message_histories.singlestoredb import (
|
||||
SingleStoreDBChatMessageHistory,
|
||||
)
|
||||
from langchain_integrations.chat_message_histories.sql import SQLChatMessageHistory
|
||||
from langchain_integrations.chat_message_histories.streamlit import (
|
||||
StreamlitChatMessageHistory,
|
||||
)
|
||||
from langchain_integrations.chat_message_histories.upstash_redis import (
|
||||
UpstashRedisChatMessageHistory,
|
||||
)
|
||||
from langchain_integrations.chat_message_histories.xata import XataChatMessageHistory
|
||||
from langchain_integrations.chat_message_histories.zep import ZepChatMessageHistory
|
||||
|
||||
__all__ = [
|
||||
"AstraDBChatMessageHistory",
|
||||
"ChatMessageHistory",
|
||||
"CassandraChatMessageHistory",
|
||||
"CosmosDBChatMessageHistory",
|
||||
"DynamoDBChatMessageHistory",
|
||||
"ElasticsearchChatMessageHistory",
|
||||
"FileChatMessageHistory",
|
||||
"FirestoreChatMessageHistory",
|
||||
"MomentoChatMessageHistory",
|
||||
"MongoDBChatMessageHistory",
|
||||
"PostgresChatMessageHistory",
|
||||
"RedisChatMessageHistory",
|
||||
"RocksetChatMessageHistory",
|
||||
"SQLChatMessageHistory",
|
||||
"StreamlitChatMessageHistory",
|
||||
"SingleStoreDBChatMessageHistory",
|
||||
"XataChatMessageHistory",
|
||||
"ZepChatMessageHistory",
|
||||
"UpstashRedisChatMessageHistory",
|
||||
"Neo4jChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,114 @@
|
||||
"""Astra DB - based chat message history, based on astrapy."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import typing
|
||||
from typing import List, Optional
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from astrapy.db import AstraDB as LibAstraDB
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
DEFAULT_COLLECTION_NAME = "langchain_message_store"
|
||||
|
||||
|
||||
class AstraDBChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history that stores history in Astra DB.
|
||||
|
||||
Args (only keyword-arguments accepted):
|
||||
session_id: arbitrary key that is used to store the messages
|
||||
of a single chat session.
|
||||
collection_name (str): name of the Astra DB collection to create/use.
|
||||
token (Optional[str]): API token for Astra DB usage.
|
||||
api_endpoint (Optional[str]): full URL to the API endpoint,
|
||||
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
|
||||
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AstraDB' instance.
|
||||
namespace (Optional[str]): namespace (aka keyspace) where the
|
||||
collection is created. Defaults to the database's "default namespace".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
collection_name: str = DEFAULT_COLLECTION_NAME,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[LibAstraDB] = None, # type 'astrapy.db.AstraDB'
|
||||
namespace: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Create an Astra DB chat message history."""
|
||||
try:
|
||||
from astrapy.db import AstraDB as LibAstraDB
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import a recent astrapy python package. "
|
||||
"Please install it with `pip install --upgrade astrapy`."
|
||||
)
|
||||
|
||||
# Conflicting-arg checks:
|
||||
if astra_db_client is not None:
|
||||
if token is not None or api_endpoint is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass 'astra_db_client' to AstraDB if passing "
|
||||
"'token' and 'api_endpoint'."
|
||||
)
|
||||
|
||||
self.session_id = session_id
|
||||
self.collection_name = collection_name
|
||||
self.token = token
|
||||
self.api_endpoint = api_endpoint
|
||||
self.namespace = namespace
|
||||
if astra_db_client is not None:
|
||||
self.astra_db = astra_db_client
|
||||
else:
|
||||
self.astra_db = LibAstraDB(
|
||||
token=self.token,
|
||||
api_endpoint=self.api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
self.collection = self.astra_db.create_collection(self.collection_name)
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve all session messages from DB"""
|
||||
message_blobs = [
|
||||
doc["body_blob"]
|
||||
for doc in sorted(
|
||||
self.collection.paginated_find(
|
||||
filter={
|
||||
"session_id": self.session_id,
|
||||
},
|
||||
projection={
|
||||
"timestamp": 1,
|
||||
"body_blob": 1,
|
||||
},
|
||||
),
|
||||
key=lambda _doc: _doc["timestamp"],
|
||||
)
|
||||
]
|
||||
items = [json.loads(message_blob) for message_blob in message_blobs]
|
||||
messages = messages_from_dict(items)
|
||||
return messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Write a message to the table"""
|
||||
self.collection.insert_one(
|
||||
{
|
||||
"timestamp": time.time(),
|
||||
"session_id": self.session_id,
|
||||
"body_blob": json.dumps(message_to_dict(message)),
|
||||
}
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from DB"""
|
||||
self.collection.delete_many(filter={"session_id": self.session_id})
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Cassandra-based chat message history, based on cassIO."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
from typing import List
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from cassandra.cluster import Session
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
DEFAULT_TABLE_NAME = "message_store"
|
||||
DEFAULT_TTL_SECONDS = None
|
||||
|
||||
|
||||
class CassandraChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history that stores history in Cassandra.
|
||||
|
||||
Args:
|
||||
session_id: arbitrary key that is used to store the messages
|
||||
of a single chat session.
|
||||
session: a Cassandra `Session` object (an open DB connection)
|
||||
keyspace: name of the keyspace to use.
|
||||
table_name: name of the table to use.
|
||||
ttl_seconds: time-to-live (seconds) for automatic expiration
|
||||
of stored entries. None (default) for no expiration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
session: Session,
|
||||
keyspace: str,
|
||||
table_name: str = DEFAULT_TABLE_NAME,
|
||||
ttl_seconds: typing.Optional[int] = DEFAULT_TTL_SECONDS,
|
||||
) -> None:
|
||||
try:
|
||||
from cassio.history import StoredBlobHistory
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import cassio python package. "
|
||||
"Please install it with `pip install cassio`."
|
||||
)
|
||||
self.session_id = session_id
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self.blob_history = StoredBlobHistory(session, keyspace, table_name)
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve all session messages from DB"""
|
||||
message_blobs = self.blob_history.retrieve(
|
||||
self.session_id,
|
||||
)
|
||||
items = [json.loads(message_blob) for message_blob in message_blobs]
|
||||
messages = messages_from_dict(items)
|
||||
return messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Write a message to the table"""
|
||||
self.blob_history.store(
|
||||
self.session_id, json.dumps(message_to_dict(message)), self.ttl_seconds
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from DB"""
|
||||
self.blob_history.clear_session_id(self.session_id)
|
||||
@@ -0,0 +1,172 @@
|
||||
"""Azure CosmosDB Memory History."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Type
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from azure.cosmos import ContainerProxy
|
||||
|
||||
|
||||
class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history backed by Azure CosmosDB."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cosmos_endpoint: str,
|
||||
cosmos_database: str,
|
||||
cosmos_container: str,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
credential: Any = None,
|
||||
connection_string: Optional[str] = None,
|
||||
ttl: Optional[int] = None,
|
||||
cosmos_client_kwargs: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initializes a new instance of the CosmosDBChatMessageHistory class.
|
||||
|
||||
Make sure to call prepare_cosmos or use the context manager to make
|
||||
sure your database is ready.
|
||||
|
||||
Either a credential or a connection string must be provided.
|
||||
|
||||
:param cosmos_endpoint: The connection endpoint for the Azure Cosmos DB account.
|
||||
:param cosmos_database: The name of the database to use.
|
||||
:param cosmos_container: The name of the container to use.
|
||||
:param session_id: The session ID to use, can be overwritten while loading.
|
||||
:param user_id: The user ID to use, can be overwritten while loading.
|
||||
:param credential: The credential to use to authenticate to Azure Cosmos DB.
|
||||
:param connection_string: The connection string to use to authenticate.
|
||||
:param ttl: The time to live (in seconds) to use for documents in the container.
|
||||
:param cosmos_client_kwargs: Additional kwargs to pass to the CosmosClient.
|
||||
"""
|
||||
self.cosmos_endpoint = cosmos_endpoint
|
||||
self.cosmos_database = cosmos_database
|
||||
self.cosmos_container = cosmos_container
|
||||
self.credential = credential
|
||||
self.conn_string = connection_string
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
self.ttl = ttl
|
||||
|
||||
self.messages: List[BaseMessage] = []
|
||||
try:
|
||||
from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
|
||||
CosmosClient,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
|
||||
"Please install it with `pip install azure-cosmos`."
|
||||
) from exc
|
||||
if self.credential:
|
||||
self._client = CosmosClient(
|
||||
url=self.cosmos_endpoint,
|
||||
credential=self.credential,
|
||||
**cosmos_client_kwargs or {},
|
||||
)
|
||||
elif self.conn_string:
|
||||
self._client = CosmosClient.from_connection_string(
|
||||
conn_str=self.conn_string,
|
||||
**cosmos_client_kwargs or {},
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either a connection string or a credential must be set.")
|
||||
self._container: Optional[ContainerProxy] = None
|
||||
|
||||
def prepare_cosmos(self) -> None:
|
||||
"""Prepare the CosmosDB client.
|
||||
|
||||
Use this function or the context manager to make sure your database is ready.
|
||||
"""
|
||||
try:
|
||||
from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
|
||||
PartitionKey,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
|
||||
"Please install it with `pip install azure-cosmos`."
|
||||
) from exc
|
||||
database = self._client.create_database_if_not_exists(self.cosmos_database)
|
||||
self._container = database.create_container_if_not_exists(
|
||||
self.cosmos_container,
|
||||
partition_key=PartitionKey("/user_id"),
|
||||
default_ttl=self.ttl,
|
||||
)
|
||||
self.load_messages()
|
||||
|
||||
def __enter__(self) -> "CosmosDBChatMessageHistory":
|
||||
"""Context manager entry point."""
|
||||
self._client.__enter__()
|
||||
self.prepare_cosmos()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
"""Context manager exit"""
|
||||
self.upsert_messages()
|
||||
self._client.__exit__(exc_type, exc_val, traceback)
|
||||
|
||||
def load_messages(self) -> None:
|
||||
"""Retrieve the messages from Cosmos"""
|
||||
if not self._container:
|
||||
raise ValueError("Container not initialized")
|
||||
try:
|
||||
from azure.cosmos.exceptions import ( # pylint: disable=import-outside-toplevel # noqa: E501
|
||||
CosmosHttpResponseError,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
|
||||
"Please install it with `pip install azure-cosmos`."
|
||||
) from exc
|
||||
try:
|
||||
item = self._container.read_item(
|
||||
item=self.session_id, partition_key=self.user_id
|
||||
)
|
||||
except CosmosHttpResponseError:
|
||||
logger.info("no session found")
|
||||
return
|
||||
if "messages" in item and len(item["messages"]) > 0:
|
||||
self.messages = messages_from_dict(item["messages"])
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a self-created message to the store"""
|
||||
self.messages.append(message)
|
||||
self.upsert_messages()
|
||||
|
||||
def upsert_messages(self) -> None:
|
||||
"""Update the cosmosdb item."""
|
||||
if not self._container:
|
||||
raise ValueError("Container not initialized")
|
||||
self._container.upsert_item(
|
||||
body={
|
||||
"id": self.session_id,
|
||||
"user_id": self.user_id,
|
||||
"messages": messages_to_dict(self.messages),
|
||||
}
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from this memory and cosmos."""
|
||||
self.messages = []
|
||||
if self._container:
|
||||
self._container.delete_item(
|
||||
item=self.session_id, partition_key=self.user_id
|
||||
)
|
||||
@@ -0,0 +1,153 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from boto3.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history that stores history in AWS DynamoDB.
|
||||
|
||||
This class expects that a DynamoDB table exists with name `table_name`
|
||||
|
||||
Args:
|
||||
table_name: name of the DynamoDB table
|
||||
session_id: arbitrary key that is used to store the messages
|
||||
of a single chat session.
|
||||
endpoint_url: URL of the AWS endpoint to connect to. This argument
|
||||
is optional and useful for test purposes, like using Localstack.
|
||||
If you plan to use AWS cloud service, you normally don't have to
|
||||
worry about setting the endpoint_url.
|
||||
primary_key_name: name of the primary key of the DynamoDB table. This argument
|
||||
is optional, defaulting to "SessionId".
|
||||
key: an optional dictionary with a custom primary and secondary key.
|
||||
This argument is optional, but useful when using composite dynamodb keys, or
|
||||
isolating records based off of application details such as a user id.
|
||||
This may also contain global and local secondary index keys.
|
||||
kms_key_id: an optional AWS KMS Key ID, AWS KMS Key ARN, or AWS KMS Alias for
|
||||
client-side encryption
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
table_name: str,
|
||||
session_id: str,
|
||||
endpoint_url: Optional[str] = None,
|
||||
primary_key_name: str = "SessionId",
|
||||
key: Optional[Dict[str, str]] = None,
|
||||
boto3_session: Optional[Session] = None,
|
||||
kms_key_id: Optional[str] = None,
|
||||
):
|
||||
if boto3_session:
|
||||
client = boto3_session.resource("dynamodb", endpoint_url=endpoint_url)
|
||||
else:
|
||||
try:
|
||||
import boto3
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import boto3, please install with `pip install boto3`."
|
||||
) from e
|
||||
if endpoint_url:
|
||||
client = boto3.resource("dynamodb", endpoint_url=endpoint_url)
|
||||
else:
|
||||
client = boto3.resource("dynamodb")
|
||||
self.table = client.Table(table_name)
|
||||
self.session_id = session_id
|
||||
self.key: Dict = key or {primary_key_name: session_id}
|
||||
|
||||
if kms_key_id:
|
||||
try:
|
||||
from dynamodb_encryption_sdk.encrypted.table import EncryptedTable
|
||||
from dynamodb_encryption_sdk.identifiers import CryptoAction
|
||||
from dynamodb_encryption_sdk.material_providers.aws_kms import (
|
||||
AwsKmsCryptographicMaterialsProvider,
|
||||
)
|
||||
from dynamodb_encryption_sdk.structures import AttributeActions
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import dynamodb_encryption_sdk, please install with "
|
||||
"`pip install dynamodb-encryption-sdk`."
|
||||
) from e
|
||||
|
||||
actions = AttributeActions(
|
||||
default_action=CryptoAction.DO_NOTHING,
|
||||
attribute_actions={"History": CryptoAction.ENCRYPT_AND_SIGN},
|
||||
)
|
||||
aws_kms_cmp = AwsKmsCryptographicMaterialsProvider(key_id=kms_key_id)
|
||||
self.table = EncryptedTable(
|
||||
table=self.table,
|
||||
materials_provider=aws_kms_cmp,
|
||||
attribute_actions=actions,
|
||||
auto_refresh_table_indexes=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve the messages from DynamoDB"""
|
||||
try:
|
||||
from botocore.exceptions import ClientError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import botocore, please install with `pip install botocore`."
|
||||
) from e
|
||||
|
||||
response = None
|
||||
try:
|
||||
response = self.table.get_item(Key=self.key)
|
||||
except ClientError as error:
|
||||
if error.response["Error"]["Code"] == "ResourceNotFoundException":
|
||||
logger.warning("No record found with session id: %s", self.session_id)
|
||||
else:
|
||||
logger.error(error)
|
||||
|
||||
if response and "Item" in response:
|
||||
items = response["Item"]["History"]
|
||||
else:
|
||||
items = []
|
||||
|
||||
messages = messages_from_dict(items)
|
||||
return messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in DynamoDB"""
|
||||
try:
|
||||
from botocore.exceptions import ClientError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import botocore, please install with `pip install botocore`."
|
||||
) from e
|
||||
|
||||
messages = messages_to_dict(self.messages)
|
||||
_message = message_to_dict(message)
|
||||
messages.append(_message)
|
||||
|
||||
try:
|
||||
self.table.put_item(Item={**self.key, "History": messages})
|
||||
except ClientError as err:
|
||||
logger.error(err)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from DynamoDB"""
|
||||
try:
|
||||
from botocore.exceptions import ClientError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import botocore, please install with `pip install botocore`."
|
||||
) from e
|
||||
|
||||
try:
|
||||
self.table.delete_item(Key=self.key)
|
||||
except ClientError as err:
|
||||
logger.error(err)
|
||||
@@ -0,0 +1,195 @@
|
||||
import json
|
||||
import logging
|
||||
from time import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history that stores history in Elasticsearch.
|
||||
|
||||
Args:
|
||||
es_url: URL of the Elasticsearch instance to connect to.
|
||||
es_cloud_id: Cloud ID of the Elasticsearch instance to connect to.
|
||||
es_user: Username to use when connecting to Elasticsearch.
|
||||
es_password: Password to use when connecting to Elasticsearch.
|
||||
es_api_key: API key to use when connecting to Elasticsearch.
|
||||
es_connection: Optional pre-existing Elasticsearch connection.
|
||||
index: Name of the index to use.
|
||||
session_id: Arbitrary key that is used to store the messages
|
||||
of a single chat session.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: str,
|
||||
session_id: str,
|
||||
*,
|
||||
es_connection: Optional["Elasticsearch"] = None,
|
||||
es_url: Optional[str] = None,
|
||||
es_cloud_id: Optional[str] = None,
|
||||
es_user: Optional[str] = None,
|
||||
es_api_key: Optional[str] = None,
|
||||
es_password: Optional[str] = None,
|
||||
):
|
||||
self.index: str = index
|
||||
self.session_id: str = session_id
|
||||
|
||||
# Initialize Elasticsearch client from passed client arg or connection info
|
||||
if es_connection is not None:
|
||||
self.client = es_connection.options(
|
||||
headers={"user-agent": self.get_user_agent()}
|
||||
)
|
||||
elif es_url is not None or es_cloud_id is not None:
|
||||
self.client = ElasticsearchChatMessageHistory.connect_to_elasticsearch(
|
||||
es_url=es_url,
|
||||
username=es_user,
|
||||
password=es_password,
|
||||
cloud_id=es_cloud_id,
|
||||
api_key=es_api_key,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"""Either provide a pre-existing Elasticsearch connection, \
|
||||
or valid credentials for creating a new connection."""
|
||||
)
|
||||
|
||||
if self.client.indices.exists(index=index):
|
||||
logger.debug(
|
||||
f"Chat history index {index} already exists, skipping creation."
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Creating index {index} for storing chat history.")
|
||||
|
||||
self.client.indices.create(
|
||||
index=index,
|
||||
mappings={
|
||||
"properties": {
|
||||
"session_id": {"type": "keyword"},
|
||||
"created_at": {"type": "date"},
|
||||
"history": {"type": "text"},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent() -> str:
|
||||
from langchain import __version__
|
||||
|
||||
return f"langchain-py-ms/{__version__}"
|
||||
|
||||
@staticmethod
|
||||
def connect_to_elasticsearch(
|
||||
*,
|
||||
es_url: Optional[str] = None,
|
||||
cloud_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
) -> "Elasticsearch":
|
||||
try:
|
||||
import elasticsearch
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import elasticsearch python package. "
|
||||
"Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
|
||||
if es_url and cloud_id:
|
||||
raise ValueError(
|
||||
"Both es_url and cloud_id are defined. Please provide only one."
|
||||
)
|
||||
|
||||
connection_params: Dict[str, Any] = {}
|
||||
|
||||
if es_url:
|
||||
connection_params["hosts"] = [es_url]
|
||||
elif cloud_id:
|
||||
connection_params["cloud_id"] = cloud_id
|
||||
else:
|
||||
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
|
||||
|
||||
if api_key:
|
||||
connection_params["api_key"] = api_key
|
||||
elif username and password:
|
||||
connection_params["basic_auth"] = (username, password)
|
||||
|
||||
es_client = elasticsearch.Elasticsearch(
|
||||
**connection_params,
|
||||
headers={"user-agent": ElasticsearchChatMessageHistory.get_user_agent()},
|
||||
)
|
||||
try:
|
||||
es_client.info()
|
||||
except Exception as err:
|
||||
logger.error(f"Error connecting to Elasticsearch: {err}")
|
||||
raise err
|
||||
|
||||
return es_client
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||
"""Retrieve the messages from Elasticsearch"""
|
||||
try:
|
||||
from elasticsearch import ApiError
|
||||
|
||||
result = self.client.search(
|
||||
index=self.index,
|
||||
query={"term": {"session_id": self.session_id}},
|
||||
sort="created_at:asc",
|
||||
)
|
||||
except ApiError as err:
|
||||
logger.error(f"Could not retrieve messages from Elasticsearch: {err}")
|
||||
raise err
|
||||
|
||||
if result and len(result["hits"]["hits"]) > 0:
|
||||
items = [
|
||||
json.loads(document["_source"]["history"])
|
||||
for document in result["hits"]["hits"]
|
||||
]
|
||||
else:
|
||||
items = []
|
||||
|
||||
return messages_from_dict(items)
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a message to the chat session in Elasticsearch"""
|
||||
try:
|
||||
from elasticsearch import ApiError
|
||||
|
||||
self.client.index(
|
||||
index=self.index,
|
||||
document={
|
||||
"session_id": self.session_id,
|
||||
"created_at": round(time() * 1000),
|
||||
"history": json.dumps(message_to_dict(message)),
|
||||
},
|
||||
refresh=True,
|
||||
)
|
||||
except ApiError as err:
|
||||
logger.error(f"Could not add message to Elasticsearch: {err}")
|
||||
raise err
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory in Elasticsearch"""
|
||||
try:
|
||||
from elasticsearch import ApiError
|
||||
|
||||
self.client.delete_by_query(
|
||||
index=self.index,
|
||||
query={"term": {"session_id": self.session_id}},
|
||||
refresh=True,
|
||||
)
|
||||
except ApiError as err:
|
||||
logger.error(f"Could not clear session memory in Elasticsearch: {err}")
|
||||
raise err
|
||||
@@ -0,0 +1,45 @@
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileChatMessageHistory(BaseChatMessageHistory):
|
||||
"""
|
||||
Chat message history that stores history in a local file.
|
||||
|
||||
Args:
|
||||
file_path: path of the local file to store the messages.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
self.file_path = Path(file_path)
|
||||
if not self.file_path.exists():
|
||||
self.file_path.touch()
|
||||
self.file_path.write_text(json.dumps([]))
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve the messages from the local file"""
|
||||
items = json.loads(self.file_path.read_text())
|
||||
messages = messages_from_dict(items)
|
||||
return messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in the local file"""
|
||||
messages = messages_to_dict(self.messages)
|
||||
messages.append(messages_to_dict([message])[0])
|
||||
self.file_path.write_text(json.dumps(messages))
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from the local file"""
|
||||
self.file_path.write_text(json.dumps([]))
|
||||
@@ -0,0 +1,105 @@
|
||||
"""Firestore Chat Message History."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.cloud.firestore import Client, DocumentReference
|
||||
|
||||
|
||||
def _get_firestore_client() -> Client:
|
||||
try:
|
||||
import firebase_admin
|
||||
from firebase_admin import firestore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import firebase-admin python package. "
|
||||
"Please install it with `pip install firebase-admin`."
|
||||
)
|
||||
|
||||
# For multiple instances, only initialize the app once.
|
||||
try:
|
||||
firebase_admin.get_app()
|
||||
except ValueError as e:
|
||||
logger.debug("Initializing Firebase app: %s", e)
|
||||
firebase_admin.initialize_app()
|
||||
|
||||
return firestore.client()
|
||||
|
||||
|
||||
class FirestoreChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history backed by Google Firestore."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
firestore_client: Optional[Client] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new instance of the FirestoreChatMessageHistory class.
|
||||
|
||||
:param collection_name: The name of the collection to use.
|
||||
:param session_id: The session ID for the chat..
|
||||
:param user_id: The user ID for the chat.
|
||||
"""
|
||||
self.collection_name = collection_name
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
self._document: Optional[DocumentReference] = None
|
||||
self.messages: List[BaseMessage] = []
|
||||
self.firestore_client = firestore_client or _get_firestore_client()
|
||||
self.prepare_firestore()
|
||||
|
||||
def prepare_firestore(self) -> None:
|
||||
"""Prepare the Firestore client.
|
||||
|
||||
Use this function to make sure your database is ready.
|
||||
"""
|
||||
self._document = self.firestore_client.collection(
|
||||
self.collection_name
|
||||
).document(self.session_id)
|
||||
self.load_messages()
|
||||
|
||||
def load_messages(self) -> None:
|
||||
"""Retrieve the messages from Firestore"""
|
||||
if not self._document:
|
||||
raise ValueError("Document not initialized")
|
||||
doc = self._document.get()
|
||||
if doc.exists:
|
||||
data = doc.to_dict()
|
||||
if "messages" in data and len(data["messages"]) > 0:
|
||||
self.messages = messages_from_dict(data["messages"])
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
self.messages.append(message)
|
||||
self.upsert_messages()
|
||||
|
||||
def upsert_messages(self, new_message: Optional[BaseMessage] = None) -> None:
|
||||
"""Update the Firestore document."""
|
||||
if not self._document:
|
||||
raise ValueError("Document not initialized")
|
||||
self._document.set(
|
||||
{
|
||||
"id": self.session_id,
|
||||
"user_id": self.user_id,
|
||||
"messages": messages_to_dict(self.messages),
|
||||
}
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from this memory and Firestore."""
|
||||
self.messages = []
|
||||
if self._document:
|
||||
self._document.delete()
|
||||
@@ -0,0 +1,21 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
|
||||
class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
|
||||
"""In memory implementation of chat message history.
|
||||
|
||||
Stores messages in an in memory list.
|
||||
"""
|
||||
|
||||
messages: List[BaseMessage] = Field(default_factory=list)
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a self-created message to the store"""
|
||||
self.messages.append(message)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.messages = []
|
||||
@@ -0,0 +1,190 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
from langchain_core.utils import get_from_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import momento
|
||||
|
||||
|
||||
def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None:
|
||||
"""Create cache if it doesn't exist.
|
||||
|
||||
Raises:
|
||||
SdkException: Momento service or network error
|
||||
Exception: Unexpected response
|
||||
"""
|
||||
from momento.responses import CreateCache
|
||||
|
||||
create_cache_response = cache_client.create_cache(cache_name)
|
||||
if isinstance(create_cache_response, CreateCache.Success) or isinstance(
|
||||
create_cache_response, CreateCache.CacheAlreadyExists
|
||||
):
|
||||
return None
|
||||
elif isinstance(create_cache_response, CreateCache.Error):
|
||||
raise create_cache_response.inner_exception
|
||||
else:
|
||||
raise Exception(f"Unexpected response cache creation: {create_cache_response}")
|
||||
|
||||
|
||||
class MomentoChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history cache that uses Momento as a backend.
|
||||
|
||||
See https://gomomento.com/"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
cache_client: momento.CacheClient,
|
||||
cache_name: str,
|
||||
*,
|
||||
key_prefix: str = "message_store:",
|
||||
ttl: Optional[timedelta] = None,
|
||||
ensure_cache_exists: bool = True,
|
||||
):
|
||||
"""Instantiate a chat message history cache that uses Momento as a backend.
|
||||
|
||||
Note: to instantiate the cache client passed to MomentoChatMessageHistory,
|
||||
you must have a Momento account at https://gomomento.com/.
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID to use for this chat session.
|
||||
cache_client (CacheClient): The Momento cache client.
|
||||
cache_name (str): The name of the cache to use to store the messages.
|
||||
key_prefix (str, optional): The prefix to apply to the cache key.
|
||||
Defaults to "message_store:".
|
||||
ttl (Optional[timedelta], optional): The TTL to use for the messages.
|
||||
Defaults to None, ie the default TTL of the cache will be used.
|
||||
ensure_cache_exists (bool, optional): Create the cache if it doesn't exist.
|
||||
Defaults to True.
|
||||
|
||||
Raises:
|
||||
ImportError: Momento python package is not installed.
|
||||
TypeError: cache_client is not of type momento.CacheClientObject
|
||||
"""
|
||||
try:
|
||||
from momento import CacheClient
|
||||
from momento.requests import CollectionTtl
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import momento python package. "
|
||||
"Please install it with `pip install momento`."
|
||||
)
|
||||
if not isinstance(cache_client, CacheClient):
|
||||
raise TypeError("cache_client must be a momento.CacheClient object.")
|
||||
if ensure_cache_exists:
|
||||
_ensure_cache_exists(cache_client, cache_name)
|
||||
self.key = key_prefix + session_id
|
||||
self.cache_client = cache_client
|
||||
self.cache_name = cache_name
|
||||
if ttl is not None:
|
||||
self.ttl = CollectionTtl.of(ttl)
|
||||
else:
|
||||
self.ttl = CollectionTtl.from_cache_ttl()
|
||||
|
||||
@classmethod
|
||||
def from_client_params(
|
||||
cls,
|
||||
session_id: str,
|
||||
cache_name: str,
|
||||
ttl: timedelta,
|
||||
*,
|
||||
configuration: Optional[momento.config.Configuration] = None,
|
||||
api_key: Optional[str] = None,
|
||||
auth_token: Optional[str] = None, # for backwards compatibility
|
||||
**kwargs: Any,
|
||||
) -> MomentoChatMessageHistory:
|
||||
"""Construct cache from CacheClient parameters."""
|
||||
try:
|
||||
from momento import CacheClient, Configurations, CredentialProvider
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import momento python package. "
|
||||
"Please install it with `pip install momento`."
|
||||
)
|
||||
if configuration is None:
|
||||
configuration = Configurations.Laptop.v1()
|
||||
|
||||
# Try checking `MOMENTO_AUTH_TOKEN` first for backwards compatibility
|
||||
try:
|
||||
api_key = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN")
|
||||
except ValueError:
|
||||
api_key = api_key or get_from_env("api_key", "MOMENTO_API_KEY")
|
||||
credentials = CredentialProvider.from_string(api_key)
|
||||
cache_client = CacheClient(configuration, credentials, default_ttl=ttl)
|
||||
return cls(session_id, cache_client, cache_name, ttl=ttl, **kwargs)
|
||||
|
||||
@property
|
||||
def messages(self) -> list[BaseMessage]: # type: ignore[override]
|
||||
"""Retrieve the messages from Momento.
|
||||
|
||||
Raises:
|
||||
SdkException: Momento service or network error
|
||||
Exception: Unexpected response
|
||||
|
||||
Returns:
|
||||
list[BaseMessage]: List of cached messages
|
||||
"""
|
||||
from momento.responses import CacheListFetch
|
||||
|
||||
fetch_response = self.cache_client.list_fetch(self.cache_name, self.key)
|
||||
|
||||
if isinstance(fetch_response, CacheListFetch.Hit):
|
||||
items = [json.loads(m) for m in fetch_response.value_list_string]
|
||||
return messages_from_dict(items)
|
||||
elif isinstance(fetch_response, CacheListFetch.Miss):
|
||||
return []
|
||||
elif isinstance(fetch_response, CacheListFetch.Error):
|
||||
raise fetch_response.inner_exception
|
||||
else:
|
||||
raise Exception(f"Unexpected response: {fetch_response}")
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Store a message in the cache.
|
||||
|
||||
Args:
|
||||
message (BaseMessage): The message object to store.
|
||||
|
||||
Raises:
|
||||
SdkException: Momento service or network error.
|
||||
Exception: Unexpected response.
|
||||
"""
|
||||
from momento.responses import CacheListPushBack
|
||||
|
||||
item = json.dumps(message_to_dict(message))
|
||||
push_response = self.cache_client.list_push_back(
|
||||
self.cache_name, self.key, item, ttl=self.ttl
|
||||
)
|
||||
if isinstance(push_response, CacheListPushBack.Success):
|
||||
return None
|
||||
elif isinstance(push_response, CacheListPushBack.Error):
|
||||
raise push_response.inner_exception
|
||||
else:
|
||||
raise Exception(f"Unexpected response: {push_response}")
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove the session's messages from the cache.
|
||||
|
||||
Raises:
|
||||
SdkException: Momento service or network error.
|
||||
Exception: Unexpected response.
|
||||
"""
|
||||
from momento.responses import CacheDelete
|
||||
|
||||
delete_response = self.cache_client.delete(self.cache_name, self.key)
|
||||
if isinstance(delete_response, CacheDelete.Success):
|
||||
return None
|
||||
elif isinstance(delete_response, CacheDelete.Error):
|
||||
raise delete_response.inner_exception
|
||||
else:
|
||||
raise Exception(f"Unexpected response: {delete_response}")
|
||||
@@ -0,0 +1,91 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_DBNAME = "chat_history"
|
||||
DEFAULT_COLLECTION_NAME = "message_store"
|
||||
|
||||
|
||||
class MongoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history that stores history in MongoDB.
|
||||
|
||||
Args:
|
||||
connection_string: connection string to connect to MongoDB
|
||||
session_id: arbitrary key that is used to store the messages
|
||||
of a single chat session.
|
||||
database_name: name of the database to use
|
||||
collection_name: name of the collection to use
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str,
|
||||
session_id: str,
|
||||
database_name: str = DEFAULT_DBNAME,
|
||||
collection_name: str = DEFAULT_COLLECTION_NAME,
|
||||
):
|
||||
from pymongo import MongoClient, errors
|
||||
|
||||
self.connection_string = connection_string
|
||||
self.session_id = session_id
|
||||
self.database_name = database_name
|
||||
self.collection_name = collection_name
|
||||
|
||||
try:
|
||||
self.client: MongoClient = MongoClient(connection_string)
|
||||
except errors.ConnectionFailure as error:
|
||||
logger.error(error)
|
||||
|
||||
self.db = self.client[database_name]
|
||||
self.collection = self.db[collection_name]
|
||||
self.collection.create_index("SessionId")
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve the messages from MongoDB"""
|
||||
from pymongo import errors
|
||||
|
||||
try:
|
||||
cursor = self.collection.find({"SessionId": self.session_id})
|
||||
except errors.OperationFailure as error:
|
||||
logger.error(error)
|
||||
|
||||
if cursor:
|
||||
items = [json.loads(document["History"]) for document in cursor]
|
||||
else:
|
||||
items = []
|
||||
|
||||
messages = messages_from_dict(items)
|
||||
return messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in MongoDB"""
|
||||
from pymongo import errors
|
||||
|
||||
try:
|
||||
self.collection.insert_one(
|
||||
{
|
||||
"SessionId": self.session_id,
|
||||
"History": json.dumps(message_to_dict(message)),
|
||||
}
|
||||
)
|
||||
except errors.WriteError as err:
|
||||
logger.error(err)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from MongoDB"""
|
||||
from pymongo import errors
|
||||
|
||||
try:
|
||||
self.collection.delete_many({"SessionId": self.session_id})
|
||||
except errors.WriteError as err:
|
||||
logger.error(err)
|
||||
@@ -0,0 +1,113 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import BaseMessage, messages_from_dict
|
||||
|
||||
from langchain_core.utils import get_from_env
|
||||
|
||||
|
||||
class Neo4jChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history stored in a Neo4j database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: Union[str, int],
|
||||
url: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
database: str = "neo4j",
|
||||
node_label: str = "Session",
|
||||
window: int = 3,
|
||||
):
|
||||
try:
|
||||
import neo4j
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import neo4j python package. "
|
||||
"Please install it with `pip install neo4j`."
|
||||
)
|
||||
|
||||
# Make sure session id is not null
|
||||
if not session_id:
|
||||
raise ValueError("Please ensure that the session_id parameter is provided")
|
||||
|
||||
url = get_from_env("url", "NEO4J_URI", url)
|
||||
username = get_from_env("username", "NEO4J_USERNAME", username)
|
||||
password = get_from_env("password", "NEO4J_PASSWORD", password)
|
||||
database = get_from_env("database", "NEO4J_DATABASE", database)
|
||||
|
||||
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
|
||||
self._database = database
|
||||
self._session_id = session_id
|
||||
self._node_label = node_label
|
||||
self._window = window
|
||||
|
||||
# Verify connection
|
||||
try:
|
||||
self._driver.verify_connectivity()
|
||||
except neo4j.exceptions.ServiceUnavailable:
|
||||
raise ValueError(
|
||||
"Could not connect to Neo4j database. "
|
||||
"Please ensure that the url is correct"
|
||||
)
|
||||
except neo4j.exceptions.AuthError:
|
||||
raise ValueError(
|
||||
"Could not connect to Neo4j database. "
|
||||
"Please ensure that the username and password are correct"
|
||||
)
|
||||
# Create session node
|
||||
self._driver.execute_query(
|
||||
f"MERGE (s:`{self._node_label}` {{id:$session_id}})",
|
||||
{"session_id": self._session_id},
|
||||
).summary
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve the messages from Neo4j"""
|
||||
query = (
|
||||
f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) "
|
||||
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT*0.."
|
||||
f"{self._window*2}]-() WITH p, length(p) AS length "
|
||||
"ORDER BY length DESC LIMIT 1 UNWIND reverse(nodes(p)) AS node "
|
||||
"RETURN {data:{content: node.content}, type:node.type} AS result"
|
||||
)
|
||||
records, _, _ = self._driver.execute_query(
|
||||
query, {"session_id": self._session_id}
|
||||
)
|
||||
|
||||
messages = messages_from_dict([el["result"] for el in records])
|
||||
return messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in Neo4j"""
|
||||
query = (
|
||||
f"MATCH (s:`{self._node_label}`) WHERE s.id = $session_id "
|
||||
"OPTIONAL MATCH (s)-[lm:LAST_MESSAGE]->(last_message) "
|
||||
"CREATE (s)-[:LAST_MESSAGE]->(new:Message) "
|
||||
"SET new += {type:$type, content:$content} "
|
||||
"WITH new, lm, last_message WHERE last_message IS NOT NULL "
|
||||
"CREATE (last_message)-[:NEXT]->(new) "
|
||||
"DELETE lm"
|
||||
)
|
||||
self._driver.execute_query(
|
||||
query,
|
||||
{
|
||||
"type": message.type,
|
||||
"content": message.content,
|
||||
"session_id": self._session_id,
|
||||
},
|
||||
).summary
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from Neo4j"""
|
||||
query = (
|
||||
f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) "
|
||||
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT]-() "
|
||||
"WITH p, length(p) AS length ORDER BY length DESC LIMIT 1 "
|
||||
"UNWIND nodes(p) as node DETACH DELETE node;"
|
||||
)
|
||||
self._driver.execute_query(query, {"session_id": self._session_id}).summary
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self._driver:
|
||||
self._driver.close()
|
||||
@@ -0,0 +1,82 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_CONNECTION_STRING = "postgresql://postgres:mypassword@localhost/chat_history"
|
||||
|
||||
|
||||
class PostgresChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history stored in a Postgres database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
connection_string: str = DEFAULT_CONNECTION_STRING,
|
||||
table_name: str = "message_store",
|
||||
):
|
||||
import psycopg
|
||||
from psycopg.rows import dict_row
|
||||
|
||||
try:
|
||||
self.connection = psycopg.connect(connection_string)
|
||||
self.cursor = self.connection.cursor(row_factory=dict_row)
|
||||
except psycopg.OperationalError as error:
|
||||
logger.error(error)
|
||||
|
||||
self.session_id = session_id
|
||||
self.table_name = table_name
|
||||
|
||||
self._create_table_if_not_exists()
|
||||
|
||||
def _create_table_if_not_exists(self) -> None:
|
||||
create_table_query = f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
id SERIAL PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
message JSONB NOT NULL
|
||||
);"""
|
||||
self.cursor.execute(create_table_query)
|
||||
self.connection.commit()
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve the messages from PostgreSQL"""
|
||||
query = (
|
||||
f"SELECT message FROM {self.table_name} WHERE session_id = %s ORDER BY id;"
|
||||
)
|
||||
self.cursor.execute(query, (self.session_id,))
|
||||
items = [record["message"] for record in self.cursor.fetchall()]
|
||||
messages = messages_from_dict(items)
|
||||
return messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in PostgreSQL"""
|
||||
from psycopg import sql
|
||||
|
||||
query = sql.SQL("INSERT INTO {} (session_id, message) VALUES (%s, %s);").format(
|
||||
sql.Identifier(self.table_name)
|
||||
)
|
||||
self.cursor.execute(
|
||||
query, (self.session_id, json.dumps(message_to_dict(message)))
|
||||
)
|
||||
self.connection.commit()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from PostgreSQL"""
|
||||
query = f"DELETE FROM {self.table_name} WHERE session_id = %s;"
|
||||
self.cursor.execute(query, (self.session_id,))
|
||||
self.connection.commit()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.cursor:
|
||||
self.cursor.close()
|
||||
if self.connection:
|
||||
self.connection.close()
|
||||
@@ -0,0 +1,65 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
from langchain_integrations.utilities.redis import get_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history stored in a Redis database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
url: str = "redis://localhost:6379/0",
|
||||
key_prefix: str = "message_store:",
|
||||
ttl: Optional[int] = None,
|
||||
):
|
||||
try:
|
||||
import redis
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
|
||||
try:
|
||||
self.redis_client = get_client(redis_url=url)
|
||||
except redis.exceptions.ConnectionError as error:
|
||||
logger.error(error)
|
||||
|
||||
self.session_id = session_id
|
||||
self.key_prefix = key_prefix
|
||||
self.ttl = ttl
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Construct the record key to use"""
|
||||
return self.key_prefix + self.session_id
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve the messages from Redis"""
|
||||
_items = self.redis_client.lrange(self.key, 0, -1)
|
||||
items = [json.loads(m.decode("utf-8")) for m in _items[::-1]]
|
||||
messages = messages_from_dict(items)
|
||||
return messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in Redis"""
|
||||
self.redis_client.lpush(self.key, json.dumps(message_to_dict(message)))
|
||||
if self.ttl:
|
||||
self.redis_client.expire(self.key, self.ttl)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from Redis"""
|
||||
self.redis_client.delete(self.key)
|
||||
@@ -0,0 +1,268 @@
|
||||
from datetime import datetime
|
||||
from time import sleep
|
||||
from typing import Any, Callable, List, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
|
||||
class RocksetChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Uses Rockset to store chat messages.
|
||||
|
||||
To use, ensure that the `rockset` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_integrations.chat_message_histories import (
|
||||
RocksetChatMessageHistory
|
||||
)
|
||||
from rockset import RocksetClient
|
||||
|
||||
history = RocksetChatMessageHistory(
|
||||
session_id="MySession",
|
||||
client=RocksetClient(),
|
||||
collection="langchain_demo",
|
||||
sync=True
|
||||
)
|
||||
|
||||
history.add_user_message("hi!")
|
||||
history.add_ai_message("whats up?")
|
||||
|
||||
print(history.messages)
|
||||
"""
|
||||
|
||||
# You should set these values based on your VI.
|
||||
# These values are configured for the typical
|
||||
# free VI. Read more about VIs here:
|
||||
# https://rockset.com/docs/instances
|
||||
SLEEP_INTERVAL_MS: int = 5
|
||||
ADD_TIMEOUT_MS: int = 5000
|
||||
CREATE_TIMEOUT_MS: int = 20000
|
||||
|
||||
def _wait_until(self, method: Callable, timeout: int, **method_params: Any) -> None:
|
||||
"""Sleeps until meth() evaluates to true. Passes kwargs into
|
||||
meth.
|
||||
"""
|
||||
start = datetime.now()
|
||||
while not method(**method_params):
|
||||
curr = datetime.now()
|
||||
if (curr - start).total_seconds() * 1000 > timeout:
|
||||
raise TimeoutError(f"{method} timed out at {timeout} ms")
|
||||
sleep(RocksetChatMessageHistory.SLEEP_INTERVAL_MS / 1000)
|
||||
|
||||
def _query(self, query: str, **query_params: Any) -> List[Any]:
|
||||
"""Executes an SQL statement and returns the result
|
||||
Args:
|
||||
- query: The SQL string
|
||||
- **query_params: Parameters to pass into the query
|
||||
"""
|
||||
return self.client.sql(query, params=query_params).results
|
||||
|
||||
def _create_collection(self) -> None:
|
||||
"""Creates a collection for this message history"""
|
||||
self.client.Collections.create_s3_collection(
|
||||
name=self.collection, workspace=self.workspace
|
||||
)
|
||||
|
||||
def _collection_exists(self) -> bool:
|
||||
"""Checks whether a collection exists for this message history"""
|
||||
try:
|
||||
self.client.Collections.get(collection=self.collection)
|
||||
except self.rockset.exceptions.NotFoundException:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _collection_is_ready(self) -> bool:
|
||||
"""Checks whether the collection for this message history is ready
|
||||
to be queried
|
||||
"""
|
||||
return (
|
||||
self.client.Collections.get(collection=self.collection).data.status
|
||||
== "READY"
|
||||
)
|
||||
|
||||
def _document_exists(self) -> bool:
|
||||
return (
|
||||
len(
|
||||
self._query(
|
||||
f"""
|
||||
SELECT 1
|
||||
FROM {self.location}
|
||||
WHERE _id=:session_id
|
||||
LIMIT 1
|
||||
""",
|
||||
session_id=self.session_id,
|
||||
)
|
||||
)
|
||||
!= 0
|
||||
)
|
||||
|
||||
def _wait_until_collection_created(self) -> None:
|
||||
"""Sleeps until the collection for this message history is ready
|
||||
to be queried
|
||||
"""
|
||||
self._wait_until(
|
||||
lambda: self._collection_is_ready(),
|
||||
RocksetChatMessageHistory.CREATE_TIMEOUT_MS,
|
||||
)
|
||||
|
||||
def _wait_until_message_added(self, message_id: str) -> None:
|
||||
"""Sleeps until a message is added to the messages list"""
|
||||
self._wait_until(
|
||||
lambda message_id: len(
|
||||
self._query(
|
||||
f"""
|
||||
SELECT *
|
||||
FROM UNNEST((
|
||||
SELECT {self.messages_key}
|
||||
FROM {self.location}
|
||||
WHERE _id = :session_id
|
||||
)) AS message
|
||||
WHERE message.data.additional_kwargs.id = :message_id
|
||||
LIMIT 1
|
||||
""",
|
||||
session_id=self.session_id,
|
||||
message_id=message_id,
|
||||
),
|
||||
)
|
||||
!= 0,
|
||||
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def _create_empty_doc(self) -> None:
|
||||
"""Creates or replaces a document for this message history with no
|
||||
messages"""
|
||||
self.client.Documents.add_documents(
|
||||
collection=self.collection,
|
||||
workspace=self.workspace,
|
||||
data=[{"_id": self.session_id, self.messages_key: []}],
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
client: Any,
|
||||
collection: str,
|
||||
workspace: str = "commons",
|
||||
messages_key: str = "messages",
|
||||
sync: bool = False,
|
||||
message_uuid_method: Callable[[], Union[str, int]] = lambda: str(uuid4()),
|
||||
) -> None:
|
||||
"""Constructs a new RocksetChatMessageHistory.
|
||||
|
||||
Args:
|
||||
- session_id: The ID of the chat session
|
||||
- client: The RocksetClient object to use to query
|
||||
- collection: The name of the collection to use to store chat
|
||||
messages. If a collection with the given name
|
||||
does not exist in the workspace, it is created.
|
||||
- workspace: The workspace containing `collection`. Defaults
|
||||
to `"commons"`
|
||||
- messages_key: The DB column containing message history.
|
||||
Defaults to `"messages"`
|
||||
- sync: Whether to wait for messages to be added. Defaults
|
||||
to `False`. NOTE: setting this to `True` will slow
|
||||
down performance.
|
||||
- message_uuid_method: The method that generates message IDs.
|
||||
If set, all messages will have an `id` field within the
|
||||
`additional_kwargs` property. If this param is not set
|
||||
and `sync` is `False`, message IDs will not be created.
|
||||
If this param is not set and `sync` is `True`, the
|
||||
`uuid.uuid4` method will be used to create message IDs.
|
||||
"""
|
||||
try:
|
||||
import rockset
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import rockset client python package. "
|
||||
"Please install it with `pip install rockset`."
|
||||
)
|
||||
|
||||
if not isinstance(client, rockset.RocksetClient):
|
||||
raise ValueError(
|
||||
f"client should be an instance of rockset.RocksetClient, "
|
||||
f"got {type(client)}"
|
||||
)
|
||||
|
||||
self.session_id = session_id
|
||||
self.client = client
|
||||
self.collection = collection
|
||||
self.workspace = workspace
|
||||
self.location = f'"{self.workspace}"."{self.collection}"'
|
||||
self.rockset = rockset
|
||||
self.messages_key = messages_key
|
||||
self.message_uuid_method = message_uuid_method
|
||||
self.sync = sync
|
||||
|
||||
try:
|
||||
self.client.set_application("langchain")
|
||||
except AttributeError:
|
||||
# ignore
|
||||
pass
|
||||
|
||||
if not self._collection_exists():
|
||||
self._create_collection()
|
||||
self._wait_until_collection_created()
|
||||
self._create_empty_doc()
|
||||
elif not self._document_exists():
|
||||
self._create_empty_doc()
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Messages in this chat history."""
|
||||
return messages_from_dict(
|
||||
self._query(
|
||||
f"""
|
||||
SELECT *
|
||||
FROM UNNEST ((
|
||||
SELECT "{self.messages_key}"
|
||||
FROM {self.location}
|
||||
WHERE _id = :session_id
|
||||
))
|
||||
""",
|
||||
session_id=self.session_id,
|
||||
)
|
||||
)
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a Message object to the history.
|
||||
|
||||
Args:
|
||||
message: A BaseMessage object to store.
|
||||
"""
|
||||
if self.sync and "id" not in message.additional_kwargs:
|
||||
message.additional_kwargs["id"] = self.message_uuid_method()
|
||||
self.client.Documents.patch_documents(
|
||||
collection=self.collection,
|
||||
workspace=self.workspace,
|
||||
data=[
|
||||
self.rockset.model.patch_document.PatchDocument(
|
||||
id=self.session_id,
|
||||
patch=[
|
||||
self.rockset.model.patch_operation.PatchOperation(
|
||||
op="ADD",
|
||||
path=f"/{self.messages_key}/-",
|
||||
value=message_to_dict(message),
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
if self.sync:
|
||||
self._wait_until_message_added(message.additional_kwargs["id"])
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Removes all messages from the chat history"""
|
||||
self._create_empty_doc()
|
||||
if self.sync:
|
||||
self._wait_until(
|
||||
lambda: not self.messages,
|
||||
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
|
||||
)
|
||||
@@ -0,0 +1,277 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
)
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SingleStoreDBChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history stored in a SingleStoreDB database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
table_name: str = "message_store",
|
||||
id_field: str = "id",
|
||||
session_id_field: str = "session_id",
|
||||
message_field: str = "message",
|
||||
pool_size: int = 5,
|
||||
max_overflow: int = 10,
|
||||
timeout: float = 30,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize with necessary components.
|
||||
|
||||
Args:
|
||||
|
||||
|
||||
table_name (str, optional): Specifies the name of the table in use.
|
||||
Defaults to "message_store".
|
||||
id_field (str, optional): Specifies the name of the id field in the table.
|
||||
Defaults to "id".
|
||||
session_id_field (str, optional): Specifies the name of the session_id
|
||||
field in the table. Defaults to "session_id".
|
||||
message_field (str, optional): Specifies the name of the message field
|
||||
in the table. Defaults to "message".
|
||||
|
||||
Following arguments pertain to the connection pool:
|
||||
|
||||
pool_size (int, optional): Determines the number of active connections in
|
||||
the pool. Defaults to 5.
|
||||
max_overflow (int, optional): Determines the maximum number of connections
|
||||
allowed beyond the pool_size. Defaults to 10.
|
||||
timeout (float, optional): Specifies the maximum wait time in seconds for
|
||||
establishing a connection. Defaults to 30.
|
||||
|
||||
Following arguments pertain to the database connection:
|
||||
|
||||
host (str, optional): Specifies the hostname, IP address, or URL for the
|
||||
database connection. The default scheme is "mysql".
|
||||
user (str, optional): Database username.
|
||||
password (str, optional): Database password.
|
||||
port (int, optional): Database port. Defaults to 3306 for non-HTTP
|
||||
connections, 80 for HTTP connections, and 443 for HTTPS connections.
|
||||
database (str, optional): Database name.
|
||||
|
||||
Additional optional arguments provide further customization over the
|
||||
database connection:
|
||||
|
||||
pure_python (bool, optional): Toggles the connector mode. If True,
|
||||
operates in pure Python mode.
|
||||
local_infile (bool, optional): Allows local file uploads.
|
||||
charset (str, optional): Specifies the character set for string values.
|
||||
ssl_key (str, optional): Specifies the path of the file containing the SSL
|
||||
key.
|
||||
ssl_cert (str, optional): Specifies the path of the file containing the SSL
|
||||
certificate.
|
||||
ssl_ca (str, optional): Specifies the path of the file containing the SSL
|
||||
certificate authority.
|
||||
ssl_cipher (str, optional): Sets the SSL cipher list.
|
||||
ssl_disabled (bool, optional): Disables SSL usage.
|
||||
ssl_verify_cert (bool, optional): Verifies the server's certificate.
|
||||
Automatically enabled if ``ssl_ca`` is specified.
|
||||
ssl_verify_identity (bool, optional): Verifies the server's identity.
|
||||
conv (dict[int, Callable], optional): A dictionary of data conversion
|
||||
functions.
|
||||
credential_type (str, optional): Specifies the type of authentication to
|
||||
use: auth.PASSWORD, auth.JWT, or auth.BROWSER_SSO.
|
||||
autocommit (bool, optional): Enables autocommits.
|
||||
results_type (str, optional): Determines the structure of the query results:
|
||||
tuples, namedtuples, dicts.
|
||||
results_format (str, optional): Deprecated. This option has been renamed to
|
||||
results_type.
|
||||
|
||||
Examples:
|
||||
Basic Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_integrations.chat_message_histories import (
|
||||
SingleStoreDBChatMessageHistory
|
||||
)
|
||||
|
||||
message_history = SingleStoreDBChatMessageHistory(
|
||||
session_id="my-session",
|
||||
host="https://user:password@127.0.0.1:3306/database"
|
||||
)
|
||||
|
||||
Advanced Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_integrations.chat_message_histories import (
|
||||
SingleStoreDBChatMessageHistory
|
||||
)
|
||||
|
||||
message_history = SingleStoreDBChatMessageHistory(
|
||||
session_id="my-session",
|
||||
host="127.0.0.1",
|
||||
port=3306,
|
||||
user="user",
|
||||
password="password",
|
||||
database="db",
|
||||
table_name="my_custom_table",
|
||||
pool_size=10,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
Using environment variables:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_integrations.chat_message_histories import (
|
||||
SingleStoreDBChatMessageHistory
|
||||
)
|
||||
|
||||
os.environ['SINGLESTOREDB_URL'] = 'me:p455w0rd@s2-host.com/my_db'
|
||||
message_history = SingleStoreDBChatMessageHistory("my-session")
|
||||
"""
|
||||
|
||||
self.table_name = self._sanitize_input(table_name)
|
||||
self.session_id = self._sanitize_input(session_id)
|
||||
self.id_field = self._sanitize_input(id_field)
|
||||
self.session_id_field = self._sanitize_input(session_id_field)
|
||||
self.message_field = self._sanitize_input(message_field)
|
||||
|
||||
# Pass the rest of the kwargs to the connection.
|
||||
self.connection_kwargs = kwargs
|
||||
|
||||
# Add connection attributes to the connection kwargs.
|
||||
if "conn_attrs" not in self.connection_kwargs:
|
||||
self.connection_kwargs["conn_attrs"] = dict()
|
||||
|
||||
self.connection_kwargs["conn_attrs"]["_connector_name"] = "langchain python sdk"
|
||||
self.connection_kwargs["conn_attrs"]["_connector_version"] = "1.0.1"
|
||||
|
||||
# Create a connection pool.
|
||||
try:
|
||||
from sqlalchemy.pool import QueuePool
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import sqlalchemy.pool python package. "
|
||||
"Please install it with `pip install singlestoredb`."
|
||||
)
|
||||
|
||||
self.connection_pool = QueuePool(
|
||||
self._get_connection,
|
||||
max_overflow=max_overflow,
|
||||
pool_size=pool_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
self.table_created = False
|
||||
|
||||
def _sanitize_input(self, input_str: str) -> str:
|
||||
# Remove characters that are not alphanumeric or underscores
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "", input_str)
|
||||
|
||||
def _get_connection(self) -> Any:
|
||||
try:
|
||||
import singlestoredb as s2
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import singlestoredb python package. "
|
||||
"Please install it with `pip install singlestoredb`."
|
||||
)
|
||||
return s2.connect(**self.connection_kwargs)
|
||||
|
||||
def _create_table_if_not_exists(self) -> None:
|
||||
"""Create table if it doesn't exist."""
|
||||
if self.table_created:
|
||||
return
|
||||
conn = self.connection_pool.connect()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS {}
|
||||
({} BIGINT PRIMARY KEY AUTO_INCREMENT,
|
||||
{} TEXT NOT NULL,
|
||||
{} JSON NOT NULL);""".format(
|
||||
self.table_name,
|
||||
self.id_field,
|
||||
self.session_id_field,
|
||||
self.message_field,
|
||||
),
|
||||
)
|
||||
self.table_created = True
|
||||
finally:
|
||||
cur.close()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve the messages from SingleStoreDB"""
|
||||
self._create_table_if_not_exists()
|
||||
conn = self.connection_pool.connect()
|
||||
items = []
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"""SELECT {} FROM {} WHERE {} = %s""".format(
|
||||
self.message_field,
|
||||
self.table_name,
|
||||
self.session_id_field,
|
||||
),
|
||||
(self.session_id),
|
||||
)
|
||||
for row in cur.fetchall():
|
||||
items.append(row[0])
|
||||
finally:
|
||||
cur.close()
|
||||
finally:
|
||||
conn.close()
|
||||
messages = messages_from_dict(items)
|
||||
return messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in SingleStoreDB"""
|
||||
self._create_table_if_not_exists()
|
||||
conn = self.connection_pool.connect()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"""INSERT INTO {} ({}, {}) VALUES (%s, %s)""".format(
|
||||
self.table_name,
|
||||
self.session_id_field,
|
||||
self.message_field,
|
||||
),
|
||||
(self.session_id, json.dumps(message_to_dict(message))),
|
||||
)
|
||||
finally:
|
||||
cur.close()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from SingleStoreDB"""
|
||||
self._create_table_if_not_exists()
|
||||
conn = self.connection_pool.connect()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"""DELETE FROM {} WHERE {} = %s""".format(
|
||||
self.table_name,
|
||||
self.session_id_field,
|
||||
),
|
||||
(self.session_id),
|
||||
)
|
||||
finally:
|
||||
cur.close()
|
||||
finally:
|
||||
conn.close()
|
||||
@@ -0,0 +1,140 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, create_engine
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import declarative_base
|
||||
except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseMessageConverter(ABC):
|
||||
"""The class responsible for converting BaseMessage to your SQLAlchemy model."""
|
||||
|
||||
@abstractmethod
|
||||
def from_sql_model(self, sql_message: Any) -> BaseMessage:
|
||||
"""Convert a SQLAlchemy model to a BaseMessage instance."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
|
||||
"""Convert a BaseMessage instance to a SQLAlchemy model."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_sql_model_class(self) -> Any:
|
||||
"""Get the SQLAlchemy model class."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def create_message_model(table_name, DynamicBase): # type: ignore
|
||||
"""
|
||||
Create a message model for a given table name.
|
||||
|
||||
Args:
|
||||
table_name: The name of the table to use.
|
||||
DynamicBase: The base class to use for the model.
|
||||
|
||||
Returns:
|
||||
The model class.
|
||||
|
||||
"""
|
||||
|
||||
# Model decleared inside a function to have a dynamic table name
|
||||
class Message(DynamicBase):
|
||||
__tablename__ = table_name
|
||||
id = Column(Integer, primary_key=True)
|
||||
session_id = Column(Text)
|
||||
message = Column(Text)
|
||||
|
||||
return Message
|
||||
|
||||
|
||||
class DefaultMessageConverter(BaseMessageConverter):
|
||||
"""The default message converter for SQLChatMessageHistory."""
|
||||
|
||||
def __init__(self, table_name: str):
|
||||
self.model_class = create_message_model(table_name, declarative_base())
|
||||
|
||||
def from_sql_model(self, sql_message: Any) -> BaseMessage:
|
||||
return messages_from_dict([json.loads(sql_message.message)])[0]
|
||||
|
||||
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
|
||||
return self.model_class(
|
||||
session_id=session_id, message=json.dumps(message_to_dict(message))
|
||||
)
|
||||
|
||||
def get_sql_model_class(self) -> Any:
|
||||
return self.model_class
|
||||
|
||||
|
||||
class SQLChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history stored in an SQL database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
connection_string: str,
|
||||
table_name: str = "message_store",
|
||||
session_id_field_name: str = "session_id",
|
||||
custom_message_converter: Optional[BaseMessageConverter] = None,
|
||||
):
|
||||
self.connection_string = connection_string
|
||||
self.engine = create_engine(connection_string, echo=False)
|
||||
self.session_id_field_name = session_id_field_name
|
||||
self.converter = custom_message_converter or DefaultMessageConverter(table_name)
|
||||
self.sql_model_class = self.converter.get_sql_model_class()
|
||||
if not hasattr(self.sql_model_class, session_id_field_name):
|
||||
raise ValueError("SQL model class must have session_id column")
|
||||
self._create_table_if_not_exists()
|
||||
|
||||
self.session_id = session_id
|
||||
self.Session = sessionmaker(self.engine)
|
||||
|
||||
def _create_table_if_not_exists(self) -> None:
|
||||
self.sql_model_class.metadata.create_all(self.engine)
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve all messages from db"""
|
||||
with self.Session() as session:
|
||||
result = (
|
||||
session.query(self.sql_model_class)
|
||||
.where(
|
||||
getattr(self.sql_model_class, self.session_id_field_name)
|
||||
== self.session_id
|
||||
)
|
||||
.order_by(self.sql_model_class.id.asc())
|
||||
)
|
||||
messages = []
|
||||
for record in result:
|
||||
messages.append(self.converter.from_sql_model(record))
|
||||
return messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in db"""
|
||||
with self.Session() as session:
|
||||
session.add(self.converter.to_sql_model(message, self.session_id))
|
||||
session.commit()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from db"""
|
||||
|
||||
with self.Session() as session:
|
||||
session.query(self.sql_model_class).filter(
|
||||
getattr(self.sql_model_class, self.session_id_field_name)
|
||||
== self.session_id
|
||||
).delete()
|
||||
session.commit()
|
||||
@@ -0,0 +1,38 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
class StreamlitChatMessageHistory(BaseChatMessageHistory):
|
||||
"""
|
||||
Chat message history that stores messages in Streamlit session state.
|
||||
|
||||
Args:
|
||||
key: The key to use in Streamlit session state for storing messages.
|
||||
"""
|
||||
|
||||
def __init__(self, key: str = "langchain_messages"):
|
||||
try:
|
||||
import streamlit as st
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import streamlit, please run `pip install streamlit`."
|
||||
) from e
|
||||
|
||||
if key not in st.session_state:
|
||||
st.session_state[key] = []
|
||||
self._messages = st.session_state[key]
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve the current list of messages"""
|
||||
return self._messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a message to the session memory"""
|
||||
self._messages.append(message)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory"""
|
||||
self._messages.clear()
|
||||
@@ -0,0 +1,69 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpstashRedisChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history stored in an Upstash Redis database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
url: str = "",
|
||||
token: str = "",
|
||||
key_prefix: str = "message_store:",
|
||||
ttl: Optional[int] = None,
|
||||
):
|
||||
try:
|
||||
from upstash_redis import Redis
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import upstash redis python package. "
|
||||
"Please install it with `pip install upstash_redis`."
|
||||
)
|
||||
|
||||
if url == "" or token == "":
|
||||
raise ValueError(
|
||||
"UPSTASH_REDIS_REST_URL and UPSTASH_REDIS_REST_TOKEN are needed."
|
||||
)
|
||||
|
||||
try:
|
||||
self.redis_client = Redis(url=url, token=token)
|
||||
except Exception:
|
||||
logger.error("Upstash Redis instance could not be initiated.")
|
||||
|
||||
self.session_id = session_id
|
||||
self.key_prefix = key_prefix
|
||||
self.ttl = ttl
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Construct the record key to use"""
|
||||
return self.key_prefix + self.session_id
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve the messages from Upstash Redis"""
|
||||
_items = self.redis_client.lrange(self.key, 0, -1)
|
||||
items = [json.loads(m) for m in _items[::-1]]
|
||||
messages = messages_from_dict(items)
|
||||
return messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in Upstash Redis"""
|
||||
self.redis_client.lpush(self.key, json.dumps(message_to_dict(message)))
|
||||
if self.ttl:
|
||||
self.redis_client.expire(self.key, self.ttl)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from Upstash Redis"""
|
||||
self.redis_client.delete(self.key)
|
||||
@@ -0,0 +1,134 @@
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
|
||||
class XataChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history stored in a Xata database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
db_url: str,
|
||||
api_key: str,
|
||||
branch_name: str = "main",
|
||||
table_name: str = "messages",
|
||||
create_table: bool = True,
|
||||
) -> None:
|
||||
"""Initialize with Xata client."""
|
||||
try:
|
||||
from xata.client import XataClient # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import xata python package. "
|
||||
"Please install it with `pip install xata`."
|
||||
)
|
||||
self._client = XataClient(
|
||||
api_key=api_key, db_url=db_url, branch_name=branch_name
|
||||
)
|
||||
self._table_name = table_name
|
||||
self._session_id = session_id
|
||||
|
||||
if create_table:
|
||||
self._create_table_if_not_exists()
|
||||
|
||||
def _create_table_if_not_exists(self) -> None:
|
||||
r = self._client.table().get_schema(self._table_name)
|
||||
if r.status_code <= 299:
|
||||
return
|
||||
if r.status_code != 404:
|
||||
raise Exception(
|
||||
f"Error checking if table exists in Xata: {r.status_code} {r}"
|
||||
)
|
||||
r = self._client.table().create(self._table_name)
|
||||
if r.status_code > 299:
|
||||
raise Exception(f"Error creating table in Xata: {r.status_code} {r}")
|
||||
r = self._client.table().set_schema(
|
||||
self._table_name,
|
||||
payload={
|
||||
"columns": [
|
||||
{"name": "sessionId", "type": "string"},
|
||||
{"name": "type", "type": "string"},
|
||||
{"name": "role", "type": "string"},
|
||||
{"name": "content", "type": "text"},
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "additionalKwargs", "type": "json"},
|
||||
]
|
||||
},
|
||||
)
|
||||
if r.status_code > 299:
|
||||
raise Exception(f"Error setting table schema in Xata: {r.status_code} {r}")
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the Xata table"""
|
||||
msg = message_to_dict(message)
|
||||
r = self._client.records().insert(
|
||||
self._table_name,
|
||||
{
|
||||
"sessionId": self._session_id,
|
||||
"type": msg["type"],
|
||||
"content": message.content,
|
||||
"additionalKwargs": json.dumps(message.additional_kwargs),
|
||||
"role": msg["data"].get("role"),
|
||||
"name": msg["data"].get("name"),
|
||||
},
|
||||
)
|
||||
if r.status_code > 299:
|
||||
raise Exception(f"Error adding message to Xata: {r.status_code} {r}")
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
r = self._client.data().query(
|
||||
self._table_name,
|
||||
payload={
|
||||
"filter": {
|
||||
"sessionId": self._session_id,
|
||||
},
|
||||
"sort": {"xata.createdAt": "asc"},
|
||||
},
|
||||
)
|
||||
if r.status_code != 200:
|
||||
raise Exception(f"Error running query: {r.status_code} {r}")
|
||||
msgs = messages_from_dict(
|
||||
[
|
||||
{
|
||||
"type": m["type"],
|
||||
"data": {
|
||||
"content": m["content"],
|
||||
"role": m.get("role"),
|
||||
"name": m.get("name"),
|
||||
"additional_kwargs": json.loads(m["additionalKwargs"]),
|
||||
},
|
||||
}
|
||||
for m in r["records"]
|
||||
]
|
||||
)
|
||||
return msgs
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Delete session from Xata table."""
|
||||
while True:
|
||||
r = self._client.data().query(
|
||||
self._table_name,
|
||||
payload={
|
||||
"columns": ["id"],
|
||||
"filter": {
|
||||
"sessionId": self._session_id,
|
||||
},
|
||||
},
|
||||
)
|
||||
if r.status_code != 200:
|
||||
raise Exception(f"Error running query: {r.status_code} {r}")
|
||||
ids = [rec["id"] for rec in r["records"]]
|
||||
if len(ids) == 0:
|
||||
break
|
||||
operations = [
|
||||
{"delete": {"table": self._table_name, "id": id}} for id in ids
|
||||
]
|
||||
self._client.records().transaction(payload={"operations": operations})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user