mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-24 20:09:01 +00:00
Add Tenant ID to V2 Tracer (#4135)
Update the V2 tracer to - use UUIDs instead of int's - load a tenant ID and use that when saving sessions
This commit is contained in:
@@ -6,7 +6,7 @@ import os
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union, cast
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from langchain.callbacks.base import (
|
||||
@@ -21,6 +21,7 @@ from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.tracers.base import TracerSession
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2
|
||||
from langchain.callbacks.tracers.schemas import TracerSessionV2
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||
@@ -28,7 +29,7 @@ Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
|
||||
"openai_callback", default=None
|
||||
)
|
||||
tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar(
|
||||
tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501
|
||||
"tracing_callback", default=None
|
||||
)
|
||||
|
||||
@@ -48,7 +49,7 @@ def tracing_enabled(
|
||||
) -> Generator[TracerSession, None, None]:
|
||||
"""Get Tracer in a context manager."""
|
||||
cb = LangChainTracer()
|
||||
session = cb.load_session(session_name)
|
||||
session = cast(TracerSession, cb.load_session(session_name))
|
||||
tracing_callback_var.set(cb)
|
||||
yield session
|
||||
tracing_callback_var.set(None)
|
||||
@@ -57,7 +58,7 @@ def tracing_enabled(
|
||||
@contextmanager
|
||||
def tracing_v2_enabled(
|
||||
session_name: str = "default",
|
||||
) -> Generator[TracerSession, None, None]:
|
||||
) -> Generator[TracerSessionV2, None, None]:
|
||||
"""Get the experimental tracer handler in a context manager."""
|
||||
# Issue a warning that this is experimental
|
||||
warnings.warn(
|
||||
|
@@ -12,7 +12,9 @@ from langchain.callbacks.tracers.schemas import (
|
||||
LLMRun,
|
||||
ToolRun,
|
||||
TracerSession,
|
||||
TracerSessionBase,
|
||||
TracerSessionCreate,
|
||||
TracerSessionV2,
|
||||
)
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
@@ -27,7 +29,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {}
|
||||
self.session: Optional[TracerSession] = None
|
||||
self.session: Optional[Union[TracerSessionV2, TracerSession]] = None
|
||||
|
||||
@staticmethod
|
||||
def _add_child_run(
|
||||
@@ -49,22 +51,31 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
"""Persist a run."""
|
||||
|
||||
@abstractmethod
|
||||
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
||||
def _persist_session(
|
||||
self, session: TracerSessionBase
|
||||
) -> Union[TracerSession, TracerSessionV2]:
|
||||
"""Persist a tracing session."""
|
||||
|
||||
def new_session(self, name: Optional[str] = None, **kwargs: Any) -> TracerSession:
|
||||
def _get_session_create(
|
||||
self, name: Optional[str] = None, **kwargs: Any
|
||||
) -> TracerSessionBase:
|
||||
return TracerSessionCreate(name=name, extra=kwargs)
|
||||
|
||||
def new_session(
|
||||
self, name: Optional[str] = None, **kwargs: Any
|
||||
) -> Union[TracerSession, TracerSessionV2]:
|
||||
"""NOT thread safe, do not call this method from multiple threads."""
|
||||
session_create = TracerSessionCreate(name=name, extra=kwargs)
|
||||
session_create = self._get_session_create(name=name, **kwargs)
|
||||
session = self._persist_session(session_create)
|
||||
self.session = session
|
||||
return session
|
||||
|
||||
@abstractmethod
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]:
|
||||
"""Load a tracing session and set it as the Tracer's session."""
|
||||
|
||||
@abstractmethod
|
||||
def load_default_session(self) -> TracerSession:
|
||||
def load_default_session(self) -> Union[TracerSession, TracerSessionV2]:
|
||||
"""Load the default tracing session and set it as the Tracer's session."""
|
||||
|
||||
def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
|
@@ -14,21 +14,32 @@ from langchain.callbacks.tracers.schemas import (
|
||||
Run,
|
||||
ToolRun,
|
||||
TracerSession,
|
||||
TracerSessionCreate,
|
||||
TracerSessionBase,
|
||||
TracerSessionV2,
|
||||
TracerSessionV2Create,
|
||||
)
|
||||
|
||||
|
||||
def _get_headers() -> Dict[str, Any]:
|
||||
"""Get the headers for the LangChain API."""
|
||||
headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
||||
if os.getenv("LANGCHAIN_API_KEY"):
|
||||
headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
|
||||
return headers
|
||||
|
||||
|
||||
def _get_endpoint() -> str:
|
||||
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
||||
|
||||
|
||||
class LangChainTracer(BaseTracer):
|
||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||
|
||||
def __init__(self, session_name: str = "default", **kwargs: Any) -> None:
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the LangChain tracer."""
|
||||
super().__init__(**kwargs)
|
||||
self._endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
||||
self._headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
||||
if os.getenv("LANGCHAIN_API_KEY"):
|
||||
self._headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
|
||||
self.session = self.load_session(session_name)
|
||||
self._endpoint = _get_endpoint()
|
||||
self._headers = _get_headers()
|
||||
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
@@ -48,7 +59,9 @@ class LangChainTracer(BaseTracer):
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to persist run: {e}")
|
||||
|
||||
def _persist_session(self, session_create: TracerSessionCreate) -> TracerSession:
|
||||
def _persist_session(
|
||||
self, session_create: TracerSessionBase
|
||||
) -> Union[TracerSession, TracerSessionV2]:
|
||||
"""Persist a session."""
|
||||
try:
|
||||
r = requests.post(
|
||||
@@ -81,22 +94,89 @@ class LangChainTracer(BaseTracer):
|
||||
self.session = tracer_session
|
||||
return tracer_session
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]:
|
||||
"""Load a session with the given name from the tracer."""
|
||||
return self._load_session(session_name)
|
||||
|
||||
def load_default_session(self) -> TracerSession:
|
||||
def load_default_session(self) -> Union[TracerSession, TracerSessionV2]:
|
||||
"""Load the default tracing session and set it as the Tracer's session."""
|
||||
return self._load_session("default")
|
||||
|
||||
|
||||
def _get_tenant_id() -> Optional[str]:
|
||||
"""Get the tenant ID for the LangChain API."""
|
||||
tenant_id: Optional[str] = os.getenv("LANGCHAIN_TENANT_ID")
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
endpoint = _get_endpoint()
|
||||
headers = _get_headers()
|
||||
response = requests.get(endpoint + "/tenants", headers=headers)
|
||||
response.raise_for_status()
|
||||
tenants: List[Dict[str, Any]] = response.json()
|
||||
if not tenants:
|
||||
raise ValueError(f"No tenants found for URL {endpoint}")
|
||||
return tenants[0]["id"]
|
||||
|
||||
|
||||
class LangChainTracerV2(LangChainTracer):
|
||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||
|
||||
@staticmethod
|
||||
def _convert_run(run: Union[LLMRun, ChainRun, ToolRun]) -> Run:
|
||||
"""Convert a run to a Run."""
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the LangChain tracer."""
|
||||
super().__init__(**kwargs)
|
||||
self._endpoint = _get_endpoint()
|
||||
self._headers = _get_headers()
|
||||
self.tenant_id = _get_tenant_id()
|
||||
|
||||
def _get_session_create(
|
||||
self, name: Optional[str] = None, **kwargs: Any
|
||||
) -> TracerSessionBase:
|
||||
return TracerSessionV2Create(name=name, extra=kwargs, tenant_id=self.tenant_id)
|
||||
|
||||
def _persist_session(self, session_create: TracerSessionBase) -> TracerSessionV2:
|
||||
"""Persist a session."""
|
||||
try:
|
||||
r = requests.post(
|
||||
f"{self._endpoint}/sessions",
|
||||
data=session_create.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
session = TracerSessionV2(id=r.json()["id"], **session_create.dict())
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to create session, using default session: {e}")
|
||||
session = self.load_session("default")
|
||||
return session
|
||||
|
||||
def _get_default_query_params(self) -> Dict[str, Any]:
|
||||
"""Get the query params for the LangChain API."""
|
||||
return {"tenant_id": self.tenant_id}
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSessionV2:
|
||||
"""Load a session with the given name from the tracer."""
|
||||
try:
|
||||
url = f"{self._endpoint}/sessions"
|
||||
params = {"tenant_id": self.tenant_id}
|
||||
if session_name:
|
||||
params["name"] = session_name
|
||||
r = requests.get(url, headers=self._headers, params=params)
|
||||
tracer_session = TracerSessionV2(**r.json()[0])
|
||||
except Exception as e:
|
||||
session_type = "default" if not session_name else session_name
|
||||
logging.warning(
|
||||
f"Failed to load {session_type} session, using empty session: {e}"
|
||||
)
|
||||
tracer_session = TracerSessionV2(id=1, tenant_id=self.tenant_id)
|
||||
|
||||
self.session = tracer_session
|
||||
return tracer_session
|
||||
|
||||
def load_default_session(self) -> TracerSessionV2:
|
||||
"""Load the default tracing session and set it as the Tracer's session."""
|
||||
return self.load_session("default")
|
||||
|
||||
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> Run:
|
||||
"""Convert a run to a Run."""
|
||||
session = self.session or self.load_default_session()
|
||||
inputs: Dict[str, Any] = {}
|
||||
outputs: Optional[Dict[str, Any]] = None
|
||||
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
||||
@@ -126,30 +206,30 @@ class LangChainTracerV2(LangChainTracer):
|
||||
|
||||
return Run(
|
||||
id=run.uuid,
|
||||
name=run.serialized.get("name"),
|
||||
name=run.serialized.get("name", f"{run_type}-{run.uuid}"),
|
||||
start_time=run.start_time,
|
||||
end_time=run.end_time,
|
||||
extra=run.extra,
|
||||
extra=run.extra or {},
|
||||
error=run.error,
|
||||
execution_order=run.execution_order,
|
||||
serialized=run.serialized,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
session_id=run.session_id,
|
||||
session_id=session.id,
|
||||
run_type=run_type,
|
||||
parent_run_id=run.parent_uuid,
|
||||
child_runs=[LangChainTracerV2._convert_run(child) for child in child_runs],
|
||||
child_runs=[self._convert_run(child) for child in child_runs],
|
||||
)
|
||||
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
run_create = self._convert_run(run)
|
||||
|
||||
try:
|
||||
requests.post(
|
||||
result = requests.post(
|
||||
f"{self._endpoint}/runs",
|
||||
data=run_create.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to persist run: {e}")
|
||||
|
@@ -31,6 +31,24 @@ class TracerSession(TracerSessionBase):
|
||||
id: int
|
||||
|
||||
|
||||
class TracerSessionV2Base(TracerSessionBase):
|
||||
"""A creation class for TracerSessionV2."""
|
||||
|
||||
tenant_id: UUID
|
||||
|
||||
|
||||
class TracerSessionV2Create(TracerSessionBase):
|
||||
"""A creation class for TracerSessionV2."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TracerSessionV2(TracerSessionV2Base):
|
||||
"""TracerSession schema for the V2 API."""
|
||||
|
||||
id: UUID
|
||||
|
||||
|
||||
class BaseRun(BaseModel):
|
||||
"""Base class for Run."""
|
||||
|
||||
@@ -93,9 +111,9 @@ class Run(BaseModel):
|
||||
serialized: dict
|
||||
inputs: dict
|
||||
outputs: Optional[dict]
|
||||
session_id: int
|
||||
session_id: UUID
|
||||
parent_run_id: Optional[UUID]
|
||||
example_id: Optional[UUID]
|
||||
reference_example_id: Optional[UUID]
|
||||
run_type: RunTypeEnum
|
||||
child_runs: List[Run] = Field(default_factory=list)
|
||||
|
||||
|
@@ -2,8 +2,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Union
|
||||
from uuid import uuid4
|
||||
from typing import List, Tuple, Union
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
@@ -16,7 +17,8 @@ from langchain.callbacks.tracers.base import (
|
||||
TracerException,
|
||||
TracerSession,
|
||||
)
|
||||
from langchain.callbacks.tracers.schemas import TracerSessionCreate
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracerV2
|
||||
from langchain.callbacks.tracers.schemas import Run, TracerSessionBase, TracerSessionV2
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
TEST_SESSION_ID = 2023
|
||||
@@ -27,7 +29,7 @@ def load_session(session_name: str) -> TracerSession:
|
||||
return TracerSession(id=1, name=session_name, start_time=datetime.utcnow())
|
||||
|
||||
|
||||
def _persist_session(session: TracerSessionCreate) -> TracerSession:
|
||||
def _persist_session(session: TracerSessionBase) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
return TracerSession(id=TEST_SESSION_ID, **session.dict())
|
||||
|
||||
@@ -49,7 +51,7 @@ class FakeTracer(BaseTracer):
|
||||
"""Persist a run."""
|
||||
self.runs.append(run)
|
||||
|
||||
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
||||
def _persist_session(self, session: TracerSessionBase) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
return _persist_session(session)
|
||||
|
||||
@@ -473,3 +475,125 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
)
|
||||
|
||||
assert tracer.runs == [compare_run] * 3
|
||||
|
||||
|
||||
_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a")
|
||||
_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV2:
|
||||
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
|
||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
|
||||
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
|
||||
tracer = LangChainTracerV2()
|
||||
return tracer
|
||||
|
||||
|
||||
# Mock a sample TracerSessionV2 object
|
||||
@pytest.fixture
|
||||
def sample_tracer_session_v2() -> TracerSessionV2:
|
||||
return TracerSessionV2(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
|
||||
|
||||
|
||||
# Mock a sample LLMRun, ChainRun, and ToolRun objects
|
||||
@pytest.fixture
|
||||
def sample_runs() -> Tuple[LLMRun, ChainRun, ToolRun]:
|
||||
llm_run = LLMRun(
|
||||
uuid="57a08cc4-73d2-4236-8370-549099d07fad",
|
||||
name="llm_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
session_id=1,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
serialized={},
|
||||
extra={},
|
||||
)
|
||||
chain_run = ChainRun(
|
||||
uuid="57a08cc4-73d2-4236-8371-549099d07fad",
|
||||
name="chain_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
session_id=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs=None,
|
||||
child_llm_runs=[llm_run],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
extra={},
|
||||
)
|
||||
tool_run = ToolRun(
|
||||
uuid="57a08cc4-73d2-4236-8372-549099d07fad",
|
||||
name="tool_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
session_id=1,
|
||||
tool_input="test",
|
||||
action="{}",
|
||||
serialized={},
|
||||
child_llm_runs=[],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
extra={},
|
||||
)
|
||||
return llm_run, chain_run, tool_run
|
||||
|
||||
|
||||
# Test _get_default_query_params method
|
||||
def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None:
|
||||
expected = {"tenant_id": "test-tenant-id"}
|
||||
result = lang_chain_tracer_v2._get_default_query_params()
|
||||
assert result == expected
|
||||
|
||||
|
||||
# Test load_session method
|
||||
@patch("langchain.callbacks.tracers.langchain.requests.get")
|
||||
def test_load_session(
|
||||
mock_requests_get: Mock,
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
) -> None:
|
||||
"""Test that load_session method returns a TracerSessionV2 object."""
|
||||
mock_requests_get.return_value.json.return_value = [sample_tracer_session_v2.dict()]
|
||||
result = lang_chain_tracer_v2.load_session("test-session-name")
|
||||
mock_requests_get.assert_called_with(
|
||||
"http://test-endpoint.com/sessions",
|
||||
headers={"Content-Type": "application/json", "x-api-key": "foo"},
|
||||
params={"tenant_id": "test-tenant-id", "name": "test-session-name"},
|
||||
)
|
||||
assert result == sample_tracer_session_v2
|
||||
|
||||
|
||||
def test_convert_run(
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
llm_run, chain_run, tool_run = sample_runs
|
||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
||||
converted_llm_run = lang_chain_tracer_v2._convert_run(llm_run)
|
||||
converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run)
|
||||
converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run)
|
||||
|
||||
assert isinstance(converted_llm_run, Run)
|
||||
assert isinstance(converted_chain_run, Run)
|
||||
assert isinstance(converted_tool_run, Run)
|
||||
|
||||
|
||||
@patch("langchain.callbacks.tracers.langchain.requests.post")
|
||||
def test_persist_run(
|
||||
mock_requests_post: Mock,
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
mock_requests_post.return_value.raise_for_status.return_value = None
|
||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
||||
llm_run, chain_run, tool_run = sample_runs
|
||||
lang_chain_tracer_v2._persist_run(llm_run)
|
||||
lang_chain_tracer_v2._persist_run(chain_run)
|
||||
lang_chain_tracer_v2._persist_run(tool_run)
|
||||
|
||||
assert mock_requests_post.call_count == 3
|
||||
|
@@ -70,7 +70,7 @@ def test_success(mocked_responses: responses.RequestsMock, ref: str) -> None:
|
||||
assert file_contents is None
|
||||
file_contents = Path(file_path).read_text()
|
||||
|
||||
mocked_responses.get(
|
||||
mocked_responses.get( # type: ignore
|
||||
urljoin(URL_BASE.format(ref=ref), path),
|
||||
body=body,
|
||||
status=200,
|
||||
@@ -86,7 +86,9 @@ def test_failed_request(mocked_responses: responses.RequestsMock) -> None:
|
||||
path = "chains/path/chain.json"
|
||||
loader = Mock()
|
||||
|
||||
mocked_responses.get(urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500)
|
||||
mocked_responses.get( # type: ignore
|
||||
urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=re.compile("Could not find file at .*")):
|
||||
try_load_from_hub(f"lc://{path}", loader, "chains", {"json"})
|
||||
|
Reference in New Issue
Block a user