diff --git a/langchain/chains/llm_for_loop.py b/langchain/chains/llm_for_loop.py index 7c2b63d060a..e47813bc78b 100644 --- a/langchain/chains/llm_for_loop.py +++ b/langchain/chains/llm_for_loop.py @@ -1,10 +1,11 @@ +"""Chain that first uses an LLM to generate multiple items then loops over them.""" +from typing import Any, Dict, List -from typing import Dict, List, Any - -from pydantic import BaseModel, Extra +from pydantic import BaseModel, Extra, root_validator from langchain.chains.base import Chain from langchain.chains.llm import LLMChain +from langchain.prompts.base import ListOutputParser class LLMForLoopChain(Chain, BaseModel): @@ -38,6 +39,17 @@ class LLMForLoopChain(Chain, BaseModel): """ return [self.output_key] + @root_validator() + def validate_output_parser(cls, values: Dict) -> Dict: + """Validate that the correct inputs exist for all chains.""" + chain = values["llm_chain"] + if not isinstance(chain.prompt.output_parser, ListOutputParser): + raise ValueError( + f"The OutputParser on the base prompt should be of type " + f"ListOutputParser, got {type(chain.prompt.output_parser)}" + ) + return values + def run_list(self, **kwargs: Any) -> List[str]: """Get list from LLM chain and then run chain on each item.""" output_items = self.llm_chain.predict_and_parse(**kwargs) diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index 849ea926e62..22a385ec1fa 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -1,8 +1,8 @@ """BasePrompt schema definition.""" from abc import ABC, abstractmethod -from typing import Any, Dict, List, Union, Optional +from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, Field, root_validator, Extra +from pydantic import BaseModel, Extra, root_validator from langchain.formatting import formatter @@ -37,6 +37,13 @@ class BaseOutputParser(ABC): """Parse the output of an LLM call.""" +class ListOutputParser(ABC): + """Class to parse the output of an LLM call to a list.""" + + @abstractmethod + def parse(self, text: str) -> List[str]: + """Parse the output of an LLM call.""" + class BasePromptTemplate(BaseModel, ABC): """Base prompt should expose the format method, returning a prompt."""