mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 23:00:00 +00:00
Fix for SVM retriever discarding document metadata (#9141)
As stated in the title the SVM retriever discarded the metadata of passed in docs. This code fixes that. I also added one unit test that should test that. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
bace17e0aa
commit
00bf472265
@ -38,6 +38,8 @@ class SVMRetriever(BaseRetriever):
|
||||
"""Index of embeddings."""
|
||||
texts: List[str]
|
||||
"""List of texts to index."""
|
||||
metadatas: Optional[List[dict]] = None
|
||||
"""List of metadatas corresponding with each text."""
|
||||
k: int = 4
|
||||
"""Number of results to return."""
|
||||
relevancy_threshold: Optional[float] = None
|
||||
@ -51,10 +53,20 @@ class SVMRetriever(BaseRetriever):
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls, texts: List[str], embeddings: Embeddings, **kwargs: Any
|
||||
cls,
|
||||
texts: List[str],
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> SVMRetriever:
|
||||
index = create_index(texts, embeddings)
|
||||
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
|
||||
return cls(
|
||||
embeddings=embeddings,
|
||||
index=index,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
@ -64,7 +76,9 @@ class SVMRetriever(BaseRetriever):
|
||||
**kwargs: Any,
|
||||
) -> SVMRetriever:
|
||||
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
|
||||
return cls.from_texts(texts=texts, embeddings=embeddings, **kwargs)
|
||||
return cls.from_texts(
|
||||
texts=texts, embeddings=embeddings, metadatas=metadatas, **kwargs
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
@ -108,5 +122,7 @@ class SVMRetriever(BaseRetriever):
|
||||
self.relevancy_threshold is None
|
||||
or normalized_similarities[row] >= self.relevancy_threshold
|
||||
):
|
||||
top_k_results.append(Document(page_content=self.texts[row - 1]))
|
||||
metadata = self.metadatas[row - 1] if self.metadatas else {}
|
||||
doc = Document(page_content=self.texts[row - 1], metadata=metadata)
|
||||
top_k_results.append(doc)
|
||||
return top_k_results
|
||||
|
@ -25,3 +25,18 @@ class TestSVMRetriever:
|
||||
documents=input_docs, embeddings=FakeEmbeddings(size=100)
|
||||
)
|
||||
assert len(svm_retriever.texts) == 3
|
||||
|
||||
@pytest.mark.requires("sklearn")
|
||||
def test_metadata_persists(self) -> None:
|
||||
input_docs = [
|
||||
Document(page_content="I have a pen.", metadata={"foo": "bar"}),
|
||||
Document(page_content="How about you?", metadata={"foo": "baz"}),
|
||||
Document(page_content="I have a bag.", metadata={"foo": "qux"}),
|
||||
]
|
||||
svm_retriever = SVMRetriever.from_documents(
|
||||
documents=input_docs, embeddings=FakeEmbeddings(size=100)
|
||||
)
|
||||
query = "Have anything?"
|
||||
output_docs = svm_retriever.get_relevant_documents(query=query)
|
||||
for doc in output_docs:
|
||||
assert "foo" in doc.metadata
|
||||
|
Loading…
Reference in New Issue
Block a user