community[minor]: Add IPEX-LLM BGE embedding support on both Intel CPU and GPU (#22226)

**Description:** [IPEX-LLM](https://github.com/intel-analytics/ipex-llm)
is a PyTorch library for running LLM on Intel CPU and GPU (e.g., local
PC with iGPU, discrete GPU such as Arc, Flex and Max) with very low
latency. This PR adds ipex-llm integrations to langchain for BGE
embedding support on both Intel CPU and GPU.
**Dependencies:** `ipex-llm`, `sentence-transformers`
**Contribution maintainer**: @Oscilloscope98 
**tests and docs**: 
- langchain/docs/docs/integrations/text_embedding/ipex_llm.ipynb
- langchain/docs/docs/integrations/text_embedding/ipex_llm_gpu.ipynb
-
langchain/libs/community/tests/integration_tests/embeddings/test_ipex_llm.py

---------

Co-authored-by: Shengsheng Huang <shannie.huang@gmail.com>
This commit is contained in:
Yuwen Hu 2024-06-04 03:37:10 +08:00 committed by GitHub
parent c01467b1f4
commit ba0dca46d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 461 additions and 0 deletions

View File

@ -0,0 +1,101 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Local BGE Embeddings with IPEX-LLM on Intel CPU\n",
"\n",
"> [IPEX-LLM](https://github.com/intel-analytics/ipex-llm) is a PyTorch library for running LLM on Intel CPU and GPU (e.g., local PC with iGPU, discrete GPU such as Arc, Flex and Max) with very low latency.\n",
"\n",
"This example goes over how to use LangChain to conduct embedding tasks with `ipex-llm` optimizations on Intel CPU. This would be helpful in applications such as RAG, document QA, etc.\n",
"\n",
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain langchain-community"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Install IPEX-LLM for optimizations on Intel CPU, as well as `sentence-transformers`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu\n",
"%pip install sentence-transformers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> **Note**\n",
">\n",
"> For Windows users, `--extra-index-url https://download.pytorch.org/whl/cpu` when install `ipex-llm` is not required.\n",
"\n",
"## Basic Usage"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.embeddings import IpexLLMBgeEmbeddings\n",
"\n",
"embedding_model = IpexLLMBgeEmbeddings(\n",
" model_name=\"BAAI/bge-large-en-v1.5\",\n",
" model_kwargs={},\n",
" encode_kwargs={\"normalize_embeddings\": True},\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"API Reference\n",
"- [IpexLLMBgeEmbeddings](https://api.python.langchain.com/en/latest/embeddings/langchain_community.embeddings.ipex_llm.IpexLLMBgeEmbeddings.html)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"IPEX-LLM is a PyTorch library for running LLM on Intel CPU and GPU (e.g., local PC with iGPU, discrete GPU such as Arc, Flex and Max) with very low latency.\"\n",
"query = \"What is IPEX-LLM?\"\n",
"\n",
"text_embeddings = embedding_model.embed_documents([sentence, query])\n",
"print(f\"text_embeddings[0][:10]: {text_embeddings[0][:10]}\")\n",
"print(f\"text_embeddings[1][:10]: {text_embeddings[1][:10]}\")\n",
"\n",
"query_embedding = embedding_model.embed_query(query)\n",
"print(f\"query_embedding[:10]: {query_embedding[:10]}\")"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,164 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Local BGE Embeddings with IPEX-LLM on Intel GPU\n",
"\n",
"> [IPEX-LLM](https://github.com/intel-analytics/ipex-llm) is a PyTorch library for running LLM on Intel CPU and GPU (e.g., local PC with iGPU, discrete GPU such as Arc, Flex and Max) with very low latency.\n",
"\n",
"This example goes over how to use LangChain to conduct embedding tasks with `ipex-llm` optimizations on Intel GPU. This would be helpful in applications such as RAG, document QA, etc.\n",
"\n",
"> **Note**\n",
">\n",
"> It is recommended that only Windows users with Intel Arc A-Series GPU (except for Intel Arc A300-Series or Pro A60) run this Jupyter notebook directly. For other cases (e.g. Linux users, Intel iGPU, etc.), it is recommended to run the code with Python scripts in terminal for best experiences.\n",
"\n",
"## Install Prerequisites\n",
"To benefit from IPEX-LLM on Intel GPUs, there are several prerequisite steps for tools installation and environment preparation.\n",
"\n",
"If you are a Windows user, visit the [Install IPEX-LLM on Windows with Intel GPU Guide](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_windows_gpu.html), and follow [Install Prerequisites](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_windows_gpu.html#install-prerequisites) to update GPU driver (optional) and install Conda.\n",
"\n",
"If you are a Linux user, visit the [Install IPEX-LLM on Linux with Intel GPU](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_linux_gpu.html), and follow [**Install Prerequisites**](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_linux_gpu.html#install-prerequisites) to install GPU driver, Intel® oneAPI Base Toolkit 2024.0, and Conda.\n",
"\n",
"## Setup\n",
"\n",
"After the prerequisites installation, you should have created a conda environment with all prerequisites installed. **Start the jupyter service in this conda environment**:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain langchain-community"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Install IPEX-LLM for optimizations on Intel GPU, as well as `sentence-transformers`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/\n",
"%pip install sentence-transformers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> **Note**\n",
">\n",
"> You can also use `https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/` as the extra-indel-url.\n",
"\n",
"## Runtime Configuration\n",
"\n",
"For optimal performance, it is recommended to set several environment variables based on your device:\n",
"\n",
"### For Windows Users with Intel Core Ultra integrated GPU"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"SYCL_CACHE_PERSISTENT\"] = \"1\"\n",
"os.environ[\"BIGDL_LLM_XMX_DISABLED\"] = \"1\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### For Windows Users with Intel Arc A-Series GPU"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"SYCL_CACHE_PERSISTENT\"] = \"1\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> **Note**\n",
">\n",
"> For the first time that each model runs on Intel iGPU/Intel Arc A300-Series or Pro A60, it may take several minutes to compile.\n",
">\n",
"> For other GPU type, please refer to [here](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Overview/install_gpu.html#runtime-configuration) for Windows users, and [here](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Overview/install_gpu.html#id5) for Linux users.\n",
"\n",
"\n",
"## Basic Usage\n",
"\n",
"Setting `device` to `\"xpu\"` in `model_kwargs` when initializing `IpexLLMBgeEmbeddings` will put the embedding model on Intel GPU and benefit from IPEX-LLM optimizations:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.embeddings import IpexLLMBgeEmbeddings\n",
"\n",
"embedding_model = IpexLLMBgeEmbeddings(\n",
" model_name=\"BAAI/bge-large-en-v1.5\",\n",
" model_kwargs={\"device\": \"xpu\"},\n",
" encode_kwargs={\"normalize_embeddings\": True},\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"API Reference\n",
"- [IpexLLMBgeEmbeddings](https://api.python.langchain.com/en/latest/embeddings/langchain_community.embeddings.ipex_llm.IpexLLMBgeEmbeddings.html)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"IPEX-LLM is a PyTorch library for running LLM on Intel CPU and GPU (e.g., local PC with iGPU, discrete GPU such as Arc, Flex and Max) with very low latency.\"\n",
"query = \"What is IPEX-LLM?\"\n",
"\n",
"text_embeddings = embedding_model.embed_documents([sentence, query])\n",
"print(f\"text_embeddings[0][:10]: {text_embeddings[0][:10]}\")\n",
"print(f\"text_embeddings[1][:10]: {text_embeddings[1][:10]}\")\n",
"\n",
"query_embedding = embedding_model.embed_query(query)\n",
"print(f\"query_embedding[:10]: {query_embedding[:10]}\")"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -104,6 +104,7 @@ if TYPE_CHECKING:
from langchain_community.embeddings.infinity_local import (
InfinityEmbeddingsLocal,
)
from langchain_community.embeddings.ipex_llm import IpexLLMBgeEmbeddings
from langchain_community.embeddings.itrex import (
QuantizedBgeEmbeddings,
)
@ -258,6 +259,7 @@ __all__ = [
"HuggingFaceInstructEmbeddings",
"InfinityEmbeddings",
"InfinityEmbeddingsLocal",
"IpexLLMBgeEmbeddings",
"JavelinAIGatewayEmbeddings",
"JinaEmbeddings",
"JohnSnowLabsEmbeddings",
@ -336,6 +338,7 @@ _module_lookup = {
"HuggingFaceInstructEmbeddings": "langchain_community.embeddings.huggingface",
"InfinityEmbeddings": "langchain_community.embeddings.infinity",
"InfinityEmbeddingsLocal": "langchain_community.embeddings.infinity_local",
"IpexLLMBgeEmbeddings": "langchain_community.embeddings.ipex_llm",
"JavelinAIGatewayEmbeddings": "langchain_community.embeddings.javelin_ai_gateway",
"JinaEmbeddings": "langchain_community.embeddings.jina",
"JohnSnowLabsEmbeddings": "langchain_community.embeddings.johnsnowlabs",

View File

@ -0,0 +1,140 @@
# This file is adapted from
# https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/embeddings/huggingface.py
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, Field
DEFAULT_BGE_MODEL = "BAAI/bge-small-en-v1.5"
DEFAULT_QUERY_BGE_INSTRUCTION_EN = (
"Represent this question for searching relevant passages: "
)
DEFAULT_QUERY_BGE_INSTRUCTION_ZH = "为这个句子生成表示以用于检索相关文章:"
class IpexLLMBgeEmbeddings(BaseModel, Embeddings):
"""Wrapper around the BGE embedding model
with IPEX-LLM optimizations on Intel CPUs and GPUs.
To use, you should have the ``ipex-llm``
and ``sentence_transformers`` package installed. Refer to
`here <https://python.langchain.com/v0.1/docs/integrations/text_embedding/ipex_llm/>`_
for installation on Intel CPU.
Example on Intel CPU:
.. code-block:: python
from langchain_community.embeddings import IpexLLMBgeEmbeddings
embedding_model = IpexLLMBgeEmbeddings(
model_name="BAAI/bge-large-en-v1.5",
model_kwargs={},
encode_kwargs={"normalize_embeddings": True},
)
Refer to
`here <https://python.langchain.com/v0.1/docs/integrations/text_embedding/ipex_llm_gpu/>`_
for installation on Intel GPU.
Example on Intel GPU:
.. code-block:: python
from langchain_community.embeddings import IpexLLMBgeEmbeddings
embedding_model = IpexLLMBgeEmbeddings(
model_name="BAAI/bge-large-en-v1.5",
model_kwargs={"device": "xpu"},
encode_kwargs={"normalize_embeddings": True},
)
"""
client: Any #: :meta private:
model_name: str = DEFAULT_BGE_MODEL
"""Model name to use."""
cache_folder: Optional[str] = None
"""Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method of the model."""
query_instruction: str = DEFAULT_QUERY_BGE_INSTRUCTION_EN
"""Instruction to use for embedding query."""
embed_instruction: str = ""
"""Instruction to use for embedding document."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
super().__init__(**kwargs)
try:
import sentence_transformers
from ipex_llm.transformers.convert import _optimize_post, _optimize_pre
except ImportError as exc:
base_url = (
"https://python.langchain.com/v0.1/docs/integrations/text_embedding/"
)
raise ImportError(
"Could not import ipex_llm or sentence_transformers. "
f"Please refer to {base_url}/ipex_llm/ "
"for install required packages on Intel CPU. "
f"And refer to {base_url}/ipex_llm_gpu/ "
"for install required packages on Intel GPU. "
) from exc
# Set "cpu" as default device
if "device" not in self.model_kwargs:
self.model_kwargs["device"] = "cpu"
if self.model_kwargs["device"] not in ["cpu", "xpu"]:
raise ValueError(
"IpexLLMBgeEmbeddings currently only supports device to be "
f"'cpu' or 'xpu', but you have: {self.model_kwargs['device']}."
)
self.client = sentence_transformers.SentenceTransformer(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
)
# Add ipex-llm optimizations
self.client = _optimize_pre(self.client)
self.client = _optimize_post(self.client)
if self.model_kwargs["device"] == "xpu":
self.client = self.client.half().to("xpu")
if "-zh" in self.model_name:
self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
texts = [self.embed_instruction + t.replace("\n", " ") for t in texts]
embeddings = self.client.encode(texts, **self.encode_kwargs)
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
text = text.replace("\n", " ")
embedding = self.client.encode(
self.query_instruction + text, **self.encode_kwargs
)
return embedding.tolist()

View File

@ -0,0 +1,52 @@
"""Test IPEX LLM"""
import os
import pytest
from langchain_community.embeddings import IpexLLMBgeEmbeddings
model_ids_to_test = os.getenv("TEST_IPEXLLM_BGE_EMBEDDING_MODEL_IDS") or ""
skip_if_no_model_ids = pytest.mark.skipif(
not model_ids_to_test,
reason="TEST_IPEXLLM_BGE_EMBEDDING_MODEL_IDS environment variable not set.",
)
model_ids_to_test = [model_id.strip() for model_id in model_ids_to_test.split(",")] # type: ignore
device = os.getenv("TEST_IPEXLLM_BGE_EMBEDDING_MODEL_DEVICE") or "cpu"
sentence = "IPEX-LLM is a PyTorch library for running LLM on Intel CPU and GPU (e.g., \
local PC with iGPU, discrete GPU such as Arc, Flex and Max) with very low latency."
query = "What is IPEX-LLM?"
@skip_if_no_model_ids
@pytest.mark.parametrize(
"model_id",
model_ids_to_test,
)
def test_embed_documents(model_id: str) -> None:
"""Test IpexLLMBgeEmbeddings embed_documents"""
embedding_model = IpexLLMBgeEmbeddings(
model_name=model_id,
model_kwargs={"device": device},
encode_kwargs={"normalize_embeddings": True},
)
output = embedding_model.embed_documents([sentence, query])
assert len(output) == 2
@skip_if_no_model_ids
@pytest.mark.parametrize(
"model_id",
model_ids_to_test,
)
def test_embed_query(model_id: str) -> None:
"""Test IpexLLMBgeEmbeddings embed_documents"""
embedding_model = IpexLLMBgeEmbeddings(
model_name=model_id,
model_kwargs={"device": device},
encode_kwargs={"normalize_embeddings": True},
)
output = embedding_model.embed_query(query)
assert isinstance(output, list)

View File

@ -55,6 +55,7 @@ EXPECTED_ALL = [
"LocalAIEmbeddings",
"AwaEmbeddings",
"HuggingFaceBgeEmbeddings",
"IpexLLMBgeEmbeddings",
"ErnieEmbeddings",
"JavelinAIGatewayEmbeddings",
"OllamaEmbeddings",