This commit is contained in:
Harrison Chase
2023-03-13 09:20:14 -07:00
parent 735d465abf
commit c8dca75ae3
12 changed files with 26 additions and 105 deletions

View File

@@ -1,11 +1,13 @@
"""Check if chain or agent violates one or more restrictions."""
from typing import Any, Callable, List, Tuple
from __future__ import annotations
from typing import Any, List, Tuple
from langchain.chains.llm import LLMChain
from langchain.guards.base import BaseGuard
from langchain.guards.restriction_prompt import RESTRICTION_PROMPT
from langchain.llms.base import BaseLLM
from langchain.output_parsing.boolean import BooleanOutputParser
from langchain.output_parsers.boolean import BooleanOutputParser
from langchain.prompts.base import BasePromptTemplate
@@ -60,7 +62,7 @@ class RestrictionGuard(BaseGuard):
llm: BaseLLM,
prompt: BasePromptTemplate = RESTRICTION_PROMPT,
**kwargs: Any,
):
) -> RestrictionGuard:
"""Load from llm and prompt."""
guard_chain = LLMChain(llm=llm, prompt=prompt)
return cls(guard_chain=guard_chain, **kwargs)

View File

@@ -38,7 +38,8 @@ class StringGuard(BaseGuard):
called recursively if the output violates the restrictions. Defaults to 0.
Raises:
Exception: If the output violates the restrictions and the maximum number of retries has been exceeded.
Exception: If the output violates the restrictions and the maximum number of
retries has been exceeded.
Example:
.. code-block:: python

View File

@@ -1,4 +1,5 @@
from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.boolean import BooleanOutputParser
from langchain.output_parsers.list import (
CommaSeparatedListOutputParser,
ListOutputParser,
@@ -10,4 +11,5 @@ __all__ = [
"ListOutputParser",
"CommaSeparatedListOutputParser",
"BaseOutputParser",
"BooleanOutputParser",
]

View File

@@ -4,7 +4,7 @@ from typing import Dict, List
from pydantic import Field, root_validator
from langchain.output_parsing.base import BaseOutputParser
from langchain.output_parsers.base import BaseOutputParser
class BooleanOutputParser(BaseOutputParser):

View File

@@ -1,6 +0,0 @@
"""Classes to parse the output of an LLM call."""
from langchain.output_parsing.base import BaseOutputParser
from langchain.output_parsing.boolean import BooleanOutputParser
from langchain.output_parsing.json import JsonOutputParser
from langchain.output_parsing.list import ListOutputParser
from langchain.output_parsing.regex import RegexParser

View File

@@ -1,32 +0,0 @@
"""Class to parse the output of an LLM call."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from pydantic import BaseModel
class BaseOutputParser(BaseModel, ABC):
"""Class to parse the output of an LLM call."""
@abstractmethod
def parse(self, text: str) -> Any:
"""Parse the output of an LLM call."""
@property
def _type(self) -> str:
"""Return the type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict
class ListOutputParser(BaseOutputParser):
"""Class to parse the output of an LLM call to a list."""
@abstractmethod
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""

View File

@@ -1,13 +0,0 @@
"""Parse json output."""
import json
from typing import Dict, List, Union
from langchain.output_parsing.base import BaseOutputParser
class JsonOutputParser(BaseOutputParser):
"""Parse json output."""
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
"""Parse json string."""
return json.loads(text)

View File

@@ -1,12 +0,0 @@
"""Parse out comma separated lists."""
from typing import List
from langchain.output_parsing.base import ListOutputParser
class CommaSeparatedListOutputParser(ListOutputParser):
"""Parse out comma separated lists."""
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
return text.strip().split(", ")

View File

@@ -1,34 +0,0 @@
"""Class to parse the output into a dictionary."""
import re
from typing import Dict, List, Optional
from pydantic import BaseModel
from langchain.output_parsing.base import BaseOutputParser
class RegexParser(BaseOutputParser, BaseModel):
"""Class to parse the output into a dictionary."""
regex: str
output_keys: List[str]
default_output_key: Optional[str] = None
@property
def _type(self) -> str:
"""Return the type key."""
return "regex_parser"
def parse(self, text: str) -> Dict[str, str]:
"""Parse the output of an LLM call."""
match = re.search(self.regex, text)
if match:
return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)}
else:
if self.default_output_key is None:
raise ValueError(f"Could not parse output: {text}")
else:
return {
key: text if key == self.default_output_key else ""
for key in self.output_keys
}

View File

@@ -59,6 +59,18 @@ def check_valid_template(
)
class StringPromptValue(PromptValue):
text: str
def to_string(self) -> str:
"""Return prompt as string."""
return self.text
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
return [HumanMessage(content=self.text)]
class BasePromptTemplate(BaseModel, ABC):
"""Base prompt should expose the format method, returning a prompt."""

View File

@@ -3,6 +3,7 @@ from typing import List
import pytest
from langchain.guards.restriction import RestrictionGuard
from langchain.guards.restriction_prompt import RESTRICTION_PROMPT
from tests.unit_tests.llms.fake_llm import FakeLLM
@@ -19,7 +20,7 @@ def test_restriction_guard() -> None:
) -> str:
concatenated_restrictions = ", ".join(restrictions)
queries = {
RestrictionGuard.prompt.format(
RESTRICTION_PROMPT.format(
restrictions=concatenated_restrictions, function_output=llm_input_output
): "restricted because I said so :) (¥)"
if restricted
@@ -27,7 +28,7 @@ def test_restriction_guard() -> None:
}
restriction_guard_llm = FakeLLM(queries=queries)
@RestrictionGuard(
@RestrictionGuard.from_llm(
restrictions=restrictions, llm=restriction_guard_llm, retries=0
)
def example_func(prompt: str) -> str:

View File

@@ -2,7 +2,7 @@ from typing import List
import pytest
from langchain.output_parsing.boolean import BooleanOutputParser
from langchain.output_parsers.boolean import BooleanOutputParser
GOOD_EXAMPLES = [
("0", False, ["1"], ["0"]),