mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 16:36:06 +00:00
Update Cohere Reranker (#4180)
The forward ref annotations don't get updated if we only iimport with type checking --------- Co-authored-by: Abhinav Verma <abhinav_win12@yahoo.co.in>
This commit is contained in:
parent
d84bb02881
commit
84cfa76e00
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING, Dict, Sequence
|
from typing import TYPE_CHECKING, Dict, Sequence
|
||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import Extra, root_validator
|
||||||
|
|
||||||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
@ -10,6 +10,13 @@ from langchain.utils import get_from_dict_or_env
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from cohere import Client
|
from cohere import Client
|
||||||
|
else:
|
||||||
|
# We do to avoid pydantic annotation issues when actually instantiating
|
||||||
|
# while keeping this import optional
|
||||||
|
try:
|
||||||
|
from cohere import Client
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CohereRerank(BaseDocumentCompressor):
|
class CohereRerank(BaseDocumentCompressor):
|
||||||
@ -17,7 +24,13 @@ class CohereRerank(BaseDocumentCompressor):
|
|||||||
top_n: int = 3
|
top_n: int = 3
|
||||||
model: str = "rerank-english-v2.0"
|
model: str = "rerank-english-v2.0"
|
||||||
|
|
||||||
@root_validator()
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
cohere_api_key = get_from_dict_or_env(
|
cohere_api_key = get_from_dict_or_env(
|
||||||
|
@ -0,0 +1,8 @@
|
|||||||
|
"""Test the cohere reranker."""
|
||||||
|
|
||||||
|
from langchain.retrievers.document_compressors.cohere_rerank import CohereRerank
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_reranker_init() -> None:
|
||||||
|
"""Test the cohere reranker initializes correctly."""
|
||||||
|
CohereRerank()
|
Loading…
Reference in New Issue
Block a user