langchain/libs/community/tests/integration_tests/llms/test_ipex_llm.py
Yuwen Hu 566e9ba164
community: add Intel GPU support to ipex-llm llm integration (#22458)
**Description:** [IPEX-LLM](https://github.com/intel-analytics/ipex-llm)
is a PyTorch library for running LLM on Intel CPU and GPU (e.g., local
PC with iGPU, discrete GPU such as Arc, Flex and Max) with very low
latency. This PR adds Intel GPU support to `ipex-llm` llm integration.
**Dependencies:** `ipex-llm`
**Contribution maintainer**: @ivy-lv11 @Oscilloscope98
**tests and docs**: 
- Add: langchain/docs/docs/integrations/llms/ipex_llm_gpu.ipynb
- Update: langchain/docs/docs/integrations/llms/ipex_llm_gpu.ipynb
- Update: langchain/libs/community/tests/llms/test_ipex_llm.py

---------

Co-authored-by: ivy-lv11 <zhicunlv@gmail.com>
2024-09-02 08:49:08 -04:00

96 lines
2.5 KiB
Python

"""Test IPEX LLM"""
import os
from typing import Any
import pytest
from langchain_core.outputs import LLMResult
from langchain_community.llms import IpexLLM
model_ids_to_test = os.getenv("TEST_IPEXLLM_MODEL_IDS") or ""
skip_if_no_model_ids = pytest.mark.skipif(
not model_ids_to_test, reason="TEST_IPEXLLM_MODEL_IDS environment variable not set."
)
model_ids_to_test = [model_id.strip() for model_id in model_ids_to_test.split(",")] # type: ignore
device = os.getenv("TEST_IPEXLLM_MODEL_DEVICE") or "cpu"
def load_model(model_id: str) -> Any:
llm = IpexLLM.from_model_id(
model_id=model_id,
model_kwargs={
"temperature": 0,
"max_length": 16,
"trust_remote_code": True,
"device": device,
},
)
return llm
def load_model_more_types(model_id: str, load_in_low_bit: str) -> Any:
llm = IpexLLM.from_model_id(
model_id=model_id,
load_in_low_bit=load_in_low_bit,
model_kwargs={"temperature": 0, "max_length": 16, "trust_remote_code": True},
)
return llm
@skip_if_no_model_ids
@pytest.mark.parametrize(
"model_id",
model_ids_to_test,
)
def test_call(model_id: str) -> None:
"""Test valid call."""
llm = load_model(model_id)
output = llm.invoke("Hello!")
assert isinstance(output, str)
@skip_if_no_model_ids
@pytest.mark.parametrize(
"model_id",
model_ids_to_test,
)
def test_asym_int4(model_id: str) -> None:
"""Test asym int4 data type."""
llm = load_model_more_types(model_id=model_id, load_in_low_bit="asym_int4")
output = llm.invoke("Hello!")
assert isinstance(output, str)
@skip_if_no_model_ids
@pytest.mark.parametrize(
"model_id",
model_ids_to_test,
)
def test_generate(model_id: str) -> None:
"""Test valid generate."""
llm = load_model(model_id)
output = llm.generate(["Hello!"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
@skip_if_no_model_ids
@pytest.mark.parametrize(
"model_id",
model_ids_to_test,
)
def test_save_load_lowbit(model_id: str) -> None:
"""Test save and load lowbit model."""
saved_lowbit_path = "/tmp/saved_model"
llm = load_model(model_id)
llm.model.save_low_bit(saved_lowbit_path)
del llm
loaded_llm = IpexLLM.from_model_id_low_bit(
model_id=saved_lowbit_path,
tokenizer_id=model_id,
model_kwargs={"temperature": 0, "max_length": 16, "trust_remote_code": True},
)
output = loaded_llm.invoke("Hello!")
assert isinstance(output, str)