Compare commits

...

2 Commits

Author SHA1 Message Date
Harrison Chase
b789f8696d Merge branch 'master' into harrison/comp-prompt 2023-05-30 15:26:39 -07:00
Harrison Chase
e0259daee6 stash 2023-05-30 06:27:51 -07:00
2 changed files with 28 additions and 9 deletions

View File

@@ -12,6 +12,8 @@ from pydantic import BaseModel, Extra, Field, root_validator
from langchain.formatting import formatter
from langchain.schema import BaseMessage, BaseOutputParser, HumanMessage, PromptValue
ACCEPTABLE_PARTIAL_TYPES = Union[str, Callable[[], str], "BasePromptTemplate"]
def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2."""
@@ -87,6 +89,9 @@ def check_valid_template(
+ str(e)
)
def filter_keys(_dict: Dict, keys: List[str]):
return {k: _dict[k] for k in keys}
class StringPromptValue(PromptValue):
text: str
@@ -107,7 +112,7 @@ class BasePromptTemplate(BaseModel, ABC):
"""A list of the names of the variables the prompt template expects."""
output_parser: Optional[BaseOutputParser] = None
"""How to parse the output of calling an LLM on this formatted prompt."""
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
partial_variables: Mapping[str, ACCEPTABLE_PARTIAL_TYPES] = Field(
default_factory=dict
)
@@ -144,21 +149,29 @@ class BasePromptTemplate(BaseModel, ABC):
)
return values
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
def partial(self, **kwargs: ACCEPTABLE_PARTIAL_TYPES) -> BasePromptTemplate:
"""Return a partial of the prompt template."""
prompt_dict = self.__dict__.copy()
prompt_dict["input_variables"] = list(
set(self.input_variables).difference(kwargs)
)
input_variables = set(self.input_variables).difference(kwargs)
for v in kwargs.items():
if isinstance(v, BasePromptTemplate):
input_variables |= set(v.input_variables)
prompt_dict["input_variables"] = list(input_variables)
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
# Get partial params:
partial_kwargs = {
k: v if isinstance(v, str) else v()
for k, v in self.partial_variables.items()
}
partial_kwargs = {}
for k, v in self.partial_variables.items():
if isinstance(v, str):
partial_kwargs[k] = v
elif isinstance(v, BasePromptTemplate):
partial_kwargs[k] = v.format(**filter_keys(kwargs, v.input_variables))
elif isinstance(v, Callable):
partial_kwargs[k] = v()
else:
raise ValueError("Got unexpected partial type")
return {**partial_kwargs, **kwargs}
@abstractmethod

View File

@@ -145,6 +145,12 @@ def test_partial() -> None:
assert new_result == "This is a 3 test."
result = prompt.format(foo="foo")
assert result == "This is a foo test."
# Test partialing with prompt templates
compose_prompt = PromptTemplate(input_variables=["bar"], template="sad and {bar}")
new_prompt = prompt.partial(foo=compose_prompt)
result = new_prompt.format(bar="happy")
assert result == "This is a sad and happy test."
def test_prompt_from_jinja2_template() -> None: