This commit is contained in:
Harrison Chase 2022-12-01 16:27:54 -08:00
parent 3ef44f41b7
commit 9966fd0e05
2 changed files with 24 additions and 5 deletions

View File

@ -1,10 +1,11 @@
"""Chain that first uses an LLM to generate multiple items then loops over them."""
from typing import Any, Dict, List
from typing import Dict, List, Any from pydantic import BaseModel, Extra, root_validator
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.prompts.base import ListOutputParser
class LLMForLoopChain(Chain, BaseModel): class LLMForLoopChain(Chain, BaseModel):
@ -38,6 +39,17 @@ class LLMForLoopChain(Chain, BaseModel):
""" """
return [self.output_key] return [self.output_key]
@root_validator()
def validate_output_parser(cls, values: Dict) -> Dict:
"""Validate that the correct inputs exist for all chains."""
chain = values["llm_chain"]
if not isinstance(chain.prompt.output_parser, ListOutputParser):
raise ValueError(
f"The OutputParser on the base prompt should be of type "
f"ListOutputParser, got {type(chain.prompt.output_parser)}"
)
return values
def run_list(self, **kwargs: Any) -> List[str]: def run_list(self, **kwargs: Any) -> List[str]:
"""Get list from LLM chain and then run chain on each item.""" """Get list from LLM chain and then run chain on each item."""
output_items = self.llm_chain.predict_and_parse(**kwargs) output_items = self.llm_chain.predict_and_parse(**kwargs)

View File

@ -1,8 +1,8 @@
"""BasePrompt schema definition.""" """BasePrompt schema definition."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field, root_validator, Extra from pydantic import BaseModel, Extra, root_validator
from langchain.formatting import formatter from langchain.formatting import formatter
@ -37,6 +37,13 @@ class BaseOutputParser(ABC):
"""Parse the output of an LLM call.""" """Parse the output of an LLM call."""
class ListOutputParser(ABC):
"""Class to parse the output of an LLM call to a list."""
@abstractmethod
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
class BasePromptTemplate(BaseModel, ABC): class BasePromptTemplate(BaseModel, ABC):
"""Base prompt should expose the format method, returning a prompt.""" """Base prompt should expose the format method, returning a prompt."""