FEAT: Integrate Xinference LLMs and Embeddings (#8171)

- [Xorbits
Inference(Xinference)](https://github.com/xorbitsai/inference) is a
powerful and versatile library designed to serve language, speech
recognition, and multimodal models. Xinference supports a variety of
GGML-compatible models including chatglm, whisper, and vicuna, and
utilizes heterogeneous hardware and a distributed architecture for
seamless cross-device and cross-server model deployment.
- This PR integrates Xinference models and Xinference embeddings into
LangChain.
- Dependencies: To install the depenedencies for this integration, run
    
    `pip install "xinference[all]"`
    
- Example Usage:

To start a local instance of Xinference, run `xinference`.

To deploy Xinference in a distributed cluster, first start an Xinference
supervisor using `xinference-supervisor`:

`xinference-supervisor -H "${supervisor_host}"`

Then, start the Xinference workers using `xinference-worker` on each
server you want to run them on.

`xinference-worker -e "http://${supervisor_host}:9997"`

To use Xinference with LangChain, you also need to launch a model. You
can use command line interface (CLI) to do so. Fo example: `xinference
launch -n vicuna-v1.3 -f ggmlv3 -q q4_0`. This launches a model named
vicuna-v1.3 with `model_format="ggmlv3"` and `quantization="q4_0"`. A
model UID is returned for you to use.

Now you can use Xinference with LangChain:

```python
from langchain.llms import Xinference

llm = Xinference(
    server_url="http://0.0.0.0:9997", # suppose the supervisor_host is "0.0.0.0"
    model_uid = {model_uid} # model UID returned from launching a model
)

llm(
    prompt="Q: where can we visit in the capital of France? A:",
    generate_config={"max_tokens": 1024},
)
```

You can also use RESTful client to launch a model:
```python
from xinference.client import RESTfulClient

client = RESTfulClient("http://0.0.0.0:9997")

model_uid = client.launch_model(model_name="vicuna-v1.3", model_size_in_billions=7, quantization="q4_0")
```

The following code block demonstrates how to use Xinference embeddings
with LangChain:
```python
from langchain.embeddings import XinferenceEmbeddings

xinference = XinferenceEmbeddings(
    server_url="http://0.0.0.0:9997",
    model_uid = model_uid
)
```

```python
query_result = xinference.embed_query("This is a test query")
```

```python
doc_result = xinference.embed_documents(["text A", "text B"])
```

Xinference is still under rapid development. Feel free to [join our
Slack
community](https://xorbitsio.slack.com/join/shared_invite/zt-1z3zsm9ep-87yI9YZ_B79HLB2ccTq4WA)
to get the latest updates!

- Request for review: @hwchase17, @baskaryan
- Twitter handle: https://twitter.com/Xorbitsio

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Jiayi Ni
2023-07-28 12:23:19 +08:00
committed by GitHub
parent 877d384bc9
commit 1efb9bae5f
13 changed files with 1556 additions and 541 deletions

View File

@@ -42,6 +42,7 @@ from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddi
from langchain.embeddings.spacy_embeddings import SpacyEmbeddings
from langchain.embeddings.tensorflow_hub import TensorflowHubEmbeddings
from langchain.embeddings.vertexai import VertexAIEmbeddings
from langchain.embeddings.xinference import XinferenceEmbeddings
logger = logging.getLogger(__name__)
@@ -78,6 +79,7 @@ __all__ = [
"SpacyEmbeddings",
"NLPCloudEmbeddings",
"GPT4AllEmbeddings",
"XinferenceEmbeddings",
"LocalAIEmbeddings",
"AwaEmbeddings",
]

View File

@@ -0,0 +1,113 @@
"""Wrapper around Xinference embedding models."""
from typing import Any, List, Optional
from langchain.embeddings.base import Embeddings
class XinferenceEmbeddings(Embeddings):
"""Wrapper around xinference embedding models.
To use, you should have the xinference library installed:
.. code-block:: bash
pip install xinference
Check out: https://github.com/xorbitsai/inference
To run, you need to start a Xinference supervisor on one server and Xinference workers on the other servers
Example:
To start a local instance of Xinference, run
.. code-block:: bash
$ xinference
You can also deploy Xinference in a distributed cluster. Here are the steps:
Starting the supervisor:
.. code-block:: bash
$ xinference-supervisor
Starting the worker:
.. code-block:: bash
$ xinference-worker
Then, launch a model using command line interface (CLI).
Example:
.. code-block:: bash
$ xinference launch -n orca -s 3 -q q4_0
It will return a model UID. Then you can use Xinference Embedding with LangChain.
Example:
.. code-block:: python
from langchain.embeddings import XinferenceEmbeddings
xinference = XinferenceEmbeddings(
server_url="http://0.0.0.0:9997",
model_uid = {model_uid} # replace model_uid with the model UID return from launching the model
)
""" # noqa: E501
client: Any
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
"""UID of the launched model"""
def __init__(
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
):
try:
from xinference.client import RESTfulClient
except ImportError as e:
raise ImportError(
"Could not import RESTfulClient from xinference. Please install it"
" with `pip install xinference`."
) from e
super().__init__()
if server_url is None:
raise ValueError("Please provide server URL")
if model_uid is None:
raise ValueError("Please provide the model UID")
self.server_url = server_url
self.model_uid = model_uid
self.client = RESTfulClient(server_url)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of documents using Xinference.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
model = self.client.get_model(self.model_uid)
embeddings = [
model.create_embedding(text)["data"][0]["embedding"] for text in texts
]
return [list(map(float, e)) for e in embeddings]
def embed_query(self, text: str) -> List[float]:
"""Embed a query of documents using Xinference.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
model = self.client.get_model(self.model_uid)
embedding_res = model.create_embedding(text)
embedding = embedding_res["data"][0]["embedding"]
return list(map(float, embedding))

View File

@@ -56,6 +56,7 @@ from langchain.llms.textgen import TextGen
from langchain.llms.tongyi import Tongyi
from langchain.llms.vertexai import VertexAI
from langchain.llms.writer import Writer
from langchain.llms.xinference import Xinference
__all__ = [
"AI21",
@@ -115,6 +116,7 @@ __all__ = [
"VertexAI",
"Writer",
"OctoAIEndpoint",
"Xinference",
]
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
@@ -170,4 +172,5 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"openllm": OpenLLM,
"openllm_client": OpenLLM,
"writer": Writer,
"xinference": Xinference,
}

View File

@@ -0,0 +1,185 @@
from typing import TYPE_CHECKING, Any, Generator, List, Mapping, Optional, Union
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
if TYPE_CHECKING:
from xinference.client import RESTfulChatModelHandle, RESTfulGenerateModelHandle
from xinference.model.llm.core import LlamaCppGenerateConfig
class Xinference(LLM):
"""Wrapper for accessing Xinference's large-scale model inference service.
To use, you should have the xinference library installed:
.. code-block:: bash
pip install "xinference[all]"
Check out: https://github.com/xorbitsai/inference
To run, you need to start a Xinference supervisor on one server and Xinference workers on the other servers
Example:
To start a local instance of Xinference, run
.. code-block:: bash
$ xinference
You can also deploy Xinference in a distributed cluster. Here are the steps:
Starting the supervisor:
.. code-block:: bash
$ xinference-supervisor
Starting the worker:
.. code-block:: bash
$ xinference-worker
Then, launch a model using command line interface (CLI).
Example:
.. code-block:: bash
$ xinference launch -n orca -s 3 -q q4_0
It will return a model UID. Then, you can use Xinference with LangChain.
Example:
.. code-block:: python
from langchain.llms import Xinference
llm = Xinference(
server_url="http://0.0.0.0:9997",
model_uid = {model_uid} # replace model_uid with the model UID return from launching the model
)
llm(
prompt="Q: where can we visit in the capital of France? A:",
generate_config={"max_tokens": 1024, "stream": True},
)
To view all the supported builtin models, run:
.. code-block:: bash
$ xinference list --all
""" # noqa: E501
client: Any
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
"""UID of the launched model"""
def __init__(
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
):
try:
from xinference.client import RESTfulClient
except ImportError as e:
raise ImportError(
"Could not import RESTfulClient from xinference. Please install it"
" with `pip install xinference`."
) from e
super().__init__(
**{
"server_url": server_url,
"model_uid": model_uid,
}
)
if self.server_url is None:
raise ValueError("Please provide server URL")
if self.model_uid is None:
raise ValueError("Please provide the model UID")
self.client = RESTfulClient(server_url)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "xinference"
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
**{"server_url": self.server_url},
**{"model_uid": self.model_uid},
}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the xinference model and return the output.
Args:
prompt: The prompt to use for generation.
stop: Optional list of stop words to use when generating.
generate_config: Optional dictionary for the configuration used for
generation.
Returns:
The generated string by the model.
"""
model = self.client.get_model(self.model_uid)
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
if stop:
generate_config["stop"] = stop
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
return combined_text_output
else:
completion = model.generate(prompt=prompt, generate_config=generate_config)
return completion["choices"][0]["text"]
def _stream_generate(
self,
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle"],
prompt: str,
run_manager: Optional[CallbackManagerForLLMRun] = None,
generate_config: Optional["LlamaCppGenerateConfig"] = None,
) -> Generator[str, None, None]:
"""
Args:
prompt: The prompt to use for generation.
model: The model used for generation.
stop: Optional list of stop words to use when generating.
generate_config: Optional dictionary for the configuration used for
generation.
Yields:
A string token.
"""
streaming_response = model.generate(
prompt=prompt, generate_config=generate_config
)
for chunk in streaming_response:
if isinstance(chunk, dict):
choices = chunk.get("choices", [])
if choices:
choice = choices[0]
if isinstance(choice, dict):
token = choice.get("text", "")
log_probs = choice.get("logprobs")
if run_manager:
run_manager.on_llm_new_token(
token=token, verbose=self.verbose, log_probs=log_probs
)
yield token

File diff suppressed because it is too large Load Diff

View File

@@ -124,6 +124,7 @@ langsmith = "~0.0.11"
rank-bm25 = {version = "^0.2.2", optional = true}
amadeus = {version = ">=8.1.0", optional = true}
geopandas = {version = "^0.13.1", optional = true}
xinference = {version = "^0.0.6", optional = true}
python-arango = {version = "^7.5.9", optional = true}
[tool.poetry.group.test.dependencies]
@@ -219,7 +220,7 @@ playwright = "^1.28.0"
setuptools = "^67.6.1"
[tool.poetry.extras]
llms = ["anthropic", "clarifai", "cohere", "openai", "openllm", "openlm", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"]
llms = ["anthropic", "clarifai", "cohere", "openai", "openllm", "openlm", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers", "xinference"]
qdrant = ["qdrant-client"]
openai = ["openai", "tiktoken"]
text_helpers = ["chardet"]
@@ -315,6 +316,7 @@ all = [
"octoai-sdk",
"rdflib",
"amadeus",
"xinference",
"python-arango",
]
@@ -356,6 +358,7 @@ extended_testing = [
"rank_bm25",
"geopandas",
"jinja2",
"xinference",
]
[tool.ruff]

View File

@@ -0,0 +1,74 @@
"""Test Xinference embeddings."""
import time
from typing import AsyncGenerator, Tuple
import pytest_asyncio
from langchain.embeddings import XinferenceEmbeddings
@pytest_asyncio.fixture
async def setup() -> AsyncGenerator[Tuple[str, str], None]:
import xoscar as xo
from xinference.deploy.supervisor import start_supervisor_components
from xinference.deploy.utils import create_worker_actor_pool
from xinference.deploy.worker import start_worker_components
pool = await create_worker_actor_pool(
f"test://127.0.0.1:{xo.utils.get_next_port()}"
)
print(f"Pool running on localhost:{pool.external_address}")
endpoint = await start_supervisor_components(
pool.external_address, "127.0.0.1", xo.utils.get_next_port()
)
await start_worker_components(
address=pool.external_address, supervisor_address=pool.external_address
)
# wait for the api.
time.sleep(3)
async with pool:
yield endpoint, pool.external_address
def test_xinference_embedding_documents(setup: Tuple[str, str]) -> None:
"""Test xinference embeddings for documents."""
from xinference.client import RESTfulClient
endpoint, _ = setup
client = RESTfulClient(endpoint)
model_uid = client.launch_model(
model_name="vicuna-v1.3",
model_size_in_billions=7,
model_format="ggmlv3",
quantization="q4_0",
)
xinference = XinferenceEmbeddings(server_url=endpoint, model_uid=model_uid)
documents = ["foo bar", "bar foo"]
output = xinference.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 4096
def test_xinference_embedding_query(setup: Tuple[str, str]) -> None:
"""Test xinference embeddings for query."""
from xinference.client import RESTfulClient
endpoint, _ = setup
client = RESTfulClient(endpoint)
model_uid = client.launch_model(
model_name="vicuna-v1.3", model_size_in_billions=7, quantization="q4_0"
)
xinference = XinferenceEmbeddings(server_url=endpoint, model_uid=model_uid)
document = "foo bar"
output = xinference.embed_query(document)
assert len(output) == 4096

View File

@@ -0,0 +1,57 @@
"""Test Xinference wrapper."""
import time
from typing import AsyncGenerator, Tuple
import pytest_asyncio
from langchain.llms import Xinference
@pytest_asyncio.fixture
async def setup() -> AsyncGenerator[Tuple[str, str], None]:
import xoscar as xo
from xinference.deploy.supervisor import start_supervisor_components
from xinference.deploy.utils import create_worker_actor_pool
from xinference.deploy.worker import start_worker_components
pool = await create_worker_actor_pool(
f"test://127.0.0.1:{xo.utils.get_next_port()}"
)
print(f"Pool running on localhost:{pool.external_address}")
endpoint = await start_supervisor_components(
pool.external_address, "127.0.0.1", xo.utils.get_next_port()
)
await start_worker_components(
address=pool.external_address, supervisor_address=pool.external_address
)
# wait for the api.
time.sleep(3)
async with pool:
yield endpoint, pool.external_address
def test_xinference_llm_(setup: Tuple[str, str]) -> None:
from xinference.client import RESTfulClient
endpoint, _ = setup
client = RESTfulClient(endpoint)
model_uid = client.launch_model(
model_name="vicuna-v1.3", model_size_in_billions=7, quantization="q4_0"
)
llm = Xinference(server_url=endpoint, model_uid=model_uid)
answer = llm(prompt="Q: What food can we try in the capital of France? A:")
assert isinstance(answer, str)
answer = llm(
prompt="Q: where can we visit in the capital of France? A:",
generate_config={"max_tokens": 1024, "stream": True},
)
assert isinstance(answer, str)