mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 00:23:25 +00:00
Add relik transformer config (#25019)
This commit is contained in:
parent
1dcee68cb8
commit
f9a11a9197
@ -1,4 +1,5 @@
|
|||||||
from typing import List, Sequence
|
import logging
|
||||||
|
from typing import Any, Dict, List, Sequence
|
||||||
|
|
||||||
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
@ -22,23 +23,33 @@ class RelikGraphTransformer:
|
|||||||
model (str): The name of the pretrained Relik model to use.
|
model (str): The name of the pretrained Relik model to use.
|
||||||
Default is "relik-ie/relik-relation-extraction-small-wikipedia".
|
Default is "relik-ie/relik-relation-extraction-small-wikipedia".
|
||||||
relationship_confidence_threshold (float): The confidence threshold for
|
relationship_confidence_threshold (float): The confidence threshold for
|
||||||
filtering relationships. Default is 0.0.
|
filtering relationships. Default is 0.1.
|
||||||
|
model_config (Dict[str, any]): Additional configuration options for the
|
||||||
|
Relik model. Default is an empty dictionary.
|
||||||
|
ignore_self_loops (bool): Whether to ignore relationships where the
|
||||||
|
source and target nodes are the same. Default is True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str = "relik-ie/relik-relation-extraction-small-wikipedia",
|
model: str = "relik-ie/relik-relation-extraction-small",
|
||||||
relationship_confidence_threshold: float = 0.0,
|
relationship_confidence_threshold: float = 0.1,
|
||||||
|
model_config: Dict[str, Any] = {},
|
||||||
|
ignore_self_loops: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
import relik # type: ignore
|
import relik # type: ignore
|
||||||
|
|
||||||
|
# Remove default INFO logging
|
||||||
|
logging.getLogger("relik").setLevel(logging.WARNING)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import relik python package. "
|
"Could not import relik python package. "
|
||||||
"Please install it with `pip install relik`."
|
"Please install it with `pip install relik`."
|
||||||
)
|
)
|
||||||
self.relik_model = relik.Relik.from_pretrained(model)
|
self.relik_model = relik.Relik.from_pretrained(model, **model_config)
|
||||||
self.relationship_confidence_threshold = relationship_confidence_threshold
|
self.relationship_confidence_threshold = relationship_confidence_threshold
|
||||||
|
self.ignore_self_loops = ignore_self_loops
|
||||||
|
|
||||||
def process_document(self, document: Document) -> GraphDocument:
|
def process_document(self, document: Document) -> GraphDocument:
|
||||||
relik_out = self.relik_model(document.page_content)
|
relik_out = self.relik_model(document.page_content)
|
||||||
@ -60,6 +71,9 @@ class RelikGraphTransformer:
|
|||||||
# Ignore relationship if below confidence threshold
|
# Ignore relationship if below confidence threshold
|
||||||
if triple.confidence < self.relationship_confidence_threshold:
|
if triple.confidence < self.relationship_confidence_threshold:
|
||||||
continue
|
continue
|
||||||
|
# Ignore self loops
|
||||||
|
if self.ignore_self_loops and triple.subject.text == triple.object.text:
|
||||||
|
continue
|
||||||
source_node = Node(
|
source_node = Node(
|
||||||
id=triple.subject.text,
|
id=triple.subject.text,
|
||||||
type=DEFAULT_NODE_TYPE
|
type=DEFAULT_NODE_TYPE
|
||||||
|
Loading…
Reference in New Issue
Block a user