diff --git a/langchain/guards/restriction.py b/langchain/guards/restriction.py index 45eb58abc37..bba4713913e 100644 --- a/langchain/guards/restriction.py +++ b/langchain/guards/restriction.py @@ -1,11 +1,13 @@ """Check if chain or agent violates one or more restrictions.""" -from typing import Any, Callable, List, Tuple +from __future__ import annotations + +from typing import Any, List, Tuple from langchain.chains.llm import LLMChain from langchain.guards.base import BaseGuard from langchain.guards.restriction_prompt import RESTRICTION_PROMPT from langchain.llms.base import BaseLLM -from langchain.output_parsing.boolean import BooleanOutputParser +from langchain.output_parsers.boolean import BooleanOutputParser from langchain.prompts.base import BasePromptTemplate @@ -60,7 +62,7 @@ class RestrictionGuard(BaseGuard): llm: BaseLLM, prompt: BasePromptTemplate = RESTRICTION_PROMPT, **kwargs: Any, - ): + ) -> RestrictionGuard: """Load from llm and prompt.""" guard_chain = LLMChain(llm=llm, prompt=prompt) return cls(guard_chain=guard_chain, **kwargs) diff --git a/langchain/guards/string.py b/langchain/guards/string.py index 91d8d383c95..fe4d045c838 100644 --- a/langchain/guards/string.py +++ b/langchain/guards/string.py @@ -38,7 +38,8 @@ class StringGuard(BaseGuard): called recursively if the output violates the restrictions. Defaults to 0. Raises: - Exception: If the output violates the restrictions and the maximum number of retries has been exceeded. + Exception: If the output violates the restrictions and the maximum number of + retries has been exceeded. Example: .. code-block:: python diff --git a/langchain/output_parsers/__init__.py b/langchain/output_parsers/__init__.py index 8509b6f2384..a5720178af7 100644 --- a/langchain/output_parsers/__init__.py +++ b/langchain/output_parsers/__init__.py @@ -1,4 +1,5 @@ from langchain.output_parsers.base import BaseOutputParser +from langchain.output_parsers.boolean import BooleanOutputParser from langchain.output_parsers.list import ( CommaSeparatedListOutputParser, ListOutputParser, @@ -10,4 +11,5 @@ __all__ = [ "ListOutputParser", "CommaSeparatedListOutputParser", "BaseOutputParser", + "BooleanOutputParser", ] diff --git a/langchain/output_parsing/boolean.py b/langchain/output_parsers/boolean.py similarity index 97% rename from langchain/output_parsing/boolean.py rename to langchain/output_parsers/boolean.py index fdfeecf8f9e..cd96ff5e84a 100644 --- a/langchain/output_parsing/boolean.py +++ b/langchain/output_parsers/boolean.py @@ -4,7 +4,7 @@ from typing import Dict, List from pydantic import Field, root_validator -from langchain.output_parsing.base import BaseOutputParser +from langchain.output_parsers.base import BaseOutputParser class BooleanOutputParser(BaseOutputParser): diff --git a/langchain/output_parsing/__init__.py b/langchain/output_parsing/__init__.py deleted file mode 100644 index 011ea33cfb2..00000000000 --- a/langchain/output_parsing/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Classes to parse the output of an LLM call.""" -from langchain.output_parsing.base import BaseOutputParser -from langchain.output_parsing.boolean import BooleanOutputParser -from langchain.output_parsing.json import JsonOutputParser -from langchain.output_parsing.list import ListOutputParser -from langchain.output_parsing.regex import RegexParser diff --git a/langchain/output_parsing/base.py b/langchain/output_parsing/base.py deleted file mode 100644 index 37e95bf6d28..00000000000 --- a/langchain/output_parsing/base.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Class to parse the output of an LLM call.""" -from abc import ABC, abstractmethod -from typing import Any, Dict, List - -from pydantic import BaseModel - - -class BaseOutputParser(BaseModel, ABC): - """Class to parse the output of an LLM call.""" - - @abstractmethod - def parse(self, text: str) -> Any: - """Parse the output of an LLM call.""" - - @property - def _type(self) -> str: - """Return the type key.""" - raise NotImplementedError - - def dict(self, **kwargs: Any) -> Dict: - """Return dictionary representation of output parser.""" - output_parser_dict = super().dict() - output_parser_dict["_type"] = self._type - return output_parser_dict - - -class ListOutputParser(BaseOutputParser): - """Class to parse the output of an LLM call to a list.""" - - @abstractmethod - def parse(self, text: str) -> List[str]: - """Parse the output of an LLM call.""" diff --git a/langchain/output_parsing/json.py b/langchain/output_parsing/json.py deleted file mode 100644 index 19ff82f8f53..00000000000 --- a/langchain/output_parsing/json.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Parse json output.""" -import json -from typing import Dict, List, Union - -from langchain.output_parsing.base import BaseOutputParser - - -class JsonOutputParser(BaseOutputParser): - """Parse json output.""" - - def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]: - """Parse json string.""" - return json.loads(text) diff --git a/langchain/output_parsing/list.py b/langchain/output_parsing/list.py deleted file mode 100644 index 1bae127704a..00000000000 --- a/langchain/output_parsing/list.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Parse out comma separated lists.""" -from typing import List - -from langchain.output_parsing.base import ListOutputParser - - -class CommaSeparatedListOutputParser(ListOutputParser): - """Parse out comma separated lists.""" - - def parse(self, text: str) -> List[str]: - """Parse the output of an LLM call.""" - return text.strip().split(", ") diff --git a/langchain/output_parsing/regex.py b/langchain/output_parsing/regex.py deleted file mode 100644 index 294006c7b5e..00000000000 --- a/langchain/output_parsing/regex.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Class to parse the output into a dictionary.""" -import re -from typing import Dict, List, Optional - -from pydantic import BaseModel - -from langchain.output_parsing.base import BaseOutputParser - - -class RegexParser(BaseOutputParser, BaseModel): - """Class to parse the output into a dictionary.""" - - regex: str - output_keys: List[str] - default_output_key: Optional[str] = None - - @property - def _type(self) -> str: - """Return the type key.""" - return "regex_parser" - - def parse(self, text: str) -> Dict[str, str]: - """Parse the output of an LLM call.""" - match = re.search(self.regex, text) - if match: - return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)} - else: - if self.default_output_key is None: - raise ValueError(f"Could not parse output: {text}") - else: - return { - key: text if key == self.default_output_key else "" - for key in self.output_keys - } diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index 4a63b6209d3..b85f31613aa 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -59,6 +59,18 @@ def check_valid_template( ) +class StringPromptValue(PromptValue): + text: str + + def to_string(self) -> str: + """Return prompt as string.""" + return self.text + + def to_messages(self) -> List[BaseMessage]: + """Return prompt as messages.""" + return [HumanMessage(content=self.text)] + + class BasePromptTemplate(BaseModel, ABC): """Base prompt should expose the format method, returning a prompt.""" diff --git a/tests/unit_tests/guards/test_restriction.py b/tests/unit_tests/guards/test_restriction.py index e802e4d8aeb..4ad3afe780d 100644 --- a/tests/unit_tests/guards/test_restriction.py +++ b/tests/unit_tests/guards/test_restriction.py @@ -3,6 +3,7 @@ from typing import List import pytest from langchain.guards.restriction import RestrictionGuard +from langchain.guards.restriction_prompt import RESTRICTION_PROMPT from tests.unit_tests.llms.fake_llm import FakeLLM @@ -19,7 +20,7 @@ def test_restriction_guard() -> None: ) -> str: concatenated_restrictions = ", ".join(restrictions) queries = { - RestrictionGuard.prompt.format( + RESTRICTION_PROMPT.format( restrictions=concatenated_restrictions, function_output=llm_input_output ): "restricted because I said so :) (¥)" if restricted @@ -27,7 +28,7 @@ def test_restriction_guard() -> None: } restriction_guard_llm = FakeLLM(queries=queries) - @RestrictionGuard( + @RestrictionGuard.from_llm( restrictions=restrictions, llm=restriction_guard_llm, retries=0 ) def example_func(prompt: str) -> str: diff --git a/tests/unit_tests/output_parsing/test_boolean.py b/tests/unit_tests/output_parsing/test_boolean.py index 482f4c2f058..c7b804a2ba6 100644 --- a/tests/unit_tests/output_parsing/test_boolean.py +++ b/tests/unit_tests/output_parsing/test_boolean.py @@ -2,7 +2,7 @@ from typing import List import pytest -from langchain.output_parsing.boolean import BooleanOutputParser +from langchain.output_parsers.boolean import BooleanOutputParser GOOD_EXAMPLES = [ ("0", False, ["1"], ["0"]),