mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-06 23:24:48 +00:00
Add Delete Session Method (#5193)
This commit is contained in:
parent
66113c2a62
commit
e76e68b211
@ -200,7 +200,7 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
return Dataset(**result)
|
return Dataset(**result)
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||||
def read_run(self, run_id: str) -> Run:
|
def read_run(self, run_id: Union[str, UUID]) -> Run:
|
||||||
"""Read a run from the LangChain+ API."""
|
"""Read a run from the LangChain+ API."""
|
||||||
response = self._get(f"/runs/{run_id}")
|
response = self._get(f"/runs/{run_id}")
|
||||||
raise_for_status_with_text(response)
|
raise_for_status_with_text(response)
|
||||||
@ -268,6 +268,22 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
raise_for_status_with_text(response)
|
raise_for_status_with_text(response)
|
||||||
yield from [TracerSession(**session) for session in response.json()]
|
yield from [TracerSession(**session) for session in response.json()]
|
||||||
|
|
||||||
|
@xor_args(("session_name", "session_id"))
|
||||||
|
def delete_session(
|
||||||
|
self, *, session_name: Optional[str] = None, session_id: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""Delete a session from the LangChain+ API."""
|
||||||
|
if session_name is not None:
|
||||||
|
session_id = self.read_session(session_name=session_name).id
|
||||||
|
elif session_id is None:
|
||||||
|
raise ValueError("Must provide session_name or session_id")
|
||||||
|
response = requests.delete(
|
||||||
|
self.api_url + f"/sessions/{session_id}",
|
||||||
|
headers=self._headers,
|
||||||
|
)
|
||||||
|
raise_for_status_with_text(response)
|
||||||
|
return None
|
||||||
|
|
||||||
def create_dataset(self, dataset_name: str, description: str) -> Dataset:
|
def create_dataset(self, dataset_name: str, description: str) -> Dataset:
|
||||||
"""Create a dataset in the LangChain+ API."""
|
"""Create a dataset in the LangChain+ API."""
|
||||||
dataset = DatasetCreate(
|
dataset = DatasetCreate(
|
||||||
@ -360,7 +376,7 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
return Example(**result)
|
return Example(**result)
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||||
def read_example(self, example_id: str) -> Example:
|
def read_example(self, example_id: Union[str, UUID]) -> Example:
|
||||||
"""Read an example from the LangChain+ API."""
|
"""Read an example from the LangChain+ API."""
|
||||||
response = self._get(f"/examples/{example_id}")
|
response = self._get(f"/examples/{example_id}")
|
||||||
raise_for_status_with_text(response)
|
raise_for_status_with_text(response)
|
||||||
|
0
tests/integration_tests/client/__init__.py
Normal file
0
tests/integration_tests/client/__init__.py
Normal file
52
tests/integration_tests/client/test_client.py
Normal file
52
tests/integration_tests/client/test_client.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
"""LangChain+ langchain_client Integration Tests."""
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from tenacity import RetryError
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import tracing_v2_enabled
|
||||||
|
from langchain.client import LangChainPlusClient
|
||||||
|
from langchain.tools.base import tool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def langchain_client(monkeypatch: pytest.MonkeyPatch) -> LangChainPlusClient:
|
||||||
|
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
|
||||||
|
return LangChainPlusClient()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sessions(
|
||||||
|
langchain_client: LangChainPlusClient, monkeypatch: pytest.MonkeyPatch
|
||||||
|
) -> None:
|
||||||
|
"""Test sessions."""
|
||||||
|
session_names = set([session.name for session in langchain_client.list_sessions()])
|
||||||
|
new_session = f"Session {uuid4()}"
|
||||||
|
assert new_session not in session_names
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def example_tool() -> str:
|
||||||
|
"""Call me, maybe."""
|
||||||
|
return "test_tool"
|
||||||
|
|
||||||
|
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
|
||||||
|
with tracing_v2_enabled(session_name=new_session):
|
||||||
|
example_tool({})
|
||||||
|
session = langchain_client.read_session(session_name=new_session)
|
||||||
|
assert session.name == new_session
|
||||||
|
session_names = set([sess.name for sess in langchain_client.list_sessions()])
|
||||||
|
assert new_session in session_names
|
||||||
|
runs = list(langchain_client.list_runs(session_name=new_session))
|
||||||
|
session_id_runs = list(langchain_client.list_runs(session_id=session.id))
|
||||||
|
assert len(runs) == len(session_id_runs) == 1
|
||||||
|
assert runs[0].id == session_id_runs[0].id
|
||||||
|
langchain_client.delete_session(session_name=new_session)
|
||||||
|
|
||||||
|
with pytest.raises(RetryError):
|
||||||
|
langchain_client.read_session(session_name=new_session)
|
||||||
|
assert new_session not in set(
|
||||||
|
[sess.name for sess in langchain_client.list_sessions()]
|
||||||
|
)
|
||||||
|
with pytest.raises(RetryError):
|
||||||
|
langchain_client.delete_session(session_name=new_session)
|
||||||
|
with pytest.raises(RetryError):
|
||||||
|
langchain_client.read_run(run_id=str(runs[0].id))
|
Loading…
Reference in New Issue
Block a user