diff --git a/libs/community/langchain_community/embeddings/xinference.py b/libs/community/langchain_community/embeddings/xinference.py index db9b56fcbf9..858a2fea415 100644 --- a/libs/community/langchain_community/embeddings/xinference.py +++ b/libs/community/langchain_community/embeddings/xinference.py @@ -1,11 +1,11 @@ """Wrapper around Xinference embedding models.""" + from typing import Any, List, Optional from langchain_core.embeddings import Embeddings class XinferenceEmbeddings(Embeddings): - """Xinference embedding models. To use, you should have the xinference library installed: @@ -14,6 +14,12 @@ class XinferenceEmbeddings(Embeddings): pip install xinference + If you're simply using the services provided by Xinference, you can utilize the xinference_client package: + + .. code-block:: bash + + pip install xinference_client + Check out: https://github.com/xorbitsai/inference To run, you need to start a Xinference supervisor on one server and Xinference workers on the other servers. @@ -32,6 +38,12 @@ class XinferenceEmbeddings(Embeddings): $ xinference-supervisor + If you're simply using the services provided by Xinference, you can utilize the xinference_client package: + + .. code-block:: bash + + pip install xinference_client + Starting the worker: .. code-block:: bash @@ -72,11 +84,14 @@ class XinferenceEmbeddings(Embeddings): ): try: from xinference.client import RESTfulClient - except ImportError as e: - raise ImportError( - "Could not import RESTfulClient from xinference. Please install it" - " with `pip install xinference`." - ) from e + except ImportError: + try: + from xinference_client import RESTfulClient + except ImportError as e: + raise ImportError( + "Could not import RESTfulClient from xinference. Please install it" + " with `pip install xinference` or `pip install xinference_client`." + ) from e super().__init__() diff --git a/libs/community/tests/integration_tests/embeddings/test_xinference.py b/libs/community/tests/integration_tests/embeddings/test_xinference.py index 6c52fa98981..13d441ed133 100644 --- a/libs/community/tests/integration_tests/embeddings/test_xinference.py +++ b/libs/community/tests/integration_tests/embeddings/test_xinference.py @@ -73,3 +73,13 @@ def test_xinference_embedding_query(setup: Tuple[str, str]) -> None: document = "foo bar" output = xinference.embed_query(document) assert len(output) == 4096 + + +def test_xinference_embedding() -> None: + embedding_model = XinferenceEmbeddings( + server_url="http://xinference-hostname:9997", model_uid="foo" + ) + + embedding_model.embed_documents( + texts=["hello", "i'm trying to upgrade xinference embedding"] + )