Switch to md5 for deduplication in neo4j integrations (#18846)

Deduplicate documents using MD5 of the page_content. Also allows for
custom deduplication with graph ingestion method by providing metadata
id attribute

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Tomaz Bratanic 2024-03-09 22:28:55 +01:00 committed by GitHub
parent 246724faab
commit a28be31a96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 7 deletions

View File

@ -1,3 +1,4 @@
from hashlib import md5
from typing import Any, Dict, List, Optional
from langchain_core.utils import get_from_dict_or_env
@ -39,7 +40,7 @@ RETURN {start: label, type: property, end: toString(other_node)} AS output
"""
include_docs_query = (
"CREATE (d:Document) "
"MERGE (d:Document {id:$document.metadata.id}) "
"SET d.text = $document.page_content "
"SET d += $document.metadata "
"WITH d "
@ -339,7 +340,10 @@ class Neo4jGraph(GraphStore):
including nodes, relationships, and the source document information.
- include_source (bool, optional): If True, stores the source document
and links it to nodes in the graph using the MENTIONS relationship.
This is useful for tracing back the origin of data. Defaults to False.
This is useful for tracing back the origin of data. Merges source
documents based on the `id` property from the source document metadata
if available; otherwise it calculates the MD5 hash of `page_content`
for merging process. Defaults to False.
- baseEntityLabel (bool, optional): If True, each newly created node
gets a secondary __Entity__ label, which is indexed and improves import
speed and performance. Defaults to False.
@ -365,6 +369,11 @@ class Neo4jGraph(GraphStore):
node_import_query = _get_node_import_query(baseEntityLabel, include_source)
rel_import_query = _get_rel_import_query(baseEntityLabel)
for document in graph_documents:
if not document.source.metadata.get("id"):
document.source.metadata["id"] = md5(
document.source.page_content.encode("utf-8")
).hexdigest()
# Import nodes
self.query(
node_import_query,

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import enum
import logging
import os
import uuid
from hashlib import md5
from typing import (
Any,
Callable,
@ -434,7 +434,7 @@ class Neo4jVector(VectorStore):
**kwargs: Any,
) -> Neo4jVector:
if ids is None:
ids = [str(uuid.uuid1()) for _ in texts]
ids = [md5(text.encode("utf-8")).hexdigest() for text in texts]
if not metadatas:
metadatas = [{} for _ in texts]
@ -501,7 +501,7 @@ class Neo4jVector(VectorStore):
kwargs: vectorstore specific parameters
"""
if ids is None:
ids = [str(uuid.uuid1()) for _ in texts]
ids = [md5(text.encode("utf-8")).hexdigest() for text in texts]
if not metadatas:
metadatas = [{} for _ in texts]

View File

@ -1,7 +1,6 @@
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.agent_toolkits import SQLDatabaseToolkit, create_sql_agent
from langchain_community.utilities.sql_database import SQLDatabase
from langchain.agents import create_sql_agent
from tests.unit_tests.llms.fake_llm import FakeLLM