diff --git a/libs/partners/cohere/langchain_cohere/rerank.py b/libs/partners/cohere/langchain_cohere/rerank.py index 5c8c2bcfc8d..f946b3ea366 100644 --- a/libs/partners/cohere/langchain_cohere/rerank.py +++ b/libs/partners/cohere/langchain_cohere/rerank.py @@ -69,10 +69,14 @@ class CohereRerank(BaseDocumentCompressor): model = model or self.model top_n = top_n if (top_n is None or top_n > 0) else self.top_n results = self.client.rerank( - query, docs, model, top_n=top_n, max_chunks_per_doc=max_chunks_per_doc + query=query, + documents=docs, + model=model, + top_n=top_n, + max_chunks_per_doc=max_chunks_per_doc, ) result_dicts = [] - for res in results: + for res in results.results: result_dicts.append( {"index": res.index, "relevance_score": res.relevance_score} ) diff --git a/libs/partners/cohere/tests/integration_tests/test_rerank.py b/libs/partners/cohere/tests/integration_tests/test_rerank.py new file mode 100644 index 00000000000..f9a2ebd0aec --- /dev/null +++ b/libs/partners/cohere/tests/integration_tests/test_rerank.py @@ -0,0 +1,16 @@ +"""Test Cohere reranks.""" +from langchain_core.documents import Document + +from langchain_cohere import CohereRerank + + +def test_langchain_cohere_rerank_documents() -> None: + """Test cohere rerank.""" + rerank = CohereRerank() + test_documents = [ + Document(page_content="This is a test document."), + Document(page_content="Another test document."), + ] + test_query = "Test query" + results = rerank.rerank(test_documents, test_query) + assert len(results) == 2