"""Test Xinference embeddings."""

import time
from typing import AsyncGenerator, Tuple

import pytest_asyncio

from langchain_community.embeddings import XinferenceEmbeddings


@pytest_asyncio.fixture
async def setup() -> AsyncGenerator[Tuple[str, str], None]:
    import xoscar as xo
    from xinference.deploy.supervisor import start_supervisor_components
    from xinference.deploy.utils import create_worker_actor_pool
    from xinference.deploy.worker import start_worker_components

    pool = await create_worker_actor_pool(
        f"test://127.0.0.1:{xo.utils.get_next_port()}"
    )
    print(f"Pool running on localhost:{pool.external_address}")  # noqa: T201

    endpoint = await start_supervisor_components(
        pool.external_address, "127.0.0.1", xo.utils.get_next_port()
    )
    await start_worker_components(
        address=pool.external_address, supervisor_address=pool.external_address
    )

    # wait for the api.
    time.sleep(3)
    async with pool:
        yield endpoint, pool.external_address


def test_xinference_embedding_documents(setup: Tuple[str, str]) -> None:
    """Test xinference embeddings for documents."""
    from xinference.client import RESTfulClient

    endpoint, _ = setup

    client = RESTfulClient(endpoint)

    model_uid = client.launch_model(
        model_name="vicuna-v1.3",
        model_size_in_billions=7,
        model_format="ggmlv3",
        quantization="q4_0",
    )

    xinference = XinferenceEmbeddings(server_url=endpoint, model_uid=model_uid)

    documents = ["foo bar", "bar foo"]
    output = xinference.embed_documents(documents)
    assert len(output) == 2
    assert len(output[0]) == 4096


def test_xinference_embedding_query(setup: Tuple[str, str]) -> None:
    """Test xinference embeddings for query."""
    from xinference.client import RESTfulClient

    endpoint, _ = setup

    client = RESTfulClient(endpoint)

    model_uid = client.launch_model(
        model_name="vicuna-v1.3", model_size_in_billions=7, quantization="q4_0"
    )

    xinference = XinferenceEmbeddings(server_url=endpoint, model_uid=model_uid)

    document = "foo bar"
    output = xinference.embed_query(document)
    assert len(output) == 4096


def test_xinference_embedding() -> None:
    embedding_model = XinferenceEmbeddings(
        server_url="http://xinference-hostname:9997", model_uid="foo"
    )

    embedding_model.embed_documents(
        texts=["hello", "i'm trying to upgrade xinference embedding"]
    )