langchain/libs/community/tests/integration_tests/embeddings/test_titan_takeoff.py
pjb157 479be3cc91
community[minor]: Unify Titan Takeoff Integrations and Adding Embedding Support (#18775)
**Community: Unify Titan Takeoff Integrations and Adding Embedding
Support**

 **Description:** 
Titan Takeoff no longer reflects this either of the integrations in the
community folder. The two integrations (TitanTakeoffPro and
TitanTakeoff) where causing confusion with clients, so have moved code
into one place and created an alias for backwards compatibility. Added
Takeoff Client python package to do the bulk of the work with the
requests, this is because this package is actively updated with new
versions of Takeoff. So this integration will be far more robust and
will not degrade as badly over time.

**Issue:**
Fixes bugs in the old Titan integrations and unified the code with added
unit test converge to avoid future problems.

**Dependencies:**
Added optional dependency takeoff-client, all imports still work without
dependency including the Titan Takeoff classes but just will fail on
initialisation if not pip installed takeoff-client

**Twitter**
@MeryemArik9

Thanks all :)

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
2024-04-17 01:43:35 +00:00

179 lines
6.0 KiB
Python

"""Test Titan Takeoff Embedding wrapper."""
import json
from typing import Any
import pytest
from langchain_community.embeddings import TitanTakeoffEmbed
from langchain_community.embeddings.titan_takeoff import MissingConsumerGroup
@pytest.mark.requires("pytest_httpx")
@pytest.mark.requires("takeoff_client")
def test_titan_takeoff_call(httpx_mock: Any) -> None:
"""Test valid call to Titan Takeoff."""
port = 2345
httpx_mock.add_response(
method="POST",
url=f"http://localhost:{port}/embed",
json={"result": [0.46635, 0.234, -0.8521]},
)
embedding = TitanTakeoffEmbed(port=port)
output_1 = embedding.embed_documents("What is 2 + 2?", "primary")
output_2 = embedding.embed_query("What is 2 + 2?", "primary")
assert isinstance(output_1, list)
assert isinstance(output_2, list)
assert len(httpx_mock.get_requests()) == 2
for n in range(2):
assert httpx_mock.get_requests()[n].url == f"http://localhost:{port}/embed"
assert (
json.loads(httpx_mock.get_requests()[n].content)["text"] == "What is 2 + 2?"
)
@pytest.mark.requires("pytest_httpx")
@pytest.mark.requires("takeoff_client")
def test_no_consumer_group_fails(httpx_mock: Any) -> None:
"""Test that not specifying a consumer group fails."""
port = 2345
httpx_mock.add_response(
method="POST",
url=f"http://localhost:{port}/embed",
json={"result": [0.46635, 0.234, -0.8521]},
)
embedding = TitanTakeoffEmbed(port=port)
with pytest.raises(MissingConsumerGroup):
embedding.embed_documents("What is 2 + 2?")
with pytest.raises(MissingConsumerGroup):
embedding.embed_query("What is 2 + 2?")
# Check specifying a consumer group works
embedding.embed_documents("What is 2 + 2?", "primary")
embedding.embed_query("What is 2 + 2?", "primary")
@pytest.mark.requires("pytest_httpx")
@pytest.mark.requires("takeoff_client")
def test_takeoff_initialization(httpx_mock: Any) -> None:
"""Test valid call to Titan Takeoff."""
mgnt_port = 36452
inf_port = 46253
mgnt_url = f"http://localhost:{mgnt_port}/reader"
embed_url = f"http://localhost:{inf_port}/embed"
reader_1 = {
"model_name": "test",
"device": "cpu",
"consumer_group": "embed",
}
reader_2 = reader_1.copy()
reader_2["model_name"] = "test2"
reader_2["device"] = "cuda"
httpx_mock.add_response(
method="POST", url=mgnt_url, json={"key": "value"}, status_code=201
)
httpx_mock.add_response(
method="POST",
url=embed_url,
json={"result": [0.34, 0.43, -0.934532]},
status_code=200,
)
llm = TitanTakeoffEmbed(
port=inf_port, mgmt_port=mgnt_port, models=[reader_1, reader_2]
)
# Shouldn't need to specify consumer group as there is only one specified during
# initialization
output_1 = llm.embed_documents("What is 2 + 2?")
output_2 = llm.embed_query("What is 2 + 2?")
assert isinstance(output_1, list)
assert isinstance(output_2, list)
# Ensure the management api was called to create the reader
assert len(httpx_mock.get_requests()) == 4
for key, value in reader_1.items():
assert json.loads(httpx_mock.get_requests()[0].content)[key] == value
assert httpx_mock.get_requests()[0].url == mgnt_url
# Also second call should be made to spin uo reader 2
for key, value in reader_2.items():
assert json.loads(httpx_mock.get_requests()[1].content)[key] == value
assert httpx_mock.get_requests()[1].url == mgnt_url
# Ensure the third call is to generate endpoint to inference
for n in range(2, 4):
assert httpx_mock.get_requests()[n].url == embed_url
assert (
json.loads(httpx_mock.get_requests()[n].content)["text"] == "What is 2 + 2?"
)
@pytest.mark.requires("pytest_httpx")
@pytest.mark.requires("takeoff_client")
def test_takeoff_initialization_with_more_than_one_consumer_group(
httpx_mock: Any,
) -> None:
"""Test valid call to Titan Takeoff."""
mgnt_port = 36452
inf_port = 46253
mgnt_url = f"http://localhost:{mgnt_port}/reader"
embed_url = f"http://localhost:{inf_port}/embed"
reader_1 = {
"model_name": "test",
"device": "cpu",
"consumer_group": "embed",
}
reader_2 = reader_1.copy()
reader_2["model_name"] = "test2"
reader_2["device"] = "cuda"
reader_2["consumer_group"] = "embed2"
httpx_mock.add_response(
method="POST", url=mgnt_url, json={"key": "value"}, status_code=201
)
httpx_mock.add_response(
method="POST",
url=embed_url,
json={"result": [0.34, 0.43, -0.934532]},
status_code=200,
)
llm = TitanTakeoffEmbed(
port=inf_port, mgmt_port=mgnt_port, models=[reader_1, reader_2]
)
# There was more than one consumer group specified during initialization so we
# need to specify which one to use
with pytest.raises(MissingConsumerGroup):
llm.embed_documents("What is 2 + 2?")
with pytest.raises(MissingConsumerGroup):
llm.embed_query("What is 2 + 2?")
output_1 = llm.embed_documents("What is 2 + 2?", "embed")
output_2 = llm.embed_query("What is 2 + 2?", "embed2")
assert isinstance(output_1, list)
assert isinstance(output_2, list)
# Ensure the management api was called to create the reader
assert len(httpx_mock.get_requests()) == 4
for key, value in reader_1.items():
assert json.loads(httpx_mock.get_requests()[0].content)[key] == value
assert httpx_mock.get_requests()[0].url == mgnt_url
# Also second call should be made to spin uo reader 2
for key, value in reader_2.items():
assert json.loads(httpx_mock.get_requests()[1].content)[key] == value
assert httpx_mock.get_requests()[1].url == mgnt_url
# Ensure the third call is to generate endpoint to inference
for n in range(2, 4):
assert httpx_mock.get_requests()[n].url == embed_url
assert (
json.loads(httpx_mock.get_requests()[n].content)["text"] == "What is 2 + 2?"
)