{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "245065c6",
   "metadata": {},
   "source": [
    "# Vector SQL Retriever with MyScale\n",
    "\n",
    ">[MyScale](https://docs.myscale.com/en/) is an integrated vector database. You can access your database in SQL and also from here, LangChain. MyScale can make a use of [various data types and functions for filters](https://blog.myscale.com/2023/06/06/why-integrated-database-solution-can-boost-your-llm-apps/#filter-on-anything-without-constraints). It will boost up your LLM app no matter if you are scaling up your data or expand your system to broader application."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0246c5bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install clickhouse-sqlalchemy InstructorEmbedding sentence_transformers openai langchain-experimental"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7585d2c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import getpass\n",
    "from os import environ\n",
    "\n",
    "from langchain.chains import LLMChain\n",
    "from langchain.prompts import PromptTemplate\n",
    "from langchain_community.utilities import SQLDatabase\n",
    "from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain\n",
    "from langchain_openai import OpenAI\n",
    "from sqlalchemy import MetaData, create_engine\n",
    "\n",
    "MYSCALE_HOST = \"msc-4a9e710a.us-east-1.aws.staging.myscale.cloud\"\n",
    "MYSCALE_PORT = 443\n",
    "MYSCALE_USER = \"chatdata\"\n",
    "MYSCALE_PASSWORD = \"myscale_rocks\"\n",
    "OPENAI_API_KEY = getpass.getpass(\"OpenAI API Key:\")\n",
    "\n",
    "engine = create_engine(\n",
    "    f\"clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/default?protocol=https\"\n",
    ")\n",
    "metadata = MetaData(bind=engine)\n",
    "environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e08d9ddc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.embeddings import HuggingFaceInstructEmbeddings\n",
    "from langchain_experimental.sql.vector_sql import VectorSQLOutputParser\n",
    "\n",
    "output_parser = VectorSQLOutputParser.from_embeddings(\n",
    "    model=HuggingFaceInstructEmbeddings(\n",
    "        model_name=\"hkunlp/instructor-xl\", model_kwargs={\"device\": \"cpu\"}\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84b705b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.callbacks import StdOutCallbackHandler\n",
    "from langchain_community.utilities.sql_database import SQLDatabase\n",
    "from langchain_experimental.sql.prompt import MYSCALE_PROMPT\n",
    "from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain\n",
    "from langchain_openai import OpenAI\n",
    "\n",
    "chain = VectorSQLDatabaseChain(\n",
    "    llm_chain=LLMChain(\n",
    "        llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),\n",
    "        prompt=MYSCALE_PROMPT,\n",
    "    ),\n",
    "    top_k=10,\n",
    "    return_direct=True,\n",
    "    sql_cmd_parser=output_parser,\n",
    "    database=SQLDatabase(engine, None, metadata),\n",
    ")\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "pd.DataFrame(\n",
    "    chain.run(\n",
    "        \"Please give me 10 papers to ask what is PageRank?\",\n",
    "        callbacks=[StdOutCallbackHandler()],\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c09cda0",
   "metadata": {},
   "source": [
    "## SQL Database as Retriever"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "734d7ff5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain\n",
    "from langchain_experimental.retrievers.vector_sql_database import (\n",
    "    VectorSQLDatabaseChainRetriever,\n",
    ")\n",
    "from langchain_experimental.sql.prompt import MYSCALE_PROMPT\n",
    "from langchain_experimental.sql.vector_sql import (\n",
    "    VectorSQLDatabaseChain,\n",
    "    VectorSQLRetrieveAllOutputParser,\n",
    ")\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "output_parser_retrieve_all = VectorSQLRetrieveAllOutputParser.from_embeddings(\n",
    "    output_parser.model\n",
    ")\n",
    "\n",
    "chain = VectorSQLDatabaseChain.from_llm(\n",
    "    llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),\n",
    "    prompt=MYSCALE_PROMPT,\n",
    "    top_k=10,\n",
    "    return_direct=True,\n",
    "    db=SQLDatabase(engine, None, metadata),\n",
    "    sql_cmd_parser=output_parser_retrieve_all,\n",
    "    native_format=True,\n",
    ")\n",
    "\n",
    "# You need all those keys to get docs\n",
    "retriever = VectorSQLDatabaseChainRetriever(\n",
    "    sql_db_chain=chain, page_content_key=\"abstract\"\n",
    ")\n",
    "\n",
    "document_with_metadata_prompt = PromptTemplate(\n",
    "    input_variables=[\"page_content\", \"id\", \"title\", \"authors\", \"pubdate\", \"categories\"],\n",
    "    template=\"Content:\\n\\tTitle: {title}\\n\\tAbstract: {page_content}\\n\\tAuthors: {authors}\\n\\tDate of Publication: {pubdate}\\n\\tCategories: {categories}\\nSOURCE: {id}\",\n",
    ")\n",
    "\n",
    "chain = RetrievalQAWithSourcesChain.from_chain_type(\n",
    "    ChatOpenAI(\n",
    "        model_name=\"gpt-3.5-turbo-16k\", openai_api_key=OPENAI_API_KEY, temperature=0.6\n",
    "    ),\n",
    "    retriever=retriever,\n",
    "    chain_type=\"stuff\",\n",
    "    chain_type_kwargs={\n",
    "        \"document_prompt\": document_with_metadata_prompt,\n",
    "    },\n",
    "    return_source_documents=True,\n",
    ")\n",
    "ans = chain(\n",
    "    \"Please give me 10 papers to ask what is PageRank?\",\n",
    "    callbacks=[StdOutCallbackHandler()],\n",
    ")\n",
    "print(ans[\"answer\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4948ff25",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}