diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index 2038f3015dd..c0c747ae491 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -18,7 +18,7 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str: try: from jinja2 import Template except ImportError: - raise ValueError( + raise ImportError( "jinja2 not installed, which is needed to use the jinja2_formatter. " "Please install it with `pip install jinja2`." ) diff --git a/langchain/prompts/prompt.py b/langchain/prompts/prompt.py index 18a18514891..af6ff29e8d7 100644 --- a/langchain/prompts/prompt.py +++ b/langchain/prompts/prompt.py @@ -5,7 +5,6 @@ from pathlib import Path from string import Formatter from typing import Any, Dict, List, Set, Union -from jinja2 import Environment, meta from pydantic import Extra, root_validator from langchain.prompts.base import ( @@ -16,6 +15,13 @@ from langchain.prompts.base import ( def _get_jinja2_variables_from_template(template: str) -> Set[str]: + try: + from jinja2 import Environment, meta + except ImportError: + raise ImportError( + "jinja2 not installed, which is needed to use the jinja2_formatter. " + "Please install it with `pip install jinja2`." + ) env = Environment() ast = env.parse(template) variables = meta.find_undeclared_variables(ast)