From b04e6634263712c8965281d2a8aa033fb4866321 Mon Sep 17 00:00:00 2001 From: Tomaz Bratanic Date: Thu, 28 Mar 2024 03:35:34 +0100 Subject: [PATCH] experimental[patch]: Flatten relationships in LLM graph transformer (#19642) --- .../graph_transformers/llm.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/libs/experimental/langchain_experimental/graph_transformers/llm.py b/libs/experimental/langchain_experimental/graph_transformers/llm.py index 34c069b772e..e671bc24754 100644 --- a/libs/experimental/langchain_experimental/graph_transformers/llm.py +++ b/libs/experimental/langchain_experimental/graph_transformers/llm.py @@ -112,8 +112,18 @@ def create_simple_model( class SimpleRelationship(BaseModel): """Represents a directed relationship between two nodes in a graph.""" - source: SimpleNode = Field(description="The source node of the relationship.") - target: SimpleNode = Field(description="The target node of the relationship.") + source_node_id: str = Field( + description="Name or human-readable unique identifier of source node" + ) + source_node_type: str = optional_enum_field( + node_labels, description="The type or label of the source node." + ) + target_node_id: str = Field( + description="Name or human-readable unique identifier of target node" + ) + target_node_type: str = optional_enum_field( + node_labels, description="The type or label of the target node." + ) type: str = optional_enum_field( rel_types, description="The type of the relationship.", is_rel=True ) @@ -136,8 +146,8 @@ def map_to_base_node(node: Any) -> Node: def map_to_base_relationship(rel: Any) -> Relationship: """Map the SimpleRelationship to the base Relationship.""" - source = map_to_base_node(rel.source) - target = map_to_base_node(rel.target) + source = Node(id=rel.source_node_id.title(), type=rel.source_node_type.capitalize()) + target = Node(id=rel.target_node_id.title(), type=rel.target_node_type.capitalize()) return Relationship( source=source, target=target, type=rel.type.replace(" ", "_").upper() )