mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 23:13:31 +00:00
merge from upstream/master
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
"""Data anonymizer package"""
|
||||
from langchain_experimental.data_anonymizer.presidio import PresidioAnonymizer
|
||||
from langchain_experimental.data_anonymizer.presidio import (
|
||||
PresidioAnonymizer,
|
||||
PresidioReversibleAnonymizer,
|
||||
)
|
||||
|
||||
__all__ = ["PresidioAnonymizer"]
|
||||
__all__ = ["PresidioAnonymizer", "PresidioReversibleAnonymizer"]
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class AnonymizerBase(ABC):
|
||||
@@ -8,10 +9,24 @@ class AnonymizerBase(ABC):
|
||||
wrapping the behavior for all methods in a base class.
|
||||
"""
|
||||
|
||||
def anonymize(self, text: str) -> str:
|
||||
def anonymize(self, text: str, language: Optional[str] = None) -> str:
|
||||
"""Anonymize text"""
|
||||
return self._anonymize(text)
|
||||
return self._anonymize(text, language)
|
||||
|
||||
@abstractmethod
|
||||
def _anonymize(self, text: str) -> str:
|
||||
def _anonymize(self, text: str, language: Optional[str]) -> str:
|
||||
"""Abstract method to anonymize text"""
|
||||
|
||||
|
||||
class ReversibleAnonymizerBase(AnonymizerBase):
|
||||
"""
|
||||
Base abstract class for reversible anonymizers.
|
||||
"""
|
||||
|
||||
def deanonymize(self, text: str) -> str:
|
||||
"""Deanonymize text"""
|
||||
return self._deanonymize(text)
|
||||
|
||||
@abstractmethod
|
||||
def _deanonymize(self, text: str) -> str:
|
||||
"""Abstract method to deanonymize text"""
|
||||
|
@@ -0,0 +1,21 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
MappingDataType = Dict[str, Dict[str, str]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeanonymizerMapping:
|
||||
mapping: MappingDataType = field(
|
||||
default_factory=lambda: defaultdict(lambda: defaultdict(str))
|
||||
)
|
||||
|
||||
@property
|
||||
def data(self) -> MappingDataType:
|
||||
"""Return the deanonymizer mapping"""
|
||||
return {k: dict(v) for k, v in self.mapping.items()}
|
||||
|
||||
def update(self, new_mapping: MappingDataType) -> None:
|
||||
for entity_type, values in new_mapping.items():
|
||||
self.mapping[entity_type].update(values)
|
@@ -0,0 +1,17 @@
|
||||
from langchain_experimental.data_anonymizer.presidio import MappingDataType
|
||||
|
||||
|
||||
def default_matching_strategy(text: str, deanonymizer_mapping: MappingDataType) -> str:
|
||||
"""
|
||||
Default matching strategy for deanonymization.
|
||||
It replaces all the anonymized entities with the original ones.
|
||||
|
||||
Args:
|
||||
text: text to deanonymize
|
||||
deanonymizer_mapping: mapping between anonymized entities and original ones"""
|
||||
|
||||
# Iterate over all the entities (PERSON, EMAIL_ADDRESS, etc.)
|
||||
for entity_type in deanonymizer_mapping:
|
||||
for anonymized, original in deanonymizer_mapping[entity_type].items():
|
||||
text = text.replace(anonymized, original)
|
||||
return text
|
@@ -1,8 +1,8 @@
|
||||
import string
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
|
||||
def get_pseudoanonymizer_mapping() -> Dict[str, Callable]:
|
||||
def get_pseudoanonymizer_mapping(seed: Optional[int] = None) -> Dict[str, Callable]:
|
||||
try:
|
||||
from faker import Faker
|
||||
except ImportError as e:
|
||||
@@ -11,6 +11,7 @@ def get_pseudoanonymizer_mapping() -> Dict[str, Callable]:
|
||||
) from e
|
||||
|
||||
fake = Faker()
|
||||
fake.seed_instance(seed)
|
||||
|
||||
# Listed entities supported by Microsoft Presidio (for now, global and US only)
|
||||
# Source: https://microsoft.github.io/presidio/supported_entities/
|
||||
@@ -26,8 +27,8 @@ def get_pseudoanonymizer_mapping() -> Dict[str, Callable]:
|
||||
fake.random_choices(string.ascii_lowercase + string.digits, length=26)
|
||||
),
|
||||
"IP_ADDRESS": lambda _: fake.ipv4_public(),
|
||||
"LOCATION": lambda _: fake.address(),
|
||||
"DATE_TIME": lambda _: fake.iso8601(),
|
||||
"LOCATION": lambda _: fake.city(),
|
||||
"DATE_TIME": lambda _: fake.date(),
|
||||
"NRP": lambda _: str(fake.random_number(digits=8, fix_len=True)),
|
||||
"MEDICAL_LICENSE": lambda _: fake.bothify(text="??######").upper(),
|
||||
"URL": lambda _: fake.url(),
|
||||
|
@@ -1,24 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
||||
|
||||
from langchain_experimental.data_anonymizer.base import AnonymizerBase
|
||||
import yaml
|
||||
|
||||
from langchain_experimental.data_anonymizer.base import (
|
||||
AnonymizerBase,
|
||||
ReversibleAnonymizerBase,
|
||||
)
|
||||
from langchain_experimental.data_anonymizer.deanonymizer_mapping import (
|
||||
DeanonymizerMapping,
|
||||
MappingDataType,
|
||||
)
|
||||
from langchain_experimental.data_anonymizer.deanonymizer_matching_strategies import (
|
||||
default_matching_strategy,
|
||||
)
|
||||
from langchain_experimental.data_anonymizer.faker_presidio_mapping import (
|
||||
get_pseudoanonymizer_mapping,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from presidio_analyzer import EntityRecognizer
|
||||
try:
|
||||
from presidio_analyzer import AnalyzerEngine
|
||||
from presidio_analyzer.nlp_engine import NlpEngineProvider
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import presidio_analyzer, please install with "
|
||||
"`pip install presidio-analyzer`. You will also need to download a "
|
||||
"spaCy model to use the analyzer, e.g. "
|
||||
"`python -m spacy download en_core_web_lg`."
|
||||
) from e
|
||||
try:
|
||||
from presidio_anonymizer import AnonymizerEngine
|
||||
from presidio_anonymizer.entities import OperatorConfig
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import presidio_anonymizer, please install with "
|
||||
"`pip install presidio-anonymizer`."
|
||||
) from e
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from presidio_analyzer import EntityRecognizer, RecognizerResult
|
||||
from presidio_anonymizer.entities import EngineResult
|
||||
|
||||
# Configuring Anonymizer for multiple languages
|
||||
# Detailed description and examples can be found here:
|
||||
# langchain/docs/extras/guides/privacy/multi_language_anonymization.ipynb
|
||||
DEFAULT_LANGUAGES_CONFIG = {
|
||||
# You can also use Stanza or transformers library.
|
||||
# See https://microsoft.github.io/presidio/analyzer/customizing_nlp_models/
|
||||
"nlp_engine_name": "spacy",
|
||||
"models": [
|
||||
{"lang_code": "en", "model_name": "en_core_web_lg"},
|
||||
# {"lang_code": "de", "model_name": "de_core_news_md"},
|
||||
# {"lang_code": "es", "model_name": "es_core_news_md"},
|
||||
# ...
|
||||
# List of available models: https://spacy.io/usage/models
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class PresidioAnonymizer(AnonymizerBase):
|
||||
"""Anonymizer using Microsoft Presidio."""
|
||||
|
||||
class PresidioAnonymizerBase(AnonymizerBase):
|
||||
def __init__(
|
||||
self,
|
||||
analyzed_fields: Optional[List[str]] = None,
|
||||
operators: Optional[Dict[str, OperatorConfig]] = None,
|
||||
languages_config: Dict = DEFAULT_LANGUAGES_CONFIG,
|
||||
faker_seed: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -28,25 +79,15 @@ class PresidioAnonymizer(AnonymizerBase):
|
||||
Operators allow for custom anonymization of detected PII.
|
||||
Learn more:
|
||||
https://microsoft.github.io/presidio/tutorial/10_simple_anonymization/
|
||||
languages_config: Configuration for the NLP engine.
|
||||
First language in the list will be used as the main language
|
||||
in self.anonymize(...) when no language is specified.
|
||||
Learn more:
|
||||
https://microsoft.github.io/presidio/analyzer/customizing_nlp_models/
|
||||
faker_seed: Seed used to initialize faker.
|
||||
Defaults to None, in which case faker will be seeded randomly
|
||||
and provide random values.
|
||||
"""
|
||||
try:
|
||||
from presidio_analyzer import AnalyzerEngine
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import presidio_analyzer, please install with "
|
||||
"`pip install presidio-analyzer`. You will also need to download a "
|
||||
"spaCy model to use the analyzer, e.g. "
|
||||
"`python -m spacy download en_core_web_lg`."
|
||||
) from e
|
||||
try:
|
||||
from presidio_anonymizer import AnonymizerEngine
|
||||
from presidio_anonymizer.entities import OperatorConfig
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import presidio_anonymizer, please install with "
|
||||
"`pip install presidio-anonymizer`."
|
||||
) from e
|
||||
|
||||
self.analyzed_fields = (
|
||||
analyzed_fields
|
||||
if analyzed_fields is not None
|
||||
@@ -59,17 +100,66 @@ class PresidioAnonymizer(AnonymizerBase):
|
||||
field: OperatorConfig(
|
||||
operator_name="custom", params={"lambda": faker_function}
|
||||
)
|
||||
for field, faker_function in get_pseudoanonymizer_mapping().items()
|
||||
for field, faker_function in get_pseudoanonymizer_mapping(
|
||||
faker_seed
|
||||
).items()
|
||||
}
|
||||
)
|
||||
self._analyzer = AnalyzerEngine()
|
||||
|
||||
provider = NlpEngineProvider(nlp_configuration=languages_config)
|
||||
nlp_engine = provider.create_engine()
|
||||
|
||||
self.supported_languages = list(nlp_engine.nlp.keys())
|
||||
|
||||
self._analyzer = AnalyzerEngine(
|
||||
supported_languages=self.supported_languages, nlp_engine=nlp_engine
|
||||
)
|
||||
self._anonymizer = AnonymizerEngine()
|
||||
|
||||
def _anonymize(self, text: str) -> str:
|
||||
def add_recognizer(self, recognizer: EntityRecognizer) -> None:
|
||||
"""Add a recognizer to the analyzer
|
||||
|
||||
Args:
|
||||
recognizer: Recognizer to add to the analyzer.
|
||||
"""
|
||||
self._analyzer.registry.add_recognizer(recognizer)
|
||||
self.analyzed_fields.extend(recognizer.supported_entities)
|
||||
|
||||
def add_operators(self, operators: Dict[str, OperatorConfig]) -> None:
|
||||
"""Add operators to the anonymizer
|
||||
|
||||
Args:
|
||||
operators: Operators to add to the anonymizer.
|
||||
"""
|
||||
self.operators.update(operators)
|
||||
|
||||
|
||||
class PresidioAnonymizer(PresidioAnonymizerBase):
|
||||
def _anonymize(self, text: str, language: Optional[str] = None) -> str:
|
||||
"""Anonymize text.
|
||||
Each PII entity is replaced with a fake value.
|
||||
Each time fake values will be different, as they are generated randomly.
|
||||
|
||||
Args:
|
||||
text: text to anonymize
|
||||
language: language to use for analysis of PII
|
||||
If None, the first (main) language in the list
|
||||
of languages specified in the configuration will be used.
|
||||
"""
|
||||
if language is None:
|
||||
language = self.supported_languages[0]
|
||||
|
||||
if language not in self.supported_languages:
|
||||
raise ValueError(
|
||||
f"Language '{language}' is not supported. "
|
||||
f"Supported languages are: {self.supported_languages}. "
|
||||
"Change your language configuration file to add more languages."
|
||||
)
|
||||
|
||||
results = self._analyzer.analyze(
|
||||
text,
|
||||
entities=self.analyzed_fields,
|
||||
language="en",
|
||||
language=language,
|
||||
)
|
||||
|
||||
return self._anonymizer.anonymize(
|
||||
@@ -78,11 +168,199 @@ class PresidioAnonymizer(AnonymizerBase):
|
||||
operators=self.operators,
|
||||
).text
|
||||
|
||||
def add_recognizer(self, recognizer: EntityRecognizer) -> None:
|
||||
"""Add a recognizer to the analyzer"""
|
||||
self._analyzer.registry.add_recognizer(recognizer)
|
||||
self.analyzed_fields.extend(recognizer.supported_entities)
|
||||
|
||||
def add_operators(self, operators: Dict[str, OperatorConfig]) -> None:
|
||||
"""Add operators to the anonymizer"""
|
||||
self.operators.update(operators)
|
||||
class PresidioReversibleAnonymizer(PresidioAnonymizerBase, ReversibleAnonymizerBase):
|
||||
def __init__(
|
||||
self,
|
||||
analyzed_fields: Optional[List[str]] = None,
|
||||
operators: Optional[Dict[str, OperatorConfig]] = None,
|
||||
languages_config: Dict = DEFAULT_LANGUAGES_CONFIG,
|
||||
faker_seed: Optional[int] = None,
|
||||
):
|
||||
super().__init__(analyzed_fields, operators, languages_config, faker_seed)
|
||||
self._deanonymizer_mapping = DeanonymizerMapping()
|
||||
|
||||
@property
|
||||
def deanonymizer_mapping(self) -> MappingDataType:
|
||||
"""Return the deanonymizer mapping"""
|
||||
return self._deanonymizer_mapping.data
|
||||
|
||||
def _update_deanonymizer_mapping(
|
||||
self,
|
||||
original_text: str,
|
||||
analyzer_results: List[RecognizerResult],
|
||||
anonymizer_results: EngineResult,
|
||||
) -> None:
|
||||
"""Creates or updates the mapping used to de-anonymize text.
|
||||
|
||||
This method exploits the results returned by the
|
||||
analysis and anonymization processes.
|
||||
|
||||
It constructs a mapping from each anonymized entity
|
||||
back to its original text value.
|
||||
|
||||
Mapping will be stored as "deanonymizer_mapping" property.
|
||||
|
||||
Example of "deanonymizer_mapping":
|
||||
{
|
||||
"PERSON": {
|
||||
"<anonymized>": "<original>",
|
||||
"John Doe": "Slim Shady"
|
||||
},
|
||||
"PHONE_NUMBER": {
|
||||
"111-111-1111": "555-555-5555"
|
||||
}
|
||||
...
|
||||
}
|
||||
"""
|
||||
|
||||
# We are able to zip and loop through both lists because we expect
|
||||
# them to return corresponding entities for each identified piece
|
||||
# of analyzable data from our input.
|
||||
|
||||
# We sort them by their 'start' attribute because it allows us to
|
||||
# match corresponding entities by their position in the input text.
|
||||
analyzer_results = sorted(analyzer_results, key=lambda d: d.start)
|
||||
anonymizer_results.items = sorted(
|
||||
anonymizer_results.items, key=lambda d: d.start
|
||||
)
|
||||
|
||||
new_deanonymizer_mapping: MappingDataType = defaultdict(dict)
|
||||
|
||||
for analyzed_entity, anonymized_entity in zip(
|
||||
analyzer_results, anonymizer_results.items
|
||||
):
|
||||
original_value = original_text[analyzed_entity.start : analyzed_entity.end]
|
||||
new_deanonymizer_mapping[anonymized_entity.entity_type][
|
||||
anonymized_entity.text
|
||||
] = original_value
|
||||
|
||||
self._deanonymizer_mapping.update(new_deanonymizer_mapping)
|
||||
|
||||
def _anonymize(self, text: str, language: Optional[str] = None) -> str:
|
||||
"""Anonymize text.
|
||||
Each PII entity is replaced with a fake value.
|
||||
Each time fake values will be different, as they are generated randomly.
|
||||
At the same time, we will create a mapping from each anonymized entity
|
||||
back to its original text value.
|
||||
|
||||
Args:
|
||||
text: text to anonymize
|
||||
language: language to use for analysis of PII
|
||||
If None, the first (main) language in the list
|
||||
of languages specified in the configuration will be used.
|
||||
"""
|
||||
if language is None:
|
||||
language = self.supported_languages[0]
|
||||
|
||||
if language not in self.supported_languages:
|
||||
raise ValueError(
|
||||
f"Language '{language}' is not supported. "
|
||||
f"Supported languages are: {self.supported_languages}. "
|
||||
"Change your language configuration file to add more languages."
|
||||
)
|
||||
|
||||
analyzer_results = self._analyzer.analyze(
|
||||
text,
|
||||
entities=self.analyzed_fields,
|
||||
language=language,
|
||||
)
|
||||
|
||||
filtered_analyzer_results = (
|
||||
self._anonymizer._remove_conflicts_and_get_text_manipulation_data(
|
||||
analyzer_results
|
||||
)
|
||||
)
|
||||
|
||||
anonymizer_results = self._anonymizer.anonymize(
|
||||
text,
|
||||
analyzer_results=analyzer_results,
|
||||
operators=self.operators,
|
||||
)
|
||||
|
||||
self._update_deanonymizer_mapping(
|
||||
text, filtered_analyzer_results, anonymizer_results
|
||||
)
|
||||
|
||||
return anonymizer_results.text
|
||||
|
||||
def _deanonymize(
|
||||
self,
|
||||
text_to_deanonymize: str,
|
||||
deanonymizer_matching_strategy: Callable[
|
||||
[str, MappingDataType], str
|
||||
] = default_matching_strategy,
|
||||
) -> str:
|
||||
"""Deanonymize text.
|
||||
Each anonymized entity is replaced with its original value.
|
||||
This method exploits the mapping created during the anonymization process.
|
||||
|
||||
Args:
|
||||
text_to_deanonymize: text to deanonymize
|
||||
deanonymizer_matching_strategy: function to use to match
|
||||
anonymized entities with their original values and replace them.
|
||||
"""
|
||||
if not self._deanonymizer_mapping:
|
||||
raise ValueError(
|
||||
"Deanonymizer mapping is empty.",
|
||||
"Please call anonymize() and anonymize some text first.",
|
||||
)
|
||||
|
||||
text_to_deanonymize = deanonymizer_matching_strategy(
|
||||
text_to_deanonymize, self.deanonymizer_mapping
|
||||
)
|
||||
|
||||
return text_to_deanonymize
|
||||
|
||||
def save_deanonymizer_mapping(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the deanonymizer mapping to a JSON or YAML file.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to save the mapping to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
anonymizer.save_deanonymizer_mapping(file_path="path/mapping.json")
|
||||
"""
|
||||
|
||||
save_path = Path(file_path)
|
||||
|
||||
if save_path.suffix not in [".json", ".yaml"]:
|
||||
raise ValueError(f"{save_path} must have an extension of .json or .yaml")
|
||||
|
||||
# Make sure parent directories exist
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(save_path, "w") as f:
|
||||
json.dump(self.deanonymizer_mapping, f, indent=2)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(save_path, "w") as f:
|
||||
yaml.dump(self.deanonymizer_mapping, f, default_flow_style=False)
|
||||
|
||||
def load_deanonymizer_mapping(self, file_path: Union[Path, str]) -> None:
|
||||
"""Load the deanonymizer mapping from a JSON or YAML file.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to load the mapping from.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
anonymizer.load_deanonymizer_mapping(file_path="path/mapping.json")
|
||||
"""
|
||||
|
||||
load_path = Path(file_path)
|
||||
|
||||
if load_path.suffix not in [".json", ".yaml"]:
|
||||
raise ValueError(f"{load_path} must have an extension of .json or .yaml")
|
||||
|
||||
if load_path.suffix == ".json":
|
||||
with open(load_path, "r") as f:
|
||||
loaded_mapping = json.load(f)
|
||||
elif load_path.suffix == ".yaml":
|
||||
with open(load_path, "r") as f:
|
||||
loaded_mapping = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
self._deanonymizer_mapping.update(loaded_mapping)
|
||||
|
@@ -0,0 +1,5 @@
|
||||
from langchain_experimental.graph_transformers.diffbot import DiffbotGraphTransformer
|
||||
|
||||
__all__ = [
|
||||
"DiffbotGraphTransformer",
|
||||
]
|
@@ -0,0 +1,316 @@
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import requests
|
||||
from langchain.graphs.graph_document import GraphDocument, Node, Relationship
|
||||
from langchain.schema import Document
|
||||
from langchain.utils import get_from_env
|
||||
|
||||
|
||||
def format_property_key(s: str) -> str:
|
||||
words = s.split()
|
||||
if not words:
|
||||
return s
|
||||
first_word = words[0].lower()
|
||||
capitalized_words = [word.capitalize() for word in words[1:]]
|
||||
return "".join([first_word] + capitalized_words)
|
||||
|
||||
|
||||
class NodesList:
|
||||
"""
|
||||
Manages a list of nodes with associated properties.
|
||||
|
||||
Attributes:
|
||||
nodes (Dict[Tuple, Any]): Stores nodes as keys and their properties as values.
|
||||
Each key is a tuple where the first element is the
|
||||
node ID and the second is the node type.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.nodes: Dict[Tuple[Union[str, int], str], Any] = dict()
|
||||
|
||||
def add_node_property(
|
||||
self, node: Tuple[Union[str, int], str], properties: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Adds or updates node properties.
|
||||
|
||||
If the node does not exist in the list, it's added along with its properties.
|
||||
If the node already exists, its properties are updated with the new values.
|
||||
|
||||
Args:
|
||||
node (Tuple): A tuple containing the node ID and node type.
|
||||
properties (Dict): A dictionary of properties to add or update for the node.
|
||||
"""
|
||||
if node not in self.nodes:
|
||||
self.nodes[node] = properties
|
||||
else:
|
||||
self.nodes[node].update(properties)
|
||||
|
||||
def return_node_list(self) -> List[Node]:
|
||||
"""
|
||||
Returns the nodes as a list of Node objects.
|
||||
|
||||
Each Node object will have its ID, type, and properties populated.
|
||||
|
||||
Returns:
|
||||
List[Node]: A list of Node objects.
|
||||
"""
|
||||
nodes = [
|
||||
Node(id=key[0], type=key[1], properties=self.nodes[key])
|
||||
for key in self.nodes
|
||||
]
|
||||
return nodes
|
||||
|
||||
|
||||
# Properties that should be treated as node properties instead of relationships
|
||||
FACT_TO_PROPERTY_TYPE = [
|
||||
"Date",
|
||||
"Number",
|
||||
"Job title",
|
||||
"Cause of death",
|
||||
"Organization type",
|
||||
"Academic title",
|
||||
]
|
||||
|
||||
|
||||
schema_mapping = [
|
||||
("HEADQUARTERS", "ORGANIZATION_LOCATIONS"),
|
||||
("RESIDENCE", "PERSON_LOCATION"),
|
||||
("ALL_PERSON_LOCATIONS", "PERSON_LOCATION"),
|
||||
("CHILD", "HAS_CHILD"),
|
||||
("PARENT", "HAS_PARENT"),
|
||||
("CUSTOMERS", "HAS_CUSTOMER"),
|
||||
("SKILLED_AT", "INTERESTED_IN"),
|
||||
]
|
||||
|
||||
|
||||
class SimplifiedSchema:
|
||||
"""
|
||||
Provides functionality for working with a simplified schema mapping.
|
||||
|
||||
Attributes:
|
||||
schema (Dict): A dictionary containing the mapping to simplified schema types.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initializes the schema dictionary based on the predefined list."""
|
||||
self.schema = dict()
|
||||
for row in schema_mapping:
|
||||
self.schema[row[0]] = row[1]
|
||||
|
||||
def get_type(self, type: str) -> str:
|
||||
"""
|
||||
Retrieves the simplified schema type for a given original type.
|
||||
|
||||
Args:
|
||||
type (str): The original schema type to find the simplified type for.
|
||||
|
||||
Returns:
|
||||
str: The simplified schema type if it exists;
|
||||
otherwise, returns the original type.
|
||||
"""
|
||||
try:
|
||||
return self.schema[type]
|
||||
except KeyError:
|
||||
return type
|
||||
|
||||
|
||||
class DiffbotGraphTransformer:
|
||||
"""Transforms documents into graph documents using Diffbot's NLP API.
|
||||
|
||||
A graph document transformation system takes a sequence of Documents and returns a
|
||||
sequence of Graph Documents.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class DiffbotGraphTransformer(BaseGraphDocumentTransformer):
|
||||
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[GraphDocument]:
|
||||
results = []
|
||||
|
||||
for document in documents:
|
||||
raw_results = self.nlp_request(document.page_content)
|
||||
graph_document = self.process_response(raw_results, document)
|
||||
results.append(graph_document)
|
||||
return results
|
||||
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
raise NotImplementedError
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
diffbot_api_key: Optional[str] = None,
|
||||
fact_confidence_threshold: float = 0.7,
|
||||
include_qualifiers: bool = True,
|
||||
include_evidence: bool = True,
|
||||
simplified_schema: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the graph transformer with various options.
|
||||
|
||||
Args:
|
||||
diffbot_api_key (str):
|
||||
The API key for Diffbot's NLP services.
|
||||
|
||||
fact_confidence_threshold (float):
|
||||
Minimum confidence level for facts to be included.
|
||||
include_qualifiers (bool):
|
||||
Whether to include qualifiers in the relationships.
|
||||
include_evidence (bool):
|
||||
Whether to include evidence for the relationships.
|
||||
simplified_schema (bool):
|
||||
Whether to use a simplified schema for relationships.
|
||||
"""
|
||||
self.diffbot_api_key = diffbot_api_key or get_from_env(
|
||||
"diffbot_api_key", "DIFFBOT_API_KEY"
|
||||
)
|
||||
self.fact_threshold_confidence = fact_confidence_threshold
|
||||
self.include_qualifiers = include_qualifiers
|
||||
self.include_evidence = include_evidence
|
||||
self.simplified_schema = None
|
||||
if simplified_schema:
|
||||
self.simplified_schema = SimplifiedSchema()
|
||||
|
||||
def nlp_request(self, text: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Make an API request to the Diffbot NLP endpoint.
|
||||
|
||||
Args:
|
||||
text (str): The text to be processed.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The JSON response from the API.
|
||||
"""
|
||||
|
||||
# Relationship extraction only works for English
|
||||
payload = {
|
||||
"content": text,
|
||||
"lang": "en",
|
||||
}
|
||||
|
||||
FIELDS = "facts"
|
||||
HOST = "nl.diffbot.com"
|
||||
url = (
|
||||
f"https://{HOST}/v1/?fields={FIELDS}&"
|
||||
f"token={self.diffbot_api_key}&language=en"
|
||||
)
|
||||
result = requests.post(url, data=payload)
|
||||
return result.json()
|
||||
|
||||
def process_response(
|
||||
self, payload: Dict[str, Any], document: Document
|
||||
) -> GraphDocument:
|
||||
"""
|
||||
Transform the Diffbot NLP response into a GraphDocument.
|
||||
|
||||
Args:
|
||||
payload (Dict[str, Any]): The JSON response from Diffbot's NLP API.
|
||||
document (Document): The original document.
|
||||
|
||||
Returns:
|
||||
GraphDocument: The transformed document as a graph.
|
||||
"""
|
||||
|
||||
# Return empty result if there are no facts
|
||||
if "facts" not in payload or not payload["facts"]:
|
||||
return GraphDocument(nodes=[], relationships=[], source=document)
|
||||
|
||||
# Nodes are a custom class because we need to deduplicate
|
||||
nodes_list = NodesList()
|
||||
# Relationships are a list because we don't deduplicate nor anything else
|
||||
relationships = list()
|
||||
for record in payload["facts"]:
|
||||
# Skip if the fact is below the threshold confidence
|
||||
if record["confidence"] < self.fact_threshold_confidence:
|
||||
continue
|
||||
|
||||
# TODO: It should probably be treated as a node property
|
||||
if not record["value"]["allTypes"]:
|
||||
continue
|
||||
|
||||
# Define source node
|
||||
source_id = (
|
||||
record["entity"]["allUris"][0]
|
||||
if record["entity"]["allUris"]
|
||||
else record["entity"]["name"]
|
||||
)
|
||||
source_label = record["entity"]["allTypes"][0]["name"].capitalize()
|
||||
source_name = record["entity"]["name"]
|
||||
source_node = Node(id=source_id, type=source_label)
|
||||
nodes_list.add_node_property(
|
||||
(source_id, source_label), {"name": source_name}
|
||||
)
|
||||
|
||||
# Define target node
|
||||
target_id = (
|
||||
record["value"]["allUris"][0]
|
||||
if record["value"]["allUris"]
|
||||
else record["value"]["name"]
|
||||
)
|
||||
target_label = record["value"]["allTypes"][0]["name"].capitalize()
|
||||
target_name = record["value"]["name"]
|
||||
# Some facts are better suited as node properties
|
||||
if target_label in FACT_TO_PROPERTY_TYPE:
|
||||
nodes_list.add_node_property(
|
||||
(source_id, source_label),
|
||||
{format_property_key(record["property"]["name"]): target_name},
|
||||
)
|
||||
else: # Define relationship
|
||||
# Define target node object
|
||||
target_node = Node(id=target_id, type=target_label)
|
||||
nodes_list.add_node_property(
|
||||
(target_id, target_label), {"name": target_name}
|
||||
)
|
||||
# Define relationship type
|
||||
rel_type = record["property"]["name"].replace(" ", "_").upper()
|
||||
if self.simplified_schema:
|
||||
rel_type = self.simplified_schema.get_type(rel_type)
|
||||
|
||||
# Relationship qualifiers/properties
|
||||
rel_properties = dict()
|
||||
relationship_evidence = [el["passage"] for el in record["evidence"]][0]
|
||||
if self.include_evidence:
|
||||
rel_properties.update({"evidence": relationship_evidence})
|
||||
if self.include_qualifiers and record.get("qualifiers"):
|
||||
for property in record["qualifiers"]:
|
||||
prop_key = format_property_key(property["property"]["name"])
|
||||
rel_properties[prop_key] = property["value"]["name"]
|
||||
|
||||
relationship = Relationship(
|
||||
source=source_node,
|
||||
target=target_node,
|
||||
type=rel_type,
|
||||
properties=rel_properties,
|
||||
)
|
||||
relationships.append(relationship)
|
||||
|
||||
return GraphDocument(
|
||||
nodes=nodes_list.return_node_list(),
|
||||
relationships=relationships,
|
||||
source=document,
|
||||
)
|
||||
|
||||
def convert_to_graph_documents(
|
||||
self, documents: Sequence[Document]
|
||||
) -> List[GraphDocument]:
|
||||
"""Convert a sequence of documents into graph documents.
|
||||
|
||||
Args:
|
||||
documents (Sequence[Document]): The original documents.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Sequence[GraphDocument]: The transformed documents as graphs.
|
||||
"""
|
||||
results = []
|
||||
for document in documents:
|
||||
raw_results = self.nlp_request(document.page_content)
|
||||
graph_document = self.process_response(raw_results, document)
|
||||
results.append(graph_document)
|
||||
return results
|
@@ -0,0 +1,38 @@
|
||||
"""Vector SQL Database Chain Retriever"""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
|
||||
|
||||
|
||||
class VectorSQLDatabaseChainRetriever(BaseRetriever):
|
||||
"""Retriever that uses SQLDatabase as Retriever"""
|
||||
|
||||
sql_db_chain: VectorSQLDatabaseChain
|
||||
"""SQL Database Chain"""
|
||||
page_content_key: str = "content"
|
||||
"""column name for page content of documents"""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
ret: List[Dict[str, Any]] = self.sql_db_chain(
|
||||
query, callbacks=run_manager.get_child(), **kwargs
|
||||
)["result"]
|
||||
return [
|
||||
Document(page_content=r[self.page_content_key], metadata=r) for r in ret
|
||||
]
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
85
libs/experimental/langchain_experimental/sql/prompt.py
Normal file
85
libs/experimental/langchain_experimental/sql/prompt.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# flake8: noqa
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
PROMPT_SUFFIX = """Only use the following tables:
|
||||
{table_info}
|
||||
|
||||
Question: {input}"""
|
||||
|
||||
_VECTOR_SQL_DEFAULT_TEMPLATE = """You are a {dialect} expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer to the input question.
|
||||
{dialect} queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance.
|
||||
When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows.
|
||||
|
||||
*NOTICE*: `DISTANCE(column, array)` only accept an array column as its first argument and a `NeuralArray(entity)` as its second argument. You also need a user defined function called `NeuralArray(entity)` to retrieve the entity's array.
|
||||
|
||||
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per {dialect}. You should only order according to the distance function.
|
||||
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
|
||||
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
||||
Pay attention to use today() function to get the current date, if the question involves "today". `ORDER BY` clause should always be after `WHERE` clause. DO NOT add semicolon to the end of SQL. Pay attention to the comment in table schema.
|
||||
|
||||
Use the following format:
|
||||
|
||||
Question: "Question here"
|
||||
SQLQuery: "SQL Query to run"
|
||||
SQLResult: "Result of the SQLQuery"
|
||||
Answer: "Final answer here"
|
||||
"""
|
||||
|
||||
VECTOR_SQL_PROMPT = PromptTemplate(
|
||||
input_variables=["input", "table_info", "dialect", "top_k"],
|
||||
template=_VECTOR_SQL_DEFAULT_TEMPLATE + PROMPT_SUFFIX,
|
||||
)
|
||||
|
||||
|
||||
_myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
|
||||
MyScale queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance.
|
||||
When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows.
|
||||
|
||||
*NOTICE*: `DISTANCE(column, array)` only accept an array column as its first argument and a `NeuralArray(entity)` as its second argument. You also need a user defined function called `NeuralArray(entity)` to retrieve the entity's array.
|
||||
|
||||
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MyScale. You should only order according to the distance function.
|
||||
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
|
||||
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
||||
Pay attention to use today() function to get the current date, if the question involves "today". `ORDER BY` clause should always be after `WHERE` clause. DO NOT add semicolon to the end of SQL. Pay attention to the comment in table schema.
|
||||
|
||||
Use the following format:
|
||||
|
||||
======== table info ========
|
||||
<some table infos>
|
||||
|
||||
Question: "Question here"
|
||||
SQLQuery: "SQL Query to run"
|
||||
|
||||
|
||||
Here are some examples:
|
||||
|
||||
======== table info ========
|
||||
CREATE TABLE "ChatPaper" (
|
||||
abstract String,
|
||||
id String,
|
||||
vector Array(Float32),
|
||||
) ENGINE = ReplicatedReplacingMergeTree()
|
||||
ORDER BY id
|
||||
PRIMARY KEY id
|
||||
|
||||
Question: What is Feartue Pyramid Network?
|
||||
SQLQuery: SELECT ChatPaper.title, ChatPaper.id, ChatPaper.authors FROM ChatPaper ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k}
|
||||
|
||||
|
||||
Let's begin:
|
||||
======== table info ========
|
||||
{table_info}
|
||||
|
||||
Question: {input}
|
||||
SQLQuery: """
|
||||
|
||||
MYSCALE_PROMPT = PromptTemplate(
|
||||
input_variables=["input", "table_info", "top_k"],
|
||||
template=_myscale_prompt + PROMPT_SUFFIX,
|
||||
)
|
||||
|
||||
|
||||
VECTOR_SQL_PROMPTS = {
|
||||
"myscale": MYSCALE_PROMPT,
|
||||
}
|
237
libs/experimental/langchain_experimental/sql/vector_sql.py
Normal file
237
libs/experimental/langchain_experimental/sql/vector_sql.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Vector SQL Database Chain Retriever"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseOutputParser, BasePromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
from langchain.utilities.sql_database import SQLDatabase
|
||||
|
||||
from langchain_experimental.sql.base import INTERMEDIATE_STEPS_KEY, SQLDatabaseChain
|
||||
|
||||
|
||||
class VectorSQLOutputParser(BaseOutputParser[str]):
|
||||
"""Output Parser for Vector SQL
|
||||
1. finds for `NeuralArray()` and replace it with the embedding
|
||||
2. finds for `DISTANCE()` and replace it with the distance name in backend SQL
|
||||
"""
|
||||
|
||||
model: Embeddings
|
||||
"""Embedding model to extract embedding for entity"""
|
||||
distance_func_name: str = "distance"
|
||||
"""Distance name for Vector SQL"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = 1
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "vector_sql_parser"
|
||||
|
||||
@classmethod
|
||||
def from_embeddings(
|
||||
cls, model: Embeddings, distance_func_name: str = "distance", **kwargs: Any
|
||||
) -> BaseOutputParser:
|
||||
return cls(model=model, distance_func_name=distance_func_name, **kwargs)
|
||||
|
||||
def parse(self, text: str) -> str:
|
||||
text = text.strip()
|
||||
start = text.find("NeuralArray(")
|
||||
_sql_str_compl = text
|
||||
if start > 0:
|
||||
_matched = text[text.find("NeuralArray(") + len("NeuralArray(") :]
|
||||
end = _matched.find(")") + start + len("NeuralArray(") + 1
|
||||
entity = _matched[: _matched.find(")")]
|
||||
vecs = self.model.embed_query(entity)
|
||||
vecs_str = "[" + ",".join(map(str, vecs)) + "]"
|
||||
_sql_str_compl = text.replace("DISTANCE", self.distance_func_name).replace(
|
||||
text[start:end], vecs_str
|
||||
)
|
||||
if _sql_str_compl[-1] == ";":
|
||||
_sql_str_compl = _sql_str_compl[:-1]
|
||||
return _sql_str_compl
|
||||
|
||||
|
||||
class VectorSQLRetrieveAllOutputParser(VectorSQLOutputParser):
|
||||
"""Based on VectorSQLOutputParser
|
||||
It also modify the SQL to get all columns
|
||||
"""
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "vector_sql_retrieve_all_parser"
|
||||
|
||||
def parse(self, text: str) -> str:
|
||||
text = text.strip()
|
||||
start = text.upper().find("SELECT")
|
||||
if start >= 0:
|
||||
end = text.upper().find("FROM")
|
||||
text = text.replace(text[start + len("SELECT") + 1 : end - 1], "*")
|
||||
return super().parse(text)
|
||||
|
||||
|
||||
def _try_eval(x: Any) -> Any:
|
||||
try:
|
||||
return eval(x)
|
||||
except Exception:
|
||||
return x
|
||||
|
||||
|
||||
def get_result_from_sqldb(
|
||||
db: SQLDatabase, cmd: str
|
||||
) -> Union[str, List[Dict[str, Any]], Dict[str, Any]]:
|
||||
result = db._execute(cmd, fetch="all") # type: ignore
|
||||
if isinstance(result, list):
|
||||
return [{k: _try_eval(v) for k, v in dict(d._asdict()).items()} for d in result]
|
||||
else:
|
||||
return {
|
||||
k: _try_eval(v) for k, v in dict(result._asdict()).items() # type: ignore
|
||||
}
|
||||
|
||||
|
||||
class VectorSQLDatabaseChain(SQLDatabaseChain):
|
||||
"""Chain for interacting with Vector SQL Database.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_experimental.sql import SQLDatabaseChain
|
||||
from langchain import OpenAI, SQLDatabase, OpenAIEmbeddings
|
||||
db = SQLDatabase(...)
|
||||
db_chain = VectorSQLDatabaseChain.from_llm(OpenAI(), db, OpenAIEmbeddings())
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include the permissions this chain needs.
|
||||
Failure to do so may result in data corruption or loss, since this chain may
|
||||
attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this chain.
|
||||
This issue shows an example negative outcome if these steps are not taken:
|
||||
https://github.com/langchain-ai/langchain/issues/5923
|
||||
"""
|
||||
|
||||
sql_cmd_parser: VectorSQLOutputParser
|
||||
"""Parser for Vector SQL"""
|
||||
native_format: bool = False
|
||||
"""If return_direct, controls whether to return in python native format"""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
input_text = f"{inputs[self.input_key]}\nSQLQuery:"
|
||||
_run_manager.on_text(input_text, verbose=self.verbose)
|
||||
# If not present, then defaults to None which is all tables.
|
||||
table_names_to_use = inputs.get("table_names_to_use")
|
||||
table_info = self.database.get_table_info(table_names=table_names_to_use)
|
||||
llm_inputs = {
|
||||
"input": input_text,
|
||||
"top_k": str(self.top_k),
|
||||
"dialect": self.database.dialect,
|
||||
"table_info": table_info,
|
||||
"stop": ["\nSQLResult:"],
|
||||
}
|
||||
intermediate_steps: List = []
|
||||
try:
|
||||
intermediate_steps.append(llm_inputs) # input: sql generation
|
||||
llm_out = self.llm_chain.predict(
|
||||
callbacks=_run_manager.get_child(),
|
||||
**llm_inputs,
|
||||
)
|
||||
sql_cmd = self.sql_cmd_parser.parse(llm_out)
|
||||
if self.return_sql:
|
||||
return {self.output_key: sql_cmd}
|
||||
if not self.use_query_checker:
|
||||
_run_manager.on_text(llm_out, color="green", verbose=self.verbose)
|
||||
intermediate_steps.append(
|
||||
llm_out
|
||||
) # output: sql generation (no checker)
|
||||
intermediate_steps.append({"sql_cmd": llm_out}) # input: sql exec
|
||||
result = get_result_from_sqldb(self.database, sql_cmd)
|
||||
intermediate_steps.append(str(result)) # output: sql exec
|
||||
else:
|
||||
query_checker_prompt = self.query_checker_prompt or PromptTemplate(
|
||||
template=QUERY_CHECKER, input_variables=["query", "dialect"]
|
||||
)
|
||||
query_checker_chain = LLMChain(
|
||||
llm=self.llm_chain.llm,
|
||||
prompt=query_checker_prompt,
|
||||
output_parser=self.llm_chain.output_parser,
|
||||
)
|
||||
query_checker_inputs = {
|
||||
"query": llm_out,
|
||||
"dialect": self.database.dialect,
|
||||
}
|
||||
checked_llm_out = query_checker_chain.predict(
|
||||
callbacks=_run_manager.get_child(), **query_checker_inputs
|
||||
)
|
||||
checked_sql_command = self.sql_cmd_parser.parse(checked_llm_out)
|
||||
intermediate_steps.append(
|
||||
checked_llm_out
|
||||
) # output: sql generation (checker)
|
||||
_run_manager.on_text(
|
||||
checked_llm_out, color="green", verbose=self.verbose
|
||||
)
|
||||
intermediate_steps.append(
|
||||
{"sql_cmd": checked_llm_out}
|
||||
) # input: sql exec
|
||||
result = get_result_from_sqldb(self.database, checked_sql_command)
|
||||
intermediate_steps.append(str(result)) # output: sql exec
|
||||
llm_out = checked_llm_out
|
||||
sql_cmd = checked_sql_command
|
||||
|
||||
_run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
|
||||
_run_manager.on_text(str(result), color="yellow", verbose=self.verbose)
|
||||
# If return direct, we just set the final result equal to
|
||||
# the result of the sql query result, otherwise try to get a human readable
|
||||
# final answer
|
||||
if self.return_direct:
|
||||
final_result = result
|
||||
else:
|
||||
_run_manager.on_text("\nAnswer:", verbose=self.verbose)
|
||||
input_text += f"{llm_out}\nSQLResult: {result}\nAnswer:"
|
||||
llm_inputs["input"] = input_text
|
||||
intermediate_steps.append(llm_inputs) # input: final answer
|
||||
final_result = self.llm_chain.predict(
|
||||
callbacks=_run_manager.get_child(),
|
||||
**llm_inputs,
|
||||
).strip()
|
||||
intermediate_steps.append(final_result) # output: final answer
|
||||
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
return chain_result
|
||||
except Exception as exc:
|
||||
# Append intermediate steps to exception, to aid in logging and later
|
||||
# improvement of few shot prompt seeds
|
||||
exc.intermediate_steps = intermediate_steps # type: ignore
|
||||
raise exc
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "vector_sql_database_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
db: SQLDatabase,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
sql_cmd_parser: Optional[VectorSQLOutputParser] = None,
|
||||
**kwargs: Any,
|
||||
) -> VectorSQLDatabaseChain:
|
||||
assert sql_cmd_parser, "`sql_cmd_parser` must be set in VectorSQLDatabaseChain."
|
||||
prompt = prompt or SQL_PROMPTS.get(db.dialect, PROMPT)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(
|
||||
llm_chain=llm_chain, database=db, sql_cmd_parser=sql_cmd_parser, **kwargs
|
||||
)
|
1125
libs/experimental/poetry.lock
generated
1125
libs/experimental/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-experimental"
|
||||
version = "0.0.14"
|
||||
version = "0.0.16"
|
||||
description = "Building applications with LLMs through composability"
|
||||
authors = []
|
||||
license = "MIT"
|
||||
@@ -26,6 +26,7 @@ black = "^23.1.0"
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
types-pyyaml = "^6.0.12.2"
|
||||
types-requests = "^2.28.11.5"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
jupyter = "^1.0.0"
|
||||
|
@@ -0,0 +1,154 @@
|
||||
import os
|
||||
from typing import Iterator, List
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def check_spacy_model() -> Iterator[None]:
|
||||
import spacy
|
||||
|
||||
if not spacy.util.is_package("en_core_web_lg"):
|
||||
pytest.skip(reason="Spacy model 'en_core_web_lg' not installed")
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.requires("presidio_analyzer", "presidio_anonymizer", "faker")
|
||||
@pytest.mark.parametrize(
|
||||
"analyzed_fields,should_contain",
|
||||
[(["PERSON"], False), (["PHONE_NUMBER"], True), (None, False)],
|
||||
)
|
||||
def test_anonymize(analyzed_fields: List[str], should_contain: bool) -> None:
|
||||
"""Test anonymizing a name in a simple sentence"""
|
||||
from langchain_experimental.data_anonymizer import PresidioReversibleAnonymizer
|
||||
|
||||
text = "Hello, my name is John Doe."
|
||||
anonymizer = PresidioReversibleAnonymizer(analyzed_fields=analyzed_fields)
|
||||
anonymized_text = anonymizer.anonymize(text)
|
||||
assert ("John Doe" in anonymized_text) == should_contain
|
||||
|
||||
|
||||
@pytest.mark.requires("presidio_analyzer", "presidio_anonymizer", "faker")
|
||||
def test_anonymize_multiple() -> None:
|
||||
"""Test anonymizing multiple items in a sentence"""
|
||||
from langchain_experimental.data_anonymizer import PresidioReversibleAnonymizer
|
||||
|
||||
text = "John Smith's phone number is 313-666-7440 and email is johnsmith@gmail.com"
|
||||
anonymizer = PresidioReversibleAnonymizer()
|
||||
anonymized_text = anonymizer.anonymize(text)
|
||||
for phrase in ["John Smith", "313-666-7440", "johnsmith@gmail.com"]:
|
||||
assert phrase not in anonymized_text
|
||||
|
||||
|
||||
@pytest.mark.requires("presidio_analyzer", "presidio_anonymizer", "faker")
|
||||
def test_anonymize_with_custom_operator() -> None:
|
||||
"""Test anonymize a name with a custom operator"""
|
||||
from presidio_anonymizer.entities import OperatorConfig
|
||||
|
||||
from langchain_experimental.data_anonymizer import PresidioReversibleAnonymizer
|
||||
|
||||
custom_operator = {"PERSON": OperatorConfig("replace", {"new_value": "<name>"})}
|
||||
anonymizer = PresidioReversibleAnonymizer(operators=custom_operator)
|
||||
|
||||
text = "Jane Doe was here."
|
||||
|
||||
anonymized_text = anonymizer.anonymize(text)
|
||||
assert anonymized_text == "<name> was here."
|
||||
|
||||
|
||||
@pytest.mark.requires("presidio_analyzer", "presidio_anonymizer", "faker")
|
||||
def test_add_recognizer_operator() -> None:
|
||||
"""
|
||||
Test add recognizer and anonymize a new type of entity and with a custom operator
|
||||
"""
|
||||
from presidio_analyzer import PatternRecognizer
|
||||
from presidio_anonymizer.entities import OperatorConfig
|
||||
|
||||
from langchain_experimental.data_anonymizer import PresidioReversibleAnonymizer
|
||||
|
||||
anonymizer = PresidioReversibleAnonymizer(analyzed_fields=[])
|
||||
titles_list = ["Sir", "Madam", "Professor"]
|
||||
custom_recognizer = PatternRecognizer(
|
||||
supported_entity="TITLE", deny_list=titles_list
|
||||
)
|
||||
anonymizer.add_recognizer(custom_recognizer)
|
||||
|
||||
# anonymizing with custom recognizer
|
||||
text = "Madam Jane Doe was here."
|
||||
anonymized_text = anonymizer.anonymize(text)
|
||||
assert anonymized_text == "<TITLE> Jane Doe was here."
|
||||
|
||||
# anonymizing with custom recognizer and operator
|
||||
custom_operator = {"TITLE": OperatorConfig("replace", {"new_value": "Dear"})}
|
||||
anonymizer.add_operators(custom_operator)
|
||||
anonymized_text = anonymizer.anonymize(text)
|
||||
assert anonymized_text == "Dear Jane Doe was here."
|
||||
|
||||
|
||||
@pytest.mark.requires("presidio_analyzer", "presidio_anonymizer", "faker")
|
||||
def test_deanonymizer_mapping() -> None:
|
||||
"""Test if deanonymizer mapping is correctly populated"""
|
||||
from langchain_experimental.data_anonymizer import PresidioReversibleAnonymizer
|
||||
|
||||
anonymizer = PresidioReversibleAnonymizer(
|
||||
analyzed_fields=["PERSON", "PHONE_NUMBER", "EMAIL_ADDRESS", "CREDIT_CARD"]
|
||||
)
|
||||
|
||||
anonymizer.anonymize("Hello, my name is John Doe and my number is 444 555 6666.")
|
||||
|
||||
# ["PERSON", "PHONE_NUMBER"]
|
||||
assert len(anonymizer.deanonymizer_mapping.keys()) == 2
|
||||
assert "John Doe" in anonymizer.deanonymizer_mapping.get("PERSON", {}).values()
|
||||
assert (
|
||||
"444 555 6666"
|
||||
in anonymizer.deanonymizer_mapping.get("PHONE_NUMBER", {}).values()
|
||||
)
|
||||
|
||||
text_to_anonymize = (
|
||||
"And my name is Jane Doe, my email is jane@gmail.com and "
|
||||
"my credit card is 4929 5319 6292 5362."
|
||||
)
|
||||
anonymizer.anonymize(text_to_anonymize)
|
||||
|
||||
# ["PERSON", "PHONE_NUMBER", "EMAIL_ADDRESS", "CREDIT_CARD"]
|
||||
assert len(anonymizer.deanonymizer_mapping.keys()) == 4
|
||||
assert "Jane Doe" in anonymizer.deanonymizer_mapping.get("PERSON", {}).values()
|
||||
assert (
|
||||
"jane@gmail.com"
|
||||
in anonymizer.deanonymizer_mapping.get("EMAIL_ADDRESS", {}).values()
|
||||
)
|
||||
assert (
|
||||
"4929 5319 6292 5362"
|
||||
in anonymizer.deanonymizer_mapping.get("CREDIT_CARD", {}).values()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("presidio_analyzer", "presidio_anonymizer", "faker")
|
||||
def test_deanonymize() -> None:
|
||||
"""Test deanonymizing a name in a simple sentence"""
|
||||
from langchain_experimental.data_anonymizer import PresidioReversibleAnonymizer
|
||||
|
||||
text = "Hello, my name is John Doe."
|
||||
anonymizer = PresidioReversibleAnonymizer(analyzed_fields=["PERSON"])
|
||||
anonymized_text = anonymizer.anonymize(text)
|
||||
deanonymized_text = anonymizer.deanonymize(anonymized_text)
|
||||
assert deanonymized_text == text
|
||||
|
||||
|
||||
@pytest.mark.requires("presidio_analyzer", "presidio_anonymizer", "faker")
|
||||
def test_save_load_deanonymizer_mapping() -> None:
|
||||
from langchain_experimental.data_anonymizer import PresidioReversibleAnonymizer
|
||||
|
||||
anonymizer = PresidioReversibleAnonymizer(analyzed_fields=["PERSON"])
|
||||
anonymizer.anonymize("Hello, my name is John Doe.")
|
||||
try:
|
||||
anonymizer.save_deanonymizer_mapping("test_file.json")
|
||||
assert os.path.isfile("test_file.json")
|
||||
|
||||
anonymizer = PresidioReversibleAnonymizer()
|
||||
anonymizer.load_deanonymizer_mapping("test_file.json")
|
||||
|
||||
assert "John Doe" in anonymizer.deanonymizer_mapping.get("PERSON", {}).values()
|
||||
|
||||
finally:
|
||||
os.remove("test_file.json")
|
Reference in New Issue
Block a user