mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-06 11:37:12 +00:00
cr
This commit is contained in:
parent
3ef44f41b7
commit
9966fd0e05
@ -1,10 +1,11 @@
|
||||
"""Chain that first uses an LLM to generate multiple items then loops over them."""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import ListOutputParser
|
||||
|
||||
|
||||
class LLMForLoopChain(Chain, BaseModel):
|
||||
@ -38,6 +39,17 @@ class LLMForLoopChain(Chain, BaseModel):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
@root_validator()
|
||||
def validate_output_parser(cls, values: Dict) -> Dict:
|
||||
"""Validate that the correct inputs exist for all chains."""
|
||||
chain = values["llm_chain"]
|
||||
if not isinstance(chain.prompt.output_parser, ListOutputParser):
|
||||
raise ValueError(
|
||||
f"The OutputParser on the base prompt should be of type "
|
||||
f"ListOutputParser, got {type(chain.prompt.output_parser)}"
|
||||
)
|
||||
return values
|
||||
|
||||
def run_list(self, **kwargs: Any) -> List[str]:
|
||||
"""Get list from LLM chain and then run chain on each item."""
|
||||
output_items = self.llm_chain.predict_and_parse(**kwargs)
|
||||
|
@ -1,8 +1,8 @@
|
||||
"""BasePrompt schema definition."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator, Extra
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.formatting import formatter
|
||||
|
||||
@ -37,6 +37,13 @@ class BaseOutputParser(ABC):
|
||||
"""Parse the output of an LLM call."""
|
||||
|
||||
|
||||
class ListOutputParser(ABC):
|
||||
"""Class to parse the output of an LLM call to a list."""
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
|
||||
|
||||
class BasePromptTemplate(BaseModel, ABC):
|
||||
"""Base prompt should expose the format method, returning a prompt."""
|
||||
|
Loading…
Reference in New Issue
Block a user