From 22638e5927f7874fab9416821d45cf22d87f3f75 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Mon, 12 Feb 2024 17:21:53 -0800 Subject: [PATCH] community[patch]: give reranker default client val (#17289) --- .../document_compressors/cohere_rerank.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py index b36ea305c78..3374f7906b9 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py @@ -14,14 +14,15 @@ from langchain.utils import get_from_dict_or_env class CohereRerank(BaseDocumentCompressor): """Document compressor that uses `Cohere Rerank API`.""" - client: Any + client: Any = None """Cohere client to use for compressing documents.""" top_n: Optional[int] = 3 """Number of documents to return.""" model: str = "rerank-english-v2.0" """Model to use for reranking.""" - cohere_api_key: Optional[str] = None + """Cohere API key. Must be specified directly or via environment variable + COHERE_API_KEY.""" user_agent: str = "langchain" """Identifier for the application making the request.""" @@ -34,18 +35,19 @@ class CohereRerank(BaseDocumentCompressor): @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - try: - import cohere - except ImportError: - raise ImportError( - "Could not import cohere python package. " - "Please install it with `pip install cohere`." + if not values.get("client"): + try: + import cohere + except ImportError: + raise ImportError( + "Could not import cohere python package. " + "Please install it with `pip install cohere`." + ) + cohere_api_key = get_from_dict_or_env( + values, "cohere_api_key", "COHERE_API_KEY" ) - cohere_api_key = get_from_dict_or_env( - values, "cohere_api_key", "COHERE_API_KEY" - ) - client_name = values.get("user_agent", "langchain") - values["client"] = cohere.Client(cohere_api_key, client_name=client_name) + client_name = values.get("user_agent", "langchain") + values["client"] = cohere.Client(cohere_api_key, client_name=client_name) return values def rerank(