mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-05 00:30:18 +00:00
Compare commits
11 Commits
eugene/lan
...
bagatur/0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0495ca0d10 | ||
|
|
a84310cdcb | ||
|
|
58b8747c44 | ||
|
|
c57e506f9c | ||
|
|
068620a871 | ||
|
|
4812403b48 | ||
|
|
ed75bccda8 | ||
|
|
5c194ee224 | ||
|
|
408bdd5604 | ||
|
|
6a93ff2a4b | ||
|
|
7e96a7eaea |
@@ -1,15 +1,4 @@
|
||||
"""Toolkits for agents."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from langchain_core.tools import BaseToolkit
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
from langchain_community.tools import BaseTool
|
||||
|
||||
|
||||
class BaseToolkit(BaseModel, ABC):
|
||||
"""Base Toolkit representing a collection of related tools."""
|
||||
|
||||
@abstractmethod
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
__all__ = ["BaseToolkit"]
|
||||
|
||||
@@ -47,6 +47,7 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from sqlalchemy import Column, Integer, String, create_engine, delete, select
|
||||
from sqlalchemy.engine import Row
|
||||
from sqlalchemy.engine.base import Engine
|
||||
@@ -187,6 +188,9 @@ def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
|
||||
return None
|
||||
|
||||
|
||||
@deprecated(
|
||||
"0.0.29", alternative_import="langchain.cache.InMemoryCache", removal="0.2.0"
|
||||
)
|
||||
class InMemoryCache(BaseCache):
|
||||
"""Cache that stores things in memory."""
|
||||
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
from typing import Any
|
||||
|
||||
_LANGCHAIN_DEPENDENT = [
|
||||
"ArangoGraphQAChain",
|
||||
"GraphQAChain",
|
||||
"GraphCypherQAChain",
|
||||
"FalkorDBQAChain",
|
||||
"GremlinQAChain",
|
||||
"HugeGraphQAChain",
|
||||
"KuzuQAChain",
|
||||
"NebulaGraphQAChain",
|
||||
"NeptuneOpenCypherQAChain",
|
||||
"NeptuneSparqlQAChain",
|
||||
"OntotextGraphDBQAChain",
|
||||
"GraphSparqlQAChain",
|
||||
]
|
||||
|
||||
try:
|
||||
from langchain.chains.base import Chain
|
||||
except ImportError:
|
||||
__all__ = []
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in _LANGCHAIN_DEPENDENT:
|
||||
raise ImportError(
|
||||
f"Must have `langchain` installed to use {name}. Please install it "
|
||||
f"`pip install -U langchain` and re-run your {name} import."
|
||||
)
|
||||
raise AttributeError()
|
||||
else:
|
||||
from langchain_community.chains.graph_qa.arangodb import ArangoGraphQAChain
|
||||
from langchain_community.chains.graph_qa.base import GraphQAChain
|
||||
from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain
|
||||
from langchain_community.chains.graph_qa.falkordb import FalkorDBQAChain
|
||||
from langchain_community.chains.graph_qa.gremlin import GremlinQAChain
|
||||
from langchain_community.chains.graph_qa.hugegraph import HugeGraphQAChain
|
||||
from langchain_community.chains.graph_qa.kuzu import KuzuQAChain
|
||||
from langchain_community.chains.graph_qa.nebulagraph import NebulaGraphQAChain
|
||||
from langchain_community.chains.graph_qa.neptune_cypher import (
|
||||
NeptuneOpenCypherQAChain,
|
||||
)
|
||||
from langchain_community.chains.graph_qa.neptune_sparql import NeptuneSparqlQAChain
|
||||
from langchain_community.chains.graph_qa.ontotext_graphdb import (
|
||||
OntotextGraphDBQAChain,
|
||||
)
|
||||
from langchain_community.chains.graph_qa.sparql import GraphSparqlQAChain
|
||||
|
||||
__all__ = _LANGCHAIN_DEPENDENT
|
||||
248
libs/community/langchain_community/chains/graph_qa/arangodb.py
Normal file
248
libs/community/langchain_community/chains/graph_qa/arangodb.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Question answering over a graph."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Must have `langchain` installed to use ArangoGraphQAChain. Please install it "
|
||||
"with `pip install -U langchain`."
|
||||
) from e
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
AQL_FIX_PROMPT,
|
||||
AQL_GENERATION_PROMPT,
|
||||
AQL_QA_PROMPT,
|
||||
)
|
||||
from langchain_community.graphs.arangodb_graph import ArangoGraph
|
||||
|
||||
|
||||
class ArangoGraphQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating AQL statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: ArangoGraph = Field(exclude=True)
|
||||
aql_generation_chain: LLMChain
|
||||
aql_fix_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
# Specifies the maximum number of AQL Query Results to return
|
||||
top_k: int = 10
|
||||
|
||||
# Specifies the set of AQL Query Examples that promote few-shot-learning
|
||||
aql_examples: str = ""
|
||||
|
||||
# Specify whether to return the AQL Query in the output dictionary
|
||||
return_aql_query: bool = False
|
||||
|
||||
# Specify whether to return the AQL JSON Result in the output dictionary
|
||||
return_aql_result: bool = False
|
||||
|
||||
# Specify the maximum amount of AQL Generation attempts that should be made
|
||||
max_aql_generation_attempts: int = 3
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
return [self.output_key]
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "graph_aql_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = AQL_QA_PROMPT,
|
||||
aql_generation_prompt: BasePromptTemplate = AQL_GENERATION_PROMPT,
|
||||
aql_fix_prompt: BasePromptTemplate = AQL_FIX_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> ArangoGraphQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
aql_generation_chain = LLMChain(llm=llm, prompt=aql_generation_prompt)
|
||||
aql_fix_chain = LLMChain(llm=llm, prompt=aql_fix_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
aql_generation_chain=aql_generation_chain,
|
||||
aql_fix_chain=aql_fix_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate an AQL statement from user input, use it retrieve a response
|
||||
from an ArangoDB Database instance, and respond to the user input
|
||||
in natural language.
|
||||
|
||||
Users can modify the following ArangoGraphQAChain Class Variables:
|
||||
|
||||
:var top_k: The maximum number of AQL Query Results to return
|
||||
:type top_k: int
|
||||
|
||||
:var aql_examples: A set of AQL Query Examples that are passed to
|
||||
the AQL Generation Prompt Template to promote few-shot-learning.
|
||||
Defaults to an empty string.
|
||||
:type aql_examples: str
|
||||
|
||||
:var return_aql_query: Whether to return the AQL Query in the
|
||||
output dictionary. Defaults to False.
|
||||
:type return_aql_query: bool
|
||||
|
||||
:var return_aql_result: Whether to return the AQL Query in the
|
||||
output dictionary. Defaults to False
|
||||
:type return_aql_result: bool
|
||||
|
||||
:var max_aql_generation_attempts: The maximum amount of AQL
|
||||
Generation attempts to be made prior to raising the last
|
||||
AQL Query Execution Error. Defaults to 3.
|
||||
:type max_aql_generation_attempts: int
|
||||
"""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
user_input = inputs[self.input_key]
|
||||
|
||||
#########################
|
||||
# Generate AQL Query #
|
||||
aql_generation_output = self.aql_generation_chain.run(
|
||||
{
|
||||
"adb_schema": self.graph.schema,
|
||||
"aql_examples": self.aql_examples,
|
||||
"user_input": user_input,
|
||||
},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
#########################
|
||||
|
||||
aql_query = ""
|
||||
aql_error = ""
|
||||
aql_result = None
|
||||
aql_generation_attempt = 1
|
||||
|
||||
while (
|
||||
aql_result is None
|
||||
and aql_generation_attempt < self.max_aql_generation_attempts + 1
|
||||
):
|
||||
#####################
|
||||
# Extract AQL Query #
|
||||
pattern = r"```(?i:aql)?(.*?)```"
|
||||
matches = re.findall(pattern, aql_generation_output, re.DOTALL)
|
||||
if not matches:
|
||||
_run_manager.on_text(
|
||||
"Invalid Response: ", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
aql_generation_output, color="red", end="\n", verbose=self.verbose
|
||||
)
|
||||
raise ValueError(f"Response is Invalid: {aql_generation_output}")
|
||||
|
||||
aql_query = matches[0]
|
||||
#####################
|
||||
|
||||
_run_manager.on_text(
|
||||
f"AQL Query ({aql_generation_attempt}):", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
aql_query, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
#####################
|
||||
# Execute AQL Query #
|
||||
from arango import AQLQueryExecuteError
|
||||
|
||||
try:
|
||||
aql_result = self.graph.query(aql_query, self.top_k)
|
||||
except AQLQueryExecuteError as e:
|
||||
aql_error = e.error_message
|
||||
|
||||
_run_manager.on_text(
|
||||
"AQL Query Execution Error: ", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
aql_error, color="yellow", end="\n\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
########################
|
||||
# Retry AQL Generation #
|
||||
aql_generation_output = self.aql_fix_chain.run(
|
||||
{
|
||||
"adb_schema": self.graph.schema,
|
||||
"aql_query": aql_query,
|
||||
"aql_error": aql_error,
|
||||
},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
########################
|
||||
|
||||
#####################
|
||||
|
||||
aql_generation_attempt += 1
|
||||
|
||||
if aql_result is None:
|
||||
m = f"""
|
||||
Maximum amount of AQL Query Generation attempts reached.
|
||||
Unable to execute the AQL Query due to the following error:
|
||||
{aql_error}
|
||||
"""
|
||||
raise ValueError(m)
|
||||
|
||||
_run_manager.on_text("AQL Result:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(aql_result), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
########################
|
||||
# Interpret AQL Result #
|
||||
result = self.qa_chain(
|
||||
{
|
||||
"adb_schema": self.graph.schema,
|
||||
"user_input": user_input,
|
||||
"aql_query": aql_query,
|
||||
"aql_result": aql_result,
|
||||
},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
########################
|
||||
|
||||
# Return results #
|
||||
result = {self.output_key: result[self.qa_chain.output_key]}
|
||||
|
||||
if self.return_aql_query:
|
||||
result["aql_query"] = aql_query
|
||||
|
||||
if self.return_aql_result:
|
||||
result["aql_result"] = aql_result
|
||||
|
||||
return result
|
||||
106
libs/community/langchain_community/chains/graph_qa/base.py
Normal file
106
libs/community/langchain_community/chains/graph_qa/base.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
try:
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Must have `langchain` installed to use GraphQAChain. Please install it "
|
||||
"with `pip install -U langchain`."
|
||||
) from e
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
|
||||
from langchain_community.graphs.networkx_graph import NetworkxEntityGraph, get_entities
|
||||
|
||||
|
||||
class GraphQAChain(Chain):
|
||||
"""Chain for question-answering against a graph.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: NetworkxEntityGraph = Field(exclude=True)
|
||||
entity_extraction_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
qa_prompt: BasePromptTemplate = GRAPH_QA_PROMPT,
|
||||
entity_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> GraphQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
entity_chain = LLMChain(llm=llm, prompt=entity_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
entity_extraction_chain=entity_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Extract entities, look up info and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
entity_string = self.entity_extraction_chain.run(question)
|
||||
|
||||
_run_manager.on_text("Entities Extracted:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
entity_string, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
entities = get_entities(entity_string)
|
||||
context = ""
|
||||
all_triplets = []
|
||||
for entity in entities:
|
||||
all_triplets.extend(self.graph.get_entity_knowledge(entity))
|
||||
context = "\n".join(all_triplets)
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(context, color="green", end="\n", verbose=self.verbose)
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return {self.output_key: result[self.qa_chain.output_key]}
|
||||
328
libs/community/langchain_community/chains/graph_qa/cypher.py
Normal file
328
libs/community/langchain_community/chains/graph_qa/cypher.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""Question answering over a graph."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
try:
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.loading import load_chain_from_config
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Must have `langchain` installed to use GremlinQAChain. Please install it "
|
||||
"with `pip install -U langchain`."
|
||||
) from e
|
||||
from langchain_community.chains.graph_qa.cypher_utils import (
|
||||
CypherQueryCorrector,
|
||||
Schema,
|
||||
)
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_GENERATION_PROMPT,
|
||||
CYPHER_QA_PROMPT,
|
||||
)
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
|
||||
def extract_cypher(text: str) -> str:
|
||||
"""Extract Cypher code from a text.
|
||||
|
||||
Args:
|
||||
text: Text to extract Cypher code from.
|
||||
|
||||
Returns:
|
||||
Cypher code extracted from the text.
|
||||
"""
|
||||
# The pattern to find Cypher code enclosed in triple backticks
|
||||
pattern = r"```(.*?)```"
|
||||
|
||||
# Find all matches in the input text
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
return matches[0] if matches else text
|
||||
|
||||
|
||||
def construct_schema(
|
||||
structured_schema: Dict[str, Any],
|
||||
include_types: List[str],
|
||||
exclude_types: List[str],
|
||||
) -> str:
|
||||
"""Filter the schema based on included or excluded types"""
|
||||
|
||||
def filter_func(x: str) -> bool:
|
||||
return x in include_types if include_types else x not in exclude_types
|
||||
|
||||
filtered_schema: Dict[str, Any] = {
|
||||
"node_props": {
|
||||
k: v
|
||||
for k, v in structured_schema.get("node_props", {}).items()
|
||||
if filter_func(k)
|
||||
},
|
||||
"rel_props": {
|
||||
k: v
|
||||
for k, v in structured_schema.get("rel_props", {}).items()
|
||||
if filter_func(k)
|
||||
},
|
||||
"relationships": [
|
||||
r
|
||||
for r in structured_schema.get("relationships", [])
|
||||
if all(filter_func(r[t]) for t in ["start", "end", "type"])
|
||||
],
|
||||
}
|
||||
|
||||
# Format node properties
|
||||
formatted_node_props = []
|
||||
for label, properties in filtered_schema["node_props"].items():
|
||||
props_str = ", ".join(
|
||||
[f"{prop['property']}: {prop['type']}" for prop in properties]
|
||||
)
|
||||
formatted_node_props.append(f"{label} {{{props_str}}}")
|
||||
|
||||
# Format relationship properties
|
||||
formatted_rel_props = []
|
||||
for rel_type, properties in filtered_schema["rel_props"].items():
|
||||
props_str = ", ".join(
|
||||
[f"{prop['property']}: {prop['type']}" for prop in properties]
|
||||
)
|
||||
formatted_rel_props.append(f"{rel_type} {{{props_str}}}")
|
||||
|
||||
# Format relationships
|
||||
formatted_rels = [
|
||||
f"(:{el['start']})-[:{el['type']}]->(:{el['end']})"
|
||||
for el in filtered_schema["relationships"]
|
||||
]
|
||||
|
||||
return "\n".join(
|
||||
[
|
||||
"Node properties are the following:",
|
||||
",".join(formatted_node_props),
|
||||
"Relationship properties are the following:",
|
||||
",".join(formatted_rel_props),
|
||||
"The relationships are the following:",
|
||||
",".join(formatted_rels),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class GraphCypherQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating Cypher statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: Any = Field(exclude=True)
|
||||
cypher_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
graph_schema: str
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
top_k: int = 10
|
||||
"""Number of results to return from the query"""
|
||||
return_intermediate_steps: bool = False
|
||||
"""Whether or not to return the intermediate steps along with the final answer."""
|
||||
return_direct: bool = False
|
||||
"""Whether or not to return the result of querying the graph directly."""
|
||||
cypher_query_corrector: Optional[CypherQueryCorrector] = None
|
||||
"""Optional cypher validation tool"""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "graph_cypher_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
*,
|
||||
qa_prompt: Optional[BasePromptTemplate] = None,
|
||||
cypher_prompt: Optional[BasePromptTemplate] = None,
|
||||
cypher_llm: Optional[BaseLanguageModel] = None,
|
||||
qa_llm: Optional[BaseLanguageModel] = None,
|
||||
exclude_types: List[str] = [],
|
||||
include_types: List[str] = [],
|
||||
validate_cypher: bool = False,
|
||||
qa_llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cypher_llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> GraphCypherQAChain:
|
||||
"""Initialize from LLM."""
|
||||
|
||||
if not cypher_llm and not llm:
|
||||
raise ValueError("Either `llm` or `cypher_llm` parameters must be provided")
|
||||
if not qa_llm and not llm:
|
||||
raise ValueError("Either `llm` or `qa_llm` parameters must be provided")
|
||||
if cypher_llm and qa_llm and llm:
|
||||
raise ValueError(
|
||||
"You can specify up to two of 'cypher_llm', 'qa_llm'"
|
||||
", and 'llm', but not all three simultaneously."
|
||||
)
|
||||
if cypher_prompt and cypher_llm_kwargs:
|
||||
raise ValueError(
|
||||
"Specifying cypher_prompt and cypher_llm_kwargs together is"
|
||||
" not allowed. Please pass prompt via cypher_llm_kwargs."
|
||||
)
|
||||
if qa_prompt and qa_llm_kwargs:
|
||||
raise ValueError(
|
||||
"Specifying qa_prompt and qa_llm_kwargs together is"
|
||||
" not allowed. Please pass prompt via qa_llm_kwargs."
|
||||
)
|
||||
use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {}
|
||||
use_cypher_llm_kwargs = (
|
||||
cypher_llm_kwargs if cypher_llm_kwargs is not None else {}
|
||||
)
|
||||
if "prompt" not in use_qa_llm_kwargs:
|
||||
use_qa_llm_kwargs["prompt"] = (
|
||||
qa_prompt if qa_prompt is not None else CYPHER_QA_PROMPT
|
||||
)
|
||||
if "prompt" not in use_cypher_llm_kwargs:
|
||||
use_cypher_llm_kwargs["prompt"] = (
|
||||
cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT
|
||||
)
|
||||
|
||||
qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs) # type: ignore[arg-type]
|
||||
|
||||
cypher_generation_chain = LLMChain(
|
||||
llm=cypher_llm or llm, # type: ignore[arg-type]
|
||||
**use_cypher_llm_kwargs, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if exclude_types and include_types:
|
||||
raise ValueError(
|
||||
"Either `exclude_types` or `include_types` "
|
||||
"can be provided, but not both"
|
||||
)
|
||||
|
||||
graph_schema = construct_schema(
|
||||
kwargs["graph"].get_structured_schema, include_types, exclude_types
|
||||
)
|
||||
|
||||
cypher_query_corrector = None
|
||||
if validate_cypher:
|
||||
corrector_schema = [
|
||||
Schema(el["start"], el["type"], el["end"])
|
||||
for el in kwargs["graph"].structured_schema.get("relationships")
|
||||
]
|
||||
cypher_query_corrector = CypherQueryCorrector(corrector_schema)
|
||||
|
||||
return cls(
|
||||
graph_schema=graph_schema,
|
||||
qa_chain=qa_chain,
|
||||
cypher_generation_chain=cypher_generation_chain,
|
||||
cypher_query_corrector=cypher_query_corrector,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate Cypher statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
intermediate_steps: List = []
|
||||
|
||||
generated_cypher = self.cypher_generation_chain.run(
|
||||
{"question": question, "schema": self.graph_schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
# Extract Cypher code if it is wrapped in backticks
|
||||
generated_cypher = extract_cypher(generated_cypher)
|
||||
|
||||
# Correct Cypher query if enabled
|
||||
if self.cypher_query_corrector:
|
||||
generated_cypher = self.cypher_query_corrector(generated_cypher)
|
||||
|
||||
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_cypher, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"query": generated_cypher})
|
||||
|
||||
# Retrieve and limit the number of results
|
||||
# Generated Cypher be null if query corrector identifies invalid schema
|
||||
if generated_cypher:
|
||||
context = self.graph.query(generated_cypher)[: self.top_k]
|
||||
else:
|
||||
context = []
|
||||
|
||||
if self.return_direct:
|
||||
final_result = context
|
||||
else:
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"context": context})
|
||||
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
final_result = result[self.qa_chain.output_key]
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
|
||||
return chain_result
|
||||
|
||||
|
||||
def load_graph_cypher_chain(config: dict, **kwargs: Any) -> GraphCypherQAChain:
|
||||
if "graph" in kwargs:
|
||||
graph = kwargs.pop("graph")
|
||||
else:
|
||||
raise ValueError("`graph` must be present.")
|
||||
if "cypher_generation_chain" in config:
|
||||
cypher_generation_chain_config = config.pop("cypher_generation_chain")
|
||||
cypher_generation_chain = load_chain_from_config(cypher_generation_chain_config)
|
||||
else:
|
||||
raise ValueError("`cypher_generation_chain` must be present.")
|
||||
if "qa_chain" in config:
|
||||
qa_chain_config = config.pop("qa_chain")
|
||||
qa_chain = load_chain_from_config(qa_chain_config)
|
||||
else:
|
||||
raise ValueError("`qa_chain` must be present.")
|
||||
|
||||
return GraphCypherQAChain(
|
||||
graph=graph,
|
||||
cypher_generation_chain=cypher_generation_chain, # type: ignore[arg-type]
|
||||
qa_chain=qa_chain, # type: ignore[arg-type]
|
||||
**config,
|
||||
)
|
||||
@@ -0,0 +1,260 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
Schema = namedtuple("Schema", ["left_node", "relation", "right_node"])
|
||||
|
||||
|
||||
class CypherQueryCorrector:
|
||||
"""
|
||||
Used to correct relationship direction in generated Cypher statements.
|
||||
This code is copied from the winner's submission to the Cypher competition:
|
||||
https://github.com/sakusaku-rich/cypher-direction-competition
|
||||
"""
|
||||
|
||||
property_pattern = re.compile(r"\{.+?\}")
|
||||
node_pattern = re.compile(r"\(.+?\)")
|
||||
path_pattern = re.compile(
|
||||
r"(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))(<?-)(\[.*?\])?(->?)(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))"
|
||||
)
|
||||
node_relation_node_pattern = re.compile(
|
||||
r"(\()+(?P<left_node>[^()]*?)\)(?P<relation>.*?)\((?P<right_node>[^()]*?)(\))+"
|
||||
)
|
||||
relation_type_pattern = re.compile(r":(?P<relation_type>.+?)?(\{.+\})?]")
|
||||
|
||||
def __init__(self, schemas: List[Schema]):
|
||||
"""
|
||||
Args:
|
||||
schemas: list of schemas
|
||||
"""
|
||||
self.schemas = schemas
|
||||
|
||||
def clean_node(self, node: str) -> str:
|
||||
"""
|
||||
Args:
|
||||
node: node in string format
|
||||
|
||||
"""
|
||||
node = re.sub(self.property_pattern, "", node)
|
||||
node = node.replace("(", "")
|
||||
node = node.replace(")", "")
|
||||
node = node.strip()
|
||||
return node
|
||||
|
||||
def detect_node_variables(self, query: str) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Args:
|
||||
query: cypher query
|
||||
"""
|
||||
nodes = re.findall(self.node_pattern, query)
|
||||
nodes = [self.clean_node(node) for node in nodes]
|
||||
res: Dict[str, Any] = {}
|
||||
for node in nodes:
|
||||
parts = node.split(":")
|
||||
if parts == "":
|
||||
continue
|
||||
variable = parts[0]
|
||||
if variable not in res:
|
||||
res[variable] = []
|
||||
res[variable] += parts[1:]
|
||||
return res
|
||||
|
||||
def extract_paths(self, query: str) -> "List[str]":
|
||||
"""
|
||||
Args:
|
||||
query: cypher query
|
||||
"""
|
||||
paths = []
|
||||
idx = 0
|
||||
while matched := self.path_pattern.findall(query[idx:]):
|
||||
matched = matched[0]
|
||||
matched = [
|
||||
m for i, m in enumerate(matched) if i not in [1, len(matched) - 1]
|
||||
]
|
||||
path = "".join(matched)
|
||||
idx = query.find(path) + len(path) - len(matched[-1])
|
||||
paths.append(path)
|
||||
return paths
|
||||
|
||||
def judge_direction(self, relation: str) -> str:
|
||||
"""
|
||||
Args:
|
||||
relation: relation in string format
|
||||
"""
|
||||
direction = "BIDIRECTIONAL"
|
||||
if relation[0] == "<":
|
||||
direction = "INCOMING"
|
||||
if relation[-1] == ">":
|
||||
direction = "OUTGOING"
|
||||
return direction
|
||||
|
||||
def extract_node_variable(self, part: str) -> Optional[str]:
|
||||
"""
|
||||
Args:
|
||||
part: node in string format
|
||||
"""
|
||||
part = part.lstrip("(").rstrip(")")
|
||||
idx = part.find(":")
|
||||
if idx != -1:
|
||||
part = part[:idx]
|
||||
return None if part == "" else part
|
||||
|
||||
def detect_labels(
|
||||
self, str_node: str, node_variable_dict: Dict[str, Any]
|
||||
) -> List[str]:
|
||||
"""
|
||||
Args:
|
||||
str_node: node in string format
|
||||
node_variable_dict: dictionary of node variables
|
||||
"""
|
||||
splitted_node = str_node.split(":")
|
||||
variable = splitted_node[0]
|
||||
labels = []
|
||||
if variable in node_variable_dict:
|
||||
labels = node_variable_dict[variable]
|
||||
elif variable == "" and len(splitted_node) > 1:
|
||||
labels = splitted_node[1:]
|
||||
return labels
|
||||
|
||||
def verify_schema(
|
||||
self,
|
||||
from_node_labels: List[str],
|
||||
relation_types: List[str],
|
||||
to_node_labels: List[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Args:
|
||||
from_node_labels: labels of the from node
|
||||
relation_type: type of the relation
|
||||
to_node_labels: labels of the to node
|
||||
"""
|
||||
valid_schemas = self.schemas
|
||||
if from_node_labels != []:
|
||||
from_node_labels = [label.strip("`") for label in from_node_labels]
|
||||
valid_schemas = [
|
||||
schema for schema in valid_schemas if schema[0] in from_node_labels
|
||||
]
|
||||
if to_node_labels != []:
|
||||
to_node_labels = [label.strip("`") for label in to_node_labels]
|
||||
valid_schemas = [
|
||||
schema for schema in valid_schemas if schema[2] in to_node_labels
|
||||
]
|
||||
if relation_types != []:
|
||||
relation_types = [type.strip("`") for type in relation_types]
|
||||
valid_schemas = [
|
||||
schema for schema in valid_schemas if schema[1] in relation_types
|
||||
]
|
||||
return valid_schemas != []
|
||||
|
||||
def detect_relation_types(self, str_relation: str) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Args:
|
||||
str_relation: relation in string format
|
||||
"""
|
||||
relation_direction = self.judge_direction(str_relation)
|
||||
relation_type = self.relation_type_pattern.search(str_relation)
|
||||
if relation_type is None or relation_type.group("relation_type") is None:
|
||||
return relation_direction, []
|
||||
relation_types = [
|
||||
t.strip().strip("!")
|
||||
for t in relation_type.group("relation_type").split("|")
|
||||
]
|
||||
return relation_direction, relation_types
|
||||
|
||||
def correct_query(self, query: str) -> str:
|
||||
"""
|
||||
Args:
|
||||
query: cypher query
|
||||
"""
|
||||
node_variable_dict = self.detect_node_variables(query)
|
||||
paths = self.extract_paths(query)
|
||||
for path in paths:
|
||||
original_path = path
|
||||
start_idx = 0
|
||||
while start_idx < len(path):
|
||||
match_res = re.match(self.node_relation_node_pattern, path[start_idx:])
|
||||
if match_res is None:
|
||||
break
|
||||
start_idx += match_res.start()
|
||||
match_dict = match_res.groupdict()
|
||||
left_node_labels = self.detect_labels(
|
||||
match_dict["left_node"], node_variable_dict
|
||||
)
|
||||
right_node_labels = self.detect_labels(
|
||||
match_dict["right_node"], node_variable_dict
|
||||
)
|
||||
end_idx = (
|
||||
start_idx
|
||||
+ 4
|
||||
+ len(match_dict["left_node"])
|
||||
+ len(match_dict["relation"])
|
||||
+ len(match_dict["right_node"])
|
||||
)
|
||||
original_partial_path = original_path[start_idx : end_idx + 1]
|
||||
relation_direction, relation_types = self.detect_relation_types(
|
||||
match_dict["relation"]
|
||||
)
|
||||
|
||||
if relation_types != [] and "".join(relation_types).find("*") != -1:
|
||||
start_idx += (
|
||||
len(match_dict["left_node"]) + len(match_dict["relation"]) + 2
|
||||
)
|
||||
continue
|
||||
|
||||
if relation_direction == "OUTGOING":
|
||||
is_legal = self.verify_schema(
|
||||
left_node_labels, relation_types, right_node_labels
|
||||
)
|
||||
if not is_legal:
|
||||
is_legal = self.verify_schema(
|
||||
right_node_labels, relation_types, left_node_labels
|
||||
)
|
||||
if is_legal:
|
||||
corrected_relation = "<" + match_dict["relation"][:-1]
|
||||
corrected_partial_path = original_partial_path.replace(
|
||||
match_dict["relation"], corrected_relation
|
||||
)
|
||||
query = query.replace(
|
||||
original_partial_path, corrected_partial_path
|
||||
)
|
||||
else:
|
||||
return ""
|
||||
elif relation_direction == "INCOMING":
|
||||
is_legal = self.verify_schema(
|
||||
right_node_labels, relation_types, left_node_labels
|
||||
)
|
||||
if not is_legal:
|
||||
is_legal = self.verify_schema(
|
||||
left_node_labels, relation_types, right_node_labels
|
||||
)
|
||||
if is_legal:
|
||||
corrected_relation = match_dict["relation"][1:] + ">"
|
||||
corrected_partial_path = original_partial_path.replace(
|
||||
match_dict["relation"], corrected_relation
|
||||
)
|
||||
query = query.replace(
|
||||
original_partial_path, corrected_partial_path
|
||||
)
|
||||
else:
|
||||
return ""
|
||||
else:
|
||||
is_legal = self.verify_schema(
|
||||
left_node_labels, relation_types, right_node_labels
|
||||
)
|
||||
is_legal |= self.verify_schema(
|
||||
right_node_labels, relation_types, left_node_labels
|
||||
)
|
||||
if not is_legal:
|
||||
return ""
|
||||
|
||||
start_idx += (
|
||||
len(match_dict["left_node"]) + len(match_dict["relation"]) + 2
|
||||
)
|
||||
return query
|
||||
|
||||
def __call__(self, query: str) -> str:
|
||||
"""Correct the query to make it valid. If
|
||||
Args:
|
||||
query: cypher query
|
||||
"""
|
||||
return self.correct_query(query)
|
||||
162
libs/community/langchain_community/chains/graph_qa/falkordb.py
Normal file
162
libs/community/langchain_community/chains/graph_qa/falkordb.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Question answering over a graph."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
try:
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Must have `langchain` installed to use FalkorDBQAChain. Please install it "
|
||||
"with `pip install -U langchain`."
|
||||
) from e
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_GENERATION_PROMPT,
|
||||
CYPHER_QA_PROMPT,
|
||||
)
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
|
||||
def extract_cypher(text: str) -> str:
|
||||
"""
|
||||
Extract Cypher code from a text.
|
||||
Args:
|
||||
text: Text to extract Cypher code from.
|
||||
|
||||
Returns:
|
||||
Cypher code extracted from the text.
|
||||
"""
|
||||
# The pattern to find Cypher code enclosed in triple backticks
|
||||
pattern = r"```(.*?)```"
|
||||
|
||||
# Find all matches in the input text
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
return matches[0] if matches else text
|
||||
|
||||
|
||||
class FalkorDBQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating Cypher statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: Any = Field(exclude=True)
|
||||
cypher_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
top_k: int = 10
|
||||
"""Number of results to return from the query"""
|
||||
return_intermediate_steps: bool = False
|
||||
"""Whether or not to return the intermediate steps along with the final answer."""
|
||||
return_direct: bool = False
|
||||
"""Whether or not to return the result of querying the graph directly."""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "graph_cypher_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||
cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> FalkorDBQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
cypher_generation_chain=cypher_generation_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate Cypher statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
intermediate_steps: List = []
|
||||
|
||||
generated_cypher = self.cypher_generation_chain.run(
|
||||
{"question": question, "schema": self.graph.schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
# Extract Cypher code if it is wrapped in backticks
|
||||
generated_cypher = extract_cypher(generated_cypher)
|
||||
|
||||
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_cypher, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"query": generated_cypher})
|
||||
|
||||
# Retrieve and limit the number of results
|
||||
context = self.graph.query(generated_cypher)[: self.top_k]
|
||||
|
||||
if self.return_direct:
|
||||
final_result = context
|
||||
else:
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"context": context})
|
||||
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
final_result = result[self.qa_chain.output_key]
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
|
||||
return chain_result
|
||||
226
libs/community/langchain_community/chains/graph_qa/gremlin.py
Normal file
226
libs/community/langchain_community/chains/graph_qa/gremlin.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""Question answering over a graph."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
try:
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Must have `langchain` installed to use GremlinQAChain. Please install it "
|
||||
"with `pip install -U langchain`."
|
||||
) from e
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
GRAPHDB_SPARQL_FIX_TEMPLATE,
|
||||
GREMLIN_GENERATION_PROMPT,
|
||||
)
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
|
||||
def extract_gremlin(text: str) -> str:
|
||||
"""Extract Gremlin code from a text.
|
||||
|
||||
Args:
|
||||
text: Text to extract Gremlin code from.
|
||||
|
||||
Returns:
|
||||
Gremlin code extracted from the text.
|
||||
"""
|
||||
text = text.replace("`", "")
|
||||
if text.startswith("gremlin"):
|
||||
text = text[len("gremlin") :]
|
||||
return text.replace("\n", "")
|
||||
|
||||
|
||||
class GremlinQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating gremlin statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: Any = Field(exclude=True)
|
||||
gremlin_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
gremlin_fix_chain: LLMChain
|
||||
max_fix_retries: int = 3
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
top_k: int = 100
|
||||
return_direct: bool = False
|
||||
return_intermediate_steps: bool = False
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
gremlin_fix_prompt: BasePromptTemplate = PromptTemplate(
|
||||
input_variables=["error_message", "generated_sparql", "schema"],
|
||||
template=GRAPHDB_SPARQL_FIX_TEMPLATE.replace("SPARQL", "Gremlin").replace(
|
||||
"in Turtle format", ""
|
||||
),
|
||||
),
|
||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||
gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> GremlinQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt)
|
||||
gremlinl_fix_chain = LLMChain(llm=llm, prompt=gremlin_fix_prompt)
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
gremlin_generation_chain=gremlin_generation_chain,
|
||||
gremlin_fix_chain=gremlinl_fix_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Generate gremlin statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
intermediate_steps: List = []
|
||||
|
||||
chain_response = self.gremlin_generation_chain.invoke(
|
||||
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
generated_gremlin = extract_gremlin(
|
||||
chain_response[self.gremlin_generation_chain.output_key]
|
||||
)
|
||||
|
||||
_run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_gremlin, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"query": generated_gremlin})
|
||||
|
||||
if generated_gremlin:
|
||||
context = self.execute_with_retry(
|
||||
_run_manager, callbacks, generated_gremlin
|
||||
)[: self.top_k]
|
||||
else:
|
||||
context = []
|
||||
|
||||
if self.return_direct:
|
||||
final_result = context
|
||||
else:
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"context": context})
|
||||
|
||||
result = self.qa_chain.invoke(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
final_result = result[self.qa_chain.output_key]
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
|
||||
return chain_result
|
||||
|
||||
def execute_query(self, query: str) -> List[Any]:
|
||||
try:
|
||||
return self.graph.query(query)
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_message"):
|
||||
raise ValueError(e.status_message)
|
||||
else:
|
||||
raise ValueError(str(e))
|
||||
|
||||
def execute_with_retry(
|
||||
self,
|
||||
_run_manager: CallbackManagerForChainRun,
|
||||
callbacks: CallbackManager,
|
||||
generated_gremlin: str,
|
||||
) -> List[Any]:
|
||||
try:
|
||||
return self.execute_query(generated_gremlin)
|
||||
except Exception as e:
|
||||
retries = 0
|
||||
error_message = str(e)
|
||||
self.log_invalid_query(_run_manager, generated_gremlin, error_message)
|
||||
|
||||
while retries < self.max_fix_retries:
|
||||
try:
|
||||
fix_chain_result = self.gremlin_fix_chain.invoke(
|
||||
{
|
||||
"error_message": error_message,
|
||||
# we are borrowing template from sparql
|
||||
"generated_sparql": generated_gremlin,
|
||||
"schema": self.schema,
|
||||
},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
fixed_gremlin = fix_chain_result[self.gremlin_fix_chain.output_key]
|
||||
return self.execute_query(fixed_gremlin)
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
parse_exception = str(e)
|
||||
self.log_invalid_query(_run_manager, fixed_gremlin, parse_exception)
|
||||
|
||||
raise ValueError("The generated Gremlin query is invalid.")
|
||||
|
||||
def log_invalid_query(
|
||||
self,
|
||||
_run_manager: CallbackManagerForChainRun,
|
||||
generated_query: str,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
_run_manager.on_text("Invalid Gremlin query: ", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_query, color="red", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
"Gremlin Query Parse Error: ", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
error_message, color="red", end="\n\n", verbose=self.verbose
|
||||
)
|
||||
111
libs/community/langchain_community/chains/graph_qa/hugegraph.py
Normal file
111
libs/community/langchain_community/chains/graph_qa/hugegraph.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Question answering over a graph."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
try:
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Must have `langchain` installed to use HugeGraphQAChain. Please install it "
|
||||
"with `pip install -U langchain`."
|
||||
) from e
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
GREMLIN_GENERATION_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
class HugeGraphQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating gremlin statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: Any = Field(exclude=True)
|
||||
gremlin_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||
gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> HugeGraphQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
gremlin_generation_chain=gremlin_generation_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Generate gremlin statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
generated_gremlin = self.gremlin_generation_chain.run(
|
||||
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
_run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_gremlin, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
context = self.graph.query(generated_gremlin)
|
||||
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return {self.output_key: result[self.qa_chain.output_key]}
|
||||
112
libs/community/langchain_community/chains/graph_qa/kuzu.py
Normal file
112
libs/community/langchain_community/chains/graph_qa/kuzu.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Question answering over a graph."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
try:
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Must have `langchain` installed to use KuzuQAChain. Please install it "
|
||||
"with `pip install -U langchain`."
|
||||
) from e
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
KUZU_GENERATION_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
class KuzuQAChain(Chain):
|
||||
"""Question-answering against a graph by generating Cypher statements for Kùzu.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: Any = Field(exclude=True)
|
||||
cypher_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||
cypher_prompt: BasePromptTemplate = KUZU_GENERATION_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> KuzuQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
cypher_generation_chain=cypher_generation_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Generate Cypher statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
generated_cypher = self.cypher_generation_chain.run(
|
||||
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_cypher, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
context = self.graph.query(generated_cypher)
|
||||
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return {self.output_key: result[self.qa_chain.output_key]}
|
||||
@@ -0,0 +1,114 @@
|
||||
"""Question answering over a graph."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
try:
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Must have `langchain` installed to use NebulaGraphQAChain. Please install it "
|
||||
"with `pip install -U langchain`."
|
||||
) from e
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
NGQL_GENERATION_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
class NebulaGraphQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating nGQL statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: Any = Field(exclude=True)
|
||||
ngql_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||
ngql_prompt: BasePromptTemplate = NGQL_GENERATION_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> NebulaGraphQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
ngql_generation_chain = LLMChain(llm=llm, prompt=ngql_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
ngql_generation_chain=ngql_generation_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Generate nGQL statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
generated_ngql = self.ngql_generation_chain.run(
|
||||
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
_run_manager.on_text("Generated nGQL:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_ngql, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
context = self.graph.query(generated_ngql)
|
||||
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return {self.output_key: result[self.qa_chain.output_key]}
|
||||
@@ -0,0 +1,225 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
try:
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.prompt_selector import ConditionalPromptSelector
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Must have `langchain` installed to use NeptuneOpenCypherQAChain. Please "
|
||||
"install it with `pip install -U langchain`."
|
||||
) from e
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
|
||||
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT,
|
||||
)
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
|
||||
def trim_query(query: str) -> str:
|
||||
"""Trim the query to only include Cypher keywords."""
|
||||
keywords = (
|
||||
"CALL",
|
||||
"CREATE",
|
||||
"DELETE",
|
||||
"DETACH",
|
||||
"LIMIT",
|
||||
"MATCH",
|
||||
"MERGE",
|
||||
"OPTIONAL",
|
||||
"ORDER",
|
||||
"REMOVE",
|
||||
"RETURN",
|
||||
"SET",
|
||||
"SKIP",
|
||||
"UNWIND",
|
||||
"WITH",
|
||||
"WHERE",
|
||||
"//",
|
||||
)
|
||||
|
||||
lines = query.split("\n")
|
||||
new_query = ""
|
||||
|
||||
for line in lines:
|
||||
if line.strip().upper().startswith(keywords):
|
||||
new_query += line + "\n"
|
||||
|
||||
return new_query
|
||||
|
||||
|
||||
def extract_cypher(text: str) -> str:
|
||||
"""Extract Cypher code from text using Regex."""
|
||||
# The pattern to find Cypher code enclosed in triple backticks
|
||||
pattern = r"```(.*?)```"
|
||||
|
||||
# Find all matches in the input text
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
return matches[0] if matches else text
|
||||
|
||||
|
||||
def use_simple_prompt(llm: BaseLanguageModel) -> bool:
|
||||
"""Decides whether to use the simple prompt"""
|
||||
if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore
|
||||
return True
|
||||
|
||||
# Bedrock anthropic
|
||||
if hasattr(llm, "model_id") and "anthropic" in llm.model_id: # type: ignore
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
PROMPT_SELECTOR = ConditionalPromptSelector(
|
||||
default_prompt=NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
|
||||
conditionals=[(use_simple_prompt, NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT)],
|
||||
)
|
||||
|
||||
|
||||
class NeptuneOpenCypherQAChain(Chain):
|
||||
"""Chain for question-answering against a Neptune graph
|
||||
by generating openCypher statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
chain = NeptuneOpenCypherQAChain.from_llm(
|
||||
llm=llm,
|
||||
graph=graph
|
||||
)
|
||||
response = chain.run(query)
|
||||
"""
|
||||
|
||||
graph: Any = Field(exclude=True)
|
||||
cypher_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
top_k: int = 10
|
||||
return_intermediate_steps: bool = False
|
||||
"""Whether or not to return the intermediate steps along with the final answer."""
|
||||
return_direct: bool = False
|
||||
"""Whether or not to return the result of querying the graph directly."""
|
||||
extra_instructions: Optional[str] = None
|
||||
"""Extra instructions by the appended to the query generation prompt."""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||
cypher_prompt: Optional[BasePromptTemplate] = None,
|
||||
extra_instructions: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> NeptuneOpenCypherQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
|
||||
_cypher_prompt = cypher_prompt or PROMPT_SELECTOR.get_prompt(llm)
|
||||
cypher_generation_chain = LLMChain(llm=llm, prompt=_cypher_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
cypher_generation_chain=cypher_generation_chain,
|
||||
extra_instructions=extra_instructions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate Cypher statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
intermediate_steps: List = []
|
||||
|
||||
generated_cypher = self.cypher_generation_chain.run(
|
||||
{
|
||||
"question": question,
|
||||
"schema": self.graph.get_schema,
|
||||
"extra_instructions": self.extra_instructions or "",
|
||||
},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
# Extract Cypher code if it is wrapped in backticks
|
||||
generated_cypher = extract_cypher(generated_cypher)
|
||||
generated_cypher = trim_query(generated_cypher)
|
||||
|
||||
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_cypher, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"query": generated_cypher})
|
||||
|
||||
context = self.graph.query(generated_cypher)
|
||||
|
||||
if self.return_direct:
|
||||
final_result = context
|
||||
else:
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"context": context})
|
||||
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
final_result = result[self.qa_chain.output_key]
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
|
||||
return chain_result
|
||||
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Question answering over an RDF or OWL graph using SPARQL.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
try:
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Must have `langchain` installed to use NeptuneSparqlQAChain. Please install it "
|
||||
"with `pip install -U langchain`."
|
||||
) from e
|
||||
from langchain_community.chains.graph_qa.prompts import SPARQL_QA_PROMPT
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
SPARQL_GENERATION_TEMPLATE = """
|
||||
Task: Generate a SPARQL SELECT statement for querying a graph database.
|
||||
For instance, to find all email addresses of John Doe, the following
|
||||
query in backticks would be suitable:
|
||||
```
|
||||
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
|
||||
SELECT ?email
|
||||
WHERE {{
|
||||
?person foaf:name "John Doe" .
|
||||
?person foaf:mbox ?email .
|
||||
}}
|
||||
```
|
||||
Instructions:
|
||||
Use only the node types and properties provided in the schema.
|
||||
Do not use any node types and properties that are not explicitly provided.
|
||||
Include all necessary prefixes.
|
||||
|
||||
Examples:
|
||||
|
||||
Schema:
|
||||
{schema}
|
||||
Note: Be as concise as possible.
|
||||
Do not include any explanations or apologies in your responses.
|
||||
Do not respond to any questions that ask for anything else than
|
||||
for you to construct a SPARQL query.
|
||||
Do not include any text except the SPARQL query generated.
|
||||
|
||||
The question is:
|
||||
{prompt}"""
|
||||
|
||||
SPARQL_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "prompt"], template=SPARQL_GENERATION_TEMPLATE
|
||||
)
|
||||
|
||||
|
||||
def extract_sparql(query: str) -> str:
|
||||
query = query.strip()
|
||||
querytoks = query.split("```")
|
||||
if len(querytoks) == 3:
|
||||
query = querytoks[1]
|
||||
|
||||
if query.startswith("sparql"):
|
||||
query = query[6:]
|
||||
elif query.startswith("<sparql>") and query.endswith("</sparql>"):
|
||||
query = query[8:-9]
|
||||
return query
|
||||
|
||||
|
||||
class NeptuneSparqlQAChain(Chain):
|
||||
"""Chain for question-answering against a Neptune graph
|
||||
by generating SPARQL statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
chain = NeptuneSparqlQAChain.from_llm(
|
||||
llm=llm,
|
||||
graph=graph
|
||||
)
|
||||
response = chain.invoke(query)
|
||||
"""
|
||||
|
||||
graph: Any = Field(exclude=True)
|
||||
sparql_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
top_k: int = 10
|
||||
return_intermediate_steps: bool = False
|
||||
"""Whether or not to return the intermediate steps along with the final answer."""
|
||||
return_direct: bool = False
|
||||
"""Whether or not to return the result of querying the graph directly."""
|
||||
extra_instructions: Optional[str] = None
|
||||
"""Extra instructions by the appended to the query generation prompt."""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT,
|
||||
sparql_prompt: BasePromptTemplate = SPARQL_GENERATION_PROMPT,
|
||||
examples: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> NeptuneSparqlQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
template_to_use = SPARQL_GENERATION_TEMPLATE
|
||||
if examples:
|
||||
template_to_use = template_to_use.replace(
|
||||
"Examples:", "Examples: " + examples
|
||||
)
|
||||
sparql_prompt = PromptTemplate(
|
||||
input_variables=["schema", "prompt"], template=template_to_use
|
||||
)
|
||||
sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_prompt)
|
||||
|
||||
return cls( # type: ignore[call-arg]
|
||||
qa_chain=qa_chain,
|
||||
sparql_generation_chain=sparql_generation_chain,
|
||||
examples=examples,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Generate SPARQL query, use it to retrieve a response from the gdb and answer
|
||||
the question.
|
||||
"""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
prompt = inputs[self.input_key]
|
||||
|
||||
intermediate_steps: List = []
|
||||
|
||||
generated_sparql = self.sparql_generation_chain.run(
|
||||
{"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
# Extract SPARQL
|
||||
generated_sparql = extract_sparql(generated_sparql)
|
||||
|
||||
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_sparql, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"query": generated_sparql})
|
||||
|
||||
context = self.graph.query(generated_sparql)
|
||||
|
||||
if self.return_direct:
|
||||
final_result = context
|
||||
else:
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"context": context})
|
||||
|
||||
result = self.qa_chain(
|
||||
{"prompt": prompt, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
final_result = result[self.qa_chain.output_key]
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
|
||||
return chain_result
|
||||
@@ -0,0 +1,195 @@
|
||||
"""Question answering over a graph."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import rdflib
|
||||
|
||||
from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
try:
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Must have `langchain` installed to use OntotextGraphDBQAChain. Please install "
|
||||
"it with `pip install -U langchain`."
|
||||
) from e
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
GRAPHDB_QA_PROMPT,
|
||||
GRAPHDB_SPARQL_FIX_PROMPT,
|
||||
GRAPHDB_SPARQL_GENERATION_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
class OntotextGraphDBQAChain(Chain):
|
||||
"""Question-answering against Ontotext GraphDB
|
||||
https://graphdb.ontotext.com/ by generating SPARQL queries.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: Any = Field(exclude=True)
|
||||
sparql_generation_chain: LLMChain
|
||||
sparql_fix_chain: LLMChain
|
||||
max_fix_retries: int
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
sparql_generation_prompt: BasePromptTemplate = GRAPHDB_SPARQL_GENERATION_PROMPT,
|
||||
sparql_fix_prompt: BasePromptTemplate = GRAPHDB_SPARQL_FIX_PROMPT,
|
||||
max_fix_retries: int = 5,
|
||||
qa_prompt: BasePromptTemplate = GRAPHDB_QA_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> OntotextGraphDBQAChain:
|
||||
"""Initialize from LLM."""
|
||||
sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_generation_prompt)
|
||||
sparql_fix_chain = LLMChain(llm=llm, prompt=sparql_fix_prompt)
|
||||
max_fix_retries = max_fix_retries
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
sparql_generation_chain=sparql_generation_chain,
|
||||
sparql_fix_chain=sparql_fix_chain,
|
||||
max_fix_retries=max_fix_retries,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Generate a SPARQL query, use it to retrieve a response from GraphDB and answer
|
||||
the question.
|
||||
"""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
prompt = inputs[self.input_key]
|
||||
ontology_schema = self.graph.get_schema
|
||||
|
||||
sparql_generation_chain_result = self.sparql_generation_chain.invoke(
|
||||
{"prompt": prompt, "schema": ontology_schema}, callbacks=callbacks
|
||||
)
|
||||
generated_sparql = sparql_generation_chain_result[
|
||||
self.sparql_generation_chain.output_key
|
||||
]
|
||||
|
||||
generated_sparql = self._get_prepared_sparql_query(
|
||||
_run_manager, callbacks, generated_sparql, ontology_schema
|
||||
)
|
||||
query_results = self._execute_query(generated_sparql)
|
||||
|
||||
qa_chain_result = self.qa_chain.invoke(
|
||||
{"prompt": prompt, "context": query_results}, callbacks=callbacks
|
||||
)
|
||||
result = qa_chain_result[self.qa_chain.output_key]
|
||||
return {self.output_key: result}
|
||||
|
||||
def _get_prepared_sparql_query(
|
||||
self,
|
||||
_run_manager: CallbackManagerForChainRun,
|
||||
callbacks: CallbackManager,
|
||||
generated_sparql: str,
|
||||
ontology_schema: str,
|
||||
) -> str:
|
||||
try:
|
||||
return self._prepare_sparql_query(_run_manager, generated_sparql)
|
||||
except Exception as e:
|
||||
retries = 0
|
||||
error_message = str(e)
|
||||
self._log_invalid_sparql_query(
|
||||
_run_manager, generated_sparql, error_message
|
||||
)
|
||||
|
||||
while retries < self.max_fix_retries:
|
||||
try:
|
||||
sparql_fix_chain_result = self.sparql_fix_chain.invoke(
|
||||
{
|
||||
"error_message": error_message,
|
||||
"generated_sparql": generated_sparql,
|
||||
"schema": ontology_schema,
|
||||
},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
generated_sparql = sparql_fix_chain_result[
|
||||
self.sparql_fix_chain.output_key
|
||||
]
|
||||
return self._prepare_sparql_query(_run_manager, generated_sparql)
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
parse_exception = str(e)
|
||||
self._log_invalid_sparql_query(
|
||||
_run_manager, generated_sparql, parse_exception
|
||||
)
|
||||
|
||||
raise ValueError("The generated SPARQL query is invalid.")
|
||||
|
||||
def _prepare_sparql_query(
|
||||
self, _run_manager: CallbackManagerForChainRun, generated_sparql: str
|
||||
) -> str:
|
||||
from rdflib.plugins.sparql import prepareQuery
|
||||
|
||||
prepareQuery(generated_sparql)
|
||||
self._log_prepared_sparql_query(_run_manager, generated_sparql)
|
||||
return generated_sparql
|
||||
|
||||
def _log_prepared_sparql_query(
|
||||
self, _run_manager: CallbackManagerForChainRun, generated_query: str
|
||||
) -> None:
|
||||
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_query, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
def _log_invalid_sparql_query(
|
||||
self,
|
||||
_run_manager: CallbackManagerForChainRun,
|
||||
generated_query: str,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
_run_manager.on_text("Invalid SPARQL query: ", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_query, color="red", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
"SPARQL Query Parse Error: ", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
error_message, color="red", end="\n\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
def _execute_query(self, query: str) -> List[rdflib.query.ResultRow]:
|
||||
try:
|
||||
return self.graph.query(query)
|
||||
except Exception:
|
||||
raise ValueError("Failed to execute the generated SPARQL query.")
|
||||
414
libs/community/langchain_community/chains/graph_qa/prompts.py
Normal file
414
libs/community/langchain_community/chains/graph_qa/prompts.py
Normal file
@@ -0,0 +1,414 @@
|
||||
# flake8: noqa
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
|
||||
_DEFAULT_ENTITY_EXTRACTION_TEMPLATE = """Extract all entities from the following text. As a guideline, a proper noun is generally capitalized. You should definitely extract all names and places.
|
||||
|
||||
Return the output as a single comma-separated list, or NONE if there is nothing of note to return.
|
||||
|
||||
EXAMPLE
|
||||
i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff.
|
||||
Output: Langchain
|
||||
END OF EXAMPLE
|
||||
|
||||
EXAMPLE
|
||||
i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff. I'm working with Sam.
|
||||
Output: Langchain, Sam
|
||||
END OF EXAMPLE
|
||||
|
||||
Begin!
|
||||
|
||||
{input}
|
||||
Output:"""
|
||||
ENTITY_EXTRACTION_PROMPT = PromptTemplate(
|
||||
input_variables=["input"], template=_DEFAULT_ENTITY_EXTRACTION_TEMPLATE
|
||||
)
|
||||
|
||||
_DEFAULT_GRAPH_QA_TEMPLATE = """Use the following knowledge triplets to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
|
||||
{context}
|
||||
|
||||
Question: {question}
|
||||
Helpful Answer:"""
|
||||
GRAPH_QA_PROMPT = PromptTemplate(
|
||||
template=_DEFAULT_GRAPH_QA_TEMPLATE, input_variables=["context", "question"]
|
||||
)
|
||||
|
||||
CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database.
|
||||
Instructions:
|
||||
Use only the provided relationship types and properties in the schema.
|
||||
Do not use any other relationship types or properties that are not provided.
|
||||
Schema:
|
||||
{schema}
|
||||
Note: Do not include any explanations or apologies in your responses.
|
||||
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
|
||||
Do not include any text except the generated Cypher statement.
|
||||
|
||||
The question is:
|
||||
{question}"""
|
||||
CYPHER_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE
|
||||
)
|
||||
|
||||
NEBULAGRAPH_EXTRA_INSTRUCTIONS = """
|
||||
Instructions:
|
||||
|
||||
First, generate cypher then convert it to NebulaGraph Cypher dialect(rather than standard):
|
||||
1. it requires explicit label specification only when referring to node properties: v.`Foo`.name
|
||||
2. note explicit label specification is not needed for edge properties, so it's e.name instead of e.`Bar`.name
|
||||
3. it uses double equals sign for comparison: `==` rather than `=`
|
||||
For instance:
|
||||
```diff
|
||||
< MATCH (p:person)-[e:directed]->(m:movie) WHERE m.name = 'The Godfather II'
|
||||
< RETURN p.name, e.year, m.name;
|
||||
---
|
||||
> MATCH (p:`person`)-[e:directed]->(m:`movie`) WHERE m.`movie`.`name` == 'The Godfather II'
|
||||
> RETURN p.`person`.`name`, e.year, m.`movie`.`name`;
|
||||
```\n"""
|
||||
|
||||
NGQL_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
|
||||
"Generate Cypher", "Generate NebulaGraph Cypher"
|
||||
).replace("Instructions:", NEBULAGRAPH_EXTRA_INSTRUCTIONS)
|
||||
|
||||
NGQL_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "question"], template=NGQL_GENERATION_TEMPLATE
|
||||
)
|
||||
|
||||
KUZU_EXTRA_INSTRUCTIONS = """
|
||||
Instructions:
|
||||
|
||||
Generate statement with Kùzu Cypher dialect (rather than standard):
|
||||
1. do not use `WHERE EXISTS` clause to check the existence of a property because Kùzu database has a fixed schema.
|
||||
2. do not omit relationship pattern. Always use `()-[]->()` instead of `()->()`.
|
||||
3. do not include any notes or comments even if the statement does not produce the expected result.
|
||||
```\n"""
|
||||
|
||||
KUZU_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
|
||||
"Generate Cypher", "Generate Kùzu Cypher"
|
||||
).replace("Instructions:", KUZU_EXTRA_INSTRUCTIONS)
|
||||
|
||||
KUZU_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "question"], template=KUZU_GENERATION_TEMPLATE
|
||||
)
|
||||
|
||||
GREMLIN_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace("Cypher", "Gremlin")
|
||||
|
||||
GREMLIN_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "question"], template=GREMLIN_GENERATION_TEMPLATE
|
||||
)
|
||||
|
||||
CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers.
|
||||
The information part contains the provided information that you must use to construct an answer.
|
||||
The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
|
||||
Make the answer sound as a response to the question. Do not mention that you based the result on the given information.
|
||||
Here is an example:
|
||||
|
||||
Question: Which managers own Neo4j stocks?
|
||||
Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC]
|
||||
Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks.
|
||||
|
||||
Follow this example when generating answers.
|
||||
If the provided information is empty, say that you don't know the answer.
|
||||
Information:
|
||||
{context}
|
||||
|
||||
Question: {question}
|
||||
Helpful Answer:"""
|
||||
CYPHER_QA_PROMPT = PromptTemplate(
|
||||
input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE
|
||||
)
|
||||
|
||||
SPARQL_INTENT_TEMPLATE = """Task: Identify the intent of a prompt and return the appropriate SPARQL query type.
|
||||
You are an assistant that distinguishes different types of prompts and returns the corresponding SPARQL query types.
|
||||
Consider only the following query types:
|
||||
* SELECT: this query type corresponds to questions
|
||||
* UPDATE: this query type corresponds to all requests for deleting, inserting, or changing triples
|
||||
Note: Be as concise as possible.
|
||||
Do not include any explanations or apologies in your responses.
|
||||
Do not respond to any questions that ask for anything else than for you to identify a SPARQL query type.
|
||||
Do not include any unnecessary whitespaces or any text except the query type, i.e., either return 'SELECT' or 'UPDATE'.
|
||||
|
||||
The prompt is:
|
||||
{prompt}
|
||||
Helpful Answer:"""
|
||||
SPARQL_INTENT_PROMPT = PromptTemplate(
|
||||
input_variables=["prompt"], template=SPARQL_INTENT_TEMPLATE
|
||||
)
|
||||
|
||||
SPARQL_GENERATION_SELECT_TEMPLATE = """Task: Generate a SPARQL SELECT statement for querying a graph database.
|
||||
For instance, to find all email addresses of John Doe, the following query in backticks would be suitable:
|
||||
```
|
||||
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
|
||||
SELECT ?email
|
||||
WHERE {{
|
||||
?person foaf:name "John Doe" .
|
||||
?person foaf:mbox ?email .
|
||||
}}
|
||||
```
|
||||
Instructions:
|
||||
Use only the node types and properties provided in the schema.
|
||||
Do not use any node types and properties that are not explicitly provided.
|
||||
Include all necessary prefixes.
|
||||
Schema:
|
||||
{schema}
|
||||
Note: Be as concise as possible.
|
||||
Do not include any explanations or apologies in your responses.
|
||||
Do not respond to any questions that ask for anything else than for you to construct a SPARQL query.
|
||||
Do not include any text except the SPARQL query generated.
|
||||
|
||||
The question is:
|
||||
{prompt}"""
|
||||
SPARQL_GENERATION_SELECT_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "prompt"], template=SPARQL_GENERATION_SELECT_TEMPLATE
|
||||
)
|
||||
|
||||
SPARQL_GENERATION_UPDATE_TEMPLATE = """Task: Generate a SPARQL UPDATE statement for updating a graph database.
|
||||
For instance, to add 'jane.doe@foo.bar' as a new email address for Jane Doe, the following query in backticks would be suitable:
|
||||
```
|
||||
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
|
||||
INSERT {{
|
||||
?person foaf:mbox <mailto:jane.doe@foo.bar> .
|
||||
}}
|
||||
WHERE {{
|
||||
?person foaf:name "Jane Doe" .
|
||||
}}
|
||||
```
|
||||
Instructions:
|
||||
Make the query as short as possible and avoid adding unnecessary triples.
|
||||
Use only the node types and properties provided in the schema.
|
||||
Do not use any node types and properties that are not explicitly provided.
|
||||
Include all necessary prefixes.
|
||||
Schema:
|
||||
{schema}
|
||||
Note: Be as concise as possible.
|
||||
Do not include any explanations or apologies in your responses.
|
||||
Do not respond to any questions that ask for anything else than for you to construct a SPARQL query.
|
||||
Return only the generated SPARQL query, nothing else.
|
||||
|
||||
The information to be inserted is:
|
||||
{prompt}"""
|
||||
SPARQL_GENERATION_UPDATE_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "prompt"], template=SPARQL_GENERATION_UPDATE_TEMPLATE
|
||||
)
|
||||
|
||||
SPARQL_QA_TEMPLATE = """Task: Generate a natural language response from the results of a SPARQL query.
|
||||
You are an assistant that creates well-written and human understandable answers.
|
||||
The information part contains the information provided, which you can use to construct an answer.
|
||||
The information provided is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
|
||||
Make your response sound like the information is coming from an AI assistant, but don't add any information.
|
||||
Information:
|
||||
{context}
|
||||
|
||||
Question: {prompt}
|
||||
Helpful Answer:"""
|
||||
SPARQL_QA_PROMPT = PromptTemplate(
|
||||
input_variables=["context", "prompt"], template=SPARQL_QA_TEMPLATE
|
||||
)
|
||||
|
||||
GRAPHDB_SPARQL_GENERATION_TEMPLATE = """
|
||||
Write a SPARQL SELECT query for querying a graph database.
|
||||
The ontology schema delimited by triple backticks in Turtle format is:
|
||||
```
|
||||
{schema}
|
||||
```
|
||||
Use only the classes and properties provided in the schema to construct the SPARQL query.
|
||||
Do not use any classes or properties that are not explicitly provided in the SPARQL query.
|
||||
Include all necessary prefixes.
|
||||
Do not include any explanations or apologies in your responses.
|
||||
Do not wrap the query in backticks.
|
||||
Do not include any text except the SPARQL query generated.
|
||||
The question delimited by triple backticks is:
|
||||
```
|
||||
{prompt}
|
||||
```
|
||||
"""
|
||||
GRAPHDB_SPARQL_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "prompt"],
|
||||
template=GRAPHDB_SPARQL_GENERATION_TEMPLATE,
|
||||
)
|
||||
|
||||
GRAPHDB_SPARQL_FIX_TEMPLATE = """
|
||||
This following SPARQL query delimited by triple backticks
|
||||
```
|
||||
{generated_sparql}
|
||||
```
|
||||
is not valid.
|
||||
The error delimited by triple backticks is
|
||||
```
|
||||
{error_message}
|
||||
```
|
||||
Give me a correct version of the SPARQL query.
|
||||
Do not change the logic of the query.
|
||||
Do not include any explanations or apologies in your responses.
|
||||
Do not wrap the query in backticks.
|
||||
Do not include any text except the SPARQL query generated.
|
||||
The ontology schema delimited by triple backticks in Turtle format is:
|
||||
```
|
||||
{schema}
|
||||
```
|
||||
"""
|
||||
|
||||
GRAPHDB_SPARQL_FIX_PROMPT = PromptTemplate(
|
||||
input_variables=["error_message", "generated_sparql", "schema"],
|
||||
template=GRAPHDB_SPARQL_FIX_TEMPLATE,
|
||||
)
|
||||
|
||||
GRAPHDB_QA_TEMPLATE = """Task: Generate a natural language response from the results of a SPARQL query.
|
||||
You are an assistant that creates well-written and human understandable answers.
|
||||
The information part contains the information provided, which you can use to construct an answer.
|
||||
The information provided is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
|
||||
Make your response sound like the information is coming from an AI assistant, but don't add any information.
|
||||
Don't use internal knowledge to answer the question, just say you don't know if no information is available.
|
||||
Information:
|
||||
{context}
|
||||
|
||||
Question: {prompt}
|
||||
Helpful Answer:"""
|
||||
GRAPHDB_QA_PROMPT = PromptTemplate(
|
||||
input_variables=["context", "prompt"], template=GRAPHDB_QA_TEMPLATE
|
||||
)
|
||||
|
||||
AQL_GENERATION_TEMPLATE = """Task: Generate an ArangoDB Query Language (AQL) query from a User Input.
|
||||
|
||||
You are an ArangoDB Query Language (AQL) expert responsible for translating a `User Input` into an ArangoDB Query Language (AQL) query.
|
||||
|
||||
You are given an `ArangoDB Schema`. It is a JSON Object containing:
|
||||
1. `Graph Schema`: Lists all Graphs within the ArangoDB Database Instance, along with their Edge Relationships.
|
||||
2. `Collection Schema`: Lists all Collections within the ArangoDB Database Instance, along with their document/edge properties and a document/edge example.
|
||||
|
||||
You may also be given a set of `AQL Query Examples` to help you create the `AQL Query`. If provided, the `AQL Query Examples` should be used as a reference, similar to how `ArangoDB Schema` should be used.
|
||||
|
||||
Things you should do:
|
||||
- Think step by step.
|
||||
- Rely on `ArangoDB Schema` and `AQL Query Examples` (if provided) to generate the query.
|
||||
- Begin the `AQL Query` by the `WITH` AQL keyword to specify all of the ArangoDB Collections required.
|
||||
- Return the `AQL Query` wrapped in 3 backticks (```).
|
||||
- Use only the provided relationship types and properties in the `ArangoDB Schema` and any `AQL Query Examples` queries.
|
||||
- Only answer to requests related to generating an AQL Query.
|
||||
- If a request is unrelated to generating AQL Query, say that you cannot help the user.
|
||||
|
||||
Things you should not do:
|
||||
- Do not use any properties/relationships that can't be inferred from the `ArangoDB Schema` or the `AQL Query Examples`.
|
||||
- Do not include any text except the generated AQL Query.
|
||||
- Do not provide explanations or apologies in your responses.
|
||||
- Do not generate an AQL Query that removes or deletes any data.
|
||||
|
||||
Under no circumstance should you generate an AQL Query that deletes any data whatsoever.
|
||||
|
||||
ArangoDB Schema:
|
||||
{adb_schema}
|
||||
|
||||
AQL Query Examples (Optional):
|
||||
{aql_examples}
|
||||
|
||||
User Input:
|
||||
{user_input}
|
||||
|
||||
AQL Query:
|
||||
"""
|
||||
|
||||
AQL_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["adb_schema", "aql_examples", "user_input"],
|
||||
template=AQL_GENERATION_TEMPLATE,
|
||||
)
|
||||
|
||||
AQL_FIX_TEMPLATE = """Task: Address the ArangoDB Query Language (AQL) error message of an ArangoDB Query Language query.
|
||||
|
||||
You are an ArangoDB Query Language (AQL) expert responsible for correcting the provided `AQL Query` based on the provided `AQL Error`.
|
||||
|
||||
The `AQL Error` explains why the `AQL Query` could not be executed in the database.
|
||||
The `AQL Error` may also contain the position of the error relative to the total number of lines of the `AQL Query`.
|
||||
For example, 'error X at position 2:5' denotes that the error X occurs on line 2, column 5 of the `AQL Query`.
|
||||
|
||||
You are also given the `ArangoDB Schema`. It is a JSON Object containing:
|
||||
1. `Graph Schema`: Lists all Graphs within the ArangoDB Database Instance, along with their Edge Relationships.
|
||||
2. `Collection Schema`: Lists all Collections within the ArangoDB Database Instance, along with their document/edge properties and a document/edge example.
|
||||
|
||||
You will output the `Corrected AQL Query` wrapped in 3 backticks (```). Do not include any text except the Corrected AQL Query.
|
||||
|
||||
Remember to think step by step.
|
||||
|
||||
ArangoDB Schema:
|
||||
{adb_schema}
|
||||
|
||||
AQL Query:
|
||||
{aql_query}
|
||||
|
||||
AQL Error:
|
||||
{aql_error}
|
||||
|
||||
Corrected AQL Query:
|
||||
"""
|
||||
|
||||
AQL_FIX_PROMPT = PromptTemplate(
|
||||
input_variables=[
|
||||
"adb_schema",
|
||||
"aql_query",
|
||||
"aql_error",
|
||||
],
|
||||
template=AQL_FIX_TEMPLATE,
|
||||
)
|
||||
|
||||
AQL_QA_TEMPLATE = """Task: Generate a natural language `Summary` from the results of an ArangoDB Query Language query.
|
||||
|
||||
You are an ArangoDB Query Language (AQL) expert responsible for creating a well-written `Summary` from the `User Input` and associated `AQL Result`.
|
||||
|
||||
A user has executed an ArangoDB Query Language query, which has returned the AQL Result in JSON format.
|
||||
You are responsible for creating an `Summary` based on the AQL Result.
|
||||
|
||||
You are given the following information:
|
||||
- `ArangoDB Schema`: contains a schema representation of the user's ArangoDB Database.
|
||||
- `User Input`: the original question/request of the user, which has been translated into an AQL Query.
|
||||
- `AQL Query`: the AQL equivalent of the `User Input`, translated by another AI Model. Should you deem it to be incorrect, suggest a different AQL Query.
|
||||
- `AQL Result`: the JSON output returned by executing the `AQL Query` within the ArangoDB Database.
|
||||
|
||||
Remember to think step by step.
|
||||
|
||||
Your `Summary` should sound like it is a response to the `User Input`.
|
||||
Your `Summary` should not include any mention of the `AQL Query` or the `AQL Result`.
|
||||
|
||||
ArangoDB Schema:
|
||||
{adb_schema}
|
||||
|
||||
User Input:
|
||||
{user_input}
|
||||
|
||||
AQL Query:
|
||||
{aql_query}
|
||||
|
||||
AQL Result:
|
||||
{aql_result}
|
||||
"""
|
||||
AQL_QA_PROMPT = PromptTemplate(
|
||||
input_variables=["adb_schema", "user_input", "aql_query", "aql_result"],
|
||||
template=AQL_QA_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS = """
|
||||
Instructions:
|
||||
Generate the query in openCypher format and follow these rules:
|
||||
Do not use `NONE`, `ALL` or `ANY` predicate functions, rather use list comprehensions.
|
||||
Do not use `REDUCE` function. Rather use a combination of list comprehension and the `UNWIND` clause to achieve similar results.
|
||||
Do not use `FOREACH` clause. Rather use a combination of `WITH` and `UNWIND` clauses to achieve similar results.{extra_instructions}
|
||||
\n"""
|
||||
|
||||
NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
|
||||
"Instructions:", NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS
|
||||
)
|
||||
|
||||
NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "question", "extra_instructions"],
|
||||
template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE,
|
||||
)
|
||||
|
||||
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE = """
|
||||
Write an openCypher query to answer the following question. Do not explain the answer. Only return the query.{extra_instructions}
|
||||
Question: "{question}".
|
||||
Here is the property graph schema:
|
||||
{schema}
|
||||
\n"""
|
||||
|
||||
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "question", "extra_instructions"],
|
||||
template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE,
|
||||
)
|
||||
158
libs/community/langchain_community/chains/graph_qa/sparql.py
Normal file
158
libs/community/langchain_community/chains/graph_qa/sparql.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Question answering over an RDF or OWL graph using SPARQL.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
try:
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Must have `langchain` installed to use GraphSparqlQAChain. Please install it "
|
||||
"with `pip install -U langchain`."
|
||||
) from e
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
SPARQL_GENERATION_SELECT_PROMPT,
|
||||
SPARQL_GENERATION_UPDATE_PROMPT,
|
||||
SPARQL_INTENT_PROMPT,
|
||||
SPARQL_QA_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
class GraphSparqlQAChain(Chain):
|
||||
"""Question-answering against an RDF or OWL graph by generating SPARQL statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: Any = Field(exclude=True)
|
||||
sparql_generation_select_chain: LLMChain
|
||||
sparql_generation_update_chain: LLMChain
|
||||
sparql_intent_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
return_sparql_query: bool = False
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
sparql_query_key: str = "sparql_query" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT,
|
||||
sparql_select_prompt: BasePromptTemplate = SPARQL_GENERATION_SELECT_PROMPT,
|
||||
sparql_update_prompt: BasePromptTemplate = SPARQL_GENERATION_UPDATE_PROMPT,
|
||||
sparql_intent_prompt: BasePromptTemplate = SPARQL_INTENT_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> GraphSparqlQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
sparql_generation_select_chain = LLMChain(llm=llm, prompt=sparql_select_prompt)
|
||||
sparql_generation_update_chain = LLMChain(llm=llm, prompt=sparql_update_prompt)
|
||||
sparql_intent_chain = LLMChain(llm=llm, prompt=sparql_intent_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
sparql_generation_select_chain=sparql_generation_select_chain,
|
||||
sparql_generation_update_chain=sparql_generation_update_chain,
|
||||
sparql_intent_chain=sparql_intent_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Generate SPARQL query, use it to retrieve a response from the gdb and answer
|
||||
the question.
|
||||
"""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
prompt = inputs[self.input_key]
|
||||
|
||||
_intent = self.sparql_intent_chain.run({"prompt": prompt}, callbacks=callbacks)
|
||||
intent = _intent.strip()
|
||||
|
||||
if "SELECT" in intent and "UPDATE" not in intent:
|
||||
sparql_generation_chain = self.sparql_generation_select_chain
|
||||
intent = "SELECT"
|
||||
elif "UPDATE" in intent and "SELECT" not in intent:
|
||||
sparql_generation_chain = self.sparql_generation_update_chain
|
||||
intent = "UPDATE"
|
||||
else:
|
||||
raise ValueError(
|
||||
"I am sorry, but this prompt seems to fit none of the currently "
|
||||
"supported SPARQL query types, i.e., SELECT and UPDATE."
|
||||
)
|
||||
|
||||
_run_manager.on_text("Identified intent:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(intent, color="green", end="\n", verbose=self.verbose)
|
||||
|
||||
generated_sparql = sparql_generation_chain.run(
|
||||
{"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_sparql, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
if intent == "SELECT":
|
||||
context = self.graph.query(generated_sparql)
|
||||
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
result = self.qa_chain(
|
||||
{"prompt": prompt, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
res = result[self.qa_chain.output_key]
|
||||
elif intent == "UPDATE":
|
||||
self.graph.update(generated_sparql)
|
||||
res = "Successfully inserted triples into the graph."
|
||||
else:
|
||||
raise ValueError("Unsupported SPARQL query type.")
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: res}
|
||||
if self.return_sparql_query:
|
||||
chain_result[self.sparql_query_key] = generated_sparql
|
||||
return chain_result
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
@@ -13,6 +14,11 @@ from langchain_core.messages import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"0.0.29",
|
||||
alternative_import="langchain.memory.chat_message_histories.FileChatMessageHistory",
|
||||
removal="0.2.0",
|
||||
)
|
||||
class FileChatMessageHistory(BaseChatMessageHistory):
|
||||
"""
|
||||
Chat message history that stores history in a local file.
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
from typing import List, Sequence
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
|
||||
@deprecated(
|
||||
"0.0.29",
|
||||
alternative_import="langchain.memory.chat_message_histories.ChatMessageHistory",
|
||||
removal="0.2.0",
|
||||
)
|
||||
class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
|
||||
"""In memory implementation of chat message history.
|
||||
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from langchain_community.entity_stores.redis import RedisEntityStore
|
||||
from langchain_community.entity_stores.upstash import UpstashRedisEntityStore
|
||||
|
||||
__all__ = ["RedisEntityStore", "UpstashRedisEntityStore"]
|
||||
134
libs/community/langchain_community/entity_stores/redis.py
Normal file
134
libs/community/langchain_community/entity_stores/redis.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Iterable, Optional
|
||||
|
||||
from langchain_core.entity_stores import BaseEntityStore
|
||||
|
||||
from langchain_community.utilities.redis import _redis_sentinel_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis as RedisType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisEntityStore(BaseEntityStore):
|
||||
"""Redis-backed Entity store.
|
||||
|
||||
Entities get a TTL of 1 day by default, and
|
||||
that TTL is extended by 3 days every time the entity is read back.
|
||||
|
||||
Must have `redis` and `langchain-community` installed.
|
||||
"""
|
||||
|
||||
redis_client: Any
|
||||
session_id: str = "default"
|
||||
key_prefix: str = "memory_store"
|
||||
ttl: Optional[int] = 60 * 60 * 24
|
||||
recall_ttl: Optional[int] = 60 * 60 * 24 * 3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
url: str = "redis://localhost:6379/0",
|
||||
key_prefix: str = "memory_store",
|
||||
ttl: Optional[int] = 60 * 60 * 24,
|
||||
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
import redis
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
try:
|
||||
self.redis_client = self._get_client(redis_url=url, decode_responses=True)
|
||||
except redis.exceptions.ConnectionError as error:
|
||||
logger.error(error)
|
||||
|
||||
self.session_id = session_id
|
||||
self.key_prefix = key_prefix
|
||||
self.ttl = ttl
|
||||
self.recall_ttl = recall_ttl or ttl
|
||||
|
||||
@staticmethod
|
||||
def _get_client(redis_url: str, **kwargs: Any) -> RedisType:
|
||||
try:
|
||||
import redis
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis>=4.1.0`."
|
||||
)
|
||||
|
||||
# check if normal redis:// or redis+sentinel:// url
|
||||
if redis_url.startswith("redis+sentinel"):
|
||||
redis_client = _redis_sentinel_client(redis_url, **kwargs)
|
||||
elif redis_url.startswith(
|
||||
"rediss+sentinel"
|
||||
): # sentinel with TLS support enables
|
||||
kwargs["ssl"] = True
|
||||
if "ssl_cert_reqs" not in kwargs:
|
||||
kwargs["ssl_cert_reqs"] = "none"
|
||||
redis_client = _redis_sentinel_client(redis_url, **kwargs)
|
||||
else:
|
||||
# connect to redis server from url, reconnect with cluster client if needed
|
||||
redis_client = redis.from_url(redis_url, **kwargs)
|
||||
|
||||
try:
|
||||
cluster_info = redis_client.info("cluster")
|
||||
cluster_enabled = cluster_info["cluster_enabled"] == 1
|
||||
except redis.exceptions.RedisError:
|
||||
cluster_enabled = False
|
||||
if cluster_enabled:
|
||||
from redis.cluster import RedisCluster
|
||||
|
||||
redis_client.close()
|
||||
return RedisCluster.from_url(redis_url, **kwargs)
|
||||
return redis_client
|
||||
|
||||
@property
|
||||
def full_key_prefix(self) -> str:
|
||||
return f"{self.key_prefix}:{self.session_id}"
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
res = (
|
||||
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
|
||||
or default
|
||||
or ""
|
||||
)
|
||||
logger.debug(f"REDIS MEM get '{self.full_key_prefix}:{key}': '{res}'")
|
||||
return res
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
|
||||
logger.debug(
|
||||
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
||||
)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
|
||||
|
||||
def clear(self) -> None:
|
||||
# iterate a list in batches of size batch_size
|
||||
def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]:
|
||||
iterator = iter(iterable)
|
||||
while batch := list(islice(iterator, batch_size)):
|
||||
yield batch
|
||||
|
||||
for keybatch in batched(
|
||||
self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500
|
||||
):
|
||||
self.redis_client.delete(*keybatch)
|
||||
86
libs/community/langchain_community/entity_stores/upstash.py
Normal file
86
libs/community/langchain_community/entity_stores/upstash.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.entity_stores import BaseEntityStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpstashRedisEntityStore(BaseEntityStore):
|
||||
"""Upstash Redis backed Entity store.
|
||||
|
||||
Entities get a TTL of 1 day by default, and
|
||||
that TTL is extended by 3 days every time the entity is read back.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
url: str = "",
|
||||
token: str = "",
|
||||
key_prefix: str = "memory_store",
|
||||
ttl: Optional[int] = 60 * 60 * 24,
|
||||
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
from upstash_redis import Redis
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import upstash_redis python package. "
|
||||
"Please install it with `pip install upstash_redis`."
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
try:
|
||||
self.redis_client = Redis(url=url, token=token)
|
||||
except Exception:
|
||||
logger.error("Upstash Redis instance could not be initiated.")
|
||||
|
||||
self.session_id = session_id
|
||||
self.key_prefix = key_prefix
|
||||
self.ttl = ttl
|
||||
self.recall_ttl = recall_ttl or ttl
|
||||
|
||||
@property
|
||||
def full_key_prefix(self) -> str:
|
||||
return f"{self.key_prefix}:{self.session_id}"
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
res = (
|
||||
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
|
||||
or default
|
||||
or ""
|
||||
)
|
||||
logger.debug(f"Upstash Redis MEM get '{self.full_key_prefix}:{key}': '{res}'")
|
||||
return res
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
|
||||
logger.debug(
|
||||
f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
||||
)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
|
||||
|
||||
def clear(self) -> None:
|
||||
def scan_and_delete(cursor: int) -> int:
|
||||
cursor, keys_to_delete = self.redis_client.scan(
|
||||
cursor, f"{self.full_key_prefix}:*"
|
||||
)
|
||||
self.redis_client.delete(*keys_to_delete)
|
||||
return cursor
|
||||
|
||||
cursor = scan_and_delete(0)
|
||||
while cursor != 0:
|
||||
scan_and_delete(cursor)
|
||||
@@ -1,7 +1,7 @@
|
||||
# flake8: noqa
|
||||
|
||||
from langchain_community.graphs.networkx_graph import KG_TRIPLE_DELIMITER
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE = (
|
||||
"You are a networked intelligence helping a human track knowledge triples"
|
||||
47
libs/community/langchain_community/indexes/graph.py
Normal file
47
libs/community/langchain_community/indexes/graph.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Graph Index Creator."""
|
||||
from typing import Optional, Type
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
from langchain_community.graphs.networkx_graph import NetworkxEntityGraph, parse_triples
|
||||
from langchain_community.indexes._prompts.knowledge_triplet_extraction import (
|
||||
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
class GraphIndexCreator(BaseModel):
|
||||
"""Functionality to create graph index."""
|
||||
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph
|
||||
|
||||
def from_text(
|
||||
self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
|
||||
) -> NetworkxEntityGraph:
|
||||
"""Create graph index from text."""
|
||||
if self.llm is None:
|
||||
raise ValueError("llm should not be None")
|
||||
graph = self.graph_type()
|
||||
chain = prompt | self.llm | StrOutputParser()
|
||||
output = chain.invoke({"text": text})
|
||||
knowledge = parse_triples(output)
|
||||
for triple in knowledge:
|
||||
graph.add_triple(triple)
|
||||
return graph
|
||||
|
||||
async def afrom_text(
|
||||
self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
|
||||
) -> NetworkxEntityGraph:
|
||||
"""Create graph index from text asynchronously."""
|
||||
if self.llm is None:
|
||||
raise ValueError("llm should not be None")
|
||||
graph = self.graph_type()
|
||||
chain = prompt | self.llm | StrOutputParser()
|
||||
output = await chain.ainvoke({"text": text})
|
||||
knowledge = parse_triples(output)
|
||||
for triple in knowledge:
|
||||
graph.add_triple(triple)
|
||||
return graph
|
||||
@@ -1,5 +1,3 @@
|
||||
from langchain_core.exceptions import LangChainException
|
||||
from langchain_core.exceptions import InvalidKeyException
|
||||
|
||||
|
||||
class InvalidKeyException(LangChainException):
|
||||
"""Raised when a key is invalid; e.g., uses incorrect characters."""
|
||||
__all__ = ["InvalidKeyException"]
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
from langchain_community.structured_query_translators.astradb import AstraDBTranslator
|
||||
from langchain_community.structured_query_translators.chroma import ChromaTranslator
|
||||
from langchain_community.structured_query_translators.dashvector import (
|
||||
DashvectorTranslator,
|
||||
)
|
||||
from langchain_community.structured_query_translators.deeplake import DeepLakeTranslator
|
||||
from langchain_community.structured_query_translators.dingo import DingoDBTranslator
|
||||
from langchain_community.structured_query_translators.elasticsearch import (
|
||||
ElasticsearchTranslator,
|
||||
)
|
||||
from langchain_community.structured_query_translators.milvus import MilvusTranslator
|
||||
from langchain_community.structured_query_translators.mongodb_atlas import (
|
||||
MongoDBAtlasTranslator,
|
||||
)
|
||||
from langchain_community.structured_query_translators.myscale import MyScaleTranslator
|
||||
from langchain_community.structured_query_translators.opensearch import (
|
||||
OpenSearchTranslator,
|
||||
)
|
||||
from langchain_community.structured_query_translators.pgvector import PGVectorTranslator
|
||||
from langchain_community.structured_query_translators.pinecone import PineconeTranslator
|
||||
from langchain_community.structured_query_translators.qdrant import QdrantTranslator
|
||||
from langchain_community.structured_query_translators.redis import RedisTranslator
|
||||
from langchain_community.structured_query_translators.supabase import (
|
||||
SupabaseVectorTranslator,
|
||||
)
|
||||
from langchain_community.structured_query_translators.timescalevector import (
|
||||
TimescaleVectorTranslator,
|
||||
)
|
||||
from langchain_community.structured_query_translators.vectara import VectaraTranslator
|
||||
from langchain_community.structured_query_translators.weaviate import WeaviateTranslator
|
||||
|
||||
__all__ = [
|
||||
"AstraDBTranslator",
|
||||
"ChromaTranslator",
|
||||
"DashvectorTranslator",
|
||||
"DeepLakeTranslator",
|
||||
"DingoDBTranslator",
|
||||
"ElasticsearchTranslator",
|
||||
"MilvusTranslator",
|
||||
"MongoDBAtlasTranslator",
|
||||
"MyScaleTranslator",
|
||||
"OpenSearchTranslator",
|
||||
"PGVectorTranslator",
|
||||
"PineconeTranslator",
|
||||
"QdrantTranslator",
|
||||
"RedisTranslator",
|
||||
"SupabaseVectorTranslator",
|
||||
"TimescaleVectorTranslator",
|
||||
"VectaraTranslator",
|
||||
"WeaviateTranslator",
|
||||
]
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Logic for converting internal query language to a valid AstraDB query."""
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Logic for converting internal query language to a valid DashVector query."""
|
||||
from typing import Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Logic for converting internal query language to a valid Chroma query."""
|
||||
from typing import Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Logic for converting internal query language to a valid Milvus query."""
|
||||
from typing import Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Logic for converting internal query language to a valid MongoDB Atlas query."""
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
@@ -2,6 +2,15 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
Visitor,
|
||||
)
|
||||
|
||||
from langchain_community.vectorstores.redis import Redis
|
||||
from langchain_community.vectorstores.redis.filters import (
|
||||
RedisFilterExpression,
|
||||
@@ -11,16 +20,6 @@ from langchain_community.vectorstores.redis.filters import (
|
||||
RedisTag,
|
||||
RedisText,
|
||||
)
|
||||
from langchain_community.vectorstores.redis.schema import RedisModel
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
Visitor,
|
||||
)
|
||||
|
||||
_COMPARATOR_TO_BUILTIN_METHOD = {
|
||||
Comparator.EQ: "__eq__",
|
||||
@@ -51,7 +50,7 @@ class RedisTranslator(Visitor):
|
||||
allowed_operators = (Operator.AND, Operator.OR)
|
||||
"""Subset of allowed logical operators."""
|
||||
|
||||
def __init__(self, schema: RedisModel) -> None:
|
||||
def __init__(self, schema: Any) -> None:
|
||||
self._schema = schema
|
||||
|
||||
def _attribute_to_filter_field(self, attribute: str) -> RedisFilterField:
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
@@ -1,7 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
@@ -459,10 +459,10 @@ class APIOperation(BaseModel):
|
||||
"""The HTTP method of the operation."""
|
||||
|
||||
properties: Sequence[APIProperty] = Field(alias="properties")
|
||||
"""The properties of the operation."""
|
||||
|
||||
# TODO: Add parse in used components to be able to specify what type of
|
||||
# referenced object it is.
|
||||
# """The properties of the operation."""
|
||||
# components: Dict[str, BaseModel] = Field(alias="components")
|
||||
|
||||
request_body: Optional[APIRequestBody] = Field(alias="request_body")
|
||||
|
||||
@@ -24,9 +24,11 @@ from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.runnables.utils import gather_with_concurrency
|
||||
from langchain_core.structured_query.ir import Visitor
|
||||
from langchain_core.utils.iter import batch_iterate
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.structured_query_translators.astradb import AstraDBTranslator
|
||||
from langchain_community.utilities.astradb import (
|
||||
SetupMode,
|
||||
_AstraDBCollectionEnvironment,
|
||||
@@ -1283,3 +1285,6 @@ class AstraDB(VectorStore):
|
||||
an `AstraDB` vectorstore.
|
||||
"""
|
||||
return super().from_documents(documents, embedding, **kwargs)
|
||||
|
||||
def get_structured_query_translator(self) -> Visitor:
|
||||
return AstraDBTranslator()
|
||||
|
||||
@@ -18,9 +18,11 @@ from typing import (
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.structured_query.ir import Visitor
|
||||
from langchain_core.utils import xor_args
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.structured_query_translators.chroma import ChromaTranslator
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -795,3 +797,6 @@ class Chroma(VectorStore):
|
||||
ids: List of ids to delete.
|
||||
"""
|
||||
self._collection.delete(ids=ids)
|
||||
|
||||
def get_structured_query_translator(self) -> Visitor:
|
||||
return ChromaTranslator()
|
||||
|
||||
@@ -13,9 +13,13 @@ from typing import (
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.structured_query.ir import Visitor
|
||||
from langchain_core.utils import get_from_env
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.structured_query_translators.dashvector import (
|
||||
DashvectorTranslator,
|
||||
)
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -399,3 +403,6 @@ class DashVector(VectorStore):
|
||||
dashvector_vector_db = cls(collection, embedding, text_field)
|
||||
dashvector_vector_db.add_texts(texts, metadatas, ids, batch_size)
|
||||
return dashvector_vector_db
|
||||
|
||||
def get_structured_query_translator(self) -> Visitor:
|
||||
return DashvectorTranslator()
|
||||
|
||||
@@ -4,6 +4,9 @@ import logging
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.structured_query.ir import Visitor
|
||||
|
||||
from langchain_community.structured_query_translators.deeplake import DeepLakeTranslator
|
||||
|
||||
try:
|
||||
import deeplake
|
||||
@@ -956,3 +959,6 @@ class DeepLake(VectorStore):
|
||||
if kwargs:
|
||||
unsupported_items = "`, `".join(set(kwargs.keys()))
|
||||
return unsupported_items
|
||||
|
||||
def get_structured_query_translator(self) -> Visitor:
|
||||
return DeepLakeTranslator()
|
||||
|
||||
@@ -7,8 +7,10 @@ from typing import Any, Iterable, List, Optional, Tuple
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.structured_query.ir import Visitor
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.structured_query_translators.dingo import DingoDBTranslator
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -380,3 +382,6 @@ class Dingo(VectorStore):
|
||||
raise ValueError("No ids provided to delete.")
|
||||
|
||||
return self._client.vector_delete(self._index_name, ids=ids)
|
||||
|
||||
def get_structured_query_translator(self) -> Visitor:
|
||||
return DingoDBTranslator()
|
||||
|
||||
@@ -18,8 +18,12 @@ import numpy as np
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.structured_query.ir import Visitor
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.structured_query_translators.elasticsearch import (
|
||||
ElasticsearchTranslator,
|
||||
)
|
||||
from langchain_community.vectorstores.utils import (
|
||||
DistanceStrategy,
|
||||
maximal_marginal_relevance,
|
||||
@@ -1320,3 +1324,6 @@ class ElasticsearchStore(VectorStore):
|
||||
deployed to Elasticsearch.
|
||||
"""
|
||||
return SparseRetrievalStrategy(model_id=model_id)
|
||||
|
||||
def get_structured_query_translator(self) -> Visitor:
|
||||
return ElasticsearchTranslator()
|
||||
|
||||
@@ -7,8 +7,10 @@ from uuid import uuid4
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.structured_query.ir import Visitor
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.structured_query_translators.milvus import MilvusTranslator
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1057,3 +1059,6 @@ class Milvus(VectorStore):
|
||||
"Failed to upsert entities: %s error: %s", self.collection_name, exc
|
||||
)
|
||||
raise exc
|
||||
|
||||
def get_structured_query_translator(self) -> Visitor:
|
||||
return MilvusTranslator()
|
||||
|
||||
@@ -19,8 +19,12 @@ import numpy as np
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.structured_query.ir import Visitor
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.structured_query_translators.mongodb_atlas import (
|
||||
MongoDBAtlasTranslator,
|
||||
)
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -374,3 +378,6 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
vectorstore = cls(collection, embedding, **kwargs)
|
||||
vectorstore.add_texts(texts, metadatas=metadatas)
|
||||
return vectorstore
|
||||
|
||||
def get_structured_query_translator(self) -> Visitor:
|
||||
return MongoDBAtlasTranslator()
|
||||
|
||||
@@ -9,8 +9,11 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseSettings
|
||||
from langchain_core.structured_query.ir import Visitor
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.structured_query_translators.myscale import MyScaleTranslator
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@@ -613,3 +616,6 @@ class MyScaleWithoutJSON(MyScale):
|
||||
@property
|
||||
def metadata_column(self) -> str:
|
||||
return ""
|
||||
|
||||
def get_structured_query_translator(self) -> Visitor:
|
||||
return MyScaleTranslator()
|
||||
|
||||
@@ -7,9 +7,13 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.structured_query.ir import Visitor
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.structured_query_translators.opensearch import (
|
||||
OpenSearchTranslator,
|
||||
)
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
IMPORT_OPENSEARCH_PY_ERROR = (
|
||||
@@ -979,3 +983,6 @@ class OpenSearchVectorSearch(VectorStore):
|
||||
)
|
||||
kwargs["engine"] = engine
|
||||
return cls(opensearch_url, index_name, embedding, **kwargs)
|
||||
|
||||
def get_structured_query_translator(self) -> Visitor:
|
||||
return OpenSearchTranslator()
|
||||
|
||||
@@ -20,10 +20,13 @@ from typing import (
|
||||
import numpy as np
|
||||
import sqlalchemy
|
||||
from langchain_core._api import warn_deprecated
|
||||
from langchain_core.structured_query.ir import Visitor
|
||||
from sqlalchemy import SQLColumnExpression, delete, func
|
||||
from sqlalchemy.dialects.postgresql import JSON, JSONB, UUID
|
||||
from sqlalchemy.orm import Session, relationship
|
||||
|
||||
from langchain_community.structured_query_translators.pgvector import PGVectorTranslator
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import declarative_base
|
||||
except ImportError:
|
||||
@@ -1341,3 +1344,6 @@ class PGVector(VectorStore):
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_structured_query_translator(self) -> Visitor:
|
||||
return PGVectorTranslator()
|
||||
|
||||
@@ -10,10 +10,12 @@ import numpy as np
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.structured_query.ir import Visitor
|
||||
from langchain_core.utils.iter import batch_iterate
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from packaging import version
|
||||
|
||||
from langchain_community.structured_query_translators.pinecone import PineconeTranslator
|
||||
from langchain_community.vectorstores.utils import (
|
||||
DistanceStrategy,
|
||||
maximal_marginal_relevance,
|
||||
@@ -486,3 +488,6 @@ class Pinecone(VectorStore):
|
||||
raise ValueError("Either ids, delete_all, or filter must be provided.")
|
||||
|
||||
return None
|
||||
|
||||
def get_structured_query_translator(self) -> Visitor:
|
||||
return PineconeTranslator()
|
||||
|
||||
@@ -24,9 +24,11 @@ from typing import (
|
||||
import numpy as np
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
from langchain_core.structured_query.ir import Visitor
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.docstore.document import Document
|
||||
from langchain_community.structured_query_translators.qdrant import QdrantTranslator
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -2234,3 +2236,6 @@ class Qdrant(VectorStore):
|
||||
)
|
||||
|
||||
return sync_client, async_client
|
||||
|
||||
def get_structured_query_translator(self) -> Visitor:
|
||||
return QdrantTranslator(metadata_key=self.metadata_payload_key)
|
||||
|
||||
@@ -18,8 +18,12 @@ from typing import (
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.structured_query.ir import Visitor
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.structured_query_translators.supabase import (
|
||||
SupabaseVectorTranslator,
|
||||
)
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -478,3 +482,6 @@ class SupabaseVectorStore(VectorStore):
|
||||
# TODO: Check if this can be done in bulk
|
||||
for row in rows:
|
||||
self._client.from_(self.table_name).delete().eq("id", row["id"]).execute()
|
||||
|
||||
def get_structured_query_translator(self) -> Visitor:
|
||||
return SupabaseVectorTranslator()
|
||||
|
||||
@@ -20,9 +20,13 @@ from typing import (
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.structured_query.ir import Visitor
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.structured_query_translators.timescalevector import (
|
||||
TimescaleVectorTranslator,
|
||||
)
|
||||
from langchain_community.vectorstores.utils import DistanceStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -880,3 +884,6 @@ class TimescaleVector(VectorStore):
|
||||
|
||||
def drop_index(self) -> None:
|
||||
self.sync_client.drop_embedding_index()
|
||||
|
||||
def get_structured_query_translator(self) -> Visitor:
|
||||
return TimescaleVectorTranslator()
|
||||
|
||||
@@ -11,8 +11,11 @@ import requests
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.structured_query.ir import Visitor
|
||||
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
|
||||
|
||||
from langchain_community.structured_query_translators.vectara import VectaraTranslator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -603,3 +606,6 @@ class VectaraRetriever(VectorStoreRetriever):
|
||||
metadatas (List[dict]): Metadata dicts, must line up with existing store
|
||||
"""
|
||||
self.vectorstore.add_texts(texts, metadatas, doc_metadata or {})
|
||||
|
||||
def get_structured_query_translator(self) -> Visitor:
|
||||
return VectaraTranslator()
|
||||
|
||||
@@ -17,8 +17,10 @@ from uuid import uuid4
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.structured_query.ir import Visitor
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.structured_query_translators.weaviate import WeaviateTranslator
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -526,3 +528,6 @@ class Weaviate(VectorStore):
|
||||
# TODO: Check if this can be done in bulk
|
||||
for id in ids:
|
||||
self._client.data_object.delete(uuid=id)
|
||||
|
||||
def get_structured_query_translator(self) -> Visitor:
|
||||
return WeaviateTranslator()
|
||||
|
||||
0
libs/community/tests/unit_tests/chains/__init__.py
Normal file
0
libs/community/tests/unit_tests/chains/__init__.py
Normal file
@@ -1,18 +1,24 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pandas as pd
|
||||
from langchain_community.graphs.graph_document import GraphDocument
|
||||
from langchain_community.graphs.graph_store import GraphStore
|
||||
from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from langchain.chains.graph_qa.cypher import (
|
||||
from langchain_community.chains.graph_qa.cypher import (
|
||||
GraphCypherQAChain,
|
||||
construct_schema,
|
||||
extract_cypher,
|
||||
)
|
||||
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
|
||||
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT
|
||||
from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory
|
||||
from langchain_community.chains.graph_qa.cypher_utils import (
|
||||
CypherQueryCorrector,
|
||||
Schema,
|
||||
)
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_GENERATION_PROMPT,
|
||||
CYPHER_QA_PROMPT,
|
||||
)
|
||||
from langchain_community.graphs.graph_document import GraphDocument
|
||||
from langchain_community.graphs.graph_store import GraphStore
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.astradb import AstraDBTranslator
|
||||
|
||||
from langchain_community.structured_query_translators import AstraDBTranslator
|
||||
|
||||
DEFAULT_TRANSLATOR = AstraDBTranslator()
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.chroma import ChromaTranslator
|
||||
|
||||
from langchain_community.structured_query_translators import ChromaTranslator
|
||||
|
||||
DEFAULT_TRANSLATOR = ChromaTranslator()
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from typing import Any, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
)
|
||||
from langchain.retrievers.self_query.dashvector import DashvectorTranslator
|
||||
|
||||
from langchain_community.structured_query_translators import DashvectorTranslator
|
||||
|
||||
DEFAULT_TRANSLATOR = DashvectorTranslator()
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.deeplake import DeepLakeTranslator
|
||||
|
||||
from langchain_community.structured_query_translators.deeplake import DeepLakeTranslator
|
||||
|
||||
DEFAULT_TRANSLATOR = DeepLakeTranslator()
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.dingo import DingoDBTranslator
|
||||
|
||||
from langchain_community.structured_query_translators.dingo import DingoDBTranslator
|
||||
|
||||
DEFAULT_TRANSLATOR = DingoDBTranslator()
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.elasticsearch import ElasticsearchTranslator
|
||||
|
||||
from langchain_community.structured_query_translators import ElasticsearchTranslator
|
||||
|
||||
DEFAULT_TRANSLATOR = ElasticsearchTranslator()
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.milvus import MilvusTranslator
|
||||
|
||||
from langchain_community.structured_query_translators import MilvusTranslator
|
||||
|
||||
DEFAULT_TRANSLATOR = MilvusTranslator()
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.mongodb_atlas import MongoDBAtlasTranslator
|
||||
|
||||
from langchain_community.structured_query_translators.mongodb_atlas import (
|
||||
MongoDBAtlasTranslator,
|
||||
)
|
||||
|
||||
DEFAULT_TRANSLATOR = MongoDBAtlasTranslator()
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.myscale import MyScaleTranslator
|
||||
|
||||
from langchain_community.structured_query_translators import MyScaleTranslator
|
||||
|
||||
DEFAULT_TRANSLATOR = MyScaleTranslator()
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.opensearch import OpenSearchTranslator
|
||||
|
||||
from langchain_community.structured_query_translators.opensearch import (
|
||||
OpenSearchTranslator,
|
||||
)
|
||||
|
||||
DEFAULT_TRANSLATOR = OpenSearchTranslator()
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import pytest as pytest
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.pgvector import PGVectorTranslator
|
||||
|
||||
from langchain_community.structured_query_translators import PGVectorTranslator
|
||||
|
||||
DEFAULT_TRANSLATOR = PGVectorTranslator()
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.pinecone import PineconeTranslator
|
||||
|
||||
from langchain_community.structured_query_translators.pinecone import PineconeTranslator
|
||||
|
||||
DEFAULT_TRANSLATOR = PineconeTranslator()
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import pytest
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
|
||||
from langchain_community.structured_query_translators.redis import RedisTranslator
|
||||
from langchain_community.vectorstores.redis.filters import (
|
||||
RedisFilterExpression,
|
||||
RedisNum,
|
||||
@@ -14,15 +23,6 @@ from langchain_community.vectorstores.redis.schema import (
|
||||
TextFieldSchema,
|
||||
)
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.redis import RedisTranslator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def translator() -> RedisTranslator:
|
||||
@@ -1,13 +1,14 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.supabase import SupabaseVectorTranslator
|
||||
|
||||
from langchain_community.structured_query_translators import SupabaseVectorTranslator
|
||||
|
||||
DEFAULT_TRANSLATOR = SupabaseVectorTranslator()
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import pytest as pytest
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.timescalevector import TimescaleVectorTranslator
|
||||
|
||||
from langchain_community.structured_query_translators import TimescaleVectorTranslator
|
||||
|
||||
DEFAULT_TRANSLATOR = TimescaleVectorTranslator()
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.vectara import VectaraTranslator
|
||||
|
||||
from langchain_community.structured_query_translators import VectaraTranslator
|
||||
|
||||
DEFAULT_TRANSLATOR = VectaraTranslator()
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
from langchain_core.structured_query.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.weaviate import WeaviateTranslator
|
||||
|
||||
from langchain_community.structured_query_translators.weaviate import WeaviateTranslator
|
||||
|
||||
DEFAULT_TRANSLATOR = WeaviateTranslator()
|
||||
|
||||
35
libs/core/langchain_core/entity_stores.py
Normal file
35
libs/core/langchain_core/entity_stores.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseEntityStore(BaseModel, ABC):
|
||||
"""Abstract base class for Entity store."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
"""Get entity value from store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
"""Set entity value in store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete entity value from store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
"""Check if entity exists in store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Delete all entities from store."""
|
||||
pass
|
||||
@@ -47,3 +47,7 @@ class OutputParserException(ValueError, LangChainException):
|
||||
self.observation = observation
|
||||
self.llm_output = llm_output
|
||||
self.send_to_llm = send_to_llm
|
||||
|
||||
|
||||
class InvalidKeyException(LangChainException):
|
||||
"""Raised when a key for a BaseStore is invalid; e.g., uses incorrect characters."""
|
||||
|
||||
@@ -22,7 +22,7 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from abc import abstractmethod, ABC
|
||||
from inspect import signature
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
@@ -919,3 +919,11 @@ def tool(
|
||||
return _partial
|
||||
else:
|
||||
raise ValueError("Too many arguments for tool decorator")
|
||||
|
||||
|
||||
class BaseToolkit(BaseModel, ABC):
|
||||
"""Base Toolkit representing a collection of related tools."""
|
||||
|
||||
@abstractmethod
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
@@ -40,6 +40,17 @@ def _warn_on_import(name: str, replacement: Optional[str] = None) -> None:
|
||||
surface_langchain_deprecation_warnings()
|
||||
|
||||
|
||||
def _raise_community_deprecation_error(name: str, new_module: str) -> None:
|
||||
raise ImportError(
|
||||
f"{name} has been moved to the langchain-community package. "
|
||||
f"See https://github.com/langchain-ai/langchain/discussions/19083 for more "
|
||||
f"information.\n\nTo use it install langchain-community:\n\n"
|
||||
f"`pip install -U langchain-community`\n\n"
|
||||
f"then import with:\n\n"
|
||||
f"`from {new_module} import {name}`"
|
||||
)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "MRKLChain":
|
||||
from langchain.agents import MRKLChain
|
||||
@@ -124,112 +135,39 @@ def __getattr__(name: str) -> Any:
|
||||
|
||||
return Wikipedia
|
||||
elif name == "Anthropic":
|
||||
from langchain_community.llms import Anthropic
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.Anthropic")
|
||||
|
||||
return Anthropic
|
||||
_raise_community_deprecation_error(name, "langchain_community.llms")
|
||||
elif name == "Banana":
|
||||
from langchain_community.llms import Banana
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.Banana")
|
||||
|
||||
return Banana
|
||||
_raise_community_deprecation_error(name, "langchain_community.llms")
|
||||
elif name == "CerebriumAI":
|
||||
from langchain_community.llms import CerebriumAI
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.CerebriumAI")
|
||||
|
||||
return CerebriumAI
|
||||
_raise_community_deprecation_error(name, "langchain_community.llms")
|
||||
elif name == "Cohere":
|
||||
from langchain_community.llms import Cohere
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.Cohere")
|
||||
|
||||
return Cohere
|
||||
_raise_community_deprecation_error(name, "langchain_community.llms")
|
||||
elif name == "ForefrontAI":
|
||||
from langchain_community.llms import ForefrontAI
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.ForefrontAI")
|
||||
|
||||
return ForefrontAI
|
||||
_raise_community_deprecation_error(name, "langchain_community.llms")
|
||||
elif name == "GooseAI":
|
||||
from langchain_community.llms import GooseAI
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.GooseAI")
|
||||
|
||||
return GooseAI
|
||||
_raise_community_deprecation_error(name, "langchain_community.llms")
|
||||
elif name == "HuggingFaceHub":
|
||||
from langchain_community.llms import HuggingFaceHub
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.HuggingFaceHub")
|
||||
|
||||
return HuggingFaceHub
|
||||
_raise_community_deprecation_error(name, "langchain_community.llms")
|
||||
elif name == "HuggingFaceTextGenInference":
|
||||
from langchain_community.llms import HuggingFaceTextGenInference
|
||||
|
||||
_warn_on_import(
|
||||
name, replacement="langchain_community.llms.HuggingFaceTextGenInference"
|
||||
)
|
||||
|
||||
return HuggingFaceTextGenInference
|
||||
_raise_community_deprecation_error(name, "langchain_community.llms")
|
||||
elif name == "LlamaCpp":
|
||||
from langchain_community.llms import LlamaCpp
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.LlamaCpp")
|
||||
|
||||
return LlamaCpp
|
||||
_raise_community_deprecation_error(name, "langchain_community.llms")
|
||||
elif name == "Modal":
|
||||
from langchain_community.llms import Modal
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.Modal")
|
||||
|
||||
return Modal
|
||||
_raise_community_deprecation_error(name, "langchain_community.llms")
|
||||
elif name == "OpenAI":
|
||||
from langchain_community.llms import OpenAI
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.OpenAI")
|
||||
|
||||
return OpenAI
|
||||
_raise_community_deprecation_error(name, "langchain_community.llms")
|
||||
elif name == "Petals":
|
||||
from langchain_community.llms import Petals
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.Petals")
|
||||
|
||||
return Petals
|
||||
_raise_community_deprecation_error(name, "langchain_community.llms")
|
||||
elif name == "PipelineAI":
|
||||
from langchain_community.llms import PipelineAI
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.PipelineAI")
|
||||
|
||||
return PipelineAI
|
||||
_raise_community_deprecation_error(name, "langchain_community.llms")
|
||||
elif name == "SagemakerEndpoint":
|
||||
from langchain_community.llms import SagemakerEndpoint
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.SagemakerEndpoint")
|
||||
|
||||
return SagemakerEndpoint
|
||||
_raise_community_deprecation_error(name, "langchain_community.llms")
|
||||
elif name == "StochasticAI":
|
||||
from langchain_community.llms import StochasticAI
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.StochasticAI")
|
||||
|
||||
return StochasticAI
|
||||
_raise_community_deprecation_error(name, "langchain_community.llms")
|
||||
elif name == "Writer":
|
||||
from langchain_community.llms import Writer
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.llms.Writer")
|
||||
|
||||
return Writer
|
||||
_raise_community_deprecation_error(name, "langchain_community.llms")
|
||||
elif name == "HuggingFacePipeline":
|
||||
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
|
||||
_warn_on_import(
|
||||
name,
|
||||
replacement="langchain_community.llms.huggingface_pipeline.HuggingFacePipeline",
|
||||
)
|
||||
|
||||
return HuggingFacePipeline
|
||||
_raise_community_deprecation_error(name, "langchain_community.llms")
|
||||
elif name == "FewShotPromptTemplate":
|
||||
from langchain_core.prompts import FewShotPromptTemplate
|
||||
|
||||
@@ -267,90 +205,28 @@ def __getattr__(name: str) -> Any:
|
||||
|
||||
return ArxivAPIWrapper
|
||||
elif name == "GoldenQueryAPIWrapper":
|
||||
from langchain_community.utilities import GoldenQueryAPIWrapper
|
||||
|
||||
_warn_on_import(
|
||||
name, replacement="langchain_community.utilities.GoldenQueryAPIWrapper"
|
||||
)
|
||||
|
||||
return GoldenQueryAPIWrapper
|
||||
_raise_community_deprecation_error(name, "langchain_community.utilities")
|
||||
elif name == "GoogleSearchAPIWrapper":
|
||||
from langchain_community.utilities import GoogleSearchAPIWrapper
|
||||
|
||||
_warn_on_import(
|
||||
name, replacement="langchain_community.utilities.GoogleSearchAPIWrapper"
|
||||
)
|
||||
|
||||
return GoogleSearchAPIWrapper
|
||||
_raise_community_deprecation_error(name, "langchain_community.utilities")
|
||||
elif name == "GoogleSerperAPIWrapper":
|
||||
from langchain_community.utilities import GoogleSerperAPIWrapper
|
||||
|
||||
_warn_on_import(
|
||||
name, replacement="langchain_community.utilities.GoogleSerperAPIWrapper"
|
||||
)
|
||||
|
||||
return GoogleSerperAPIWrapper
|
||||
_raise_community_deprecation_error(name, "langchain_community.utilities")
|
||||
elif name == "PowerBIDataset":
|
||||
from langchain_community.utilities import PowerBIDataset
|
||||
|
||||
_warn_on_import(
|
||||
name, replacement="langchain_community.utilities.PowerBIDataset"
|
||||
)
|
||||
|
||||
return PowerBIDataset
|
||||
_raise_community_deprecation_error(name, "langchain_community.utilities")
|
||||
elif name == "SearxSearchWrapper":
|
||||
from langchain_community.utilities import SearxSearchWrapper
|
||||
|
||||
_warn_on_import(
|
||||
name, replacement="langchain_community.utilities.SearxSearchWrapper"
|
||||
)
|
||||
|
||||
return SearxSearchWrapper
|
||||
_raise_community_deprecation_error(name, "langchain_community.utilities")
|
||||
elif name == "WikipediaAPIWrapper":
|
||||
from langchain_community.utilities import WikipediaAPIWrapper
|
||||
|
||||
_warn_on_import(
|
||||
name, replacement="langchain_community.utilities.WikipediaAPIWrapper"
|
||||
)
|
||||
|
||||
return WikipediaAPIWrapper
|
||||
_raise_community_deprecation_error(name, "langchain_community.utilities")
|
||||
elif name == "WolframAlphaAPIWrapper":
|
||||
from langchain_community.utilities import WolframAlphaAPIWrapper
|
||||
|
||||
_warn_on_import(
|
||||
name, replacement="langchain_community.utilities.WolframAlphaAPIWrapper"
|
||||
)
|
||||
|
||||
return WolframAlphaAPIWrapper
|
||||
_raise_community_deprecation_error(name, "langchain_community.utilities")
|
||||
elif name == "SQLDatabase":
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.utilities.SQLDatabase")
|
||||
|
||||
return SQLDatabase
|
||||
_raise_community_deprecation_error(name, "langchain_community.utilities")
|
||||
elif name == "FAISS":
|
||||
from langchain_community.vectorstores import FAISS
|
||||
|
||||
_warn_on_import(name, replacement="langchain_community.vectorstores.FAISS")
|
||||
|
||||
return FAISS
|
||||
_raise_community_deprecation_error(name, "langchain_community.vectorstores")
|
||||
elif name == "ElasticVectorSearch":
|
||||
from langchain_community.vectorstores import ElasticVectorSearch
|
||||
|
||||
_warn_on_import(
|
||||
name, replacement="langchain_community.vectorstores.ElasticVectorSearch"
|
||||
)
|
||||
|
||||
return ElasticVectorSearch
|
||||
_raise_community_deprecation_error(name, "langchain_community.vectorstores")
|
||||
# For backwards compatibility
|
||||
elif name == "SerpAPIChain" or name == "SerpAPIWrapper":
|
||||
from langchain_community.utilities import SerpAPIWrapper
|
||||
|
||||
_warn_on_import(
|
||||
name, replacement="langchain_community.utilities.SerpAPIWrapper"
|
||||
)
|
||||
|
||||
return SerpAPIWrapper
|
||||
_raise_community_deprecation_error(name, "langchain_community.utilities")
|
||||
elif name == "verbose":
|
||||
from langchain.globals import _verbose
|
||||
|
||||
@@ -392,47 +268,15 @@ __all__ = [
|
||||
"LLMChain",
|
||||
"LLMCheckerChain",
|
||||
"LLMMathChain",
|
||||
"ArxivAPIWrapper",
|
||||
"GoldenQueryAPIWrapper",
|
||||
"SelfAskWithSearchChain",
|
||||
"SerpAPIWrapper",
|
||||
"SerpAPIChain",
|
||||
"SearxSearchWrapper",
|
||||
"GoogleSearchAPIWrapper",
|
||||
"GoogleSerperAPIWrapper",
|
||||
"WolframAlphaAPIWrapper",
|
||||
"WikipediaAPIWrapper",
|
||||
"Anthropic",
|
||||
"Banana",
|
||||
"CerebriumAI",
|
||||
"Cohere",
|
||||
"ForefrontAI",
|
||||
"GooseAI",
|
||||
"Modal",
|
||||
"OpenAI",
|
||||
"Petals",
|
||||
"PipelineAI",
|
||||
"StochasticAI",
|
||||
"Writer",
|
||||
"BasePromptTemplate",
|
||||
"Prompt",
|
||||
"FewShotPromptTemplate",
|
||||
"PromptTemplate",
|
||||
"ReActChain",
|
||||
"Wikipedia",
|
||||
"HuggingFaceHub",
|
||||
"SagemakerEndpoint",
|
||||
"HuggingFacePipeline",
|
||||
"SQLDatabase",
|
||||
"PowerBIDataset",
|
||||
"FAISS",
|
||||
"MRKLChain",
|
||||
"VectorDBQA",
|
||||
"ElasticVectorSearch",
|
||||
"InMemoryDocstore",
|
||||
"ConversationChain",
|
||||
"VectorDBQAWithSourcesChain",
|
||||
"QAWithSourcesChain",
|
||||
"LlamaCpp",
|
||||
"HuggingFaceTextGenInference",
|
||||
]
|
||||
|
||||
@@ -1,20 +1,6 @@
|
||||
from langchain_community.adapters.openai import (
|
||||
Chat,
|
||||
ChatCompletion,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletions,
|
||||
Choice,
|
||||
ChoiceChunk,
|
||||
Completions,
|
||||
IndexableBaseModel,
|
||||
chat,
|
||||
convert_dict_to_message,
|
||||
convert_message_to_dict,
|
||||
convert_messages_for_finetuning,
|
||||
convert_openai_messages,
|
||||
)
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
DEPRECATED_IMPORTS = [
|
||||
"IndexableBaseModel",
|
||||
"Choice",
|
||||
"ChatCompletions",
|
||||
@@ -29,3 +15,16 @@ __all__ = [
|
||||
"Chat",
|
||||
"chat",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in DEPRECATED_IMPORTS:
|
||||
raise ImportError(
|
||||
f"{name} has been moved to the langchain-community package. "
|
||||
f"See https://github.com/langchain-ai/langchain/discussions/19083 for more "
|
||||
f"information.\n\nTo use it install langchain-community:\n\n"
|
||||
f"`pip install -U langchain-community`\n\n"
|
||||
f"then import with:\n\n"
|
||||
f"`from langchain_community.adapters.openai import {name}`"
|
||||
)
|
||||
raise AttributeError()
|
||||
|
||||
@@ -31,14 +31,6 @@ Agents select and use **Tools** and **Toolkits** for actions.
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain_community.agent_toolkits import (
|
||||
create_json_agent,
|
||||
create_openapi_agent,
|
||||
create_pbi_agent,
|
||||
create_pbi_chat_agent,
|
||||
create_spark_sql_agent,
|
||||
create_sql_agent,
|
||||
)
|
||||
from langchain_core._api.path import as_import_path
|
||||
|
||||
from langchain.agents.agent import (
|
||||
@@ -91,6 +83,14 @@ DEPRECATED_CODE = [
|
||||
"create_spark_dataframe_agent",
|
||||
"create_xorbits_agent",
|
||||
]
|
||||
DEPRECATED_COMMUNITY_CODE = [
|
||||
"create_json_agent",
|
||||
"create_openapi_agent",
|
||||
"create_pbi_agent",
|
||||
"create_pbi_chat_agent",
|
||||
"create_spark_sql_agent",
|
||||
"create_sql_agent",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
@@ -109,6 +109,16 @@ def __getattr__(name: str) -> Any:
|
||||
"for more information.\n"
|
||||
f"Please update your import statement from: `{old_path}` to `{new_path}`."
|
||||
)
|
||||
if name in DEPRECATED_COMMUNITY_CODE:
|
||||
raise ImportError(
|
||||
f"{name} has been moved to the langchain-community package. "
|
||||
f"See https://github.com/langchain-ai/langchain/discussions/19083 for more "
|
||||
f"information.\n\nTo use it install langchain-community:\n\n"
|
||||
f"`pip install -U langchain-community`\n\n"
|
||||
f"then import with:\n\n"
|
||||
f"`from langchain_community.agent_toolkits import {name}`"
|
||||
)
|
||||
|
||||
raise AttributeError(f"{name} does not exist")
|
||||
|
||||
|
||||
@@ -132,12 +142,6 @@ __all__ = [
|
||||
"StructuredChatAgent",
|
||||
"Tool",
|
||||
"ZeroShotAgent",
|
||||
"create_json_agent",
|
||||
"create_openapi_agent",
|
||||
"create_pbi_agent",
|
||||
"create_pbi_chat_agent",
|
||||
"create_spark_sql_agent",
|
||||
"create_sql_agent",
|
||||
"create_vectorstore_agent",
|
||||
"create_vectorstore_router_agent",
|
||||
"get_all_tool_names",
|
||||
|
||||
@@ -13,11 +13,9 @@ whether permissions of the given toolkit are appropriate for the application.
|
||||
|
||||
See [Security](https://python.langchain.com/docs/security) for more information.
|
||||
"""
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api import LangChainDeprecationWarning
|
||||
from langchain_core._api.path import as_import_path
|
||||
|
||||
from langchain.agents.agent_toolkits.conversational_retrieval.openai_functions import (
|
||||
@@ -33,7 +31,6 @@ from langchain.agents.agent_toolkits.vectorstore.toolkit import (
|
||||
VectorStoreToolkit,
|
||||
)
|
||||
from langchain.tools.retriever import create_retriever_tool
|
||||
from langchain.utils.interactive_env import is_interactive_env
|
||||
|
||||
DEPRECATED_AGENTS = [
|
||||
"create_csv_agent",
|
||||
@@ -42,38 +39,7 @@ DEPRECATED_AGENTS = [
|
||||
"create_python_agent",
|
||||
"create_spark_dataframe_agent",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Get attr name."""
|
||||
if name in DEPRECATED_AGENTS:
|
||||
relative_path = as_import_path(Path(__file__).parent, suffix=name)
|
||||
old_path = "langchain." + relative_path
|
||||
new_path = "langchain_experimental." + relative_path
|
||||
raise ImportError(
|
||||
f"{name} has been moved to langchain experimental. "
|
||||
"See https://github.com/langchain-ai/langchain/discussions/11680"
|
||||
"for more information.\n"
|
||||
f"Please update your import statement from: `{old_path}` to `{new_path}`."
|
||||
)
|
||||
|
||||
from langchain_community import agent_toolkits
|
||||
|
||||
# If not in interactive env, raise warning.
|
||||
if not is_interactive_env():
|
||||
warnings.warn(
|
||||
"Importing this agent toolkit from langchain is deprecated. Importing it "
|
||||
"from langchain will no longer be supported as of langchain==0.2.0. "
|
||||
"Please import from langchain-community instead:\n\n"
|
||||
f"`from langchain_community.agent_toolkits import {name}`.\n\n"
|
||||
"To install langchain-community run `pip install -U langchain-community`.",
|
||||
category=LangChainDeprecationWarning,
|
||||
)
|
||||
|
||||
return getattr(agent_toolkits, name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
DEPRECATED_COMMUNITY_AGENTS = [
|
||||
"AINetworkToolkit",
|
||||
"AmadeusToolkit",
|
||||
"AzureCognitiveServicesToolkit",
|
||||
@@ -92,9 +58,6 @@ __all__ = [
|
||||
"SteamToolkit",
|
||||
"SQLDatabaseToolkit",
|
||||
"SparkSQLToolkit",
|
||||
"VectorStoreInfo",
|
||||
"VectorStoreRouterToolkit",
|
||||
"VectorStoreToolkit",
|
||||
"ZapierToolkit",
|
||||
"create_json_agent",
|
||||
"create_openapi_agent",
|
||||
@@ -102,6 +65,37 @@ __all__ = [
|
||||
"create_pbi_chat_agent",
|
||||
"create_spark_sql_agent",
|
||||
"create_sql_agent",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Get attr name."""
|
||||
if name in DEPRECATED_AGENTS:
|
||||
relative_path = as_import_path(Path(__file__).parent, suffix=name)
|
||||
old_path = "langchain." + relative_path
|
||||
new_path = "langchain_experimental." + relative_path
|
||||
raise ImportError(
|
||||
f"{name} has been moved to langchain experimental. "
|
||||
"See https://github.com/langchain-ai/langchain/discussions/11680"
|
||||
"for more information.\n"
|
||||
f"Please update your import statement from: `{old_path}` to `{new_path}`."
|
||||
)
|
||||
if name in DEPRECATED_COMMUNITY_AGENTS:
|
||||
raise ImportError(
|
||||
f"{name} has been moved to the langchain-community package. "
|
||||
f"See https://github.com/langchain-ai/langchain/discussions/19083 for more "
|
||||
f"information.\n\nTo use it install langchain-community:\n\n"
|
||||
f"`pip install -U langchain-community`\n\n"
|
||||
f"then import with:\n\n"
|
||||
f"`from langchain_community.agent_toolkits import {name}`"
|
||||
)
|
||||
raise AttributeError()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"VectorStoreInfo",
|
||||
"VectorStoreRouterToolkit",
|
||||
"VectorStoreToolkit",
|
||||
"create_vectorstore_agent",
|
||||
"create_vectorstore_router_agent",
|
||||
"create_conversational_retrieval_agent",
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
from langchain_community.agent_toolkits.ainetwork.toolkit import AINetworkToolkit
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["AINetworkToolkit"]
|
||||
DEPRECATED_IMPORTS = ["AINetworkToolkit"]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in DEPRECATED_IMPORTS:
|
||||
raise ImportError(
|
||||
f"{name} has been moved to the langchain-community package. "
|
||||
f"See https://github.com/langchain-ai/langchain/discussions/19083 for more "
|
||||
f"information.\n\nTo use it install langchain-community:\n\n"
|
||||
f"`pip install -U langchain-community`\n\n"
|
||||
f"then import with:\n\n"
|
||||
f"`from langchain_community.agent_toolkits.ainetwork.toolkit import {name}`"
|
||||
)
|
||||
|
||||
raise AttributeError()
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
from langchain_community.agent_toolkits.amadeus.toolkit import AmadeusToolkit
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["AmadeusToolkit"]
|
||||
DEPRECATED_IMPORTS = ["AmadeusToolkit"]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in DEPRECATED_IMPORTS:
|
||||
raise ImportError(
|
||||
f"{name} has been moved to the langchain-community package. "
|
||||
f"See https://github.com/langchain-ai/langchain/discussions/19083 for more "
|
||||
f"information.\n\nTo use it install langchain-community:\n\n"
|
||||
f"`pip install -U langchain-community`\n\n"
|
||||
f"then import with:\n\n"
|
||||
f"`from langchain_community.agent_toolkits.amadeus.toolkit import {name}`"
|
||||
)
|
||||
|
||||
raise AttributeError()
|
||||
|
||||
@@ -1,5 +1,18 @@
|
||||
from langchain_community.agent_toolkits.azure_cognitive_services import (
|
||||
AzureCognitiveServicesToolkit,
|
||||
)
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["AzureCognitiveServicesToolkit"]
|
||||
DEPRECATED_IMPORTS = ["AzureCognitiveServicesToolkit"]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in DEPRECATED_IMPORTS:
|
||||
raise ImportError(
|
||||
f"{name} has been moved to the langchain-community package. "
|
||||
f"See https://github.com/langchain-ai/langchain/discussions/19083 for more "
|
||||
f"information.\n\nTo use it install langchain-community:\n\n"
|
||||
f"`pip install -U langchain-community`\n\n"
|
||||
f"then import with:\n\n"
|
||||
f"`from langchain_community.agent_toolkits.azure_cognitive_services import "
|
||||
f"{name}`"
|
||||
)
|
||||
|
||||
raise AttributeError()
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
from langchain_community.agent_toolkits.base import BaseToolkit
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["BaseToolkit"]
|
||||
DEPRECATED_IMPORTS = ["BaseToolkit"]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in DEPRECATED_IMPORTS:
|
||||
raise ImportError(
|
||||
f"{name} has been moved to the langchain-community package. "
|
||||
f"See https://github.com/langchain-ai/langchain/discussions/19083 for more "
|
||||
f"information.\n\nTo use it install langchain-community:\n\n"
|
||||
f"`pip install -U langchain-community`\n\n"
|
||||
f"then import with:\n\n"
|
||||
f"`from langchain_community.agent_toolkits.base import "
|
||||
f"{name}`"
|
||||
)
|
||||
|
||||
raise AttributeError()
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
from langchain_community.agent_toolkits.clickup.toolkit import ClickupToolkit
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["ClickupToolkit"]
|
||||
DEPRECATED_IMPORTS = ["ClickupToolkit"]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in DEPRECATED_IMPORTS:
|
||||
raise ImportError(
|
||||
f"{name} has been moved to the langchain-community package. "
|
||||
f"See https://github.com/langchain-ai/langchain/discussions/19083 for more "
|
||||
f"information.\n\nTo use it install langchain-community:\n\n"
|
||||
f"`pip install -U langchain-community`\n\n"
|
||||
f"then import with:\n\n"
|
||||
f"`from langchain_community.agent_toolkits.clickup.toolkit import {name}`"
|
||||
)
|
||||
|
||||
raise AttributeError()
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
from langchain.tools.retriever import create_retriever_tool
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["create_retriever_tool"]
|
||||
DEPRECATED_IMPORTS = ["create_retriever_tool"]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in DEPRECATED_IMPORTS:
|
||||
raise ImportError(
|
||||
f"{name} has been moved to the langchain-community package. "
|
||||
f"See https://github.com/langchain-ai/langchain/discussions/19083 for more "
|
||||
f"information.\n\nTo use it install langchain-community:\n\n"
|
||||
f"`pip install -U langchain-community`\n\n"
|
||||
f"then import with:\n\n"
|
||||
f"`from langchain_community.agent_toolkits.conversational_retrieval.tool import {name}`" # noqa: E501
|
||||
)
|
||||
|
||||
raise AttributeError()
|
||||
|
||||
@@ -1,7 +1,17 @@
|
||||
"""Local file management toolkit."""
|
||||
from typing import Any
|
||||
|
||||
from langchain_community.agent_toolkits.file_management.toolkit import (
|
||||
FileManagementToolkit,
|
||||
)
|
||||
DEPRECATED_IMPORTS = ["FileManagementToolkit"]
|
||||
|
||||
__all__ = ["FileManagementToolkit"]
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in DEPRECATED_IMPORTS:
|
||||
raise ImportError(
|
||||
f"{name} has been moved to the langchain-community package. "
|
||||
f"See https://github.com/langchain-ai/langchain/discussions/19083 for more "
|
||||
f"information.\n\nTo use it install langchain-community:\n\n"
|
||||
f"`pip install -U langchain-community`\n\n"
|
||||
f"then import with:\n\n"
|
||||
f"`from langchain_community.agent_toolkits.file_management import {name}`"
|
||||
)
|
||||
|
||||
raise AttributeError()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user