mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 04:29:09 +00:00
cohere[patch]: Fix cohere rerank (#19624)
Fix cohere rerank inspired by https://github.com/langchain-ai/langchain/pull/19486
This commit is contained in:
parent
8ab7bb3166
commit
85f57ab4cd
@ -69,10 +69,14 @@ class CohereRerank(BaseDocumentCompressor):
|
||||
model = model or self.model
|
||||
top_n = top_n if (top_n is None or top_n > 0) else self.top_n
|
||||
results = self.client.rerank(
|
||||
query, docs, model, top_n=top_n, max_chunks_per_doc=max_chunks_per_doc
|
||||
query=query,
|
||||
documents=docs,
|
||||
model=model,
|
||||
top_n=top_n,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
)
|
||||
result_dicts = []
|
||||
for res in results:
|
||||
for res in results.results:
|
||||
result_dicts.append(
|
||||
{"index": res.index, "relevance_score": res.relevance_score}
|
||||
)
|
||||
|
16
libs/partners/cohere/tests/integration_tests/test_rerank.py
Normal file
16
libs/partners/cohere/tests/integration_tests/test_rerank.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""Test Cohere reranks."""
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_cohere import CohereRerank
|
||||
|
||||
|
||||
def test_langchain_cohere_rerank_documents() -> None:
|
||||
"""Test cohere rerank."""
|
||||
rerank = CohereRerank()
|
||||
test_documents = [
|
||||
Document(page_content="This is a test document."),
|
||||
Document(page_content="Another test document."),
|
||||
]
|
||||
test_query = "Test query"
|
||||
results = rerank.rerank(test_documents, test_query)
|
||||
assert len(results) == 2
|
Loading…
Reference in New Issue
Block a user