mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
Add prompt metadata + tags (#17054)
This commit is contained in:
parent
d8f41d0521
commit
3d5e988c55
@ -27,6 +27,7 @@ from langchain_core.prompt_values import (
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables.config import ensure_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.documents import Document
|
||||
@ -48,6 +49,14 @@ class BasePromptTemplate(
|
||||
output_parser: Optional[BaseOutputParser] = None
|
||||
"""How to parse the output of calling an LLM on this formatted prompt."""
|
||||
partial_variables: Mapping[str, Any] = Field(default_factory=dict)
|
||||
"""A dictionary of the partial variables the prompt template carries.
|
||||
|
||||
Partial variables populate the template so that you don't need to
|
||||
pass them in every time you call the prompt."""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
"""Metadata to be used for tracing."""
|
||||
tags: Optional[List[str]] = None
|
||||
"""Tags to be used for tracing."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
@ -95,6 +104,11 @@ class BasePromptTemplate(
|
||||
def invoke(
|
||||
self, input: Dict, config: Optional[RunnableConfig] = None
|
||||
) -> PromptValue:
|
||||
config = ensure_config(config)
|
||||
if self.metadata:
|
||||
config["metadata"].update(self.metadata)
|
||||
if self.tags:
|
||||
config["tags"].extend(self.tags)
|
||||
return self._call_with_config(
|
||||
self._format_prompt_with_error_handling,
|
||||
input,
|
||||
|
@ -1,9 +1,11 @@
|
||||
"""Test functionality related to prompts."""
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
|
||||
|
||||
|
||||
def test_prompt_valid() -> None:
|
||||
@ -328,3 +330,22 @@ def test_prompt_jinja2_wrong_input_variables() -> None:
|
||||
assert PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="jinja2"
|
||||
).input_variables == ["foo"]
|
||||
|
||||
|
||||
def test_prompt_invoke_with_metadata() -> None:
|
||||
"""Test prompt can be invoked with metadata."""
|
||||
template = "This is a {foo} test."
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["foo"],
|
||||
template=template,
|
||||
metadata={"version": "1"},
|
||||
tags=["tag1", "tag2"],
|
||||
)
|
||||
tracer = RunCollectorCallbackHandler()
|
||||
result = prompt.invoke(
|
||||
{"foo": "bar"}, {"metadata": {"foo": "bar"}, "callbacks": [tracer]}
|
||||
)
|
||||
assert result.to_string() == "This is a bar test."
|
||||
assert len(tracer.traced_runs) == 1
|
||||
assert tracer.traced_runs[0].extra["metadata"] == {"version": "1", "foo": "bar"} # type: ignore
|
||||
assert tracer.traced_runs[0].tags == ["tag1", "tag2"] # type: ignore
|
||||
|
Loading…
Reference in New Issue
Block a user