diff --git a/docs/extras/modules/chains/additional/llm_symbolic_math.ipynb b/docs/extras/modules/chains/additional/llm_symbolic_math.ipynb new file mode 100644 index 00000000000..4aca644c37b --- /dev/null +++ b/docs/extras/modules/chains/additional/llm_symbolic_math.ipynb @@ -0,0 +1,160 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LLM Symbolic Math \n", + "This notebook showcases using LLMs and Python to Solve Algebraic Equations." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Calculating the limit of an equation" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new LLMSymbolicMathChain chain...\u001b[0m\n", + "What is the limit of sin(x) / x as x goes to 0?\u001b[32;1m\u001b[1;3mAnswer: 1\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'Answer: 1'" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain.llms import OpenAI\n", + "from langchain.chains.llm_symbolic_math.base import LLMSymbolicMathChain\n", + "\n", + "llm = OpenAI(temperature=0)\n", + "llm_symbolic_math = LLMSymbolicMathChain.from_llm(llm, verbose=True)\n", + "\n", + "llm_symbolic_math.run(\"What is the limit of sin(x) / x as x goes to 0?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Calculating an integral" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new LLMSymbolicMathChain chain...\u001b[0m\n", + "What is the integral of e^-x from 0 to infinity?\u001b[32;1m\u001b[1;3mAnswer: 1\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'Answer: 1'" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llm_symbolic_math.run(\"What is the integral of e^-x from 0 to infinity?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Calculating an algebraic equation" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new LLMSymbolicMathChain chain...\u001b[0m\n", + "What are the solutions to this equation x**2 - x?\u001b[32;1m\u001b[1;3mAnswer: 0 and 1.\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'Answer: 0 and 1.'" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llm_symbolic_math.run(\"What are the solutions to this equation x**2 - x?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/langchain/chains/llm_symbolic_math/__init__.py b/langchain/chains/llm_symbolic_math/__init__.py new file mode 100644 index 00000000000..d6cde9105ad --- /dev/null +++ b/langchain/chains/llm_symbolic_math/__init__.py @@ -0,0 +1,4 @@ +"""Chain that interprets a prompt and executes python code to do math. + +Heavily borrowed from llm_math, wrapper for SymPy +""" diff --git a/langchain/chains/llm_symbolic_math/base.py b/langchain/chains/llm_symbolic_math/base.py new file mode 100644 index 00000000000..c2092765343 --- /dev/null +++ b/langchain/chains/llm_symbolic_math/base.py @@ -0,0 +1,156 @@ +"""Chain that interprets a prompt and executes python code to do math.""" +from __future__ import annotations + +import re +from typing import Any, Dict, List, Optional + +from pydantic import Extra + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.chains.llm_symbolic_math.prompt import PROMPT +from langchain.prompts.base import BasePromptTemplate + + +class LLMSymbolicMathChain(Chain): + """Chain that interprets a prompt and executes python code to do math. + + Example: + .. code-block:: python + + from langchain import LLMSymbolicMathChain, OpenAI + llm_symbolic_math = LLMSymbolicMathChain.from_llm(OpenAI()) + """ + + llm_chain: LLMChain + input_key: str = "question" #: :meta private: + output_key: str = "answer" #: :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Expect output key. + + :meta private: + """ + return [self.output_key] + + def _evaluate_expression(self, expression: str) -> str: + try: + import sympy + except ImportError as e: + raise ImportError( + "Unable to import sympy, please install it with `pip install sympy`." + ) from e + try: + output = str(sympy.sympify(expression, evaluate=True)) + except Exception as e: + raise ValueError( + f'LLMSymbolicMathChain._evaluate("{expression}") raised error: {e}.' + " Please try again with a valid numerical expression" + ) + + # Remove any leading and trailing brackets from the output + return re.sub(r"^\[|\]$", "", output) + + def _process_llm_result( + self, llm_output: str, run_manager: CallbackManagerForChainRun + ) -> Dict[str, str]: + run_manager.on_text(llm_output, color="green", verbose=self.verbose) + llm_output = llm_output.strip() + text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) + if text_match: + expression = text_match.group(1) + output = self._evaluate_expression(expression) + run_manager.on_text("\nAnswer: ", verbose=self.verbose) + run_manager.on_text(output, color="yellow", verbose=self.verbose) + answer = "Answer: " + output + elif llm_output.startswith("Answer:"): + answer = llm_output + elif "Answer:" in llm_output: + answer = "Answer: " + llm_output.split("Answer:")[-1] + else: + raise ValueError(f"unknown format from LLM: {llm_output}") + return {self.output_key: answer} + + async def _aprocess_llm_result( + self, + llm_output: str, + run_manager: AsyncCallbackManagerForChainRun, + ) -> Dict[str, str]: + await run_manager.on_text(llm_output, color="green", verbose=self.verbose) + llm_output = llm_output.strip() + text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) + if text_match: + expression = text_match.group(1) + output = self._evaluate_expression(expression) + await run_manager.on_text("\nAnswer: ", verbose=self.verbose) + await run_manager.on_text(output, color="yellow", verbose=self.verbose) + answer = "Answer: " + output + elif llm_output.startswith("Answer:"): + answer = llm_output + elif "Answer:" in llm_output: + answer = "Answer: " + llm_output.split("Answer:")[-1] + else: + raise ValueError(f"unknown format from LLM: {llm_output}") + return {self.output_key: answer} + + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + _run_manager.on_text(inputs[self.input_key]) + llm_output = self.llm_chain.predict( + question=inputs[self.input_key], + stop=["```output"], + callbacks=_run_manager.get_child(), + ) + return self._process_llm_result(llm_output, _run_manager) + + async def _acall( + self, + inputs: Dict[str, str], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() + await _run_manager.on_text(inputs[self.input_key]) + llm_output = await self.llm_chain.apredict( + question=inputs[self.input_key], + stop=["```output"], + callbacks=_run_manager.get_child(), + ) + return await self._aprocess_llm_result(llm_output, _run_manager) + + @property + def _chain_type(self) -> str: + return "llm_symbolic_math_chain" + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + prompt: BasePromptTemplate = PROMPT, + **kwargs: Any, + ) -> LLMSymbolicMathChain: + llm_chain = LLMChain(llm=llm, prompt=prompt) + return cls(llm_chain=llm_chain, **kwargs) diff --git a/langchain/chains/llm_symbolic_math/prompt.py b/langchain/chains/llm_symbolic_math/prompt.py new file mode 100644 index 00000000000..576dd1f9dc2 --- /dev/null +++ b/langchain/chains/llm_symbolic_math/prompt.py @@ -0,0 +1,51 @@ +# flake8: noqa +from langchain.prompts.prompt import PromptTemplate + +_PROMPT_TEMPLATE = """Translate a math problem into a expression that can be executed using Python's SymPy library. Use the output of running this code to answer the question. + +Question: ${{Question with math problem.}} +```text +${{single line sympy expression that solves the problem}} +``` +...sympy.sympify(text, evaluate=True)... +```output +${{Output of running the code}} +``` +Answer: ${{Answer}} + +Begin. + +Question: What is the limit of sin(x) / x as x goes to 0 +```text +limit(sin(x)/x, x, 0) +``` +...sympy.sympify("limit(sin(x)/x, x, 0)")... +```output +1 +``` +Answer: 1 + +Question: What is the integral of e^-x from 0 to infinity +```text +integrate(exp(-x), (x, 0, oo)) +``` +...sympy.sympify("integrate(exp(-x), (x, 0, oo))")... +```output +1 +``` + +Question: What are the solutions to this equation x**2 - x? +```text +solveset(x**2 - x, x) +``` +...sympy.sympify("solveset(x**2 - x, x)")... +```output +[0, 1] +``` +Question: {question} +""" + +PROMPT = PromptTemplate( + input_variables=["question"], + template=_PROMPT_TEMPLATE, +) diff --git a/poetry.lock b/poetry.lock index 432f751ef7f..cb3877df0ad 100644 --- a/poetry.lock +++ b/poetry.lock @@ -12690,7 +12690,7 @@ clarifai = ["clarifai"] cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] -extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "gql", "html2text", "jq", "lxml", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "telethon", "tqdm", "zep-python"] +extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "gql", "html2text", "jq", "lxml", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "zep-python"] javascript = ["esprima"] llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers"] openai = ["openai", "tiktoken"] @@ -12700,4 +12700,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "8bb95a90cfff1af5cd7e485a3f271db72de49f4af95bcf0a907ae00384ac35ed" +content-hash = "cd49db5debee164e0fbb17b1d096b5ee7bae992e4dce91567525572d8dc4205e" diff --git a/pyproject.toml b/pyproject.toml index d6c6eb7b3f0..67622dcab37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,6 +117,7 @@ streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || > psychicapi = {version = "^0.8.0", optional = true} cassio = {version = "^0.0.7", optional = true} rdflib = {version = "^6.3.2", optional = true} +sympy = {version = "^1.12", optional = true} rapidfuzz = {version = "^3.1.1", optional = true} [tool.poetry.group.docs.dependencies] @@ -352,8 +353,8 @@ extended_testing = [ "streamlit", "pyspark", "openai", - "rapidfuzz" - + "sympy", + "rapidfuzz", ] [[tool.poetry.source]] diff --git a/tests/unit_tests/chains/test_llm_symbolic_math.py b/tests/unit_tests/chains/test_llm_symbolic_math.py new file mode 100644 index 00000000000..351dcbddfb0 --- /dev/null +++ b/tests/unit_tests/chains/test_llm_symbolic_math.py @@ -0,0 +1,82 @@ +"""Test LLM Math functionality.""" + +import pytest + +from langchain.chains.llm_symbolic_math.base import LLMSymbolicMathChain +from langchain.chains.llm_symbolic_math.prompt import _PROMPT_TEMPLATE +from tests.unit_tests.llms.fake_llm import FakeLLM + + +@pytest.fixture +@pytest.mark.requires("sympy") +def fake_llm_symbolic_math_chain() -> LLMSymbolicMathChain: + """Fake LLM Math chain for testing.""" + queries = { + _PROMPT_TEMPLATE.format(question="What is 1 plus 1?"): "Answer: 2", + _PROMPT_TEMPLATE.format( + question="What is the square root of 2?" + ): "```text\nsqrt(2)\n```", + _PROMPT_TEMPLATE.format( + question="What is the limit of sin(x) / x as x goes to 0?" + ): "```text\nlimit(sin(x)/x,x,0)\n```", + _PROMPT_TEMPLATE.format( + question="What is the integral of e^-x from 0 to infinity?" + ): "```text\nintegrate(exp(-x), (x, 0, oo))\n```", + _PROMPT_TEMPLATE.format( + question="What are the solutions to this equation x**2 - x?" + ): "```text\nsolveset(x**2 - x, x)\n```", + _PROMPT_TEMPLATE.format(question="foo"): "foo", + } + fake_llm = FakeLLM(queries=queries) + return LLMSymbolicMathChain.from_llm(fake_llm, input_key="q", output_key="a") + + +@pytest.mark.requires("sympy") +def test_simple_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None: + """Test simple question that should not need python.""" + question = "What is 1 plus 1?" + output = fake_llm_symbolic_math_chain.run(question) + assert output == "Answer: 2" + + +@pytest.mark.requires("sympy") +def test_root_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None: + """Test irrational number that should need sympy.""" + import sympy + + question = "What is the square root of 2?" + output = fake_llm_symbolic_math_chain.run(question) + assert output == f"Answer: {sympy.sqrt(2)}" + + +@pytest.mark.requires("sympy") +def test_limit_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None: + """Test question about limits that needs sympy""" + question = "What is the limit of sin(x) / x as x goes to 0?" + output = fake_llm_symbolic_math_chain.run(question) + assert output == "Answer: 1" + + +@pytest.mark.requires("sympy") +def test_integration_question( + fake_llm_symbolic_math_chain: LLMSymbolicMathChain, +) -> None: + """Test question about integration that needs sympy""" + question = "What is the integral of e^-x from 0 to infinity?" + output = fake_llm_symbolic_math_chain.run(question) + assert output == "Answer: 1" + + +@pytest.mark.requires("sympy") +def test_solver_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None: + """Test question about solving algebraic equations that needs sympy""" + question = "What are the solutions to this equation x**2 - x?" + output = fake_llm_symbolic_math_chain.run(question) + assert output == "Answer: {0, 1}" + + +@pytest.mark.requires("sympy") +def test_error(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None: + """Test question that raises error.""" + with pytest.raises(ValueError): + fake_llm_symbolic_math_chain.run("foo")