Files
langchain/tests/unit_tests/client/test_langchain.py
Ankush Gola aa73a888fa Some notebook and client fixes (add retries, clean up docs, etc) (#4820)
# Your PR Title (What it does)

<!--
Thank you for contributing to LangChain! Your PR will appear in our next
release under the title you set. Please make sure it highlights your
valuable contribution.

Replace this with a description of the change, the issue it fixes (if
applicable), and relevant context. List any dependencies required for
this change.

After you're done, someone will review your PR. They may suggest
improvements. If no one reviews your PR within a few days, feel free to
@-mention the same people again, as notifications can get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

## Before submitting

<!-- If you're adding a new integration, include an integration test and
an example notebook showing its use! -->

## Who can review?

Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:

<!-- For a quicker response, figure out the right person to tag with @

        @hwchase17 - project lead

        Tracing / Callbacks
        - @agola11

        Async
        - @agola11

        DataLoaders
        - @eyurtsev

        Models
        - @hwchase17
        - @agola11

        Agents / Tools / Toolkits
        - @vowelparrot
        
        VectorStores / Retrievers / Memory
        - @dev2049
        
 -->
2023-05-16 20:23:00 -07:00

233 lines
7.4 KiB
Python

"""Test the LangChain+ client."""
import uuid
from datetime import datetime
from io import BytesIO
from typing import Any, Dict, List, Optional, Union
from unittest import mock
import pytest
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.schemas import TracerSession
from langchain.chains.base import Chain
from langchain.client.langchain import (
LangChainPlusClient,
_get_link_stem,
_is_localhost,
)
from langchain.client.models import Dataset, Example
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
@pytest.mark.parametrize(
"api_url, expected_url",
[
("http://localhost:8000", "http://localhost"),
("http://www.example.com", "http://www.example.com"),
(
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
),
("https://www.langchain.com/path/to/nowhere", "https://www.langchain.com"),
],
)
def test_link_split(api_url: str, expected_url: str) -> None:
"""Test the link splitting handles both localhost and deployed urls."""
assert _get_link_stem(api_url) == expected_url
def test_is_localhost() -> None:
assert _is_localhost("http://localhost:8000")
assert _is_localhost("http://127.0.0.1:8000")
assert _is_localhost("http://0.0.0.0:8000")
assert not _is_localhost("http://example.com:8000")
def test_validate_api_key_if_hosted() -> None:
def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
return _TENANT_ID
with mock.patch.object(
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
):
with pytest.raises(ValueError, match="API key must be provided"):
LangChainPlusClient(api_url="http://www.example.com")
with mock.patch.object(
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
):
client = LangChainPlusClient(api_url="http://localhost:8000")
assert client.api_url == "http://localhost:8000"
assert client.api_key is None
def test_headers() -> None:
def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
return _TENANT_ID
with mock.patch.object(
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
):
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
assert client._headers == {"x-api-key": "123"}
with mock.patch.object(
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
):
client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
assert client_no_key._headers == {}
@mock.patch("langchain.client.langchain.requests.post")
def test_upload_csv(mock_post: mock.Mock) -> None:
mock_response = mock.Mock()
dataset_id = str(uuid.uuid4())
example_1 = Example(
id=str(uuid.uuid4()),
created_at=_CREATED_AT,
inputs={"input": "1"},
outputs={"output": "2"},
dataset_id=dataset_id,
)
example_2 = Example(
id=str(uuid.uuid4()),
created_at=_CREATED_AT,
inputs={"input": "3"},
outputs={"output": "4"},
dataset_id=dataset_id,
)
mock_response.json.return_value = {
"id": dataset_id,
"name": "test.csv",
"description": "Test dataset",
"owner_id": "the owner",
"created_at": _CREATED_AT,
"examples": [example_1, example_2],
"tenant_id": _TENANT_ID,
}
mock_post.return_value = mock_response
client = LangChainPlusClient(
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
)
csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n"))
dataset = client.upload_csv(
csv_file, "Test dataset", input_keys=["input"], output_keys=["output"]
)
assert dataset.id == uuid.UUID(dataset_id)
assert dataset.name == "test.csv"
assert dataset.description == "Test dataset"
@pytest.mark.asyncio
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = Dataset(
id=uuid.uuid4(),
name="test",
description="Test dataset",
owner_id="owner",
created_at=_CREATED_AT,
tenant_id=_TENANT_ID,
)
uuids = [
"0c193153-2309-4704-9a47-17aee4fb25c8",
"0d11b5fd-8e66-4485-b696-4b55155c0c05",
"90d696f0-f10d-4fd0-b88b-bfee6df08b84",
"4ce2c6d8-5124-4c0c-8292-db7bdebcf167",
"7b5a524c-80fa-4960-888e-7d380f9a11ee",
]
examples = [
Example(
id=uuids[0],
created_at=_CREATED_AT,
inputs={"input": "1"},
outputs={"output": "2"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[1],
created_at=_CREATED_AT,
inputs={"input": "3"},
outputs={"output": "4"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[2],
created_at=_CREATED_AT,
inputs={"input": "5"},
outputs={"output": "6"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[3],
created_at=_CREATED_AT,
inputs={"input": "7"},
outputs={"output": "8"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[4],
created_at=_CREATED_AT,
inputs={"input": "9"},
outputs={"output": "10"},
dataset_id=str(uuid.uuid4()),
),
]
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
return dataset
def mock_list_examples(*args: Any, **kwargs: Any) -> List[Example]:
return examples
async def mock_arun_chain(
example: Example,
tracer: Any,
llm_or_chain: Union[BaseLanguageModel, Chain],
n_repetitions: int,
) -> List[Dict[str, Any]]:
return [
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
]
def mock_ensure_session(self: Any, *args: Any, **kwargs: Any) -> TracerSession:
return TracerSession(name="test_session", tenant_id=_TENANT_ID, id=uuid.uuid4())
with mock.patch.object(
LangChainPlusClient, "read_dataset", new=mock_read_dataset
), mock.patch.object(
LangChainPlusClient, "list_examples", new=mock_list_examples
), mock.patch.object(
LangChainPlusClient, "_arun_llm_or_chain", new=mock_arun_chain
), mock.patch.object(
LangChainTracer, "ensure_session", new=mock_ensure_session
):
monkeypatch.setenv("LANGCHAIN_TENANT_ID", _TENANT_ID)
client = LangChainPlusClient(
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
)
chain = mock.MagicMock()
num_repetitions = 3
results = await client.arun_on_dataset(
dataset_name="test",
llm_or_chain_factory=lambda: chain,
concurrency_level=2,
session_name="test_session",
num_repetitions=num_repetitions,
)
expected = {
uuid_: [
{"result": f"Result for example {uuid.UUID(uuid_)}"}
for _ in range(num_repetitions)
]
for uuid_ in uuids
}
assert results == expected