community: Fix rank-llm import paths for new 0.20.3 version (#29154)

# **PR title**: "community: Fix rank-llm import paths for new 0.20.3
version"
- The "community" package is being modified to handle updated import
paths for the new `rank-llm` version.

---

## Description
This PR updates the import paths for the `rank-llm` package to account
for changes introduced in version `0.20.3`. The changes ensure
compatibility with both pre- and post-revamp versions of `rank-llm`,
specifically version `0.12.8`. Conditional imports are introduced based
on the detected version of `rank-llm` to handle different path
structures for `VicunaReranker`, `ZephyrReranker`, and `SafeOpenai`.

## Issue
RankLLMRerank usage throws an error when used GPT (not only) when
rank-llm version is > 0.12.8 - #29156

## Dependencies
This change relies on the `packaging` and `pkg_resources` libraries to
handle version checks.

## Twitter handle
@tymzar
This commit is contained in:
Tymon Żarski 2025-01-13 16:22:14 +01:00 committed by GitHub
parent 0e3115330d
commit 689592f9bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,12 +2,14 @@ from __future__ import annotations
from copy import deepcopy
from enum import Enum
from importlib.metadata import version
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.utils import get_from_dict_or_env
from packaging.version import Version
from pydantic import ConfigDict, Field, PrivateAttr, model_validator
if TYPE_CHECKING:
@ -49,6 +51,10 @@ class RankLLMRerank(BaseDocumentCompressor):
if not values.get("client"):
client_name = values.get("model", "zephyr")
is_pre_rank_llm_revamp = Version(version=version("rank_llm")) <= Version(
"0.12.8"
)
try:
model_enum = ModelType(client_name.lower())
except ValueError:
@ -58,15 +64,29 @@ class RankLLMRerank(BaseDocumentCompressor):
try:
if model_enum == ModelType.VICUNA:
from rank_llm.rerank.vicuna_reranker import VicunaReranker
if is_pre_rank_llm_revamp:
from rank_llm.rerank.vicuna_reranker import VicunaReranker
else:
from rank_llm.rerank.listwise.vicuna_reranker import (
VicunaReranker,
)
values["client"] = VicunaReranker()
elif model_enum == ModelType.ZEPHYR:
from rank_llm.rerank.zephyr_reranker import ZephyrReranker
if is_pre_rank_llm_revamp:
from rank_llm.rerank.zephyr_reranker import ZephyrReranker
else:
from rank_llm.rerank.listwise.zephyr_reranker import (
ZephyrReranker,
)
values["client"] = ZephyrReranker()
elif model_enum == ModelType.GPT:
from rank_llm.rerank.rank_gpt import SafeOpenai
if is_pre_rank_llm_revamp:
from rank_llm.rerank.rank_gpt import SafeOpenai
else:
from rank_llm.rerank.listwise.rank_gpt import SafeOpenai
from rank_llm.rerank.reranker import Reranker
openai_api_key = get_from_dict_or_env(