mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 20:28:10 +00:00
Harrison/nebula graph (#5865)
Co-authored-by: Wey Gu <weyl.gu@gmail.com> Co-authored-by: chenweisomebody <chenweisomebody@gmail.com>
This commit is contained in:
parent
658f8bdee7
commit
35cfd25db3
270
docs/modules/chains/examples/graph_nebula_qa.ipynb
Normal file
270
docs/modules/chains/examples/graph_nebula_qa.ipynb
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c94240f5",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# NebulaGraphQAChain\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook shows how to use LLMs to provide a natural language interface to NebulaGraph database."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "dbc0ee68",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You will need to have a running NebulaGraph cluster, for which you can run a containerized cluster by running the following script:\n",
|
||||||
|
"\n",
|
||||||
|
"```bash\n",
|
||||||
|
"curl -fsSL nebula-up.siwei.io/install.sh | bash\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"Other options are:\n",
|
||||||
|
"- Install as a [Docker Desktop Extension](https://www.docker.com/blog/distributed-cloud-native-graph-database-nebulagraph-docker-extension/). See [here](https://docs.nebula-graph.io/3.5.0/2.quick-start/1.quick-start-workflow/)\n",
|
||||||
|
"- NebulaGraph Cloud Service. See [here](https://www.nebula-graph.io/cloud)\n",
|
||||||
|
"- Deploy from package, source code, or via Kubernetes. See [here](https://docs.nebula-graph.io/)\n",
|
||||||
|
"\n",
|
||||||
|
"Once the cluster is running, we could create the SPACE and SCHEMA for the database."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c82f4141",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install ipython-ngql\n",
|
||||||
|
"%load_ext ngql\n",
|
||||||
|
"\n",
|
||||||
|
"# connect ngql jupyter extension to nebulagraph\n",
|
||||||
|
"%ngql --address 127.0.0.1 --port 9669 --user root --password nebula\n",
|
||||||
|
"# create a new space\n",
|
||||||
|
"%ngql CREATE SPACE IF NOT EXISTS langchain(partition_num=1, replica_factor=1, vid_type=fixed_string(128));\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "eda0809a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Wait for a few seconds for the space to be created.\n",
|
||||||
|
"%ngql USE langchain;"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "119fe35c",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Create the schema, for full dataset, refer [here](https://www.siwei.io/en/nebulagraph-etl-dbt/)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "5aa796ee",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%%ngql\n",
|
||||||
|
"CREATE TAG IF NOT EXISTS movie(name string);\n",
|
||||||
|
"CREATE TAG IF NOT EXISTS person(name string, birthdate string);\n",
|
||||||
|
"CREATE EDGE IF NOT EXISTS acted_in();\n",
|
||||||
|
"CREATE TAG INDEX IF NOT EXISTS person_index ON person(name(128));\n",
|
||||||
|
"CREATE TAG INDEX IF NOT EXISTS movie_index ON movie(name(128));"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "66e4799a",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Wait for schema creation to complete, then we can insert some data."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "d8eea530",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"UsageError: Cell magic `%%ngql` not found.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"%%ngql\n",
|
||||||
|
"INSERT VERTEX person(name, birthdate) VALUES \"Al Pacino\":(\"Al Pacino\", \"1940-04-25\");\n",
|
||||||
|
"INSERT VERTEX movie(name) VALUES \"The Godfather II\":(\"The Godfather II\");\n",
|
||||||
|
"INSERT VERTEX movie(name) VALUES \"The Godfather Coda: The Death of Michael Corleone\":(\"The Godfather Coda: The Death of Michael Corleone\");\n",
|
||||||
|
"INSERT EDGE acted_in() VALUES \"Al Pacino\"->\"The Godfather II\":();\n",
|
||||||
|
"INSERT EDGE acted_in() VALUES \"Al Pacino\"->\"The Godfather Coda: The Death of Michael Corleone\":();"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "62812aad",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chat_models import ChatOpenAI\n",
|
||||||
|
"from langchain.chains import NebulaGraphQAChain\n",
|
||||||
|
"from langchain.graphs import NebulaGraph"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "0928915d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"graph = NebulaGraph(\n",
|
||||||
|
" space=\"langchain\",\n",
|
||||||
|
" username=\"root\",\n",
|
||||||
|
" password=\"nebula\",\n",
|
||||||
|
" address=\"127.0.0.1\",\n",
|
||||||
|
" port=9669,\n",
|
||||||
|
" session_pool_size=30,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "58c1a8ea",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Refresh graph schema information\n",
|
||||||
|
"\n",
|
||||||
|
"If the schema of database changes, you can refresh the schema information needed to generate nGQL statements."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "4e3de44f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# graph.refresh_schema()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "1fe76ccd",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Node properties: [{'tag': 'movie', 'properties': [('name', 'string')]}, {'tag': 'person', 'properties': [('name', 'string'), ('birthdate', 'string')]}]\n",
|
||||||
|
"Edge properties: [{'edge': 'acted_in', 'properties': []}]\n",
|
||||||
|
"Relationships: ['(:person)-[:acted_in]->(:movie)']\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(graph.get_schema)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "68a3c677",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Querying the graph\n",
|
||||||
|
"\n",
|
||||||
|
"We can now use the graph cypher QA chain to ask question of the graph"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "7476ce98",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chain = NebulaGraphQAChain.from_llm(\n",
|
||||||
|
" ChatOpenAI(temperature=0), graph=graph, verbose=True\n",
|
||||||
|
")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "ef8ee27b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new NebulaGraphQAChain chain...\u001b[0m\n",
|
||||||
|
"Generated nGQL:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mMATCH (p:`person`)-[:acted_in]->(m:`movie`) WHERE m.`movie`.`name` == 'The Godfather II'\n",
|
||||||
|
"RETURN p.`person`.`name`\u001b[0m\n",
|
||||||
|
"Full Context:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3m{'p.person.name': ['Al Pacino']}\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'Al Pacino played in The Godfather II.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chain.run(\"Who played in The Godfather II?\")"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -11,6 +11,7 @@ from langchain.chains.conversational_retrieval.base import (
|
|||||||
from langchain.chains.flare.base import FlareChain
|
from langchain.chains.flare.base import FlareChain
|
||||||
from langchain.chains.graph_qa.base import GraphQAChain
|
from langchain.chains.graph_qa.base import GraphQAChain
|
||||||
from langchain.chains.graph_qa.cypher import GraphCypherQAChain
|
from langchain.chains.graph_qa.cypher import GraphCypherQAChain
|
||||||
|
from langchain.chains.graph_qa.nebulagraph import NebulaGraphQAChain
|
||||||
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
|
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.llm_bash.base import LLMBashChain
|
from langchain.chains.llm_bash.base import LLMBashChain
|
||||||
@ -67,4 +68,5 @@ __all__ = [
|
|||||||
"ConversationalRetrievalChain",
|
"ConversationalRetrievalChain",
|
||||||
"OpenAPIEndpointChain",
|
"OpenAPIEndpointChain",
|
||||||
"FlareChain",
|
"FlareChain",
|
||||||
|
"NebulaGraphQAChain",
|
||||||
]
|
]
|
||||||
|
91
langchain/chains/graph_qa/nebulagraph.py
Normal file
91
langchain/chains/graph_qa/nebulagraph.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
"""Question answering over a graph."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, NGQL_GENERATION_PROMPT
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.graphs.nebula_graph import NebulaGraph
|
||||||
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
|
class NebulaGraphQAChain(Chain):
|
||||||
|
"""Chain for question-answering against a graph by generating nGQL statements."""
|
||||||
|
|
||||||
|
graph: NebulaGraph = Field(exclude=True)
|
||||||
|
ngql_generation_chain: LLMChain
|
||||||
|
qa_chain: LLMChain
|
||||||
|
input_key: str = "query" #: :meta private:
|
||||||
|
output_key: str = "result" #: :meta private:
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Return the input keys.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Return the output keys.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
_output_keys = [self.output_key]
|
||||||
|
return _output_keys
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm(
|
||||||
|
cls,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
*,
|
||||||
|
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||||
|
ngql_prompt: BasePromptTemplate = NGQL_GENERATION_PROMPT,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> NebulaGraphQAChain:
|
||||||
|
"""Initialize from LLM."""
|
||||||
|
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||||
|
ngql_generation_chain = LLMChain(llm=llm, prompt=ngql_prompt)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
qa_chain=qa_chain,
|
||||||
|
ngql_generation_chain=ngql_generation_chain,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""Generate nGQL statement, use it to look up in db and answer question."""
|
||||||
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
|
callbacks = _run_manager.get_child()
|
||||||
|
question = inputs[self.input_key]
|
||||||
|
|
||||||
|
generated_ngql = self.ngql_generation_chain.run(
|
||||||
|
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||||
|
)
|
||||||
|
|
||||||
|
_run_manager.on_text("Generated nGQL:", end="\n", verbose=self.verbose)
|
||||||
|
_run_manager.on_text(
|
||||||
|
generated_ngql, color="green", end="\n", verbose=self.verbose
|
||||||
|
)
|
||||||
|
context = self.graph.query(generated_ngql)
|
||||||
|
|
||||||
|
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||||
|
_run_manager.on_text(
|
||||||
|
str(context), color="green", end="\n", verbose=self.verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.qa_chain(
|
||||||
|
{"question": question, "context": context},
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
return {self.output_key: result[self.qa_chain.output_key]}
|
@ -49,6 +49,29 @@ CYPHER_GENERATION_PROMPT = PromptTemplate(
|
|||||||
input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE
|
input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
NEBULAGRAPH_EXTRA_INSTRUCTIONS = """
|
||||||
|
Instructions:
|
||||||
|
|
||||||
|
First, generate cypher then convert it to NebulaGraph Cypher dialect(rather than standard):
|
||||||
|
1. it requires explicit label specification when referring to node properties: v.`Foo`.name
|
||||||
|
2. it uses double equals sign for comparison: `==` rather than `=`
|
||||||
|
For instance:
|
||||||
|
```diff
|
||||||
|
< MATCH (p:person)-[:directed]->(m:movie) WHERE m.name = 'The Godfather II'
|
||||||
|
< RETURN p.name;
|
||||||
|
---
|
||||||
|
> MATCH (p:`person`)-[:directed]->(m:`movie`) WHERE m.`movie`.`name` == 'The Godfather II'
|
||||||
|
> RETURN p.`person`.`name`;
|
||||||
|
```\n"""
|
||||||
|
|
||||||
|
NGQL_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
|
||||||
|
"Generate Cypher", "Generate NebulaGraph Cypher"
|
||||||
|
).replace("Instructions:", NEBULAGRAPH_EXTRA_INSTRUCTIONS)
|
||||||
|
|
||||||
|
NGQL_GENERATION_PROMPT = PromptTemplate(
|
||||||
|
input_variables=["schema", "question"], template=NGQL_GENERATION_TEMPLATE
|
||||||
|
)
|
||||||
|
|
||||||
CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers.
|
CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers.
|
||||||
The information part contains the provided information that you must use to construct an answer.
|
The information part contains the provided information that you must use to construct an answer.
|
||||||
The provided information is authorative, you must never doubt it or try to use your internal knowledge to correct it.
|
The provided information is authorative, you must never doubt it or try to use your internal knowledge to correct it.
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Graph implementations."""
|
"""Graph implementations."""
|
||||||
|
from langchain.graphs.nebula_graph import NebulaGraph
|
||||||
from langchain.graphs.neo4j_graph import Neo4jGraph
|
from langchain.graphs.neo4j_graph import Neo4jGraph
|
||||||
from langchain.graphs.networkx_graph import NetworkxEntityGraph
|
from langchain.graphs.networkx_graph import NetworkxEntityGraph
|
||||||
|
|
||||||
__all__ = ["NetworkxEntityGraph", "Neo4jGraph"]
|
__all__ = ["NetworkxEntityGraph", "Neo4jGraph", "NebulaGraph"]
|
||||||
|
201
langchain/graphs/nebula_graph.py
Normal file
201
langchain/graphs/nebula_graph.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
import logging
|
||||||
|
from string import Template
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
rel_query = Template(
|
||||||
|
"""
|
||||||
|
MATCH ()-[e:`$edge_type`]->()
|
||||||
|
WITH e limit 1
|
||||||
|
MATCH (m)-[:`$edge_type`]->(n) WHERE id(m) == src(e) AND id(n) == dst(e)
|
||||||
|
RETURN "(:" + tags(m)[0] + ")-[:$edge_type]->(:" + tags(n)[0] + ")" AS rels
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
RETRY_TIMES = 3
|
||||||
|
|
||||||
|
|
||||||
|
class NebulaGraph:
|
||||||
|
"""NebulaGraph wrapper for graph operations
|
||||||
|
NebulaGraph inherits methods from Neo4jGraph to bring ease to the user space.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
space: str,
|
||||||
|
username: str = "root",
|
||||||
|
password: str = "nebula",
|
||||||
|
address: str = "127.0.0.1",
|
||||||
|
port: int = 9669,
|
||||||
|
session_pool_size: int = 30,
|
||||||
|
) -> None:
|
||||||
|
"""Create a new NebulaGraph wrapper instance."""
|
||||||
|
try:
|
||||||
|
import nebula3 # noqa: F401
|
||||||
|
import pandas # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Please install NebulaGraph Python client and pandas first: "
|
||||||
|
"`pip install nebula3-python pandas`"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.username = username
|
||||||
|
self.password = password
|
||||||
|
self.address = address
|
||||||
|
self.port = port
|
||||||
|
self.space = space
|
||||||
|
self.session_pool_size = session_pool_size
|
||||||
|
|
||||||
|
self.session_pool = self._get_session_pool()
|
||||||
|
self.schema = ""
|
||||||
|
# Set schema
|
||||||
|
try:
|
||||||
|
self.refresh_schema()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Could not refresh schema. Error: {e}")
|
||||||
|
|
||||||
|
def _get_session_pool(self) -> Any:
|
||||||
|
assert all(
|
||||||
|
[self.username, self.password, self.address, self.port, self.space]
|
||||||
|
), (
|
||||||
|
"Please provide all of the following parameters: "
|
||||||
|
"username, password, address, port, space"
|
||||||
|
)
|
||||||
|
|
||||||
|
from nebula3.Config import SessionPoolConfig
|
||||||
|
from nebula3.Exception import AuthFailedException, InValidHostname
|
||||||
|
from nebula3.gclient.net.SessionPool import SessionPool
|
||||||
|
|
||||||
|
config = SessionPoolConfig()
|
||||||
|
config.max_size = self.session_pool_size
|
||||||
|
|
||||||
|
try:
|
||||||
|
session_pool = SessionPool(
|
||||||
|
self.username,
|
||||||
|
self.password,
|
||||||
|
self.space,
|
||||||
|
[(self.address, self.port)],
|
||||||
|
)
|
||||||
|
except InValidHostname:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not connect to NebulaGraph database. "
|
||||||
|
"Please ensure that the address and port are correct"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
session_pool.init(config)
|
||||||
|
except AuthFailedException:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not connect to NebulaGraph database. "
|
||||||
|
"Please ensure that the username and password are correct"
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise ValueError(f"Error initializing session pool. Error: {e}")
|
||||||
|
|
||||||
|
return session_pool
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
try:
|
||||||
|
self.session_pool.close()
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Could not close session pool. Error: {e}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def get_schema(self) -> str:
|
||||||
|
"""Returns the schema of the NebulaGraph database"""
|
||||||
|
return self.schema
|
||||||
|
|
||||||
|
def execute(self, query: str, params: dict = {}, retry: int = 0) -> Any:
|
||||||
|
"""Query NebulaGraph database."""
|
||||||
|
from nebula3.Exception import IOErrorException, NoValidSessionException
|
||||||
|
from nebula3.fbthrift.transport.TTransport import TTransportException
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.session_pool.execute_parameter(query, params)
|
||||||
|
if not result.is_succeeded():
|
||||||
|
logging.warning(
|
||||||
|
f"Error executing query to NebulaGraph. "
|
||||||
|
f"Error: {result.error_msg()}\n"
|
||||||
|
f"Query: {query} \n"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except NoValidSessionException:
|
||||||
|
logging.warning(
|
||||||
|
f"No valid session found in session pool. "
|
||||||
|
f"Please consider increasing the session pool size. "
|
||||||
|
f"Current size: {self.session_pool_size}"
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"No valid session found in session pool. "
|
||||||
|
f"Please consider increasing the session pool size. "
|
||||||
|
f"Current size: {self.session_pool_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
if retry < RETRY_TIMES:
|
||||||
|
retry += 1
|
||||||
|
logging.warning(
|
||||||
|
f"Error executing query to NebulaGraph. "
|
||||||
|
f"Retrying ({retry}/{RETRY_TIMES})...\n"
|
||||||
|
f"query: {query} \n"
|
||||||
|
f"Error: {e}"
|
||||||
|
)
|
||||||
|
return self.execute(query, params, retry)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Error executing query to NebulaGraph. Error: {e}")
|
||||||
|
|
||||||
|
except (TTransportException, IOErrorException):
|
||||||
|
# connection issue, try to recreate session pool
|
||||||
|
if retry < RETRY_TIMES:
|
||||||
|
retry += 1
|
||||||
|
logging.warning(
|
||||||
|
f"Connection issue with NebulaGraph. "
|
||||||
|
f"Retrying ({retry}/{RETRY_TIMES})...\n to recreate session pool"
|
||||||
|
)
|
||||||
|
self.session_pool = self._get_session_pool()
|
||||||
|
return self.execute(query, params, retry)
|
||||||
|
|
||||||
|
def refresh_schema(self) -> None:
|
||||||
|
"""
|
||||||
|
Refreshes the NebulaGraph schema information.
|
||||||
|
"""
|
||||||
|
tags_schema, edge_types_schema, relationships = [], [], []
|
||||||
|
for tag in self.execute("SHOW TAGS").column_values("Name"):
|
||||||
|
tag_name = tag.cast()
|
||||||
|
tag_schema = {"tag": tag_name, "properties": []}
|
||||||
|
r = self.execute(f"DESCRIBE TAG `{tag_name}`")
|
||||||
|
props, types = r.column_values("Field"), r.column_values("Type")
|
||||||
|
for i in range(r.row_size()):
|
||||||
|
tag_schema["properties"].append((props[i].cast(), types[i].cast()))
|
||||||
|
tags_schema.append(tag_schema)
|
||||||
|
for edge_type in self.execute("SHOW EDGES").column_values("Name"):
|
||||||
|
edge_type_name = edge_type.cast()
|
||||||
|
edge_schema = {"edge": edge_type_name, "properties": []}
|
||||||
|
r = self.execute(f"DESCRIBE EDGE `{edge_type_name}`")
|
||||||
|
props, types = r.column_values("Field"), r.column_values("Type")
|
||||||
|
for i in range(r.row_size()):
|
||||||
|
edge_schema["properties"].append((props[i].cast(), types[i].cast()))
|
||||||
|
edge_types_schema.append(edge_schema)
|
||||||
|
|
||||||
|
# build relationships types
|
||||||
|
r = self.execute(
|
||||||
|
rel_query.substitute(edge_type=edge_type_name)
|
||||||
|
).column_values("rels")
|
||||||
|
if len(r) > 0:
|
||||||
|
relationships.append(r[0].cast())
|
||||||
|
|
||||||
|
self.schema = (
|
||||||
|
f"Node properties: {tags_schema}\n"
|
||||||
|
f"Edge properties: {edge_types_schema}\n"
|
||||||
|
f"Relationships: {relationships}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
def query(self, query: str, retry: int = 0) -> Dict[str, Any]:
|
||||||
|
result = self.execute(query, retry=retry)
|
||||||
|
columns = result.keys()
|
||||||
|
d: Dict[str, list] = {}
|
||||||
|
for col_num in range(result.col_size()):
|
||||||
|
col_name = columns[col_num]
|
||||||
|
col_list = result.column_values(col_name)
|
||||||
|
d[col_name] = [x.cast() for x in col_list]
|
||||||
|
return d
|
33
poetry.lock
generated
33
poetry.lock
generated
@ -2355,6 +2355,17 @@ smb = ["smbprotocol"]
|
|||||||
ssh = ["paramiko"]
|
ssh = ["paramiko"]
|
||||||
tqdm = ["tqdm"]
|
tqdm = ["tqdm"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "future"
|
||||||
|
version = "0.18.3"
|
||||||
|
description = "Clean single-source support for Python 3 and 2"
|
||||||
|
category = "main"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
|
||||||
|
files = [
|
||||||
|
{file = "future-0.18.3.tar.gz", hash = "sha256:34a17436ed1e96697a86f9de3d15a3b0be01d8bc8de9c1dffd59fb8234ed5307"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gast"
|
name = "gast"
|
||||||
version = "0.4.0"
|
version = "0.4.0"
|
||||||
@ -5184,6 +5195,24 @@ nbformat = "*"
|
|||||||
sphinx = ">=1.8"
|
sphinx = ">=1.8"
|
||||||
traitlets = ">=5"
|
traitlets = ">=5"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nebula3-python"
|
||||||
|
version = "3.4.0"
|
||||||
|
description = "Python client for NebulaGraph V3.4"
|
||||||
|
category = "main"
|
||||||
|
optional = true
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "nebula3-python-3.4.0.tar.gz", hash = "sha256:47bd8b1b4bb2c2f0e5122bc147926cb50578a66841acf6a743cae4d0362c9eaa"},
|
||||||
|
{file = "nebula3_python-3.4.0-py3-none-any.whl", hash = "sha256:d9d94c6a41712875e6ec866907de0789057f860e64f547f87d9f199439759dd6"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
future = ">=0.18.0"
|
||||||
|
httplib2 = ">=0.20.0"
|
||||||
|
pytz = ">=2021.1"
|
||||||
|
six = ">=1.16.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "neo4j"
|
name = "neo4j"
|
||||||
version = "5.9.0"
|
version = "5.9.0"
|
||||||
@ -11311,7 +11340,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
|
|||||||
cffi = ["cffi (>=1.11)"]
|
cffi = ["cffi (>=1.11)"]
|
||||||
|
|
||||||
[extras]
|
[extras]
|
||||||
all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "pymongo", "weaviate-client", "redis", "google-api-python-client", "google-auth", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos", "lancedb", "langkit", "lark", "pexpect", "pyvespa", "O365", "jq", "docarray", "steamship", "pdfminer-six", "lxml", "requests-toolbelt", "neo4j", "openlm", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "momento", "singlestoredb", "tigrisdb"]
|
all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "pymongo", "weaviate-client", "redis", "google-api-python-client", "google-auth", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos", "lancedb", "langkit", "lark", "pexpect", "pyvespa", "O365", "jq", "docarray", "steamship", "pdfminer-six", "lxml", "requests-toolbelt", "neo4j", "openlm", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "momento", "singlestoredb", "nebula3-python"]
|
||||||
azure = ["azure-identity", "azure-cosmos", "openai", "azure-core", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech"]
|
azure = ["azure-identity", "azure-cosmos", "openai", "azure-core", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech"]
|
||||||
cohere = ["cohere"]
|
cohere = ["cohere"]
|
||||||
docarray = ["docarray"]
|
docarray = ["docarray"]
|
||||||
@ -11325,4 +11354,4 @@ text-helpers = ["chardet"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "faeb3cc6feb059096a66ba8b1fd2271cd91e3a9553cb4f05e5ea493610ac3763"
|
content-hash = "836aca50cdc2300a684e7c039cfe8100b705f21496014bd49f64d92f4a6baa10"
|
||||||
|
@ -104,6 +104,7 @@ bibtexparser = {version = "^1.4.0", optional = true}
|
|||||||
singlestoredb = {version = "^0.6.1", optional = true}
|
singlestoredb = {version = "^0.6.1", optional = true}
|
||||||
pyspark = {version = "^3.4.0", optional = true}
|
pyspark = {version = "^3.4.0", optional = true}
|
||||||
tigrisdb = {version = "^1.0.0b6", optional = true}
|
tigrisdb = {version = "^1.0.0b6", optional = true}
|
||||||
|
nebula3-python = {version = "^3.4.0", optional = true}
|
||||||
langchainplus-sdk = ">=0.0.6"
|
langchainplus-sdk = ">=0.0.6"
|
||||||
|
|
||||||
|
|
||||||
@ -283,7 +284,7 @@ all = [
|
|||||||
"azure-cognitiveservices-speech",
|
"azure-cognitiveservices-speech",
|
||||||
"momento",
|
"momento",
|
||||||
"singlestoredb",
|
"singlestoredb",
|
||||||
"tigrisdb"
|
"nebula3-python",
|
||||||
]
|
]
|
||||||
|
|
||||||
# An extra used to be able to add extended testing.
|
# An extra used to be able to add extended testing.
|
||||||
|
90
tests/integration_tests/test_nebulagraph.py
Normal file
90
tests/integration_tests/test_nebulagraph.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
import unittest
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from langchain.graphs import NebulaGraph
|
||||||
|
|
||||||
|
|
||||||
|
class TestNebulaGraph(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.space = "test_space"
|
||||||
|
self.username = "test_user"
|
||||||
|
self.password = "test_password"
|
||||||
|
self.address = "test_address"
|
||||||
|
self.port = 1234
|
||||||
|
self.session_pool_size = 10
|
||||||
|
|
||||||
|
@patch("nebula3.gclient.net.SessionPool.SessionPool")
|
||||||
|
def test_init(self, mock_session_pool: Any) -> None:
|
||||||
|
mock_session_pool.return_value = MagicMock()
|
||||||
|
nebula_graph = NebulaGraph(
|
||||||
|
self.space,
|
||||||
|
self.username,
|
||||||
|
self.password,
|
||||||
|
self.address,
|
||||||
|
self.port,
|
||||||
|
self.session_pool_size,
|
||||||
|
)
|
||||||
|
self.assertEqual(nebula_graph.space, self.space)
|
||||||
|
self.assertEqual(nebula_graph.username, self.username)
|
||||||
|
self.assertEqual(nebula_graph.password, self.password)
|
||||||
|
self.assertEqual(nebula_graph.address, self.address)
|
||||||
|
self.assertEqual(nebula_graph.port, self.port)
|
||||||
|
self.assertEqual(nebula_graph.session_pool_size, self.session_pool_size)
|
||||||
|
|
||||||
|
@patch("nebula3.gclient.net.SessionPool.SessionPool")
|
||||||
|
def test_get_session_pool(self, mock_session_pool: Any) -> None:
|
||||||
|
mock_session_pool.return_value = MagicMock()
|
||||||
|
nebula_graph = NebulaGraph(
|
||||||
|
self.space,
|
||||||
|
self.username,
|
||||||
|
self.password,
|
||||||
|
self.address,
|
||||||
|
self.port,
|
||||||
|
self.session_pool_size,
|
||||||
|
)
|
||||||
|
session_pool = nebula_graph._get_session_pool()
|
||||||
|
self.assertIsInstance(session_pool, MagicMock)
|
||||||
|
|
||||||
|
@patch("nebula3.gclient.net.SessionPool.SessionPool")
|
||||||
|
def test_del(self, mock_session_pool: Any) -> None:
|
||||||
|
mock_session_pool.return_value = MagicMock()
|
||||||
|
nebula_graph = NebulaGraph(
|
||||||
|
self.space,
|
||||||
|
self.username,
|
||||||
|
self.password,
|
||||||
|
self.address,
|
||||||
|
self.port,
|
||||||
|
self.session_pool_size,
|
||||||
|
)
|
||||||
|
nebula_graph.__del__()
|
||||||
|
mock_session_pool.return_value.close.assert_called_once()
|
||||||
|
|
||||||
|
@patch("nebula3.gclient.net.SessionPool.SessionPool")
|
||||||
|
def test_execute(self, mock_session_pool: Any) -> None:
|
||||||
|
mock_session_pool.return_value = MagicMock()
|
||||||
|
nebula_graph = NebulaGraph(
|
||||||
|
self.space,
|
||||||
|
self.username,
|
||||||
|
self.password,
|
||||||
|
self.address,
|
||||||
|
self.port,
|
||||||
|
self.session_pool_size,
|
||||||
|
)
|
||||||
|
query = "SELECT * FROM test_table"
|
||||||
|
result = nebula_graph.execute(query)
|
||||||
|
self.assertIsInstance(result, MagicMock)
|
||||||
|
|
||||||
|
@patch("nebula3.gclient.net.SessionPool.SessionPool")
|
||||||
|
def test_refresh_schema(self, mock_session_pool: Any) -> None:
|
||||||
|
mock_session_pool.return_value = MagicMock()
|
||||||
|
nebula_graph = NebulaGraph(
|
||||||
|
self.space,
|
||||||
|
self.username,
|
||||||
|
self.password,
|
||||||
|
self.address,
|
||||||
|
self.port,
|
||||||
|
self.session_pool_size,
|
||||||
|
)
|
||||||
|
nebula_graph.refresh_schema()
|
||||||
|
self.assertNotEqual(nebula_graph.get_schema, "")
|
Loading…
Reference in New Issue
Block a user