diff --git a/langchain/__init__.py b/langchain/__init__.py index ebf276a4e98..32bbba17737 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -11,7 +11,6 @@ from langchain.chains import ( LLMChain, LLMMathChain, PALChain, - PythonChain, QAWithSourcesChain, SQLDatabaseChain, VectorDBQA, @@ -32,7 +31,6 @@ from langchain.vectorstores import FAISS, ElasticVectorSearch __all__ = [ "LLMChain", "LLMMathChain", - "PythonChain", "SelfAskWithSearchChain", "SerpAPIWrapper", "SerpAPIChain", diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index d63839bc70d..1152bc1930a 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -4,7 +4,6 @@ from langchain.chains.conversation.base import ConversationChain from langchain.chains.llm import LLMChain from langchain.chains.llm_math.base import LLMMathChain from langchain.chains.pal.base import PALChain -from langchain.chains.python import PythonChain from langchain.chains.qa_with_sources.base import QAWithSourcesChain from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain from langchain.chains.sequential import SequentialChain, SimpleSequentialChain @@ -14,7 +13,6 @@ from langchain.chains.vector_db_qa.base import VectorDBQA __all__ = [ "LLMChain", "LLMMathChain", - "PythonChain", "SQLDatabaseChain", "VectorDBQA", "SequentialChain", diff --git a/langchain/chains/llm_math/base.py b/langchain/chains/llm_math/base.py index e9d22e32c28..5ebd9051e3b 100644 --- a/langchain/chains/llm_math/base.py +++ b/langchain/chains/llm_math/base.py @@ -6,9 +6,9 @@ from pydantic import BaseModel, Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_math.prompt import PROMPT -from langchain.chains.python import PythonChain from langchain.input import print_text from langchain.llms.base import LLM +from langchain.python import PythonREPL class LLMMathChain(Chain, BaseModel): @@ -50,7 +50,7 @@ class LLMMathChain(Chain, BaseModel): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: llm_executor = LLMChain(prompt=PROMPT, llm=self.llm) - python_executor = PythonChain() + python_executor = PythonREPL() if self.verbose: print_text(inputs[self.input_key]) t = llm_executor.predict(question=inputs[self.input_key], stop=["```output"]) diff --git a/langchain/chains/pal/base.py b/langchain/chains/pal/base.py index 04573625a0a..1077c9580e7 100644 --- a/langchain/chains/pal/base.py +++ b/langchain/chains/pal/base.py @@ -12,10 +12,10 @@ from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT from langchain.chains.pal.math_prompt import MATH_PROMPT -from langchain.chains.python import PythonChain from langchain.input import print_text from langchain.llms.base import LLM from langchain.prompts.base import BasePromptTemplate +from langchain.python import PythonREPL class PALChain(Chain, BaseModel): @@ -54,7 +54,7 @@ class PALChain(Chain, BaseModel): code = llm_chain.predict(stop=[self.stop], **inputs) if self.verbose: print_text(code, color="green", end="\n") - repl = PythonChain() + repl = PythonREPL() res = repl.run(code + f"\n{self.get_answer_expr}") return {self.output_key: res.strip()} diff --git a/langchain/chains/python.py b/langchain/chains/python.py deleted file mode 100644 index 74e598c24e4..00000000000 --- a/langchain/chains/python.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Chain that runs python code. - -Heavily borrowed from https://replit.com/@amasad/gptpy?v=1#main.py -""" -import sys -from io import StringIO -from typing import Dict, List - -from pydantic import BaseModel - -from langchain.chains.base import Chain -from langchain.python import PythonREPL - - -class PythonChain(Chain, BaseModel): - """Chain to run python code. - - Example: - .. code-block:: python - - from langchain import PythonChain - python_chain = PythonChain() - """ - - input_key: str = "code" #: :meta private: - output_key: str = "output" #: :meta private: - - @property - def input_keys(self) -> List[str]: - """Expect input key. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Return output key. - - :meta private: - """ - return [self.output_key] - - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: - python_repl = PythonREPL() - old_stdout = sys.stdout - sys.stdout = mystdout = StringIO() - python_repl.run(inputs[self.input_key]) - sys.stdout = old_stdout - output = mystdout.getvalue() - return {self.output_key: output} diff --git a/langchain/python.py b/langchain/python.py index b4a40cff178..88ad4db05d8 100644 --- a/langchain/python.py +++ b/langchain/python.py @@ -1,4 +1,6 @@ """Mock Python REPL.""" +import sys +from io import StringIO from typing import Dict, Optional @@ -10,6 +12,11 @@ class PythonREPL: self._globals = _globals if _globals is not None else {} self._locals = _locals if _locals is not None else {} - def run(self, command: str) -> None: - """Run command with own globals/locals.""" + def run(self, command: str) -> str: + """Run command with own globals/locals and returns anything printed.""" + old_stdout = sys.stdout + sys.stdout = mystdout = StringIO() exec(command, self._globals, self._locals) + sys.stdout = old_stdout + output = mystdout.getvalue() + return output diff --git a/tests/unit_tests/chains/test_python.py b/tests/unit_tests/chains/test_python.py deleted file mode 100644 index 1677a76a472..00000000000 --- a/tests/unit_tests/chains/test_python.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Test python chain.""" - -from langchain.chains.python import PythonChain - - -def test_functionality() -> None: - """Test correct functionality.""" - chain = PythonChain(input_key="code1", output_key="output1") - code = "print(1 + 1)" - output = chain({"code1": code}) - assert output == {"code1": code, "output1": "2\n"} - - # Test with the more user-friendly interface. - simple_output = chain.run(code) - assert simple_output == "2\n" diff --git a/tests/unit_tests/test_python.py b/tests/unit_tests/test_python.py index 419f13cae33..3cdd4dc4d9a 100644 --- a/tests/unit_tests/test_python.py +++ b/tests/unit_tests/test_python.py @@ -32,3 +32,11 @@ def test_python_repl_pass_in_locals() -> None: repl = PythonREPL(_locals=_locals) repl.run("bar = foo * 2") assert repl._locals["bar"] == 8 + + +def test_functionality() -> None: + """Test correct functionality.""" + chain = PythonREPL() + code = "print(1 + 1)" + output = chain.run(code) + assert output == "2\n"