[Hub|tracing] Tag hub prompts (#14720)

If you're using the hub, you'll likely be interested in tracking the
commit/object when tracing. This PR adds it to the config
This commit is contained in:
William FH 2023-12-14 10:04:18 -08:00 committed by GitHub
parent 79ae6c2a9e
commit 852b9ca494
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 1 deletions

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Optional
from langchain_core.load.dump import dumps
from langchain_core.load.load import loads
from langchain_core.runnables import Runnable
if TYPE_CHECKING:
from langchainhub import Client
@ -78,4 +79,7 @@ def pull(
"""
client = _get_client(api_url=api_url, api_key=api_key)
resp: str = client.pull(owner_repo_commit)
return loads(resp)
loaded = loads(resp)
if isinstance(loaded, Runnable):
return loaded.with_config(metadata={"hub_owner_repo_commit": owner_repo_commit})
return loaded

View File

@ -0,0 +1,61 @@
from typing import Any, List
from unittest.mock import MagicMock, Mock, patch
from langchain_core.load.dump import dumps
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables.base import RunnableBinding
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
from langchain import hub
class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution."""
def __init__(self) -> None:
"""Initialize the tracer."""
super().__init__()
self.runs: List[Run] = []
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
self.runs.append(run)
repo_dict = {
"wfh/my-prompt-1": ChatPromptTemplate.from_messages(
[("system", "a"), ("user", "1")]
),
"wfh/my-random-object": {"Hi": "there"},
}
def repo_lookup(owner_repo_commit: str, **kwargs: Any) -> str:
return dumps(repo_dict[owner_repo_commit])
@patch("langchain.hub._get_client")
def test_hub_pull_metadata(mock_get_client: Mock) -> None:
mock_client = MagicMock()
mock_client.pull = repo_lookup
mock_get_client.return_value = mock_client
res = hub.pull("wfh/my-prompt-1")
assert isinstance(res, RunnableBinding)
tracer = FakeTracer()
result = res.invoke({}, {"callbacks": [tracer]})
assert result.messages[0].content == "a"
assert result.messages[1].content == "1"
assert len(tracer.runs) == 1
run = tracer.runs[0]
assert run.extra is not None
assert run.extra["metadata"]["hub_owner_repo_commit"] == "wfh/my-prompt-1" # type: ignore
@patch("langchain.hub._get_client")
def test_hub_pull_random_object(mock_get_client: Mock) -> None:
mock_client = MagicMock()
mock_client.pull = repo_lookup
mock_get_client.return_value = mock_client
res = hub.pull("wfh/my-random-object")
assert res == {"Hi": "there"}