Fixes error loading Obsidian templates (#13888)

- **Description:** Obsidian templates can include
[variables](https://help.obsidian.md/Plugins/Templates#Template+variables)
using double curly braces. `ObsidianLoader` uses PyYaml to parse the
frontmatter of documents. This parsing throws an error when encountering
variables' curly braces. This is avoided by temporarily substituting
safe strings before parsing.
  - **Issue:** #13887
  - **Tag maintainer:** @hwchase17
This commit is contained in:
ealt 2023-12-04 20:55:37 +00:00 committed by GitHub
parent f6d68d78f3
commit e09b876863
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 4 deletions

View File

@ -1,7 +1,8 @@
import functools
import logging import logging
import re import re
from pathlib import Path from pathlib import Path
from typing import List from typing import Any, Dict, List
import yaml import yaml
from langchain_core.documents import Document from langchain_core.documents import Document
@ -15,6 +16,7 @@ class ObsidianLoader(BaseLoader):
"""Load `Obsidian` files from directory.""" """Load `Obsidian` files from directory."""
FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL) FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL)
TEMPLATE_VARIABLE_REGEX = re.compile(r"{{(.*?)}}", re.DOTALL)
TAG_REGEX = re.compile(r"[^\S\/]#([a-zA-Z_]+[-_/\w]*)") TAG_REGEX = re.compile(r"[^\S\/]#([a-zA-Z_]+[-_/\w]*)")
DATAVIEW_LINE_REGEX = re.compile(r"^\s*(\w+)::\s*(.*)$", re.MULTILINE) DATAVIEW_LINE_REGEX = re.compile(r"^\s*(\w+)::\s*(.*)$", re.MULTILINE)
DATAVIEW_INLINE_BRACKET_REGEX = re.compile(r"\[(\w+)::\s*(.*)\]", re.MULTILINE) DATAVIEW_INLINE_BRACKET_REGEX = re.compile(r"\[(\w+)::\s*(.*)\]", re.MULTILINE)
@ -35,6 +37,27 @@ class ObsidianLoader(BaseLoader):
self.encoding = encoding self.encoding = encoding
self.collect_metadata = collect_metadata self.collect_metadata = collect_metadata
def _replace_template_var(
self, placeholders: Dict[str, str], match: re.Match
) -> str:
"""Replace a template variable with a placeholder."""
placeholder = f"__TEMPLATE_VAR_{len(placeholders)}__"
placeholders[placeholder] = match.group(1)
return placeholder
def _restore_template_vars(self, obj: Any, placeholders: Dict[str, str]) -> Any:
"""Restore template variables replaced with placeholders to original values."""
if isinstance(obj, str):
for placeholder, value in placeholders.items():
obj = obj.replace(placeholder, f"{{{{{value}}}}}")
elif isinstance(obj, dict):
for key, value in obj.items():
obj[key] = self._restore_template_vars(value, placeholders)
elif isinstance(obj, list):
for i, item in enumerate(obj):
obj[i] = self._restore_template_vars(item, placeholders)
return obj
def _parse_front_matter(self, content: str) -> dict: def _parse_front_matter(self, content: str) -> dict:
"""Parse front matter metadata from the content and return it as a dict.""" """Parse front matter metadata from the content and return it as a dict."""
if not self.collect_metadata: if not self.collect_metadata:
@ -44,8 +67,17 @@ class ObsidianLoader(BaseLoader):
if not match: if not match:
return {} return {}
placeholders: Dict[str, str] = {}
replace_template_var = functools.partial(
self._replace_template_var, placeholders
)
front_matter_text = self.TEMPLATE_VARIABLE_REGEX.sub(
replace_template_var, match.group(1)
)
try: try:
front_matter = yaml.safe_load(match.group(1)) front_matter = yaml.safe_load(front_matter_text)
front_matter = self._restore_template_vars(front_matter, placeholders)
# If tags are a string, split them into a list # If tags are a string, split them into a list
if "tags" in front_matter and isinstance(front_matter["tags"], str): if "tags" in front_matter and isinstance(front_matter["tags"], str):

View File

@ -0,0 +1,12 @@
---
aString: {{var}}
anArray:
- element
- {{varElement}}
aDict:
dictId1: 'val'
dictId2: '{{varVal}}'
tags: [ 'tag', '{{varTag}}' ]
---
Frontmatter contains template variables.

View File

@ -17,7 +17,7 @@ docs = loader.load()
def test_page_content_loaded() -> None: def test_page_content_loaded() -> None:
"""Verify that all docs have page_content""" """Verify that all docs have page_content"""
assert len(docs) == 5 assert len(docs) == 6
assert all(doc.page_content for doc in docs) assert all(doc.page_content for doc in docs)
@ -27,7 +27,7 @@ def test_disable_collect_metadata() -> None:
str(OBSIDIAN_EXAMPLE_PATH), collect_metadata=False str(OBSIDIAN_EXAMPLE_PATH), collect_metadata=False
) )
docs_wo = loader_without_metadata.load() docs_wo = loader_without_metadata.load()
assert len(docs_wo) == 5 assert len(docs_wo) == 6
assert all(doc.page_content for doc in docs_wo) assert all(doc.page_content for doc in docs_wo)
assert all(set(doc.metadata) == STANDARD_METADATA_FIELDS for doc in docs_wo) assert all(set(doc.metadata) == STANDARD_METADATA_FIELDS for doc in docs_wo)
@ -45,6 +45,24 @@ def test_metadata_with_frontmatter() -> None:
assert set(doc.metadata["tags"].split(",")) == {"journal/entry", "obsidian"} assert set(doc.metadata["tags"].split(",")) == {"journal/entry", "obsidian"}
def test_metadata_with_template_vars_in_frontmatter() -> None:
"""Verify frontmatter fields with template variables are loaded."""
doc = next(
doc for doc in docs if doc.metadata["source"] == "template_var_frontmatter.md"
)
FRONTMATTER_FIELDS = {
"aString",
"anArray",
"aDict",
"tags",
}
assert set(doc.metadata) == FRONTMATTER_FIELDS | STANDARD_METADATA_FIELDS
assert doc.metadata["aString"] == "{{var}}"
assert doc.metadata["anArray"] == "['element', '{{varElement}}']"
assert doc.metadata["aDict"] == "{'dictId1': 'val', 'dictId2': '{{varVal}}'}"
assert set(doc.metadata["tags"].split(",")) == {"tag", "{{varTag}}"}
def test_metadata_with_bad_frontmatter() -> None: def test_metadata_with_bad_frontmatter() -> None:
"""Verify a doc with non-yaml frontmatter.""" """Verify a doc with non-yaml frontmatter."""
doc = next(doc for doc in docs if doc.metadata["source"] == "bad_frontmatter.md") doc = next(doc for doc in docs if doc.metadata["source"] == "bad_frontmatter.md")