mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-17 20:38:56 +00:00
Update Tracer Auth / Reduce Num Calls (#5517)
Update the session creation and calls --------- Co-authored-by: Ankush Gola <ankush.gola@gmail.com>
This commit is contained in:
parent
949729ff5c
commit
20ec1173f4
@ -23,7 +23,6 @@ from langchain.callbacks.openai_info import OpenAICallbackHandler
|
|||||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||||
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
|
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
|
||||||
from langchain.callbacks.tracers.schemas import TracerSession
|
|
||||||
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||||
from langchain.callbacks.tracers.wandb import WandbTracer
|
from langchain.callbacks.tracers.wandb import WandbTracer
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
@ -99,26 +98,21 @@ def tracing_v2_enabled(
|
|||||||
session_name: Optional[str] = None,
|
session_name: Optional[str] = None,
|
||||||
*,
|
*,
|
||||||
example_id: Optional[Union[str, UUID]] = None,
|
example_id: Optional[Union[str, UUID]] = None,
|
||||||
tenant_id: Optional[str] = None,
|
) -> Generator[None, None, None]:
|
||||||
session_extra: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> Generator[TracerSession, None, None]:
|
|
||||||
"""Get the experimental tracer handler in a context manager."""
|
"""Get the experimental tracer handler in a context manager."""
|
||||||
# Issue a warning that this is experimental
|
# Issue a warning that this is experimental
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The experimental tracing v2 is in development. "
|
"The tracing v2 API is in development. "
|
||||||
"This is not yet stable and may change in the future."
|
"This is not yet stable and may change in the future."
|
||||||
)
|
)
|
||||||
if isinstance(example_id, str):
|
if isinstance(example_id, str):
|
||||||
example_id = UUID(example_id)
|
example_id = UUID(example_id)
|
||||||
cb = LangChainTracer(
|
cb = LangChainTracer(
|
||||||
tenant_id=tenant_id,
|
|
||||||
session_name=session_name,
|
|
||||||
example_id=example_id,
|
example_id=example_id,
|
||||||
session_extra=session_extra,
|
session_name=session_name,
|
||||||
)
|
)
|
||||||
session = cb.ensure_session()
|
|
||||||
tracing_v2_callback_var.set(cb)
|
tracing_v2_callback_var.set(cb)
|
||||||
yield session
|
yield
|
||||||
tracing_v2_callback_var.set(None)
|
tracing_v2_callback_var.set(None)
|
||||||
|
|
||||||
|
|
||||||
@ -919,7 +913,6 @@ def _configure(
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
handler = LangChainTracer(session_name=tracer_session)
|
handler = LangChainTracer(session_name=tracer_session)
|
||||||
handler.ensure_session()
|
|
||||||
callback_manager.add_handler(handler, True)
|
callback_manager.add_handler(handler, True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -25,10 +25,8 @@ from langchain.callbacks.tracers.schemas import (
|
|||||||
RunTypeEnum,
|
RunTypeEnum,
|
||||||
RunUpdate,
|
RunUpdate,
|
||||||
TracerSession,
|
TracerSession,
|
||||||
TracerSessionCreate,
|
|
||||||
)
|
)
|
||||||
from langchain.schema import BaseMessage, messages_to_dict
|
from langchain.schema import BaseMessage, messages_to_dict
|
||||||
from langchain.utils import raise_for_status_with_text
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -65,49 +63,13 @@ retry_decorator = retry(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@retry_decorator
|
|
||||||
def _get_tenant_id(
|
|
||||||
tenant_id: Optional[str], endpoint: Optional[str], headers: Optional[dict]
|
|
||||||
) -> str:
|
|
||||||
"""Get the tenant ID for the LangChain API."""
|
|
||||||
tenant_id_: Optional[str] = tenant_id or os.getenv("LANGCHAIN_TENANT_ID")
|
|
||||||
if tenant_id_:
|
|
||||||
return tenant_id_
|
|
||||||
endpoint_ = endpoint or get_endpoint()
|
|
||||||
headers_ = headers or get_headers()
|
|
||||||
response = None
|
|
||||||
try:
|
|
||||||
response = requests.get(endpoint_ + "/tenants", headers=headers_)
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
except HTTPError as e:
|
|
||||||
if response is not None and response.status_code == 500:
|
|
||||||
raise LangChainTracerAPIError(
|
|
||||||
f"Failed to get tenant ID from LangChain API. {e}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise LangChainTracerUserError(
|
|
||||||
f"Failed to get tenant ID from LangChain API. {e}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise LangChainTracerError(
|
|
||||||
f"Failed to get tenant ID from LangChain API. {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
tenants: List[Dict[str, Any]] = response.json()
|
|
||||||
if not tenants:
|
|
||||||
raise ValueError(f"No tenants found for URL {endpoint_}")
|
|
||||||
return tenants[0]["id"]
|
|
||||||
|
|
||||||
|
|
||||||
class LangChainTracer(BaseTracer):
|
class LangChainTracer(BaseTracer):
|
||||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tenant_id: Optional[str] = None,
|
|
||||||
example_id: Optional[UUID] = None,
|
example_id: Optional[UUID] = None,
|
||||||
session_name: Optional[str] = None,
|
session_name: Optional[str] = None,
|
||||||
session_extra: Optional[Dict[str, Any]] = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the LangChain tracer."""
|
"""Initialize the LangChain tracer."""
|
||||||
@ -115,10 +77,8 @@ class LangChainTracer(BaseTracer):
|
|||||||
self.session: Optional[TracerSession] = None
|
self.session: Optional[TracerSession] = None
|
||||||
self._endpoint = get_endpoint()
|
self._endpoint = get_endpoint()
|
||||||
self._headers = get_headers()
|
self._headers = get_headers()
|
||||||
self.tenant_id = tenant_id
|
|
||||||
self.example_id = example_id
|
self.example_id = example_id
|
||||||
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
|
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
|
||||||
self.session_extra = session_extra
|
|
||||||
# set max_workers to 1 to process tasks in order
|
# set max_workers to 1 to process tasks in order
|
||||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
|
||||||
@ -149,62 +109,20 @@ class LangChainTracer(BaseTracer):
|
|||||||
self._start_trace(chat_model_run)
|
self._start_trace(chat_model_run)
|
||||||
self._on_chat_model_start(chat_model_run)
|
self._on_chat_model_start(chat_model_run)
|
||||||
|
|
||||||
def ensure_tenant_id(self) -> str:
|
|
||||||
"""Load or use the tenant ID."""
|
|
||||||
tenant_id = self.tenant_id or _get_tenant_id(
|
|
||||||
self.tenant_id, self._endpoint, self._headers
|
|
||||||
)
|
|
||||||
self.tenant_id = tenant_id
|
|
||||||
return tenant_id
|
|
||||||
|
|
||||||
@retry_decorator
|
|
||||||
def ensure_session(self) -> TracerSession:
|
|
||||||
"""Upsert a session."""
|
|
||||||
if self.session is not None:
|
|
||||||
return self.session
|
|
||||||
tenant_id = self.ensure_tenant_id()
|
|
||||||
url = f"{self._endpoint}/sessions?upsert=true"
|
|
||||||
session_create = TracerSessionCreate(
|
|
||||||
name=self.session_name, extra=self.session_extra, tenant_id=tenant_id
|
|
||||||
)
|
|
||||||
response = None
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
url,
|
|
||||||
data=session_create.json(),
|
|
||||||
headers=self._headers,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
except HTTPError as e:
|
|
||||||
if response is not None and response.status_code == 500:
|
|
||||||
raise LangChainTracerAPIError(
|
|
||||||
f"Failed to upsert session to LangChain API. {e}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise LangChainTracerUserError(
|
|
||||||
f"Failed to upsert session to LangChain API. {e}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise LangChainTracerError(
|
|
||||||
f"Failed to upsert session to LangChain API. {e}"
|
|
||||||
) from e
|
|
||||||
self.session = TracerSession(**response.json())
|
|
||||||
return self.session
|
|
||||||
|
|
||||||
def _persist_run(self, run: Run) -> None:
|
def _persist_run(self, run: Run) -> None:
|
||||||
"""Persist a run."""
|
"""The Langchain Tracer uses Post/Patch rather than persist."""
|
||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
def _persist_run_single(self, run: Run) -> None:
|
def _persist_run_single(self, run: Run) -> None:
|
||||||
"""Persist a run."""
|
"""Persist a run."""
|
||||||
session = self.ensure_session()
|
|
||||||
if run.parent_run_id is None:
|
if run.parent_run_id is None:
|
||||||
run.reference_example_id = self.example_id
|
run.reference_example_id = self.example_id
|
||||||
run_dict = run.dict()
|
run_dict = run.dict()
|
||||||
del run_dict["child_runs"]
|
del run_dict["child_runs"]
|
||||||
run_create = RunCreate(**run_dict, session_id=session.id)
|
run_create = RunCreate(**run_dict, session_name=self.session_name)
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
# TODO: Add retries when async
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{self._endpoint}/runs",
|
f"{self._endpoint}/runs",
|
||||||
data=run_create.json(),
|
data=run_create.json(),
|
||||||
|
@ -36,12 +36,6 @@ class TracerSessionBase(TracerSessionV1Base):
|
|||||||
tenant_id: UUID
|
tenant_id: UUID
|
||||||
|
|
||||||
|
|
||||||
class TracerSessionCreate(TracerSessionBase):
|
|
||||||
"""A creation class for TracerSession."""
|
|
||||||
|
|
||||||
id: Optional[UUID]
|
|
||||||
|
|
||||||
|
|
||||||
class TracerSession(TracerSessionBase):
|
class TracerSession(TracerSessionBase):
|
||||||
"""TracerSessionV1 schema for the V2 API."""
|
"""TracerSessionV1 schema for the V2 API."""
|
||||||
|
|
||||||
@ -136,7 +130,7 @@ class Run(RunBase):
|
|||||||
|
|
||||||
class RunCreate(RunBase):
|
class RunCreate(RunBase):
|
||||||
name: str
|
name: str
|
||||||
session_id: UUID
|
session_name: Optional[str] = None
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def add_runtime_env(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def add_runtime_env(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
@ -10,7 +10,6 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
@ -26,7 +25,8 @@ from requests import Response
|
|||||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.tracers.schemas import Run, TracerSession
|
from langchain.callbacks.tracers.schemas import Run as TracerRun
|
||||||
|
from langchain.callbacks.tracers.schemas import TracerSession
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.client.models import (
|
from langchain.client.models import (
|
||||||
APIFeedbackSource,
|
APIFeedbackSource,
|
||||||
@ -54,6 +54,10 @@ logger = logging.getLogger(__name__)
|
|||||||
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
|
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
|
||||||
|
|
||||||
|
|
||||||
|
class Run(TracerRun):
|
||||||
|
id: UUID
|
||||||
|
|
||||||
|
|
||||||
def _get_link_stem(url: str) -> str:
|
def _get_link_stem(url: str) -> str:
|
||||||
scheme = urlsplit(url).scheme
|
scheme = urlsplit(url).scheme
|
||||||
netloc_prefix = urlsplit(url).netloc.split(":")[0]
|
netloc_prefix = urlsplit(url).netloc.split(":")[0]
|
||||||
@ -75,7 +79,6 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
|
|
||||||
api_key: Optional[str] = Field(default=None, env="LANGCHAIN_API_KEY")
|
api_key: Optional[str] = Field(default=None, env="LANGCHAIN_API_KEY")
|
||||||
api_url: str = Field(default="http://localhost:1984", env="LANGCHAIN_ENDPOINT")
|
api_url: str = Field(default="http://localhost:1984", env="LANGCHAIN_ENDPOINT")
|
||||||
tenant_id: Optional[str] = None
|
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def validate_api_key_if_hosted(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def validate_api_key_if_hosted(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
@ -87,31 +90,8 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"API key must be provided when using hosted LangChain+ API"
|
"API key must be provided when using hosted LangChain+ API"
|
||||||
)
|
)
|
||||||
tenant_id = values.get("tenant_id")
|
|
||||||
if not tenant_id:
|
|
||||||
values["tenant_id"] = LangChainPlusClient._get_seeded_tenant_id(
|
|
||||||
api_url, api_key
|
|
||||||
)
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
|
||||||
def _get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
|
|
||||||
"""Get the tenant ID from the seeded tenant."""
|
|
||||||
url = f"{api_url}/tenants"
|
|
||||||
headers = {"x-api-key": api_key} if api_key else {}
|
|
||||||
response = requests.get(url, headers=headers)
|
|
||||||
try:
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(
|
|
||||||
"Unable to get default tenant ID. Please manually provide."
|
|
||||||
) from e
|
|
||||||
results: List[dict] = response.json()
|
|
||||||
if len(results) == 0:
|
|
||||||
raise ValueError("No seeded tenant found")
|
|
||||||
return results[0]["id"]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_session_name(
|
def _get_session_name(
|
||||||
session_name: Optional[str],
|
session_name: Optional[str],
|
||||||
@ -149,18 +129,10 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
headers["x-api-key"] = self.api_key
|
headers["x-api-key"] = self.api_key
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
@property
|
|
||||||
def query_params(self) -> Dict[str, Any]:
|
|
||||||
"""Get the headers for the API request."""
|
|
||||||
return {"tenant_id": self.tenant_id}
|
|
||||||
|
|
||||||
def _get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Response:
|
def _get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Response:
|
||||||
"""Make a GET request."""
|
"""Make a GET request."""
|
||||||
query_params = self.query_params
|
|
||||||
if params:
|
|
||||||
query_params.update(params)
|
|
||||||
return requests.get(
|
return requests.get(
|
||||||
f"{self.api_url}{path}", headers=self._headers, params=query_params
|
f"{self.api_url}{path}", headers=self._headers, params=params
|
||||||
)
|
)
|
||||||
|
|
||||||
def upload_dataframe(
|
def upload_dataframe(
|
||||||
@ -192,7 +164,6 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
"input_keys": ",".join(input_keys),
|
"input_keys": ",".join(input_keys),
|
||||||
"output_keys": ",".join(output_keys),
|
"output_keys": ",".join(output_keys),
|
||||||
"description": description,
|
"description": description,
|
||||||
"tenant_id": self.tenant_id,
|
|
||||||
}
|
}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.api_url + "/datasets/upload",
|
self.api_url + "/datasets/upload",
|
||||||
@ -244,7 +215,7 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
) -> TracerSession:
|
) -> TracerSession:
|
||||||
"""Read a session from the LangChain+ API."""
|
"""Read a session from the LangChain+ API."""
|
||||||
path = "/sessions"
|
path = "/sessions"
|
||||||
params: Dict[str, Any] = {"limit": 1, "tenant_id": self.tenant_id}
|
params: Dict[str, Any] = {"limit": 1}
|
||||||
if session_id is not None:
|
if session_id is not None:
|
||||||
path += f"/{session_id}"
|
path += f"/{session_id}"
|
||||||
elif session_name is not None:
|
elif session_name is not None:
|
||||||
@ -291,7 +262,6 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""Create a dataset in the LangChain+ API."""
|
"""Create a dataset in the LangChain+ API."""
|
||||||
dataset = DatasetCreate(
|
dataset = DatasetCreate(
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
name=dataset_name,
|
name=dataset_name,
|
||||||
description=description,
|
description=description,
|
||||||
)
|
)
|
||||||
@ -309,7 +279,7 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
self, *, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None
|
self, *, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
path = "/datasets"
|
path = "/datasets"
|
||||||
params: Dict[str, Any] = {"limit": 1, "tenant_id": self.tenant_id}
|
params: Dict[str, Any] = {"limit": 1}
|
||||||
if dataset_id is not None:
|
if dataset_id is not None:
|
||||||
path += f"/{dataset_id}"
|
path += f"/{dataset_id}"
|
||||||
elif dataset_name is not None:
|
elif dataset_name is not None:
|
||||||
|
@ -49,7 +49,6 @@ class ExampleUpdate(BaseModel):
|
|||||||
class DatasetBase(BaseModel):
|
class DatasetBase(BaseModel):
|
||||||
"""Dataset base model."""
|
"""Dataset base model."""
|
||||||
|
|
||||||
tenant_id: UUID
|
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
|
||||||
@ -68,6 +67,7 @@ class Dataset(DatasetBase):
|
|||||||
"""Dataset ORM model."""
|
"""Dataset ORM model."""
|
||||||
|
|
||||||
id: UUID
|
id: UUID
|
||||||
|
tenant_id: UUID
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
modified_at: Optional[datetime] = Field(default=None)
|
modified_at: Optional[datetime] = Field(default=None)
|
||||||
|
|
||||||
|
@ -214,7 +214,6 @@ async def _tracer_initializer(session_name: Optional[str]) -> Optional[LangChain
|
|||||||
"""
|
"""
|
||||||
if session_name:
|
if session_name:
|
||||||
tracer = LangChainTracer(session_name=session_name)
|
tracer = LangChainTracer(session_name=session_name)
|
||||||
tracer.ensure_session()
|
|
||||||
return tracer
|
return tracer
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
@ -148,8 +148,7 @@ def test_tracing_v2_context_manager() -> None:
|
|||||||
)
|
)
|
||||||
if "LANGCHAIN_TRACING_V2" in os.environ:
|
if "LANGCHAIN_TRACING_V2" in os.environ:
|
||||||
del os.environ["LANGCHAIN_TRACING_V2"]
|
del os.environ["LANGCHAIN_TRACING_V2"]
|
||||||
with tracing_v2_enabled() as session:
|
with tracing_v2_enabled():
|
||||||
assert session
|
|
||||||
agent.run(questions[0]) # this should be traced
|
agent.run(questions[0]) # this should be traced
|
||||||
|
|
||||||
agent.run(questions[0]) # this should not be traced
|
agent.run(questions[0]) # this should not be traced
|
||||||
|
@ -2,14 +2,12 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Union
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
|
||||||
from langchain.callbacks.tracers.schemas import TracerSession
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.client.langchain import (
|
from langchain.client.langchain import (
|
||||||
LangChainPlusClient,
|
LangChainPlusClient,
|
||||||
@ -46,39 +44,23 @@ def test_is_localhost() -> None:
|
|||||||
assert not _is_localhost("http://example.com:8000")
|
assert not _is_localhost("http://example.com:8000")
|
||||||
|
|
||||||
|
|
||||||
def test_validate_api_key_if_hosted() -> None:
|
def test_validate_api_key_if_hosted(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
|
monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False)
|
||||||
return _TENANT_ID
|
with pytest.raises(ValueError, match="API key must be provided"):
|
||||||
|
LangChainPlusClient(api_url="http://www.example.com")
|
||||||
|
|
||||||
with mock.patch.object(
|
client = LangChainPlusClient(api_url="http://localhost:8000")
|
||||||
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
|
assert client.api_url == "http://localhost:8000"
|
||||||
):
|
assert client.api_key is None
|
||||||
with pytest.raises(ValueError, match="API key must be provided"):
|
|
||||||
LangChainPlusClient(api_url="http://www.example.com")
|
|
||||||
|
|
||||||
with mock.patch.object(
|
|
||||||
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
|
|
||||||
):
|
|
||||||
client = LangChainPlusClient(api_url="http://localhost:8000")
|
|
||||||
assert client.api_url == "http://localhost:8000"
|
|
||||||
assert client.api_key is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_headers() -> None:
|
def test_headers(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
|
monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False)
|
||||||
return _TENANT_ID
|
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
||||||
|
assert client._headers == {"x-api-key": "123"}
|
||||||
|
|
||||||
with mock.patch.object(
|
client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
|
||||||
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
|
assert client_no_key._headers == {}
|
||||||
):
|
|
||||||
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
|
||||||
assert client._headers == {"x-api-key": "123"}
|
|
||||||
|
|
||||||
with mock.patch.object(
|
|
||||||
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
|
|
||||||
):
|
|
||||||
client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
|
|
||||||
assert client_no_key._headers == {}
|
|
||||||
|
|
||||||
|
|
||||||
@mock.patch("langchain.client.langchain.requests.post")
|
@mock.patch("langchain.client.langchain.requests.post")
|
||||||
@ -112,7 +94,8 @@ def test_upload_csv(mock_post: mock.Mock) -> None:
|
|||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
client = LangChainPlusClient(
|
client = LangChainPlusClient(
|
||||||
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
|
api_url="http://localhost:8000",
|
||||||
|
api_key="123",
|
||||||
)
|
)
|
||||||
csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n"))
|
csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n"))
|
||||||
|
|
||||||
@ -196,22 +179,14 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||||
]
|
]
|
||||||
|
|
||||||
def mock_ensure_session(self: Any, *args: Any, **kwargs: Any) -> TracerSession:
|
|
||||||
return TracerSession(name="test_session", tenant_id=_TENANT_ID, id=uuid.uuid4())
|
|
||||||
|
|
||||||
with mock.patch.object(
|
with mock.patch.object(
|
||||||
LangChainPlusClient, "read_dataset", new=mock_read_dataset
|
LangChainPlusClient, "read_dataset", new=mock_read_dataset
|
||||||
), mock.patch.object(
|
), mock.patch.object(
|
||||||
LangChainPlusClient, "list_examples", new=mock_list_examples
|
LangChainPlusClient, "list_examples", new=mock_list_examples
|
||||||
), mock.patch(
|
), mock.patch(
|
||||||
"langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain
|
"langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain
|
||||||
), mock.patch.object(
|
|
||||||
LangChainTracer, "ensure_session", new=mock_ensure_session
|
|
||||||
):
|
):
|
||||||
monkeypatch.setenv("LANGCHAIN_TENANT_ID", _TENANT_ID)
|
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
||||||
client = LangChainPlusClient(
|
|
||||||
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
|
|
||||||
)
|
|
||||||
chain = mock.MagicMock()
|
chain = mock.MagicMock()
|
||||||
num_repetitions = 3
|
num_repetitions = 3
|
||||||
results = await client.arun_on_dataset(
|
results = await client.arun_on_dataset(
|
||||||
|
Loading…
Reference in New Issue
Block a user