diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index 9a65e5e32db..51eddf0f771 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -27,6 +27,7 @@ from langchain_core.prompt_values import ( ) from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator from langchain_core.runnables import RunnableConfig, RunnableSerializable +from langchain_core.runnables.config import ensure_config if TYPE_CHECKING: from langchain_core.documents import Document @@ -48,6 +49,14 @@ class BasePromptTemplate( output_parser: Optional[BaseOutputParser] = None """How to parse the output of calling an LLM on this formatted prompt.""" partial_variables: Mapping[str, Any] = Field(default_factory=dict) + """A dictionary of the partial variables the prompt template carries. + + Partial variables populate the template so that you don't need to + pass them in every time you call the prompt.""" + metadata: Optional[Dict[str, Any]] = None + """Metadata to be used for tracing.""" + tags: Optional[List[str]] = None + """Tags to be used for tracing.""" @classmethod def get_lc_namespace(cls) -> List[str]: @@ -95,6 +104,11 @@ class BasePromptTemplate( def invoke( self, input: Dict, config: Optional[RunnableConfig] = None ) -> PromptValue: + config = ensure_config(config) + if self.metadata: + config["metadata"].update(self.metadata) + if self.tags: + config["tags"].extend(self.tags) return self._call_with_config( self._format_prompt_with_error_handling, input, diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index f50b71bea80..5075c5c0bb5 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -1,9 +1,11 @@ """Test functionality related to prompts.""" + from unittest import mock import pytest from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.tracers.run_collector import RunCollectorCallbackHandler def test_prompt_valid() -> None: @@ -328,3 +330,22 @@ def test_prompt_jinja2_wrong_input_variables() -> None: assert PromptTemplate( input_variables=input_variables, template=template, template_format="jinja2" ).input_variables == ["foo"] + + +def test_prompt_invoke_with_metadata() -> None: + """Test prompt can be invoked with metadata.""" + template = "This is a {foo} test." + prompt = PromptTemplate( + input_variables=["foo"], + template=template, + metadata={"version": "1"}, + tags=["tag1", "tag2"], + ) + tracer = RunCollectorCallbackHandler() + result = prompt.invoke( + {"foo": "bar"}, {"metadata": {"foo": "bar"}, "callbacks": [tracer]} + ) + assert result.to_string() == "This is a bar test." + assert len(tracer.traced_runs) == 1 + assert tracer.traced_runs[0].extra["metadata"] == {"version": "1", "foo": "bar"} # type: ignore + assert tracer.traced_runs[0].tags == ["tag1", "tag2"] # type: ignore