Add sentiment and confidence levels to diffbotgraphtransformer (#21590)

Co-authored-by: Erick Friis <erickfriis@gmail.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Tomaz Bratanic 2024-05-14 01:00:52 +02:00 committed by GitHub
parent 526ba235f3
commit 89ff6a3d3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,6 +10,7 @@ from langchain_core.documents import Document
class TypeOption(str, Enum):
FACTS = "facts"
ENTITIES = "entities"
SENTIMENT = "sentiment"
def format_property_key(s: str) -> str:
@ -148,6 +149,8 @@ class DiffbotGraphTransformer:
include_evidence: bool = True,
simplified_schema: bool = True,
extract_types: List[TypeOption] = [TypeOption.FACTS],
*,
include_confidence: bool = False,
) -> None:
"""
Initialize the graph transformer with various options.
@ -165,10 +168,12 @@ class DiffbotGraphTransformer:
simplified_schema (bool):
Whether to use a simplified schema for relationships.
extract_types (List[TypeOption]):
A list of data types to extract. Only facts or entities
are supported. By default, the option is set to facts.
A fact represents a combination of source and target
nodes with a relationship type.
A list of data types to extract. Facts, entities, and
sentiment are supported. By default, the option is
set to facts. A fact represents a combination of
source and target nodes with a relationship type.
include_confidence (bool):
Whether to include confidence scores on nodes and rels
"""
self.diffbot_api_key = diffbot_api_key or get_from_env(
"diffbot_api_key", "DIFFBOT_API_KEY"
@ -176,6 +181,7 @@ class DiffbotGraphTransformer:
self.fact_threshold_confidence = fact_confidence_threshold
self.include_qualifiers = include_qualifiers
self.include_evidence = include_evidence
self.include_confidence = include_confidence
self.simplified_schema = None
if simplified_schema:
self.simplified_schema = SimplifiedSchema()
@ -250,6 +256,17 @@ class DiffbotGraphTransformer:
nodes_list.add_node_property(
(source_id, source_label), {"name": source_name}
)
if record.get("sentiment") is not None:
nodes_list.add_node_property(
(source_id, source_label),
{"sentiment": record.get("sentiment")},
)
if self.include_confidence:
nodes_list.add_node_property(
(source_id, source_label),
{"confidence": record.get("confidence")},
)
relationships = list()
# Relationships are a list because we don't deduplicate nor anything else
if "facts" in payload and payload["facts"]:
@ -307,6 +324,8 @@ class DiffbotGraphTransformer:
][0]
if self.include_evidence:
rel_properties.update({"evidence": relationship_evidence})
if self.include_confidence:
rel_properties.update({"confidence": record["confidence"]})
if self.include_qualifiers and record.get("qualifiers"):
for property in record["qualifiers"]:
prop_key = format_property_key(property["property"]["name"])