mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 10:13:29 +00:00
langchain[patch]: expose cohere rerank score, add parent doc param (#16887)
This commit is contained in:
parent
35c1bf339d
commit
02ef9164b5
@ -484,8 +484,8 @@ class ElasticsearchStore(VectorStore):
|
||||
from langchain_community.vectorstores.utils import DistanceStrategy
|
||||
|
||||
vectorstore = ElasticsearchStore(
|
||||
"langchain-demo",
|
||||
embedding=OpenAIEmbeddings(),
|
||||
index_name="langchain-demo",
|
||||
es_url="http://localhost:9200",
|
||||
distance_strategy="DOT_PRODUCT"
|
||||
)
|
||||
|
@ -0,0 +1,3 @@
|
||||
from langchain.chains.query_constructor.base import load_query_constructor_runnable
|
||||
|
||||
__all__ = ["load_query_constructor_runnable"]
|
@ -323,7 +323,8 @@ def load_query_constructor_runnable(
|
||||
|
||||
Args:
|
||||
llm: BaseLanguageModel to use for the chain.
|
||||
document_contents: The contents of the document to be queried.
|
||||
document_contents: Description of the page contents of the document to be
|
||||
queried.
|
||||
attribute_info: Sequence of attributes in the document.
|
||||
examples: Optional list of examples to use for the chain.
|
||||
allowed_comparators: Sequence of allowed comparators. Defaults to all
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||
@ -9,23 +10,13 @@ from langchain.callbacks.manager import Callbacks
|
||||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cohere import Client
|
||||
else:
|
||||
# We do to avoid pydantic annotation issues when actually instantiating
|
||||
# while keeping this import optional
|
||||
try:
|
||||
from cohere import Client
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class CohereRerank(BaseDocumentCompressor):
|
||||
"""Document compressor that uses `Cohere Rerank API`."""
|
||||
|
||||
client: Client
|
||||
client: Any
|
||||
"""Cohere client to use for compressing documents."""
|
||||
top_n: int = 3
|
||||
top_n: Optional[int] = 3
|
||||
"""Number of documents to return."""
|
||||
model: str = "rerank-english-v2.0"
|
||||
"""Model to use for reranking."""
|
||||
@ -57,6 +48,42 @@ class CohereRerank(BaseDocumentCompressor):
|
||||
values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
|
||||
return values
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
documents: Sequence[Union[str, Document, dict]],
|
||||
query: str,
|
||||
*,
|
||||
model: Optional[str] = None,
|
||||
top_n: Optional[int] = -1,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Returns an ordered list of documents ordered by their relevance to the provided query.
|
||||
|
||||
Args:
|
||||
query: The query to use for reranking.
|
||||
documents: A sequence of documents to rerank.
|
||||
model: The model to use for re-ranking. Default to self.model.
|
||||
top_n : The number of results to return. If None returns all results.
|
||||
Defaults to self.top_n.
|
||||
max_chunks_per_doc : The maximum number of chunks derived from a document.
|
||||
""" # noqa: E501
|
||||
if len(documents) == 0: # to avoid empty api call
|
||||
return []
|
||||
docs = [
|
||||
doc.page_content if isinstance(doc, Document) else doc for doc in documents
|
||||
]
|
||||
model = model or self.model
|
||||
top_n = top_n if (top_n is None or top_n > 0) else self.top_n
|
||||
results = self.client.rerank(
|
||||
query, docs, model, top_n=top_n, max_chunks_per_doc=max_chunks_per_doc
|
||||
)
|
||||
result_dicts = []
|
||||
for res in results:
|
||||
result_dicts.append(
|
||||
{"index": res.index, "relevance_score": res.relevance_score}
|
||||
)
|
||||
return result_dicts
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
@ -74,16 +101,10 @@ class CohereRerank(BaseDocumentCompressor):
|
||||
Returns:
|
||||
A sequence of compressed documents.
|
||||
"""
|
||||
if len(documents) == 0: # to avoid empty api call
|
||||
return []
|
||||
doc_list = list(documents)
|
||||
_docs = [d.page_content for d in doc_list]
|
||||
results = self.client.rerank(
|
||||
model=self.model, query=query, documents=_docs, top_n=self.top_n
|
||||
)
|
||||
final_results = []
|
||||
for r in results:
|
||||
doc = doc_list[r.index]
|
||||
doc.metadata["relevance_score"] = r.relevance_score
|
||||
final_results.append(doc)
|
||||
return final_results
|
||||
compressed = []
|
||||
for res in self.rerank(documents, query):
|
||||
doc = documents[res["index"]]
|
||||
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
|
||||
doc_copy.metadata["relevance_score"] = res["relevance_score"]
|
||||
compressed.append(doc_copy)
|
||||
return compressed
|
||||
|
@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@ -31,17 +31,16 @@ class ParentDocumentRetriever(MultiVectorRetriever):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Imports
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.storage import InMemoryStore
|
||||
|
||||
# This text splitter is used to create the parent documents
|
||||
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000)
|
||||
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, add_start_index=True)
|
||||
# This text splitter is used to create the child documents
|
||||
# It should create documents smaller than the parent
|
||||
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
|
||||
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400, add_start_index=True)
|
||||
# The vectorstore to use to index the child chunks
|
||||
vectorstore = Chroma(embedding_function=OpenAIEmbeddings())
|
||||
# The storage layer for the parent documents
|
||||
@ -54,7 +53,7 @@ class ParentDocumentRetriever(MultiVectorRetriever):
|
||||
child_splitter=child_splitter,
|
||||
parent_splitter=parent_splitter,
|
||||
)
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
child_splitter: TextSplitter
|
||||
"""The text splitter to use to create child documents."""
|
||||
@ -65,6 +64,11 @@ class ParentDocumentRetriever(MultiVectorRetriever):
|
||||
"""The text splitter to use to create parent documents.
|
||||
If none, then the parent documents will be the raw documents passed in."""
|
||||
|
||||
child_metadata_fields: Optional[Sequence[str]] = None
|
||||
"""Metadata fields to leave in child documents. If None, leave all parent document
|
||||
metadata.
|
||||
"""
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
documents: List[Document],
|
||||
@ -76,7 +80,7 @@ class ParentDocumentRetriever(MultiVectorRetriever):
|
||||
Args:
|
||||
documents: List of documents to add
|
||||
ids: Optional list of ids for documents. If provided should be the same
|
||||
length as the list of documents. Can provided if parent documents
|
||||
length as the list of documents. Can be provided if parent documents
|
||||
are already in the document store and you don't want to re-add
|
||||
to the docstore. If not provided, random UUIDs will be used as
|
||||
ids.
|
||||
@ -106,6 +110,11 @@ class ParentDocumentRetriever(MultiVectorRetriever):
|
||||
for i, doc in enumerate(documents):
|
||||
_id = doc_ids[i]
|
||||
sub_docs = self.child_splitter.split_documents([doc])
|
||||
if self.child_metadata_fields is not None:
|
||||
for _doc in sub_docs:
|
||||
_doc.metadata = {
|
||||
k: _doc.metadata[k] for k in self.child_metadata_fields
|
||||
}
|
||||
for _doc in sub_docs:
|
||||
_doc.metadata[self.id_key] = _id
|
||||
docs.extend(sub_docs)
|
||||
|
@ -649,7 +649,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
Must be the name of the single provided function or
|
||||
"auto" to automatically determine which function to call
|
||||
(if any).
|
||||
kwargs: Any additional parameters to pass to the
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
@ -701,22 +701,21 @@ class ChatOpenAI(BaseChatModel):
|
||||
"auto" to automatically determine which function to call
|
||||
(if any), or a dict of the form:
|
||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||
kwargs: Any additional parameters to pass to the
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
if tool_choice is not None:
|
||||
if isinstance(tool_choice, str) and tool_choice not in ("auto", "none"):
|
||||
if isinstance(tool_choice, str) and (tool_choice not in ("auto", "none")):
|
||||
tool_choice = {"type": "function", "function": {"name": tool_choice}}
|
||||
if isinstance(tool_choice, dict) and len(formatted_tools) != 1:
|
||||
if isinstance(tool_choice, dict) and (len(formatted_tools) != 1):
|
||||
raise ValueError(
|
||||
"When specifying `tool_choice`, you must provide exactly one "
|
||||
f"tool. Received {len(formatted_tools)} tools."
|
||||
)
|
||||
if (
|
||||
isinstance(tool_choice, dict)
|
||||
and formatted_tools[0]["function"]["name"]
|
||||
if isinstance(tool_choice, dict) and (
|
||||
formatted_tools[0]["function"]["name"]
|
||||
!= tool_choice["function"]["name"]
|
||||
):
|
||||
raise ValueError(
|
||||
@ -724,7 +723,4 @@ class ChatOpenAI(BaseChatModel):
|
||||
f"provided tool was {formatted_tools[0]['function']['name']}."
|
||||
)
|
||||
kwargs["tool_choice"] = tool_choice
|
||||
return super().bind(
|
||||
tools=formatted_tools,
|
||||
**kwargs,
|
||||
)
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user