Add relik transformer config (#25019)

This commit is contained in:
Tomaz Bratanic 2024-08-03 14:41:45 +02:00 committed by GitHub
parent 1dcee68cb8
commit f9a11a9197
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,5 @@
from typing import List, Sequence import logging
from typing import Any, Dict, List, Sequence
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_core.documents import Document from langchain_core.documents import Document
@ -22,23 +23,33 @@ class RelikGraphTransformer:
model (str): The name of the pretrained Relik model to use. model (str): The name of the pretrained Relik model to use.
Default is "relik-ie/relik-relation-extraction-small-wikipedia". Default is "relik-ie/relik-relation-extraction-small-wikipedia".
relationship_confidence_threshold (float): The confidence threshold for relationship_confidence_threshold (float): The confidence threshold for
filtering relationships. Default is 0.0. filtering relationships. Default is 0.1.
model_config (Dict[str, any]): Additional configuration options for the
Relik model. Default is an empty dictionary.
ignore_self_loops (bool): Whether to ignore relationships where the
source and target nodes are the same. Default is True.
""" """
def __init__( def __init__(
self, self,
model: str = "relik-ie/relik-relation-extraction-small-wikipedia", model: str = "relik-ie/relik-relation-extraction-small",
relationship_confidence_threshold: float = 0.0, relationship_confidence_threshold: float = 0.1,
model_config: Dict[str, Any] = {},
ignore_self_loops: bool = True,
) -> None: ) -> None:
try: try:
import relik # type: ignore import relik # type: ignore
# Remove default INFO logging
logging.getLogger("relik").setLevel(logging.WARNING)
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import relik python package. " "Could not import relik python package. "
"Please install it with `pip install relik`." "Please install it with `pip install relik`."
) )
self.relik_model = relik.Relik.from_pretrained(model) self.relik_model = relik.Relik.from_pretrained(model, **model_config)
self.relationship_confidence_threshold = relationship_confidence_threshold self.relationship_confidence_threshold = relationship_confidence_threshold
self.ignore_self_loops = ignore_self_loops
def process_document(self, document: Document) -> GraphDocument: def process_document(self, document: Document) -> GraphDocument:
relik_out = self.relik_model(document.page_content) relik_out = self.relik_model(document.page_content)
@ -60,6 +71,9 @@ class RelikGraphTransformer:
# Ignore relationship if below confidence threshold # Ignore relationship if below confidence threshold
if triple.confidence < self.relationship_confidence_threshold: if triple.confidence < self.relationship_confidence_threshold:
continue continue
# Ignore self loops
if self.ignore_self_loops and triple.subject.text == triple.object.text:
continue
source_node = Node( source_node = Node(
id=triple.subject.text, id=triple.subject.text,
type=DEFAULT_NODE_TYPE type=DEFAULT_NODE_TYPE