Fix async parsing for llm graph transformer (#26650)

This commit is contained in:
Tomaz Bratanic 2024-09-19 21:15:33 +08:00 committed by GitHub
parent 4e0a6ebe7d
commit a8561bc303
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -817,8 +817,39 @@ class LLMGraphTransformer:
"""
text = document.page_content
raw_schema = await self.chain.ainvoke({"input": text}, config=config)
raw_schema = cast(Dict[Any, Any], raw_schema)
nodes, relationships = _convert_to_graph_document(raw_schema)
if self._function_call:
raw_schema = cast(Dict[Any, Any], raw_schema)
nodes, relationships = _convert_to_graph_document(raw_schema)
else:
nodes_set = set()
relationships = []
if not isinstance(raw_schema, str):
raw_schema = raw_schema.content
parsed_json = self.json_repair.loads(raw_schema)
if isinstance(parsed_json, dict):
parsed_json = [parsed_json]
for rel in parsed_json:
# Check if mandatory properties are there
if (
not rel.get("head")
or not rel.get("tail")
or not rel.get("relation")
):
continue
# Nodes need to be deduplicated using a set
# Use default Node label for nodes if missing
nodes_set.add((rel["head"], rel.get("head_type", "Node")))
nodes_set.add((rel["tail"], rel.get("tail_type", "Node")))
source_node = Node(id=rel["head"], type=rel.get("head_type", "Node"))
target_node = Node(id=rel["tail"], type=rel.get("tail_type", "Node"))
relationships.append(
Relationship(
source=source_node, target=target_node, type=rel["relation"]
)
)
# Create nodes list
nodes = [Node(id=el[0], type=el[1]) for el in list(nodes_set)]
if self.strict_mode and (self.allowed_nodes or self.allowed_relationships):
if self.allowed_nodes: