mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 08:27:03 +00:00
community: Fix rank-llm import paths for new 0.20.3 version (#29154)
# **PR title**: "community: Fix rank-llm import paths for new 0.20.3 version" - The "community" package is being modified to handle updated import paths for the new `rank-llm` version. --- ## Description This PR updates the import paths for the `rank-llm` package to account for changes introduced in version `0.20.3`. The changes ensure compatibility with both pre- and post-revamp versions of `rank-llm`, specifically version `0.12.8`. Conditional imports are introduced based on the detected version of `rank-llm` to handle different path structures for `VicunaReranker`, `ZephyrReranker`, and `SafeOpenai`. ## Issue RankLLMRerank usage throws an error when used GPT (not only) when rank-llm version is > 0.12.8 - #29156 ## Dependencies This change relies on the `packaging` and `pkg_resources` libraries to handle version checks. ## Twitter handle @tymzar
This commit is contained in:
parent
0e3115330d
commit
689592f9bb
@ -2,12 +2,14 @@ from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from importlib.metadata import version
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
|
||||
|
||||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
||||
from langchain_core.callbacks.manager import Callbacks
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from packaging.version import Version
|
||||
from pydantic import ConfigDict, Field, PrivateAttr, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -49,6 +51,10 @@ class RankLLMRerank(BaseDocumentCompressor):
|
||||
if not values.get("client"):
|
||||
client_name = values.get("model", "zephyr")
|
||||
|
||||
is_pre_rank_llm_revamp = Version(version=version("rank_llm")) <= Version(
|
||||
"0.12.8"
|
||||
)
|
||||
|
||||
try:
|
||||
model_enum = ModelType(client_name.lower())
|
||||
except ValueError:
|
||||
@ -58,15 +64,29 @@ class RankLLMRerank(BaseDocumentCompressor):
|
||||
|
||||
try:
|
||||
if model_enum == ModelType.VICUNA:
|
||||
from rank_llm.rerank.vicuna_reranker import VicunaReranker
|
||||
if is_pre_rank_llm_revamp:
|
||||
from rank_llm.rerank.vicuna_reranker import VicunaReranker
|
||||
else:
|
||||
from rank_llm.rerank.listwise.vicuna_reranker import (
|
||||
VicunaReranker,
|
||||
)
|
||||
|
||||
values["client"] = VicunaReranker()
|
||||
elif model_enum == ModelType.ZEPHYR:
|
||||
from rank_llm.rerank.zephyr_reranker import ZephyrReranker
|
||||
if is_pre_rank_llm_revamp:
|
||||
from rank_llm.rerank.zephyr_reranker import ZephyrReranker
|
||||
else:
|
||||
from rank_llm.rerank.listwise.zephyr_reranker import (
|
||||
ZephyrReranker,
|
||||
)
|
||||
|
||||
values["client"] = ZephyrReranker()
|
||||
elif model_enum == ModelType.GPT:
|
||||
from rank_llm.rerank.rank_gpt import SafeOpenai
|
||||
if is_pre_rank_llm_revamp:
|
||||
from rank_llm.rerank.rank_gpt import SafeOpenai
|
||||
else:
|
||||
from rank_llm.rerank.listwise.rank_gpt import SafeOpenai
|
||||
|
||||
from rank_llm.rerank.reranker import Reranker
|
||||
|
||||
openai_api_key = get_from_dict_or_env(
|
||||
|
Loading…
Reference in New Issue
Block a user