mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 21:33:51 +00:00
community: update Memgraph integration (#27017)
**Description:** - **Memgraph** no longer relies on `Neo4jGraphStore` but **implements `GraphStore`**, just like other graph databases. - **Memgraph** no longer relies on `GraphQAChain`, but implements `MemgraphQAChain`, just like other graph databases. - The refresh schema procedure has been updated to try using `SHOW SCHEMA INFO`. The fallback uses Cypher queries (a combination of schema and Cypher) → **LangChain integration no longer relies on MAGE library**. - The **schema structure** has been reformatted. Regardless of the procedures used to get schema, schema structure is the same. - The `add_graph_documents()` method has been implemented. It transforms `GraphDocument` into Cypher queries and creates a graph in Memgraph. It implements the ability to use `baseEntityLabel` to improve speed (`baseEntityLabel` has an index on the `id` property). It also implements the ability to include sources by creating a `MENTIONS` relationship to the source document. - Jupyter Notebook for Memgraph has been updated. - **Issue:** / - **Dependencies:** / - **Twitter handle:** supe_katarina (DX Engineer @ Memgraph) Closes #25606
This commit is contained in:
parent
5c6e2cbcda
commit
aba2711e7f
File diff suppressed because it is too large
Load Diff
BIN
docs/static/img/memgraph_kg.png
vendored
Normal file
BIN
docs/static/img/memgraph_kg.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 117 KiB |
BIN
docs/static/img/memgraph_kg_2.png
vendored
Normal file
BIN
docs/static/img/memgraph_kg_2.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 66 KiB |
BIN
docs/static/img/memgraph_kg_3.png
vendored
Normal file
BIN
docs/static/img/memgraph_kg_3.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 137 KiB |
BIN
docs/static/img/memgraph_kg_4.png
vendored
Normal file
BIN
docs/static/img/memgraph_kg_4.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 251 KiB |
316
libs/community/langchain_community/chains/graph_qa/memgraph.py
Normal file
316
libs/community/langchain_community/chains/graph_qa/memgraph.py
Normal file
@ -0,0 +1,316 @@
|
|||||||
|
"""Question answering over a graph."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||||
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
SystemMessage,
|
||||||
|
ToolMessage,
|
||||||
|
)
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from langchain_core.prompts import (
|
||||||
|
BasePromptTemplate,
|
||||||
|
ChatPromptTemplate,
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
MessagesPlaceholder,
|
||||||
|
)
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from langchain_community.chains.graph_qa.prompts import (
|
||||||
|
MEMGRAPH_GENERATION_PROMPT,
|
||||||
|
MEMGRAPH_QA_PROMPT,
|
||||||
|
)
|
||||||
|
from langchain_community.graphs.memgraph_graph import MemgraphGraph
|
||||||
|
|
||||||
|
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||||
|
|
||||||
|
FUNCTION_RESPONSE_SYSTEM = """You are an assistant that helps to form nice and human
|
||||||
|
understandable answers based on the provided information from tools.
|
||||||
|
Do not add any other information that wasn't present in the tools, and use
|
||||||
|
very concise style in interpreting results!
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def extract_cypher(text: str) -> str:
|
||||||
|
"""Extract Cypher code from a text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to extract Cypher code from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cypher code extracted from the text.
|
||||||
|
"""
|
||||||
|
# The pattern to find Cypher code enclosed in triple backticks
|
||||||
|
pattern = r"```(.*?)```"
|
||||||
|
|
||||||
|
# Find all matches in the input text
|
||||||
|
matches = re.findall(pattern, text, re.DOTALL)
|
||||||
|
|
||||||
|
return matches[0] if matches else text
|
||||||
|
|
||||||
|
|
||||||
|
def get_function_response(
|
||||||
|
question: str, context: List[Dict[str, Any]]
|
||||||
|
) -> List[BaseMessage]:
|
||||||
|
TOOL_ID = "call_H7fABDuzEau48T10Qn0Lsh0D"
|
||||||
|
messages = [
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": TOOL_ID,
|
||||||
|
"function": {
|
||||||
|
"arguments": '{"question":"' + question + '"}',
|
||||||
|
"name": "GetInformation",
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
ToolMessage(content=str(context), tool_call_id=TOOL_ID),
|
||||||
|
]
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
class MemgraphQAChain(Chain):
|
||||||
|
"""Chain for question-answering against a graph by generating Cypher statements.
|
||||||
|
|
||||||
|
*Security note*: Make sure that the database connection uses credentials
|
||||||
|
that are narrowly-scoped to only include necessary permissions.
|
||||||
|
Failure to do so may result in data corruption or loss, since the calling
|
||||||
|
code may attempt commands that would result in deletion, mutation
|
||||||
|
of data if appropriately prompted or reading sensitive data if such
|
||||||
|
data is present in the database.
|
||||||
|
The best way to guard against such negative outcomes is to (as appropriate)
|
||||||
|
limit the permissions granted to the credentials used with this tool.
|
||||||
|
|
||||||
|
See https://python.langchain.com/docs/security for more information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph: MemgraphGraph = Field(exclude=True)
|
||||||
|
cypher_generation_chain: Runnable
|
||||||
|
qa_chain: Runnable
|
||||||
|
graph_schema: str
|
||||||
|
input_key: str = "query" #: :meta private:
|
||||||
|
output_key: str = "result" #: :meta private:
|
||||||
|
top_k: int = 10
|
||||||
|
"""Number of results to return from the query"""
|
||||||
|
return_intermediate_steps: bool = False
|
||||||
|
"""Whether or not to return the intermediate steps along with the final answer."""
|
||||||
|
return_direct: bool = False
|
||||||
|
"""Optional cypher validation tool"""
|
||||||
|
use_function_response: bool = False
|
||||||
|
"""Whether to wrap the database context as tool/function response"""
|
||||||
|
allow_dangerous_requests: bool = False
|
||||||
|
"""Forced user opt-in to acknowledge that the chain can make dangerous requests.
|
||||||
|
|
||||||
|
*Security note*: Make sure that the database connection uses credentials
|
||||||
|
that are narrowly-scoped to only include necessary permissions.
|
||||||
|
Failure to do so may result in data corruption or loss, since the calling
|
||||||
|
code may attempt commands that would result in deletion, mutation
|
||||||
|
of data if appropriately prompted or reading sensitive data if such
|
||||||
|
data is present in the database.
|
||||||
|
The best way to guard against such negative outcomes is to (as appropriate)
|
||||||
|
limit the permissions granted to the credentials used with this tool.
|
||||||
|
|
||||||
|
See https://python.langchain.com/docs/security for more information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
"""Initialize the chain."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if self.allow_dangerous_requests is not True:
|
||||||
|
raise ValueError(
|
||||||
|
"In order to use this chain, you must acknowledge that it can make "
|
||||||
|
"dangerous requests by setting `allow_dangerous_requests` to `True`."
|
||||||
|
"You must narrowly scope the permissions of the database connection "
|
||||||
|
"to only include necessary permissions. Failure to do so may result "
|
||||||
|
"in data corruption or loss or reading sensitive data if such data is "
|
||||||
|
"present in the database."
|
||||||
|
"Only use this chain if you understand the risks and have taken the "
|
||||||
|
"necessary precautions. "
|
||||||
|
"See https://python.langchain.com/docs/security for more information."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Return the input keys.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Return the output keys.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
_output_keys = [self.output_key]
|
||||||
|
return _output_keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _chain_type(self) -> str:
|
||||||
|
return "graph_cypher_chain"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm(
|
||||||
|
cls,
|
||||||
|
llm: Optional[BaseLanguageModel] = None,
|
||||||
|
*,
|
||||||
|
qa_prompt: Optional[BasePromptTemplate] = None,
|
||||||
|
cypher_prompt: Optional[BasePromptTemplate] = None,
|
||||||
|
cypher_llm: Optional[BaseLanguageModel] = None,
|
||||||
|
qa_llm: Optional[Union[BaseLanguageModel, Any]] = None,
|
||||||
|
qa_llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
cypher_llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
use_function_response: bool = False,
|
||||||
|
function_response_system: str = FUNCTION_RESPONSE_SYSTEM,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> MemgraphQAChain:
|
||||||
|
"""Initialize from LLM."""
|
||||||
|
|
||||||
|
if not cypher_llm and not llm:
|
||||||
|
raise ValueError("Either `llm` or `cypher_llm` parameters must be provided")
|
||||||
|
if not qa_llm and not llm:
|
||||||
|
raise ValueError("Either `llm` or `qa_llm` parameters must be provided")
|
||||||
|
if cypher_llm and qa_llm and llm:
|
||||||
|
raise ValueError(
|
||||||
|
"You can specify up to two of 'cypher_llm', 'qa_llm'"
|
||||||
|
", and 'llm', but not all three simultaneously."
|
||||||
|
)
|
||||||
|
if cypher_prompt and cypher_llm_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
"Specifying cypher_prompt and cypher_llm_kwargs together is"
|
||||||
|
" not allowed. Please pass prompt via cypher_llm_kwargs."
|
||||||
|
)
|
||||||
|
if qa_prompt and qa_llm_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
"Specifying qa_prompt and qa_llm_kwargs together is"
|
||||||
|
" not allowed. Please pass prompt via qa_llm_kwargs."
|
||||||
|
)
|
||||||
|
use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {}
|
||||||
|
use_cypher_llm_kwargs = (
|
||||||
|
cypher_llm_kwargs if cypher_llm_kwargs is not None else {}
|
||||||
|
)
|
||||||
|
if "prompt" not in use_qa_llm_kwargs:
|
||||||
|
use_qa_llm_kwargs["prompt"] = (
|
||||||
|
qa_prompt if qa_prompt is not None else MEMGRAPH_QA_PROMPT
|
||||||
|
)
|
||||||
|
if "prompt" not in use_cypher_llm_kwargs:
|
||||||
|
use_cypher_llm_kwargs["prompt"] = (
|
||||||
|
cypher_prompt
|
||||||
|
if cypher_prompt is not None
|
||||||
|
else MEMGRAPH_GENERATION_PROMPT
|
||||||
|
)
|
||||||
|
|
||||||
|
qa_llm = qa_llm or llm
|
||||||
|
if use_function_response:
|
||||||
|
try:
|
||||||
|
qa_llm.bind_tools({}) # type: ignore[union-attr]
|
||||||
|
response_prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
SystemMessage(content=function_response_system),
|
||||||
|
HumanMessagePromptTemplate.from_template("{question}"),
|
||||||
|
MessagesPlaceholder(variable_name="function_response"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore
|
||||||
|
except (NotImplementedError, AttributeError):
|
||||||
|
raise ValueError("Provided LLM does not support native tools/functions")
|
||||||
|
else:
|
||||||
|
qa_chain = use_qa_llm_kwargs["prompt"] | qa_llm | StrOutputParser() # type: ignore
|
||||||
|
|
||||||
|
prompt = use_cypher_llm_kwargs["prompt"]
|
||||||
|
llm_to_use = cypher_llm if cypher_llm is not None else llm
|
||||||
|
|
||||||
|
if prompt is not None and llm_to_use is not None:
|
||||||
|
cypher_generation_chain = prompt | llm_to_use | StrOutputParser() # type: ignore[arg-type]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing required components for the cypher generation chain: "
|
||||||
|
"'prompt' or 'llm'"
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_schema = kwargs["graph"].get_schema
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
graph_schema=graph_schema,
|
||||||
|
qa_chain=qa_chain,
|
||||||
|
cypher_generation_chain=cypher_generation_chain,
|
||||||
|
use_function_response=use_function_response,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Generate Cypher statement, use it to look up in db and answer question."""
|
||||||
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
|
callbacks = _run_manager.get_child()
|
||||||
|
question = inputs[self.input_key]
|
||||||
|
args = {
|
||||||
|
"question": question,
|
||||||
|
"schema": self.graph_schema,
|
||||||
|
}
|
||||||
|
args.update(inputs)
|
||||||
|
|
||||||
|
intermediate_steps: List = []
|
||||||
|
|
||||||
|
generated_cypher = self.cypher_generation_chain.invoke(
|
||||||
|
args, callbacks=callbacks
|
||||||
|
)
|
||||||
|
# Extract Cypher code if it is wrapped in backticks
|
||||||
|
generated_cypher = extract_cypher(generated_cypher)
|
||||||
|
|
||||||
|
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
||||||
|
_run_manager.on_text(
|
||||||
|
generated_cypher, color="green", end="\n", verbose=self.verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
intermediate_steps.append({"query": generated_cypher})
|
||||||
|
|
||||||
|
# Retrieve and limit the number of results
|
||||||
|
# Generated Cypher be null if query corrector identifies invalid schema
|
||||||
|
if generated_cypher:
|
||||||
|
context = self.graph.query(generated_cypher)[: self.top_k]
|
||||||
|
else:
|
||||||
|
context = []
|
||||||
|
|
||||||
|
if self.return_direct:
|
||||||
|
result = context
|
||||||
|
else:
|
||||||
|
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||||
|
_run_manager.on_text(
|
||||||
|
str(context), color="green", end="\n", verbose=self.verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
intermediate_steps.append({"context": context})
|
||||||
|
if self.use_function_response:
|
||||||
|
function_response = get_function_response(question, context)
|
||||||
|
result = self.qa_chain.invoke( # type: ignore
|
||||||
|
{"question": question, "function_response": function_response},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = self.qa_chain.invoke( # type: ignore
|
||||||
|
{"question": question, "context": context},
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
chain_result: Dict[str, Any] = {"result": result}
|
||||||
|
if self.return_intermediate_steps:
|
||||||
|
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||||
|
|
||||||
|
return chain_result
|
@ -411,3 +411,58 @@ NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate(
|
|||||||
input_variables=["schema", "question", "extra_instructions"],
|
input_variables=["schema", "question", "extra_instructions"],
|
||||||
template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE,
|
template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MEMGRAPH_GENERATION_TEMPLATE = """Your task is to directly translate natural language inquiry into precise and executable Cypher query for Memgraph database.
|
||||||
|
You will utilize a provided database schema to understand the structure, nodes and relationships within the Memgraph database.
|
||||||
|
Instructions:
|
||||||
|
- Use provided node and relationship labels and property names from the
|
||||||
|
schema which describes the database's structure. Upon receiving a user
|
||||||
|
question, synthesize the schema to craft a precise Cypher query that
|
||||||
|
directly corresponds to the user's intent.
|
||||||
|
- Generate valid executable Cypher queries on top of Memgraph database.
|
||||||
|
Any explanation, context, or additional information that is not a part
|
||||||
|
of the Cypher query syntax should be omitted entirely.
|
||||||
|
- Use Memgraph MAGE procedures instead of Neo4j APOC procedures.
|
||||||
|
- Do not include any explanations or apologies in your responses.
|
||||||
|
- Do not include any text except the generated Cypher statement.
|
||||||
|
- For queries that ask for information or functionalities outside the direct
|
||||||
|
generation of Cypher queries, use the Cypher query format to communicate
|
||||||
|
limitations or capabilities. For example: RETURN "I am designed to generate
|
||||||
|
Cypher queries based on the provided schema only."
|
||||||
|
Schema:
|
||||||
|
{schema}
|
||||||
|
|
||||||
|
With all the above information and instructions, generate Cypher query for the
|
||||||
|
user question.
|
||||||
|
|
||||||
|
The question is:
|
||||||
|
{question}"""
|
||||||
|
|
||||||
|
MEMGRAPH_GENERATION_PROMPT = PromptTemplate(
|
||||||
|
input_variables=["schema", "question"], template=MEMGRAPH_GENERATION_TEMPLATE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MEMGRAPH_QA_TEMPLATE = """Your task is to form nice and human
|
||||||
|
understandable answers. The information part contains the provided
|
||||||
|
information that you must use to construct an answer.
|
||||||
|
The provided information is authoritative, you must never doubt it or try to
|
||||||
|
use your internal knowledge to correct it. Make the answer sound as a
|
||||||
|
response to the question. Do not mention that you based the result on the
|
||||||
|
given information. Here is an example:
|
||||||
|
|
||||||
|
Question: Which managers own Neo4j stocks?
|
||||||
|
Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC]
|
||||||
|
Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks.
|
||||||
|
|
||||||
|
Follow this example when generating answers. If the provided information is
|
||||||
|
empty, say that you don't know the answer.
|
||||||
|
|
||||||
|
Information:
|
||||||
|
{context}
|
||||||
|
|
||||||
|
Question: {question}
|
||||||
|
Helpful Answer:"""
|
||||||
|
MEMGRAPH_QA_PROMPT = PromptTemplate(
|
||||||
|
input_variables=["context", "question"], template=MEMGRAPH_QA_TEMPLATE
|
||||||
|
)
|
||||||
|
@ -1,15 +1,272 @@
|
|||||||
from langchain_community.graphs.neo4j_graph import Neo4jGraph
|
import logging
|
||||||
|
from hashlib import md5
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||||
|
from langchain_community.graphs.graph_store import GraphStore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
BASE_ENTITY_LABEL = "__Entity__"
|
||||||
|
|
||||||
SCHEMA_QUERY = """
|
SCHEMA_QUERY = """
|
||||||
CALL llm_util.schema("raw")
|
SHOW SCHEMA INFO
|
||||||
YIELD *
|
"""
|
||||||
RETURN *
|
|
||||||
|
NODE_PROPERTIES_QUERY = """
|
||||||
|
CALL schema.node_type_properties()
|
||||||
|
YIELD nodeType AS label, propertyName AS property, propertyTypes AS type
|
||||||
|
WITH label AS nodeLabels, collect({key: property, types: type}) AS properties
|
||||||
|
RETURN {labels: nodeLabels, properties: properties} AS output
|
||||||
|
"""
|
||||||
|
|
||||||
|
REL_QUERY = """
|
||||||
|
MATCH (n)-[e]->(m)
|
||||||
|
WITH DISTINCT
|
||||||
|
labels(n) AS start_node_labels,
|
||||||
|
type(e) AS rel_type,
|
||||||
|
labels(m) AS end_node_labels,
|
||||||
|
e,
|
||||||
|
keys(e) AS properties
|
||||||
|
UNWIND CASE WHEN size(properties) > 0 THEN properties ELSE [null] END AS prop
|
||||||
|
WITH
|
||||||
|
start_node_labels,
|
||||||
|
rel_type,
|
||||||
|
end_node_labels,
|
||||||
|
CASE WHEN prop IS NULL THEN [] ELSE [prop, valueType(e[prop])] END AS property_info
|
||||||
|
RETURN
|
||||||
|
start_node_labels,
|
||||||
|
rel_type,
|
||||||
|
end_node_labels,
|
||||||
|
COLLECT(DISTINCT CASE
|
||||||
|
WHEN property_info <> []
|
||||||
|
THEN property_info
|
||||||
|
ELSE null END) AS properties_info
|
||||||
|
"""
|
||||||
|
|
||||||
|
NODE_IMPORT_QUERY = """
|
||||||
|
UNWIND $data AS row
|
||||||
|
CALL merge.node(row.label, row.properties, {}, {})
|
||||||
|
YIELD node
|
||||||
|
RETURN distinct 'done' AS result
|
||||||
|
"""
|
||||||
|
|
||||||
|
REL_NODES_IMPORT_QUERY = """
|
||||||
|
UNWIND $data AS row
|
||||||
|
MERGE (source {id: row.source_id})
|
||||||
|
MERGE (target {id: row.target_id})
|
||||||
|
RETURN distinct 'done' AS result
|
||||||
|
"""
|
||||||
|
|
||||||
|
REL_IMPORT_QUERY = """
|
||||||
|
UNWIND $data AS row
|
||||||
|
MATCH (source {id: row.source_id})
|
||||||
|
MATCH (target {id: row.target_id})
|
||||||
|
WITH source, target, row
|
||||||
|
CALL merge.relationship(source, row.type, {}, {}, target, {})
|
||||||
|
YIELD rel
|
||||||
|
RETURN distinct 'done' AS result
|
||||||
|
"""
|
||||||
|
|
||||||
|
INCLUDE_DOCS_QUERY = """
|
||||||
|
MERGE (d:Document {id:$document.metadata.id})
|
||||||
|
SET d.content = $document.page_content
|
||||||
|
SET d += $document.metadata
|
||||||
|
RETURN distinct 'done' AS result
|
||||||
|
"""
|
||||||
|
|
||||||
|
INCLUDE_DOCS_SOURCE_QUERY = """
|
||||||
|
UNWIND $data AS row
|
||||||
|
MATCH (source {id: row.source_id}), (d:Document {id: $document.metadata.id})
|
||||||
|
MERGE (d)-[:MENTIONS]->(source)
|
||||||
|
RETURN distinct 'done' AS result
|
||||||
|
"""
|
||||||
|
|
||||||
|
NODE_PROPS_TEXT = """
|
||||||
|
Node labels and properties (name and type) are:
|
||||||
|
"""
|
||||||
|
|
||||||
|
REL_PROPS_TEXT = """
|
||||||
|
Relationship labels and properties are:
|
||||||
|
"""
|
||||||
|
|
||||||
|
REL_TEXT = """
|
||||||
|
Nodes are connected with the following relationships:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class MemgraphGraph(Neo4jGraph):
|
def get_schema_subset(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"end_node_labels": edge["end_node_labels"],
|
||||||
|
"properties": [
|
||||||
|
{
|
||||||
|
"key": prop["key"],
|
||||||
|
"types": [
|
||||||
|
{"type": type_item["type"].lower()}
|
||||||
|
for type_item in prop["types"]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
for prop in edge["properties"]
|
||||||
|
],
|
||||||
|
"start_node_labels": edge["start_node_labels"],
|
||||||
|
"type": edge["type"],
|
||||||
|
}
|
||||||
|
for edge in data["edges"]
|
||||||
|
],
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"labels": node["labels"],
|
||||||
|
"properties": [
|
||||||
|
{
|
||||||
|
"key": prop["key"],
|
||||||
|
"types": [
|
||||||
|
{"type": type_item["type"].lower()}
|
||||||
|
for type_item in prop["types"]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
for prop in node["properties"]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
for node in data["nodes"]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_reformated_schema(
|
||||||
|
nodes: List[Dict[str, Any]], rels: List[Dict[str, Any]]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"end_node_labels": rel["end_node_labels"],
|
||||||
|
"properties": [
|
||||||
|
{"key": prop[0], "types": [{"type": prop[1].lower()}]}
|
||||||
|
for prop in rel["properties_info"]
|
||||||
|
],
|
||||||
|
"start_node_labels": rel["start_node_labels"],
|
||||||
|
"type": rel["rel_type"],
|
||||||
|
}
|
||||||
|
for rel in rels
|
||||||
|
],
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"labels": [_remove_backticks(node["labels"])[1:]],
|
||||||
|
"properties": [
|
||||||
|
{
|
||||||
|
"key": prop["key"],
|
||||||
|
"types": [
|
||||||
|
{"type": type_item.lower()} for type_item in prop["types"]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
for prop in node["properties"]
|
||||||
|
if node["properties"][0]["key"] != ""
|
||||||
|
],
|
||||||
|
}
|
||||||
|
for node in nodes
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def transform_schema_to_text(schema: Dict[str, Any]) -> str:
|
||||||
|
node_props_data = ""
|
||||||
|
rel_props_data = ""
|
||||||
|
rel_data = ""
|
||||||
|
|
||||||
|
for node in schema["nodes"]:
|
||||||
|
node_props_data += f"- labels: (:{':'.join(node['labels'])})\n"
|
||||||
|
if node["properties"] == []:
|
||||||
|
continue
|
||||||
|
node_props_data += " properties:\n"
|
||||||
|
for prop in node["properties"]:
|
||||||
|
prop_types_str = " or ".join(
|
||||||
|
{prop_types["type"] for prop_types in prop["types"]}
|
||||||
|
)
|
||||||
|
node_props_data += f" - {prop['key']}: {prop_types_str}\n"
|
||||||
|
|
||||||
|
for rel in schema["edges"]:
|
||||||
|
rel_type = rel["type"]
|
||||||
|
start_labels = ":".join(rel["start_node_labels"])
|
||||||
|
end_labels = ":".join(rel["end_node_labels"])
|
||||||
|
rel_data += f"(:{start_labels})-[:{rel_type}]->(:{end_labels})\n"
|
||||||
|
|
||||||
|
if rel["properties"] == []:
|
||||||
|
continue
|
||||||
|
|
||||||
|
rel_props_data += f"- labels: {rel_type}\n properties:\n"
|
||||||
|
for prop in rel["properties"]:
|
||||||
|
prop_types_str = " or ".join(
|
||||||
|
{prop_types["type"].lower() for prop_types in prop["types"]}
|
||||||
|
)
|
||||||
|
rel_props_data += f" - {prop['key']}: {prop_types_str}\n"
|
||||||
|
|
||||||
|
return "".join(
|
||||||
|
[
|
||||||
|
NODE_PROPS_TEXT + node_props_data if node_props_data else "",
|
||||||
|
REL_PROPS_TEXT + rel_props_data if rel_props_data else "",
|
||||||
|
REL_TEXT + rel_data if rel_data else "",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_backticks(text: str) -> str:
|
||||||
|
return text.replace("`", "")
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_nodes(nodes: list[Node], baseEntityLabel: bool) -> List[dict]:
|
||||||
|
transformed_nodes = []
|
||||||
|
for node in nodes:
|
||||||
|
properties_dict = node.properties | {"id": node.id}
|
||||||
|
label = (
|
||||||
|
[_remove_backticks(node.type), BASE_ENTITY_LABEL]
|
||||||
|
if baseEntityLabel
|
||||||
|
else [_remove_backticks(node.type)]
|
||||||
|
)
|
||||||
|
node_dict = {"label": label, "properties": properties_dict}
|
||||||
|
transformed_nodes.append(node_dict)
|
||||||
|
return transformed_nodes
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_relationships(
|
||||||
|
relationships: list[Relationship], baseEntityLabel: bool
|
||||||
|
) -> List[dict]:
|
||||||
|
transformed_relationships = []
|
||||||
|
for rel in relationships:
|
||||||
|
rel_dict = {
|
||||||
|
"type": _remove_backticks(rel.type),
|
||||||
|
"source_label": (
|
||||||
|
[BASE_ENTITY_LABEL]
|
||||||
|
if baseEntityLabel
|
||||||
|
else [_remove_backticks(rel.source.type)]
|
||||||
|
),
|
||||||
|
"source_id": rel.source.id,
|
||||||
|
"target_label": (
|
||||||
|
[BASE_ENTITY_LABEL]
|
||||||
|
if baseEntityLabel
|
||||||
|
else [_remove_backticks(rel.target.type)]
|
||||||
|
),
|
||||||
|
"target_id": rel.target.id,
|
||||||
|
}
|
||||||
|
transformed_relationships.append(rel_dict)
|
||||||
|
return transformed_relationships
|
||||||
|
|
||||||
|
|
||||||
|
class MemgraphGraph(GraphStore):
|
||||||
"""Memgraph wrapper for graph operations.
|
"""Memgraph wrapper for graph operations.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
url (Optional[str]): The URL of the Memgraph database server.
|
||||||
|
username (Optional[str]): The username for database authentication.
|
||||||
|
password (Optional[str]): The password for database authentication.
|
||||||
|
database (str): The name of the database to connect to. Default is 'memgraph'.
|
||||||
|
refresh_schema (bool): A flag whether to refresh schema information
|
||||||
|
at initialization. Default is True.
|
||||||
|
driver_config (Dict): Configuration passed to Neo4j Driver.
|
||||||
|
|
||||||
*Security note*: Make sure that the database connection uses credentials
|
*Security note*: Make sure that the database connection uses credentials
|
||||||
that are narrowly-scoped to only include necessary permissions.
|
that are narrowly-scoped to only include necessary permissions.
|
||||||
Failure to do so may result in data corruption or loss, since the calling
|
Failure to do so may result in data corruption or loss, since the calling
|
||||||
@ -23,49 +280,247 @@ class MemgraphGraph(Neo4jGraph):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, url: str, username: str, password: str, *, database: str = "memgraph"
|
self,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
username: Optional[str] = None,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
database: Optional[str] = None,
|
||||||
|
refresh_schema: bool = True,
|
||||||
|
*,
|
||||||
|
driver_config: Optional[Dict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create a new Memgraph graph wrapper instance."""
|
"""Create a new Memgraph graph wrapper instance."""
|
||||||
super().__init__(url, username, password, database=database)
|
try:
|
||||||
|
import neo4j
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import neo4j python package. "
|
||||||
|
"Please install it with `pip install neo4j`."
|
||||||
|
)
|
||||||
|
|
||||||
|
url = get_from_dict_or_env({"url": url}, "url", "MEMGRAPH_URI")
|
||||||
|
|
||||||
|
# if username and password are "", assume auth is disabled
|
||||||
|
if username == "" and password == "":
|
||||||
|
auth = None
|
||||||
|
else:
|
||||||
|
username = get_from_dict_or_env(
|
||||||
|
{"username": username},
|
||||||
|
"username",
|
||||||
|
"MEMGRAPH_USERNAME",
|
||||||
|
)
|
||||||
|
password = get_from_dict_or_env(
|
||||||
|
{"password": password},
|
||||||
|
"password",
|
||||||
|
"MEMGRAPH_PASSWORD",
|
||||||
|
)
|
||||||
|
auth = (username, password)
|
||||||
|
database = get_from_dict_or_env(
|
||||||
|
{"database": database}, "database", "MEMGRAPH_DATABASE", "memgraph"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._driver = neo4j.GraphDatabase.driver(
|
||||||
|
url, auth=auth, **(driver_config or {})
|
||||||
|
)
|
||||||
|
|
||||||
|
self._database = database
|
||||||
|
self.schema: str = ""
|
||||||
|
self.structured_schema: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Verify connection
|
||||||
|
try:
|
||||||
|
self._driver.verify_connectivity()
|
||||||
|
except neo4j.exceptions.ServiceUnavailable:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not connect to Memgraph database. "
|
||||||
|
"Please ensure that the url is correct"
|
||||||
|
)
|
||||||
|
except neo4j.exceptions.AuthError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not connect to Memgraph database. "
|
||||||
|
"Please ensure that the username and password are correct"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set schema
|
||||||
|
if refresh_schema:
|
||||||
|
try:
|
||||||
|
self.refresh_schema()
|
||||||
|
except neo4j.exceptions.ClientError as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
if self._driver:
|
||||||
|
logger.info("Closing the driver connection.")
|
||||||
|
self._driver.close()
|
||||||
|
self._driver = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def get_schema(self) -> str:
|
||||||
|
"""Returns the schema of the Graph database"""
|
||||||
|
return self.schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def get_structured_schema(self) -> Dict[str, Any]:
|
||||||
|
"""Returns the structured schema of the Graph database"""
|
||||||
|
return self.structured_schema
|
||||||
|
|
||||||
|
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
|
||||||
|
"""Query the graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The Cypher query to execute.
|
||||||
|
params (dict): The parameters to pass to the query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, Any]]: The list of dictionaries containing the query results.
|
||||||
|
"""
|
||||||
|
from neo4j.exceptions import Neo4jError
|
||||||
|
|
||||||
|
try:
|
||||||
|
data, _, _ = self._driver.execute_query(
|
||||||
|
query,
|
||||||
|
database_=self._database,
|
||||||
|
parameters_=params,
|
||||||
|
)
|
||||||
|
json_data = [r.data() for r in data]
|
||||||
|
return json_data
|
||||||
|
except Neo4jError as e:
|
||||||
|
if not (
|
||||||
|
(
|
||||||
|
( # isCallInTransactionError
|
||||||
|
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
|
||||||
|
or e.code
|
||||||
|
== "Neo.DatabaseError.Transaction.TransactionStartFailed"
|
||||||
|
)
|
||||||
|
and "in an implicit transaction" in e.message
|
||||||
|
)
|
||||||
|
or ( # isPeriodicCommitError
|
||||||
|
e.code == "Neo.ClientError.Statement.SemanticError"
|
||||||
|
and (
|
||||||
|
"in an open transaction is not possible" in e.message
|
||||||
|
or "tried to execute in an explicit transaction" in e.message
|
||||||
|
)
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
e.code == "Memgraph.ClientError.MemgraphError.MemgraphError"
|
||||||
|
and ("in multicommand transactions" in e.message)
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
e.code == "Memgraph.ClientError.MemgraphError.MemgraphError"
|
||||||
|
and "SchemaInfo disabled" in e.message
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise
|
||||||
|
|
||||||
|
# fallback to allow implicit transactions
|
||||||
|
with self._driver.session(database=self._database) as session:
|
||||||
|
data = session.run(query, params)
|
||||||
|
json_data = [r.data() for r in data]
|
||||||
|
return json_data
|
||||||
|
|
||||||
def refresh_schema(self) -> None:
|
def refresh_schema(self) -> None:
|
||||||
"""
|
"""
|
||||||
Refreshes the Memgraph graph schema information.
|
Refreshes the Memgraph graph schema information.
|
||||||
"""
|
"""
|
||||||
|
import ast
|
||||||
|
|
||||||
db_structured_schema = self.query(SCHEMA_QUERY)[0].get("schema")
|
from neo4j.exceptions import Neo4jError
|
||||||
assert db_structured_schema is not None
|
|
||||||
self.structured_schema = db_structured_schema
|
|
||||||
|
|
||||||
# Format node properties
|
# leave schema empty if db is empty
|
||||||
formatted_node_props = []
|
if self.query("MATCH (n) RETURN n LIMIT 1") == []:
|
||||||
|
return
|
||||||
|
|
||||||
for node_name, properties in db_structured_schema["node_props"].items():
|
# first try with SHOW SCHEMA INFO
|
||||||
formatted_node_props.append(
|
try:
|
||||||
f"Node name: '{node_name}', Node properties: {properties}"
|
result = self.query(SCHEMA_QUERY)[0].get("schema")
|
||||||
|
if result is not None and isinstance(result, (str, ast.AST)):
|
||||||
|
schema_result = ast.literal_eval(result)
|
||||||
|
else:
|
||||||
|
schema_result = result
|
||||||
|
assert schema_result is not None
|
||||||
|
structured_schema = get_schema_subset(schema_result)
|
||||||
|
self.structured_schema = structured_schema
|
||||||
|
self.schema = transform_schema_to_text(structured_schema)
|
||||||
|
return
|
||||||
|
except Neo4jError as e:
|
||||||
|
if (
|
||||||
|
e.code == "Memgraph.ClientError.MemgraphError.MemgraphError"
|
||||||
|
and "SchemaInfo disabled" in e.message
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Schema generation with SHOW SCHEMA INFO query failed. "
|
||||||
|
"Set --schema-info-enabled=true to use SHOW SCHEMA INFO query. "
|
||||||
|
"Falling back to alternative queries."
|
||||||
|
)
|
||||||
|
|
||||||
|
# fallback on Cypher without SHOW SCHEMA INFO
|
||||||
|
nodes = [query["output"] for query in self.query(NODE_PROPERTIES_QUERY)]
|
||||||
|
rels = self.query(REL_QUERY)
|
||||||
|
|
||||||
|
structured_schema = get_reformated_schema(nodes, rels)
|
||||||
|
self.structured_schema = structured_schema
|
||||||
|
self.schema = transform_schema_to_text(structured_schema)
|
||||||
|
|
||||||
|
def add_graph_documents(
|
||||||
|
self,
|
||||||
|
graph_documents: List[GraphDocument],
|
||||||
|
include_source: bool = False,
|
||||||
|
baseEntityLabel: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Take GraphDocument as input as uses it to construct a graph in Memgraph.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- graph_documents (List[GraphDocument]): A list of GraphDocument objects
|
||||||
|
that contain the nodes and relationships to be added to the graph. Each
|
||||||
|
GraphDocument should encapsulate the structure of part of the graph,
|
||||||
|
including nodes, relationships, and the source document information.
|
||||||
|
- include_source (bool, optional): If True, stores the source document
|
||||||
|
and links it to nodes in the graph using the MENTIONS relationship.
|
||||||
|
This is useful for tracing back the origin of data. Merges source
|
||||||
|
documents based on the `id` property from the source document metadata
|
||||||
|
if available; otherwise it calculates the MD5 hash of `page_content`
|
||||||
|
for merging process. Defaults to False.
|
||||||
|
- baseEntityLabel (bool, optional): If True, each newly created node
|
||||||
|
gets a secondary __Entity__ label, which is indexed and improves import
|
||||||
|
speed and performance. Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if baseEntityLabel:
|
||||||
|
self.query(
|
||||||
|
f"CREATE CONSTRAINT ON (b:{BASE_ENTITY_LABEL}) "
|
||||||
|
"ASSERT b.id IS UNIQUE;"
|
||||||
|
)
|
||||||
|
self.query(f"CREATE INDEX ON :{BASE_ENTITY_LABEL}(id);")
|
||||||
|
self.query(f"CREATE INDEX ON :{BASE_ENTITY_LABEL};")
|
||||||
|
|
||||||
|
for document in graph_documents:
|
||||||
|
if include_source:
|
||||||
|
if not document.source.metadata.get("id"):
|
||||||
|
document.source.metadata["id"] = md5(
|
||||||
|
document.source.page_content.encode("utf-8")
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
self.query(INCLUDE_DOCS_QUERY, {"document": document.source.__dict__})
|
||||||
|
|
||||||
|
self.query(
|
||||||
|
NODE_IMPORT_QUERY,
|
||||||
|
{"data": _transform_nodes(document.nodes, baseEntityLabel)},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Format relationship properties
|
rel_data = _transform_relationships(document.relationships, baseEntityLabel)
|
||||||
formatted_rel_props = []
|
self.query(
|
||||||
for rel_name, properties in db_structured_schema["rel_props"].items():
|
REL_NODES_IMPORT_QUERY,
|
||||||
formatted_rel_props.append(
|
{"data": rel_data},
|
||||||
f"Relationship name: '{rel_name}', "
|
)
|
||||||
f"Relationship properties: {properties}"
|
self.query(
|
||||||
|
REL_IMPORT_QUERY,
|
||||||
|
{"data": rel_data},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Format relationships
|
if include_source:
|
||||||
formatted_rels = [
|
self.query(
|
||||||
f"(:{rel['start']})-[:{rel['type']}]->(:{rel['end']})"
|
INCLUDE_DOCS_SOURCE_QUERY,
|
||||||
for rel in db_structured_schema["relationships"]
|
{"data": rel_data, "document": document.source.__dict__},
|
||||||
]
|
)
|
||||||
|
self.refresh_schema()
|
||||||
self.schema = "\n".join(
|
|
||||||
[
|
|
||||||
"Node properties are the following:",
|
|
||||||
*formatted_node_props,
|
|
||||||
"Relationship properties are the following:",
|
|
||||||
*formatted_rel_props,
|
|
||||||
"The relationships are the following:",
|
|
||||||
*formatted_rels,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
@ -1,24 +1,44 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
from langchain_community.graphs import MemgraphGraph
|
from langchain_community.graphs import MemgraphGraph
|
||||||
|
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||||
|
from langchain_community.graphs.memgraph_graph import NODE_PROPERTIES_QUERY, REL_QUERY
|
||||||
|
|
||||||
|
test_data = [
|
||||||
|
GraphDocument(
|
||||||
|
nodes=[Node(id="foo", type="foo"), Node(id="bar", type="bar")],
|
||||||
|
relationships=[
|
||||||
|
Relationship(
|
||||||
|
source=Node(id="foo", type="foo"),
|
||||||
|
target=Node(id="bar", type="bar"),
|
||||||
|
type="REL",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
source=Document(page_content="source document"),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_cypher_return_correct_schema() -> None:
|
def test_cypher_return_correct_schema() -> None:
|
||||||
"""Test that chain returns direct results."""
|
"""Test that chain returns direct results."""
|
||||||
|
|
||||||
url = os.environ.get("MEMGRAPH_URI", "bolt://localhost:7687")
|
url = os.environ.get("MEMGRAPH_URI", "bolt://localhost:7687")
|
||||||
username = os.environ.get("MEMGRAPH_USERNAME", "")
|
username = os.environ.get("MEMGRAPH_USERNAME", "")
|
||||||
password = os.environ.get("MEMGRAPH_PASSWORD", "")
|
password = os.environ.get("MEMGRAPH_PASSWORD", "")
|
||||||
|
|
||||||
assert url is not None
|
assert url is not None
|
||||||
assert username is not None
|
assert username is not None
|
||||||
assert password is not None
|
assert password is not None
|
||||||
|
|
||||||
graph = MemgraphGraph(
|
graph = MemgraphGraph(url=url, username=username, password=password)
|
||||||
url=url,
|
|
||||||
username=username,
|
# Drop graph
|
||||||
password=password,
|
graph.query("STORAGE MODE IN_MEMORY_ANALYTICAL")
|
||||||
)
|
graph.query("DROP GRAPH")
|
||||||
# Delete all nodes in the graph
|
graph.query("STORAGE MODE IN_MEMORY_TRANSACTIONAL")
|
||||||
graph.query("MATCH (n) DETACH DELETE n")
|
|
||||||
# Create two nodes and a relationship
|
# Create two nodes and a relationship
|
||||||
graph.query(
|
graph.query(
|
||||||
"""
|
"""
|
||||||
@ -31,32 +51,123 @@ def test_cypher_return_correct_schema() -> None:
|
|||||||
)
|
)
|
||||||
# Refresh schema information
|
# Refresh schema information
|
||||||
graph.refresh_schema()
|
graph.refresh_schema()
|
||||||
relationships = graph.query(
|
|
||||||
"CALL llm_util.schema('raw') YIELD schema "
|
|
||||||
"WITH schema.relationships AS relationships "
|
|
||||||
"UNWIND relationships AS relationship "
|
|
||||||
"RETURN relationship['start'] AS start, "
|
|
||||||
"relationship['type'] AS type, "
|
|
||||||
"relationship['end'] AS end "
|
|
||||||
"ORDER BY start, type, end;"
|
|
||||||
)
|
|
||||||
|
|
||||||
node_props = graph.query(
|
node_properties = graph.query(NODE_PROPERTIES_QUERY)
|
||||||
"CALL llm_util.schema('raw') YIELD schema "
|
relationships = graph.query(REL_QUERY)
|
||||||
"WITH schema.node_props AS nodes "
|
|
||||||
"WITH nodes['LabelA'] AS properties "
|
|
||||||
"UNWIND properties AS property "
|
|
||||||
"RETURN property['property'] AS prop, "
|
|
||||||
"property['type'] AS type "
|
|
||||||
"ORDER BY prop ASC;"
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_relationships = [
|
expected_node_properties = [
|
||||||
{"start": "LabelA", "type": "REL_TYPE", "end": "LabelB"},
|
{
|
||||||
{"start": "LabelA", "type": "REL_TYPE", "end": "LabelC"},
|
"output": {
|
||||||
|
"labels": ":`LabelA`",
|
||||||
|
"properties": [{"key": "property_a", "types": ["String"]}],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"output": {"labels": ":`LabelB`", "properties": [{"key": "", "types": []}]}},
|
||||||
|
{"output": {"labels": ":`LabelC`", "properties": [{"key": "", "types": []}]}},
|
||||||
]
|
]
|
||||||
|
|
||||||
expected_node_props = [{"prop": "property_a", "type": "str"}]
|
expected_relationships = [
|
||||||
|
{
|
||||||
|
"start_node_labels": ["LabelA"],
|
||||||
|
"rel_type": "REL_TYPE",
|
||||||
|
"end_node_labels": ["LabelC"],
|
||||||
|
"properties_info": [["rel_prop", "STRING"]],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"start_node_labels": ["LabelA"],
|
||||||
|
"rel_type": "REL_TYPE",
|
||||||
|
"end_node_labels": ["LabelB"],
|
||||||
|
"properties_info": [],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
graph.close()
|
||||||
|
|
||||||
|
assert node_properties == expected_node_properties
|
||||||
assert relationships == expected_relationships
|
assert relationships == expected_relationships
|
||||||
assert node_props == expected_node_props
|
|
||||||
|
|
||||||
|
def test_add_graph_documents() -> None:
|
||||||
|
"""Test that Memgraph correctly imports graph document."""
|
||||||
|
url = os.environ.get("MEMGRAPH_URI", "bolt://localhost:7687")
|
||||||
|
username = os.environ.get("MEMGRAPH_USERNAME", "")
|
||||||
|
password = os.environ.get("MEMGRAPH_PASSWORD", "")
|
||||||
|
|
||||||
|
assert url is not None
|
||||||
|
assert username is not None
|
||||||
|
assert password is not None
|
||||||
|
|
||||||
|
graph = MemgraphGraph(
|
||||||
|
url=url, username=username, password=password, refresh_schema=False
|
||||||
|
)
|
||||||
|
# Drop graph
|
||||||
|
graph.query("STORAGE MODE IN_MEMORY_ANALYTICAL")
|
||||||
|
graph.query("DROP GRAPH")
|
||||||
|
graph.query("STORAGE MODE IN_MEMORY_TRANSACTIONAL")
|
||||||
|
# Create KG
|
||||||
|
graph.add_graph_documents(test_data)
|
||||||
|
output = graph.query("MATCH (n) RETURN labels(n) AS label, count(*) AS count")
|
||||||
|
# Close the connection
|
||||||
|
graph.close()
|
||||||
|
assert output == [{"label": ["bar"], "count": 1}, {"label": ["foo"], "count": 1}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_graph_documents_base_entity() -> None:
|
||||||
|
"""Test that Memgraph correctly imports graph document with Entity label."""
|
||||||
|
url = os.environ.get("MEMGRAPH_URI", "bolt://localhost:7687")
|
||||||
|
username = os.environ.get("MEMGRAPH_USERNAME", "")
|
||||||
|
password = os.environ.get("MEMGRAPH_PASSWORD", "")
|
||||||
|
|
||||||
|
assert url is not None
|
||||||
|
assert username is not None
|
||||||
|
assert password is not None
|
||||||
|
|
||||||
|
graph = MemgraphGraph(
|
||||||
|
url=url, username=username, password=password, refresh_schema=False
|
||||||
|
)
|
||||||
|
# Drop graph
|
||||||
|
graph.query("STORAGE MODE IN_MEMORY_ANALYTICAL")
|
||||||
|
graph.query("DROP GRAPH")
|
||||||
|
graph.query("STORAGE MODE IN_MEMORY_TRANSACTIONAL")
|
||||||
|
# Create KG
|
||||||
|
graph.add_graph_documents(test_data, baseEntityLabel=True)
|
||||||
|
output = graph.query("MATCH (n) RETURN labels(n) AS label, count(*) AS count")
|
||||||
|
|
||||||
|
# Close the connection
|
||||||
|
graph.close()
|
||||||
|
|
||||||
|
assert output == [
|
||||||
|
{"label": ["__Entity__", "bar"], "count": 1},
|
||||||
|
{"label": ["__Entity__", "foo"], "count": 1},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_graph_documents_include_source() -> None:
|
||||||
|
"""Test that Memgraph correctly imports graph document with source included."""
|
||||||
|
url = os.environ.get("MEMGRAPH_URI", "bolt://localhost:7687")
|
||||||
|
username = os.environ.get("MEMGRAPH_USERNAME", "")
|
||||||
|
password = os.environ.get("MEMGRAPH_PASSWORD", "")
|
||||||
|
|
||||||
|
assert url is not None
|
||||||
|
assert username is not None
|
||||||
|
assert password is not None
|
||||||
|
|
||||||
|
graph = MemgraphGraph(
|
||||||
|
url=url, username=username, password=password, refresh_schema=False
|
||||||
|
)
|
||||||
|
# Drop graph
|
||||||
|
graph.query("STORAGE MODE IN_MEMORY_ANALYTICAL")
|
||||||
|
graph.query("DROP GRAPH")
|
||||||
|
graph.query("STORAGE MODE IN_MEMORY_TRANSACTIONAL")
|
||||||
|
# Create KG
|
||||||
|
graph.add_graph_documents(test_data, include_source=True)
|
||||||
|
output = graph.query("MATCH (n) RETURN labels(n) AS label, count(*) AS count")
|
||||||
|
|
||||||
|
# Close the connection
|
||||||
|
graph.close()
|
||||||
|
|
||||||
|
assert output == [
|
||||||
|
{"label": ["bar"], "count": 1},
|
||||||
|
{"label": ["foo"], "count": 1},
|
||||||
|
{"label": ["Document"], "count": 1},
|
||||||
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user