From a8561bc303396929cf723e53944c5dc870d3283e Mon Sep 17 00:00:00 2001 From: Tomaz Bratanic Date: Thu, 19 Sep 2024 21:15:33 +0800 Subject: [PATCH] Fix async parsing for llm graph transformer (#26650) --- .../graph_transformers/llm.py | 35 +++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/libs/experimental/langchain_experimental/graph_transformers/llm.py b/libs/experimental/langchain_experimental/graph_transformers/llm.py index b7655c90e7e..79ebe8b841f 100644 --- a/libs/experimental/langchain_experimental/graph_transformers/llm.py +++ b/libs/experimental/langchain_experimental/graph_transformers/llm.py @@ -817,8 +817,39 @@ class LLMGraphTransformer: """ text = document.page_content raw_schema = await self.chain.ainvoke({"input": text}, config=config) - raw_schema = cast(Dict[Any, Any], raw_schema) - nodes, relationships = _convert_to_graph_document(raw_schema) + if self._function_call: + raw_schema = cast(Dict[Any, Any], raw_schema) + nodes, relationships = _convert_to_graph_document(raw_schema) + else: + nodes_set = set() + relationships = [] + if not isinstance(raw_schema, str): + raw_schema = raw_schema.content + parsed_json = self.json_repair.loads(raw_schema) + if isinstance(parsed_json, dict): + parsed_json = [parsed_json] + for rel in parsed_json: + # Check if mandatory properties are there + if ( + not rel.get("head") + or not rel.get("tail") + or not rel.get("relation") + ): + continue + # Nodes need to be deduplicated using a set + # Use default Node label for nodes if missing + nodes_set.add((rel["head"], rel.get("head_type", "Node"))) + nodes_set.add((rel["tail"], rel.get("tail_type", "Node"))) + + source_node = Node(id=rel["head"], type=rel.get("head_type", "Node")) + target_node = Node(id=rel["tail"], type=rel.get("tail_type", "Node")) + relationships.append( + Relationship( + source=source_node, target=target_node, type=rel["relation"] + ) + ) + # Create nodes list + nodes = [Node(id=el[0], type=el[1]) for el in list(nodes_set)] if self.strict_mode and (self.allowed_nodes or self.allowed_relationships): if self.allowed_nodes: