mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 14:31:55 +00:00
community[minor]: add graph store implementation for apache age (#20582)
**Description:** implemented GraphStore class for Apache Age graph db **Dependencies:** depends on psycopg2 Unit and integration tests included. Formatting and linting have been run. --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
749
libs/community/langchain_community/graphs/age_graph.py
Normal file
749
libs/community/langchain_community/graphs/age_graph.py
Normal file
@@ -0,0 +1,749 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from hashlib import md5
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Tuple, Union
|
||||
|
||||
from langchain_community.graphs.graph_document import GraphDocument
|
||||
from langchain_community.graphs.graph_store import GraphStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import psycopg2.extras
|
||||
|
||||
|
||||
class AGEQueryException(Exception):
|
||||
"""Exception for the AGE queries."""
|
||||
|
||||
def __init__(self, exception: Union[str, Dict]) -> None:
|
||||
if isinstance(exception, dict):
|
||||
self.message = exception["message"] if "message" in exception else "unknown"
|
||||
self.details = exception["details"] if "details" in exception else "unknown"
|
||||
else:
|
||||
self.message = exception
|
||||
self.details = "unknown"
|
||||
|
||||
def get_message(self) -> str:
|
||||
return self.message
|
||||
|
||||
def get_details(self) -> Any:
|
||||
return self.details
|
||||
|
||||
|
||||
class AGEGraph(GraphStore):
|
||||
"""
|
||||
Apache AGE wrapper for graph operations.
|
||||
|
||||
Args:
|
||||
graph_name (str): the name of the graph to connect to or create
|
||||
conf (Dict[str, Any]): the pgsql connection config passed directly
|
||||
to psycopg2.connect
|
||||
create (bool): if True and graph doesn't exist, attempt to create it
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
# python type mapping for providing readable types to LLM
|
||||
types = {
|
||||
"str": "STRING",
|
||||
"float": "DOUBLE",
|
||||
"int": "INTEGER",
|
||||
"list": "LIST",
|
||||
"dict": "MAP",
|
||||
"bool": "BOOLEAN",
|
||||
}
|
||||
|
||||
# precompiled regex for checking chars in graph labels
|
||||
label_regex = re.compile("[^0-9a-zA-Z]+")
|
||||
|
||||
def __init__(
|
||||
self, graph_name: str, conf: Dict[str, Any], create: bool = True
|
||||
) -> None:
|
||||
"""Create a new AGEGraph instance."""
|
||||
|
||||
self.graph_name = graph_name
|
||||
|
||||
# check that psycopg2 is installed
|
||||
try:
|
||||
import psycopg2
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import psycopg2 python package. "
|
||||
"Please install it with `pip install psycopg2`."
|
||||
)
|
||||
|
||||
self.connection = psycopg2.connect(**conf)
|
||||
|
||||
with self._get_cursor() as curs:
|
||||
# check if graph with name graph_name exists
|
||||
graph_id_query = (
|
||||
"""SELECT graphid FROM ag_catalog.ag_graph WHERE name = '{}'""".format(
|
||||
graph_name
|
||||
)
|
||||
)
|
||||
|
||||
curs.execute(graph_id_query)
|
||||
data = curs.fetchone()
|
||||
|
||||
# if graph doesn't exist and create is True, create it
|
||||
if data is None:
|
||||
if create:
|
||||
create_statement = """
|
||||
SELECT ag_catalog.create_graph('{}');
|
||||
""".format(graph_name)
|
||||
|
||||
try:
|
||||
curs.execute(create_statement)
|
||||
self.connection.commit()
|
||||
except psycopg2.Error as e:
|
||||
raise AGEQueryException(
|
||||
{
|
||||
"message": "Could not create the graph",
|
||||
"detail": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception(
|
||||
(
|
||||
'Graph "{}" does not exist in the database '
|
||||
+ 'and "create" is set to False'
|
||||
).format(graph_name)
|
||||
)
|
||||
|
||||
curs.execute(graph_id_query)
|
||||
data = curs.fetchone()
|
||||
|
||||
# store graph id and refresh the schema
|
||||
self.graphid = data.graphid
|
||||
self.refresh_schema()
|
||||
|
||||
def _get_cursor(self) -> psycopg2.extras.NamedTupleCursor:
|
||||
"""
|
||||
get cursor, load age extension and set search path
|
||||
"""
|
||||
|
||||
try:
|
||||
import psycopg2.extras
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import psycopg2, please install with "
|
||||
"`pip install -U psycopg2`."
|
||||
) from e
|
||||
cursor = self.connection.cursor(cursor_factory=psycopg2.extras.NamedTupleCursor)
|
||||
cursor.execute("""LOAD 'age';""")
|
||||
cursor.execute("""SET search_path = ag_catalog, "$user", public;""")
|
||||
return cursor
|
||||
|
||||
def _get_labels(self) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Get all labels of a graph (for both edges and vertices)
|
||||
by querying the graph metadata table directly
|
||||
|
||||
Returns
|
||||
Tuple[List[str]]: 2 lists, the first containing vertex
|
||||
labels and the second containing edge labels
|
||||
"""
|
||||
|
||||
e_labels_records = self.query(
|
||||
"""MATCH ()-[e]-() RETURN collect(distinct label(e)) as labels"""
|
||||
)
|
||||
e_labels = e_labels_records[0]["labels"] if e_labels_records else []
|
||||
|
||||
n_labels_records = self.query(
|
||||
"""MATCH (n) RETURN collect(distinct label(n)) as labels"""
|
||||
)
|
||||
n_labels = n_labels_records[0]["labels"] if n_labels_records else []
|
||||
|
||||
return n_labels, e_labels
|
||||
|
||||
def _get_triples(self, e_labels: List[str]) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Get a set of distinct relationship types (as a list of dicts) in the graph
|
||||
to be used as context by an llm.
|
||||
|
||||
Args:
|
||||
e_labels (List[str]): a list of edge labels to filter for
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: relationships as a list of dicts in the format
|
||||
"{'start':<from_label>, 'type':<edge_label>, 'end':<from_label>}"
|
||||
"""
|
||||
|
||||
# age query to get distinct relationship types
|
||||
try:
|
||||
import psycopg2
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import psycopg2, please install with "
|
||||
"`pip install -U psycopg2`."
|
||||
) from e
|
||||
triple_query = """
|
||||
SELECT * FROM ag_catalog.cypher('{graph_name}', $$
|
||||
MATCH (a)-[e:`{e_label}`]->(b)
|
||||
WITH a,e,b LIMIT 3000
|
||||
RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to
|
||||
LIMIT 10
|
||||
$$) AS (f agtype, edge agtype, t agtype);
|
||||
"""
|
||||
|
||||
triple_schema = []
|
||||
|
||||
# iterate desired edge types and add distinct relationship types to result
|
||||
with self._get_cursor() as curs:
|
||||
for label in e_labels:
|
||||
q = triple_query.format(graph_name=self.graph_name, e_label=label)
|
||||
try:
|
||||
curs.execute(q)
|
||||
data = curs.fetchall()
|
||||
for d in data:
|
||||
# use json.loads to convert returned
|
||||
# strings to python primitives
|
||||
triple_schema.append(
|
||||
{
|
||||
"start": json.loads(d.f)[0],
|
||||
"type": json.loads(d.edge),
|
||||
"end": json.loads(d.t)[0],
|
||||
}
|
||||
)
|
||||
except psycopg2.Error as e:
|
||||
raise AGEQueryException(
|
||||
{
|
||||
"message": "Error fetching triples",
|
||||
"detail": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
return triple_schema
|
||||
|
||||
def _get_triples_str(self, e_labels: List[str]) -> List[str]:
|
||||
"""
|
||||
Get a set of distinct relationship types (as a list of strings) in the graph
|
||||
to be used as context by an llm.
|
||||
|
||||
Args:
|
||||
e_labels (List[str]): a list of edge labels to filter for
|
||||
|
||||
Returns:
|
||||
List[str]: relationships as a list of strings in the format
|
||||
"(:`<from_label>`)-[:`<edge_label>`]->(:`<to_label>`)"
|
||||
"""
|
||||
|
||||
triples = self._get_triples(e_labels)
|
||||
|
||||
return self._format_triples(triples)
|
||||
|
||||
@staticmethod
|
||||
def _format_triples(triples: List[Dict[str, str]]) -> List[str]:
|
||||
"""
|
||||
Convert a list of relationships from dictionaries to formatted strings
|
||||
to be better readable by an llm
|
||||
|
||||
Args:
|
||||
triples (List[Dict[str,str]]): a list relationships in the form
|
||||
{'start':<from_label>, 'type':<edge_label>, 'end':<from_label>}
|
||||
|
||||
Returns:
|
||||
List[str]: a list of relationships in the form
|
||||
"(:`<from_label>`)-[:`<edge_label>`]->(:`<to_label>`)"
|
||||
"""
|
||||
triple_template = "(:`{start}`)-[:`{type}`]->(:`{end}`)"
|
||||
triple_schema = [triple_template.format(**triple) for triple in triples]
|
||||
|
||||
return triple_schema
|
||||
|
||||
def _get_node_properties(self, n_labels: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch a list of available node properties by node label to be used
|
||||
as context for an llm
|
||||
|
||||
Args:
|
||||
n_labels (List[str]): a list of node labels to filter for
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: a list of node labels and
|
||||
their corresponding properties in the form
|
||||
"{
|
||||
'labels': <node_label>,
|
||||
'properties': [
|
||||
{
|
||||
'property': <property_name>,
|
||||
'type': <property_type>
|
||||
},...
|
||||
]
|
||||
}"
|
||||
"""
|
||||
try:
|
||||
import psycopg2
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import psycopg2, please install with "
|
||||
"`pip install -U psycopg2`."
|
||||
) from e
|
||||
|
||||
# cypher query to fetch properties of a given label
|
||||
node_properties_query = """
|
||||
SELECT * FROM ag_catalog.cypher('{graph_name}', $$
|
||||
MATCH (a:`{n_label}`)
|
||||
RETURN properties(a) AS props
|
||||
LIMIT 100
|
||||
$$) AS (props agtype);
|
||||
"""
|
||||
|
||||
node_properties = []
|
||||
with self._get_cursor() as curs:
|
||||
for label in n_labels:
|
||||
q = node_properties_query.format(
|
||||
graph_name=self.graph_name, n_label=label
|
||||
)
|
||||
|
||||
try:
|
||||
curs.execute(q)
|
||||
except psycopg2.Error as e:
|
||||
raise AGEQueryException(
|
||||
{
|
||||
"message": "Error fetching node properties",
|
||||
"detail": str(e),
|
||||
}
|
||||
)
|
||||
data = curs.fetchall()
|
||||
|
||||
# build a set of distinct properties
|
||||
s = set({})
|
||||
for d in data:
|
||||
# use json.loads to convert to python
|
||||
# primitive and get readable type
|
||||
for k, v in json.loads(d.props).items():
|
||||
s.add((k, self.types[type(v).__name__]))
|
||||
|
||||
np = {
|
||||
"properties": [{"property": k, "type": v} for k, v in s],
|
||||
"labels": label,
|
||||
}
|
||||
node_properties.append(np)
|
||||
|
||||
return node_properties
|
||||
|
||||
def _get_edge_properties(self, e_labels: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch a list of available edge properties by edge label to be used
|
||||
as context for an llm
|
||||
|
||||
Args:
|
||||
e_labels (List[str]): a list of edge labels to filter for
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: a list of edge labels
|
||||
and their corresponding properties in the form
|
||||
"{
|
||||
'labels': <edge_label>,
|
||||
'properties': [
|
||||
{
|
||||
'property': <property_name>,
|
||||
'type': <property_type>
|
||||
},...
|
||||
]
|
||||
}"
|
||||
"""
|
||||
|
||||
try:
|
||||
import psycopg2
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import psycopg2, please install with "
|
||||
"`pip install -U psycopg2`."
|
||||
) from e
|
||||
# cypher query to fetch properties of a given label
|
||||
edge_properties_query = """
|
||||
SELECT * FROM ag_catalog.cypher('{graph_name}', $$
|
||||
MATCH ()-[e:`{e_label}`]->()
|
||||
RETURN properties(e) AS props
|
||||
LIMIT 100
|
||||
$$) AS (props agtype);
|
||||
"""
|
||||
edge_properties = []
|
||||
with self._get_cursor() as curs:
|
||||
for label in e_labels:
|
||||
q = edge_properties_query.format(
|
||||
graph_name=self.graph_name, e_label=label
|
||||
)
|
||||
|
||||
try:
|
||||
curs.execute(q)
|
||||
except psycopg2.Error as e:
|
||||
raise AGEQueryException(
|
||||
{
|
||||
"message": "Error fetching edge properties",
|
||||
"detail": str(e),
|
||||
}
|
||||
)
|
||||
data = curs.fetchall()
|
||||
|
||||
# build a set of distinct properties
|
||||
s = set({})
|
||||
for d in data:
|
||||
# use json.loads to convert to python
|
||||
# primitive and get readable type
|
||||
for k, v in json.loads(d.props).items():
|
||||
s.add((k, self.types[type(v).__name__]))
|
||||
|
||||
np = {
|
||||
"properties": [{"property": k, "type": v} for k, v in s],
|
||||
"type": label,
|
||||
}
|
||||
edge_properties.append(np)
|
||||
|
||||
return edge_properties
|
||||
|
||||
def refresh_schema(self) -> None:
|
||||
"""
|
||||
Refresh the graph schema information by updating the available
|
||||
labels, relationships, and properties
|
||||
"""
|
||||
|
||||
# fetch graph schema information
|
||||
n_labels, e_labels = self._get_labels()
|
||||
triple_schema = self._get_triples(e_labels)
|
||||
|
||||
node_properties = self._get_node_properties(n_labels)
|
||||
edge_properties = self._get_edge_properties(e_labels)
|
||||
|
||||
# update the formatted string representation
|
||||
self.schema = f"""
|
||||
Node properties are the following:
|
||||
{node_properties}
|
||||
Relationship properties are the following:
|
||||
{edge_properties}
|
||||
The relationships are the following:
|
||||
{self._format_triples(triple_schema)}
|
||||
"""
|
||||
|
||||
# update the dictionary representation
|
||||
self.structured_schema = {
|
||||
"node_props": {el["labels"]: el["properties"] for el in node_properties},
|
||||
"rel_props": {el["type"]: el["properties"] for el in edge_properties},
|
||||
"relationships": triple_schema,
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
@property
|
||||
def get_schema(self) -> str:
|
||||
"""Returns the schema of the Graph"""
|
||||
return self.schema
|
||||
|
||||
@property
|
||||
def get_structured_schema(self) -> Dict[str, Any]:
|
||||
"""Returns the structured schema of the Graph"""
|
||||
return self.structured_schema
|
||||
|
||||
@staticmethod
|
||||
def _get_col_name(field: str, idx: int) -> str:
|
||||
"""
|
||||
Convert a cypher return field to a pgsql select field
|
||||
If possible keep the cypher column name, but create a generic name if necessary
|
||||
|
||||
Args:
|
||||
field (str): a return field from a cypher query to be formatted for pgsql
|
||||
idx (int): the position of the field in the return statement
|
||||
|
||||
Returns:
|
||||
str: the field to be used in the pgsql select statement
|
||||
"""
|
||||
# remove white space
|
||||
field = field.strip()
|
||||
# if an alias is provided for the field, use it
|
||||
if " as " in field:
|
||||
return field.split(" as ")[-1].strip()
|
||||
# if the return value is an unnamed primitive, give it a generic name
|
||||
elif field.isnumeric() or field in ("true", "false", "null"):
|
||||
return f"column_{idx}"
|
||||
# otherwise return the value stripping out some common special chars
|
||||
else:
|
||||
return field.replace("(", "_").replace(")", "")
|
||||
|
||||
@staticmethod
|
||||
def _wrap_query(query: str, graph_name: str) -> str:
|
||||
"""
|
||||
Convert a cypher query to an Apache Age compatible
|
||||
sql query by wrapping the cypher query in ag_catalog.cypher,
|
||||
casting results to agtype and building a select statement
|
||||
|
||||
Args:
|
||||
query (str): a valid cypher query
|
||||
graph_name (str): the name of the graph to query
|
||||
|
||||
Returns:
|
||||
str: an equivalent pgsql query
|
||||
"""
|
||||
|
||||
# pgsql template
|
||||
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
|
||||
{query}
|
||||
$$) AS ({fields});"""
|
||||
|
||||
# if there are any returned fields they must be added to the pgsql query
|
||||
if "return" in query.lower():
|
||||
# parse return statement to identify returned fields
|
||||
fields = (
|
||||
query.lower()
|
||||
.split("return")[-1]
|
||||
.split("distinct")[-1]
|
||||
.split("order by")[0]
|
||||
.split("skip")[0]
|
||||
.split("limit")[0]
|
||||
.split(",")
|
||||
)
|
||||
|
||||
# raise exception if RETURN * is found as we can't resolve the fields
|
||||
if "*" in [x.strip() for x in fields]:
|
||||
raise ValueError(
|
||||
"AGE graph does not support 'RETURN *'"
|
||||
+ " statements in Cypher queries"
|
||||
)
|
||||
|
||||
# get pgsql formatted field names
|
||||
fields = [
|
||||
AGEGraph._get_col_name(field, idx) for idx, field in enumerate(fields)
|
||||
]
|
||||
|
||||
# build resulting pgsql relation
|
||||
fields_str = ", ".join(
|
||||
[field.split(".")[-1] + " agtype" for field in fields]
|
||||
)
|
||||
|
||||
# if no return statement we still need to return a single field of type agtype
|
||||
else:
|
||||
fields_str = "a agtype"
|
||||
|
||||
select_str = "*"
|
||||
|
||||
return template.format(
|
||||
graph_name=graph_name,
|
||||
query=query,
|
||||
fields=fields_str,
|
||||
projection=select_str,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _record_to_dict(record: NamedTuple) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a record returned from an age query to a dictionary
|
||||
|
||||
Args:
|
||||
record (): a record from an age query result
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: a dictionary representation of the record where
|
||||
the dictionary key is the field name and the value is the
|
||||
value converted to a python type
|
||||
"""
|
||||
# result holder
|
||||
d = {}
|
||||
|
||||
# prebuild a mapping of vertex_id to vertex mappings to be used
|
||||
# later to build edges
|
||||
vertices = {}
|
||||
for k in record._fields:
|
||||
v = getattr(record, k)
|
||||
# agtype comes back '{key: value}::type' which must be parsed
|
||||
if isinstance(v, str) and "::" in v:
|
||||
dtype = v.split("::")[-1]
|
||||
v = v.split("::")[0]
|
||||
if dtype == "vertex":
|
||||
vertex = json.loads(v)
|
||||
vertices[vertex["id"]] = vertex.get("properties")
|
||||
|
||||
# iterate returned fields and parse appropriately
|
||||
for k in record._fields:
|
||||
v = getattr(record, k)
|
||||
if isinstance(v, str) and "::" in v:
|
||||
dtype = v.split("::")[-1]
|
||||
v = v.split("::")[0]
|
||||
else:
|
||||
dtype = ""
|
||||
|
||||
if dtype == "vertex":
|
||||
d[k] = json.loads(v).get("properties")
|
||||
# convert edge from id-label->id by replacing id with node information
|
||||
# we only do this if the vertex was also returned in the query
|
||||
# this is an attempt to be consistent with neo4j implementation
|
||||
elif dtype == "edge":
|
||||
edge = json.loads(v)
|
||||
d[k] = (
|
||||
vertices.get(edge["start_id"], {}),
|
||||
edge["label"],
|
||||
vertices.get(edge["end_id"], {}),
|
||||
)
|
||||
else:
|
||||
d[k] = json.loads(v) if isinstance(v, str) else v
|
||||
|
||||
return d
|
||||
|
||||
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query the graph by taking a cypher query, converting it to an
|
||||
age compatible query, executing it and converting the result
|
||||
|
||||
Args:
|
||||
query (str): a cypher query to be executed
|
||||
params (dict): parameters for the query (not used in this implementation)
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: a list of dictionaries containing the result set
|
||||
"""
|
||||
try:
|
||||
import psycopg2
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import psycopg2, please install with "
|
||||
"`pip install -U psycopg2`."
|
||||
) from e
|
||||
|
||||
# convert cypher query to pgsql/age query
|
||||
wrapped_query = self._wrap_query(query, self.graph_name)
|
||||
|
||||
# execute the query, rolling back on an error
|
||||
with self._get_cursor() as curs:
|
||||
try:
|
||||
curs.execute(wrapped_query)
|
||||
self.connection.commit()
|
||||
except psycopg2.Error as e:
|
||||
self.connection.rollback()
|
||||
raise AGEQueryException(
|
||||
{
|
||||
"message": "Error executing graph query: {}".format(query),
|
||||
"detail": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
data = curs.fetchall()
|
||||
if data is None:
|
||||
result = []
|
||||
# convert to dictionaries
|
||||
else:
|
||||
result = [self._record_to_dict(d) for d in data]
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _format_properties(
|
||||
properties: Dict[str, Any], id: Union[str, None] = None
|
||||
) -> str:
|
||||
"""
|
||||
Convert a dictionary of properties to a string representation that
|
||||
can be used in a cypher query insert/merge statement.
|
||||
|
||||
Args:
|
||||
properties (Dict[str,str]): a dictionary containing node/edge properties
|
||||
id (Union[str, None]): the id of the node or None if none exists
|
||||
|
||||
Returns:
|
||||
str: the properties dictionary as a properly formatted string
|
||||
"""
|
||||
props = []
|
||||
# wrap property key in backticks to escape
|
||||
for k, v in properties.items():
|
||||
prop = f"`{k}`: {json.dumps(v)}"
|
||||
props.append(prop)
|
||||
if id is not None and "id" not in properties:
|
||||
props.append(
|
||||
f"id: {json.dumps(id)}" if isinstance(id, str) else f"id: {id}"
|
||||
)
|
||||
return "{" + ", ".join(props) + "}"
|
||||
|
||||
@staticmethod
|
||||
def clean_graph_labels(label: str) -> str:
|
||||
"""
|
||||
remove any disallowed characters from a label and replace with '_'
|
||||
|
||||
Args:
|
||||
label (str): the original label
|
||||
|
||||
Returns:
|
||||
str: the sanitized version of the label
|
||||
"""
|
||||
return re.sub(AGEGraph.label_regex, "_", label)
|
||||
|
||||
def add_graph_documents(
|
||||
self, graph_documents: List[GraphDocument], include_source: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
insert a list of graph documents into the graph
|
||||
|
||||
Args:
|
||||
graph_documents (List[GraphDocument]): the list of documents to be inserted
|
||||
include_source (bool): if True add nodes for the sources
|
||||
with MENTIONS edges to the entities they mention
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# query for inserting nodes
|
||||
node_insert_query = (
|
||||
"""
|
||||
MERGE (n:`{label}` {properties})
|
||||
"""
|
||||
if not include_source
|
||||
else """
|
||||
MERGE (n:`{label}` {properties})
|
||||
MERGE (d:Document {d_properties})
|
||||
MERGE (d)-[:MENTIONS]->(n)
|
||||
"""
|
||||
)
|
||||
|
||||
# query for inserting edges
|
||||
edge_insert_query = """
|
||||
MERGE (from:`{f_label}` {f_properties})
|
||||
MERGE (to:`{t_label}` {t_properties})
|
||||
MERGE (from)-[:`{r_label}` {r_properties}]->(to)
|
||||
"""
|
||||
# iterate docs and insert them
|
||||
for doc in graph_documents:
|
||||
# if we are adding sources, create an id for the source
|
||||
if include_source:
|
||||
if not doc.source.metadata.get("id"):
|
||||
doc.source.metadata["id"] = md5(
|
||||
doc.source.page_content.encode("utf-8")
|
||||
).hexdigest()
|
||||
|
||||
# insert entity nodes
|
||||
for node in doc.nodes:
|
||||
node.properties["id"] = node.id
|
||||
if include_source:
|
||||
query = node_insert_query.format(
|
||||
label=node.type,
|
||||
properties=self._format_properties(node.properties),
|
||||
d_properties=self._format_properties(doc.source.metadata),
|
||||
)
|
||||
else:
|
||||
query = node_insert_query.format(
|
||||
label=AGEGraph.clean_graph_labels(node.type),
|
||||
properties=self._format_properties(node.properties),
|
||||
)
|
||||
|
||||
self.query(query)
|
||||
|
||||
# insert relationships
|
||||
for edge in doc.relationships:
|
||||
edge.source.properties["id"] = edge.source.id
|
||||
edge.target.properties["id"] = edge.target.id
|
||||
inputs = {
|
||||
"f_label": AGEGraph.clean_graph_labels(edge.source.type),
|
||||
"f_properties": self._format_properties(edge.source.properties),
|
||||
"t_label": AGEGraph.clean_graph_labels(edge.target.type),
|
||||
"t_properties": self._format_properties(edge.target.properties),
|
||||
"r_label": AGEGraph.clean_graph_labels(edge.type).upper(),
|
||||
"r_properties": self._format_properties(edge.properties),
|
||||
}
|
||||
|
||||
query = edge_insert_query.format(**inputs)
|
||||
self.query(query)
|
337
libs/community/tests/integration_tests/graphs/test_age_graph.py
Normal file
337
libs/community/tests/integration_tests/graphs/test_age_graph.py
Normal file
@@ -0,0 +1,337 @@
|
||||
import os
|
||||
import re
|
||||
import unittest
|
||||
from typing import Any, Dict
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.graphs.age_graph import AGEGraph
|
||||
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||
|
||||
test_data = [
|
||||
GraphDocument(
|
||||
nodes=[Node(id="foo", type="foo"), Node(id="bar", type="bar")],
|
||||
relationships=[
|
||||
Relationship(
|
||||
source=Node(id="foo", type="foo"),
|
||||
target=Node(id="bar", type="bar"),
|
||||
type="REL",
|
||||
)
|
||||
],
|
||||
source=Document(page_content="source document"),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class TestAGEGraph(unittest.TestCase):
|
||||
def test_node_properties(self) -> None:
|
||||
conf = {
|
||||
"database": os.getenv("AGE_PGSQL_DB"),
|
||||
"user": os.getenv("AGE_PGSQL_USER"),
|
||||
"password": os.getenv("AGE_PGSQL_PASSWORD"),
|
||||
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
|
||||
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
|
||||
}
|
||||
|
||||
self.assertIsNotNone(conf["database"])
|
||||
self.assertIsNotNone(conf["user"])
|
||||
self.assertIsNotNone(conf["password"])
|
||||
|
||||
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
|
||||
|
||||
graph = AGEGraph(graph_name, conf)
|
||||
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
|
||||
# Create two nodes and a relationship
|
||||
graph.query(
|
||||
"""
|
||||
CREATE (la:LabelA {property_a: 'a'})
|
||||
CREATE (lb:LabelB)
|
||||
CREATE (lc:LabelC)
|
||||
MERGE (la)-[:REL_TYPE]-> (lb)
|
||||
MERGE (la)-[:REL_TYPE {rel_prop: 'abc'}]-> (lc)
|
||||
"""
|
||||
)
|
||||
# Refresh schema information
|
||||
# graph.refresh_schema()
|
||||
|
||||
n_labels, e_labels = graph._get_labels()
|
||||
|
||||
node_properties = graph._get_node_properties(n_labels)
|
||||
|
||||
expected_node_properties = [
|
||||
{
|
||||
"properties": [{"property": "property_a", "type": "STRING"}],
|
||||
"labels": "LabelA",
|
||||
},
|
||||
{
|
||||
"properties": [],
|
||||
"labels": "LabelB",
|
||||
},
|
||||
{
|
||||
"properties": [],
|
||||
"labels": "LabelC",
|
||||
},
|
||||
]
|
||||
|
||||
self.assertEqual(
|
||||
sorted(node_properties, key=lambda x: x["labels"]), expected_node_properties
|
||||
)
|
||||
|
||||
def test_edge_properties(self) -> None:
|
||||
conf = {
|
||||
"database": os.getenv("AGE_PGSQL_DB"),
|
||||
"user": os.getenv("AGE_PGSQL_USER"),
|
||||
"password": os.getenv("AGE_PGSQL_PASSWORD"),
|
||||
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
|
||||
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
|
||||
}
|
||||
|
||||
self.assertIsNotNone(conf["database"])
|
||||
self.assertIsNotNone(conf["user"])
|
||||
self.assertIsNotNone(conf["password"])
|
||||
|
||||
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
|
||||
|
||||
graph = AGEGraph(graph_name, conf)
|
||||
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
# Create two nodes and a relationship
|
||||
graph.query(
|
||||
"""
|
||||
CREATE (la:LabelA {property_a: 'a'})
|
||||
CREATE (lb:LabelB)
|
||||
CREATE (lc:LabelC)
|
||||
MERGE (la)-[:REL_TYPE]-> (lb)
|
||||
MERGE (la)-[:REL_TYPE {rel_prop: 'abc'}]-> (lc)
|
||||
"""
|
||||
)
|
||||
# Refresh schema information
|
||||
# graph.refresh_schema()
|
||||
|
||||
n_labels, e_labels = graph._get_labels()
|
||||
|
||||
relationships_properties = graph._get_edge_properties(e_labels)
|
||||
|
||||
expected_relationships_properties = [
|
||||
{
|
||||
"type": "REL_TYPE",
|
||||
"properties": [{"property": "rel_prop", "type": "STRING"}],
|
||||
}
|
||||
]
|
||||
|
||||
self.assertEqual(relationships_properties, expected_relationships_properties)
|
||||
|
||||
def test_relationships(self) -> None:
|
||||
conf = {
|
||||
"database": os.getenv("AGE_PGSQL_DB"),
|
||||
"user": os.getenv("AGE_PGSQL_USER"),
|
||||
"password": os.getenv("AGE_PGSQL_PASSWORD"),
|
||||
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
|
||||
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
|
||||
}
|
||||
|
||||
self.assertIsNotNone(conf["database"])
|
||||
self.assertIsNotNone(conf["user"])
|
||||
self.assertIsNotNone(conf["password"])
|
||||
|
||||
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
|
||||
|
||||
graph = AGEGraph(graph_name, conf)
|
||||
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
# Create two nodes and a relationship
|
||||
graph.query(
|
||||
"""
|
||||
CREATE (la:LabelA {property_a: 'a'})
|
||||
CREATE (lb:LabelB)
|
||||
CREATE (lc:LabelC)
|
||||
MERGE (la)-[:REL_TYPE]-> (lb)
|
||||
MERGE (la)-[:REL_TYPE {rel_prop: 'abc'}]-> (lc)
|
||||
"""
|
||||
)
|
||||
# Refresh schema information
|
||||
# graph.refresh_schema()
|
||||
|
||||
n_labels, e_labels = graph._get_labels()
|
||||
|
||||
relationships = graph._get_triples(e_labels)
|
||||
|
||||
expected_relationships = [
|
||||
{"start": "LabelA", "type": "REL_TYPE", "end": "LabelB"},
|
||||
{"start": "LabelA", "type": "REL_TYPE", "end": "LabelC"},
|
||||
]
|
||||
|
||||
self.assertEqual(
|
||||
sorted(relationships, key=lambda x: x["end"]), expected_relationships
|
||||
)
|
||||
|
||||
def test_add_documents(self) -> None:
|
||||
conf = {
|
||||
"database": os.getenv("AGE_PGSQL_DB"),
|
||||
"user": os.getenv("AGE_PGSQL_USER"),
|
||||
"password": os.getenv("AGE_PGSQL_PASSWORD"),
|
||||
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
|
||||
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
|
||||
}
|
||||
|
||||
self.assertIsNotNone(conf["database"])
|
||||
self.assertIsNotNone(conf["user"])
|
||||
self.assertIsNotNone(conf["password"])
|
||||
|
||||
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
|
||||
|
||||
graph = AGEGraph(graph_name, conf)
|
||||
|
||||
# Delete all nodes in the graph
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
# Create two nodes and a relationship
|
||||
graph.add_graph_documents(test_data)
|
||||
output = graph.query(
|
||||
"MATCH (n) RETURN labels(n) AS label, count(*) AS count ORDER BY labels(n)"
|
||||
)
|
||||
self.assertEqual(
|
||||
output, [{"label": ["bar"], "count": 1}, {"label": ["foo"], "count": 1}]
|
||||
)
|
||||
|
||||
def test_add_documents_source(self) -> None:
|
||||
conf = {
|
||||
"database": os.getenv("AGE_PGSQL_DB"),
|
||||
"user": os.getenv("AGE_PGSQL_USER"),
|
||||
"password": os.getenv("AGE_PGSQL_PASSWORD"),
|
||||
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
|
||||
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
|
||||
}
|
||||
|
||||
self.assertIsNotNone(conf["database"])
|
||||
self.assertIsNotNone(conf["user"])
|
||||
self.assertIsNotNone(conf["password"])
|
||||
|
||||
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
|
||||
|
||||
graph = AGEGraph(graph_name, conf)
|
||||
|
||||
# Delete all nodes in the graph
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
# Create two nodes and a relationship
|
||||
graph.add_graph_documents(test_data, include_source=True)
|
||||
output = graph.query(
|
||||
"MATCH (n) RETURN labels(n) AS label, count(*) AS count ORDER BY labels(n)"
|
||||
)
|
||||
|
||||
expected = [
|
||||
{"label": ["bar"], "count": 1},
|
||||
{"label": ["Document"], "count": 1},
|
||||
{"label": ["foo"], "count": 1},
|
||||
]
|
||||
self.assertEqual(output, expected)
|
||||
|
||||
def test_get_schema(self) -> None:
|
||||
conf = {
|
||||
"database": os.getenv("AGE_PGSQL_DB"),
|
||||
"user": os.getenv("AGE_PGSQL_USER"),
|
||||
"password": os.getenv("AGE_PGSQL_PASSWORD"),
|
||||
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
|
||||
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
|
||||
}
|
||||
|
||||
self.assertIsNotNone(conf["database"])
|
||||
self.assertIsNotNone(conf["user"])
|
||||
self.assertIsNotNone(conf["password"])
|
||||
|
||||
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
|
||||
|
||||
graph = AGEGraph(graph_name, conf)
|
||||
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
|
||||
graph.refresh_schema()
|
||||
|
||||
expected = """
|
||||
Node properties are the following:
|
||||
[]
|
||||
Relationship properties are the following:
|
||||
[]
|
||||
The relationships are the following:
|
||||
[]
|
||||
"""
|
||||
# check that works on empty schema
|
||||
self.assertEqual(
|
||||
re.sub(r"\s", "", graph.get_schema), re.sub(r"\s", "", expected)
|
||||
)
|
||||
|
||||
expected_structured: Dict[str, Any] = {
|
||||
"node_props": {},
|
||||
"rel_props": {},
|
||||
"relationships": [],
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
self.assertEqual(graph.get_structured_schema, expected_structured)
|
||||
|
||||
# Create two nodes and a relationship
|
||||
graph.query(
|
||||
"""
|
||||
MERGE (a:a {id: 1})-[b:b {id: 2}]-> (c:c {id: 3})
|
||||
"""
|
||||
)
|
||||
|
||||
# check that schema doesn't update without refresh
|
||||
self.assertEqual(
|
||||
re.sub(r"\s", "", graph.get_schema), re.sub(r"\s", "", expected)
|
||||
)
|
||||
self.assertEqual(graph.get_structured_schema, expected_structured)
|
||||
|
||||
# two possible orderings of node props
|
||||
expected_possibilities = [
|
||||
"""
|
||||
Node properties are the following:
|
||||
[
|
||||
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'labels': 'a'},
|
||||
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'labels': 'c'}
|
||||
]
|
||||
Relationship properties are the following:
|
||||
[
|
||||
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'type': 'b'}
|
||||
]
|
||||
The relationships are the following:
|
||||
[
|
||||
'(:`a`)-[:`b`]->(:`c`)'
|
||||
]
|
||||
""",
|
||||
"""
|
||||
Node properties are the following:
|
||||
[
|
||||
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'labels': 'c'},
|
||||
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'labels': 'a'}
|
||||
]
|
||||
Relationship properties are the following:
|
||||
[
|
||||
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'type': 'b'}
|
||||
]
|
||||
The relationships are the following:
|
||||
[
|
||||
'(:`a`)-[:`b`]->(:`c`)'
|
||||
]
|
||||
""",
|
||||
]
|
||||
|
||||
expected_structured2 = {
|
||||
"node_props": {
|
||||
"a": [{"property": "id", "type": "INTEGER"}],
|
||||
"c": [{"property": "id", "type": "INTEGER"}],
|
||||
},
|
||||
"rel_props": {"b": [{"property": "id", "type": "INTEGER"}]},
|
||||
"relationships": [{"start": "a", "type": "b", "end": "c"}],
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
graph.refresh_schema()
|
||||
|
||||
# check that schema is refreshed
|
||||
self.assertIn(
|
||||
re.sub(r"\s", "", graph.get_schema),
|
||||
[re.sub(r"\s", "", x) for x in expected_possibilities],
|
||||
)
|
||||
self.assertEqual(graph.get_structured_schema, expected_structured2)
|
145
libs/community/tests/unit_tests/graphs/test_age_graph.py
Normal file
145
libs/community/tests/unit_tests/graphs/test_age_graph.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import re
|
||||
import unittest
|
||||
from collections import namedtuple
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_community.graphs.age_graph import AGEGraph
|
||||
|
||||
|
||||
class TestAGEGraph(unittest.TestCase):
|
||||
def test_format_triples(self) -> None:
|
||||
test_input = [
|
||||
{"start": "from_a", "type": "edge_a", "end": "to_a"},
|
||||
{"start": "from_b", "type": "edge_b", "end": "to_b"},
|
||||
]
|
||||
|
||||
expected = [
|
||||
"(:`from_a`)-[:`edge_a`]->(:`to_a`)",
|
||||
"(:`from_b`)-[:`edge_b`]->(:`to_b`)",
|
||||
]
|
||||
|
||||
self.assertEqual(AGEGraph._format_triples(test_input), expected)
|
||||
|
||||
def test_get_col_name(self) -> None:
|
||||
inputs = [
|
||||
("a", 1),
|
||||
("a as b", 1),
|
||||
(" c ", 1),
|
||||
(" c as d ", 1),
|
||||
("sum(a)", 1),
|
||||
("sum(a) as b", 1),
|
||||
("count(*)", 1),
|
||||
("count(*) as cnt", 1),
|
||||
("true", 1),
|
||||
("false", 1),
|
||||
("null", 1),
|
||||
]
|
||||
|
||||
expected = [
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
"sum_a",
|
||||
"b",
|
||||
"count_*",
|
||||
"cnt",
|
||||
"column_1",
|
||||
"column_1",
|
||||
"column_1",
|
||||
]
|
||||
|
||||
for idx, value in enumerate(inputs):
|
||||
self.assertEqual(AGEGraph._get_col_name(*value), expected[idx])
|
||||
|
||||
def test_wrap_query(self) -> None:
|
||||
inputs = [
|
||||
"""
|
||||
MATCH (keanu:Person {name:'Keanu Reeves'})
|
||||
RETURN keanu.name AS name, keanu.born AS born
|
||||
""",
|
||||
"""
|
||||
MERGE (n:a {id: 1})
|
||||
""",
|
||||
]
|
||||
|
||||
expected = [
|
||||
"""
|
||||
SELECT * FROM ag_catalog.cypher('test', $$
|
||||
MATCH (keanu:Person {name:'Keanu Reeves'})
|
||||
RETURN keanu.name AS name, keanu.born AS born
|
||||
$$) AS (name agtype, born agtype);
|
||||
""",
|
||||
"""
|
||||
SELECT * FROM ag_catalog.cypher('test', $$
|
||||
MERGE (n:a {id: 1})
|
||||
$$) AS (a agtype);
|
||||
""",
|
||||
]
|
||||
|
||||
for idx, value in enumerate(inputs):
|
||||
self.assertEqual(
|
||||
re.sub(r"\s", "", AGEGraph._wrap_query(value, "test")),
|
||||
re.sub(r"\s", "", expected[idx]),
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
AGEGraph._wrap_query(
|
||||
"""
|
||||
MATCH ()
|
||||
RETURN *
|
||||
""",
|
||||
"test",
|
||||
)
|
||||
|
||||
def test_format_properties(self) -> None:
|
||||
inputs: List[Dict[str, Any]] = [{}, {"a": "b"}, {"a": "b", "c": 1, "d": True}]
|
||||
|
||||
expected = ["{}", '{`a`: "b"}', '{`a`: "b", `c`: 1, `d`: true}']
|
||||
|
||||
for idx, value in enumerate(inputs):
|
||||
self.assertEqual(AGEGraph._format_properties(value), expected[idx])
|
||||
|
||||
def test_clean_graph_labels(self) -> None:
|
||||
inputs = ["label", "label 1", "label#$"]
|
||||
|
||||
expected = ["label", "label_1", "label_"]
|
||||
|
||||
for idx, value in enumerate(inputs):
|
||||
self.assertEqual(AGEGraph.clean_graph_labels(value), expected[idx])
|
||||
|
||||
def test_record_to_dict(self) -> None:
|
||||
Record = namedtuple("Record", ["node1", "edge", "node2"])
|
||||
r = Record(
|
||||
node1='{"id": 1, "label": "label1", "properties":'
|
||||
+ ' {"prop": "a"}}::vertex',
|
||||
edge='{"id": 3, "label": "edge", "end_id": 2, '
|
||||
+ '"start_id": 1, "properties": {"test": "abc"}}::edge',
|
||||
node2='{"id": 2, "label": "label1", '
|
||||
+ '"properties": {"prop": "b"}}::vertex',
|
||||
)
|
||||
|
||||
result = AGEGraph._record_to_dict(r)
|
||||
|
||||
expected = {
|
||||
"node1": {"prop": "a"},
|
||||
"edge": ({"prop": "a"}, "edge", {"prop": "b"}),
|
||||
"node2": {"prop": "b"},
|
||||
}
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
Record2 = namedtuple("Record2", ["string", "int", "float", "bool", "null"])
|
||||
r2 = Record2('"test"', "1", "1.5", "true", None)
|
||||
|
||||
result = AGEGraph._record_to_dict(r2)
|
||||
|
||||
expected2 = {
|
||||
"string": "test",
|
||||
"int": 1,
|
||||
"float": 1.5,
|
||||
"bool": True,
|
||||
"null": None,
|
||||
}
|
||||
|
||||
self.assertEqual(result, expected2)
|
Reference in New Issue
Block a user