Compare commits

...

1 Commits

Author SHA1 Message Date
Mason Daugherty
69992c32f4 rules and fixes 2025-07-08 22:07:55 -04:00
14 changed files with 868 additions and 797 deletions

View File

@@ -33,13 +33,13 @@ lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && uv run --all-groups mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff --fix $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff check --fix $(PYTHON_FILES)
spell_check:
uv run --all-groups codespell --toml pyproject.toml

View File

@@ -110,24 +110,20 @@ class Prompty(BaseModel):
return json.dumps(d)
@staticmethod
def normalize(attribute: Any, parent: Path, env_error: bool = True) -> Any:
def normalize(attribute: Any, parent: Path, env_error: bool = True) -> Any: # noqa: FBT001, FBT002
if isinstance(attribute, str):
attribute = attribute.strip()
if attribute.startswith("${") and attribute.endswith("}"):
variable = attribute[2:-1].split(":")
if variable[0] in os.environ.keys():
if variable[0] in os.environ:
return os.environ[variable[0]]
else:
if len(variable) > 1:
return variable[1]
else:
if env_error:
raise ValueError(
f"Variable {variable[0]} not found in environment"
)
else:
return ""
elif (
if len(variable) > 1:
return variable[1]
if env_error:
msg = f"Variable {variable[0]} not found in environment"
raise ValueError(msg)
return ""
if (
attribute.startswith("file:")
and Path(parent / attribute.split(":")[1]).exists()
):
@@ -135,13 +131,12 @@ class Prompty(BaseModel):
items = json.load(f)
if isinstance(items, list):
return [Prompty.normalize(value, parent) for value in items]
elif isinstance(items, dict):
if isinstance(items, dict):
return {
key: Prompty.normalize(value, parent)
for key, value in items.items()
}
else:
return items
return items
else:
return attribute
elif isinstance(attribute, list):
@@ -167,11 +162,9 @@ def param_hoisting(
Returns:
The merged dictionary.
"""
if top_key:
new_dict = {**top[top_key]} if top_key in top else {}
else:
new_dict = {**top}
new_dict = ({**top[top_key]} if top_key in top else {}) if top_key else {**top}
for key, value in bottom.items():
if key not in new_dict:
new_dict[key] = value
@@ -208,7 +201,7 @@ class InvokerFactory:
_executors: dict[str, type[Invoker]] = {}
_processors: dict[str, type[Invoker]] = {}
def __new__(cls) -> InvokerFactory:
def __new__(cls) -> InvokerFactory: # noqa: PYI034
if cls._instance is None:
cls._instance = super().__new__(cls)
# Add NOOP invokers
@@ -220,7 +213,7 @@ class InvokerFactory:
def register(
self,
type: Literal["renderer", "parser", "executor", "processor"],
type: Literal["renderer", "parser", "executor", "processor"], # noqa: A002
name: str,
invoker: type[Invoker],
) -> None:
@@ -233,7 +226,8 @@ class InvokerFactory:
elif type == "processor":
self._processors[name] = invoker
else:
raise ValueError(f"Invalid type {type}")
msg = f"Invalid type {type}"
raise ValueError(msg)
def register_renderer(self, name: str, renderer_class: Any) -> None:
self.register("renderer", name, renderer_class)
@@ -249,21 +243,21 @@ class InvokerFactory:
def __call__(
self,
type: Literal["renderer", "parser", "executor", "processor"],
type_: Literal["renderer", "parser", "executor", "processor"],
name: str,
prompty: Prompty,
data: BaseModel,
) -> Any:
if type == "renderer":
if type_ == "renderer":
return self._renderers[name](prompty)(data)
elif type == "parser":
if type_ == "parser":
return self._parsers[name](prompty)(data)
elif type == "executor":
if type_ == "executor":
return self._executors[name](prompty)(data)
elif type == "processor":
if type_ == "processor":
return self._processors[name](prompty)(data)
else:
raise ValueError(f"Invalid type {type}")
msg = f"Invalid type {type_}"
raise ValueError(msg)
def to_dict(self) -> dict[str, Any]:
return {
@@ -296,7 +290,8 @@ class Frontmatter:
@classmethod
def read_file(cls, path: str) -> dict[str, Any]:
"""Reads file at path and returns dict with separated frontmatter.
"""Read file at path and returns dict with separated frontmatter.
See read() for more info on dict return value.
"""
with open(path, encoding="utf-8") as file:
@@ -305,7 +300,7 @@ class Frontmatter:
@classmethod
def read(cls, string: str) -> dict[str, Any]:
"""Returns dict with separated frontmatter from string.
"""Return dict with separated frontmatter from string.
Returned dict keys:
- attributes: extracted YAML attributes in dict form.

View File

@@ -29,7 +29,6 @@ def create_chat_prompt(
) # type: ignore[arg-type]
)
lc_p = ChatPromptTemplate.from_messages(lc_messages)
lc_p = lc_p.partial(**p.inputs)
return lc_p
return lc_p.partial(**p.inputs)
return RunnableLambda(runnable_chat_lambda)

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import base64
import re
from typing import Union
@@ -40,28 +42,25 @@ class PromptyChatParser(Invoker):
def inline_image(self, image_item: str) -> str:
# pass through if it's a url or base64 encoded
if image_item.startswith("http") or image_item.startswith("data"):
if image_item.startswith(("http", "data")):
return image_item
# otherwise, it's a local file - need to base64 encode it
else:
image_path = self.path / image_item
with open(image_path, "rb") as f:
base64_image = base64.b64encode(f.read()).decode("utf-8")
image_path = self.path / image_item
with open(image_path, "rb") as f:
base64_image = base64.b64encode(f.read()).decode("utf-8")
if image_path.suffix == ".png":
return f"data:image/png;base64,{base64_image}"
elif image_path.suffix == ".jpg":
return f"data:image/jpeg;base64,{base64_image}"
elif image_path.suffix == ".jpeg":
return f"data:image/jpeg;base64,{base64_image}"
else:
raise ValueError(
f"Invalid image format {image_path.suffix} - currently only .png "
"and .jpg / .jpeg are supported."
)
if image_path.suffix == ".png":
return f"data:image/png;base64,{base64_image}"
if image_path.suffix == ".jpg" or image_path.suffix == ".jpeg":
return f"data:image/jpeg;base64,{base64_image}"
msg = (
f"Invalid image format {image_path.suffix} - currently only .png "
"and .jpg / .jpeg are supported."
)
raise ValueError(msg)
def parse_content(self, content: str) -> Union[str, list]:
"""for parsing inline images"""
"""For parsing inline images."""
# regular expression to parse markdown images
image = r"(?P<alt>!\[[^\]]*\])\((?P<filename>.*?)(?=\"|\))\)"
matches = re.findall(image, content, flags=re.MULTILINE)
@@ -98,12 +97,12 @@ class PromptyChatParser(Invoker):
{"type": "text", "text": content_chunks[i].strip()}
)
return content_items
else:
return content
return content
def invoke(self, data: BaseModel) -> BaseModel:
if not isinstance(data, SimpleModel):
raise ValueError("data must be an instance of SimpleModel")
msg = "data must be an instance of SimpleModel"
raise ValueError(msg)
messages = []
separator = r"(?i)^\s*#?\s*(" + "|".join(self.roles) + r")\s*:\s*\n"
@@ -123,7 +122,8 @@ class PromptyChatParser(Invoker):
chunks.pop()
if len(chunks) % 2 != 0:
raise ValueError("Invalid prompt format")
msg = "Invalid prompt format"
raise ValueError(msg)
# create messages
for i in range(0, len(chunks), 2):

