mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-02 04:58:46 +00:00
Update Tracer Auth / Reduce Num Calls (#5517)
Update the session creation and calls --------- Co-authored-by: Ankush Gola <ankush.gola@gmail.com>
This commit is contained in:
parent
949729ff5c
commit
20ec1173f4
@ -23,7 +23,6 @@ from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
|
||||
from langchain.callbacks.tracers.schemas import TracerSession
|
||||
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||
from langchain.callbacks.tracers.wandb import WandbTracer
|
||||
from langchain.schema import (
|
||||
@ -99,26 +98,21 @@ def tracing_v2_enabled(
|
||||
session_name: Optional[str] = None,
|
||||
*,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
session_extra: Optional[Dict[str, Any]] = None,
|
||||
) -> Generator[TracerSession, None, None]:
|
||||
) -> Generator[None, None, None]:
|
||||
"""Get the experimental tracer handler in a context manager."""
|
||||
# Issue a warning that this is experimental
|
||||
warnings.warn(
|
||||
"The experimental tracing v2 is in development. "
|
||||
"The tracing v2 API is in development. "
|
||||
"This is not yet stable and may change in the future."
|
||||
)
|
||||
if isinstance(example_id, str):
|
||||
example_id = UUID(example_id)
|
||||
cb = LangChainTracer(
|
||||
tenant_id=tenant_id,
|
||||
session_name=session_name,
|
||||
example_id=example_id,
|
||||
session_extra=session_extra,
|
||||
session_name=session_name,
|
||||
)
|
||||
session = cb.ensure_session()
|
||||
tracing_v2_callback_var.set(cb)
|
||||
yield session
|
||||
yield
|
||||
tracing_v2_callback_var.set(None)
|
||||
|
||||
|
||||
@ -919,7 +913,6 @@ def _configure(
|
||||
else:
|
||||
try:
|
||||
handler = LangChainTracer(session_name=tracer_session)
|
||||
handler.ensure_session()
|
||||
callback_manager.add_handler(handler, True)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
|
@ -25,10 +25,8 @@ from langchain.callbacks.tracers.schemas import (
|
||||
RunTypeEnum,
|
||||
RunUpdate,
|
||||
TracerSession,
|
||||
TracerSessionCreate,
|
||||
)
|
||||
from langchain.schema import BaseMessage, messages_to_dict
|
||||
from langchain.utils import raise_for_status_with_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -65,49 +63,13 @@ retry_decorator = retry(
|
||||
)
|
||||
|
||||
|
||||
@retry_decorator
|
||||
def _get_tenant_id(
|
||||
tenant_id: Optional[str], endpoint: Optional[str], headers: Optional[dict]
|
||||
) -> str:
|
||||
"""Get the tenant ID for the LangChain API."""
|
||||
tenant_id_: Optional[str] = tenant_id or os.getenv("LANGCHAIN_TENANT_ID")
|
||||
if tenant_id_:
|
||||
return tenant_id_
|
||||
endpoint_ = endpoint or get_endpoint()
|
||||
headers_ = headers or get_headers()
|
||||
response = None
|
||||
try:
|
||||
response = requests.get(endpoint_ + "/tenants", headers=headers_)
|
||||
raise_for_status_with_text(response)
|
||||
except HTTPError as e:
|
||||
if response is not None and response.status_code == 500:
|
||||
raise LangChainTracerAPIError(
|
||||
f"Failed to get tenant ID from LangChain API. {e}"
|
||||
)
|
||||
else:
|
||||
raise LangChainTracerUserError(
|
||||
f"Failed to get tenant ID from LangChain API. {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise LangChainTracerError(
|
||||
f"Failed to get tenant ID from LangChain API. {e}"
|
||||
) from e
|
||||
|
||||
tenants: List[Dict[str, Any]] = response.json()
|
||||
if not tenants:
|
||||
raise ValueError(f"No tenants found for URL {endpoint_}")
|
||||
return tenants[0]["id"]
|
||||
|
||||
|
||||
class LangChainTracer(BaseTracer):
|
||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: Optional[str] = None,
|
||||
example_id: Optional[UUID] = None,
|
||||
session_name: Optional[str] = None,
|
||||
session_extra: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the LangChain tracer."""
|
||||
@ -115,10 +77,8 @@ class LangChainTracer(BaseTracer):
|
||||
self.session: Optional[TracerSession] = None
|
||||
self._endpoint = get_endpoint()
|
||||
self._headers = get_headers()
|
||||
self.tenant_id = tenant_id
|
||||
self.example_id = example_id
|
||||
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
|
||||
self.session_extra = session_extra
|
||||
# set max_workers to 1 to process tasks in order
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
@ -149,62 +109,20 @@ class LangChainTracer(BaseTracer):
|
||||
self._start_trace(chat_model_run)
|
||||
self._on_chat_model_start(chat_model_run)
|
||||
|
||||
def ensure_tenant_id(self) -> str:
|
||||
"""Load or use the tenant ID."""
|
||||
tenant_id = self.tenant_id or _get_tenant_id(
|
||||
self.tenant_id, self._endpoint, self._headers
|
||||
)
|
||||
self.tenant_id = tenant_id
|
||||
return tenant_id
|
||||
|
||||
@retry_decorator
|
||||
def ensure_session(self) -> TracerSession:
|
||||
"""Upsert a session."""
|
||||
if self.session is not None:
|
||||
return self.session
|
||||
tenant_id = self.ensure_tenant_id()
|
||||
url = f"{self._endpoint}/sessions?upsert=true"
|
||||
session_create = TracerSessionCreate(
|
||||
name=self.session_name, extra=self.session_extra, tenant_id=tenant_id
|
||||
)
|
||||
response = None
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
data=session_create.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except HTTPError as e:
|
||||
if response is not None and response.status_code == 500:
|
||||
raise LangChainTracerAPIError(
|
||||
f"Failed to upsert session to LangChain API. {e}"
|
||||
)
|
||||
else:
|
||||
raise LangChainTracerUserError(
|
||||
f"Failed to upsert session to LangChain API. {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise LangChainTracerError(
|
||||
f"Failed to upsert session to LangChain API. {e}"
|
||||
) from e
|
||||
self.session = TracerSession(**response.json())
|
||||
return self.session
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
"""The Langchain Tracer uses Post/Patch rather than persist."""
|
||||
|
||||
@retry_decorator
|
||||
def _persist_run_single(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
session = self.ensure_session()
|
||||
if run.parent_run_id is None:
|
||||
run.reference_example_id = self.example_id
|
||||
run_dict = run.dict()
|
||||
del run_dict["child_runs"]
|
||||
run_create = RunCreate(**run_dict, session_id=session.id)
|
||||
run_create = RunCreate(**run_dict, session_name=self.session_name)
|
||||
response = None
|
||||
try:
|
||||
# TODO: Add retries when async
|
||||
response = requests.post(
|
||||
f"{self._endpoint}/runs",
|
||||
data=run_create.json(),
|
||||
|
@ -36,12 +36,6 @@ class TracerSessionBase(TracerSessionV1Base):
|
||||
tenant_id: UUID
|
||||
|
||||
|
||||
class TracerSessionCreate(TracerSessionBase):
|
||||
"""A creation class for TracerSession."""
|
||||
|
||||
id: Optional[UUID]
|
||||
|
||||
|
||||
class TracerSession(TracerSessionBase):
|
||||
"""TracerSessionV1 schema for the V2 API."""
|
||||
|
||||
@ -136,7 +130,7 @@ class Run(RunBase):
|
||||
|
||||
class RunCreate(RunBase):
|
||||
name: str
|
||||
session_id: UUID
|
||||
session_name: Optional[str] = None
|
||||
|
||||
@root_validator(pre=True)
|
||||
def add_runtime_env(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
@ -10,7 +10,6 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
@ -26,7 +25,8 @@ from requests import Response
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.tracers.schemas import Run, TracerSession
|
||||
from langchain.callbacks.tracers.schemas import Run as TracerRun
|
||||
from langchain.callbacks.tracers.schemas import TracerSession
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.client.models import (
|
||||
APIFeedbackSource,
|
||||
@ -54,6 +54,10 @@ logger = logging.getLogger(__name__)
|
||||
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
|
||||
|
||||
|
||||
class Run(TracerRun):
|
||||
id: UUID
|
||||
|
||||
|
||||
def _get_link_stem(url: str) -> str:
|
||||
scheme = urlsplit(url).scheme
|
||||
netloc_prefix = urlsplit(url).netloc.split(":")[0]
|
||||
@ -75,7 +79,6 @@ class LangChainPlusClient(BaseSettings):
|
||||
|
||||
api_key: Optional[str] = Field(default=None, env="LANGCHAIN_API_KEY")
|
||||
api_url: str = Field(default="http://localhost:1984", env="LANGCHAIN_ENDPOINT")
|
||||
tenant_id: Optional[str] = None
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_api_key_if_hosted(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@ -87,31 +90,8 @@ class LangChainPlusClient(BaseSettings):
|
||||
raise ValueError(
|
||||
"API key must be provided when using hosted LangChain+ API"
|
||||
)
|
||||
tenant_id = values.get("tenant_id")
|
||||
if not tenant_id:
|
||||
values["tenant_id"] = LangChainPlusClient._get_seeded_tenant_id(
|
||||
api_url, api_key
|
||||
)
|
||||
return values
|
||||
|
||||
@staticmethod
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||
def _get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
|
||||
"""Get the tenant ID from the seeded tenant."""
|
||||
url = f"{api_url}/tenants"
|
||||
headers = {"x-api-key": api_key} if api_key else {}
|
||||
response = requests.get(url, headers=headers)
|
||||
try:
|
||||
raise_for_status_with_text(response)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Unable to get default tenant ID. Please manually provide."
|
||||
) from e
|
||||
results: List[dict] = response.json()
|
||||
if len(results) == 0:
|
||||
raise ValueError("No seeded tenant found")
|
||||
return results[0]["id"]
|
||||
|
||||
@staticmethod
|
||||
def _get_session_name(
|
||||
session_name: Optional[str],
|
||||
@ -149,18 +129,10 @@ class LangChainPlusClient(BaseSettings):
|
||||
headers["x-api-key"] = self.api_key
|
||||
return headers
|
||||
|
||||
@property
|
||||
def query_params(self) -> Dict[str, Any]:
|
||||
"""Get the headers for the API request."""
|
||||
return {"tenant_id": self.tenant_id}
|
||||
|
||||
def _get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Response:
|
||||
"""Make a GET request."""
|
||||
query_params = self.query_params
|
||||
if params:
|
||||
query_params.update(params)
|
||||
return requests.get(
|
||||
f"{self.api_url}{path}", headers=self._headers, params=query_params
|
||||
f"{self.api_url}{path}", headers=self._headers, params=params
|
||||
)
|
||||
|
||||
def upload_dataframe(
|
||||
@ -192,7 +164,6 @@ class LangChainPlusClient(BaseSettings):
|
||||
"input_keys": ",".join(input_keys),
|
||||
"output_keys": ",".join(output_keys),
|
||||
"description": description,
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
response = requests.post(
|
||||
self.api_url + "/datasets/upload",
|
||||
@ -244,7 +215,7 @@ class LangChainPlusClient(BaseSettings):
|
||||
) -> TracerSession:
|
||||
"""Read a session from the LangChain+ API."""
|
||||
path = "/sessions"
|
||||
params: Dict[str, Any] = {"limit": 1, "tenant_id": self.tenant_id}
|
||||
params: Dict[str, Any] = {"limit": 1}
|
||||
if session_id is not None:
|
||||
path += f"/{session_id}"
|
||||
elif session_name is not None:
|
||||
@ -291,7 +262,6 @@ class LangChainPlusClient(BaseSettings):
|
||||
) -> Dataset:
|
||||
"""Create a dataset in the LangChain+ API."""
|
||||
dataset = DatasetCreate(
|
||||
tenant_id=self.tenant_id,
|
||||
name=dataset_name,
|
||||
description=description,
|
||||
)
|
||||
@ -309,7 +279,7 @@ class LangChainPlusClient(BaseSettings):
|
||||
self, *, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None
|
||||
) -> Dataset:
|
||||
path = "/datasets"
|
||||
params: Dict[str, Any] = {"limit": 1, "tenant_id": self.tenant_id}
|
||||
params: Dict[str, Any] = {"limit": 1}
|
||||
if dataset_id is not None:
|
||||
path += f"/{dataset_id}"
|
||||
elif dataset_name is not None:
|
||||
|
@ -49,7 +49,6 @@ class ExampleUpdate(BaseModel):
|
||||
class DatasetBase(BaseModel):
|
||||
"""Dataset base model."""
|
||||
|
||||
tenant_id: UUID
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
@ -68,6 +67,7 @@ class Dataset(DatasetBase):
|
||||
"""Dataset ORM model."""
|
||||
|
||||
id: UUID
|
||||
tenant_id: UUID
|
||||
created_at: datetime
|
||||
modified_at: Optional[datetime] = Field(default=None)
|
||||
|
||||
|
@ -214,7 +214,6 @@ async def _tracer_initializer(session_name: Optional[str]) -> Optional[LangChain
|
||||
"""
|
||||
if session_name:
|
||||
tracer = LangChainTracer(session_name=session_name)
|
||||
tracer.ensure_session()
|
||||
return tracer
|
||||
else:
|
||||
return None
|
||||
|
@ -148,8 +148,7 @@ def test_tracing_v2_context_manager() -> None:
|
||||
)
|
||||
if "LANGCHAIN_TRACING_V2" in os.environ:
|
||||
del os.environ["LANGCHAIN_TRACING_V2"]
|
||||
with tracing_v2_enabled() as session:
|
||||
assert session
|
||||
with tracing_v2_enabled():
|
||||
agent.run(questions[0]) # this should be traced
|
||||
|
||||
agent.run(questions[0]) # this should not be traced
|
||||
|
@ -2,14 +2,12 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.callbacks.tracers.schemas import TracerSession
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.client.langchain import (
|
||||
LangChainPlusClient,
|
||||
@ -46,39 +44,23 @@ def test_is_localhost() -> None:
|
||||
assert not _is_localhost("http://example.com:8000")
|
||||
|
||||
|
||||
def test_validate_api_key_if_hosted() -> None:
|
||||
def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
|
||||
return _TENANT_ID
|
||||
def test_validate_api_key_if_hosted(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="API key must be provided"):
|
||||
LangChainPlusClient(api_url="http://www.example.com")
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
|
||||
):
|
||||
with pytest.raises(ValueError, match="API key must be provided"):
|
||||
LangChainPlusClient(api_url="http://www.example.com")
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
|
||||
):
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000")
|
||||
assert client.api_url == "http://localhost:8000"
|
||||
assert client.api_key is None
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000")
|
||||
assert client.api_url == "http://localhost:8000"
|
||||
assert client.api_key is None
|
||||
|
||||
|
||||
def test_headers() -> None:
|
||||
def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
|
||||
return _TENANT_ID
|
||||
def test_headers(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False)
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
||||
assert client._headers == {"x-api-key": "123"}
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
|
||||
):
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
||||
assert client._headers == {"x-api-key": "123"}
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
|
||||
):
|
||||
client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
|
||||
assert client_no_key._headers == {}
|
||||
client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
|
||||
assert client_no_key._headers == {}
|
||||
|
||||
|
||||
@mock.patch("langchain.client.langchain.requests.post")
|
||||
@ -112,7 +94,8 @@ def test_upload_csv(mock_post: mock.Mock) -> None:
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = LangChainPlusClient(
|
||||
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
|
||||
api_url="http://localhost:8000",
|
||||
api_key="123",
|
||||
)
|
||||
csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n"))
|
||||
|
||||
@ -196,22 +179,14 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||
]
|
||||
|
||||
def mock_ensure_session(self: Any, *args: Any, **kwargs: Any) -> TracerSession:
|
||||
return TracerSession(name="test_session", tenant_id=_TENANT_ID, id=uuid.uuid4())
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "read_dataset", new=mock_read_dataset
|
||||
), mock.patch.object(
|
||||
LangChainPlusClient, "list_examples", new=mock_list_examples
|
||||
), mock.patch(
|
||||
"langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain
|
||||
), mock.patch.object(
|
||||
LangChainTracer, "ensure_session", new=mock_ensure_session
|
||||
):
|
||||
monkeypatch.setenv("LANGCHAIN_TENANT_ID", _TENANT_ID)
|
||||
client = LangChainPlusClient(
|
||||
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
|
||||
)
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
||||
chain = mock.MagicMock()
|
||||
num_repetitions = 3
|
||||
results = await client.arun_on_dataset(
|
||||
|
Loading…
Reference in New Issue
Block a user