Compare commits

...

2 Commits

Author SHA1 Message Date
Harrison Chase
159cc166c2 cr 2023-03-19 16:43:14 -07:00
Harrison Chase
075de91675 dbpedia 2023-03-19 16:42:48 -07:00
4 changed files with 219 additions and 0 deletions

View File

@@ -0,0 +1,106 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "546fd3f7",
"metadata": {},
"source": [
"# DBPedia\n",
"\n",
"This example shows how you can use LLMs to interact in natural language with a SPARKQL database."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6f3bf955",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.chains.dbpedia.base import DBPediaChain\n",
"model = ChatOpenAI(model_name=\"gpt-4\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "09dbd131",
"metadata": {},
"outputs": [],
"source": [
"chain = DBPediaChain.from_llm(model, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "408cb57d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new DBPediaChain chain...\u001b[0m\n",
"Query written:\n",
"\u001b[32;1m\u001b[1;3mSELECT ?capital WHERE { \n",
" ?country rdfs:label \"Wakanda\"@en . \n",
" ?country dbo:capital ?capital . \n",
" ?capital rdfs:label ?capitalLabel .\n",
" FILTER (LANG(?capitalLabel) = 'en') \n",
"}\u001b[0m\n",
"Response gotten:\n",
"\u001b[32;1m\u001b[1;3m{'head': {'link': [], 'vars': ['capital']}, 'results': {'distinct': False, 'ordered': True, 'bindings': []}}\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'There is no capital information available for Wakanda in the provided SPARQL query response. This is because Wakanda is a fictional country in the Marvel Cinematic Universe and does not have a real-world capital.'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run(\"what is the capital of wakanda?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "950e2472",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.dbpedia.prompt import ANSWER_PROMPT_SELECTOR, PROMPT_SELECTOR
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class DBPediaChain(Chain):
query_chain: LLMChain
answer_chain: LLMChain
input_key: str = "question"
output_key: str = "answer"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
query_prompt: Optional[BasePromptTemplate] = None,
answer_prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> DBPediaChain:
query_prompt = query_prompt or PROMPT_SELECTOR.get_prompt(llm)
query_chain = LLMChain(llm=llm, prompt=query_prompt)
answer_prompt = answer_prompt or ANSWER_PROMPT_SELECTOR.get_prompt(llm)
answer_chain = LLMChain(llm=llm, prompt=answer_prompt)
return cls(query_chain=query_chain, answer_chain=answer_chain, **kwargs)
@property
def input_keys(self) -> List[str]:
return [self.input_key]
@property
def output_keys(self) -> List[str]:
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
from SPARQLWrapper import JSON, SPARQLWrapper
sparql = SPARQLWrapper("http://dbpedia.org/sparql")
sparql.setReturnFormat(JSON)
query = self.query_chain.run(inputs[self.input_key])
self.callback_manager.on_text("Query written:", end="\n", verbose=self.verbose)
self.callback_manager.on_text(
query, color="green", end="\n", verbose=self.verbose
)
sparql.setQuery(query)
result = sparql.query().convert()
self.callback_manager.on_text(
"Response gotten:", end="\n", verbose=self.verbose
)
self.callback_manager.on_text(
result, color="green", end="\n", verbose=self.verbose
)
answer = self.answer_chain.run(
question=inputs[self.input_key], query=query, response=result
)
return {self.output_key: answer}

View File

@@ -0,0 +1,52 @@
from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import HumanMessage
TEMPLATE = """Write a sparkql query to execute against DBPedia to answer the following question
Question: {question}
SPARQL Query:"""
PROMPT = PromptTemplate.from_template(TEMPLATE)
INSTRUCTIONS_TEMPLATE = """Write a sparkql query to execute against DBPedia to answer the following question.
Your answer should be a valid SPARKQL query and NOTHING else.
Always return just a SPARKQL query."""
INSTRUCTIONS = HumanMessage(content=INSTRUCTIONS_TEMPLATE)
CHAT_PROMPT = ChatPromptTemplate.from_messages(
[INSTRUCTIONS, HumanMessagePromptTemplate.from_template("{question}")]
)
PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=PROMPT, conditionals=[(is_chat_model, CHAT_PROMPT)]
)
ANSWER_TEMPLATE = """Write a sparkql query to execute against DBPedia to answer the following question
Question: {question}
SPARKQL Query: {query}
SPARKQL Response: {response}
Final Answer (in plain English):"""
ANSWER_PROMPT = PromptTemplate.from_template(ANSWER_TEMPLATE)
ANSWER_INSTRUCTIONS_TEMPLATE = """I wrote this SPARKQL query:
----------
{query}
----------
I got this response:
----------
{response}
----------
Now, use the above information to answer my next question."""
ANSWER_INSTRUCTIONS = HumanMessagePromptTemplate.from_template(
ANSWER_INSTRUCTIONS_TEMPLATE
)
ANSWER_CHAT_PROMPT = ChatPromptTemplate.from_messages(
[ANSWER_INSTRUCTIONS, HumanMessagePromptTemplate.from_template("{question}")]
)
ANSWER_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=ANSWER_PROMPT, conditionals=[(is_chat_model, ANSWER_CHAT_PROMPT)]
)