From 689592f9bb58fe600def230e4837c7b78a9d2195 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tymon=20=C5=BBarski?= <68420753+tymzar@users.noreply.github.com> Date: Mon, 13 Jan 2025 16:22:14 +0100 Subject: [PATCH] community: Fix rank-llm import paths for new 0.20.3 version (#29154) # **PR title**: "community: Fix rank-llm import paths for new 0.20.3 version" - The "community" package is being modified to handle updated import paths for the new `rank-llm` version. --- ## Description This PR updates the import paths for the `rank-llm` package to account for changes introduced in version `0.20.3`. The changes ensure compatibility with both pre- and post-revamp versions of `rank-llm`, specifically version `0.12.8`. Conditional imports are introduced based on the detected version of `rank-llm` to handle different path structures for `VicunaReranker`, `ZephyrReranker`, and `SafeOpenai`. ## Issue RankLLMRerank usage throws an error when used GPT (not only) when rank-llm version is > 0.12.8 - #29156 ## Dependencies This change relies on the `packaging` and `pkg_resources` libraries to handle version checks. ## Twitter handle @tymzar --- .../document_compressors/rankllm_rerank.py | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/document_compressors/rankllm_rerank.py b/libs/community/langchain_community/document_compressors/rankllm_rerank.py index bcf18652928..fb5a92ba006 100644 --- a/libs/community/langchain_community/document_compressors/rankllm_rerank.py +++ b/libs/community/langchain_community/document_compressors/rankllm_rerank.py @@ -2,12 +2,14 @@ from __future__ import annotations from copy import deepcopy from enum import Enum +from importlib.metadata import version from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from langchain_core.callbacks.manager import Callbacks from langchain_core.documents import Document from langchain_core.utils import get_from_dict_or_env +from packaging.version import Version from pydantic import ConfigDict, Field, PrivateAttr, model_validator if TYPE_CHECKING: @@ -49,6 +51,10 @@ class RankLLMRerank(BaseDocumentCompressor): if not values.get("client"): client_name = values.get("model", "zephyr") + is_pre_rank_llm_revamp = Version(version=version("rank_llm")) <= Version( + "0.12.8" + ) + try: model_enum = ModelType(client_name.lower()) except ValueError: @@ -58,15 +64,29 @@ class RankLLMRerank(BaseDocumentCompressor): try: if model_enum == ModelType.VICUNA: - from rank_llm.rerank.vicuna_reranker import VicunaReranker + if is_pre_rank_llm_revamp: + from rank_llm.rerank.vicuna_reranker import VicunaReranker + else: + from rank_llm.rerank.listwise.vicuna_reranker import ( + VicunaReranker, + ) values["client"] = VicunaReranker() elif model_enum == ModelType.ZEPHYR: - from rank_llm.rerank.zephyr_reranker import ZephyrReranker + if is_pre_rank_llm_revamp: + from rank_llm.rerank.zephyr_reranker import ZephyrReranker + else: + from rank_llm.rerank.listwise.zephyr_reranker import ( + ZephyrReranker, + ) values["client"] = ZephyrReranker() elif model_enum == ModelType.GPT: - from rank_llm.rerank.rank_gpt import SafeOpenai + if is_pre_rank_llm_revamp: + from rank_llm.rerank.rank_gpt import SafeOpenai + else: + from rank_llm.rerank.listwise.rank_gpt import SafeOpenai + from rank_llm.rerank.reranker import Reranker openai_api_key = get_from_dict_or_env(