mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-11-04 10:10:09 +00:00 
			
		
		
		
	Signed-off-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com> Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com> Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: ccurme <chester.curme@gmail.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com> Co-authored-by: ZhangShenao <15201440436@163.com> Co-authored-by: Friso H. Kingma <fhkingma@gmail.com> Co-authored-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Nuno Campos <nuno@langchain.dev> Co-authored-by: Morgante Pell <morgantep@google.com>
		
			
				
	
	
		
			162 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			162 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""Chain that interprets a prompt and executes python code to do symbolic math."""
 | 
						|
 | 
						|
from __future__ import annotations
 | 
						|
 | 
						|
import re
 | 
						|
from typing import Any, Dict, List, Optional
 | 
						|
 | 
						|
from langchain.base_language import BaseLanguageModel
 | 
						|
from langchain.chains.base import Chain
 | 
						|
from langchain.chains.llm import LLMChain
 | 
						|
from langchain_core.callbacks.manager import (
 | 
						|
    AsyncCallbackManagerForChainRun,
 | 
						|
    CallbackManagerForChainRun,
 | 
						|
)
 | 
						|
from langchain_core.prompts.base import BasePromptTemplate
 | 
						|
from pydantic import ConfigDict
 | 
						|
 | 
						|
from langchain_experimental.llm_symbolic_math.prompt import PROMPT
 | 
						|
 | 
						|
 | 
						|
class LLMSymbolicMathChain(Chain):
 | 
						|
    """Chain that interprets a prompt and executes python code to do symbolic math.
 | 
						|
 | 
						|
    It is based on the sympy library and can be used to evaluate
 | 
						|
    mathematical expressions.
 | 
						|
    See https://www.sympy.org/ for more information.
 | 
						|
 | 
						|
    Example:
 | 
						|
        .. code-block:: python
 | 
						|
 | 
						|
            from langchain.chains import LLMSymbolicMathChain
 | 
						|
            from langchain_community.llms import OpenAI
 | 
						|
            llm_symbolic_math = LLMSymbolicMathChain.from_llm(OpenAI())
 | 
						|
    """
 | 
						|
 | 
						|
    llm_chain: LLMChain
 | 
						|
    input_key: str = "question"  #: :meta private:
 | 
						|
    output_key: str = "answer"  #: :meta private:
 | 
						|
 | 
						|
    model_config = ConfigDict(
 | 
						|
        arbitrary_types_allowed=True,
 | 
						|
        extra="forbid",
 | 
						|
    )
 | 
						|
 | 
						|
    @property
 | 
						|
    def input_keys(self) -> List[str]:
 | 
						|
        """Expect input key.
 | 
						|
 | 
						|
        :meta private:
 | 
						|
        """
 | 
						|
        return [self.input_key]
 | 
						|
 | 
						|
    @property
 | 
						|
    def output_keys(self) -> List[str]:
 | 
						|
        """Expect output key.
 | 
						|
 | 
						|
        :meta private:
 | 
						|
        """
 | 
						|
        return [self.output_key]
 | 
						|
 | 
						|
    def _evaluate_expression(self, expression: str) -> str:
 | 
						|
        try:
 | 
						|
            import sympy
 | 
						|
        except ImportError as e:
 | 
						|
            raise ImportError(
 | 
						|
                "Unable to import sympy, please install it with `pip install sympy`."
 | 
						|
            ) from e
 | 
						|
        try:
 | 
						|
            output = str(sympy.sympify(expression, evaluate=True))
 | 
						|
        except Exception as e:
 | 
						|
            raise ValueError(
 | 
						|
                f'LLMSymbolicMathChain._evaluate("{expression}") raised error: {e}.'
 | 
						|
                " Please try again with a valid numerical expression"
 | 
						|
            )
 | 
						|
 | 
						|
        # Remove any leading and trailing brackets from the output
 | 
						|
        return re.sub(r"^\[|\]$", "", output)
 | 
						|
 | 
						|
    def _process_llm_result(
 | 
						|
        self, llm_output: str, run_manager: CallbackManagerForChainRun
 | 
						|
    ) -> Dict[str, str]:
 | 
						|
        run_manager.on_text(llm_output, color="green", verbose=self.verbose)
 | 
						|
        llm_output = llm_output.strip()
 | 
						|
        text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
 | 
						|
        if text_match:
 | 
						|
            expression = text_match.group(1)
 | 
						|
            output = self._evaluate_expression(expression)
 | 
						|
            run_manager.on_text("\nAnswer: ", verbose=self.verbose)
 | 
						|
            run_manager.on_text(output, color="yellow", verbose=self.verbose)
 | 
						|
            answer = "Answer: " + output
 | 
						|
        elif llm_output.startswith("Answer:"):
 | 
						|
            answer = llm_output
 | 
						|
        elif "Answer:" in llm_output:
 | 
						|
            answer = "Answer: " + llm_output.split("Answer:")[-1]
 | 
						|
        else:
 | 
						|
            raise ValueError(f"unknown format from LLM: {llm_output}")
 | 
						|
        return {self.output_key: answer}
 | 
						|
 | 
						|
    async def _aprocess_llm_result(
 | 
						|
        self,
 | 
						|
        llm_output: str,
 | 
						|
        run_manager: AsyncCallbackManagerForChainRun,
 | 
						|
    ) -> Dict[str, str]:
 | 
						|
        await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
 | 
						|
        llm_output = llm_output.strip()
 | 
						|
        text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
 | 
						|
        if text_match:
 | 
						|
            expression = text_match.group(1)
 | 
						|
            output = self._evaluate_expression(expression)
 | 
						|
            await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
 | 
						|
            await run_manager.on_text(output, color="yellow", verbose=self.verbose)
 | 
						|
            answer = "Answer: " + output
 | 
						|
        elif llm_output.startswith("Answer:"):
 | 
						|
            answer = llm_output
 | 
						|
        elif "Answer:" in llm_output:
 | 
						|
            answer = "Answer: " + llm_output.split("Answer:")[-1]
 | 
						|
        else:
 | 
						|
            raise ValueError(f"unknown format from LLM: {llm_output}")
 | 
						|
        return {self.output_key: answer}
 | 
						|
 | 
						|
    def _call(
 | 
						|
        self,
 | 
						|
        inputs: Dict[str, str],
 | 
						|
        run_manager: Optional[CallbackManagerForChainRun] = None,
 | 
						|
    ) -> Dict[str, str]:
 | 
						|
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
 | 
						|
        _run_manager.on_text(inputs[self.input_key])
 | 
						|
        llm_output = self.llm_chain.predict(
 | 
						|
            question=inputs[self.input_key],
 | 
						|
            stop=["```output"],
 | 
						|
            callbacks=_run_manager.get_child(),
 | 
						|
        )
 | 
						|
        return self._process_llm_result(llm_output, _run_manager)
 | 
						|
 | 
						|
    async def _acall(
 | 
						|
        self,
 | 
						|
        inputs: Dict[str, str],
 | 
						|
        run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
 | 
						|
    ) -> Dict[str, str]:
 | 
						|
        _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
 | 
						|
        await _run_manager.on_text(inputs[self.input_key])
 | 
						|
        llm_output = await self.llm_chain.apredict(
 | 
						|
            question=inputs[self.input_key],
 | 
						|
            stop=["```output"],
 | 
						|
            callbacks=_run_manager.get_child(),
 | 
						|
        )
 | 
						|
        return await self._aprocess_llm_result(llm_output, _run_manager)
 | 
						|
 | 
						|
    @property
 | 
						|
    def _chain_type(self) -> str:
 | 
						|
        return "llm_symbolic_math_chain"
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def from_llm(
 | 
						|
        cls,
 | 
						|
        llm: BaseLanguageModel,
 | 
						|
        prompt: BasePromptTemplate = PROMPT,
 | 
						|
        **kwargs: Any,
 | 
						|
    ) -> LLMSymbolicMathChain:
 | 
						|
        llm_chain = LLMChain(llm=llm, prompt=prompt)
 | 
						|
        return cls(llm_chain=llm_chain, **kwargs)
 |