diff --git a/libs/experimental/langchain_experimental/llm_symbolic_math/base.py b/libs/experimental/langchain_experimental/llm_symbolic_math/base.py index 3f107d0d913..ca6e912f0fe 100644 --- a/libs/experimental/langchain_experimental/llm_symbolic_math/base.py +++ b/libs/experimental/langchain_experimental/llm_symbolic_math/base.py @@ -42,6 +42,44 @@ class LLMSymbolicMathChain(Chain): extra="forbid", ) + allow_dangerous_requests: bool # Assign no default. + """Must be set by the user to allow dangerous requests or not. + + We recommend a default of False to allow only pre-defined symbolic operations. + + When set to True, the chain will allow any kind of input. This is + STRONGLY DISCOURAGED unless you fully trust the input (and believe that + the LLM itself cannot behave in a malicious way). + You should absolutely NOT be deploying this in a production environment + with allow_dangerous_requests=True. As this would allow a malicious actor + to execute arbitrary code on your system. + Use default=True at your own risk. + + + When set to False, the chain will only allow pre-defined symbolic operations. + If the some symbolic expressions are failing to evaluate, you can open a PR + to add them to extend the list of allowed operations. + """ + + def __init__(self, **kwargs: Any) -> None: + if "allow_dangerous_requests" not in kwargs: + raise ValueError( + "LLMSymbolicMathChain requires allow_dangerous_requests to be set. " + "We recommend that you set `allow_dangerous_requests=False` to allow " + "only pre-defined symbolic operations. " + "If the some symbolic expressions are failing to evaluate, you can " + "open a PR to add them to extend the list of allowed operations. " + "Alternatively, you can set `allow_dangerous_requests=True` to allow " + "any kind of input but this is STRONGLY DISCOURAGED unless you " + "fully trust the input (and believe that the LLM itself cannot behave " + "in a malicious way)." + "You should absolutely NOT be deploying this in a production " + "environment with allow_dangerous_requests=True. As " + "this would allow a malicious actor to execute arbitrary code on " + "your system." + ) + super().__init__(**kwargs) + @property def input_keys(self) -> List[str]: """Expect input key. @@ -65,8 +103,59 @@ class LLMSymbolicMathChain(Chain): raise ImportError( "Unable to import sympy, please install it with `pip install sympy`." ) from e + try: - output = str(sympy.sympify(expression, evaluate=True)) + if self.allow_dangerous_requests: + output = str(sympy.sympify(expression, evaluate=True)) + else: + allowed_symbols = { + # Basic arithmetic and trigonometry + "sin": sympy.sin, + "cos": sympy.cos, + "tan": sympy.tan, + "cot": sympy.cot, + "sec": sympy.sec, + "csc": sympy.csc, + "asin": sympy.asin, + "acos": sympy.acos, + "atan": sympy.atan, + # Hyperbolic functions + "sinh": sympy.sinh, + "cosh": sympy.cosh, + "tanh": sympy.tanh, + "asinh": sympy.asinh, + "acosh": sympy.acosh, + "atanh": sympy.atanh, + # Exponentials and logarithms + "exp": sympy.exp, + "log": sympy.log, + "ln": sympy.log, # natural log sympy defaults to natural log + "log10": lambda x: sympy.log(x, 10), # log base 10 (use sympy.log) + # Powers and roots + "sqrt": sympy.sqrt, + "cbrt": lambda x: sympy.Pow(x, sympy.Rational(1, 3)), + # Combinatorics and other math functions + "factorial": sympy.factorial, + "binomial": sympy.binomial, + "gcd": sympy.gcd, + "lcm": sympy.lcm, + "abs": sympy.Abs, + "sign": sympy.sign, + "mod": sympy.Mod, + # Constants + "pi": sympy.pi, + "e": sympy.E, + "I": sympy.I, + "oo": sympy.oo, + "NaN": sympy.nan, + } + + # Use parse_expr with strict settings + output = str( + sympy.parse_expr( + expression, local_dict=allowed_symbols, evaluate=True + ) + ) except Exception as e: raise ValueError( f'LLMSymbolicMathChain._evaluate("{expression}") raised error: {e}.' diff --git a/libs/experimental/tests/unit_tests/test_llm_symbolic_math.py b/libs/experimental/tests/unit_tests/test_llm_symbolic_math.py index 306d4ec1673..b5d4a91c748 100644 --- a/libs/experimental/tests/unit_tests/test_llm_symbolic_math.py +++ b/libs/experimental/tests/unit_tests/test_llm_symbolic_math.py @@ -34,9 +34,20 @@ def fake_llm_symbolic_math_chain() -> LLMSymbolicMathChain: question="What are the solutions to this equation x**2 - x?" ): "```text\nsolveset(x**2 - x, x)\n```", _PROMPT_TEMPLATE.format(question="foo"): "foo", + _PROMPT_TEMPLATE.format(question="__import__('os')"): "__import__('os')", } fake_llm = FakeLLM(queries=queries) - return LLMSymbolicMathChain.from_llm(fake_llm, input_key="q", output_key="a") + return LLMSymbolicMathChain.from_llm( + fake_llm, input_key="q", output_key="a", allow_dangerous_requests=False + ) + + +def test_require_allow_dangerous_requests_to_be_set() -> None: + """Test that allow_dangerous_requests must be set.""" + fake_llm = FakeLLM(queries={}) + + with pytest.raises(ValueError): + LLMSymbolicMathChain.from_llm(fake_llm, input_key="q", output_key="a") def test_simple_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None: @@ -80,3 +91,15 @@ def test_error(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None: """Test question that raises error.""" with pytest.raises(ValueError): fake_llm_symbolic_math_chain.run("foo") + + +def test_security_vulnerability( + fake_llm_symbolic_math_chain: LLMSymbolicMathChain, +) -> None: + """Test for potential security vulnerability with malicious input.""" + # Example of a code injection attempt + malicious_input = "__import__('os')" + + # Run the chain with the malicious input and ensure it raises an error + with pytest.raises(ValueError): + fake_llm_symbolic_math_chain.run(malicious_input)