community[patch]: fix public interface for embeddings module (#21650)

## Description

The existing public interface for `langchain_community.emeddings` is
broken. In this file, `__all__` is statically defined, but is
subsequently overwritten with a dynamic expression, which type checkers
like pyright do not support. pyright actually gives the following
diagnostic on the line I am requesting we remove:


[reportUnsupportedDunderAll](https://github.com/microsoft/pyright/blob/main/docs/configuration.md#reportUnsupportedDunderAll):

```
Operation on "__all__" is not supported, so exported symbol list may be incorrect
```

Currently, I get the following errors when attempting to use publicablly
exported classes in `langchain_community.emeddings`:

```python
import langchain_community.embeddings

langchain_community.embeddings.HuggingFaceEmbeddings(...)  #  error: "HuggingFaceEmbeddings" is not exported from module "langchain_community.embeddings" (reportPrivateImportUsage)
```

This is solved easily by removing the dynamic expression.
This commit is contained in:
Matthew Hoffman 2024-05-22 08:42:15 -07:00 committed by GitHub
parent 6548052f9e
commit 4f2e3bd7fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 33 deletions

View File

@ -203,6 +203,9 @@ if TYPE_CHECKING:
from langchain_community.embeddings.tensorflow_hub import ( from langchain_community.embeddings.tensorflow_hub import (
TensorflowHubEmbeddings, TensorflowHubEmbeddings,
) )
from langchain_community.embeddings.titan_takeoff import (
TitanTakeoffEmbed,
)
from langchain_community.embeddings.vertexai import ( from langchain_community.embeddings.vertexai import (
VertexAIEmbeddings, VertexAIEmbeddings,
) )
@ -288,6 +291,7 @@ __all__ = [
"SpacyEmbeddings", "SpacyEmbeddings",
"SparkLLMTextEmbeddings", "SparkLLMTextEmbeddings",
"TensorflowHubEmbeddings", "TensorflowHubEmbeddings",
"TitanTakeoffEmbed",
"VertexAIEmbeddings", "VertexAIEmbeddings",
"VolcanoEmbeddings", "VolcanoEmbeddings",
"VoyageEmbeddings", "VoyageEmbeddings",
@ -380,8 +384,6 @@ def __getattr__(name: str) -> Any:
raise AttributeError(f"module {__name__} has no attribute {name}") raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = list(_module_lookup.keys())
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Any, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
@ -142,11 +142,12 @@ class TitanTakeoffEmbed(Embeddings):
def _embed( def _embed(
self, input: Union[List[str], str], consumer_group: Optional[str] self, input: Union[List[str], str], consumer_group: Optional[str]
) -> dict: ) -> Dict[str, Any]:
"""Embed text. """Embed text.
Args: Args:
input (List[str]): prompt/document or list of prompts/documents to embed input (Union[List[str], str]): prompt/document or list of prompts/documents
to embed
consumer_group (Optional[str]): what consumer group to send the embedding consumer_group (Optional[str]): what consumer group to send the embedding
request to. If not specified and there is only one request to. If not specified and there is only one
consumer group specified during initialization, it will be used. If there consumer group specified during initialization, it will be used. If there

View File

@ -7,7 +7,11 @@ from typing import Any
import pytest import pytest
from langchain_community.embeddings import TitanTakeoffEmbed from langchain_community.embeddings import TitanTakeoffEmbed
from langchain_community.embeddings.titan_takeoff import MissingConsumerGroup from langchain_community.embeddings.titan_takeoff import (
Device,
MissingConsumerGroup,
ReaderConfig,
)
@pytest.mark.requires("pytest_httpx") @pytest.mark.requires("pytest_httpx")
@ -24,7 +28,7 @@ def test_titan_takeoff_call(httpx_mock: Any) -> None:
embedding = TitanTakeoffEmbed(port=port) embedding = TitanTakeoffEmbed(port=port)
output_1 = embedding.embed_documents("What is 2 + 2?", "primary") output_1 = embedding.embed_documents(["What is 2 + 2?"], "primary")
output_2 = embedding.embed_query("What is 2 + 2?", "primary") output_2 = embedding.embed_query("What is 2 + 2?", "primary")
assert isinstance(output_1, list) assert isinstance(output_1, list)
@ -53,12 +57,12 @@ def test_no_consumer_group_fails(httpx_mock: Any) -> None:
embedding = TitanTakeoffEmbed(port=port) embedding = TitanTakeoffEmbed(port=port)
with pytest.raises(MissingConsumerGroup): with pytest.raises(MissingConsumerGroup):
embedding.embed_documents("What is 2 + 2?") embedding.embed_documents(["What is 2 + 2?"])
with pytest.raises(MissingConsumerGroup): with pytest.raises(MissingConsumerGroup):
embedding.embed_query("What is 2 + 2?") embedding.embed_query("What is 2 + 2?")
# Check specifying a consumer group works # Check specifying a consumer group works
embedding.embed_documents("What is 2 + 2?", "primary") embedding.embed_documents(["What is 2 + 2?"], "primary")
embedding.embed_query("What is 2 + 2?", "primary") embedding.embed_query("What is 2 + 2?", "primary")
@ -70,14 +74,16 @@ def test_takeoff_initialization(httpx_mock: Any) -> None:
inf_port = 46253 inf_port = 46253
mgnt_url = f"http://localhost:{mgnt_port}/reader" mgnt_url = f"http://localhost:{mgnt_port}/reader"
embed_url = f"http://localhost:{inf_port}/embed" embed_url = f"http://localhost:{inf_port}/embed"
reader_1 = { reader_1 = ReaderConfig(
"model_name": "test", model_name="test",
"device": "cpu", device=Device.cpu,
"consumer_group": "embed", consumer_group="embed",
} )
reader_2 = reader_1.copy() reader_2 = ReaderConfig(
reader_2["model_name"] = "test2" model_name="test2",
reader_2["device"] = "cuda" device=Device.cuda,
consumer_group="embed",
)
httpx_mock.add_response( httpx_mock.add_response(
method="POST", url=mgnt_url, json={"key": "value"}, status_code=201 method="POST", url=mgnt_url, json={"key": "value"}, status_code=201
@ -94,18 +100,18 @@ def test_takeoff_initialization(httpx_mock: Any) -> None:
) )
# Shouldn't need to specify consumer group as there is only one specified during # Shouldn't need to specify consumer group as there is only one specified during
# initialization # initialization
output_1 = llm.embed_documents("What is 2 + 2?") output_1 = llm.embed_documents(["What is 2 + 2?"])
output_2 = llm.embed_query("What is 2 + 2?") output_2 = llm.embed_query("What is 2 + 2?")
assert isinstance(output_1, list) assert isinstance(output_1, list)
assert isinstance(output_2, list) assert isinstance(output_2, list)
# Ensure the management api was called to create the reader # Ensure the management api was called to create the reader
assert len(httpx_mock.get_requests()) == 4 assert len(httpx_mock.get_requests()) == 4
for key, value in reader_1.items(): for key, value in reader_1.dict().items():
assert json.loads(httpx_mock.get_requests()[0].content)[key] == value assert json.loads(httpx_mock.get_requests()[0].content)[key] == value
assert httpx_mock.get_requests()[0].url == mgnt_url assert httpx_mock.get_requests()[0].url == mgnt_url
# Also second call should be made to spin uo reader 2 # Also second call should be made to spin uo reader 2
for key, value in reader_2.items(): for key, value in reader_2.dict().items():
assert json.loads(httpx_mock.get_requests()[1].content)[key] == value assert json.loads(httpx_mock.get_requests()[1].content)[key] == value
assert httpx_mock.get_requests()[1].url == mgnt_url assert httpx_mock.get_requests()[1].url == mgnt_url
# Ensure the third call is to generate endpoint to inference # Ensure the third call is to generate endpoint to inference
@ -126,15 +132,16 @@ def test_takeoff_initialization_with_more_than_one_consumer_group(
inf_port = 46253 inf_port = 46253
mgnt_url = f"http://localhost:{mgnt_port}/reader" mgnt_url = f"http://localhost:{mgnt_port}/reader"
embed_url = f"http://localhost:{inf_port}/embed" embed_url = f"http://localhost:{inf_port}/embed"
reader_1 = { reader_1 = ReaderConfig(
"model_name": "test", model_name="test",
"device": "cpu", device=Device.cpu,
"consumer_group": "embed", consumer_group="embed",
} )
reader_2 = reader_1.copy() reader_2 = ReaderConfig(
reader_2["model_name"] = "test2" model_name="test2",
reader_2["device"] = "cuda" device=Device.cuda,
reader_2["consumer_group"] = "embed2" consumer_group="embed2",
)
httpx_mock.add_response( httpx_mock.add_response(
method="POST", url=mgnt_url, json={"key": "value"}, status_code=201 method="POST", url=mgnt_url, json={"key": "value"}, status_code=201
@ -152,22 +159,22 @@ def test_takeoff_initialization_with_more_than_one_consumer_group(
# There was more than one consumer group specified during initialization so we # There was more than one consumer group specified during initialization so we
# need to specify which one to use # need to specify which one to use
with pytest.raises(MissingConsumerGroup): with pytest.raises(MissingConsumerGroup):
llm.embed_documents("What is 2 + 2?") llm.embed_documents(["What is 2 + 2?"])
with pytest.raises(MissingConsumerGroup): with pytest.raises(MissingConsumerGroup):
llm.embed_query("What is 2 + 2?") llm.embed_query("What is 2 + 2?")
output_1 = llm.embed_documents("What is 2 + 2?", "embed") output_1 = llm.embed_documents(["What is 2 + 2?"], "embed")
output_2 = llm.embed_query("What is 2 + 2?", "embed2") output_2 = llm.embed_query("What is 2 + 2?", "embed2")
assert isinstance(output_1, list) assert isinstance(output_1, list)
assert isinstance(output_2, list) assert isinstance(output_2, list)
# Ensure the management api was called to create the reader # Ensure the management api was called to create the reader
assert len(httpx_mock.get_requests()) == 4 assert len(httpx_mock.get_requests()) == 4
for key, value in reader_1.items(): for key, value in reader_1.dict().items():
assert json.loads(httpx_mock.get_requests()[0].content)[key] == value assert json.loads(httpx_mock.get_requests()[0].content)[key] == value
assert httpx_mock.get_requests()[0].url == mgnt_url assert httpx_mock.get_requests()[0].url == mgnt_url
# Also second call should be made to spin uo reader 2 # Also second call should be made to spin uo reader 2
for key, value in reader_2.items(): for key, value in reader_2.dict().items():
assert json.loads(httpx_mock.get_requests()[1].content)[key] == value assert json.loads(httpx_mock.get_requests()[1].content)[key] == value
assert httpx_mock.get_requests()[1].url == mgnt_url assert httpx_mock.get_requests()[1].url == mgnt_url
# Ensure the third call is to generate endpoint to inference # Ensure the third call is to generate endpoint to inference