diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index 6b1ef324218..a11269664fb 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -48,13 +48,24 @@ def check_valid_template( raise ValueError("Invalid prompt schema.") -class BaseOutputParser(ABC): +class BaseOutputParser(BaseModel, ABC): """Class to parse the output of an LLM call.""" @abstractmethod def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]: """Parse the output of an LLM call.""" + @property + def _type(self) -> str: + """Return the type key.""" + raise NotImplementedError + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of output parser.""" + output_parser_dict = super().dict() + output_parser_dict["_type"] = self._type + return output_parser_dict + class ListOutputParser(BaseOutputParser): """Class to parse the output of an LLM call to a list.""" @@ -79,6 +90,11 @@ class RegexParser(BaseOutputParser, BaseModel): output_keys: List[str] default_output_key: Optional[str] = None + @property + def _type(self) -> str: + """Return the type key.""" + return "regex_parser" + def parse(self, text: str) -> Dict[str, str]: """Parse the output of an LLM call.""" match = re.search(self.regex, text) @@ -142,7 +158,7 @@ class BasePromptTemplate(BaseModel, ABC): def dict(self, **kwargs: Any) -> Dict: """Return dictionary representation of prompt.""" - prompt_dict = super().dict() + prompt_dict = super().dict(**kwargs) prompt_dict["_type"] = self._prompt_type return prompt_dict diff --git a/langchain/prompts/loading.py b/langchain/prompts/loading.py index 651525ee0ab..eaae6ad6c5f 100644 --- a/langchain/prompts/loading.py +++ b/langchain/prompts/loading.py @@ -9,7 +9,7 @@ from typing import Union import requests import yaml -from langchain.prompts.base import BasePromptTemplate +from langchain.prompts.base import BasePromptTemplate, RegexParser from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate @@ -69,6 +69,20 @@ def _load_examples(config: dict) -> dict: return config +def _load_output_parser(config: dict) -> dict: + """Load output parser.""" + if "output_parser" in config: + if config["output_parser"] is not None: + _config = config["output_parser"] + output_parser_type = _config["_type"] + if output_parser_type == "regex_parser": + output_parser = RegexParser(**_config) + else: + raise ValueError(f"Unsupported output parser {output_parser_type}") + config["output_parser"] = output_parser + return config + + def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate: """Load the few shot prompt from the config.""" # Load the suffix and prefix templates. @@ -86,6 +100,7 @@ def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate: config["example_prompt"] = load_prompt_from_config(config["example_prompt"]) # Load the examples. config = _load_examples(config) + config = _load_output_parser(config) return FewShotPromptTemplate(**config) @@ -93,6 +108,7 @@ def _load_prompt(config: dict) -> PromptTemplate: """Load the prompt template from config.""" # Load the template from disk if necessary. config = _load_template("template", config) + config = _load_output_parser(config) return PromptTemplate(**config)