community[minor]: add Cassandra Database Toolkit (#20246)

**Description**: ToolKit and Tools for accessing data in a Cassandra
Database primarily for Agent integration. Initially, this includes the
following tools:
- `cassandra_db_schema` Gathers all schema information for the connected
database or a specific schema. Critical for the agent when determining
actions.
- `cassandra_db_select_table_data` Selects data from a specific keyspace
and table. The agent can pass paramaters for a predicate and limits on
the number of returned records.
- `cassandra_db_query` Expiriemental alternative to
`cassandra_db_select_table_data` which takes a query string completely
formed by the agent instead of parameters. May be removed in future
versions.

Includes unit test and two notebooks to demonstrate usage. 

**Dependencies**: cassio
**Twitter handle**: @PatrickMcFadin

---------

Co-authored-by: Phil Miesle <phil.miesle@datastax.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Patrick McFadin 2024-04-29 08:51:43 -07:00 committed by GitHub
parent b3e74f2b98
commit 3331865f6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 2007 additions and 0 deletions

557
cookbook/cql_agent.ipynb Normal file
View File

@ -0,0 +1,557 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup Environment"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Python Modules"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Install the following Python modules:\n",
"\n",
"```bash\n",
"pip install ipykernel python-dotenv cassio pandas langchain_openai langchain langchain-community langchainhub langchain_experimental openai-multi-tool-use-parallel-patch\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load the `.env` File"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Connection is via `cassio` using `auto=True` parameter, and the notebook uses OpenAI. You should create a `.env` file accordingly.\n",
"\n",
"For Casssandra, set:\n",
"```bash\n",
"CASSANDRA_CONTACT_POINTS\n",
"CASSANDRA_USERNAME\n",
"CASSANDRA_PASSWORD\n",
"CASSANDRA_KEYSPACE\n",
"```\n",
"\n",
"For Astra, set:\n",
"```bash\n",
"ASTRA_DB_APPLICATION_TOKEN\n",
"ASTRA_DB_DATABASE_ID\n",
"ASTRA_DB_KEYSPACE\n",
"```\n",
"\n",
"For example:\n",
"\n",
"```bash\n",
"# Connection to Astra:\n",
"ASTRA_DB_DATABASE_ID=a1b2c3d4-...\n",
"ASTRA_DB_APPLICATION_TOKEN=AstraCS:...\n",
"ASTRA_DB_KEYSPACE=notebooks\n",
"\n",
"# Also set \n",
"OPENAI_API_KEY=sk-....\n",
"```\n",
"\n",
"(You may also modify the below code to directly connect with `cassio`.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dotenv import load_dotenv\n",
"\n",
"load_dotenv(override=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Connect to Cassandra"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import cassio\n",
"\n",
"cassio.init(auto=True)\n",
"session = cassio.config.resolve_session()\n",
"if not session:\n",
" raise Exception(\n",
" \"Check environment configuration or manually configure cassio connection parameters\"\n",
" )\n",
"\n",
"keyspace = os.environ.get(\n",
" \"ASTRA_DB_KEYSPACE\", os.environ.get(\"CASSANDRA_KEYSPACE\", None)\n",
")\n",
"if not keyspace:\n",
" raise ValueError(\"a KEYSPACE environment variable must be set\")\n",
"\n",
"session.set_keyspace(keyspace)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup Database"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This needs to be done one time only!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Download Data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The dataset used is from Kaggle, the [Environmental Sensor Telemetry Data](https://www.kaggle.com/datasets/garystafford/environmental-sensor-data-132k?select=iot_telemetry_data.csv). The next cell will download and unzip the data into a Pandas dataframe. The following cell is instructions to download manually. \n",
"\n",
"The net result of this section is you should have a Pandas dataframe variable `df`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Download Automatically"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from io import BytesIO\n",
"from zipfile import ZipFile\n",
"\n",
"import pandas as pd\n",
"import requests\n",
"\n",
"datasetURL = \"https://storage.googleapis.com/kaggle-data-sets/788816/1355729/bundle/archive.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20240404%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20240404T115828Z&X-Goog-Expires=259200&X-Goog-SignedHeaders=host&X-Goog-Signature=2849f003b100eb9dcda8dd8535990f51244292f67e4f5fad36f14aa67f2d4297672d8fe6ff5a39f03a29cda051e33e95d36daab5892b8874dcd5a60228df0361fa26bae491dd4371f02dd20306b583a44ba85a4474376188b1f84765147d3b4f05c57345e5de883c2c29653cce1f3755cd8e645c5e952f4fb1c8a735b22f0c811f97f7bce8d0235d0d3731ca8ab4629ff381f3bae9e35fc1b181c1e69a9c7913a5e42d9d52d53e5f716467205af9c8a3cc6746fc5352e8fbc47cd7d18543626bd67996d18c2045c1e475fc136df83df352fa747f1a3bb73e6ba3985840792ec1de407c15836640ec96db111b173bf16115037d53fdfbfd8ac44145d7f9a546aa\"\n",
"\n",
"response = requests.get(datasetURL)\n",
"if response.status_code == 200:\n",
" zip_file = ZipFile(BytesIO(response.content))\n",
" csv_file_name = zip_file.namelist()[0]\n",
"else:\n",
" print(\"Failed to download the file\")\n",
"\n",
"with zip_file.open(csv_file_name) as csv_file:\n",
" df = pd.read_csv(csv_file)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Download Manually"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can download the `.zip` file and unpack the `.csv` contained within. Comment in the next line, and adjust the path to this `.csv` file appropriately."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# df = pd.read_csv(\"/path/to/iot_telemetry_data.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load Data into Cassandra"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This section assumes the existence of a dataframe `df`, the following cell validates its structure. The Download section above creates this object."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"assert df is not None, \"Dataframe 'df' must be set\"\n",
"expected_columns = [\n",
" \"ts\",\n",
" \"device\",\n",
" \"co\",\n",
" \"humidity\",\n",
" \"light\",\n",
" \"lpg\",\n",
" \"motion\",\n",
" \"smoke\",\n",
" \"temp\",\n",
"]\n",
"assert all(\n",
" [column in df.columns for column in expected_columns]\n",
"), \"DataFrame does not have the expected columns\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create and load tables:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datetime import UTC, datetime\n",
"\n",
"from cassandra.query import BatchStatement\n",
"\n",
"# Create sensors table\n",
"table_query = \"\"\"\n",
"CREATE TABLE IF NOT EXISTS iot_sensors (\n",
" device text,\n",
" conditions text,\n",
" room text,\n",
" PRIMARY KEY (device)\n",
")\n",
"WITH COMMENT = 'Environmental IoT room sensor metadata.';\n",
"\"\"\"\n",
"session.execute(table_query)\n",
"\n",
"pstmt = session.prepare(\n",
" \"\"\"\n",
"INSERT INTO iot_sensors (device, conditions, room)\n",
"VALUES (?, ?, ?)\n",
"\"\"\"\n",
")\n",
"\n",
"devices = [\n",
" (\"00:0f:00:70:91:0a\", \"stable conditions, cooler and more humid\", \"room 1\"),\n",
" (\"1c:bf:ce:15:ec:4d\", \"highly variable temperature and humidity\", \"room 2\"),\n",
" (\"b8:27:eb:bf:9d:51\", \"stable conditions, warmer and dryer\", \"room 3\"),\n",
"]\n",
"\n",
"for device, conditions, room in devices:\n",
" session.execute(pstmt, (device, conditions, room))\n",
"\n",
"print(\"Sensors inserted successfully.\")\n",
"\n",
"# Create data table\n",
"table_query = \"\"\"\n",
"CREATE TABLE IF NOT EXISTS iot_data (\n",
" day text,\n",
" device text,\n",
" ts timestamp,\n",
" co double,\n",
" humidity double,\n",
" light boolean,\n",
" lpg double,\n",
" motion boolean,\n",
" smoke double,\n",
" temp double,\n",
" PRIMARY KEY ((day, device), ts)\n",
")\n",
"WITH COMMENT = 'Data from environmental IoT room sensors. Columns include device identifier, timestamp (ts) of the data collection, carbon monoxide level (co), relative humidity, light presence, LPG concentration, motion detection, smoke concentration, and temperature (temp). Data is partitioned by day and device.';\n",
"\"\"\"\n",
"session.execute(table_query)\n",
"\n",
"pstmt = session.prepare(\n",
" \"\"\"\n",
"INSERT INTO iot_data (day, device, ts, co, humidity, light, lpg, motion, smoke, temp)\n",
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)\n",
"\"\"\"\n",
")\n",
"\n",
"\n",
"def insert_data_batch(name, group):\n",
" batch = BatchStatement()\n",
" day, device = name\n",
" print(f\"Inserting batch for day: {day}, device: {device}\")\n",
"\n",
" for _, row in group.iterrows():\n",
" timestamp = datetime.fromtimestamp(row[\"ts\"], UTC)\n",
" batch.add(\n",
" pstmt,\n",
" (\n",
" day,\n",
" row[\"device\"],\n",
" timestamp,\n",
" row[\"co\"],\n",
" row[\"humidity\"],\n",
" row[\"light\"],\n",
" row[\"lpg\"],\n",
" row[\"motion\"],\n",
" row[\"smoke\"],\n",
" row[\"temp\"],\n",
" ),\n",
" )\n",
"\n",
" session.execute(batch)\n",
"\n",
"\n",
"# Convert columns to appropriate types\n",
"df[\"light\"] = df[\"light\"] == \"true\"\n",
"df[\"motion\"] = df[\"motion\"] == \"true\"\n",
"df[\"ts\"] = df[\"ts\"].astype(float)\n",
"df[\"day\"] = df[\"ts\"].apply(\n",
" lambda x: datetime.fromtimestamp(x, UTC).strftime(\"%Y-%m-%d\")\n",
")\n",
"\n",
"grouped_df = df.groupby([\"day\", \"device\"])\n",
"\n",
"for name, group in grouped_df:\n",
" insert_data_batch(name, group)\n",
"\n",
"print(\"Data load complete\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(session.keyspace)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the Tools"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Python `import` statements for the demo:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import AgentExecutor, create_openai_tools_agent\n",
"from langchain_community.agent_toolkits.cassandra_database.toolkit import (\n",
" CassandraDatabaseToolkit,\n",
")\n",
"from langchain_community.tools.cassandra_database.prompt import QUERY_PATH_PROMPT\n",
"from langchain_community.tools.cassandra_database.tool import (\n",
" GetSchemaCassandraDatabaseTool,\n",
" GetTableDataCassandraDatabaseTool,\n",
" QueryCassandraDatabaseTool,\n",
")\n",
"from langchain_community.utilities.cassandra_database import CassandraDatabase\n",
"from langchain_openai import ChatOpenAI"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `CassandraDatabase` object is loaded from `cassio`, though it does accept a `Session`-type parameter as an alternative."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a CassandraDatabase instance\n",
"db = CassandraDatabase(include_tables=[\"iot_sensors\", \"iot_data\"])\n",
"\n",
"# Create the Cassandra Database tools\n",
"query_tool = QueryCassandraDatabaseTool(db=db)\n",
"schema_tool = GetSchemaCassandraDatabaseTool(db=db)\n",
"select_data_tool = GetTableDataCassandraDatabaseTool(db=db)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The tools can be invoked directly:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test the tools\n",
"print(\"Executing a CQL query:\")\n",
"query = \"SELECT * FROM iot_sensors LIMIT 5;\"\n",
"result = query_tool.run({\"query\": query})\n",
"print(result)\n",
"\n",
"print(\"\\nGetting the schema for a keyspace:\")\n",
"schema = schema_tool.run({\"keyspace\": keyspace})\n",
"print(schema)\n",
"\n",
"print(\"\\nGetting data from a table:\")\n",
"table = \"iot_data\"\n",
"predicate = \"day = '2020-07-14' and device = 'b8:27:eb:bf:9d:51'\"\n",
"data = select_data_tool.run(\n",
" {\"keyspace\": keyspace, \"table\": table, \"predicate\": predicate, \"limit\": 5}\n",
")\n",
"print(data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Agent Configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import Tool\n",
"from langchain_experimental.utilities import PythonREPL\n",
"\n",
"python_repl = PythonREPL()\n",
"\n",
"repl_tool = Tool(\n",
" name=\"python_repl\",\n",
" description=\"A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.\",\n",
" func=python_repl.run,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain import hub\n",
"\n",
"llm = ChatOpenAI(temperature=0, model=\"gpt-4-1106-preview\")\n",
"toolkit = CassandraDatabaseToolkit(db=db)\n",
"\n",
"# context = toolkit.get_context()\n",
"# tools = toolkit.get_tools()\n",
"tools = [schema_tool, select_data_tool, repl_tool]\n",
"\n",
"input = (\n",
" QUERY_PATH_PROMPT\n",
" + f\"\"\"\n",
"\n",
"Here is your task: In the {keyspace} keyspace, find the total number of times the temperature of each device has exceeded 23 degrees on July 14, 2020.\n",
" Create a summary report including the name of the room. Use Pandas if helpful.\n",
"\"\"\"\n",
")\n",
"\n",
"prompt = hub.pull(\"hwchase17/openai-tools-agent\")\n",
"\n",
"# messages = [\n",
"# HumanMessagePromptTemplate.from_template(input),\n",
"# AIMessage(content=QUERY_PATH_PROMPT),\n",
"# MessagesPlaceholder(variable_name=\"agent_scratchpad\"),\n",
"# ]\n",
"\n",
"# prompt = ChatPromptTemplate.from_messages(messages)\n",
"# print(prompt)\n",
"\n",
"# Choose the LLM that will drive the agent\n",
"# Only certain models support this\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-1106\", temperature=0)\n",
"\n",
"# Construct the OpenAI Tools agent\n",
"agent = create_openai_tools_agent(llm, tools, prompt)\n",
"\n",
"print(\"Available tools:\")\n",
"for tool in tools:\n",
" print(\"\\t\" + tool.name + \" - \" + tool.description + \" - \" + str(tool))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)\n",
"\n",
"response = agent_executor.invoke({\"input\": input})\n",
"\n",
"print(response[\"output\"])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -0,0 +1,481 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Cassandra Database\n",
"\n",
"Apache Cassandra® is a widely used database for storing transactional application data. The introduction of functions and tooling in Large Language Models has opened up some exciting use cases for existing data in Generative AI applications. The Cassandra Database toolkit enables AI engineers to efficiently integrate Agents with Cassandra data, offering the following features: \n",
" - Fast data access through optimized queries. Most queries should run in single-digit ms or less. \n",
" - Schema introspection to enhance LLM reasoning capabilities \n",
" - Compatibility with various Cassandra deployments, including Apache Cassandra®, DataStax Enterprise™, and DataStax Astra™ \n",
" - Currently, the toolkit is limited to SELECT queries and schema introspection operations. (Safety first)\n",
"\n",
"## Quick Start\n",
" - Install the cassio library\n",
" - Set environment variables for the Cassandra database you are connecting to\n",
" - Initialize CassandraDatabase\n",
" - Pass the tools to your agent with toolkit.get_tools()\n",
" - Sit back and watch it do all your work for you\n",
"\n",
"## Theory of Operation\n",
"Cassandra Query Language (CQL) is the primary *human-centric* way of interacting with a Cassandra database. While offering some flexibility when generating queries, it requires knowledge of Cassandra data modeling best practices. LLM function calling gives an agent the ability to reason and then choose a tool to satisfy the request. Agents using LLMs should reason using Cassandra-specific logic when choosing the appropriate toolkit or chain of toolkits. This reduces the randomness introduced when LLMs are forced to provide a top-down solution. Do you want an LLM to have complete unfettered access to your database? Yeah. Probably not. To accomplish this, we provide a prompt for use when constructing questions for the agent: \n",
"\n",
"```json\n",
"You are an Apache Cassandra expert query analysis bot with the following features \n",
"and rules:\n",
" - You will take a question from the end user about finding specific \n",
" data in the database.\n",
" - You will examine the schema of the database and create a query path. \n",
" - You will provide the user with the correct query to find the data they are looking \n",
" for, showing the steps provided by the query path.\n",
" - You will use best practices for querying Apache Cassandra using partition keys \n",
" and clustering columns.\n",
" - Avoid using ALLOW FILTERING in the query.\n",
" - The goal is to find a query path, so it may take querying other tables to get \n",
" to the final answer. \n",
"\n",
"The following is an example of a query path in JSON format:\n",
"\n",
" {\n",
" \"query_paths\": [\n",
" {\n",
" \"description\": \"Direct query to users table using email\",\n",
" \"steps\": [\n",
" {\n",
" \"table\": \"user_credentials\",\n",
" \"query\": \n",
" \"SELECT userid FROM user_credentials WHERE email = 'example@example.com';\"\n",
" },\n",
" {\n",
" \"table\": \"users\",\n",
" \"query\": \"SELECT * FROM users WHERE userid = ?;\"\n",
" }\n",
" ]\n",
" }\n",
" ]\n",
"}\n",
"```\n",
"\n",
"## Tools Provided\n",
"\n",
"### `cassandra_db_schema`\n",
"Gathers all schema information for the connected database or a specific schema. Critical for the agent when determining actions. \n",
"\n",
"### `cassandra_db_select_table_data`\n",
"Selects data from a specific keyspace and table. The agent can pass paramaters for a predicate and limits on the number of returned records. \n",
"\n",
"### `cassandra_db_query`\n",
"Expiriemental alternative to `cassandra_db_select_table_data` which takes a query string completely formed by the agent instead of parameters. *Warning*: This can lead to unusual queries that may not be as performant(or even work). This may be removed in future releases. If it does something cool, we want to know about that too. You never know!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Enviroment Setup\n",
"\n",
"Install the following Python modules:\n",
"\n",
"```bash\n",
"pip install ipykernel python-dotenv cassio langchain_openai langchain langchain-community langchainhub\n",
"```\n",
"\n",
"### .env file\n",
"Connection is via `cassio` using `auto=True` parameter, and the notebook uses OpenAI. You should create a `.env` file accordingly.\n",
"\n",
"For Casssandra, set:\n",
"```bash\n",
"CASSANDRA_CONTACT_POINTS\n",
"CASSANDRA_USERNAME\n",
"CASSANDRA_PASSWORD\n",
"CASSANDRA_KEYSPACE\n",
"```\n",
"\n",
"For Astra, set:\n",
"```bash\n",
"ASTRA_DB_APPLICATION_TOKEN\n",
"ASTRA_DB_DATABASE_ID\n",
"ASTRA_DB_KEYSPACE\n",
"```\n",
"\n",
"For example:\n",
"\n",
"```bash\n",
"# Connection to Astra:\n",
"ASTRA_DB_DATABASE_ID=a1b2c3d4-...\n",
"ASTRA_DB_APPLICATION_TOKEN=AstraCS:...\n",
"ASTRA_DB_KEYSPACE=notebooks\n",
"\n",
"# Also set \n",
"OPENAI_API_KEY=sk-....\n",
"```\n",
"\n",
"(You may also modify the below code to directly connect with `cassio`.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dotenv import load_dotenv\n",
"\n",
"load_dotenv(override=True)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Import necessary libraries\n",
"import os\n",
"\n",
"import cassio\n",
"from langchain import hub\n",
"from langchain.agents import AgentExecutor, create_openai_tools_agent\n",
"from langchain_community.agent_toolkits.cassandra_database.toolkit import (\n",
" CassandraDatabaseToolkit,\n",
")\n",
"from langchain_community.tools.cassandra_database.prompt import QUERY_PATH_PROMPT\n",
"from langchain_community.tools.cassandra_database.tool import (\n",
" GetSchemaCassandraDatabaseTool,\n",
" GetTableDataCassandraDatabaseTool,\n",
" QueryCassandraDatabaseTool,\n",
")\n",
"from langchain_community.utilities.cassandra_database import CassandraDatabase\n",
"from langchain_openai import ChatOpenAI"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Connect to a Cassandra Database"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"cassio.init(auto=True)\n",
"session = cassio.config.resolve_session()\n",
"if not session:\n",
" raise Exception(\n",
" \"Check environment configuration or manually configure cassio connection parameters\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Test data pep\n",
"\n",
"session = cassio.config.resolve_session()\n",
"\n",
"session.execute(\"\"\"DROP KEYSPACE IF EXISTS langchain_agent_test; \"\"\")\n",
"\n",
"session.execute(\n",
" \"\"\"\n",
"CREATE KEYSPACE if not exists langchain_agent_test \n",
"WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1};\n",
"\"\"\"\n",
")\n",
"\n",
"session.execute(\n",
" \"\"\"\n",
" CREATE TABLE IF NOT EXISTS langchain_agent_test.user_credentials (\n",
" user_email text PRIMARY KEY,\n",
" user_id UUID,\n",
" password TEXT\n",
");\n",
"\"\"\"\n",
")\n",
"\n",
"session.execute(\n",
" \"\"\"\n",
" CREATE TABLE IF NOT EXISTS langchain_agent_test.users (\n",
" id UUID PRIMARY KEY,\n",
" name TEXT,\n",
" email TEXT\n",
");\"\"\"\n",
")\n",
"\n",
"session.execute(\n",
" \"\"\"\n",
" CREATE TABLE IF NOT EXISTS langchain_agent_test.user_videos ( \n",
" user_id UUID,\n",
" video_id UUID,\n",
" title TEXT,\n",
" description TEXT,\n",
" PRIMARY KEY (user_id, video_id)\n",
");\n",
"\"\"\"\n",
")\n",
"\n",
"user_id = \"522b1fe2-2e36-4cef-a667-cd4237d08b89\"\n",
"video_id = \"27066014-bad7-9f58-5a30-f63fe03718f6\"\n",
"\n",
"session.execute(\n",
" f\"\"\"\n",
" INSERT INTO langchain_agent_test.user_credentials (user_id, user_email) \n",
" VALUES ({user_id}, 'patrick@datastax.com');\n",
"\"\"\"\n",
")\n",
"\n",
"session.execute(\n",
" f\"\"\"\n",
" INSERT INTO langchain_agent_test.users (id, name, email) \n",
" VALUES ({user_id}, 'Patrick McFadin', 'patrick@datastax.com');\n",
"\"\"\"\n",
")\n",
"\n",
"session.execute(\n",
" f\"\"\"\n",
" INSERT INTO langchain_agent_test.user_videos (user_id, video_id, title)\n",
" VALUES ({user_id}, {video_id}, 'Use Langflow to Build a LangChain LLM Application in 5 Minutes');\n",
"\"\"\"\n",
")\n",
"\n",
"session.set_keyspace(\"langchain_agent_test\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Create a CassandraDatabase instance\n",
"# Uses the cassio session to connect to the database\n",
"db = CassandraDatabase()\n",
"\n",
"# Create the Cassandra Database tools\n",
"query_tool = QueryCassandraDatabaseTool(db=db)\n",
"schema_tool = GetSchemaCassandraDatabaseTool(db=db)\n",
"select_data_tool = GetTableDataCassandraDatabaseTool(db=db)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Available tools:\n",
"cassandra_db_schema\t- \n",
" Input to this tool is a keyspace name, output is a table description \n",
" of Apache Cassandra tables.\n",
" If the query is not correct, an error message will be returned.\n",
" If an error is returned, report back to the user that the keyspace \n",
" doesn't exist and stop.\n",
" \n",
"cassandra_db_query\t- \n",
" Execute a CQL query against the database and get back the result.\n",
" If the query is not correct, an error message will be returned.\n",
" If an error is returned, rewrite the query, check the query, and try again.\n",
" \n",
"cassandra_db_select_table_data\t- \n",
" Tool for getting data from a table in an Apache Cassandra database. \n",
" Use the WHERE clause to specify the predicate for the query that uses the \n",
" primary key. A blank predicate will return all rows. Avoid this if possible. \n",
" Use the limit to specify the number of rows to return. A blank limit will \n",
" return all rows.\n",
" \n"
]
}
],
"source": [
"# Choose the LLM that will drive the agent\n",
"# Only certain models support this\n",
"llm = ChatOpenAI(temperature=0, model=\"gpt-4-1106-preview\")\n",
"toolkit = CassandraDatabaseToolkit(db=db)\n",
"\n",
"tools = toolkit.get_tools()\n",
"\n",
"print(\"Available tools:\")\n",
"for tool in tools:\n",
" print(tool.name + \"\\t- \" + tool.description)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"prompt = hub.pull(\"hwchase17/openai-tools-agent\")\n",
"\n",
"# Construct the OpenAI Tools agent\n",
"agent = create_openai_tools_agent(llm, tools, prompt)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `cassandra_db_schema` with `{'keyspace': 'langchain_agent_test'}`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3mTable Name: user_credentials\n",
"- Keyspace: langchain_agent_test\n",
"- Columns\n",
" - password (text)\n",
" - user_email (text)\n",
" - user_id (uuid)\n",
"- Partition Keys: (user_email)\n",
"- Clustering Keys: \n",
"\n",
"Table Name: user_videos\n",
"- Keyspace: langchain_agent_test\n",
"- Columns\n",
" - description (text)\n",
" - title (text)\n",
" - user_id (uuid)\n",
" - video_id (uuid)\n",
"- Partition Keys: (user_id)\n",
"- Clustering Keys: (video_id asc)\n",
"\n",
"\n",
"Table Name: users\n",
"- Keyspace: langchain_agent_test\n",
"- Columns\n",
" - email (text)\n",
" - id (uuid)\n",
" - name (text)\n",
"- Partition Keys: (id)\n",
"- Clustering Keys: \n",
"\n",
"\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `cassandra_db_select_table_data` with `{'keyspace': 'langchain_agent_test', 'table': 'user_credentials', 'predicate': \"user_email = 'patrick@datastax.com'\", 'limit': 1}`\n",
"\n",
"\n",
"\u001b[0m\u001b[38;5;200m\u001b[1;3mRow(user_email='patrick@datastax.com', password=None, user_id=UUID('522b1fe2-2e36-4cef-a667-cd4237d08b89'))\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `cassandra_db_select_table_data` with `{'keyspace': 'langchain_agent_test', 'table': 'user_videos', 'predicate': 'user_id = 522b1fe2-2e36-4cef-a667-cd4237d08b89', 'limit': 10}`\n",
"\n",
"\n",
"\u001b[0m\u001b[38;5;200m\u001b[1;3mRow(user_id=UUID('522b1fe2-2e36-4cef-a667-cd4237d08b89'), video_id=UUID('27066014-bad7-9f58-5a30-f63fe03718f6'), description='DataStax Academy is a free resource for learning Apache Cassandra.', title='DataStax Academy')\u001b[0m\u001b[32;1m\u001b[1;3mTo find all the videos that the user with the email address 'patrick@datastax.com' has uploaded to the `langchain_agent_test` keyspace, we can follow these steps:\n",
"\n",
"1. Query the `user_credentials` table to find the `user_id` associated with the email 'patrick@datastax.com'.\n",
"2. Use the `user_id` obtained from the first step to query the `user_videos` table to retrieve all the videos uploaded by the user.\n",
"\n",
"Here is the query path in JSON format:\n",
"\n",
"```json\n",
"{\n",
" \"query_paths\": [\n",
" {\n",
" \"description\": \"Find user_id from user_credentials and then query user_videos for all videos uploaded by the user\",\n",
" \"steps\": [\n",
" {\n",
" \"table\": \"user_credentials\",\n",
" \"query\": \"SELECT user_id FROM user_credentials WHERE user_email = 'patrick@datastax.com';\"\n",
" },\n",
" {\n",
" \"table\": \"user_videos\",\n",
" \"query\": \"SELECT * FROM user_videos WHERE user_id = 522b1fe2-2e36-4cef-a667-cd4237d08b89;\"\n",
" }\n",
" ]\n",
" }\n",
" ]\n",
"}\n",
"```\n",
"\n",
"Following this query path, we found that the user with the user_id `522b1fe2-2e36-4cef-a667-cd4237d08b89` has uploaded at least one video with the title 'DataStax Academy' and the description 'DataStax Academy is a free resource for learning Apache Cassandra.' The video_id for this video is `27066014-bad7-9f58-5a30-f63fe03718f6`. If there are more videos, the same query can be used to retrieve them, possibly with an increased limit if necessary.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"To find all the videos that the user with the email address 'patrick@datastax.com' has uploaded to the `langchain_agent_test` keyspace, we can follow these steps:\n",
"\n",
"1. Query the `user_credentials` table to find the `user_id` associated with the email 'patrick@datastax.com'.\n",
"2. Use the `user_id` obtained from the first step to query the `user_videos` table to retrieve all the videos uploaded by the user.\n",
"\n",
"Here is the query path in JSON format:\n",
"\n",
"```json\n",
"{\n",
" \"query_paths\": [\n",
" {\n",
" \"description\": \"Find user_id from user_credentials and then query user_videos for all videos uploaded by the user\",\n",
" \"steps\": [\n",
" {\n",
" \"table\": \"user_credentials\",\n",
" \"query\": \"SELECT user_id FROM user_credentials WHERE user_email = 'patrick@datastax.com';\"\n",
" },\n",
" {\n",
" \"table\": \"user_videos\",\n",
" \"query\": \"SELECT * FROM user_videos WHERE user_id = 522b1fe2-2e36-4cef-a667-cd4237d08b89;\"\n",
" }\n",
" ]\n",
" }\n",
" ]\n",
"}\n",
"```\n",
"\n",
"Following this query path, we found that the user with the user_id `522b1fe2-2e36-4cef-a667-cd4237d08b89` has uploaded at least one video with the title 'DataStax Academy' and the description 'DataStax Academy is a free resource for learning Apache Cassandra.' The video_id for this video is `27066014-bad7-9f58-5a30-f63fe03718f6`. If there are more videos, the same query can be used to retrieve them, possibly with an increased limit if necessary.\n"
]
}
],
"source": [
"input = (\n",
" QUERY_PATH_PROMPT\n",
" + \"\\n\\nHere is your task: Find all the videos that the user with the email address 'patrick@datastax.com' has uploaded to the langchain_agent_test keyspace.\"\n",
")\n",
"\n",
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)\n",
"\n",
"response = agent_executor.invoke({\"input\": input})\n",
"\n",
"print(response[\"output\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For a deepdive on creating a Cassandra DB agent see the [CQL agent cookbook](https://github.com/langchain-ai/langchain/blob/master/cookbook/cql_agent.ipynb)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -18,6 +18,9 @@ if TYPE_CHECKING:
from langchain_community.agent_toolkits.azure_cognitive_services import (
AzureCognitiveServicesToolkit, # noqa: F401
)
from langchain_community.agent_toolkits.cassandra_database.toolkit import (
CassandraDatabaseToolkit, # noqa: F401
)
from langchain_community.agent_toolkits.cogniswitch.toolkit import (
CogniswitchToolkit, # noqa: F401
)

View File

@ -0,0 +1 @@
"""Apache Cassandra Toolkit."""

View File

@ -0,0 +1,32 @@
"""Apache Cassandra Toolkit."""
from typing import List
from langchain_core.pydantic_v1 import Field
from langchain_community.agent_toolkits.base import BaseToolkit
from langchain_community.tools import BaseTool
from langchain_community.tools.cassandra_database.tool import (
GetSchemaCassandraDatabaseTool,
GetTableDataCassandraDatabaseTool,
QueryCassandraDatabaseTool,
)
from langchain_community.utilities.cassandra_database import CassandraDatabase
class CassandraDatabaseToolkit(BaseToolkit):
"""Toolkit for interacting with an Apache Cassandra database."""
db: CassandraDatabase = Field(exclude=True)
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
GetSchemaCassandraDatabaseTool(db=self.db),
QueryCassandraDatabaseTool(db=self.db),
GetTableDataCassandraDatabaseTool(db=self.db),
]

View File

@ -70,6 +70,11 @@ if TYPE_CHECKING:
from langchain_community.tools.brave_search.tool import (
BraveSearch, # noqa: F401
)
from langchain_community.tools.cassandra_database.tool import (
GetSchemaCassandraDatabaseTool, # noqa: F401
GetTableDataCassandraDatabaseTool, # noqa: F401
QueryCassandraDatabaseTool, # noqa: F401
)
from langchain_community.tools.cogniswitch.tool import (
CogniswitchKnowledgeRequest, # noqa: F401
CogniswitchKnowledgeSourceFile, # noqa: F401

View File

@ -0,0 +1 @@
""" Cassandra Tool """

View File

@ -0,0 +1,36 @@
"""Tools for interacting with an Apache Cassandra database."""
QUERY_PATH_PROMPT = """"
You are an Apache Cassandra expert query analysis bot with the following features
and rules:
- You will take a question from the end user about finding certain
data in the database.
- You will examine the schema of the database and create a query path.
- You will provide the user with the correct query to find the data they are looking
for showing the steps provided by the query path.
- You will use best practices for querying Apache Cassandra using partition keys
and clustering columns.
- Avoid using ALLOW FILTERING in the query.
- The goal is to find a query path, so it may take querying other tables to get
to the final answer.
The following is an example of a query path in JSON format:
{
"query_paths": [
{
"description": "Direct query to users table using email",
"steps": [
{
"table": "user_credentials",
"query":
"SELECT userid FROM user_credentials WHERE email = 'example@example.com';"
},
{
"table": "users",
"query": "SELECT * FROM users WHERE userid = ?;"
}
]
}
]
}"""

View File

@ -0,0 +1,126 @@
"""Tools for interacting with an Apache Cassandra database."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Type, Union
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool
from langchain_community.utilities.cassandra_database import CassandraDatabase
if TYPE_CHECKING:
from cassandra.cluster import ResultSet
class BaseCassandraDatabaseTool(BaseModel):
"""Base tool for interacting with an Apache Cassandra database."""
db: CassandraDatabase = Field(exclude=True)
class Config(BaseTool.Config):
pass
class _QueryCassandraDatabaseToolInput(BaseModel):
query: str = Field(..., description="A detailed and correct CQL query.")
class QueryCassandraDatabaseTool(BaseCassandraDatabaseTool, BaseTool):
"""Tool for querying an Apache Cassandra database with provided CQL."""
name: str = "cassandra_db_query"
description: str = """
Execute a CQL query against the database and get back the result.
If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again.
"""
args_schema: Type[BaseModel] = _QueryCassandraDatabaseToolInput
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Union[str, Sequence[Dict[str, Any]], ResultSet]:
"""Execute the query, return the results or an error message."""
return self.db.run_no_throw(query)
class _GetSchemaCassandraDatabaseToolInput(BaseModel):
keyspace: str = Field(
...,
description=("The name of the keyspace for which to return the schema."),
)
class GetSchemaCassandraDatabaseTool(BaseCassandraDatabaseTool, BaseTool):
"""Tool for getting the schema of a keyspace in an Apache Cassandra database."""
name: str = "cassandra_db_schema"
description: str = """
Input to this tool is a keyspace name, output is a table description
of Apache Cassandra tables.
If the query is not correct, an error message will be returned.
If an error is returned, report back to the user that the keyspace
doesn't exist and stop.
"""
args_schema: Type[BaseModel] = _GetSchemaCassandraDatabaseToolInput
def _run(
self,
keyspace: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for a keyspace."""
return self.db.get_keyspace_tables_str_no_throw(keyspace)
class _GetTableDataCassandraDatabaseToolInput(BaseModel):
keyspace: str = Field(
...,
description=("The name of the keyspace containing the table."),
)
table: str = Field(
...,
description=("The name of the table for which to return data."),
)
predicate: str = Field(
...,
description=("The predicate for the query that uses the primary key."),
)
limit: int = Field(
...,
description=("The maximum number of rows to return."),
)
class GetTableDataCassandraDatabaseTool(BaseCassandraDatabaseTool, BaseTool):
"""
Tool for getting data from a table in an Apache Cassandra database.
Use the WHERE clause to specify the predicate for the query that uses the
primary key. A blank predicate will return all rows. Avoid this if possible.
Use the limit to specify the number of rows to return. A blank limit will
return all rows.
"""
name: str = "cassandra_db_select_table_data"
description: str = """
Tool for getting data from a table in an Apache Cassandra database.
Use the WHERE clause to specify the predicate for the query that uses the
primary key. A blank predicate will return all rows. Avoid this if possible.
Use the limit to specify the number of rows to return. A blank limit will
return all rows.
"""
args_schema: Type[BaseModel] = _GetTableDataCassandraDatabaseToolInput
def _run(
self,
keyspace: str,
table: str,
predicate: str,
limit: int,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get data from a table in a keyspace."""
return self.db.get_table_data_no_throw(keyspace, table, predicate, limit)

View File

@ -0,0 +1,680 @@
"""Apache Cassandra database wrapper."""
from __future__ import annotations
import re
import traceback
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
if TYPE_CHECKING:
from cassandra.cluster import ResultSet, Session
IGNORED_KEYSPACES = [
"system",
"system_auth",
"system_distributed",
"system_schema",
"system_traces",
"system_views",
"datastax_sla",
"data_endpoint_auth",
]
class CassandraDatabase:
"""Apache Cassandra® database wrapper."""
def __init__(
self,
session: Optional[Session] = None,
exclude_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
cassio_init_kwargs: Optional[Dict[str, Any]] = None,
):
self._session = self._resolve_session(session, cassio_init_kwargs)
if not self._session:
raise ValueError("Session not provided and cannot be resolved")
self._exclude_keyspaces = IGNORED_KEYSPACES
self._exclude_tables = exclude_tables or []
self._include_tables = include_tables or []
def run(
self,
query: str,
fetch: str = "all",
include_columns: bool = False,
**kwargs: Any,
) -> Union[str, Sequence[Dict[str, Any]], ResultSet]:
"""Execute a CQL query and return the results."""
clean_query = self._validate_cql(query, "SELECT")
result = self._session.execute(clean_query, **kwargs)
if fetch == "all":
return list(result)
elif fetch == "one":
return result.one()._asdict() if result else {}
elif fetch == "cursor":
return result
else:
raise ValueError("Fetch parameter must be either 'one', 'all', or 'cursor'")
def run_no_throw(
self,
query: str,
fetch: str = "all",
include_columns: bool = False,
**kwargs: Any,
) -> Union[str, Sequence[Dict[str, Any]], ResultSet]:
"""Execute a CQL query and return the results or an error message."""
try:
return self.run(query, fetch, include_columns, **kwargs)
except Exception as e:
"""Format the error message"""
return f"Error: {e}\n{traceback.format_exc()}"
def get_keyspace_tables_str_no_throw(self, keyspace: str) -> str:
"""Get the tables for the specified keyspace."""
try:
schema_string = self.get_keyspace_tables_str(keyspace)
return schema_string
except Exception as e:
"""Format the error message"""
return f"Error: {e}\n{traceback.format_exc()}"
def get_keyspace_tables_str(self, keyspace: str) -> str:
"""Get the tables for the specified keyspace."""
tables = self.get_keyspace_tables(keyspace)
schema_string = ""
for table in tables:
schema_string += table.as_markdown() + "\n\n"
return schema_string
def get_keyspace_tables(self, keyspace: str) -> List[Table]:
"""Get the Table objects for the specified keyspace."""
schema = self._resolve_schema([keyspace])
if keyspace in schema:
return schema[keyspace]
else:
return []
def get_table_data_no_throw(
self, keyspace: str, table: str, predicate: str, limit: int
) -> str:
"""Get data from the specified table in the specified keyspace. Optionally can
take a predicate for the WHERE clause and a limit."""
try:
return self.get_table_data(keyspace, table, predicate, limit)
except Exception as e:
"""Format the error message"""
return f"Error: {e}\n{traceback.format_exc()}"
# This is a more basic string building function that doesn't use a query builder
# or prepared statements
# TODO: Refactor to use prepared statements
def get_table_data(
self, keyspace: str, table: str, predicate: str, limit: int
) -> str:
"""Get data from the specified table in the specified keyspace."""
query = f"SELECT * FROM {keyspace}.{table}"
if predicate:
query += f" WHERE {predicate}"
if limit:
query += f" LIMIT {limit}"
query += ";"
result = self.run(query, fetch="all")
data = "\n".join(str(row) for row in result)
return data
def get_context(self) -> Dict[str, Any]:
"""Return db context that you may want in agent prompt."""
keyspaces = self._fetch_keyspaces()
return {"keyspaces": ", ".join(keyspaces)}
def format_keyspace_to_markdown(
self, keyspace: str, tables: Optional[List[Table]] = None
) -> str:
"""
Generates a markdown representation of the schema for a specific keyspace
by iterating over all tables within that keyspace and calling their
as_markdown method.
Parameters:
- keyspace (str): The name of the keyspace to generate markdown
documentation for.
- tables (list[Table]): list of tables in the keyspace; it will be resolved
if not provided.
Returns:
A string containing the markdown representation of the specified
keyspace schema.
"""
if not tables:
tables = self.get_keyspace_tables(keyspace)
if tables:
output = f"## Keyspace: {keyspace}\n\n"
if tables:
for table in tables:
output += table.as_markdown(include_keyspace=False, header_level=3)
output += "\n\n"
else:
output += "No tables present in keyspace\n\n"
return output
else:
return ""
def format_schema_to_markdown(self) -> str:
"""
Generates a markdown representation of the schema for all keyspaces and tables
within the CassandraDatabase instance. This method utilizes the
format_keyspace_to_markdown method to create markdown sections for each
keyspace, assembling them into a comprehensive schema document.
Iterates through each keyspace in the database, utilizing
format_keyspace_to_markdown to generate markdown for each keyspace's schema,
including details of its tables. These sections are concatenated to form a
single markdown document that represents the schema of the entire database or
the subset of keyspaces that have been resolved in this instance.
Returns:
A markdown string that documents the schema of all resolved keyspaces and
their tables within this CassandraDatabase instance. This includes keyspace
names, table names, comments, columns, partition keys, clustering keys,
and indexes for each table.
"""
schema = self._resolve_schema()
output = "# Cassandra Database Schema\n\n"
for keyspace, tables in schema.items():
output += f"{self.format_keyspace_to_markdown(keyspace, tables)}\n\n"
return output
def _validate_cql(self, cql: str, type: str = "SELECT") -> str:
"""
Validates a CQL query string for basic formatting and safety checks.
Ensures that `cql` starts with the specified type (e.g., SELECT) and does
not contain content that could indicate CQL injection vulnerabilities.
Parameters:
- cql (str): The CQL query string to be validated.
- type (str): The expected starting keyword of the query, used to verify
that the query begins with the correct operation type
(e.g., "SELECT", "UPDATE"). Defaults to "SELECT".
Returns:
- str: The trimmed and validated CQL query string without a trailing semicolon.
Raises:
- ValueError: If the value of `type` is not supported
- DatabaseError: If `cql` is considered unsafe
"""
SUPPORTED_TYPES = ["SELECT"]
if type and type.upper() not in SUPPORTED_TYPES:
raise ValueError(
f"""Unsupported CQL type: {type}. Supported types:
{SUPPORTED_TYPES}"""
)
# Basic sanity checks
cql_trimmed = cql.strip()
if not cql_trimmed.upper().startswith(type.upper()):
raise DatabaseError(f"CQL must start with {type.upper()}.")
# Allow a trailing semicolon, but remove (it is optional with the Python driver)
cql_trimmed = cql_trimmed.rstrip(";")
# Consider content within matching quotes to be "safe"
# Remove single-quoted strings
cql_sanitized = re.sub(r"'.*?'", "", cql_trimmed)
# Remove double-quoted strings
cql_sanitized = re.sub(r'".*?"', "", cql_sanitized)
# Find unsafe content in the remaining CQL
if ";" in cql_sanitized:
raise DatabaseError(
"""Potentially unsafe CQL, as it contains a ; at a
place other than the end or within quotation marks."""
)
# The trimmed query, before modifications
return cql_trimmed
def _fetch_keyspaces(self, keyspace_list: Optional[List[str]] = None) -> List[str]:
"""
Fetches a list of keyspace names from the Cassandra database. The list can be
filtered by a provided list of keyspace names or by excluding predefined
keyspaces.
Parameters:
- keyspace_list (Optional[List[str]]): A list of keyspace names to specifically
include. If provided and not empty, the method returns only the keyspaces
present in this list. If not provided or empty, the method returns all
keyspaces except those specified in the _exclude_keyspaces attribute.
Returns:
- List[str]: A list of keyspace names according to the filtering criteria.
"""
all_keyspaces = self.run(
"SELECT keyspace_name FROM system_schema.keyspaces", fetch="all"
)
# Type check to ensure 'all_keyspaces' is a sequence of dictionaries
if not isinstance(all_keyspaces, Sequence):
raise TypeError("Expected a sequence of dictionaries from 'run' method.")
# Filtering keyspaces based on 'keyspace_list' and '_exclude_keyspaces'
filtered_keyspaces = []
for ks in all_keyspaces:
if not isinstance(ks, Dict):
continue # Skip if the row is not a dictionary.
keyspace_name = ks["keyspace_name"]
if keyspace_list and keyspace_name in keyspace_list:
filtered_keyspaces.append(keyspace_name)
elif not keyspace_list and keyspace_name not in self._exclude_keyspaces:
filtered_keyspaces.append(keyspace_name)
return filtered_keyspaces
def _fetch_schema_data(self, keyspace_list: List[str]) -> Tuple:
"""
Fetches schema data, including tables, columns, and indexes, filtered by a
list of keyspaces. This method constructs CQL queries to retrieve detailed
schema information from the specified keyspaces and executes them to gather
data about tables, columns, and indexes within those keyspaces.
Parameters:
- keyspace_list (List[str]): A list of keyspace names from which to fetch
schema data.
Returns:
- Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]: A
tuple containing three lists:
- The first list contains dictionaries of table details (keyspace name,
table name, and comment).
- The second list contains dictionaries of column details (keyspace name,
table name, column name, type, kind, and position).
- The third list contains dictionaries of index details (keyspace name,
table name, index name, kind, and options).
This method allows for efficiently fetching schema information for multiple
keyspaces in a single operation,
enabling applications to programmatically analyze or document the database
schema.
"""
# Construct IN clause for CQL query
keyspace_in_clause = ", ".join([f"'{ks}'" for ks in keyspace_list])
# Fetch filtered table details
tables_query = f"""SELECT keyspace_name, table_name, comment
FROM system_schema.tables
WHERE keyspace_name
IN ({keyspace_in_clause})"""
tables_data = self.run(tables_query, fetch="all")
# Fetch filtered column details
columns_query = f"""SELECT keyspace_name, table_name, column_name, type,
kind, clustering_order, position
FROM system_schema.columns
WHERE keyspace_name
IN ({keyspace_in_clause})"""
columns_data = self.run(columns_query, fetch="all")
# Fetch filtered index details
indexes_query = f"""SELECT keyspace_name, table_name, index_name,
kind, options
FROM system_schema.indexes
WHERE keyspace_name
IN ({keyspace_in_clause})"""
indexes_data = self.run(indexes_query, fetch="all")
return tables_data, columns_data, indexes_data
def _resolve_schema(
self, keyspace_list: Optional[List[str]] = None
) -> Dict[str, List[Table]]:
"""
Efficiently fetches and organizes Cassandra table schema information,
such as comments, columns, and indexes, into a dictionary mapping keyspace
names to lists of Table objects.
Returns:
A dictionary with keyspace names as keys and lists of Table objects as values,
where each Table object is populated with schema details appropriate for its
keyspace and table name.
"""
if not keyspace_list:
keyspace_list = self._fetch_keyspaces()
tables_data, columns_data, indexes_data = self._fetch_schema_data(keyspace_list)
keyspace_dict: dict = {}
for table_data in tables_data:
keyspace = table_data.keyspace_name
table_name = table_data.table_name
comment = table_data.comment
if self._include_tables and table_name not in self._include_tables:
continue
if self._exclude_tables and table_name in self._exclude_tables:
continue
# Filter columns and indexes for this table
table_columns = [
(c.column_name, c.type)
for c in columns_data
if c.keyspace_name == keyspace and c.table_name == table_name
]
partition_keys = [
c.column_name
for c in columns_data
if c.kind == "partition_key"
and c.keyspace_name == keyspace
and c.table_name == table_name
]
clustering_keys = [
(c.column_name, c.clustering_order)
for c in columns_data
if c.kind == "clustering"
and c.keyspace_name == keyspace
and c.table_name == table_name
]
table_indexes = [
(c.index_name, c.kind, c.options)
for c in indexes_data
if c.keyspace_name == keyspace and c.table_name == table_name
]
table_obj = Table(
keyspace=keyspace,
table_name=table_name,
comment=comment,
columns=table_columns,
partition=partition_keys,
clustering=clustering_keys,
indexes=table_indexes,
)
if keyspace not in keyspace_dict:
keyspace_dict[keyspace] = []
keyspace_dict[keyspace].append(table_obj)
return keyspace_dict
def _resolve_session(
self,
session: Optional[Session] = None,
cassio_init_kwargs: Optional[Dict[str, Any]] = None,
) -> Session:
"""
Attempts to resolve and return a Session object for use in database operations.
This function follows a specific order of precedence to determine the
appropriate session to use:
1. `session` parameter if given,
2. Existing `cassio` session,
3. A new `cassio` session derived from `cassio_init_kwargs`,
4. `None`
Parameters:
- session (Optional[Session]): An optional session to use directly.
- cassio_init_kwargs (Optional[Dict[str, Any]]): An optional dictionary of
keyword arguments to `cassio`.
Returns:
- Session: The resolved session object if successful, or `None` if the session
cannot be resolved.
Raises:
- ValueError: If `cassio_init_kwargs` is provided but is not a dictionary of
keyword arguments.
"""
# Prefer given session
if session:
return session
# If a session is not provided, create one using cassio if available
# dynamically import cassio to avoid circular imports
try:
import cassio.config # noqa: F401
except ImportError:
raise ValueError(
"cassio package not found, please install with" " `pip install cassio`"
)
# Use pre-existing session on cassio
s = cassio.config.resolve_session()
if s:
return s
# Try to init and return cassio session
if cassio_init_kwargs:
if isinstance(cassio_init_kwargs, dict):
cassio.init(**cassio_init_kwargs)
s = cassio.config.check_resolve_session()
return s
else:
raise ValueError("cassio_init_kwargs must be a keyword dictionary")
# return None if we're not able to resolve
return None
class DatabaseError(Exception):
"""Exception raised for errors in the database schema.
Attributes:
message -- explanation of the error
"""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class Table(BaseModel):
keyspace: str
"""The keyspace in which the table exists."""
table_name: str
"""The name of the table."""
comment: Optional[str] = None
"""The comment associated with the table."""
columns: List[Tuple[str, str]] = Field(default_factory=list)
partition: List[str] = Field(default_factory=list)
clustering: List[Tuple[str, str]] = Field(default_factory=list)
indexes: List[Tuple[str, str, str]] = Field(default_factory=list)
class Config:
frozen = True
@root_validator()
def check_required_fields(cls, class_values: dict) -> dict:
if not class_values["columns"]:
raise ValueError("non-empty column list for must be provided")
if not class_values["partition"]:
raise ValueError("non-empty partition list must be provided")
return class_values
@classmethod
def from_database(
cls, keyspace: str, table_name: str, db: CassandraDatabase
) -> Table:
columns, partition, clustering = cls._resolve_columns(keyspace, table_name, db)
return cls(
keyspace=keyspace,
table_name=table_name,
comment=cls._resolve_comment(keyspace, table_name, db),
columns=columns,
partition=partition,
clustering=clustering,
indexes=cls._resolve_indexes(keyspace, table_name, db),
)
def as_markdown(
self, include_keyspace: bool = True, header_level: Optional[int] = None
) -> str:
"""
Generates a Markdown representation of the Cassandra table schema, allowing for
customizable header levels for the table name section.
Parameters:
- include_keyspace (bool): If True, includes the keyspace in the output.
Defaults to True.
- header_level (Optional[int]): Specifies the markdown header level for the
table name.
If None, the table name is included without a header. Defaults to None
(no header level).
Returns:
- str: A string in Markdown format detailing the table name
(with optional header level),
keyspace (optional), comment, columns, partition keys, clustering keys
(with optional clustering order),
and indexes.
"""
output = ""
if header_level is not None:
output += f"{'#' * header_level} "
output += f"Table Name: {self.table_name}\n"
if include_keyspace:
output += f"- Keyspace: {self.keyspace}\n"
if self.comment:
output += f"- Comment: {self.comment}\n"
output += "- Columns\n"
for column, type in self.columns:
output += f" - {column} ({type})\n"
output += f"- Partition Keys: ({', '.join(self.partition)})\n"
output += "- Clustering Keys: "
if self.clustering:
cluster_list = []
for column, clustering_order in self.clustering:
if clustering_order.lower() == "none":
cluster_list.append(column)
else:
cluster_list.append(f"{column} {clustering_order}")
output += f"({', '.join(cluster_list)})\n"
if self.indexes:
output += "- Indexes\n"
for name, kind, options in self.indexes:
output += f" - {name} : kind={kind}, options={options}\n"
return output
@staticmethod
def _resolve_comment(
keyspace: str, table_name: str, db: CassandraDatabase
) -> Optional[str]:
result = db.run(
f"""SELECT comment
FROM system_schema.tables
WHERE keyspace_name = '{keyspace}'
AND table_name = '{table_name}';""",
fetch="one",
)
if isinstance(result, dict):
comment = result.get("comment")
if comment:
return comment
else:
return None # Default comment if none is found
else:
raise ValueError(
f"""Unexpected result type from db.run:
{type(result).__name__}"""
)
@staticmethod
def _resolve_columns(
keyspace: str, table_name: str, db: CassandraDatabase
) -> Tuple[List[Tuple[str, str]], List[str], List[Tuple[str, str]]]:
columns = []
partition_info = []
cluster_info = []
results = db.run(
f"""SELECT column_name, type, kind, clustering_order, position
FROM system_schema.columns
WHERE keyspace_name = '{keyspace}'
AND table_name = '{table_name}';"""
)
# Type check to ensure 'results' is a sequence of dictionaries.
if not isinstance(results, Sequence):
raise TypeError("Expected a sequence of dictionaries from 'run' method.")
for row in results:
if not isinstance(row, Dict):
continue # Skip if the row is not a dictionary.
columns.append((row["column_name"], row["type"]))
if row["kind"] == "partition_key":
partition_info.append((row["column_name"], row["position"]))
elif row["kind"] == "clustering":
cluster_info.append(
(row["column_name"], row["clustering_order"], row["position"])
)
partition = [
column_name for column_name, _ in sorted(partition_info, key=lambda x: x[1])
]
cluster = [
(column_name, clustering_order)
for column_name, clustering_order, _ in sorted(
cluster_info, key=lambda x: x[2]
)
]
return columns, partition, cluster
@staticmethod
def _resolve_indexes(
keyspace: str, table_name: str, db: CassandraDatabase
) -> List[Tuple[str, str, str]]:
indexes = []
results = db.run(
f"""SELECT index_name, kind, options
FROM system_schema.indexes
WHERE keyspace_name = '{keyspace}'
AND table_name = '{table_name}';"""
)
# Type check to ensure 'results' is a sequence of dictionaries
if not isinstance(results, Sequence):
raise TypeError("Expected a sequence of dictionaries from 'run' method.")
for row in results:
if not isinstance(row, Dict):
continue # Skip if the row is not a dictionary.
# Convert 'options' to string if it's not already,
# assuming it's JSON-like and needs conversion
index_options = row["options"]
if not isinstance(index_options, str):
# Assuming index_options needs to be serialized or simply converted
index_options = str(index_options)
indexes.append((row["index_name"], row["kind"], index_options))
return indexes

View File

@ -0,0 +1,85 @@
from collections import namedtuple
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from langchain_community.utilities.cassandra_database import (
CassandraDatabase,
DatabaseError,
Table,
)
# Define a namedtuple type
MockRow = namedtuple("MockRow", ["col1", "col2"])
class TestCassandraDatabase(object):
def setup_method(self) -> None:
self.mock_session = MagicMock()
self.cassandra_db = CassandraDatabase(session=self.mock_session)
def test_init_without_session(self) -> None:
with pytest.raises(ValueError):
CassandraDatabase()
def test_run_query(self) -> None:
# Mock the execute method to return an iterable of dictionaries directly
self.mock_session.execute.return_value = iter(
[{"col1": "val1", "col2": "val2"}]
)
# Execute the query
result = self.cassandra_db.run("SELECT * FROM table")
# Assert that the result is as expected
assert result == [{"col1": "val1", "col2": "val2"}]
# Verify that execute was called with the expected CQL query
self.mock_session.execute.assert_called_with("SELECT * FROM table")
def test_run_query_cursor(self) -> None:
mock_result_set = MagicMock()
self.mock_session.execute.return_value = mock_result_set
result = self.cassandra_db.run("SELECT * FROM table;", fetch="cursor")
assert result == mock_result_set
def test_run_query_invalid_fetch(self) -> None:
with pytest.raises(ValueError):
self.cassandra_db.run("SELECT * FROM table;", fetch="invalid")
def test_validate_cql_select(self) -> None:
query = "SELECT * FROM table;"
result = self.cassandra_db._validate_cql(query, "SELECT")
assert result == "SELECT * FROM table"
def test_validate_cql_unsupported_type(self) -> None:
query = "UPDATE table SET col=val;"
with pytest.raises(ValueError):
self.cassandra_db._validate_cql(query, "UPDATE")
def test_validate_cql_unsafe(self) -> None:
query = "SELECT * FROM table; DROP TABLE table;"
with pytest.raises(DatabaseError):
self.cassandra_db._validate_cql(query, "SELECT")
@patch(
"langchain_community.utilities.cassandra_database.CassandraDatabase._resolve_schema"
)
def test_format_schema_to_markdown(self, mock_resolve_schema: Any) -> None:
mock_table1 = MagicMock(spec=Table)
mock_table1.as_markdown.return_value = "## Keyspace: keyspace1"
mock_table2 = MagicMock(spec=Table)
mock_table2.as_markdown.return_value = "## Keyspace: keyspace2"
mock_resolve_schema.return_value = {
"keyspace1": [mock_table1],
"keyspace2": [mock_table2],
}
markdown = self.cassandra_db.format_schema_to_markdown()
assert markdown.startswith("# Cassandra Database Schema")
assert "## Keyspace: keyspace1" in markdown
assert "## Keyspace: keyspace2" in markdown
if __name__ == "__main__":
pytest.main()