mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-21 10:31:23 +00:00
FEAT: Integrate Xinference LLMs and Embeddings (#8171)
- [Xorbits Inference(Xinference)](https://github.com/xorbitsai/inference) is a powerful and versatile library designed to serve language, speech recognition, and multimodal models. Xinference supports a variety of GGML-compatible models including chatglm, whisper, and vicuna, and utilizes heterogeneous hardware and a distributed architecture for seamless cross-device and cross-server model deployment. - This PR integrates Xinference models and Xinference embeddings into LangChain. - Dependencies: To install the depenedencies for this integration, run `pip install "xinference[all]"` - Example Usage: To start a local instance of Xinference, run `xinference`. To deploy Xinference in a distributed cluster, first start an Xinference supervisor using `xinference-supervisor`: `xinference-supervisor -H "${supervisor_host}"` Then, start the Xinference workers using `xinference-worker` on each server you want to run them on. `xinference-worker -e "http://${supervisor_host}:9997"` To use Xinference with LangChain, you also need to launch a model. You can use command line interface (CLI) to do so. Fo example: `xinference launch -n vicuna-v1.3 -f ggmlv3 -q q4_0`. This launches a model named vicuna-v1.3 with `model_format="ggmlv3"` and `quantization="q4_0"`. A model UID is returned for you to use. Now you can use Xinference with LangChain: ```python from langchain.llms import Xinference llm = Xinference( server_url="http://0.0.0.0:9997", # suppose the supervisor_host is "0.0.0.0" model_uid = {model_uid} # model UID returned from launching a model ) llm( prompt="Q: where can we visit in the capital of France? A:", generate_config={"max_tokens": 1024}, ) ``` You can also use RESTful client to launch a model: ```python from xinference.client import RESTfulClient client = RESTfulClient("http://0.0.0.0:9997") model_uid = client.launch_model(model_name="vicuna-v1.3", model_size_in_billions=7, quantization="q4_0") ``` The following code block demonstrates how to use Xinference embeddings with LangChain: ```python from langchain.embeddings import XinferenceEmbeddings xinference = XinferenceEmbeddings( server_url="http://0.0.0.0:9997", model_uid = model_uid ) ``` ```python query_result = xinference.embed_query("This is a test query") ``` ```python doc_result = xinference.embed_documents(["text A", "text B"]) ``` Xinference is still under rapid development. Feel free to [join our Slack community](https://xorbitsio.slack.com/join/shared_invite/zt-1z3zsm9ep-87yI9YZ_B79HLB2ccTq4WA) to get the latest updates! - Request for review: @hwchase17, @baskaryan - Twitter handle: https://twitter.com/Xorbitsio --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -42,6 +42,7 @@ from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddi
|
||||
from langchain.embeddings.spacy_embeddings import SpacyEmbeddings
|
||||
from langchain.embeddings.tensorflow_hub import TensorflowHubEmbeddings
|
||||
from langchain.embeddings.vertexai import VertexAIEmbeddings
|
||||
from langchain.embeddings.xinference import XinferenceEmbeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -78,6 +79,7 @@ __all__ = [
|
||||
"SpacyEmbeddings",
|
||||
"NLPCloudEmbeddings",
|
||||
"GPT4AllEmbeddings",
|
||||
"XinferenceEmbeddings",
|
||||
"LocalAIEmbeddings",
|
||||
"AwaEmbeddings",
|
||||
]
|
||||
|
113
libs/langchain/langchain/embeddings/xinference.py
Normal file
113
libs/langchain/langchain/embeddings/xinference.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Wrapper around Xinference embedding models."""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class XinferenceEmbeddings(Embeddings):
|
||||
|
||||
"""Wrapper around xinference embedding models.
|
||||
To use, you should have the xinference library installed:
|
||||
.. code-block:: bash
|
||||
|
||||
pip install xinference
|
||||
|
||||
Check out: https://github.com/xorbitsai/inference
|
||||
To run, you need to start a Xinference supervisor on one server and Xinference workers on the other servers
|
||||
Example:
|
||||
To start a local instance of Xinference, run
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference
|
||||
You can also deploy Xinference in a distributed cluster. Here are the steps:
|
||||
Starting the supervisor:
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference-supervisor
|
||||
Starting the worker:
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference-worker
|
||||
|
||||
Then, launch a model using command line interface (CLI).
|
||||
|
||||
Example:
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference launch -n orca -s 3 -q q4_0
|
||||
|
||||
It will return a model UID. Then you can use Xinference Embedding with LangChain.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import XinferenceEmbeddings
|
||||
|
||||
xinference = XinferenceEmbeddings(
|
||||
server_url="http://0.0.0.0:9997",
|
||||
model_uid = {model_uid} # replace model_uid with the model UID return from launching the model
|
||||
)
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
client: Any
|
||||
server_url: Optional[str]
|
||||
"""URL of the xinference server"""
|
||||
model_uid: Optional[str]
|
||||
"""UID of the launched model"""
|
||||
|
||||
def __init__(
|
||||
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
|
||||
):
|
||||
try:
|
||||
from xinference.client import RESTfulClient
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import RESTfulClient from xinference. Please install it"
|
||||
" with `pip install xinference`."
|
||||
) from e
|
||||
|
||||
super().__init__()
|
||||
|
||||
if server_url is None:
|
||||
raise ValueError("Please provide server URL")
|
||||
|
||||
if model_uid is None:
|
||||
raise ValueError("Please provide the model UID")
|
||||
|
||||
self.server_url = server_url
|
||||
|
||||
self.model_uid = model_uid
|
||||
|
||||
self.client = RESTfulClient(server_url)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of documents using Xinference.
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
|
||||
model = self.client.get_model(self.model_uid)
|
||||
|
||||
embeddings = [
|
||||
model.create_embedding(text)["data"][0]["embedding"] for text in texts
|
||||
]
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a query of documents using Xinference.
|
||||
Args:
|
||||
text: The text to embed.
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
|
||||
model = self.client.get_model(self.model_uid)
|
||||
|
||||
embedding_res = model.create_embedding(text)
|
||||
|
||||
embedding = embedding_res["data"][0]["embedding"]
|
||||
|
||||
return list(map(float, embedding))
|
@@ -56,6 +56,7 @@ from langchain.llms.textgen import TextGen
|
||||
from langchain.llms.tongyi import Tongyi
|
||||
from langchain.llms.vertexai import VertexAI
|
||||
from langchain.llms.writer import Writer
|
||||
from langchain.llms.xinference import Xinference
|
||||
|
||||
__all__ = [
|
||||
"AI21",
|
||||
@@ -115,6 +116,7 @@ __all__ = [
|
||||
"VertexAI",
|
||||
"Writer",
|
||||
"OctoAIEndpoint",
|
||||
"Xinference",
|
||||
]
|
||||
|
||||
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
@@ -170,4 +172,5 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
"openllm": OpenLLM,
|
||||
"openllm_client": OpenLLM,
|
||||
"writer": Writer,
|
||||
"xinference": Xinference,
|
||||
}
|
||||
|
185
libs/langchain/langchain/llms/xinference.py
Normal file
185
libs/langchain/langchain/llms/xinference.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from typing import TYPE_CHECKING, Any, Generator, List, Mapping, Optional, Union
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from xinference.client import RESTfulChatModelHandle, RESTfulGenerateModelHandle
|
||||
from xinference.model.llm.core import LlamaCppGenerateConfig
|
||||
|
||||
|
||||
class Xinference(LLM):
|
||||
"""Wrapper for accessing Xinference's large-scale model inference service.
|
||||
To use, you should have the xinference library installed:
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "xinference[all]"
|
||||
|
||||
Check out: https://github.com/xorbitsai/inference
|
||||
To run, you need to start a Xinference supervisor on one server and Xinference workers on the other servers
|
||||
Example:
|
||||
To start a local instance of Xinference, run
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference
|
||||
|
||||
You can also deploy Xinference in a distributed cluster. Here are the steps:
|
||||
Starting the supervisor:
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference-supervisor
|
||||
|
||||
Starting the worker:
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference-worker
|
||||
|
||||
Then, launch a model using command line interface (CLI).
|
||||
|
||||
Example:
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference launch -n orca -s 3 -q q4_0
|
||||
|
||||
It will return a model UID. Then, you can use Xinference with LangChain.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import Xinference
|
||||
|
||||
llm = Xinference(
|
||||
server_url="http://0.0.0.0:9997",
|
||||
model_uid = {model_uid} # replace model_uid with the model UID return from launching the model
|
||||
)
|
||||
|
||||
llm(
|
||||
prompt="Q: where can we visit in the capital of France? A:",
|
||||
generate_config={"max_tokens": 1024, "stream": True},
|
||||
)
|
||||
|
||||
To view all the supported builtin models, run:
|
||||
.. code-block:: bash
|
||||
$ xinference list --all
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
client: Any
|
||||
server_url: Optional[str]
|
||||
"""URL of the xinference server"""
|
||||
model_uid: Optional[str]
|
||||
"""UID of the launched model"""
|
||||
|
||||
def __init__(
|
||||
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
|
||||
):
|
||||
try:
|
||||
from xinference.client import RESTfulClient
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import RESTfulClient from xinference. Please install it"
|
||||
" with `pip install xinference`."
|
||||
) from e
|
||||
|
||||
super().__init__(
|
||||
**{
|
||||
"server_url": server_url,
|
||||
"model_uid": model_uid,
|
||||
}
|
||||
)
|
||||
|
||||
if self.server_url is None:
|
||||
raise ValueError("Please provide server URL")
|
||||
|
||||
if self.model_uid is None:
|
||||
raise ValueError("Please provide the model UID")
|
||||
|
||||
self.client = RESTfulClient(server_url)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "xinference"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
**{"server_url": self.server_url},
|
||||
**{"model_uid": self.model_uid},
|
||||
}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call the xinference model and return the output.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to use for generation.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
generate_config: Optional dictionary for the configuration used for
|
||||
generation.
|
||||
|
||||
Returns:
|
||||
The generated string by the model.
|
||||
"""
|
||||
model = self.client.get_model(self.model_uid)
|
||||
|
||||
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
|
||||
|
||||
if stop:
|
||||
generate_config["stop"] = stop
|
||||
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
return combined_text_output
|
||||
|
||||
else:
|
||||
completion = model.generate(prompt=prompt, generate_config=generate_config)
|
||||
return completion["choices"][0]["text"]
|
||||
|
||||
def _stream_generate(
|
||||
self,
|
||||
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle"],
|
||||
prompt: str,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
generate_config: Optional["LlamaCppGenerateConfig"] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Args:
|
||||
prompt: The prompt to use for generation.
|
||||
model: The model used for generation.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
generate_config: Optional dictionary for the configuration used for
|
||||
generation.
|
||||
|
||||
Yields:
|
||||
A string token.
|
||||
"""
|
||||
streaming_response = model.generate(
|
||||
prompt=prompt, generate_config=generate_config
|
||||
)
|
||||
for chunk in streaming_response:
|
||||
if isinstance(chunk, dict):
|
||||
choices = chunk.get("choices", [])
|
||||
if choices:
|
||||
choice = choices[0]
|
||||
if isinstance(choice, dict):
|
||||
token = choice.get("text", "")
|
||||
log_probs = choice.get("logprobs")
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
token=token, verbose=self.verbose, log_probs=log_probs
|
||||
)
|
||||
yield token
|
1222
libs/langchain/poetry.lock
generated
1222
libs/langchain/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -124,6 +124,7 @@ langsmith = "~0.0.11"
|
||||
rank-bm25 = {version = "^0.2.2", optional = true}
|
||||
amadeus = {version = ">=8.1.0", optional = true}
|
||||
geopandas = {version = "^0.13.1", optional = true}
|
||||
xinference = {version = "^0.0.6", optional = true}
|
||||
python-arango = {version = "^7.5.9", optional = true}
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
@@ -219,7 +220,7 @@ playwright = "^1.28.0"
|
||||
setuptools = "^67.6.1"
|
||||
|
||||
[tool.poetry.extras]
|
||||
llms = ["anthropic", "clarifai", "cohere", "openai", "openllm", "openlm", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"]
|
||||
llms = ["anthropic", "clarifai", "cohere", "openai", "openllm", "openlm", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers", "xinference"]
|
||||
qdrant = ["qdrant-client"]
|
||||
openai = ["openai", "tiktoken"]
|
||||
text_helpers = ["chardet"]
|
||||
@@ -315,6 +316,7 @@ all = [
|
||||
"octoai-sdk",
|
||||
"rdflib",
|
||||
"amadeus",
|
||||
"xinference",
|
||||
"python-arango",
|
||||
]
|
||||
|
||||
@@ -356,6 +358,7 @@ extended_testing = [
|
||||
"rank_bm25",
|
||||
"geopandas",
|
||||
"jinja2",
|
||||
"xinference",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
@@ -0,0 +1,74 @@
|
||||
"""Test Xinference embeddings."""
|
||||
import time
|
||||
from typing import AsyncGenerator, Tuple
|
||||
|
||||
import pytest_asyncio
|
||||
|
||||
from langchain.embeddings import XinferenceEmbeddings
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup() -> AsyncGenerator[Tuple[str, str], None]:
|
||||
import xoscar as xo
|
||||
from xinference.deploy.supervisor import start_supervisor_components
|
||||
from xinference.deploy.utils import create_worker_actor_pool
|
||||
from xinference.deploy.worker import start_worker_components
|
||||
|
||||
pool = await create_worker_actor_pool(
|
||||
f"test://127.0.0.1:{xo.utils.get_next_port()}"
|
||||
)
|
||||
print(f"Pool running on localhost:{pool.external_address}")
|
||||
|
||||
endpoint = await start_supervisor_components(
|
||||
pool.external_address, "127.0.0.1", xo.utils.get_next_port()
|
||||
)
|
||||
await start_worker_components(
|
||||
address=pool.external_address, supervisor_address=pool.external_address
|
||||
)
|
||||
|
||||
# wait for the api.
|
||||
time.sleep(3)
|
||||
async with pool:
|
||||
yield endpoint, pool.external_address
|
||||
|
||||
|
||||
def test_xinference_embedding_documents(setup: Tuple[str, str]) -> None:
|
||||
"""Test xinference embeddings for documents."""
|
||||
from xinference.client import RESTfulClient
|
||||
|
||||
endpoint, _ = setup
|
||||
|
||||
client = RESTfulClient(endpoint)
|
||||
|
||||
model_uid = client.launch_model(
|
||||
model_name="vicuna-v1.3",
|
||||
model_size_in_billions=7,
|
||||
model_format="ggmlv3",
|
||||
quantization="q4_0",
|
||||
)
|
||||
|
||||
xinference = XinferenceEmbeddings(server_url=endpoint, model_uid=model_uid)
|
||||
|
||||
documents = ["foo bar", "bar foo"]
|
||||
output = xinference.embed_documents(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) == 4096
|
||||
|
||||
|
||||
def test_xinference_embedding_query(setup: Tuple[str, str]) -> None:
|
||||
"""Test xinference embeddings for query."""
|
||||
from xinference.client import RESTfulClient
|
||||
|
||||
endpoint, _ = setup
|
||||
|
||||
client = RESTfulClient(endpoint)
|
||||
|
||||
model_uid = client.launch_model(
|
||||
model_name="vicuna-v1.3", model_size_in_billions=7, quantization="q4_0"
|
||||
)
|
||||
|
||||
xinference = XinferenceEmbeddings(server_url=endpoint, model_uid=model_uid)
|
||||
|
||||
document = "foo bar"
|
||||
output = xinference.embed_query(document)
|
||||
assert len(output) == 4096
|
@@ -0,0 +1,57 @@
|
||||
"""Test Xinference wrapper."""
|
||||
import time
|
||||
from typing import AsyncGenerator, Tuple
|
||||
|
||||
import pytest_asyncio
|
||||
|
||||
from langchain.llms import Xinference
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup() -> AsyncGenerator[Tuple[str, str], None]:
|
||||
import xoscar as xo
|
||||
from xinference.deploy.supervisor import start_supervisor_components
|
||||
from xinference.deploy.utils import create_worker_actor_pool
|
||||
from xinference.deploy.worker import start_worker_components
|
||||
|
||||
pool = await create_worker_actor_pool(
|
||||
f"test://127.0.0.1:{xo.utils.get_next_port()}"
|
||||
)
|
||||
print(f"Pool running on localhost:{pool.external_address}")
|
||||
|
||||
endpoint = await start_supervisor_components(
|
||||
pool.external_address, "127.0.0.1", xo.utils.get_next_port()
|
||||
)
|
||||
await start_worker_components(
|
||||
address=pool.external_address, supervisor_address=pool.external_address
|
||||
)
|
||||
|
||||
# wait for the api.
|
||||
time.sleep(3)
|
||||
async with pool:
|
||||
yield endpoint, pool.external_address
|
||||
|
||||
|
||||
def test_xinference_llm_(setup: Tuple[str, str]) -> None:
|
||||
from xinference.client import RESTfulClient
|
||||
|
||||
endpoint, _ = setup
|
||||
|
||||
client = RESTfulClient(endpoint)
|
||||
|
||||
model_uid = client.launch_model(
|
||||
model_name="vicuna-v1.3", model_size_in_billions=7, quantization="q4_0"
|
||||
)
|
||||
|
||||
llm = Xinference(server_url=endpoint, model_uid=model_uid)
|
||||
|
||||
answer = llm(prompt="Q: What food can we try in the capital of France? A:")
|
||||
|
||||
assert isinstance(answer, str)
|
||||
|
||||
answer = llm(
|
||||
prompt="Q: where can we visit in the capital of France? A:",
|
||||
generate_config={"max_tokens": 1024, "stream": True},
|
||||
)
|
||||
|
||||
assert isinstance(answer, str)
|
Reference in New Issue
Block a user