Add the extract types to diffbot graph transformer (#21315)

Before you could only extract triples (diffbot calls it facts) from
diffbot to avoid isolated nodes. However, sometimes isolated nodes can
still be useful like for prefiltering, so we want to allow users to
extract them if they want. Default behaviour is unchanged.
This commit is contained in:
Tomaz Bratanic 2024-05-06 15:19:52 +02:00 committed by GitHub
parent c038991590
commit 5b6d1a907d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,3 +1,4 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import requests import requests
@ -6,6 +7,11 @@ from langchain_community.graphs.graph_document import GraphDocument, Node, Relat
from langchain_core.documents import Document from langchain_core.documents import Document
class TypeOption(str, Enum):
FACTS = "facts"
ENTITIES = "entities"
def format_property_key(s: str) -> str: def format_property_key(s: str) -> str:
"""Formats a string to be used as a property key.""" """Formats a string to be used as a property key."""
@ -141,6 +147,7 @@ class DiffbotGraphTransformer:
include_qualifiers: bool = True, include_qualifiers: bool = True,
include_evidence: bool = True, include_evidence: bool = True,
simplified_schema: bool = True, simplified_schema: bool = True,
extract_types: List[TypeOption] = [TypeOption.FACTS],
) -> None: ) -> None:
""" """
Initialize the graph transformer with various options. Initialize the graph transformer with various options.
@ -157,6 +164,11 @@ class DiffbotGraphTransformer:
Whether to include evidence for the relationships. Whether to include evidence for the relationships.
simplified_schema (bool): simplified_schema (bool):
Whether to use a simplified schema for relationships. Whether to use a simplified schema for relationships.
extract_types (List[TypeOption]):
A list of data types to extract. Only facts or entities
are supported. By default, the option is set to facts.
A fact represents a combination of source and target
nodes with a relationship type.
""" """
self.diffbot_api_key = diffbot_api_key or get_from_env( self.diffbot_api_key = diffbot_api_key or get_from_env(
"diffbot_api_key", "DIFFBOT_API_KEY" "diffbot_api_key", "DIFFBOT_API_KEY"
@ -167,6 +179,13 @@ class DiffbotGraphTransformer:
self.simplified_schema = None self.simplified_schema = None
if simplified_schema: if simplified_schema:
self.simplified_schema = SimplifiedSchema() self.simplified_schema = SimplifiedSchema()
if not extract_types:
raise ValueError(
"`extract_types` cannot be an empty array. "
"Allowed values are 'facts', 'entities', or both."
)
self.extract_types = extract_types
def nlp_request(self, text: str) -> Dict[str, Any]: def nlp_request(self, text: str) -> Dict[str, Any]:
""" """
@ -185,7 +204,7 @@ class DiffbotGraphTransformer:
"lang": "en", "lang": "en",
} }
FIELDS = "facts" FIELDS = ",".join(self.extract_types)
HOST = "nl.diffbot.com" HOST = "nl.diffbot.com"
url = ( url = (
f"https://{HOST}/v1/?fields={FIELDS}&" f"https://{HOST}/v1/?fields={FIELDS}&"
@ -209,13 +228,31 @@ class DiffbotGraphTransformer:
""" """
# Return empty result if there are no facts # Return empty result if there are no facts
if "facts" not in payload or not payload["facts"]: if ("facts" not in payload or not payload["facts"]) and (
"entities" not in payload or not payload["entities"]
):
return GraphDocument(nodes=[], relationships=[], source=document) return GraphDocument(nodes=[], relationships=[], source=document)
# Nodes are a custom class because we need to deduplicate # Nodes are a custom class because we need to deduplicate
nodes_list = NodesList() nodes_list = NodesList()
# Relationships are a list because we don't deduplicate nor anything else if "entities" in payload and payload["entities"]:
for record in payload["entities"]:
# Ignore if it doesn't have a type
if not record["allTypes"]:
continue
# Define source node
source_id = (
record["allUris"][0] if record["allUris"] else record["name"]
)
source_label = record["allTypes"][0]["name"].capitalize()
source_name = record["name"]
nodes_list.add_node_property(
(source_id, source_label), {"name": source_name}
)
relationships = list() relationships = list()
# Relationships are a list because we don't deduplicate nor anything else
if "facts" in payload and payload["facts"]:
for record in payload["facts"]: for record in payload["facts"]:
# Skip if the fact is below the threshold confidence # Skip if the fact is below the threshold confidence
if record["confidence"] < self.fact_threshold_confidence: if record["confidence"] < self.fact_threshold_confidence:
@ -265,7 +302,9 @@ class DiffbotGraphTransformer:
# Relationship qualifiers/properties # Relationship qualifiers/properties
rel_properties = dict() rel_properties = dict()
relationship_evidence = [el["passage"] for el in record["evidence"]][0] relationship_evidence = [
el["passage"] for el in record["evidence"]
][0]
if self.include_evidence: if self.include_evidence:
rel_properties.update({"evidence": relationship_evidence}) rel_properties.update({"evidence": relationship_evidence})
if self.include_qualifiers and record.get("qualifiers"): if self.include_qualifiers and record.get("qualifiers"):