diff --git a/libs/experimental/langchain_experimental/graph_transformers/diffbot.py b/libs/experimental/langchain_experimental/graph_transformers/diffbot.py index b174b1353de..dda51576c94 100644 --- a/libs/experimental/langchain_experimental/graph_transformers/diffbot.py +++ b/libs/experimental/langchain_experimental/graph_transformers/diffbot.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import requests @@ -6,6 +7,11 @@ from langchain_community.graphs.graph_document import GraphDocument, Node, Relat from langchain_core.documents import Document +class TypeOption(str, Enum): + FACTS = "facts" + ENTITIES = "entities" + + def format_property_key(s: str) -> str: """Formats a string to be used as a property key.""" @@ -141,6 +147,7 @@ class DiffbotGraphTransformer: include_qualifiers: bool = True, include_evidence: bool = True, simplified_schema: bool = True, + extract_types: List[TypeOption] = [TypeOption.FACTS], ) -> None: """ Initialize the graph transformer with various options. @@ -157,6 +164,11 @@ class DiffbotGraphTransformer: Whether to include evidence for the relationships. simplified_schema (bool): Whether to use a simplified schema for relationships. + extract_types (List[TypeOption]): + A list of data types to extract. Only facts or entities + are supported. By default, the option is set to facts. + A fact represents a combination of source and target + nodes with a relationship type. """ self.diffbot_api_key = diffbot_api_key or get_from_env( "diffbot_api_key", "DIFFBOT_API_KEY" @@ -167,6 +179,13 @@ class DiffbotGraphTransformer: self.simplified_schema = None if simplified_schema: self.simplified_schema = SimplifiedSchema() + if not extract_types: + raise ValueError( + "`extract_types` cannot be an empty array. " + "Allowed values are 'facts', 'entities', or both." + ) + + self.extract_types = extract_types def nlp_request(self, text: str) -> Dict[str, Any]: """ @@ -185,7 +204,7 @@ class DiffbotGraphTransformer: "lang": "en", } - FIELDS = "facts" + FIELDS = ",".join(self.extract_types) HOST = "nl.diffbot.com" url = ( f"https://{HOST}/v1/?fields={FIELDS}&" @@ -209,77 +228,97 @@ class DiffbotGraphTransformer: """ # Return empty result if there are no facts - if "facts" not in payload or not payload["facts"]: + if ("facts" not in payload or not payload["facts"]) and ( + "entities" not in payload or not payload["entities"] + ): return GraphDocument(nodes=[], relationships=[], source=document) # Nodes are a custom class because we need to deduplicate nodes_list = NodesList() - # Relationships are a list because we don't deduplicate nor anything else + if "entities" in payload and payload["entities"]: + for record in payload["entities"]: + # Ignore if it doesn't have a type + if not record["allTypes"]: + continue + + # Define source node + source_id = ( + record["allUris"][0] if record["allUris"] else record["name"] + ) + source_label = record["allTypes"][0]["name"].capitalize() + source_name = record["name"] + nodes_list.add_node_property( + (source_id, source_label), {"name": source_name} + ) relationships = list() - for record in payload["facts"]: - # Skip if the fact is below the threshold confidence - if record["confidence"] < self.fact_threshold_confidence: - continue + # Relationships are a list because we don't deduplicate nor anything else + if "facts" in payload and payload["facts"]: + for record in payload["facts"]: + # Skip if the fact is below the threshold confidence + if record["confidence"] < self.fact_threshold_confidence: + continue - # TODO: It should probably be treated as a node property - if not record["value"]["allTypes"]: - continue + # TODO: It should probably be treated as a node property + if not record["value"]["allTypes"]: + continue - # Define source node - source_id = ( - record["entity"]["allUris"][0] - if record["entity"]["allUris"] - else record["entity"]["name"] - ) - source_label = record["entity"]["allTypes"][0]["name"].capitalize() - source_name = record["entity"]["name"] - source_node = Node(id=source_id, type=source_label) - nodes_list.add_node_property( - (source_id, source_label), {"name": source_name} - ) - - # Define target node - target_id = ( - record["value"]["allUris"][0] - if record["value"]["allUris"] - else record["value"]["name"] - ) - target_label = record["value"]["allTypes"][0]["name"].capitalize() - target_name = record["value"]["name"] - # Some facts are better suited as node properties - if target_label in FACT_TO_PROPERTY_TYPE: + # Define source node + source_id = ( + record["entity"]["allUris"][0] + if record["entity"]["allUris"] + else record["entity"]["name"] + ) + source_label = record["entity"]["allTypes"][0]["name"].capitalize() + source_name = record["entity"]["name"] + source_node = Node(id=source_id, type=source_label) nodes_list.add_node_property( - (source_id, source_label), - {format_property_key(record["property"]["name"]): target_name}, + (source_id, source_label), {"name": source_name} ) - else: # Define relationship - # Define target node object - target_node = Node(id=target_id, type=target_label) - nodes_list.add_node_property( - (target_id, target_label), {"name": target_name} - ) - # Define relationship type - rel_type = record["property"]["name"].replace(" ", "_").upper() - if self.simplified_schema: - rel_type = self.simplified_schema.get_type(rel_type) - # Relationship qualifiers/properties - rel_properties = dict() - relationship_evidence = [el["passage"] for el in record["evidence"]][0] - if self.include_evidence: - rel_properties.update({"evidence": relationship_evidence}) - if self.include_qualifiers and record.get("qualifiers"): - for property in record["qualifiers"]: - prop_key = format_property_key(property["property"]["name"]) - rel_properties[prop_key] = property["value"]["name"] - - relationship = Relationship( - source=source_node, - target=target_node, - type=rel_type, - properties=rel_properties, + # Define target node + target_id = ( + record["value"]["allUris"][0] + if record["value"]["allUris"] + else record["value"]["name"] ) - relationships.append(relationship) + target_label = record["value"]["allTypes"][0]["name"].capitalize() + target_name = record["value"]["name"] + # Some facts are better suited as node properties + if target_label in FACT_TO_PROPERTY_TYPE: + nodes_list.add_node_property( + (source_id, source_label), + {format_property_key(record["property"]["name"]): target_name}, + ) + else: # Define relationship + # Define target node object + target_node = Node(id=target_id, type=target_label) + nodes_list.add_node_property( + (target_id, target_label), {"name": target_name} + ) + # Define relationship type + rel_type = record["property"]["name"].replace(" ", "_").upper() + if self.simplified_schema: + rel_type = self.simplified_schema.get_type(rel_type) + + # Relationship qualifiers/properties + rel_properties = dict() + relationship_evidence = [ + el["passage"] for el in record["evidence"] + ][0] + if self.include_evidence: + rel_properties.update({"evidence": relationship_evidence}) + if self.include_qualifiers and record.get("qualifiers"): + for property in record["qualifiers"]: + prop_key = format_property_key(property["property"]["name"]) + rel_properties[prop_key] = property["value"]["name"] + + relationship = Relationship( + source=source_node, + target=target_node, + type=rel_type, + properties=rel_properties, + ) + relationships.append(relationship) return GraphDocument( nodes=nodes_list.return_node_list(),