mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-21 03:51:42 +00:00
langchain[patch]: expose cohere rerank score, add parent doc param (#16887)
This commit is contained in:
parent
35c1bf339d
commit
02ef9164b5
@ -484,8 +484,8 @@ class ElasticsearchStore(VectorStore):
|
|||||||
from langchain_community.vectorstores.utils import DistanceStrategy
|
from langchain_community.vectorstores.utils import DistanceStrategy
|
||||||
|
|
||||||
vectorstore = ElasticsearchStore(
|
vectorstore = ElasticsearchStore(
|
||||||
|
"langchain-demo",
|
||||||
embedding=OpenAIEmbeddings(),
|
embedding=OpenAIEmbeddings(),
|
||||||
index_name="langchain-demo",
|
|
||||||
es_url="http://localhost:9200",
|
es_url="http://localhost:9200",
|
||||||
distance_strategy="DOT_PRODUCT"
|
distance_strategy="DOT_PRODUCT"
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,3 @@
|
|||||||
|
from langchain.chains.query_constructor.base import load_query_constructor_runnable
|
||||||
|
|
||||||
|
__all__ = ["load_query_constructor_runnable"]
|
@ -323,7 +323,8 @@ def load_query_constructor_runnable(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
llm: BaseLanguageModel to use for the chain.
|
llm: BaseLanguageModel to use for the chain.
|
||||||
document_contents: The contents of the document to be queried.
|
document_contents: Description of the page contents of the document to be
|
||||||
|
queried.
|
||||||
attribute_info: Sequence of attributes in the document.
|
attribute_info: Sequence of attributes in the document.
|
||||||
examples: Optional list of examples to use for the chain.
|
examples: Optional list of examples to use for the chain.
|
||||||
allowed_comparators: Sequence of allowed comparators. Defaults to all
|
allowed_comparators: Sequence of allowed comparators. Defaults to all
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Sequence
|
from copy import deepcopy
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||||
@ -9,23 +10,13 @@ from langchain.callbacks.manager import Callbacks
|
|||||||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from cohere import Client
|
|
||||||
else:
|
|
||||||
# We do to avoid pydantic annotation issues when actually instantiating
|
|
||||||
# while keeping this import optional
|
|
||||||
try:
|
|
||||||
from cohere import Client
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class CohereRerank(BaseDocumentCompressor):
|
class CohereRerank(BaseDocumentCompressor):
|
||||||
"""Document compressor that uses `Cohere Rerank API`."""
|
"""Document compressor that uses `Cohere Rerank API`."""
|
||||||
|
|
||||||
client: Client
|
client: Any
|
||||||
"""Cohere client to use for compressing documents."""
|
"""Cohere client to use for compressing documents."""
|
||||||
top_n: int = 3
|
top_n: Optional[int] = 3
|
||||||
"""Number of documents to return."""
|
"""Number of documents to return."""
|
||||||
model: str = "rerank-english-v2.0"
|
model: str = "rerank-english-v2.0"
|
||||||
"""Model to use for reranking."""
|
"""Model to use for reranking."""
|
||||||
@ -57,6 +48,42 @@ class CohereRerank(BaseDocumentCompressor):
|
|||||||
values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
|
values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
def rerank(
|
||||||
|
self,
|
||||||
|
documents: Sequence[Union[str, Document, dict]],
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
top_n: Optional[int] = -1,
|
||||||
|
max_chunks_per_doc: Optional[int] = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Returns an ordered list of documents ordered by their relevance to the provided query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query to use for reranking.
|
||||||
|
documents: A sequence of documents to rerank.
|
||||||
|
model: The model to use for re-ranking. Default to self.model.
|
||||||
|
top_n : The number of results to return. If None returns all results.
|
||||||
|
Defaults to self.top_n.
|
||||||
|
max_chunks_per_doc : The maximum number of chunks derived from a document.
|
||||||
|
""" # noqa: E501
|
||||||
|
if len(documents) == 0: # to avoid empty api call
|
||||||
|
return []
|
||||||
|
docs = [
|
||||||
|
doc.page_content if isinstance(doc, Document) else doc for doc in documents
|
||||||
|
]
|
||||||
|
model = model or self.model
|
||||||
|
top_n = top_n if (top_n is None or top_n > 0) else self.top_n
|
||||||
|
results = self.client.rerank(
|
||||||
|
query, docs, model, top_n=top_n, max_chunks_per_doc=max_chunks_per_doc
|
||||||
|
)
|
||||||
|
result_dicts = []
|
||||||
|
for res in results:
|
||||||
|
result_dicts.append(
|
||||||
|
{"index": res.index, "relevance_score": res.relevance_score}
|
||||||
|
)
|
||||||
|
return result_dicts
|
||||||
|
|
||||||
def compress_documents(
|
def compress_documents(
|
||||||
self,
|
self,
|
||||||
documents: Sequence[Document],
|
documents: Sequence[Document],
|
||||||
@ -74,16 +101,10 @@ class CohereRerank(BaseDocumentCompressor):
|
|||||||
Returns:
|
Returns:
|
||||||
A sequence of compressed documents.
|
A sequence of compressed documents.
|
||||||
"""
|
"""
|
||||||
if len(documents) == 0: # to avoid empty api call
|
compressed = []
|
||||||
return []
|
for res in self.rerank(documents, query):
|
||||||
doc_list = list(documents)
|
doc = documents[res["index"]]
|
||||||
_docs = [d.page_content for d in doc_list]
|
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
|
||||||
results = self.client.rerank(
|
doc_copy.metadata["relevance_score"] = res["relevance_score"]
|
||||||
model=self.model, query=query, documents=_docs, top_n=self.top_n
|
compressed.append(doc_copy)
|
||||||
)
|
return compressed
|
||||||
final_results = []
|
|
||||||
for r in results:
|
|
||||||
doc = doc_list[r.index]
|
|
||||||
doc.metadata["relevance_score"] = r.relevance_score
|
|
||||||
final_results.append(doc)
|
|
||||||
return final_results
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Sequence
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
@ -31,17 +31,16 @@ class ParentDocumentRetriever(MultiVectorRetriever):
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
# Imports
|
|
||||||
from langchain_community.vectorstores import Chroma
|
|
||||||
from langchain_community.embeddings import OpenAIEmbeddings
|
from langchain_community.embeddings import OpenAIEmbeddings
|
||||||
|
from langchain_community.vectorstores import Chroma
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from langchain.storage import InMemoryStore
|
from langchain.storage import InMemoryStore
|
||||||
|
|
||||||
# This text splitter is used to create the parent documents
|
# This text splitter is used to create the parent documents
|
||||||
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000)
|
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, add_start_index=True)
|
||||||
# This text splitter is used to create the child documents
|
# This text splitter is used to create the child documents
|
||||||
# It should create documents smaller than the parent
|
# It should create documents smaller than the parent
|
||||||
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
|
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400, add_start_index=True)
|
||||||
# The vectorstore to use to index the child chunks
|
# The vectorstore to use to index the child chunks
|
||||||
vectorstore = Chroma(embedding_function=OpenAIEmbeddings())
|
vectorstore = Chroma(embedding_function=OpenAIEmbeddings())
|
||||||
# The storage layer for the parent documents
|
# The storage layer for the parent documents
|
||||||
@ -54,7 +53,7 @@ class ParentDocumentRetriever(MultiVectorRetriever):
|
|||||||
child_splitter=child_splitter,
|
child_splitter=child_splitter,
|
||||||
parent_splitter=parent_splitter,
|
parent_splitter=parent_splitter,
|
||||||
)
|
)
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
child_splitter: TextSplitter
|
child_splitter: TextSplitter
|
||||||
"""The text splitter to use to create child documents."""
|
"""The text splitter to use to create child documents."""
|
||||||
@ -65,6 +64,11 @@ class ParentDocumentRetriever(MultiVectorRetriever):
|
|||||||
"""The text splitter to use to create parent documents.
|
"""The text splitter to use to create parent documents.
|
||||||
If none, then the parent documents will be the raw documents passed in."""
|
If none, then the parent documents will be the raw documents passed in."""
|
||||||
|
|
||||||
|
child_metadata_fields: Optional[Sequence[str]] = None
|
||||||
|
"""Metadata fields to leave in child documents. If None, leave all parent document
|
||||||
|
metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
def add_documents(
|
def add_documents(
|
||||||
self,
|
self,
|
||||||
documents: List[Document],
|
documents: List[Document],
|
||||||
@ -76,7 +80,7 @@ class ParentDocumentRetriever(MultiVectorRetriever):
|
|||||||
Args:
|
Args:
|
||||||
documents: List of documents to add
|
documents: List of documents to add
|
||||||
ids: Optional list of ids for documents. If provided should be the same
|
ids: Optional list of ids for documents. If provided should be the same
|
||||||
length as the list of documents. Can provided if parent documents
|
length as the list of documents. Can be provided if parent documents
|
||||||
are already in the document store and you don't want to re-add
|
are already in the document store and you don't want to re-add
|
||||||
to the docstore. If not provided, random UUIDs will be used as
|
to the docstore. If not provided, random UUIDs will be used as
|
||||||
ids.
|
ids.
|
||||||
@ -106,6 +110,11 @@ class ParentDocumentRetriever(MultiVectorRetriever):
|
|||||||
for i, doc in enumerate(documents):
|
for i, doc in enumerate(documents):
|
||||||
_id = doc_ids[i]
|
_id = doc_ids[i]
|
||||||
sub_docs = self.child_splitter.split_documents([doc])
|
sub_docs = self.child_splitter.split_documents([doc])
|
||||||
|
if self.child_metadata_fields is not None:
|
||||||
|
for _doc in sub_docs:
|
||||||
|
_doc.metadata = {
|
||||||
|
k: _doc.metadata[k] for k in self.child_metadata_fields
|
||||||
|
}
|
||||||
for _doc in sub_docs:
|
for _doc in sub_docs:
|
||||||
_doc.metadata[self.id_key] = _id
|
_doc.metadata[self.id_key] = _id
|
||||||
docs.extend(sub_docs)
|
docs.extend(sub_docs)
|
||||||
|
@ -649,7 +649,7 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
Must be the name of the single provided function or
|
Must be the name of the single provided function or
|
||||||
"auto" to automatically determine which function to call
|
"auto" to automatically determine which function to call
|
||||||
(if any).
|
(if any).
|
||||||
kwargs: Any additional parameters to pass to the
|
**kwargs: Any additional parameters to pass to the
|
||||||
:class:`~langchain.runnable.Runnable` constructor.
|
:class:`~langchain.runnable.Runnable` constructor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -701,22 +701,21 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
"auto" to automatically determine which function to call
|
"auto" to automatically determine which function to call
|
||||||
(if any), or a dict of the form:
|
(if any), or a dict of the form:
|
||||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||||
kwargs: Any additional parameters to pass to the
|
**kwargs: Any additional parameters to pass to the
|
||||||
:class:`~langchain.runnable.Runnable` constructor.
|
:class:`~langchain.runnable.Runnable` constructor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||||
if tool_choice is not None:
|
if tool_choice is not None:
|
||||||
if isinstance(tool_choice, str) and tool_choice not in ("auto", "none"):
|
if isinstance(tool_choice, str) and (tool_choice not in ("auto", "none")):
|
||||||
tool_choice = {"type": "function", "function": {"name": tool_choice}}
|
tool_choice = {"type": "function", "function": {"name": tool_choice}}
|
||||||
if isinstance(tool_choice, dict) and len(formatted_tools) != 1:
|
if isinstance(tool_choice, dict) and (len(formatted_tools) != 1):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"When specifying `tool_choice`, you must provide exactly one "
|
"When specifying `tool_choice`, you must provide exactly one "
|
||||||
f"tool. Received {len(formatted_tools)} tools."
|
f"tool. Received {len(formatted_tools)} tools."
|
||||||
)
|
)
|
||||||
if (
|
if isinstance(tool_choice, dict) and (
|
||||||
isinstance(tool_choice, dict)
|
formatted_tools[0]["function"]["name"]
|
||||||
and formatted_tools[0]["function"]["name"]
|
|
||||||
!= tool_choice["function"]["name"]
|
!= tool_choice["function"]["name"]
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -724,7 +723,4 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
f"provided tool was {formatted_tools[0]['function']['name']}."
|
f"provided tool was {formatted_tools[0]['function']['name']}."
|
||||||
)
|
)
|
||||||
kwargs["tool_choice"] = tool_choice
|
kwargs["tool_choice"] = tool_choice
|
||||||
return super().bind(
|
return super().bind(tools=formatted_tools, **kwargs)
|
||||||
tools=formatted_tools,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user