diff --git a/libs/community/langchain_community/embeddings/__init__.py b/libs/community/langchain_community/embeddings/__init__.py index 499f79453a2..200741fc6e0 100644 --- a/libs/community/langchain_community/embeddings/__init__.py +++ b/libs/community/langchain_community/embeddings/__init__.py @@ -203,6 +203,9 @@ if TYPE_CHECKING: from langchain_community.embeddings.tensorflow_hub import ( TensorflowHubEmbeddings, ) + from langchain_community.embeddings.titan_takeoff import ( + TitanTakeoffEmbed, + ) from langchain_community.embeddings.vertexai import ( VertexAIEmbeddings, ) @@ -288,6 +291,7 @@ __all__ = [ "SpacyEmbeddings", "SparkLLMTextEmbeddings", "TensorflowHubEmbeddings", + "TitanTakeoffEmbed", "VertexAIEmbeddings", "VolcanoEmbeddings", "VoyageEmbeddings", @@ -380,8 +384,6 @@ def __getattr__(name: str) -> Any: raise AttributeError(f"module {__name__} has no attribute {name}") -__all__ = list(_module_lookup.keys()) - logger = logging.getLogger(__name__) diff --git a/libs/community/langchain_community/embeddings/titan_takeoff.py b/libs/community/langchain_community/embeddings/titan_takeoff.py index 81966c6739e..bf82f54936e 100644 --- a/libs/community/langchain_community/embeddings/titan_takeoff.py +++ b/libs/community/langchain_community/embeddings/titan_takeoff.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Union from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel @@ -142,11 +142,12 @@ class TitanTakeoffEmbed(Embeddings): def _embed( self, input: Union[List[str], str], consumer_group: Optional[str] - ) -> dict: + ) -> Dict[str, Any]: """Embed text. Args: - input (List[str]): prompt/document or list of prompts/documents to embed + input (Union[List[str], str]): prompt/document or list of prompts/documents + to embed consumer_group (Optional[str]): what consumer group to send the embedding request to. If not specified and there is only one consumer group specified during initialization, it will be used. If there diff --git a/libs/community/tests/integration_tests/embeddings/test_titan_takeoff.py b/libs/community/tests/integration_tests/embeddings/test_titan_takeoff.py index 884f1a120ab..5eb78bf374c 100644 --- a/libs/community/tests/integration_tests/embeddings/test_titan_takeoff.py +++ b/libs/community/tests/integration_tests/embeddings/test_titan_takeoff.py @@ -7,7 +7,11 @@ from typing import Any import pytest from langchain_community.embeddings import TitanTakeoffEmbed -from langchain_community.embeddings.titan_takeoff import MissingConsumerGroup +from langchain_community.embeddings.titan_takeoff import ( + Device, + MissingConsumerGroup, + ReaderConfig, +) @pytest.mark.requires("pytest_httpx") @@ -24,7 +28,7 @@ def test_titan_takeoff_call(httpx_mock: Any) -> None: embedding = TitanTakeoffEmbed(port=port) - output_1 = embedding.embed_documents("What is 2 + 2?", "primary") + output_1 = embedding.embed_documents(["What is 2 + 2?"], "primary") output_2 = embedding.embed_query("What is 2 + 2?", "primary") assert isinstance(output_1, list) @@ -53,12 +57,12 @@ def test_no_consumer_group_fails(httpx_mock: Any) -> None: embedding = TitanTakeoffEmbed(port=port) with pytest.raises(MissingConsumerGroup): - embedding.embed_documents("What is 2 + 2?") + embedding.embed_documents(["What is 2 + 2?"]) with pytest.raises(MissingConsumerGroup): embedding.embed_query("What is 2 + 2?") # Check specifying a consumer group works - embedding.embed_documents("What is 2 + 2?", "primary") + embedding.embed_documents(["What is 2 + 2?"], "primary") embedding.embed_query("What is 2 + 2?", "primary") @@ -70,14 +74,16 @@ def test_takeoff_initialization(httpx_mock: Any) -> None: inf_port = 46253 mgnt_url = f"http://localhost:{mgnt_port}/reader" embed_url = f"http://localhost:{inf_port}/embed" - reader_1 = { - "model_name": "test", - "device": "cpu", - "consumer_group": "embed", - } - reader_2 = reader_1.copy() - reader_2["model_name"] = "test2" - reader_2["device"] = "cuda" + reader_1 = ReaderConfig( + model_name="test", + device=Device.cpu, + consumer_group="embed", + ) + reader_2 = ReaderConfig( + model_name="test2", + device=Device.cuda, + consumer_group="embed", + ) httpx_mock.add_response( method="POST", url=mgnt_url, json={"key": "value"}, status_code=201 @@ -94,18 +100,18 @@ def test_takeoff_initialization(httpx_mock: Any) -> None: ) # Shouldn't need to specify consumer group as there is only one specified during # initialization - output_1 = llm.embed_documents("What is 2 + 2?") + output_1 = llm.embed_documents(["What is 2 + 2?"]) output_2 = llm.embed_query("What is 2 + 2?") assert isinstance(output_1, list) assert isinstance(output_2, list) # Ensure the management api was called to create the reader assert len(httpx_mock.get_requests()) == 4 - for key, value in reader_1.items(): + for key, value in reader_1.dict().items(): assert json.loads(httpx_mock.get_requests()[0].content)[key] == value assert httpx_mock.get_requests()[0].url == mgnt_url # Also second call should be made to spin uo reader 2 - for key, value in reader_2.items(): + for key, value in reader_2.dict().items(): assert json.loads(httpx_mock.get_requests()[1].content)[key] == value assert httpx_mock.get_requests()[1].url == mgnt_url # Ensure the third call is to generate endpoint to inference @@ -126,15 +132,16 @@ def test_takeoff_initialization_with_more_than_one_consumer_group( inf_port = 46253 mgnt_url = f"http://localhost:{mgnt_port}/reader" embed_url = f"http://localhost:{inf_port}/embed" - reader_1 = { - "model_name": "test", - "device": "cpu", - "consumer_group": "embed", - } - reader_2 = reader_1.copy() - reader_2["model_name"] = "test2" - reader_2["device"] = "cuda" - reader_2["consumer_group"] = "embed2" + reader_1 = ReaderConfig( + model_name="test", + device=Device.cpu, + consumer_group="embed", + ) + reader_2 = ReaderConfig( + model_name="test2", + device=Device.cuda, + consumer_group="embed2", + ) httpx_mock.add_response( method="POST", url=mgnt_url, json={"key": "value"}, status_code=201 @@ -152,22 +159,22 @@ def test_takeoff_initialization_with_more_than_one_consumer_group( # There was more than one consumer group specified during initialization so we # need to specify which one to use with pytest.raises(MissingConsumerGroup): - llm.embed_documents("What is 2 + 2?") + llm.embed_documents(["What is 2 + 2?"]) with pytest.raises(MissingConsumerGroup): llm.embed_query("What is 2 + 2?") - output_1 = llm.embed_documents("What is 2 + 2?", "embed") + output_1 = llm.embed_documents(["What is 2 + 2?"], "embed") output_2 = llm.embed_query("What is 2 + 2?", "embed2") assert isinstance(output_1, list) assert isinstance(output_2, list) # Ensure the management api was called to create the reader assert len(httpx_mock.get_requests()) == 4 - for key, value in reader_1.items(): + for key, value in reader_1.dict().items(): assert json.loads(httpx_mock.get_requests()[0].content)[key] == value assert httpx_mock.get_requests()[0].url == mgnt_url # Also second call should be made to spin uo reader 2 - for key, value in reader_2.items(): + for key, value in reader_2.dict().items(): assert json.loads(httpx_mock.get_requests()[1].content)[key] == value assert httpx_mock.get_requests()[1].url == mgnt_url # Ensure the third call is to generate endpoint to inference