View File

@@ -12,6 +12,7 @@ class MustacheRenderer(Invoker):
def invoke(self, data: BaseModel) -> BaseModel:
if not isinstance(data, SimpleModel):
raise ValueError("Expected data to be an instance of SimpleModel")
msg = "Expected data to be an instance of SimpleModel"
raise ValueError(msg)
generated = mustache.render(self.prompty.content, data.item)
return SimpleModel[str](item=generated)

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
import traceback
from pathlib import Path
from typing import Any, Union
from typing import Any, Optional, Union
from .core import (
Frontmatter,
@@ -23,6 +25,7 @@ def load(prompt_path: str, configuration: str = "default") -> Prompty:
Returns:
The Prompty object.
"""
file_path = Path(prompt_path)
if not file_path.is_absolute():
@@ -46,7 +49,8 @@ def load(prompt_path: str, configuration: str = "default") -> Prompty:
try:
model = ModelSettings(**attributes.pop("model"))
except Exception as e:
raise ValueError(f"Error in model settings: {e}")
msg = f"Error in model settings: {e}"
raise ValueError(msg) from e
# pull template settings
try:
@@ -60,7 +64,8 @@ def load(prompt_path: str, configuration: str = "default") -> Prompty:
else:
template = TemplateSettings(type="mustache", parser="prompty")
except Exception as e:
raise ValueError(f"Error in template loader: {e}")
msg = f"Error in template loader: {e}"
raise ValueError(msg) from e
# formalize inputs and outputs
if "inputs" in attributes:
@@ -69,7 +74,8 @@ def load(prompt_path: str, configuration: str = "default") -> Prompty:
k: PropertySettings(**v) for (k, v) in attributes.pop("inputs").items()
}
except Exception as e:
raise ValueError(f"Error in inputs: {e}")
msg = f"Error in inputs: {e}"
raise ValueError(msg) from e
else:
inputs = {}
if "outputs" in attributes:
@@ -78,7 +84,8 @@ def load(prompt_path: str, configuration: str = "default") -> Prompty:
k: PropertySettings(**v) for (k, v) in attributes.pop("outputs").items()
}
except Exception as e:
raise ValueError(f"Error in outputs: {e}")
msg = f"Error in outputs: {e}"
raise ValueError(msg) from e
else:
outputs = {}
@@ -120,7 +127,7 @@ def load(prompt_path: str, configuration: str = "default") -> Prompty:
def prepare(
prompt: Prompty,
inputs: dict[str, Any] = {},
inputs: Optional[dict[str, Any]] = None,
) -> Any:
"""Prepare the inputs for the prompty.
@@ -130,7 +137,10 @@ def prepare(
Returns:
The prepared inputs.
"""
if inputs is None:
inputs = {}
invoker = InvokerFactory()
inputs = param_hoisting(inputs, prompt.sample)
@@ -160,16 +170,15 @@ def prepare(
if isinstance(result, SimpleModel):
return result.item
else:
return result
return result
def run(
prompt: Prompty,
content: Union[dict, list, str],
configuration: dict[str, Any] = {},
parameters: dict[str, Any] = {},
raw: bool = False,
configuration: Optional[dict[str, Any]] = None,
parameters: Optional[dict[str, Any]] = None,
raw: bool = False, # noqa: FBT001, FBT002
) -> Any:
"""Run the prompty.
@@ -182,7 +191,12 @@ def run(
Returns:
The result of running the prompty.
"""
if parameters is None:
parameters = {}
if configuration is None:
configuration = {}
invoker = InvokerFactory()
if configuration != {}:
@@ -213,16 +227,15 @@ def run(
if isinstance(result, SimpleModel):
return result.item
else:
return result
return result
def execute(
prompt: Union[str, Prompty],
configuration: dict[str, Any] = {},
parameters: dict[str, Any] = {},
inputs: dict[str, Any] = {},
raw: bool = False,
configuration: Optional[dict[str, Any]] = None,
parameters: Optional[dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
raw: bool = False, # noqa: FBT001, FBT002
connection: str = "default",
) -> Any:
"""Execute a prompty.
@@ -238,8 +251,14 @@ def execute(
Returns:
The result of executing the prompty.
"""
"""
if inputs is None:
inputs = {}
if parameters is None:
parameters = {}
if configuration is None:
configuration = {}
if isinstance(prompt, str):
prompt = load(prompt, connection)
@@ -247,6 +266,4 @@ def execute(
content = prepare(prompt, inputs)
# run LLM model
result = run(prompt, content, configuration, parameters, raw)
return result
return run(prompt, content, configuration, parameters, raw)

View File

@@ -49,8 +49,63 @@ langchain = { path = "../../langchain", editable = true }
target-version = "py39"
[tool.ruff.lint]
select = ["E", "F", "I", "T201", "UP", "S"]
ignore = [ "UP007", ]
select = [
"A", # flake8-builtins
"B", # flake8-bugbear
"ASYNC", # flake8-async
"C4", # flake8-comprehensions
"COM", # flake8-commas
"D", # pydocstyle
"DOC", # pydoclint
"E", # pycodestyle error
"EM", # flake8-errmsg
"F", # pyflakes
"FA", # flake8-future-annotations
"FBT", # flake8-boolean-trap
"FLY", # flake8-flynt
"I", # isort
"ICN", # flake8-import-conventions
"INT", # flake8-gettext
"ISC", # isort-comprehensions
"PGH", # pygrep-hooks
"PIE", # flake8-pie
"PERF", # flake8-perf
"PYI", # flake8-pyi
"Q", # flake8-quotes
"RET", # flake8-return
"RSE", # flake8-rst-docstrings
"RUF", # ruff
"S", # flake8-bandit
"SLF", # flake8-self
"SLOT", # flake8-slots
"SIM", # flake8-simplify
"T10", # flake8-debugger
"T20", # flake8-print
"TID", # flake8-tidy-imports
"UP", # pyupgrade
"W", # pycodestyle warning
"YTT", # flake8-2020
]
ignore = [
"D100", # pydocstyle: Missing docstring in public module
"D101", # pydocstyle: Missing docstring in public class
"D102", # pydocstyle: Missing docstring in public method
"D103", # pydocstyle: Missing docstring in public function
"D104", # pydocstyle: Missing docstring in public package
"D105", # pydocstyle: Missing docstring in magic method
"D107", # pydocstyle: Missing docstring in __init__
"D203", # Messes with the formatter
"D407", # pydocstyle: Missing-dashed-underline-after-section
"COM812", # Messes with the formatter
"ISC001", # Messes with the formatter
"PERF203", # Rarely useful
"S112", # Rarely useful
"RUF012", # Doesn't play well with Pydantic
"SLF001", # Private member access
"UP007", # pyupgrade: non-pep604-annotation-union
"UP045", # pyupgrade: non-pep604-annotation-optional
]
unfixable = ["B028"] # People should intentionally tune the stacklevel
[tool.mypy]
disallow_untyped_defs = "True"

View File

@@ -4,4 +4,3 @@ import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

View File

@@ -1,5 +1,7 @@
"""A fake callback handler for testing purposes."""
from __future__ import annotations
from itertools import chain
from typing import Any, Optional, Union
from uuid import UUID
@@ -260,7 +262,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
self.on_retriever_error_common()
# Overriding since BaseModel has __deepcopy__ method as well
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore
def __deepcopy__(self, memo: dict) -> FakeCallbackHandler: # type: ignore[override]
return self
@@ -393,5 +395,5 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
self.on_text_common()
# Overriding since BaseModel has __deepcopy__ method as well
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore
def __deepcopy__(self, memo: dict) -> FakeAsyncCallbackHandler: # type: ignore[override]
return self

View File

@@ -1,5 +1,7 @@
"""Fake Chat Model wrapper for testing purposes."""
from __future__ import annotations
import json
from typing import Any, Optional

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Optional, Union
from langchain.agents import AgentOutputParser
@@ -24,13 +26,13 @@ def extract_action_details(text: str) -> tuple[Optional[str], Optional[str]]:
class FakeOutputParser(AgentOutputParser):
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
action, input = extract_action_details(text)
action, input_ = extract_action_details(text)
if action:
log = f"\nInvoking: `{action}` with `{input}"
log = f"\nInvoking: `{action}` with `{input_}"
return AgentAction(tool=action, tool_input=(input or ""), log=log)
elif "Final Answer" in text:
return AgentAction(tool=action, tool_input=(input_ or ""), log=log)
if "Final Answer" in text:
return AgentFinish({"output": text}, text)
return AgentAction(

View File

@@ -48,25 +48,25 @@ def test_prompty_basic_chain() -> None:
user_message = msgs[1]
# Check the types of the messages
assert (
system_message["type"] == "system"
), "The first message should be of type 'system'."
assert (
user_message["type"] == "human"
), "The second message should be of type 'human'."
assert system_message["type"] == "system", (
"The first message should be of type 'system'."
)
assert user_message["type"] == "human", (
"The second message should be of type 'human'."
)
# Test for existence of fakeFirstName and fakeLastName in the system message
assert (
"fakeFirstName" in system_message["content"]
), "The string 'fakeFirstName' should be in the system message content."
assert (
"fakeLastName" in system_message["content"]
), "The string 'fakeLastName' should be in the system message content."
assert "fakeFirstName" in system_message["content"], (
"The string 'fakeFirstName' should be in the system message content."
)
assert "fakeLastName" in system_message["content"], (
"The string 'fakeLastName' should be in the system message content."
)
# Test for existence of fakeQuestion in the user message
assert (
"fakeQuestion" in user_message["content"]
), "The string 'fakeQuestion' should be in the user message content."
assert "fakeQuestion" in user_message["content"], (
"The string 'fakeQuestion' should be in the user message content."
)
def test_prompty_used_in_agent() -> None:

View File

@@ -8,11 +8,10 @@ PROMPT_DIR = Path(__file__).parent / "prompts"
def test_double_templating() -> None:
"""
Assess whether double templating occurs when invoking a chat prompt.
"""Assess whether double templating occurs when invoking a chat prompt.
If it does, an error is thrown and the test fails.
"""
prompt_path = PROMPT_DIR / "double_templating.prompty"
templated_prompt = create_chat_prompt(str(prompt_path))
query = "What do you think of this JSON object: {'key': 7}?"

File diff suppressed because it is too large Load Diff