community[minor]: add Cassandra Database Toolkit (#20246)

**Description**: ToolKit and Tools for accessing data in a Cassandra
Database primarily for Agent integration. Initially, this includes the
following tools:
- `cassandra_db_schema` Gathers all schema information for the connected
database or a specific schema. Critical for the agent when determining
actions.
- `cassandra_db_select_table_data` Selects data from a specific keyspace
and table. The agent can pass paramaters for a predicate and limits on
the number of returned records.
- `cassandra_db_query` Expiriemental alternative to
`cassandra_db_select_table_data` which takes a query string completely
formed by the agent instead of parameters. May be removed in future
versions.

Includes unit test and two notebooks to demonstrate usage. 

**Dependencies**: cassio
**Twitter handle**: @PatrickMcFadin

---------

Co-authored-by: Phil Miesle <phil.miesle@datastax.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Patrick McFadin
2024-04-29 08:51:43 -07:00
committed by GitHub
parent b3e74f2b98
commit 3331865f6b
11 changed files with 2007 additions and 0 deletions

View File

@@ -0,0 +1,481 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Cassandra Database\n",
"\n",
"Apache Cassandra® is a widely used database for storing transactional application data. The introduction of functions and tooling in Large Language Models has opened up some exciting use cases for existing data in Generative AI applications. The Cassandra Database toolkit enables AI engineers to efficiently integrate Agents with Cassandra data, offering the following features: \n",
" - Fast data access through optimized queries. Most queries should run in single-digit ms or less. \n",
" - Schema introspection to enhance LLM reasoning capabilities \n",
" - Compatibility with various Cassandra deployments, including Apache Cassandra®, DataStax Enterprise™, and DataStax Astra™ \n",
" - Currently, the toolkit is limited to SELECT queries and schema introspection operations. (Safety first)\n",
"\n",
"## Quick Start\n",
" - Install the cassio library\n",
" - Set environment variables for the Cassandra database you are connecting to\n",
" - Initialize CassandraDatabase\n",
" - Pass the tools to your agent with toolkit.get_tools()\n",
" - Sit back and watch it do all your work for you\n",
"\n",
"## Theory of Operation\n",
"Cassandra Query Language (CQL) is the primary *human-centric* way of interacting with a Cassandra database. While offering some flexibility when generating queries, it requires knowledge of Cassandra data modeling best practices. LLM function calling gives an agent the ability to reason and then choose a tool to satisfy the request. Agents using LLMs should reason using Cassandra-specific logic when choosing the appropriate toolkit or chain of toolkits. This reduces the randomness introduced when LLMs are forced to provide a top-down solution. Do you want an LLM to have complete unfettered access to your database? Yeah. Probably not. To accomplish this, we provide a prompt for use when constructing questions for the agent: \n",
"\n",
"```json\n",
"You are an Apache Cassandra expert query analysis bot with the following features \n",
"and rules:\n",
" - You will take a question from the end user about finding specific \n",
" data in the database.\n",
" - You will examine the schema of the database and create a query path. \n",
" - You will provide the user with the correct query to find the data they are looking \n",
" for, showing the steps provided by the query path.\n",
" - You will use best practices for querying Apache Cassandra using partition keys \n",
" and clustering columns.\n",
" - Avoid using ALLOW FILTERING in the query.\n",
" - The goal is to find a query path, so it may take querying other tables to get \n",
" to the final answer. \n",
"\n",
"The following is an example of a query path in JSON format:\n",
"\n",
" {\n",
" \"query_paths\": [\n",
" {\n",
" \"description\": \"Direct query to users table using email\",\n",
" \"steps\": [\n",
" {\n",
" \"table\": \"user_credentials\",\n",
" \"query\": \n",
" \"SELECT userid FROM user_credentials WHERE email = 'example@example.com';\"\n",
" },\n",
" {\n",
" \"table\": \"users\",\n",
" \"query\": \"SELECT * FROM users WHERE userid = ?;\"\n",
" }\n",
" ]\n",
" }\n",
" ]\n",
"}\n",
"```\n",
"\n",
"## Tools Provided\n",
"\n",
"### `cassandra_db_schema`\n",
"Gathers all schema information for the connected database or a specific schema. Critical for the agent when determining actions. \n",
"\n",
"### `cassandra_db_select_table_data`\n",
"Selects data from a specific keyspace and table. The agent can pass paramaters for a predicate and limits on the number of returned records. \n",
"\n",
"### `cassandra_db_query`\n",
"Expiriemental alternative to `cassandra_db_select_table_data` which takes a query string completely formed by the agent instead of parameters. *Warning*: This can lead to unusual queries that may not be as performant(or even work). This may be removed in future releases. If it does something cool, we want to know about that too. You never know!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Enviroment Setup\n",
"\n",
"Install the following Python modules:\n",
"\n",
"```bash\n",
"pip install ipykernel python-dotenv cassio langchain_openai langchain langchain-community langchainhub\n",
"```\n",
"\n",
"### .env file\n",
"Connection is via `cassio` using `auto=True` parameter, and the notebook uses OpenAI. You should create a `.env` file accordingly.\n",
"\n",
"For Casssandra, set:\n",
"```bash\n",
"CASSANDRA_CONTACT_POINTS\n",
"CASSANDRA_USERNAME\n",
"CASSANDRA_PASSWORD\n",
"CASSANDRA_KEYSPACE\n",
"```\n",
"\n",
"For Astra, set:\n",
"```bash\n",
"ASTRA_DB_APPLICATION_TOKEN\n",
"ASTRA_DB_DATABASE_ID\n",
"ASTRA_DB_KEYSPACE\n",
"```\n",
"\n",
"For example:\n",
"\n",
"```bash\n",
"# Connection to Astra:\n",
"ASTRA_DB_DATABASE_ID=a1b2c3d4-...\n",
"ASTRA_DB_APPLICATION_TOKEN=AstraCS:...\n",
"ASTRA_DB_KEYSPACE=notebooks\n",
"\n",
"# Also set \n",
"OPENAI_API_KEY=sk-....\n",
"```\n",
"\n",
"(You may also modify the below code to directly connect with `cassio`.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dotenv import load_dotenv\n",
"\n",
"load_dotenv(override=True)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Import necessary libraries\n",
"import os\n",
"\n",
"import cassio\n",
"from langchain import hub\n",
"from langchain.agents import AgentExecutor, create_openai_tools_agent\n",
"from langchain_community.agent_toolkits.cassandra_database.toolkit import (\n",
" CassandraDatabaseToolkit,\n",
")\n",
"from langchain_community.tools.cassandra_database.prompt import QUERY_PATH_PROMPT\n",
"from langchain_community.tools.cassandra_database.tool import (\n",
" GetSchemaCassandraDatabaseTool,\n",
" GetTableDataCassandraDatabaseTool,\n",
" QueryCassandraDatabaseTool,\n",
")\n",
"from langchain_community.utilities.cassandra_database import CassandraDatabase\n",
"from langchain_openai import ChatOpenAI"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Connect to a Cassandra Database"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"cassio.init(auto=True)\n",
"session = cassio.config.resolve_session()\n",
"if not session:\n",
" raise Exception(\n",
" \"Check environment configuration or manually configure cassio connection parameters\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Test data pep\n",
"\n",
"session = cassio.config.resolve_session()\n",
"\n",
"session.execute(\"\"\"DROP KEYSPACE IF EXISTS langchain_agent_test; \"\"\")\n",
"\n",
"session.execute(\n",
" \"\"\"\n",
"CREATE KEYSPACE if not exists langchain_agent_test \n",
"WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1};\n",
"\"\"\"\n",
")\n",
"\n",
"session.execute(\n",
" \"\"\"\n",
" CREATE TABLE IF NOT EXISTS langchain_agent_test.user_credentials (\n",
" user_email text PRIMARY KEY,\n",
" user_id UUID,\n",
" password TEXT\n",
");\n",
"\"\"\"\n",
")\n",
"\n",
"session.execute(\n",
" \"\"\"\n",
" CREATE TABLE IF NOT EXISTS langchain_agent_test.users (\n",
" id UUID PRIMARY KEY,\n",
" name TEXT,\n",
" email TEXT\n",
");\"\"\"\n",
")\n",
"\n",
"session.execute(\n",
" \"\"\"\n",
" CREATE TABLE IF NOT EXISTS langchain_agent_test.user_videos ( \n",
" user_id UUID,\n",
" video_id UUID,\n",
" title TEXT,\n",
" description TEXT,\n",
" PRIMARY KEY (user_id, video_id)\n",
");\n",
"\"\"\"\n",
")\n",
"\n",
"user_id = \"522b1fe2-2e36-4cef-a667-cd4237d08b89\"\n",
"video_id = \"27066014-bad7-9f58-5a30-f63fe03718f6\"\n",
"\n",
"session.execute(\n",
" f\"\"\"\n",
" INSERT INTO langchain_agent_test.user_credentials (user_id, user_email) \n",
" VALUES ({user_id}, 'patrick@datastax.com');\n",
"\"\"\"\n",
")\n",
"\n",
"session.execute(\n",
" f\"\"\"\n",
" INSERT INTO langchain_agent_test.users (id, name, email) \n",
" VALUES ({user_id}, 'Patrick McFadin', 'patrick@datastax.com');\n",
"\"\"\"\n",
")\n",
"\n",
"session.execute(\n",
" f\"\"\"\n",
" INSERT INTO langchain_agent_test.user_videos (user_id, video_id, title)\n",
" VALUES ({user_id}, {video_id}, 'Use Langflow to Build a LangChain LLM Application in 5 Minutes');\n",
"\"\"\"\n",
")\n",
"\n",
"session.set_keyspace(\"langchain_agent_test\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Create a CassandraDatabase instance\n",
"# Uses the cassio session to connect to the database\n",
"db = CassandraDatabase()\n",
"\n",
"# Create the Cassandra Database tools\n",
"query_tool = QueryCassandraDatabaseTool(db=db)\n",
"schema_tool = GetSchemaCassandraDatabaseTool(db=db)\n",
"select_data_tool = GetTableDataCassandraDatabaseTool(db=db)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Available tools:\n",
"cassandra_db_schema\t- \n",
" Input to this tool is a keyspace name, output is a table description \n",
" of Apache Cassandra tables.\n",
" If the query is not correct, an error message will be returned.\n",
" If an error is returned, report back to the user that the keyspace \n",
" doesn't exist and stop.\n",
" \n",
"cassandra_db_query\t- \n",
" Execute a CQL query against the database and get back the result.\n",
" If the query is not correct, an error message will be returned.\n",
" If an error is returned, rewrite the query, check the query, and try again.\n",
" \n",
"cassandra_db_select_table_data\t- \n",
" Tool for getting data from a table in an Apache Cassandra database. \n",
" Use the WHERE clause to specify the predicate for the query that uses the \n",
" primary key. A blank predicate will return all rows. Avoid this if possible. \n",
" Use the limit to specify the number of rows to return. A blank limit will \n",
" return all rows.\n",
" \n"
]
}
],
"source": [
"# Choose the LLM that will drive the agent\n",
"# Only certain models support this\n",
"llm = ChatOpenAI(temperature=0, model=\"gpt-4-1106-preview\")\n",
"toolkit = CassandraDatabaseToolkit(db=db)\n",
"\n",
"tools = toolkit.get_tools()\n",
"\n",
"print(\"Available tools:\")\n",
"for tool in tools:\n",
" print(tool.name + \"\\t- \" + tool.description)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"prompt = hub.pull(\"hwchase17/openai-tools-agent\")\n",
"\n",
"# Construct the OpenAI Tools agent\n",
"agent = create_openai_tools_agent(llm, tools, prompt)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `cassandra_db_schema` with `{'keyspace': 'langchain_agent_test'}`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3mTable Name: user_credentials\n",
"- Keyspace: langchain_agent_test\n",
"- Columns\n",
" - password (text)\n",
" - user_email (text)\n",
" - user_id (uuid)\n",
"- Partition Keys: (user_email)\n",
"- Clustering Keys: \n",
"\n",
"Table Name: user_videos\n",
"- Keyspace: langchain_agent_test\n",
"- Columns\n",
" - description (text)\n",
" - title (text)\n",
" - user_id (uuid)\n",
" - video_id (uuid)\n",
"- Partition Keys: (user_id)\n",
"- Clustering Keys: (video_id asc)\n",
"\n",
"\n",
"Table Name: users\n",
"- Keyspace: langchain_agent_test\n",
"- Columns\n",
" - email (text)\n",
" - id (uuid)\n",
" - name (text)\n",
"- Partition Keys: (id)\n",
"- Clustering Keys: \n",
"\n",
"\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `cassandra_db_select_table_data` with `{'keyspace': 'langchain_agent_test', 'table': 'user_credentials', 'predicate': \"user_email = 'patrick@datastax.com'\", 'limit': 1}`\n",
"\n",
"\n",
"\u001b[0m\u001b[38;5;200m\u001b[1;3mRow(user_email='patrick@datastax.com', password=None, user_id=UUID('522b1fe2-2e36-4cef-a667-cd4237d08b89'))\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `cassandra_db_select_table_data` with `{'keyspace': 'langchain_agent_test', 'table': 'user_videos', 'predicate': 'user_id = 522b1fe2-2e36-4cef-a667-cd4237d08b89', 'limit': 10}`\n",
"\n",
"\n",
"\u001b[0m\u001b[38;5;200m\u001b[1;3mRow(user_id=UUID('522b1fe2-2e36-4cef-a667-cd4237d08b89'), video_id=UUID('27066014-bad7-9f58-5a30-f63fe03718f6'), description='DataStax Academy is a free resource for learning Apache Cassandra.', title='DataStax Academy')\u001b[0m\u001b[32;1m\u001b[1;3mTo find all the videos that the user with the email address 'patrick@datastax.com' has uploaded to the `langchain_agent_test` keyspace, we can follow these steps:\n",
"\n",
"1. Query the `user_credentials` table to find the `user_id` associated with the email 'patrick@datastax.com'.\n",
"2. Use the `user_id` obtained from the first step to query the `user_videos` table to retrieve all the videos uploaded by the user.\n",
"\n",
"Here is the query path in JSON format:\n",
"\n",
"```json\n",
"{\n",
" \"query_paths\": [\n",
" {\n",
" \"description\": \"Find user_id from user_credentials and then query user_videos for all videos uploaded by the user\",\n",
" \"steps\": [\n",
" {\n",
" \"table\": \"user_credentials\",\n",
" \"query\": \"SELECT user_id FROM user_credentials WHERE user_email = 'patrick@datastax.com';\"\n",
" },\n",
" {\n",
" \"table\": \"user_videos\",\n",
" \"query\": \"SELECT * FROM user_videos WHERE user_id = 522b1fe2-2e36-4cef-a667-cd4237d08b89;\"\n",
" }\n",
" ]\n",
" }\n",
" ]\n",
"}\n",
"```\n",
"\n",
"Following this query path, we found that the user with the user_id `522b1fe2-2e36-4cef-a667-cd4237d08b89` has uploaded at least one video with the title 'DataStax Academy' and the description 'DataStax Academy is a free resource for learning Apache Cassandra.' The video_id for this video is `27066014-bad7-9f58-5a30-f63fe03718f6`. If there are more videos, the same query can be used to retrieve them, possibly with an increased limit if necessary.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"To find all the videos that the user with the email address 'patrick@datastax.com' has uploaded to the `langchain_agent_test` keyspace, we can follow these steps:\n",
"\n",
"1. Query the `user_credentials` table to find the `user_id` associated with the email 'patrick@datastax.com'.\n",
"2. Use the `user_id` obtained from the first step to query the `user_videos` table to retrieve all the videos uploaded by the user.\n",
"\n",
"Here is the query path in JSON format:\n",
"\n",
"```json\n",
"{\n",
" \"query_paths\": [\n",
" {\n",
" \"description\": \"Find user_id from user_credentials and then query user_videos for all videos uploaded by the user\",\n",
" \"steps\": [\n",
" {\n",
" \"table\": \"user_credentials\",\n",
" \"query\": \"SELECT user_id FROM user_credentials WHERE user_email = 'patrick@datastax.com';\"\n",
" },\n",
" {\n",
" \"table\": \"user_videos\",\n",
" \"query\": \"SELECT * FROM user_videos WHERE user_id = 522b1fe2-2e36-4cef-a667-cd4237d08b89;\"\n",
" }\n",
" ]\n",
" }\n",
" ]\n",
"}\n",
"```\n",
"\n",
"Following this query path, we found that the user with the user_id `522b1fe2-2e36-4cef-a667-cd4237d08b89` has uploaded at least one video with the title 'DataStax Academy' and the description 'DataStax Academy is a free resource for learning Apache Cassandra.' The video_id for this video is `27066014-bad7-9f58-5a30-f63fe03718f6`. If there are more videos, the same query can be used to retrieve them, possibly with an increased limit if necessary.\n"
]
}
],
"source": [
"input = (\n",
" QUERY_PATH_PROMPT\n",
" + \"\\n\\nHere is your task: Find all the videos that the user with the email address 'patrick@datastax.com' has uploaded to the langchain_agent_test keyspace.\"\n",
")\n",
"\n",
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)\n",
"\n",
"response = agent_executor.invoke({\"input\": input})\n",
"\n",
"print(response[\"output\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For a deepdive on creating a Cassandra DB agent see the [CQL agent cookbook](https://github.com/langchain-ai/langchain/blob/master/cookbook/cql_agent.ipynb)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}