mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 06:26:12 +00:00
merge
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import Type
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from langchain import OpenAI
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
@@ -46,7 +46,7 @@ class TestUnitCPALChain_MathWordProblems(unittest.TestCase):
|
||||
"""Unit Test the CPAL chain and its component chains on math word problems.
|
||||
|
||||
These tests can't run in the standard unit test directory because of
|
||||
this issue, https://github.com/hwchase17/langchain/issues/7451
|
||||
this issue, https://github.com/langchain-ai/langchain/issues/7451
|
||||
|
||||
"""
|
||||
|
||||
@@ -398,7 +398,7 @@ class TestCPALChain_MathWordProblems(unittest.TestCase):
|
||||
"""
|
||||
Test CPAL chain against the first example in the PAL chain notebook doc:
|
||||
|
||||
https://github.com/hwchase17/langchain/blob/master/docs/extras/modules/chains/additional/pal.ipynb
|
||||
https://github.com/langchain-ai/langchain/blob/master/docs/extras/modules/chains/additional/pal.ipynb
|
||||
"""
|
||||
|
||||
narrative_input = (
|
||||
|
@@ -1,7 +1,7 @@
|
||||
"""Test PAL chain."""
|
||||
|
||||
from langchain import OpenAI
|
||||
from langchain.chains.pal.base import PALChain
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
|
||||
def test_math_prompt() -> None:
|
||||
|
@@ -0,0 +1,104 @@
|
||||
import pytest
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
|
||||
from langchain_experimental.tabular_synthetic_data.base import SyntheticDataGenerator
|
||||
from langchain_experimental.tabular_synthetic_data.openai import (
|
||||
OPENAI_TEMPLATE,
|
||||
create_openai_data_generator,
|
||||
)
|
||||
from langchain_experimental.tabular_synthetic_data.prompts import (
|
||||
SYNTHETIC_FEW_SHOT_PREFIX,
|
||||
SYNTHETIC_FEW_SHOT_SUFFIX,
|
||||
)
|
||||
|
||||
|
||||
# Define the desired output schema for individual medical billing record
|
||||
class MedicalBilling(BaseModel):
|
||||
patient_id: int
|
||||
patient_name: str
|
||||
diagnosis_code: str
|
||||
procedure_code: str
|
||||
total_charge: float
|
||||
insurance_claim_amount: float
|
||||
|
||||
|
||||
examples = [
|
||||
{
|
||||
"example": """Patient ID: 123456, Patient Name: John Doe, Diagnosis Code:
|
||||
J20.9, Procedure Code: 99203, Total Charge: $500, Insurance Claim Amount:
|
||||
$350"""
|
||||
},
|
||||
{
|
||||
"example": """Patient ID: 789012, Patient Name: Johnson Smith, Diagnosis
|
||||
Code: M54.5, Procedure Code: 99213, Total Charge: $150, Insurance Claim
|
||||
Amount: $120"""
|
||||
},
|
||||
{
|
||||
"example": """Patient ID: 345678, Patient Name: Emily Stone, Diagnosis Code:
|
||||
E11.9, Procedure Code: 99214, Total Charge: $300, Insurance Claim Amount:
|
||||
$250"""
|
||||
},
|
||||
{
|
||||
"example": """Patient ID: 901234, Patient Name: Robert Miles, Diagnosis Code:
|
||||
B07.9, Procedure Code: 99204, Total Charge: $200, Insurance Claim Amount:
|
||||
$160"""
|
||||
},
|
||||
{
|
||||
"example": """Patient ID: 567890, Patient Name: Clara Jensen, Diagnosis Code:
|
||||
F41.9, Procedure Code: 99205, Total Charge: $450, Insurance Claim Amount:
|
||||
$310"""
|
||||
},
|
||||
{
|
||||
"example": """Patient ID: 234567, Patient Name: Alan Turing, Diagnosis Code:
|
||||
G40.909, Procedure Code: 99215, Total Charge: $220, Insurance Claim Amount:
|
||||
$180"""
|
||||
},
|
||||
]
|
||||
|
||||
prompt_template = FewShotPromptTemplate(
|
||||
prefix=SYNTHETIC_FEW_SHOT_PREFIX,
|
||||
examples=examples,
|
||||
suffix=SYNTHETIC_FEW_SHOT_SUFFIX,
|
||||
input_variables=["subject", "extra"],
|
||||
example_prompt=OPENAI_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def synthetic_data_generator() -> SyntheticDataGenerator:
|
||||
return create_openai_data_generator(
|
||||
output_schema=MedicalBilling,
|
||||
llm=ChatOpenAI(temperature=1), # replace with your LLM instance
|
||||
prompt=prompt_template,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_generate_synthetic(synthetic_data_generator: SyntheticDataGenerator) -> None:
|
||||
synthetic_results = synthetic_data_generator.generate(
|
||||
subject="medical_billing",
|
||||
extra="""the name must be chosen at random. Make it something you wouldn't
|
||||
normally choose.""",
|
||||
runs=10,
|
||||
)
|
||||
assert len(synthetic_results) == 10
|
||||
for row in synthetic_results:
|
||||
assert isinstance(row, MedicalBilling)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agenerate_synthetic(
|
||||
synthetic_data_generator: SyntheticDataGenerator,
|
||||
) -> None:
|
||||
synthetic_results = await synthetic_data_generator.agenerate(
|
||||
subject="medical_billing",
|
||||
extra="""the name must be chosen at random. Make it something you wouldn't
|
||||
normally choose.""",
|
||||
runs=10,
|
||||
)
|
||||
assert len(synthetic_results) == 10
|
||||
for row in synthetic_results:
|
||||
assert isinstance(row, MedicalBilling)
|
102
libs/experimental/tests/unit_tests/test_bash.py
Normal file
102
libs/experimental/tests/unit_tests/test_bash.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Test the bash utility."""
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_experimental.llm_bash.bash import BashProcess
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_pwd_command() -> None:
|
||||
"""Test correct functionality."""
|
||||
session = BashProcess()
|
||||
commands = ["pwd"]
|
||||
output = session.run(commands)
|
||||
|
||||
assert output == subprocess.check_output("pwd", shell=True).decode()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="flaky on GHA, TODO to fix")
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_pwd_command_persistent() -> None:
|
||||
"""Test correct functionality when the bash process is persistent."""
|
||||
session = BashProcess(persistent=True, strip_newlines=True)
|
||||
commands = ["pwd"]
|
||||
output = session.run(commands)
|
||||
|
||||
assert subprocess.check_output("pwd", shell=True).decode().strip() in output
|
||||
|
||||
session.run(["cd .."])
|
||||
new_output = session.run(["pwd"])
|
||||
# Assert that the new_output is a parent of the old output
|
||||
assert Path(output).parent == Path(new_output)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_incorrect_command() -> None:
|
||||
"""Test handling of incorrect command."""
|
||||
session = BashProcess()
|
||||
output = session.run(["invalid_command"])
|
||||
assert output == "Command 'invalid_command' returned non-zero exit status 127."
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_incorrect_command_return_err_output() -> None:
|
||||
"""Test optional returning of shell output on incorrect command."""
|
||||
session = BashProcess(return_err_output=True)
|
||||
output = session.run(["invalid_command"])
|
||||
assert re.match(
|
||||
r"^/bin/sh:.*invalid_command.*(?:not found|Permission denied).*$", output
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_create_directory_and_files(tmp_path: Path) -> None:
|
||||
"""Test creation of a directory and files in a temporary directory."""
|
||||
session = BashProcess(strip_newlines=True)
|
||||
|
||||
# create a subdirectory in the temporary directory
|
||||
temp_dir = tmp_path / "test_dir"
|
||||
temp_dir.mkdir()
|
||||
|
||||
# run the commands in the temporary directory
|
||||
commands = [
|
||||
f"touch {temp_dir}/file1.txt",
|
||||
f"touch {temp_dir}/file2.txt",
|
||||
f"echo 'hello world' > {temp_dir}/file2.txt",
|
||||
f"cat {temp_dir}/file2.txt",
|
||||
]
|
||||
|
||||
output = session.run(commands)
|
||||
assert output == "hello world"
|
||||
|
||||
# check that the files were created in the temporary directory
|
||||
output = session.run([f"ls {temp_dir}"])
|
||||
assert output == "file1.txt\nfile2.txt"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="flaky on GHA, TODO to fix")
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_create_bash_persistent() -> None:
|
||||
"""Test the pexpect persistent bash terminal"""
|
||||
session = BashProcess(persistent=True)
|
||||
response = session.run("echo hello")
|
||||
response += session.run("echo world")
|
||||
|
||||
assert "hello" in response
|
||||
assert "world" in response
|
@@ -48,6 +48,23 @@ def test_anonymize_multiple() -> None:
|
||||
assert phrase not in anonymized_text
|
||||
|
||||
|
||||
@pytest.mark.requires("presidio_analyzer", "presidio_anonymizer", "faker")
|
||||
def test_check_instances() -> None:
|
||||
"""Test anonymizing multiple items in a sentence"""
|
||||
from langchain_experimental.data_anonymizer import PresidioAnonymizer
|
||||
|
||||
text = (
|
||||
"This is John Smith. John Smith works in a bakery." "John Smith is a good guy"
|
||||
)
|
||||
anonymizer = PresidioAnonymizer(["PERSON"], faker_seed=42)
|
||||
anonymized_text = anonymizer.anonymize(text)
|
||||
assert anonymized_text.count("Connie Lawrence") == 3
|
||||
|
||||
# New name should be generated
|
||||
anonymized_text = anonymizer.anonymize(text)
|
||||
assert anonymized_text.count("Connie Lawrence") == 0
|
||||
|
||||
|
||||
@pytest.mark.requires("presidio_analyzer", "presidio_anonymizer", "faker")
|
||||
def test_anonymize_with_custom_operator() -> None:
|
||||
"""Test anonymize a name with a custom operator"""
|
||||
@@ -55,13 +72,13 @@ def test_anonymize_with_custom_operator() -> None:
|
||||
|
||||
from langchain_experimental.data_anonymizer import PresidioAnonymizer
|
||||
|
||||
custom_operator = {"PERSON": OperatorConfig("replace", {"new_value": "<name>"})}
|
||||
custom_operator = {"PERSON": OperatorConfig("replace", {"new_value": "NAME"})}
|
||||
anonymizer = PresidioAnonymizer(operators=custom_operator)
|
||||
|
||||
text = "Jane Doe was here."
|
||||
|
||||
anonymized_text = anonymizer.anonymize(text)
|
||||
assert anonymized_text == "<name> was here."
|
||||
assert anonymized_text == "NAME was here."
|
||||
|
||||
|
||||
@pytest.mark.requires("presidio_analyzer", "presidio_anonymizer", "faker")
|
||||
@@ -91,3 +108,21 @@ def test_add_recognizer_operator() -> None:
|
||||
anonymizer.add_operators(custom_operator)
|
||||
anonymized_text = anonymizer.anonymize(text)
|
||||
assert anonymized_text == "Dear Jane Doe was here."
|
||||
|
||||
|
||||
@pytest.mark.requires("presidio_analyzer", "presidio_anonymizer", "faker")
|
||||
def test_non_faker_values() -> None:
|
||||
"""Test anonymizing multiple items in a sentence without faker values"""
|
||||
from langchain_experimental.data_anonymizer import PresidioAnonymizer
|
||||
|
||||
text = (
|
||||
"My name is John Smith. Your name is Adam Smith. Her name is Jane Smith."
|
||||
"Our names are: John Smith, Adam Smith, Jane Smith."
|
||||
)
|
||||
expected_result = (
|
||||
"My name is <PERSON>. Your name is <PERSON_2>. Her name is <PERSON_3>."
|
||||
"Our names are: <PERSON>, <PERSON_2>, <PERSON_3>."
|
||||
)
|
||||
anonymizer = PresidioAnonymizer(add_default_faker_operators=False)
|
||||
anonymized_text = anonymizer.anonymize(text)
|
||||
assert anonymized_text == expected_result
|
||||
|
109
libs/experimental/tests/unit_tests/test_llm_bash.py
Normal file
109
libs/experimental/tests/unit_tests/test_llm_bash.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Test LLM Bash functionality."""
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from langchain.schema import OutputParserException
|
||||
|
||||
from langchain_experimental.llm_bash.base import LLMBashChain
|
||||
from langchain_experimental.llm_bash.prompt import _PROMPT_TEMPLATE, BashOutputParser
|
||||
from tests.unit_tests.fake_llm import FakeLLM
|
||||
|
||||
_SAMPLE_CODE = """
|
||||
Unrelated text
|
||||
```bash
|
||||
echo hello
|
||||
```
|
||||
Unrelated text
|
||||
"""
|
||||
|
||||
|
||||
_SAMPLE_CODE_2_LINES = """
|
||||
Unrelated text
|
||||
```bash
|
||||
echo hello
|
||||
|
||||
echo world
|
||||
```
|
||||
Unrelated text
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def output_parser() -> BashOutputParser:
|
||||
"""Output parser for testing."""
|
||||
return BashOutputParser()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_simple_question() -> None:
|
||||
"""Test simple question that should not need python."""
|
||||
question = "Please write a bash script that prints 'Hello World' to the console."
|
||||
prompt = _PROMPT_TEMPLATE.format(question=question)
|
||||
queries = {prompt: "```bash\nexpr 1 + 1\n```"}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a")
|
||||
output = fake_llm_bash_chain.run(question)
|
||||
assert output == "2\n"
|
||||
|
||||
|
||||
def test_get_code(output_parser: BashOutputParser) -> None:
|
||||
"""Test the parser."""
|
||||
code_lines = output_parser.parse(_SAMPLE_CODE)
|
||||
code = [c for c in code_lines if c.strip()]
|
||||
assert code == code_lines
|
||||
assert code == ["echo hello"]
|
||||
|
||||
code_lines = output_parser.parse(_SAMPLE_CODE + _SAMPLE_CODE_2_LINES)
|
||||
assert code_lines == ["echo hello", "echo hello", "echo world"]
|
||||
|
||||
|
||||
def test_parsing_error() -> None:
|
||||
"""Test that LLM Output without a bash block raises an exce"""
|
||||
question = "Please echo 'hello world' to the terminal."
|
||||
prompt = _PROMPT_TEMPLATE.format(question=question)
|
||||
queries = {
|
||||
prompt: """
|
||||
```text
|
||||
echo 'hello world'
|
||||
```
|
||||
"""
|
||||
}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a")
|
||||
with pytest.raises(OutputParserException):
|
||||
fake_llm_bash_chain.run(question)
|
||||
|
||||
|
||||
def test_get_code_lines_mixed_blocks(output_parser: BashOutputParser) -> None:
|
||||
text = """
|
||||
Unrelated text
|
||||
```bash
|
||||
echo hello
|
||||
ls && pwd && ls
|
||||
```
|
||||
|
||||
```python
|
||||
print("hello")
|
||||
```
|
||||
|
||||
```bash
|
||||
echo goodbye
|
||||
```
|
||||
"""
|
||||
code_lines = output_parser.parse(text)
|
||||
assert code_lines == ["echo hello", "ls && pwd && ls", "echo goodbye"]
|
||||
|
||||
|
||||
def test_get_code_lines_simple_nested_ticks(output_parser: BashOutputParser) -> None:
|
||||
"""Test that backticks w/o a newline are ignored."""
|
||||
text = """
|
||||
Unrelated text
|
||||
```bash
|
||||
echo hello
|
||||
echo "```bash is in this string```"
|
||||
```
|
||||
"""
|
||||
code_lines = output_parser.parse(text)
|
||||
assert code_lines == ["echo hello", 'echo "```bash is in this string```"']
|
82
libs/experimental/tests/unit_tests/test_llm_symbolic_math.py
Normal file
82
libs/experimental/tests/unit_tests/test_llm_symbolic_math.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Test LLM Math functionality."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_experimental.llm_symbolic_math.base import (
|
||||
LLMSymbolicMathChain,
|
||||
)
|
||||
from langchain_experimental.llm_symbolic_math.prompt import (
|
||||
_PROMPT_TEMPLATE,
|
||||
)
|
||||
from tests.unit_tests.fake_llm import FakeLLM
|
||||
|
||||
try:
|
||||
import sympy
|
||||
except ImportError:
|
||||
pytest.skip("sympy not installed", allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm_symbolic_math_chain() -> LLMSymbolicMathChain:
|
||||
"""Fake LLM Math chain for testing."""
|
||||
queries = {
|
||||
_PROMPT_TEMPLATE.format(question="What is 1 plus 1?"): "Answer: 2",
|
||||
_PROMPT_TEMPLATE.format(
|
||||
question="What is the square root of 2?"
|
||||
): "```text\nsqrt(2)\n```",
|
||||
_PROMPT_TEMPLATE.format(
|
||||
question="What is the limit of sin(x) / x as x goes to 0?"
|
||||
): "```text\nlimit(sin(x)/x,x,0)\n```",
|
||||
_PROMPT_TEMPLATE.format(
|
||||
question="What is the integral of e^-x from 0 to infinity?"
|
||||
): "```text\nintegrate(exp(-x), (x, 0, oo))\n```",
|
||||
_PROMPT_TEMPLATE.format(
|
||||
question="What are the solutions to this equation x**2 - x?"
|
||||
): "```text\nsolveset(x**2 - x, x)\n```",
|
||||
_PROMPT_TEMPLATE.format(question="foo"): "foo",
|
||||
}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
return LLMSymbolicMathChain.from_llm(fake_llm, input_key="q", output_key="a")
|
||||
|
||||
|
||||
def test_simple_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None:
|
||||
"""Test simple question that should not need python."""
|
||||
question = "What is 1 plus 1?"
|
||||
output = fake_llm_symbolic_math_chain.run(question)
|
||||
assert output == "Answer: 2"
|
||||
|
||||
|
||||
def test_root_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None:
|
||||
"""Test irrational number that should need sympy."""
|
||||
question = "What is the square root of 2?"
|
||||
output = fake_llm_symbolic_math_chain.run(question)
|
||||
assert output == f"Answer: {sympy.sqrt(2)}"
|
||||
|
||||
|
||||
def test_limit_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None:
|
||||
"""Test question about limits that needs sympy"""
|
||||
question = "What is the limit of sin(x) / x as x goes to 0?"
|
||||
output = fake_llm_symbolic_math_chain.run(question)
|
||||
assert output == "Answer: 1"
|
||||
|
||||
|
||||
def test_integration_question(
|
||||
fake_llm_symbolic_math_chain: LLMSymbolicMathChain,
|
||||
) -> None:
|
||||
"""Test question about integration that needs sympy"""
|
||||
question = "What is the integral of e^-x from 0 to infinity?"
|
||||
output = fake_llm_symbolic_math_chain.run(question)
|
||||
assert output == "Answer: 1"
|
||||
|
||||
|
||||
def test_solver_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None:
|
||||
"""Test question about solving algebraic equations that needs sympy"""
|
||||
question = "What are the solutions to this equation x**2 - x?"
|
||||
output = fake_llm_symbolic_math_chain.run(question)
|
||||
assert output == "Answer: {0, 1}"
|
||||
|
||||
|
||||
def test_error(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None:
|
||||
"""Test question that raises error."""
|
||||
with pytest.raises(ValueError):
|
||||
fake_llm_symbolic_math_chain.run("foo")
|
@@ -49,6 +49,32 @@ def test_anonymize_multiple() -> None:
|
||||
assert phrase not in anonymized_text
|
||||
|
||||
|
||||
@pytest.mark.requires("presidio_analyzer", "presidio_anonymizer", "faker")
|
||||
def test_check_instances() -> None:
|
||||
"""Test anonymizing multiple items in a sentence"""
|
||||
from langchain_experimental.data_anonymizer import PresidioReversibleAnonymizer
|
||||
|
||||
text = (
|
||||
"This is John Smith. John Smith works in a bakery." "John Smith is a good guy"
|
||||
)
|
||||
anonymizer = PresidioReversibleAnonymizer(["PERSON"], faker_seed=42)
|
||||
anonymized_text = anonymizer.anonymize(text)
|
||||
persons = list(anonymizer.deanonymizer_mapping["PERSON"].keys())
|
||||
assert len(persons) == 1
|
||||
|
||||
anonymized_name = persons[0]
|
||||
assert anonymized_text.count(anonymized_name) == 3
|
||||
|
||||
anonymized_text = anonymizer.anonymize(text)
|
||||
assert anonymized_text.count(anonymized_name) == 3
|
||||
assert anonymizer.deanonymizer_mapping["PERSON"][anonymized_name] == "John Smith"
|
||||
|
||||
text = "This is Jane Smith"
|
||||
anonymized_text = anonymizer.anonymize(text)
|
||||
persons = list(anonymizer.deanonymizer_mapping["PERSON"].keys())
|
||||
assert len(persons) == 2
|
||||
|
||||
|
||||
@pytest.mark.requires("presidio_analyzer", "presidio_anonymizer", "faker")
|
||||
def test_anonymize_with_custom_operator() -> None:
|
||||
"""Test anonymize a name with a custom operator"""
|
||||
@@ -56,13 +82,13 @@ def test_anonymize_with_custom_operator() -> None:
|
||||
|
||||
from langchain_experimental.data_anonymizer import PresidioReversibleAnonymizer
|
||||
|
||||
custom_operator = {"PERSON": OperatorConfig("replace", {"new_value": "<name>"})}
|
||||
custom_operator = {"PERSON": OperatorConfig("replace", {"new_value": "NAME"})}
|
||||
anonymizer = PresidioReversibleAnonymizer(operators=custom_operator)
|
||||
|
||||
text = "Jane Doe was here."
|
||||
|
||||
anonymized_text = anonymizer.anonymize(text)
|
||||
assert anonymized_text == "<name> was here."
|
||||
assert anonymized_text == "NAME was here."
|
||||
|
||||
|
||||
@pytest.mark.requires("presidio_analyzer", "presidio_anonymizer", "faker")
|
||||
@@ -88,6 +114,8 @@ def test_add_recognizer_operator() -> None:
|
||||
assert anonymized_text == "<TITLE> Jane Doe was here."
|
||||
|
||||
# anonymizing with custom recognizer and operator
|
||||
anonymizer = PresidioReversibleAnonymizer(analyzed_fields=[])
|
||||
anonymizer.add_recognizer(custom_recognizer)
|
||||
custom_operator = {"TITLE": OperatorConfig("replace", {"new_value": "Dear"})}
|
||||
anonymizer.add_operators(custom_operator)
|
||||
anonymized_text = anonymizer.anonymize(text)
|
||||
@@ -161,3 +189,21 @@ def test_save_load_deanonymizer_mapping() -> None:
|
||||
|
||||
finally:
|
||||
os.remove("test_file.json")
|
||||
|
||||
|
||||
@pytest.mark.requires("presidio_analyzer", "presidio_anonymizer", "faker")
|
||||
def test_non_faker_values() -> None:
|
||||
"""Test anonymizing multiple items in a sentence without faker values"""
|
||||
from langchain_experimental.data_anonymizer import PresidioReversibleAnonymizer
|
||||
|
||||
text = (
|
||||
"My name is John Smith. Your name is Adam Smith. Her name is Jane Smith."
|
||||
"Our names are: John Smith, Adam Smith, Jane Smith."
|
||||
)
|
||||
expected_result = (
|
||||
"My name is <PERSON>. Your name is <PERSON_2>. Her name is <PERSON_3>."
|
||||
"Our names are: <PERSON>, <PERSON_2>, <PERSON_3>."
|
||||
)
|
||||
anonymizer = PresidioReversibleAnonymizer(add_default_faker_operators=False)
|
||||
anonymized_text = anonymizer.anonymize(text)
|
||||
assert anonymized_text == expected_result
|
||||
|
128
libs/experimental/tests/unit_tests/test_sql.py
Normal file
128
libs/experimental/tests/unit_tests/test_sql.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.output_parsers.list import CommaSeparatedListOutputParser
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.sql_database import SQLDatabase
|
||||
|
||||
from langchain_experimental.sql.base import SQLDatabaseChain, SQLDatabaseSequentialChain
|
||||
from tests.unit_tests.fake_llm import FakeLLM
|
||||
|
||||
# Fake db to test SQL-Chain
|
||||
db = SQLDatabase.from_uri("sqlite:///:memory:")
|
||||
|
||||
|
||||
def create_fake_db(db: SQLDatabase) -> SQLDatabase:
|
||||
"""Create a table in fake db to test SQL-Chain"""
|
||||
db.run(
|
||||
"""
|
||||
CREATE TABLE foo (baaz TEXT);
|
||||
"""
|
||||
)
|
||||
db.run(
|
||||
"""
|
||||
INSERT INTO foo (baaz)
|
||||
VALUES ('baaz');
|
||||
"""
|
||||
)
|
||||
return db
|
||||
|
||||
|
||||
db = create_fake_db(db)
|
||||
|
||||
|
||||
def test_sql_chain_without_memory() -> None:
|
||||
queries = {"foo": "SELECT baaz from foo", "foo2": "SELECT baaz from foo"}
|
||||
llm = FakeLLM(queries=queries, sequential_responses=True)
|
||||
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
|
||||
assert db_chain.run("hello") == "SELECT baaz from foo"
|
||||
|
||||
|
||||
def test_sql_chain_sequential_without_memory() -> None:
|
||||
queries = {
|
||||
"foo": "SELECT baaz from foo",
|
||||
"foo2": "SELECT baaz from foo",
|
||||
"foo3": "SELECT baaz from foo",
|
||||
}
|
||||
llm = FakeLLM(queries=queries, sequential_responses=True)
|
||||
db_chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True)
|
||||
assert db_chain.run("hello") == "SELECT baaz from foo"
|
||||
|
||||
|
||||
def test_sql_chain_with_memory() -> None:
|
||||
valid_prompt_with_history = """
|
||||
Only use the following tables:
|
||||
{table_info}
|
||||
Question: {input}
|
||||
|
||||
Given an input question, first create a syntactically correct
|
||||
{dialect} query to run.
|
||||
Always limit your query to at most {top_k} results.
|
||||
|
||||
Relevant pieces of previous conversation:
|
||||
{history}
|
||||
|
||||
(You do not need to use these pieces of information if not relevant)
|
||||
"""
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["input", "table_info", "dialect", "top_k", "history"],
|
||||
template=valid_prompt_with_history,
|
||||
)
|
||||
queries = {"foo": "SELECT baaz from foo", "foo2": "SELECT baaz from foo"}
|
||||
llm = FakeLLM(queries=queries, sequential_responses=True)
|
||||
memory = ConversationBufferMemory()
|
||||
db_chain = SQLDatabaseChain.from_llm(
|
||||
llm, db, memory=memory, prompt=prompt, verbose=True
|
||||
)
|
||||
assert db_chain.run("hello") == "SELECT baaz from foo"
|
||||
|
||||
|
||||
def test_sql_chain_sequential_with_memory() -> None:
|
||||
valid_query_prompt_str = """
|
||||
Only use the following tables:
|
||||
{table_info}
|
||||
Question: {input}
|
||||
|
||||
Given an input question, first create a syntactically correct
|
||||
{dialect} query to run.
|
||||
Always limit your query to at most {top_k} results.
|
||||
|
||||
Relevant pieces of previous conversation:
|
||||
{history}
|
||||
|
||||
(You do not need to use these pieces of information
|
||||
if not relevant)
|
||||
"""
|
||||
valid_decider_prompt_str = """Given the below input question and list of potential
|
||||
tables, output a comma separated list of the
|
||||
table names that may be necessary to answer this question.
|
||||
|
||||
Question: {query}
|
||||
|
||||
Table Names: {table_names}
|
||||
|
||||
Relevant Table Names:"""
|
||||
|
||||
valid_query_prompt = PromptTemplate(
|
||||
input_variables=["input", "table_info", "dialect", "top_k", "history"],
|
||||
template=valid_query_prompt_str,
|
||||
)
|
||||
valid_decider_prompt = PromptTemplate(
|
||||
input_variables=["query", "table_names"],
|
||||
template=valid_decider_prompt_str,
|
||||
output_parser=CommaSeparatedListOutputParser(),
|
||||
)
|
||||
queries = {
|
||||
"foo": "SELECT baaz from foo",
|
||||
"foo2": "SELECT baaz from foo",
|
||||
"foo3": "SELECT baaz from foo",
|
||||
}
|
||||
llm = FakeLLM(queries=queries, sequential_responses=True)
|
||||
memory = ConversationBufferMemory(memory_key="history", input_key="query")
|
||||
db_chain = SQLDatabaseSequentialChain.from_llm(
|
||||
llm,
|
||||
db,
|
||||
memory=memory,
|
||||
decider_prompt=valid_decider_prompt,
|
||||
query_prompt=valid_query_prompt,
|
||||
verbose=True,
|
||||
)
|
||||
assert db_chain.run("hello") == "SELECT baaz from foo"
|
Reference in New Issue
Block a user