simplify guardrail abstraction; clean up parsing, guardrails in LLMChain

This commit is contained in:
jerwelborn
2023-03-18 17:51:03 -07:00
parent 77398c3c67
commit 3e756b75b3
8 changed files with 87 additions and 68 deletions

View File

@@ -7,16 +7,16 @@ from pydantic import BaseModel, Extra, Field
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.output_parsers.base import OutputGuardrail
from langchain.output_parsers.base import BaseOutputParser, OutputParserException
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
BaseLanguageModel,
Guardrail,
LLMResult,
PromptValue,
ValidationError,
)
from langchain.guardrails import Guardrail
from langchain.guardrails.utils import dumb_davinci_retry
class LLMChain(Chain, BaseModel):
@@ -37,7 +37,8 @@ class LLMChain(Chain, BaseModel):
"""Prompt object to use."""
llm: BaseLanguageModel
output_key: str = "text" #: :meta private:
output_parser: Optional[OutputGuardrail] = None
output_parser: Optional[BaseOutputParser] = None
output_parser_retry_enabled: bool = False
guardrails: List[Guardrail] = Field(default_factory=list)
class Config:
@@ -136,20 +137,39 @@ class LLMChain(Chain, BaseModel):
response, prompts = await self.agenerate(input_list)
return self.create_outputs(response, prompts)
def _get_final_output(self, text: str, prompt_value: PromptValue) -> Any:
result: Any = text
def _get_final_output(self, completion: str, prompt_value: PromptValue) -> Any:
"""Validate raw completion (guardrails) + extract structured data from it (parser).
We may want to apply guardrails not just to raw string completions, but also to structured parsed completions.
For this 1st attempt, we'll keep this simple.
"""
for guardrail in self.guardrails:
if isinstance(guardrail, OutputGuardrail):
try:
result = guardrail.output_parser.parse(result)
error = None
except Exception as e:
error = ValidationError(text=e)
else:
error = guardrail.check(prompt_value, result)
if error is not None:
result = guardrail.fix(prompt_value, result, error)
return result
evaluation, ok = guardrail.evaluate(prompt_value, completion)
if not ok:
# TODO: consider associating customer exception w/ guardrail
# as suggested in https://github.com/hwchase17/langchain/pull/1683/files#r1139987185
raise RuntimeError(evaluation.error_msg)
if evaluation.revised_output:
assert isinstance(evaluation.revised_output, str)
completion = evaluation.revised_output
if self.output_parser:
try:
parsed_completion = self.output_parser.parse(completion)
except OutputParserException as e:
if self.output_parser_retry_enabled:
_text = f"Uh-oh! Got {e}. Retrying with DaVinci."
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
retried_completion = dumb_davinci_retry(prompt_value.to_string(), completion)
parsed_completion = self.output_parser.parse(retried_completion)
else:
raise e
completion = parsed_completion
return completion
def create_outputs(
self, response: LLMResult, prompts: List[PromptValue]
@@ -196,6 +216,7 @@ class LLMChain(Chain, BaseModel):
"""
return (await self.acall(kwargs))[self.output_key]
# TODO: if an output_parser is provided, it should always be applied. remove these methods.
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
"""Call predict and then parse the results."""
result = self.predict(**kwargs)

View File

@@ -0,0 +1 @@
from langchain.guardrails.base import Guardrail, GuardrailEvaluation

View File

@@ -0,0 +1,24 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple
from pydantic import BaseModel
class GuardrailEvaluation(BaseModel):
"""Hm want to encapsulate the result of applying a guardrail
"""
# It may fail.
error_msg: str
# Optionally, it may retry upon failure. The retry can also fail.
revised_output: Any
class Guardrail(ABC, BaseModel):
@abstractmethod
def evaluate(self, input: Any, output: Any) -> Tuple[Optional[GuardrailEvaluation], bool]:
"""A generic guardrail on any function (a function that gets human input, an LM call, a chain, an agent, etc.)
is evaluated against that function's input and output.
Evaluation includes a validation/verification step. It may also include a retry to generate a satisfactory revised output.
"""

View File

@@ -0,0 +1,16 @@
from langchain.llms import OpenAI
from langchain.prompts.prompt import PromptTemplate
# TODO: perhaps prompt str -> PromptValue
def dumb_davinci_retry(prompt: str, completion: str) -> str:
"""Big model go brrrr.
"""
davinci = OpenAI(model_name='text-davinci-003', temperature=0.5)
retry_prompt = PromptTemplate(
template="Prompt:\n{prompt}\nCompletion:\n{completion}\n\nAbove, the Completion did not satisfy the constraints given in the Prompt. Please try again:",
input_variables=["prompt", "completion"]
)
retry_prompt_str = retry_prompt.format_prompt(prompt=prompt, completion=completion).to_string()
return davinci(retry_prompt_str)

View File

@@ -1,12 +1,10 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from typing import Any, Dict
from pydantic import BaseModel
from langchain.schema import Fixer, Guardrail, PromptValue, ValidationError
class BaseOutputParser(BaseModel, ABC):
"""Class to parse the output of an LLM call."""
@@ -30,20 +28,5 @@ class BaseOutputParser(BaseModel, ABC):
return output_parser_dict
class OutputGuardrail(Guardrail, BaseModel):
output_parser: BaseOutputParser
fixer: Fixer
def check(
self, prompt_value: PromptValue, result: Any
) -> Optional[ValidationError]:
try:
self.output_parser.parse(result)
return None
except Exception as e:
return ValidationError(text=e)
def fix(
self, prompt_value: PromptValue, result: Any, error: ValidationError
) -> Any:
return self.fix(prompt_value, result, error)
class OutputParserException(Exception):
pass

View File

@@ -4,7 +4,7 @@ from typing import Any
from pydantic import BaseModel, ValidationError
from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.base import BaseOutputParser, OutputParserException
from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS
@@ -24,7 +24,7 @@ class PydanticOutputParser(BaseOutputParser):
except (json.JSONDecodeError, ValidationError) as e:
name = self.pydantic_object.__name__
msg = f"Failed to parse {name} from completion {text}. Got: {e}"
raise ValueError(msg)
raise OutputParserException(msg)
def get_format_instructions(self) -> str:
schema = self.pydantic_object.schema()

View File

@@ -5,7 +5,7 @@ from typing import List
from pydantic import BaseModel
from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.base import BaseOutputParser, OutputParserException
from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS
line_template = '\t"{name}": {type} // {description}'
@@ -42,7 +42,7 @@ class StructuredOutputParser(BaseOutputParser):
json_obj = json.loads(json_string)
for schema in self.response_schemas:
if schema.name not in json_obj:
raise ValueError(
raise OutputParserException(
f"Got invalid return object. Expected key `{schema.name}` "
f"to be present, but got {json_obj}"
)

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List, NamedTuple, Optional
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
from pydantic import BaseModel, Extra, Field, root_validator
@@ -244,29 +244,3 @@ class BaseMemory(BaseModel, ABC):
Memory = BaseMemory
class ValidationError(BaseModel):
error_message: str
class Guardrail(ABC):
@abstractmethod
def check(
self, prompt_value: PromptValue, result: Any
) -> Optional[ValidationError]:
"""Check whether there's a validation error."""
@abstractmethod
def fix(
self, prompt_value: PromptValue, result: Any, error: ValidationError
) -> Any:
""""""
class Fixer(ABC):
@abstractmethod
def fix(
self, prompt_value: PromptValue, result: Any, error: ValidationError
) -> Any:
""""""