mirror of
https://github.com/hwchase17/langchain.git
synced 2025-11-05 10:45:45 +00:00
.devcontainer
.github
cookbook
autogpt
data
video_captioning
Gemma_LangChain.ipynb
LLaMA2_sql_chat.ipynb
Multi_modal_RAG.ipynb
Multi_modal_RAG_google.ipynb
RAPTOR.ipynb
README.md
Semi_Structured_RAG.ipynb
Semi_structured_and_multi_modal_RAG.ipynb
Semi_structured_multi_modal_RAG_LLaMA2.ipynb
advanced_rag_eval.ipynb
agent_fireworks_ai_langchain_mongodb.ipynb
agent_vectorstore.ipynb
airbyte_github.ipynb
amazon_personalize_how_to.ipynb
analyze_document.ipynb
anthropic_structured_outputs.ipynb
apache_kafka_message_handling.ipynb
azure_container_apps_dynamic_sessions_data_analyst.ipynb
baby_agi.ipynb
baby_agi_with_agent.ipynb
camel_role_playing.ipynb
causal_program_aided_language_model.ipynb
code-analysis-deeplake.ipynb
cql_agent.ipynb
custom_agent_with_plugin_retrieval.ipynb
custom_agent_with_plugin_retrieval_using_plugnplai.ipynb
custom_agent_with_tool_retrieval.ipynb
custom_multi_action_agent.ipynb
databricks_sql_db.ipynb
deeplake_semantic_search_over_chat.ipynb
docugami_xml_kg_rag.ipynb
elasticsearch_db_qa.ipynb
extraction_openai_tools.ipynb
fake_llm.ipynb
fireworks_rag.ipynb
forward_looking_retrieval_augmented_generation.ipynb
generative_agents_interactive_simulacra_of_human_behavior.ipynb
gymnasium_agent_simulation.ipynb
hugginggpt.ipynb
human_approval.ipynb
human_input_chat_model.ipynb
human_input_llm.ipynb
hypothetical_document_embeddings.ipynb
img-to_img-search_CLIP_ChromaDB.ipynb
langgraph_agentic_rag.ipynb
langgraph_crag.ipynb
langgraph_self_rag.ipynb
learned_prompt_optimization.ipynb
llm_bash.ipynb
llm_checker.ipynb
llm_math.ipynb
llm_summarization_checker.ipynb
llm_symbolic_math.ipynb
meta_prompt.ipynb
mongodb-langchain-cache-memory.ipynb
multi_modal_QA.ipynb
multi_modal_RAG_chroma.ipynb
multi_modal_RAG_vdms.ipynb
multi_modal_output_agent.ipynb
multi_player_dnd.ipynb
multiagent_authoritarian.ipynb
multiagent_bidding.ipynb
myscale_vector_sql.ipynb
nomic_embedding_rag.ipynb
nomic_multimodal_rag.ipynb
openai_functions_retrieval_qa.ipynb
openai_v1_cookbook.ipynb
optimization.ipynb
oracleai_demo.ipynb
petting_zoo.ipynb
plan_and_execute_agent.ipynb
press_releases.ipynb
program_aided_language_model.ipynb
qa_citations.ipynb
qianfan_baidu_elasticesearch_RAG.ipynb
rag-locally-on-intel-cpu.ipynb
rag_fusion.ipynb
rag_semantic_chunking_azureaidocintelligence.ipynb
rag_upstage_layout_analysis_groundedness_check.ipynb
rag_with_quantized_embeddings.ipynb
retrieval_in_sql.ipynb
rewrite.ipynb
sales_agent_with_context.ipynb
selecting_llms_based_on_context_length.ipynb
self-discover.ipynb
self_query_hotel_search.ipynb
sharedmemory_for_tools.ipynb
smart_llm.ipynb
sql_db_qa.mdx
stepback-qa.ipynb
together_ai.ipynb
tool_call_messages.ipynb
tree_of_thought.ipynb
twitter-the-algorithm-analysis-deeplake.ipynb
two_agent_debate_tools.ipynb
two_player_dnd.ipynb
visual_RAG_vdms.ipynb
wikibase_agent.ipynb
docker
docs
libs
scripts
templates
.gitattributes
.gitignore
.readthedocs.yaml
CITATION.cff
LICENSE
MIGRATE.md
Makefile
README.md
SECURITY.md
poetry.lock
poetry.toml
pyproject.toml
yarn.lock
Updates docs and cookbooks to import ChatOpenAI, OpenAI, and OpenAI Embeddings from `langchain_openai` There are likely more --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
693 lines
30 KiB
Plaintext
693 lines
30 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "82f3f65d-fbcb-4e8e-b04b-959856283643",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Causal program-aided language (CPAL) chain\n",
|
|
"\n",
|
|
"The CPAL chain builds on the recent PAL to stop LLM hallucination. The problem with the PAL approach is that it hallucinates on a math problem with a nested chain of dependence. The innovation here is that this new CPAL approach includes causal structure to fix hallucination.\n",
|
|
"\n",
|
|
"The original [PR's description](https://github.com/langchain-ai/langchain/pull/6255) contains a full overview.\n",
|
|
"\n",
|
|
"Using the CPAL chain, the LLM translated this\n",
|
|
"\n",
|
|
" \"Tim buys the same number of pets as Cindy and Boris.\"\n",
|
|
" \"Cindy buys the same number of pets as Bill plus Bob.\"\n",
|
|
" \"Boris buys the same number of pets as Ben plus Beth.\"\n",
|
|
" \"Bill buys the same number of pets as Obama.\"\n",
|
|
" \"Bob buys the same number of pets as Obama.\"\n",
|
|
" \"Ben buys the same number of pets as Obama.\"\n",
|
|
" \"Beth buys the same number of pets as Obama.\"\n",
|
|
" \"If Obama buys one pet, how many pets total does everyone buy?\"\n",
|
|
"\n",
|
|
"\n",
|
|
"into this\n",
|
|
"\n",
|
|
".\n",
|
|
"\n",
|
|
"Outline of code examples demoed in this notebook.\n",
|
|
"\n",
|
|
"1. CPAL's value against hallucination: CPAL vs PAL \n",
|
|
" 1.1 Complex narrative \n",
|
|
" 1.2 Unanswerable math word problem \n",
|
|
"2. CPAL's three types of causal diagrams ([The Book of Why](https://en.wikipedia.org/wiki/The_Book_of_Why)). \n",
|
|
" 2.1 Mediator \n",
|
|
" 2.2 Collider \n",
|
|
" 2.3 Confounder "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "1370e40f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from IPython.display import SVG\n",
|
|
"from langchain_experimental.cpal.base import CPALChain\n",
|
|
"from langchain_experimental.pal_chain import PALChain\n",
|
|
"from langchain_openai import OpenAI\n",
|
|
"\n",
|
|
"llm = OpenAI(temperature=0, max_tokens=512)\n",
|
|
"cpal_chain = CPALChain.from_univariate_prompt(llm=llm, verbose=True)\n",
|
|
"pal_chain = PALChain.from_math_prompt(llm=llm, verbose=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "858a87d9-a9bd-4850-9687-9af4b0856b62",
|
|
"metadata": {},
|
|
"source": [
|
|
"## CPAL's value against hallucination: CPAL vs PAL\n",
|
|
"\n",
|
|
"Like PAL, CPAL intends to reduce large language model (LLM) hallucination.\n",
|
|
"\n",
|
|
"The CPAL chain is different from the PAL chain for a couple of reasons.\n",
|
|
"\n",
|
|
"CPAL adds a causal structure (or DAG) to link entity actions (or math expressions).\n",
|
|
"The CPAL math expressions are modeling a chain of cause and effect relations, which can be intervened upon, whereas for the PAL chain math expressions are projected math identities.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "496403c5-d268-43ae-8852-2bd9903ce444",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 1.1 Complex narrative\n",
|
|
"\n",
|
|
"Takeaway: PAL hallucinates, CPAL does not hallucinate."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "d5dad768-2892-4825-8093-9b840f643a8a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = (\n",
|
|
" \"Tim buys the same number of pets as Cindy and Boris.\"\n",
|
|
" \"Cindy buys the same number of pets as Bill plus Bob.\"\n",
|
|
" \"Boris buys the same number of pets as Ben plus Beth.\"\n",
|
|
" \"Bill buys the same number of pets as Obama.\"\n",
|
|
" \"Bob buys the same number of pets as Obama.\"\n",
|
|
" \"Ben buys the same number of pets as Obama.\"\n",
|
|
" \"Beth buys the same number of pets as Obama.\"\n",
|
|
" \"If Obama buys one pet, how many pets total does everyone buy?\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "bbffa7a0-3c22-4a1d-ab2d-f230973073b0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mdef solution():\n",
|
|
" \"\"\"Tim buys the same number of pets as Cindy and Boris.Cindy buys the same number of pets as Bill plus Bob.Boris buys the same number of pets as Ben plus Beth.Bill buys the same number of pets as Obama.Bob buys the same number of pets as Obama.Ben buys the same number of pets as Obama.Beth buys the same number of pets as Obama.If Obama buys one pet, how many pets total does everyone buy?\"\"\"\n",
|
|
" obama_pets = 1\n",
|
|
" tim_pets = obama_pets\n",
|
|
" cindy_pets = obama_pets + obama_pets\n",
|
|
" boris_pets = obama_pets + obama_pets\n",
|
|
" total_pets = tim_pets + cindy_pets + boris_pets\n",
|
|
" result = total_pets\n",
|
|
" return result\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'5'"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"pal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "35a70d1d-86f8-4abc-b818-fbd083f072e9",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mstory outcome data\n",
|
|
" name code value depends_on\n",
|
|
"0 obama pass 1.0 []\n",
|
|
"1 bill bill.value = obama.value 1.0 [obama]\n",
|
|
"2 bob bob.value = obama.value 1.0 [obama]\n",
|
|
"3 ben ben.value = obama.value 1.0 [obama]\n",
|
|
"4 beth beth.value = obama.value 1.0 [obama]\n",
|
|
"5 cindy cindy.value = bill.value + bob.value 2.0 [bill, bob]\n",
|
|
"6 boris boris.value = ben.value + beth.value 2.0 [ben, beth]\n",
|
|
"7 tim tim.value = cindy.value + boris.value 4.0 [cindy, boris]\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[36;1m\u001b[1;3mquery data\n",
|
|
"{\n",
|
|
" \"question\": \"how many pets total does everyone buy?\",\n",
|
|
" \"expression\": \"SELECT SUM(value) FROM df\",\n",
|
|
" \"llm_error_msg\": \"\"\n",
|
|
"}\u001b[0m\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"13.0"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"cpal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "ccb6b2b0-9de6-4f66-a8fb-fc59229ee316",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/svg+xml": "<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"292pt\" height=\"260pt\" viewBox=\"0.00 0.00 292.00 260.00\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 256)\">\n<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-256 288,-256 288,4 -4,4\"/>\n<!-- obama -->\n<g id=\"node1\" class=\"node\">\n<title>obama</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"137\" cy=\"-234\" rx=\"41.69\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"137\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\">obama</text>\n</g>\n<!-- bill -->\n<g id=\"node2\" class=\"node\">\n<title>bill</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"27\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">bill</text>\n</g>\n<!-- obama->bill -->\n<g id=\"edge1\" class=\"edge\">\n<title>obama->bill</title>\n<path fill=\"none\" stroke=\"black\" d=\"M114.47,-218.67C97.08,-207.6 72.94,-192.23 54.42,-180.45\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"56.15,-177.4 45.84,-174.99 52.4,-183.31 56.15,-177.4\"/>\n</g>\n<!-- bob -->\n<g id=\"node3\" class=\"node\">\n<title>bob</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"100\" cy=\"-162\" rx=\"28\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"100\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">bob</text>\n</g>\n<!-- obama->bob -->\n<g id=\"edge2\" class=\"edge\">\n<title>obama->bob</title>\n<path fill=\"none\" stroke=\"black\" d=\"M128.04,-216.05C123.66,-207.77 118.3,-197.62 113.44,-188.42\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"116.39,-186.51 108.62,-179.31 110.2,-189.79 116.39,-186.51\"/>\n</g>\n<!-- ben -->\n<g id=\"node4\" class=\"node\">\n<title>ben</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"174\" cy=\"-162\" rx=\"28\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"174\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">ben</text>\n</g>\n<!-- obama->ben -->\n<g id=\"edge3\" class=\"edge\">\n<title>obama->ben</title>\n<path fill=\"none\" stroke=\"black\" d=\"M145.96,-216.05C150.34,-207.77 155.7,-197.62 160.56,-188.42\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"163.8,-189.79 165.38,-179.31 157.61,-186.51 163.8,-189.79\"/>\n</g>\n<!-- beth -->\n<g id=\"node5\" class=\"node\">\n<title>beth</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"252\" cy=\"-162\" rx=\"32\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"252\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">beth</text>\n</g>\n<!-- obama->beth -->\n<g id=\"edge4\" class=\"edge\">\n<title>obama->beth</title>\n<path fill=\"none\" stroke=\"black\" d=\"M160.27,-218.83C178.18,-207.94 203.04,-192.8 222.37,-181.04\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"224.36,-183.92 231.08,-175.73 220.72,-177.95 224.36,-183.92\"/>\n</g>\n<!-- cindy -->\n<g id=\"node6\" class=\"node\">\n<title>cindy</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"93\" cy=\"-90\" rx=\"36\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"93\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">cindy</text>\n</g>\n<!-- bill->cindy -->\n<g id=\"edge5\" class=\"edge\">\n<title>bill->cindy</title>\n<path fill=\"none\" stroke=\"black\" d=\"M41,-146.15C49.77,-136.85 61.25,-124.67 71.2,-114.12\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"73.79,-116.47 78.11,-106.8 68.7,-111.67 73.79,-116.47\"/>\n</g>\n<!-- bob->cindy -->\n<g id=\"edge6\" class=\"edge\">\n<title>bob->cindy</title>\n<path fill=\"none\" stroke=\"black\" d=\"M98.27,-143.7C97.5,-135.98 96.57,-126.71 95.71,-118.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"99.19,-117.7 94.71,-108.1 92.22,-118.4 99.19,-117.7\"/>\n</g>\n<!-- boris -->\n<g id=\"node7\" class=\"node\">\n<title>boris</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"181\" cy=\"-90\" rx=\"34.5\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"181\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">boris</text>\n</g>\n<!-- ben->boris -->\n<g id=\"edge7\" class=\"edge\">\n<title>ben->boris</title>\n<path fill=\"none\" stroke=\"black\" d=\"M175.73,-143.7C176.5,-135.98 177.43,-126.71 178.29,-118.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"181.78,-118.4 179.29,-108.1 174.81,-117.7 181.78,-118.4\"/>\n</g>\n<!-- beth->boris -->\n<g id=\"edge8\" class=\"edge\">\n<title>beth->boris</title>\n<path fill=\"none\" stroke=\"black\" d=\"M236.59,-145.81C227.01,-136.36 214.51,-124.04 203.8,-113.48\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"205.96,-110.69 196.38,-106.16 201.04,-115.67 205.96,-110.69\"/>\n</g>\n<!-- tim -->\n<g id=\"node8\" class=\"node\">\n<title>tim</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"137\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"137\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">tim</text>\n</g>\n<!-- cindy->tim -->\n<g id=\"edge9\" class=\"edge\">\n<title>cindy->tim</title>\n<path fill=\"none\" stroke=\"black\" d=\"M103.43,-72.41C108.82,-63.83 115.51,-53.19 121.49,-43.67\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"124.59,-45.32 126.95,-34.99 118.66,-41.59 124.59,-45.32\"/>\n</g>\n<!-- boris->tim -->\n<g id=\"edge10\" class=\"edge\">\n<title>boris->tim</title>\n<path fill=\"none\" stroke=\"black\" d=\"M170.79,-72.77C165.41,-64.19 158.68,-53.49 152.65,-43.9\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"155.43,-41.75 147.15,-35.15 149.51,-45.48 155.43,-41.75\"/>\n</g>\n</g>\n</svg>",
|
|
"text/plain": [
|
|
"<IPython.core.display.SVG object>"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# wait 20 secs to see display\n",
|
|
"cpal_chain.draw(path=\"web.svg\")\n",
|
|
"SVG(\"web.svg\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1f6f345a-bb16-4e64-83c4-cbbc789a8325",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Unanswerable math\n",
|
|
"\n",
|
|
"Takeaway: PAL hallucinates, where CPAL, rather than hallucinate, answers with _\"unanswerable, narrative question and plot are incoherent\"_"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "068afd79-fd41-4ec2-b4d0-c64140dc413f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = (\n",
|
|
" \"Jan has three times the number of pets as Marcia.\"\n",
|
|
" \"Marcia has two more pets than Cindy.\"\n",
|
|
" \"If Cindy has ten pets, how many pets does Barak have?\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "02f77db2-72e8-46c2-90b3-5e37ca42f80d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mdef solution():\n",
|
|
" \"\"\"Jan has three times the number of pets as Marcia.Marcia has two more pets than Cindy.If Cindy has ten pets, how many pets does Barak have?\"\"\"\n",
|
|
" cindy_pets = 10\n",
|
|
" marcia_pets = cindy_pets + 2\n",
|
|
" jan_pets = marcia_pets * 3\n",
|
|
" result = jan_pets\n",
|
|
" return result\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'36'"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"pal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "925958de-e998-4ffa-8b2e-5a00ddae5026",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mstory outcome data\n",
|
|
" name code value depends_on\n",
|
|
"0 cindy pass 10.0 []\n",
|
|
"1 marcia marcia.value = cindy.value + 2 12.0 [cindy]\n",
|
|
"2 jan jan.value = marcia.value * 3 36.0 [marcia]\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[36;1m\u001b[1;3mquery data\n",
|
|
"{\n",
|
|
" \"question\": \"how many pets does barak have?\",\n",
|
|
" \"expression\": \"SELECT name, value FROM df WHERE name = 'barak'\",\n",
|
|
" \"llm_error_msg\": \"\"\n",
|
|
"}\u001b[0m\n",
|
|
"\n",
|
|
"unanswerable, query and outcome are incoherent\n",
|
|
"\n",
|
|
"outcome:\n",
|
|
" name code value depends_on\n",
|
|
"0 cindy pass 10.0 []\n",
|
|
"1 marcia marcia.value = cindy.value + 2 12.0 [cindy]\n",
|
|
"2 jan jan.value = marcia.value * 3 36.0 [marcia]\n",
|
|
"query:\n",
|
|
"{'question': 'how many pets does barak have?', 'expression': \"SELECT name, value FROM df WHERE name = 'barak'\", 'llm_error_msg': ''}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"try:\n",
|
|
" cpal_chain.run(question)\n",
|
|
"except Exception as e_msg:\n",
|
|
" print(e_msg)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "095adc76",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Basic math\n",
|
|
"\n",
|
|
"#### Causal mediator"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "3ecf03fa-8350-4c4e-8080-84a307ba6ad4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = (\n",
|
|
" \"Jan has three times the number of pets as Marcia. \"\n",
|
|
" \"Marcia has two more pets than Cindy. \"\n",
|
|
" \"If Cindy has four pets, how many total pets do the three have?\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "74e49c47-3eed-4abe-98b7-8e97bcd15944",
|
|
"metadata": {},
|
|
"source": [
|
|
"---\n",
|
|
"PAL"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "2e88395f-d014-4362-abb0-88f6800860bb",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mdef solution():\n",
|
|
" \"\"\"Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy. If Cindy has four pets, how many total pets do the three have?\"\"\"\n",
|
|
" cindy_pets = 4\n",
|
|
" marcia_pets = cindy_pets + 2\n",
|
|
" jan_pets = marcia_pets * 3\n",
|
|
" total_pets = cindy_pets + marcia_pets + jan_pets\n",
|
|
" result = total_pets\n",
|
|
" return result\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'28'"
|
|
]
|
|
},
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"pal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "20ba6640-3d17-4b59-8101-aaba89d68cf4",
|
|
"metadata": {},
|
|
"source": [
|
|
"---\n",
|
|
"CPAL"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "312a0943-a482-4ed0-a064-1e7a72e9479b",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mstory outcome data\n",
|
|
" name code value depends_on\n",
|
|
"0 cindy pass 4.0 []\n",
|
|
"1 marcia marcia.value = cindy.value + 2 6.0 [cindy]\n",
|
|
"2 jan jan.value = marcia.value * 3 18.0 [marcia]\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[36;1m\u001b[1;3mquery data\n",
|
|
"{\n",
|
|
" \"question\": \"how many total pets do the three have?\",\n",
|
|
" \"expression\": \"SELECT SUM(value) FROM df\",\n",
|
|
" \"llm_error_msg\": \"\"\n",
|
|
"}\u001b[0m\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"28.0"
|
|
]
|
|
},
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"cpal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "4466b975-ae2b-4252-972b-b3182a089ade",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/svg+xml": "<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"92pt\" height=\"188pt\" viewBox=\"0.00 0.00 92.49 188.00\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 184)\">\n<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-184 88.49,-184 88.49,4 -4,4\"/>\n<!-- cindy -->\n<g id=\"node1\" class=\"node\">\n<title>cindy</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"42.25\" cy=\"-162\" rx=\"36\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"42.25\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">cindy</text>\n</g>\n<!-- marcia -->\n<g id=\"node2\" class=\"node\">\n<title>marcia</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"42.25\" cy=\"-90\" rx=\"42.49\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"42.25\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">marcia</text>\n</g>\n<!-- cindy->marcia -->\n<g id=\"edge1\" class=\"edge\">\n<title>cindy->marcia</title>\n<path fill=\"none\" stroke=\"black\" d=\"M42.25,-143.7C42.25,-135.98 42.25,-126.71 42.25,-118.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"45.75,-118.1 42.25,-108.1 38.75,-118.1 45.75,-118.1\"/>\n</g>\n<!-- jan -->\n<g id=\"node3\" class=\"node\">\n<title>jan</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"42.25\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"42.25\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">jan</text>\n</g>\n<!-- marcia->jan -->\n<g id=\"edge2\" class=\"edge\">\n<title>marcia->jan</title>\n<path fill=\"none\" stroke=\"black\" d=\"M42.25,-71.7C42.25,-63.98 42.25,-54.71 42.25,-46.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"45.75,-46.1 42.25,-36.1 38.75,-46.1 45.75,-46.1\"/>\n</g>\n</g>\n</svg>",
|
|
"text/plain": [
|
|
"<IPython.core.display.SVG object>"
|
|
]
|
|
},
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# wait 20 secs to see display\n",
|
|
"cpal_chain.draw(path=\"web.svg\")\n",
|
|
"SVG(\"web.svg\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "29fa7b8a-75a3-4270-82a2-2c31939cd7e0",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Causal collider"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "618eddac-f0ef-4ab5-90ed-72e880fdeba3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = (\n",
|
|
" \"Jan has the number of pets as Marcia plus the number of pets as Cindy. \"\n",
|
|
" \"Marcia has no pets. \"\n",
|
|
" \"If Cindy has four pets, how many total pets do the three have?\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "a01563f3-7974-4de4-8bd9-0b7d710aa0d3",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mstory outcome data\n",
|
|
" name code value depends_on\n",
|
|
"0 marcia pass 0.0 []\n",
|
|
"1 cindy pass 4.0 []\n",
|
|
"2 jan jan.value = marcia.value + cindy.value 4.0 [marcia, cindy]\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[36;1m\u001b[1;3mquery data\n",
|
|
"{\n",
|
|
" \"question\": \"how many total pets do the three have?\",\n",
|
|
" \"expression\": \"SELECT SUM(value) FROM df\",\n",
|
|
" \"llm_error_msg\": \"\"\n",
|
|
"}\u001b[0m\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"8.0"
|
|
]
|
|
},
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"cpal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "0fbe7243-0522-4946-b9a2-6e21e7c49a42",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/svg+xml": "<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"182pt\" height=\"116pt\" viewBox=\"0.00 0.00 182.00 116.00\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 112)\">\n<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-112 178,-112 178,4 -4,4\"/>\n<!-- marcia -->\n<g id=\"node1\" class=\"node\">\n<title>marcia</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"42.25\" cy=\"-90\" rx=\"42.49\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"42.25\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">marcia</text>\n</g>\n<!-- jan -->\n<g id=\"node2\" class=\"node\">\n<title>jan</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"90.25\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"90.25\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">jan</text>\n</g>\n<!-- marcia->jan -->\n<g id=\"edge1\" class=\"edge\">\n<title>marcia->jan</title>\n<path fill=\"none\" stroke=\"black\" d=\"M53.62,-72.41C59.57,-63.74 66.95,-52.97 73.53,-43.38\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"76.51,-45.21 79.28,-34.99 70.74,-41.26 76.51,-45.21\"/>\n</g>\n<!-- cindy -->\n<g id=\"node3\" class=\"node\">\n<title>cindy</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"138.25\" cy=\"-90\" rx=\"36\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"138.25\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">cindy</text>\n</g>\n<!-- cindy->jan -->\n<g id=\"edge2\" class=\"edge\">\n<title>cindy->jan</title>\n<path fill=\"none\" stroke=\"black\" d=\"M127.11,-72.77C121.09,-63.98 113.54,-52.96 106.83,-43.19\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"109.53,-40.94 100.99,-34.67 103.75,-44.89 109.53,-40.94\"/>\n</g>\n</g>\n</svg>",
|
|
"text/plain": [
|
|
"<IPython.core.display.SVG object>"
|
|
]
|
|
},
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# wait 20 secs to see display\n",
|
|
"cpal_chain.draw(path=\"web.svg\")\n",
|
|
"SVG(\"web.svg\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d4082538-ec03-44f0-aac3-07e03aad7555",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Causal confounder"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "83932c30-950b-435a-b328-7993ce8cc6bd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = (\n",
|
|
" \"Jan has the number of pets as Marcia plus the number of pets as Cindy. \"\n",
|
|
" \"Marcia has two more pets than Cindy. \"\n",
|
|
" \"If Cindy has four pets, how many total pets do the three have?\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "570de307-7c6b-4fdc-80c3-4361daa8a629",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mstory outcome data\n",
|
|
" name code value depends_on\n",
|
|
"0 cindy pass 4.0 []\n",
|
|
"1 marcia marcia.value = cindy.value + 2 6.0 [cindy]\n",
|
|
"2 jan jan.value = cindy.value + marcia.value 10.0 [cindy, marcia]\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[36;1m\u001b[1;3mquery data\n",
|
|
"{\n",
|
|
" \"question\": \"how many total pets do the three have?\",\n",
|
|
" \"expression\": \"SELECT SUM(value) FROM df\",\n",
|
|
" \"llm_error_msg\": \"\"\n",
|
|
"}\u001b[0m\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"20.0"
|
|
]
|
|
},
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"cpal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "00375615-6b6d-4357-bdb8-f64f682f7605",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/svg+xml": "<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"121pt\" height=\"188pt\" viewBox=\"0.00 0.00 120.99 188.00\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 184)\">\n<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-184 116.99,-184 116.99,4 -4,4\"/>\n<!-- cindy -->\n<g id=\"node1\" class=\"node\">\n<title>cindy</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"77.25\" cy=\"-162\" rx=\"36\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"77.25\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">cindy</text>\n</g>\n<!-- marcia -->\n<g id=\"node2\" class=\"node\">\n<title>marcia</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"42.25\" cy=\"-90\" rx=\"42.49\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"42.25\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">marcia</text>\n</g>\n<!-- cindy->marcia -->\n<g id=\"edge1\" class=\"edge\">\n<title>cindy->marcia</title>\n<path fill=\"none\" stroke=\"black\" d=\"M68.95,-144.41C64.87,-136.25 59.86,-126.22 55.28,-117.07\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"58.33,-115.34 50.72,-107.96 52.07,-118.47 58.33,-115.34\"/>\n</g>\n<!-- jan -->\n<g id=\"node3\" class=\"node\">\n<title>jan</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"77.25\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"77.25\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">jan</text>\n</g>\n<!-- cindy->jan -->\n<g id=\"edge2\" class=\"edge\">\n<title>cindy->jan</title>\n<path fill=\"none\" stroke=\"black\" d=\"M83.73,-144.1C87.32,-133.84 91.42,-120.36 93.25,-108 95.58,-92.17 95.58,-87.83 93.25,-72 91.95,-63.21 89.5,-53.86 86.91,-45.5\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"90.19,-44.29 83.73,-35.9 83.55,-46.49 90.19,-44.29\"/>\n</g>\n<!-- marcia->jan -->\n<g id=\"edge3\" class=\"edge\">\n<title>marcia->jan</title>\n<path fill=\"none\" stroke=\"black\" d=\"M50.72,-72.06C54.86,-63.77 59.94,-53.62 64.53,-44.42\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"67.75,-45.82 69.09,-35.31 61.49,-42.69 67.75,-45.82\"/>\n</g>\n</g>\n</svg>",
|
|
"text/plain": [
|
|
"<IPython.core.display.SVG object>"
|
|
]
|
|
},
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# wait 20 secs to see display\n",
|
|
"cpal_chain.draw(path=\"web.svg\")\n",
|
|
"SVG(\"web.svg\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "255683de-0c1c-4131-b277-99d09f5ac1fc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%load_ext autoreload\n",
|
|
"%autoreload 2"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|