mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
community: FlashrankRerank support loading customer client (#23350)
Description: FlashrankRerank Document compressor support loading customer client Issue: #23338 Co-authored-by: gongwn1 <gongwn1@lenovo.com>
This commit is contained in:
parent
f58c40b4e3
commit
b33d2346db
@ -38,17 +38,20 @@ class FlashrankRerank(BaseDocumentCompressor):
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
from flashrank import Ranker
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import flashrank python package. "
|
||||
"Please install it with `pip install flashrank`."
|
||||
)
|
||||
if "client" in values:
|
||||
return values
|
||||
else:
|
||||
try:
|
||||
from flashrank import Ranker
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import flashrank python package. "
|
||||
"Please install it with `pip install flashrank`."
|
||||
)
|
||||
|
||||
values["model"] = values.get("model", DEFAULT_MODEL_NAME)
|
||||
values["client"] = Ranker(model_name=values["model"])
|
||||
return values
|
||||
values["model"] = values.get("model", DEFAULT_MODEL_NAME)
|
||||
values["client"] = Ranker(model_name=values["model"])
|
||||
return values
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user