mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 16:36:06 +00:00
Update Cohere Reranker (#4180)
The forward ref annotations don't get updated if we only iimport with type checking --------- Co-authored-by: Abhinav Verma <abhinav_win12@yahoo.co.in>
This commit is contained in:
parent
d84bb02881
commit
84cfa76e00
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Sequence
|
||||
|
||||
from pydantic import root_validator
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
||||
from langchain.schema import Document
|
||||
@ -10,6 +10,13 @@ from langchain.utils import get_from_dict_or_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cohere import Client
|
||||
else:
|
||||
# We do to avoid pydantic annotation issues when actually instantiating
|
||||
# while keeping this import optional
|
||||
try:
|
||||
from cohere import Client
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class CohereRerank(BaseDocumentCompressor):
|
||||
@ -17,7 +24,13 @@ class CohereRerank(BaseDocumentCompressor):
|
||||
top_n: int = 3
|
||||
model: str = "rerank-english-v2.0"
|
||||
|
||||
@root_validator()
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
cohere_api_key = get_from_dict_or_env(
|
||||
|
@ -0,0 +1,8 @@
|
||||
"""Test the cohere reranker."""
|
||||
|
||||
from langchain.retrievers.document_compressors.cohere_rerank import CohereRerank
|
||||
|
||||
|
||||
def test_cohere_reranker_init() -> None:
|
||||
"""Test the cohere reranker initializes correctly."""
|
||||
CohereRerank()
|
Loading…
Reference in New Issue
Block a user