mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-24 04:36:46 +00:00
simplify guardrail abstraction; clean up parsing, guardrails in LLMChain
This commit is contained in:
@@ -7,16 +7,16 @@ from pydantic import BaseModel, Extra, Field
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.output_parsers.base import OutputGuardrail
|
||||
from langchain.output_parsers.base import BaseOutputParser, OutputParserException
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import (
|
||||
BaseLanguageModel,
|
||||
Guardrail,
|
||||
LLMResult,
|
||||
PromptValue,
|
||||
ValidationError,
|
||||
)
|
||||
from langchain.guardrails import Guardrail
|
||||
from langchain.guardrails.utils import dumb_davinci_retry
|
||||
|
||||
|
||||
class LLMChain(Chain, BaseModel):
|
||||
@@ -37,7 +37,8 @@ class LLMChain(Chain, BaseModel):
|
||||
"""Prompt object to use."""
|
||||
llm: BaseLanguageModel
|
||||
output_key: str = "text" #: :meta private:
|
||||
output_parser: Optional[OutputGuardrail] = None
|
||||
output_parser: Optional[BaseOutputParser] = None
|
||||
output_parser_retry_enabled: bool = False
|
||||
guardrails: List[Guardrail] = Field(default_factory=list)
|
||||
|
||||
class Config:
|
||||
@@ -136,20 +137,39 @@ class LLMChain(Chain, BaseModel):
|
||||
response, prompts = await self.agenerate(input_list)
|
||||
return self.create_outputs(response, prompts)
|
||||
|
||||
def _get_final_output(self, text: str, prompt_value: PromptValue) -> Any:
|
||||
result: Any = text
|
||||
def _get_final_output(self, completion: str, prompt_value: PromptValue) -> Any:
|
||||
"""Validate raw completion (guardrails) + extract structured data from it (parser).
|
||||
|
||||
We may want to apply guardrails not just to raw string completions, but also to structured parsed completions.
|
||||
For this 1st attempt, we'll keep this simple.
|
||||
"""
|
||||
for guardrail in self.guardrails:
|
||||
if isinstance(guardrail, OutputGuardrail):
|
||||
try:
|
||||
result = guardrail.output_parser.parse(result)
|
||||
error = None
|
||||
except Exception as e:
|
||||
error = ValidationError(text=e)
|
||||
else:
|
||||
error = guardrail.check(prompt_value, result)
|
||||
if error is not None:
|
||||
result = guardrail.fix(prompt_value, result, error)
|
||||
return result
|
||||
evaluation, ok = guardrail.evaluate(prompt_value, completion)
|
||||
if not ok:
|
||||
# TODO: consider associating customer exception w/ guardrail
|
||||
# as suggested in https://github.com/hwchase17/langchain/pull/1683/files#r1139987185
|
||||
raise RuntimeError(evaluation.error_msg)
|
||||
if evaluation.revised_output:
|
||||
assert isinstance(evaluation.revised_output, str)
|
||||
completion = evaluation.revised_output
|
||||
|
||||
if self.output_parser:
|
||||
try:
|
||||
parsed_completion = self.output_parser.parse(completion)
|
||||
except OutputParserException as e:
|
||||
if self.output_parser_retry_enabled:
|
||||
_text = f"Uh-oh! Got {e}. Retrying with DaVinci."
|
||||
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
|
||||
retried_completion = dumb_davinci_retry(prompt_value.to_string(), completion)
|
||||
parsed_completion = self.output_parser.parse(retried_completion)
|
||||
else:
|
||||
raise e
|
||||
|
||||
completion = parsed_completion
|
||||
|
||||
return completion
|
||||
|
||||
|
||||
def create_outputs(
|
||||
self, response: LLMResult, prompts: List[PromptValue]
|
||||
@@ -196,6 +216,7 @@ class LLMChain(Chain, BaseModel):
|
||||
"""
|
||||
return (await self.acall(kwargs))[self.output_key]
|
||||
|
||||
# TODO: if an output_parser is provided, it should always be applied. remove these methods.
|
||||
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Call predict and then parse the results."""
|
||||
result = self.predict(**kwargs)
|
||||
|
||||
1
langchain/guardrails/__init__.py
Normal file
1
langchain/guardrails/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from langchain.guardrails.base import Guardrail, GuardrailEvaluation
|
||||
24
langchain/guardrails/base.py
Normal file
24
langchain/guardrails/base.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GuardrailEvaluation(BaseModel):
|
||||
"""Hm want to encapsulate the result of applying a guardrail
|
||||
"""
|
||||
# It may fail.
|
||||
error_msg: str
|
||||
# Optionally, it may retry upon failure. The retry can also fail.
|
||||
revised_output: Any
|
||||
|
||||
|
||||
class Guardrail(ABC, BaseModel):
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(self, input: Any, output: Any) -> Tuple[Optional[GuardrailEvaluation], bool]:
|
||||
"""A generic guardrail on any function (a function that gets human input, an LM call, a chain, an agent, etc.)
|
||||
is evaluated against that function's input and output.
|
||||
|
||||
Evaluation includes a validation/verification step. It may also include a retry to generate a satisfactory revised output.
|
||||
"""
|
||||
16
langchain/guardrails/utils.py
Normal file
16
langchain/guardrails/utils.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
# TODO: perhaps prompt str -> PromptValue
|
||||
def dumb_davinci_retry(prompt: str, completion: str) -> str:
|
||||
"""Big model go brrrr.
|
||||
"""
|
||||
davinci = OpenAI(model_name='text-davinci-003', temperature=0.5)
|
||||
retry_prompt = PromptTemplate(
|
||||
template="Prompt:\n{prompt}\nCompletion:\n{completion}\n\nAbove, the Completion did not satisfy the constraints given in the Prompt. Please try again:",
|
||||
input_variables=["prompt", "completion"]
|
||||
)
|
||||
retry_prompt_str = retry_prompt.format_prompt(prompt=prompt, completion=completion).to_string()
|
||||
return davinci(retry_prompt_str)
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.schema import Fixer, Guardrail, PromptValue, ValidationError
|
||||
|
||||
|
||||
class BaseOutputParser(BaseModel, ABC):
|
||||
"""Class to parse the output of an LLM call."""
|
||||
@@ -30,20 +28,5 @@ class BaseOutputParser(BaseModel, ABC):
|
||||
return output_parser_dict
|
||||
|
||||
|
||||
class OutputGuardrail(Guardrail, BaseModel):
|
||||
output_parser: BaseOutputParser
|
||||
fixer: Fixer
|
||||
|
||||
def check(
|
||||
self, prompt_value: PromptValue, result: Any
|
||||
) -> Optional[ValidationError]:
|
||||
try:
|
||||
self.output_parser.parse(result)
|
||||
return None
|
||||
except Exception as e:
|
||||
return ValidationError(text=e)
|
||||
|
||||
def fix(
|
||||
self, prompt_value: PromptValue, result: Any, error: ValidationError
|
||||
) -> Any:
|
||||
return self.fix(prompt_value, result, error)
|
||||
class OutputParserException(Exception):
|
||||
pass
|
||||
@@ -4,7 +4,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from langchain.output_parsers.base import BaseOutputParser
|
||||
from langchain.output_parsers.base import BaseOutputParser, OutputParserException
|
||||
from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class PydanticOutputParser(BaseOutputParser):
|
||||
except (json.JSONDecodeError, ValidationError) as e:
|
||||
name = self.pydantic_object.__name__
|
||||
msg = f"Failed to parse {name} from completion {text}. Got: {e}"
|
||||
raise ValueError(msg)
|
||||
raise OutputParserException(msg)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
schema = self.pydantic_object.schema()
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.output_parsers.base import BaseOutputParser
|
||||
from langchain.output_parsers.base import BaseOutputParser, OutputParserException
|
||||
from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS
|
||||
|
||||
line_template = '\t"{name}": {type} // {description}'
|
||||
@@ -42,7 +42,7 @@ class StructuredOutputParser(BaseOutputParser):
|
||||
json_obj = json.loads(json_string)
|
||||
for schema in self.response_schemas:
|
||||
if schema.name not in json_obj:
|
||||
raise ValueError(
|
||||
raise OutputParserException(
|
||||
f"Got invalid return object. Expected key `{schema.name}` "
|
||||
f"to be present, but got {json_obj}"
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, NamedTuple, Optional
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
@@ -244,29 +244,3 @@ class BaseMemory(BaseModel, ABC):
|
||||
|
||||
|
||||
Memory = BaseMemory
|
||||
|
||||
|
||||
class ValidationError(BaseModel):
|
||||
error_message: str
|
||||
|
||||
|
||||
class Guardrail(ABC):
|
||||
@abstractmethod
|
||||
def check(
|
||||
self, prompt_value: PromptValue, result: Any
|
||||
) -> Optional[ValidationError]:
|
||||
"""Check whether there's a validation error."""
|
||||
|
||||
@abstractmethod
|
||||
def fix(
|
||||
self, prompt_value: PromptValue, result: Any, error: ValidationError
|
||||
) -> Any:
|
||||
""""""
|
||||
|
||||
|
||||
class Fixer(ABC):
|
||||
@abstractmethod
|
||||
def fix(
|
||||
self, prompt_value: PromptValue, result: Any, error: ValidationError
|
||||
) -> Any:
|
||||
""""""
|
||||
|
||||
Reference in New Issue
Block a user