Use client from LCP-SDK (#5695)

- Remove the client implementation (this breaks backwards compatibility
for existing testers. I could keep the stub in that file if we want, but
not many people are using it yet
- Add SDK as dependency
- Update the 'run_on_dataset' method to be a function that optionally
accepts a client as an argument
- Remove the langchain plus server implementation (you get it for free
with the SDK now)

We could make the SDK optional for now, but the plan is to use w/in the
tracer so it would likely become a hard dependency at some point.
This commit is contained in:
Zander Chase
2023-06-06 06:51:05 -07:00
committed by GitHub
parent 08e2352f7b
commit 204a73c1d9
17 changed files with 1446 additions and 2527 deletions

View File

@@ -1,116 +0,0 @@
"""LangChain+ langchain_client Integration Tests."""
import os
from uuid import uuid4
import pytest
from tenacity import RetryError
from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.callbacks.manager import tracing_v2_enabled
from langchain.chat_models import ChatOpenAI
from langchain.client import LangChainPlusClient
from langchain.tools.base import tool
@pytest.fixture
def langchain_client(monkeypatch: pytest.MonkeyPatch) -> LangChainPlusClient:
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
return LangChainPlusClient()
def test_sessions(
langchain_client: LangChainPlusClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test sessions."""
session_names = set([session.name for session in langchain_client.list_sessions()])
new_session = f"Session {uuid4()}"
assert new_session not in session_names
@tool
def example_tool() -> str:
"""Call me, maybe."""
return "test_tool"
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
with tracing_v2_enabled(session_name=new_session):
example_tool({})
session = langchain_client.read_session(session_name=new_session)
assert session.name == new_session
session_names = set([sess.name for sess in langchain_client.list_sessions()])
assert new_session in session_names
runs = list(langchain_client.list_runs(session_name=new_session))
session_id_runs = list(langchain_client.list_runs(session_id=session.id))
assert len(runs) == len(session_id_runs) == 1
assert runs[0].id == session_id_runs[0].id
langchain_client.delete_session(session_name=new_session)
with pytest.raises(RetryError):
langchain_client.read_session(session_name=new_session)
assert new_session not in set(
[sess.name for sess in langchain_client.list_sessions()]
)
with pytest.raises(RetryError):
langchain_client.delete_session(session_name=new_session)
with pytest.raises(RetryError):
langchain_client.read_run(run_id=str(runs[0].id))
def test_feedback_cycle(
monkeypatch: pytest.MonkeyPatch, langchain_client: LangChainPlusClient
) -> None:
"""Test that feedback is correctly created and updated."""
monkeypatch.setenv("LANGCHAIN_TRACING_V2", "true")
monkeypatch.setenv("LANGCHAIN_SESSION", f"Feedback Testing {uuid4()}")
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
llm = ChatOpenAI(temperature=0)
tools = load_tools(["serpapi", "llm-math"], llm=llm)
agent = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False
)
agent.run(
"What is the population of Kuala Lumpur as of January, 2023?"
" What is it's square root?"
)
other_session_name = f"Feedback Testing {uuid4()}"
with tracing_v2_enabled(session_name=other_session_name):
try:
agent.run("What is the square root of 3?")
except Exception as e:
print(e)
runs = list(
langchain_client.list_runs(
session_name=os.environ["LANGCHAIN_SESSION"], error=False, execution_order=1
)
)
assert len(runs) == 1
order_2 = list(
langchain_client.list_runs(
session_name=os.environ["LANGCHAIN_SESSION"], execution_order=2
)
)
assert len(order_2) > 0
langchain_client.create_feedback(str(order_2[0].id), "test score", score=0)
feedback = langchain_client.create_feedback(str(runs[0].id), "test score", score=1)
feedbacks = list(langchain_client.list_feedback(run_ids=[str(runs[0].id)]))
assert len(feedbacks) == 1
assert feedbacks[0].id == feedback.id
# Add feedback to other session
other_runs = list(
langchain_client.list_runs(session_name=other_session_name, execution_order=1)
)
assert len(other_runs) == 1
langchain_client.create_feedback(
run_id=str(other_runs[0].id), key="test score", score=0
)
all_runs = list(
langchain_client.list_runs(session_name=os.environ["LANGCHAIN_SESSION"])
) + list(langchain_client.list_runs(session_name=other_session_name))
test_run_ids = [str(run.id) for run in all_runs]
all_feedback = list(langchain_client.list_feedback(run_ids=test_run_ids))
assert len(all_feedback) == 3
for feedback in all_feedback:
langchain_client.delete_feedback(str(feedback.id))
feedbacks = list(langchain_client.list_feedback(run_ids=test_run_ids))
assert len(feedbacks) == 0

View File

@@ -1,207 +0,0 @@
"""Test the LangChain+ client."""
import uuid
from datetime import datetime
from io import BytesIO
from typing import Any, Dict, List, Union
from unittest import mock
import pytest
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.client.langchain import (
LangChainPlusClient,
_get_link_stem,
_is_localhost,
)
from langchain.client.models import Dataset, Example
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
@pytest.mark.parametrize(
"api_url, expected_url",
[
("http://localhost:8000", "http://localhost"),
("http://www.example.com", "http://www.example.com"),
(
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
),
("https://www.langchain.com/path/to/nowhere", "https://www.langchain.com"),
],
)
def test_link_split(api_url: str, expected_url: str) -> None:
"""Test the link splitting handles both localhost and deployed urls."""
assert _get_link_stem(api_url) == expected_url
def test_is_localhost() -> None:
assert _is_localhost("http://localhost:8000")
assert _is_localhost("http://127.0.0.1:8000")
assert _is_localhost("http://0.0.0.0:8000")
assert not _is_localhost("http://example.com:8000")
def test_validate_api_key_if_hosted(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False)
with pytest.raises(ValueError, match="API key must be provided"):
LangChainPlusClient(api_url="http://www.example.com")
client = LangChainPlusClient(api_url="http://localhost:8000")
assert client.api_url == "http://localhost:8000"
assert client.api_key is None
def test_headers(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False)
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
assert client._headers == {"x-api-key": "123"}
client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
assert client_no_key._headers == {}
@mock.patch("langchain.client.langchain.requests.post")
def test_upload_csv(mock_post: mock.Mock) -> None:
mock_response = mock.Mock()
dataset_id = str(uuid.uuid4())
example_1 = Example(
id=str(uuid.uuid4()),
created_at=_CREATED_AT,
inputs={"input": "1"},
outputs={"output": "2"},
dataset_id=dataset_id,
)
example_2 = Example(
id=str(uuid.uuid4()),
created_at=_CREATED_AT,
inputs={"input": "3"},
outputs={"output": "4"},
dataset_id=dataset_id,
)
mock_response.json.return_value = {
"id": dataset_id,
"name": "test.csv",
"description": "Test dataset",
"owner_id": "the owner",
"created_at": _CREATED_AT,
"examples": [example_1, example_2],
"tenant_id": _TENANT_ID,
}
mock_post.return_value = mock_response
client = LangChainPlusClient(
api_url="http://localhost:8000",
api_key="123",
)
csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n"))
dataset = client.upload_csv(
csv_file, "Test dataset", input_keys=["input"], output_keys=["output"]
)
assert dataset.id == uuid.UUID(dataset_id)
assert dataset.name == "test.csv"
assert dataset.description == "Test dataset"
@pytest.mark.asyncio
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = Dataset(
id=uuid.uuid4(),
name="test",
description="Test dataset",
owner_id="owner",
created_at=_CREATED_AT,
tenant_id=_TENANT_ID,
)
uuids = [
"0c193153-2309-4704-9a47-17aee4fb25c8",
"0d11b5fd-8e66-4485-b696-4b55155c0c05",
"90d696f0-f10d-4fd0-b88b-bfee6df08b84",
"4ce2c6d8-5124-4c0c-8292-db7bdebcf167",
"7b5a524c-80fa-4960-888e-7d380f9a11ee",
]
examples = [
Example(
id=uuids[0],
created_at=_CREATED_AT,
inputs={"input": "1"},
outputs={"output": "2"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[1],
created_at=_CREATED_AT,
inputs={"input": "3"},
outputs={"output": "4"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[2],
created_at=_CREATED_AT,
inputs={"input": "5"},
outputs={"output": "6"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[3],
created_at=_CREATED_AT,
inputs={"input": "7"},
outputs={"output": "8"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[4],
created_at=_CREATED_AT,
inputs={"input": "9"},
outputs={"output": "10"},
dataset_id=str(uuid.uuid4()),
),
]
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
return dataset
def mock_list_examples(*args: Any, **kwargs: Any) -> List[Example]:
return examples
async def mock_arun_chain(
example: Example,
llm_or_chain: Union[BaseLanguageModel, Chain],
n_repetitions: int,
tracer: Any,
) -> List[Dict[str, Any]]:
return [
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
]
with mock.patch.object(
LangChainPlusClient, "read_dataset", new=mock_read_dataset
), mock.patch.object(
LangChainPlusClient, "list_examples", new=mock_list_examples
), mock.patch(
"langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain
):
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
chain = mock.MagicMock()
num_repetitions = 3
results = await client.arun_on_dataset(
dataset_name="test",
llm_or_chain_factory=lambda: chain,
concurrency_level=2,
session_name="test_session",
num_repetitions=num_repetitions,
)
expected = {
uuid_: [
{"result": f"Result for example {uuid.UUID(uuid_)}"}
for _ in range(num_repetitions)
]
for uuid_ in uuids
}
assert results == expected

View File

@@ -1,18 +1,27 @@
"""Test the LangChain+ client."""
from typing import Any, Dict
import uuid
from datetime import datetime
from typing import Any, Dict, List, Union
from unittest import mock
import pytest
from langchainplus_sdk.client import LangChainPlusClient
from langchainplus_sdk.schemas import Dataset, Example
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.client.runner_utils import (
InputFormatError,
_get_messages,
_get_prompts,
arun_on_dataset,
run_llm,
)
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
_EXAMPLE_MESSAGE = {
"data": {"content": "Foo", "example": False, "additional_kwargs": {}},
"type": "human",
@@ -93,3 +102,103 @@ def test_run_llm_all_formats(inputs: Dict[str, Any]) -> None:
def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None:
llm = FakeChatModel()
run_llm(llm, inputs, mock.MagicMock())
@pytest.mark.asyncio
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = Dataset(
id=uuid.uuid4(),
name="test",
description="Test dataset",
owner_id="owner",
created_at=_CREATED_AT,
tenant_id=_TENANT_ID,
)
uuids = [
"0c193153-2309-4704-9a47-17aee4fb25c8",
"0d11b5fd-8e66-4485-b696-4b55155c0c05",
"90d696f0-f10d-4fd0-b88b-bfee6df08b84",
"4ce2c6d8-5124-4c0c-8292-db7bdebcf167",
"7b5a524c-80fa-4960-888e-7d380f9a11ee",
]
examples = [
Example(
id=uuids[0],
created_at=_CREATED_AT,
inputs={"input": "1"},
outputs={"output": "2"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[1],
created_at=_CREATED_AT,
inputs={"input": "3"},
outputs={"output": "4"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[2],
created_at=_CREATED_AT,
inputs={"input": "5"},
outputs={"output": "6"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[3],
created_at=_CREATED_AT,
inputs={"input": "7"},
outputs={"output": "8"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[4],
created_at=_CREATED_AT,
inputs={"input": "9"},
outputs={"output": "10"},
dataset_id=str(uuid.uuid4()),
),
]
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
return dataset
def mock_list_examples(*args: Any, **kwargs: Any) -> List[Example]:
return examples
async def mock_arun_chain(
example: Example,
llm_or_chain: Union[BaseLanguageModel, Chain],
n_repetitions: int,
tracer: Any,
) -> List[Dict[str, Any]]:
return [
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
]
with mock.patch.object(
LangChainPlusClient, "read_dataset", new=mock_read_dataset
), mock.patch.object(
LangChainPlusClient, "list_examples", new=mock_list_examples
), mock.patch(
"langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain
):
client = LangChainPlusClient(api_url="http://localhost:1984", api_key="123")
chain = mock.MagicMock()
num_repetitions = 3
results = await arun_on_dataset(
dataset_name="test",
llm_or_chain_factory=lambda: chain,
concurrency_level=2,
session_name="test_session",
num_repetitions=num_repetitions,
client=client,
)
expected = {
uuid_: [
{"result": f"Result for example {uuid.UUID(uuid_)}"}
for _ in range(num_repetitions)
]
for uuid_ in uuids
}
assert results == expected

View File

@@ -38,6 +38,7 @@ def test_required_dependencies(poetry_conf: Mapping[str, Any]) -> None:
"aiohttp",
"async-timeout",
"dataclasses-json",
"langchainplus-sdk",
"numexpr",
"numpy",
"openapi-schema-pydantic",