mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-29 21:30:18 +00:00
Use client from LCP-SDK (#5695)
- Remove the client implementation (this breaks backwards compatibility for existing testers. I could keep the stub in that file if we want, but not many people are using it yet - Add SDK as dependency - Update the 'run_on_dataset' method to be a function that optionally accepts a client as an argument - Remove the langchain plus server implementation (you get it for free with the SDK now) We could make the SDK optional for now, but the plan is to use w/in the tracer so it would likely become a hard dependency at some point.
This commit is contained in:
@@ -1,116 +0,0 @@
|
||||
"""LangChain+ langchain_client Integration Tests."""
|
||||
import os
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from tenacity import RetryError
|
||||
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
from langchain.callbacks.manager import tracing_v2_enabled
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.client import LangChainPlusClient
|
||||
from langchain.tools.base import tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def langchain_client(monkeypatch: pytest.MonkeyPatch) -> LangChainPlusClient:
|
||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
|
||||
return LangChainPlusClient()
|
||||
|
||||
|
||||
def test_sessions(
|
||||
langchain_client: LangChainPlusClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Test sessions."""
|
||||
session_names = set([session.name for session in langchain_client.list_sessions()])
|
||||
new_session = f"Session {uuid4()}"
|
||||
assert new_session not in session_names
|
||||
|
||||
@tool
|
||||
def example_tool() -> str:
|
||||
"""Call me, maybe."""
|
||||
return "test_tool"
|
||||
|
||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
|
||||
with tracing_v2_enabled(session_name=new_session):
|
||||
example_tool({})
|
||||
session = langchain_client.read_session(session_name=new_session)
|
||||
assert session.name == new_session
|
||||
session_names = set([sess.name for sess in langchain_client.list_sessions()])
|
||||
assert new_session in session_names
|
||||
runs = list(langchain_client.list_runs(session_name=new_session))
|
||||
session_id_runs = list(langchain_client.list_runs(session_id=session.id))
|
||||
assert len(runs) == len(session_id_runs) == 1
|
||||
assert runs[0].id == session_id_runs[0].id
|
||||
langchain_client.delete_session(session_name=new_session)
|
||||
|
||||
with pytest.raises(RetryError):
|
||||
langchain_client.read_session(session_name=new_session)
|
||||
assert new_session not in set(
|
||||
[sess.name for sess in langchain_client.list_sessions()]
|
||||
)
|
||||
with pytest.raises(RetryError):
|
||||
langchain_client.delete_session(session_name=new_session)
|
||||
with pytest.raises(RetryError):
|
||||
langchain_client.read_run(run_id=str(runs[0].id))
|
||||
|
||||
|
||||
def test_feedback_cycle(
|
||||
monkeypatch: pytest.MonkeyPatch, langchain_client: LangChainPlusClient
|
||||
) -> None:
|
||||
"""Test that feedback is correctly created and updated."""
|
||||
monkeypatch.setenv("LANGCHAIN_TRACING_V2", "true")
|
||||
monkeypatch.setenv("LANGCHAIN_SESSION", f"Feedback Testing {uuid4()}")
|
||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
tools = load_tools(["serpapi", "llm-math"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False
|
||||
)
|
||||
|
||||
agent.run(
|
||||
"What is the population of Kuala Lumpur as of January, 2023?"
|
||||
" What is it's square root?"
|
||||
)
|
||||
other_session_name = f"Feedback Testing {uuid4()}"
|
||||
with tracing_v2_enabled(session_name=other_session_name):
|
||||
try:
|
||||
agent.run("What is the square root of 3?")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
runs = list(
|
||||
langchain_client.list_runs(
|
||||
session_name=os.environ["LANGCHAIN_SESSION"], error=False, execution_order=1
|
||||
)
|
||||
)
|
||||
assert len(runs) == 1
|
||||
order_2 = list(
|
||||
langchain_client.list_runs(
|
||||
session_name=os.environ["LANGCHAIN_SESSION"], execution_order=2
|
||||
)
|
||||
)
|
||||
assert len(order_2) > 0
|
||||
langchain_client.create_feedback(str(order_2[0].id), "test score", score=0)
|
||||
feedback = langchain_client.create_feedback(str(runs[0].id), "test score", score=1)
|
||||
feedbacks = list(langchain_client.list_feedback(run_ids=[str(runs[0].id)]))
|
||||
assert len(feedbacks) == 1
|
||||
assert feedbacks[0].id == feedback.id
|
||||
|
||||
# Add feedback to other session
|
||||
other_runs = list(
|
||||
langchain_client.list_runs(session_name=other_session_name, execution_order=1)
|
||||
)
|
||||
assert len(other_runs) == 1
|
||||
langchain_client.create_feedback(
|
||||
run_id=str(other_runs[0].id), key="test score", score=0
|
||||
)
|
||||
all_runs = list(
|
||||
langchain_client.list_runs(session_name=os.environ["LANGCHAIN_SESSION"])
|
||||
) + list(langchain_client.list_runs(session_name=other_session_name))
|
||||
test_run_ids = [str(run.id) for run in all_runs]
|
||||
all_feedback = list(langchain_client.list_feedback(run_ids=test_run_ids))
|
||||
assert len(all_feedback) == 3
|
||||
for feedback in all_feedback:
|
||||
langchain_client.delete_feedback(str(feedback.id))
|
||||
feedbacks = list(langchain_client.list_feedback(run_ids=test_run_ids))
|
||||
assert len(feedbacks) == 0
|
||||
@@ -1,207 +0,0 @@
|
||||
"""Test the LangChain+ client."""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Union
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.client.langchain import (
|
||||
LangChainPlusClient,
|
||||
_get_link_stem,
|
||||
_is_localhost,
|
||||
)
|
||||
from langchain.client.models import Dataset, Example
|
||||
|
||||
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
|
||||
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_url, expected_url",
|
||||
[
|
||||
("http://localhost:8000", "http://localhost"),
|
||||
("http://www.example.com", "http://www.example.com"),
|
||||
(
|
||||
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
|
||||
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
|
||||
),
|
||||
("https://www.langchain.com/path/to/nowhere", "https://www.langchain.com"),
|
||||
],
|
||||
)
|
||||
def test_link_split(api_url: str, expected_url: str) -> None:
|
||||
"""Test the link splitting handles both localhost and deployed urls."""
|
||||
assert _get_link_stem(api_url) == expected_url
|
||||
|
||||
|
||||
def test_is_localhost() -> None:
|
||||
assert _is_localhost("http://localhost:8000")
|
||||
assert _is_localhost("http://127.0.0.1:8000")
|
||||
assert _is_localhost("http://0.0.0.0:8000")
|
||||
assert not _is_localhost("http://example.com:8000")
|
||||
|
||||
|
||||
def test_validate_api_key_if_hosted(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="API key must be provided"):
|
||||
LangChainPlusClient(api_url="http://www.example.com")
|
||||
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000")
|
||||
assert client.api_url == "http://localhost:8000"
|
||||
assert client.api_key is None
|
||||
|
||||
|
||||
def test_headers(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False)
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
||||
assert client._headers == {"x-api-key": "123"}
|
||||
|
||||
client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
|
||||
assert client_no_key._headers == {}
|
||||
|
||||
|
||||
@mock.patch("langchain.client.langchain.requests.post")
|
||||
def test_upload_csv(mock_post: mock.Mock) -> None:
|
||||
mock_response = mock.Mock()
|
||||
dataset_id = str(uuid.uuid4())
|
||||
example_1 = Example(
|
||||
id=str(uuid.uuid4()),
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "1"},
|
||||
outputs={"output": "2"},
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
example_2 = Example(
|
||||
id=str(uuid.uuid4()),
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "3"},
|
||||
outputs={"output": "4"},
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
||||
mock_response.json.return_value = {
|
||||
"id": dataset_id,
|
||||
"name": "test.csv",
|
||||
"description": "Test dataset",
|
||||
"owner_id": "the owner",
|
||||
"created_at": _CREATED_AT,
|
||||
"examples": [example_1, example_2],
|
||||
"tenant_id": _TENANT_ID,
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = LangChainPlusClient(
|
||||
api_url="http://localhost:8000",
|
||||
api_key="123",
|
||||
)
|
||||
csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n"))
|
||||
|
||||
dataset = client.upload_csv(
|
||||
csv_file, "Test dataset", input_keys=["input"], output_keys=["output"]
|
||||
)
|
||||
|
||||
assert dataset.id == uuid.UUID(dataset_id)
|
||||
assert dataset.name == "test.csv"
|
||||
assert dataset.description == "Test dataset"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = Dataset(
|
||||
id=uuid.uuid4(),
|
||||
name="test",
|
||||
description="Test dataset",
|
||||
owner_id="owner",
|
||||
created_at=_CREATED_AT,
|
||||
tenant_id=_TENANT_ID,
|
||||
)
|
||||
uuids = [
|
||||
"0c193153-2309-4704-9a47-17aee4fb25c8",
|
||||
"0d11b5fd-8e66-4485-b696-4b55155c0c05",
|
||||
"90d696f0-f10d-4fd0-b88b-bfee6df08b84",
|
||||
"4ce2c6d8-5124-4c0c-8292-db7bdebcf167",
|
||||
"7b5a524c-80fa-4960-888e-7d380f9a11ee",
|
||||
]
|
||||
examples = [
|
||||
Example(
|
||||
id=uuids[0],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "1"},
|
||||
outputs={"output": "2"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[1],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "3"},
|
||||
outputs={"output": "4"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[2],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "5"},
|
||||
outputs={"output": "6"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[3],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "7"},
|
||||
outputs={"output": "8"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[4],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "9"},
|
||||
outputs={"output": "10"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
]
|
||||
|
||||
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
|
||||
return dataset
|
||||
|
||||
def mock_list_examples(*args: Any, **kwargs: Any) -> List[Example]:
|
||||
return examples
|
||||
|
||||
async def mock_arun_chain(
|
||||
example: Example,
|
||||
llm_or_chain: Union[BaseLanguageModel, Chain],
|
||||
n_repetitions: int,
|
||||
tracer: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||
]
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "read_dataset", new=mock_read_dataset
|
||||
), mock.patch.object(
|
||||
LangChainPlusClient, "list_examples", new=mock_list_examples
|
||||
), mock.patch(
|
||||
"langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain
|
||||
):
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
||||
chain = mock.MagicMock()
|
||||
num_repetitions = 3
|
||||
results = await client.arun_on_dataset(
|
||||
dataset_name="test",
|
||||
llm_or_chain_factory=lambda: chain,
|
||||
concurrency_level=2,
|
||||
session_name="test_session",
|
||||
num_repetitions=num_repetitions,
|
||||
)
|
||||
|
||||
expected = {
|
||||
uuid_: [
|
||||
{"result": f"Result for example {uuid.UUID(uuid_)}"}
|
||||
for _ in range(num_repetitions)
|
||||
]
|
||||
for uuid_ in uuids
|
||||
}
|
||||
assert results == expected
|
||||
@@ -1,18 +1,27 @@
|
||||
"""Test the LangChain+ client."""
|
||||
from typing import Any, Dict
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Union
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from langchainplus_sdk.client import LangChainPlusClient
|
||||
from langchainplus_sdk.schemas import Dataset, Example
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.client.runner_utils import (
|
||||
InputFormatError,
|
||||
_get_messages,
|
||||
_get_prompts,
|
||||
arun_on_dataset,
|
||||
run_llm,
|
||||
)
|
||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
|
||||
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
|
||||
_EXAMPLE_MESSAGE = {
|
||||
"data": {"content": "Foo", "example": False, "additional_kwargs": {}},
|
||||
"type": "human",
|
||||
@@ -93,3 +102,103 @@ def test_run_llm_all_formats(inputs: Dict[str, Any]) -> None:
|
||||
def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None:
|
||||
llm = FakeChatModel()
|
||||
run_llm(llm, inputs, mock.MagicMock())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = Dataset(
|
||||
id=uuid.uuid4(),
|
||||
name="test",
|
||||
description="Test dataset",
|
||||
owner_id="owner",
|
||||
created_at=_CREATED_AT,
|
||||
tenant_id=_TENANT_ID,
|
||||
)
|
||||
uuids = [
|
||||
"0c193153-2309-4704-9a47-17aee4fb25c8",
|
||||
"0d11b5fd-8e66-4485-b696-4b55155c0c05",
|
||||
"90d696f0-f10d-4fd0-b88b-bfee6df08b84",
|
||||
"4ce2c6d8-5124-4c0c-8292-db7bdebcf167",
|
||||
"7b5a524c-80fa-4960-888e-7d380f9a11ee",
|
||||
]
|
||||
examples = [
|
||||
Example(
|
||||
id=uuids[0],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "1"},
|
||||
outputs={"output": "2"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[1],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "3"},
|
||||
outputs={"output": "4"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[2],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "5"},
|
||||
outputs={"output": "6"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[3],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "7"},
|
||||
outputs={"output": "8"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[4],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "9"},
|
||||
outputs={"output": "10"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
]
|
||||
|
||||
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
|
||||
return dataset
|
||||
|
||||
def mock_list_examples(*args: Any, **kwargs: Any) -> List[Example]:
|
||||
return examples
|
||||
|
||||
async def mock_arun_chain(
|
||||
example: Example,
|
||||
llm_or_chain: Union[BaseLanguageModel, Chain],
|
||||
n_repetitions: int,
|
||||
tracer: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||
]
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "read_dataset", new=mock_read_dataset
|
||||
), mock.patch.object(
|
||||
LangChainPlusClient, "list_examples", new=mock_list_examples
|
||||
), mock.patch(
|
||||
"langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain
|
||||
):
|
||||
client = LangChainPlusClient(api_url="http://localhost:1984", api_key="123")
|
||||
chain = mock.MagicMock()
|
||||
num_repetitions = 3
|
||||
results = await arun_on_dataset(
|
||||
dataset_name="test",
|
||||
llm_or_chain_factory=lambda: chain,
|
||||
concurrency_level=2,
|
||||
session_name="test_session",
|
||||
num_repetitions=num_repetitions,
|
||||
client=client,
|
||||
)
|
||||
|
||||
expected = {
|
||||
uuid_: [
|
||||
{"result": f"Result for example {uuid.UUID(uuid_)}"}
|
||||
for _ in range(num_repetitions)
|
||||
]
|
||||
for uuid_ in uuids
|
||||
}
|
||||
assert results == expected
|
||||
|
||||
@@ -38,6 +38,7 @@ def test_required_dependencies(poetry_conf: Mapping[str, Any]) -> None:
|
||||
"aiohttp",
|
||||
"async-timeout",
|
||||
"dataclasses-json",
|
||||
"langchainplus-sdk",
|
||||
"numexpr",
|
||||
"numpy",
|
||||
"openapi-schema-pydantic",
|
||||
Reference in New Issue
Block a user