mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-10 15:33:11 +00:00
Update Tracer Auth / Reduce Num Calls (#5517)
Update the session creation and calls --------- Co-authored-by: Ankush Gola <ankush.gola@gmail.com>
This commit is contained in:
@@ -2,14 +2,12 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.callbacks.tracers.schemas import TracerSession
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.client.langchain import (
|
||||
LangChainPlusClient,
|
||||
@@ -46,39 +44,23 @@ def test_is_localhost() -> None:
|
||||
assert not _is_localhost("http://example.com:8000")
|
||||
|
||||
|
||||
def test_validate_api_key_if_hosted() -> None:
|
||||
def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
|
||||
return _TENANT_ID
|
||||
def test_validate_api_key_if_hosted(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="API key must be provided"):
|
||||
LangChainPlusClient(api_url="http://www.example.com")
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
|
||||
):
|
||||
with pytest.raises(ValueError, match="API key must be provided"):
|
||||
LangChainPlusClient(api_url="http://www.example.com")
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
|
||||
):
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000")
|
||||
assert client.api_url == "http://localhost:8000"
|
||||
assert client.api_key is None
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000")
|
||||
assert client.api_url == "http://localhost:8000"
|
||||
assert client.api_key is None
|
||||
|
||||
|
||||
def test_headers() -> None:
|
||||
def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
|
||||
return _TENANT_ID
|
||||
def test_headers(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False)
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
||||
assert client._headers == {"x-api-key": "123"}
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
|
||||
):
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
||||
assert client._headers == {"x-api-key": "123"}
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
|
||||
):
|
||||
client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
|
||||
assert client_no_key._headers == {}
|
||||
client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
|
||||
assert client_no_key._headers == {}
|
||||
|
||||
|
||||
@mock.patch("langchain.client.langchain.requests.post")
|
||||
@@ -112,7 +94,8 @@ def test_upload_csv(mock_post: mock.Mock) -> None:
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = LangChainPlusClient(
|
||||
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
|
||||
api_url="http://localhost:8000",
|
||||
api_key="123",
|
||||
)
|
||||
csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n"))
|
||||
|
||||
@@ -196,22 +179,14 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||
]
|
||||
|
||||
def mock_ensure_session(self: Any, *args: Any, **kwargs: Any) -> TracerSession:
|
||||
return TracerSession(name="test_session", tenant_id=_TENANT_ID, id=uuid.uuid4())
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "read_dataset", new=mock_read_dataset
|
||||
), mock.patch.object(
|
||||
LangChainPlusClient, "list_examples", new=mock_list_examples
|
||||
), mock.patch(
|
||||
"langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain
|
||||
), mock.patch.object(
|
||||
LangChainTracer, "ensure_session", new=mock_ensure_session
|
||||
):
|
||||
monkeypatch.setenv("LANGCHAIN_TENANT_ID", _TENANT_ID)
|
||||
client = LangChainPlusClient(
|
||||
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
|
||||
)
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
||||
chain = mock.MagicMock()
|
||||
num_repetitions = 3
|
||||
results = await client.arun_on_dataset(
|
||||
|
Reference in New Issue
Block a user