mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 08:56:27 +00:00
community: Fix rank-llm import paths for new 0.20.3 version (#29154)
# **PR title**: "community: Fix rank-llm import paths for new 0.20.3 version" - The "community" package is being modified to handle updated import paths for the new `rank-llm` version. --- ## Description This PR updates the import paths for the `rank-llm` package to account for changes introduced in version `0.20.3`. The changes ensure compatibility with both pre- and post-revamp versions of `rank-llm`, specifically version `0.12.8`. Conditional imports are introduced based on the detected version of `rank-llm` to handle different path structures for `VicunaReranker`, `ZephyrReranker`, and `SafeOpenai`. ## Issue RankLLMRerank usage throws an error when used GPT (not only) when rank-llm version is > 0.12.8 - #29156 ## Dependencies This change relies on the `packaging` and `pkg_resources` libraries to handle version checks. ## Twitter handle @tymzar
This commit is contained in:
parent
0e3115330d
commit
689592f9bb
@ -2,12 +2,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from importlib.metadata import version
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
|
||||||
|
|
||||||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
||||||
from langchain_core.callbacks.manager import Callbacks
|
from langchain_core.callbacks.manager import Callbacks
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
|
from packaging.version import Version
|
||||||
from pydantic import ConfigDict, Field, PrivateAttr, model_validator
|
from pydantic import ConfigDict, Field, PrivateAttr, model_validator
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -49,6 +51,10 @@ class RankLLMRerank(BaseDocumentCompressor):
|
|||||||
if not values.get("client"):
|
if not values.get("client"):
|
||||||
client_name = values.get("model", "zephyr")
|
client_name = values.get("model", "zephyr")
|
||||||
|
|
||||||
|
is_pre_rank_llm_revamp = Version(version=version("rank_llm")) <= Version(
|
||||||
|
"0.12.8"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_enum = ModelType(client_name.lower())
|
model_enum = ModelType(client_name.lower())
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -58,15 +64,29 @@ class RankLLMRerank(BaseDocumentCompressor):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if model_enum == ModelType.VICUNA:
|
if model_enum == ModelType.VICUNA:
|
||||||
|
if is_pre_rank_llm_revamp:
|
||||||
from rank_llm.rerank.vicuna_reranker import VicunaReranker
|
from rank_llm.rerank.vicuna_reranker import VicunaReranker
|
||||||
|
else:
|
||||||
|
from rank_llm.rerank.listwise.vicuna_reranker import (
|
||||||
|
VicunaReranker,
|
||||||
|
)
|
||||||
|
|
||||||
values["client"] = VicunaReranker()
|
values["client"] = VicunaReranker()
|
||||||
elif model_enum == ModelType.ZEPHYR:
|
elif model_enum == ModelType.ZEPHYR:
|
||||||
|
if is_pre_rank_llm_revamp:
|
||||||
from rank_llm.rerank.zephyr_reranker import ZephyrReranker
|
from rank_llm.rerank.zephyr_reranker import ZephyrReranker
|
||||||
|
else:
|
||||||
|
from rank_llm.rerank.listwise.zephyr_reranker import (
|
||||||
|
ZephyrReranker,
|
||||||
|
)
|
||||||
|
|
||||||
values["client"] = ZephyrReranker()
|
values["client"] = ZephyrReranker()
|
||||||
elif model_enum == ModelType.GPT:
|
elif model_enum == ModelType.GPT:
|
||||||
|
if is_pre_rank_llm_revamp:
|
||||||
from rank_llm.rerank.rank_gpt import SafeOpenai
|
from rank_llm.rerank.rank_gpt import SafeOpenai
|
||||||
|
else:
|
||||||
|
from rank_llm.rerank.listwise.rank_gpt import SafeOpenai
|
||||||
|
|
||||||
from rank_llm.rerank.reranker import Reranker
|
from rank_llm.rerank.reranker import Reranker
|
||||||
|
|
||||||
openai_api_key = get_from_dict_or_env(
|
openai_api_key = get_from_dict_or_env(
|
||||||
|
Loading…
Reference in New Issue
Block a user