Compare commits

...

3 Commits

Author SHA1 Message Date
Nuno Campos
21ea1866de Suggested diff for guardrails error handling
- ValidationError is now a subclass of Exception
- Each OutputParser can now declare its own subclass of ValidationError that people can reference in a catch statement
2023-03-17 09:23:59 +00:00
Harrison Chase
e61e37a40a Update langchain/output_parsers/base.py
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
2023-03-16 00:26:45 -07:00
Harrison Chase
f267f59186 wip guardrails 2023-03-14 21:58:29 -07:00
6 changed files with 108 additions and 16 deletions

View File

@@ -52,7 +52,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings, BaseModel):
def embed_query(self, text: str) -> List[float]:
"""Generate a hypothetical document and embedded it."""
var_name = self.llm_chain.input_keys[0]
result = self.llm_chain.generate([{var_name: text}])
result, _ = self.llm_chain.generate([{var_name: text}])
documents = [generation.text for generation in result.generations[0]]
embeddings = self.embed_documents(documents)
return self.combine_embeddings(embeddings)

View File

@@ -3,13 +3,20 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from pydantic import BaseModel, Extra
from pydantic import BaseModel, Extra, Field
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.output_parsers.base import OutputGuardrail
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseLanguageModel, LLMResult, PromptValue
from langchain.schema import (
BaseLanguageModel,
Guardrail,
LLMResult,
PromptValue,
ValidationError,
)
class LLMChain(Chain, BaseModel):
@@ -30,6 +37,8 @@ class LLMChain(Chain, BaseModel):
"""Prompt object to use."""
llm: BaseLanguageModel
output_key: str = "text" #: :meta private:
output_parser: Optional[OutputGuardrail] = None
guardrails: List[Guardrail] = Field(default_factory=list)
class Config:
"""Configuration for this pydantic object."""
@@ -56,15 +65,19 @@ class LLMChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.apply([inputs])[0]
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
def generate(
self, input_list: List[Dict[str, Any]]
) -> Tuple[LLMResult, List[PromptValue]]:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list)
return self.llm.generate_prompt(prompts, stop)
return self.llm.generate_prompt(prompts, stop), prompts
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
async def agenerate(
self, input_list: List[Dict[str, Any]]
) -> Tuple[LLMResult, List[PromptValue]]:
"""Generate LLM result from inputs."""
prompts, stop = await self.aprep_prompts(input_list)
return await self.llm.agenerate_prompt(prompts, stop)
return await self.llm.agenerate_prompt(prompts, stop), prompts
def prep_prompts(
self, input_list: List[Dict[str, Any]]
@@ -115,20 +128,37 @@ class LLMChain(Chain, BaseModel):
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
response = self.generate(input_list)
return self.create_outputs(response)
response, prompts = self.generate(input_list)
return self.create_outputs(response, prompts)
async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
response = await self.agenerate(input_list)
return self.create_outputs(response)
response, prompts = await self.agenerate(input_list)
return self.create_outputs(response, prompts)
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
def _get_final_output(self, text: str, prompt_value: PromptValue) -> Any:
result: Any = text
for guardrail in self.guardrails:
if isinstance(guardrail, OutputGuardrail):
try:
result = guardrail.output_parser.parse(result)
error = None
except Exception as e:
error = ValidationError(text=e)
else:
error = guardrail.check(prompt_value, result)
if error is not None:
result = guardrail.fix(prompt_value, result, error)
return result
def create_outputs(
self, response: LLMResult, prompts: List[PromptValue]
) -> List[Dict[str, str]]:
"""Create outputs from response."""
return [
# Get the text of the top generated string.
{self.output_key: generation[0].text}
for generation in response.generations
{self.output_key: self._get_final_output(generation[0].text, prompts[i])}
for i, generation in enumerate(response.generations)
]
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:

View File

@@ -47,7 +47,7 @@ class QAGenerationChain(Chain):
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
docs = self.text_splitter.create_documents([inputs[self.input_key]])
results = self.llm_chain.generate([{"text": d.page_content} for d in docs])
results, _ = self.llm_chain.generate([{"text": d.page_content} for d in docs])
qa = [json.loads(res[0].text) for res in results.generations]
return {self.output_key: qa}

View File

@@ -1,14 +1,22 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict
from typing import Any, Dict, Optional
from pydantic import BaseModel
from langchain.schema import Fixer, Guardrail, PromptValue, ValidationError
class BaseOutputParser(BaseModel, ABC):
"""Class to parse the output of an LLM call."""
Exception: ValidationError = ValidationError
@property
def Exception(self) -> ValidationError:
return self.__class__.Exception
@abstractmethod
def parse(self, text: str) -> Any:
"""Parse the output of an LLM call."""
@@ -26,3 +34,22 @@ class BaseOutputParser(BaseModel, ABC):
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict
class OutputGuardrail(Guardrail, BaseModel):
output_parser: BaseOutputParser
fixer: Fixer
def check(
self, prompt_value: PromptValue, result: Any
) -> Optional[ValidationError]:
try:
self.output_parser.parse(result)
return None
except Exception as e:
return self.output_parser.Exception(text=e)
def fix(
self, prompt_value: PromptValue, result: Any, error: ValidationError
) -> Any:
return self.fixer(prompt_value, result, error)

View File

@@ -7,6 +7,7 @@ from pydantic import BaseModel
from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS
from langchain.schema import ValidationError
line_template = '\t"{name}": {type} // {description}'
@@ -22,7 +23,13 @@ def _get_sub_string(schema: ResponseSchema) -> str:
)
class StructuredOutputParserException(ValidationError):
pass
class StructuredOutputParser(BaseOutputParser):
Exception = StructuredOutputParserException
response_schemas: List[ResponseSchema]
@classmethod

View File

@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, List, NamedTuple, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from pydantic.dataclasses import dataclass
class AgentAction(NamedTuple):
@@ -220,3 +221,30 @@ class BaseMemory(BaseModel, ABC):
Memory = BaseMemory
@dataclass
class ValidationError(Exception):
error_message: str
class Guardrail(ABC):
@abstractmethod
def check(
self, prompt_value: PromptValue, result: Any
) -> Optional[ValidationError]:
"""Check whether there's a validation error."""
@abstractmethod
def fix(
self, prompt_value: PromptValue, result: Any, error: ValidationError
) -> Any:
""""""
class Fixer(ABC):
@abstractmethod
def fix(
self, prompt_value: PromptValue, result: Any, error: ValidationError
) -> Any:
""""""