Compare commits

...

1 Commits

Author SHA1 Message Date
Harrison Chase
7a918dadb4 use output parsers 2023-03-13 11:16:16 -07:00

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from pydantic import BaseModel, Extra
from pydantic import BaseModel, Extra, validator
from langchain.chains.base import Chain
from langchain.input import get_colored_text
@@ -29,8 +29,21 @@ class LLMChain(Chain, BaseModel):
prompt: BasePromptTemplate
"""Prompt object to use."""
llm: BaseLanguageModel
"""LLM wrapper to use."""
output_parsing_mode: str = "validate"
"""Output parsing mode, should be one of `validate`, `off`, `parse`."""
output_key: str = "text" #: :meta private:
@validator("output_parsing_mode")
def valid_output_parsing_mode(cls, v: str) -> str:
"""Validate output parsing mode."""
_valid_modes = {"off", "validate", "parse"}
if v not in _valid_modes:
raise ValueError(
f"Got `{v}` for output_parsing_mode, should be one of {_valid_modes}"
)
return v
class Config:
"""Configuration for this pydantic object."""
@@ -125,11 +138,20 @@ class LLMChain(Chain, BaseModel):
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
"""Create outputs from response."""
return [
outputs = []
_should_parse = self.output_parsing_mode != "off"
for generation in response.generations:
# Get the text of the top generated string.
{self.output_key: generation[0].text}
for generation in response.generations
]
response_item = generation[0].text
if self.prompt.output_parser is not None and _should_parse:
try:
parsed_output = self.prompt.output_parser.parse(response_item)
except Exception as e:
raise ValueError("Output of LLM not as expected") from e
if self.output_parsing_mode == "parse":
response_item = parsed_output
outputs.append({self.output_key: response_item})
return outputs
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return (await self.aapply([inputs]))[0]