experimental: migrate to external repo (#26879)

security scanners can't distinguish monorepo sources from each other.
this will resolve issues for folks trying to use e.g. langchain-core but
getting security issues from experimental flagged!
This commit is contained in:
Erick Friis 2024-09-25 19:02:19 -07:00 committed by GitHub
parent c750600d3d
commit 2ea5f60cc5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
223 changed files with 570 additions and 25802 deletions

View File

@ -1,7 +1,7 @@
Thank you for contributing to LangChain!
- [ ] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes.
- Where "package" is whichever of langchain, community, core, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes.
- Example: "community: add foobar LLM"

View File

@ -15,7 +15,6 @@ LANGCHAIN_DIRS = [
"libs/text-splitters",
"libs/langchain",
"libs/community",
"libs/experimental",
]
# when set to True, we are ignoring core dependents
@ -153,14 +152,19 @@ def _get_pydantic_test_configs(
core_min_pydantic_version = get_min_version_from_toml(
"./libs/core/pyproject.toml", "release", python_version, include=["pydantic"]
)["pydantic"]
core_min_pydantic_minor = core_min_pydantic_version.split(".")[1] if "." in core_min_pydantic_version else "0"
dir_min_pydantic_version = (
get_min_version_from_toml(
f"./{dir_}/pyproject.toml", "release", python_version, include=["pydantic"]
)
.get("pydantic", "0.0.0")
core_min_pydantic_minor = (
core_min_pydantic_version.split(".")[1]
if "." in core_min_pydantic_version
else "0"
)
dir_min_pydantic_version = get_min_version_from_toml(
f"./{dir_}/pyproject.toml", "release", python_version, include=["pydantic"]
).get("pydantic", "0.0.0")
dir_min_pydantic_minor = (
dir_min_pydantic_version.split(".")[1]
if "." in dir_min_pydantic_version
else "0"
)
dir_min_pydantic_minor = dir_min_pydantic_version.split(".")[1] if "." in dir_min_pydantic_version else "0"
custom_mins = {
# depends on pydantic-settings 2.4 which requires pydantic 2.7

View File

@ -31,7 +31,7 @@ jobs:
- name: Install langchain editable
run: |
poetry run pip install -e libs/core libs/langchain libs/community libs/experimental
poetry run pip install -e libs/core libs/langchain libs/community
- name: Check doc imports
shell: bash

View File

@ -26,7 +26,6 @@ from sphinx.util.docutils import SphinxDirective
_DIR = Path(__file__).parent.absolute()
sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, os.path.abspath("../../libs/langchain"))
sys.path.insert(0, os.path.abspath("../../libs/experimental"))
with (_DIR.parents[1] / "libs" / "langchain" / "pyproject.toml").open("r") as f:
data = toml.load(f)

View File

@ -73,9 +73,9 @@ make docker_tests
There are also [integration tests and code-coverage](/docs/contributing/testing/) available.
### Only develop langchain_core or langchain_experimental
### Only develop langchain_core or langchain_community
If you are only developing `langchain_core` or `langchain_experimental`, you can simply install the dependencies for the respective projects and run tests:
If you are only developing `langchain_core` or `langchain_community`, you can simply install the dependencies for the respective projects and run tests:
```bash
cd libs/core
@ -86,7 +86,7 @@ make test
Or:
```bash
cd libs/experimental
cd libs/community
poetry install --with test
make test
```

View File

@ -1,7 +1,6 @@
-e ../libs/core
-e ../libs/langchain
-e ../libs/community
-e ../libs/experimental
-e ../libs/text-splitters
langchain-cohere
urllib3==1.26.19

View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,67 +0,0 @@
.PHONY: all format lint test tests test_watch integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
test:
poetry run pytest $(TEST_FILE)
tests:
poetry run pytest $(TEST_FILE)
test_watch:
poetry run ptw --now . -- tests/unit_tests
extended_tests:
poetry run pytest --only-extended tests/unit_tests
integration_tests:
poetry run pytest tests/integration_tests
check_imports: $(shell find langchain_experimental -name '*.py')
poetry run python ./scripts/check_imports.py $^
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/experimental --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_experimental
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
######################
# HELP
######################
help:
@echo '----'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'
@echo 'test_watch - run unit tests in watch mode'

View File

@ -1,16 +1,3 @@
# 🦜️🧪 LangChain Experimental
This package has moved!
This package holds experimental LangChain code, intended for research and experimental
uses.
> [!WARNING]
> Portions of the code in this package may be dangerous if not properly deployed
> in a sandboxed environment. Please be wary of deploying experimental code
> to production unless you've taken appropriate precautions and
> have already discussed it with your security team.
Some of the code here may be marked with security notices. However,
given the exploratory and experimental nature of the code in this package,
the lack of a security notice on a piece of code does not mean that
the code in question does not require additional security considerations
in order to be safe to use.
https://github.com/langchain-ai/langchain-experimental/tree/main/libs/experimental

View File

@ -1,8 +0,0 @@
presidio-anonymizer>=2.2.352,<3
presidio-analyzer>=2.2.352,<3
faker>=19.3.1,<20
vowpal-wabbit-next==0.7.0
sentence-transformers>=2,<3
jinja2>=3,<4
pandas>=2.0.1,<3
tabulate>=0.9.0,<1

View File

@ -1,8 +0,0 @@
from importlib import metadata
try:
__version__ = metadata.version(__package__)
except metadata.PackageNotFoundError:
# Case where package metadata is not available.
__version__ = ""
del metadata # optional, avoids polluting the results of dir(__package__)

View File

@ -1,23 +0,0 @@
"""**Agent** is a class that uses an LLM to choose
a sequence of actions to take.
In Chains, a sequence of actions is hardcoded. In Agents,
a language model is used as a reasoning engine to determine which actions
to take and in which order.
Agents select and use **Tools** and **Toolkits** for actions.
"""
from langchain_experimental.agents.agent_toolkits import (
create_csv_agent,
create_pandas_dataframe_agent,
create_spark_dataframe_agent,
create_xorbits_agent,
)
__all__ = [
"create_csv_agent",
"create_pandas_dataframe_agent",
"create_spark_dataframe_agent",
"create_xorbits_agent",
]

View File

@ -1,19 +0,0 @@
from langchain_experimental.agents.agent_toolkits.csv.base import create_csv_agent
from langchain_experimental.agents.agent_toolkits.pandas.base import (
create_pandas_dataframe_agent,
)
from langchain_experimental.agents.agent_toolkits.python.base import create_python_agent
from langchain_experimental.agents.agent_toolkits.spark.base import (
create_spark_dataframe_agent,
)
from langchain_experimental.agents.agent_toolkits.xorbits.base import (
create_xorbits_agent,
)
__all__ = [
"create_xorbits_agent",
"create_pandas_dataframe_agent",
"create_spark_dataframe_agent",
"create_python_agent",
"create_csv_agent",
]

View File

@ -1,66 +0,0 @@
from __future__ import annotations
from io import IOBase
from typing import TYPE_CHECKING, Any, List, Optional, Union
from langchain_experimental.agents.agent_toolkits.pandas.base import (
create_pandas_dataframe_agent,
)
if TYPE_CHECKING:
from langchain.agents.agent import AgentExecutor
from langchain_core.language_models import LanguageModelLike
def create_csv_agent(
llm: LanguageModelLike,
path: Union[str, IOBase, List[Union[str, IOBase]]],
pandas_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Create pandas dataframe agent by loading csv to a dataframe.
Args:
llm: Language model to use for the agent.
path: A string path, file-like object or a list of string paths/file-like
objects that can be read in as pandas DataFrames with pd.read_csv().
pandas_kwargs: Named arguments to pass to pd.read_csv().
kwargs: Additional kwargs to pass to langchain_experimental.agents.agent_toolkits.pandas.base.create_pandas_dataframe_agent().
Returns:
An AgentExecutor with the specified agent_type agent and access to
a PythonAstREPLTool with the loaded DataFrame(s) and any user-provided extra_tools.
Example:
.. code-block:: python
from langchain_openai import ChatOpenAI
from langchain_experimental.agents import create_csv_agent
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
agent_executor = create_pandas_dataframe_agent(
llm,
"titanic.csv",
agent_type="openai-tools",
verbose=True
)
""" # noqa: E501
try:
import pandas as pd
except ImportError:
raise ImportError(
"pandas package not found, please install with `pip install pandas`."
)
_kwargs = pandas_kwargs or {}
if isinstance(path, (str, IOBase)):
df = pd.read_csv(path, **_kwargs)
elif isinstance(path, list):
df = []
for item in path:
if not isinstance(item, (str, IOBase)):
raise ValueError(f"Expected str or file-like object, got {type(path)}")
df.append(pd.read_csv(item, **_kwargs))
else:
raise ValueError(f"Expected str, list, or file-like object, got {type(path)}")
return create_pandas_dataframe_agent(llm, df, **kwargs)

View File

@ -1,359 +0,0 @@
"""Agent for working with pandas objects."""
import warnings
from typing import Any, Dict, List, Literal, Optional, Sequence, Union, cast
from langchain.agents import (
AgentType,
create_openai_tools_agent,
create_react_agent,
create_tool_calling_agent,
)
from langchain.agents.agent import (
AgentExecutor,
BaseMultiActionAgent,
BaseSingleActionAgent,
RunnableAgent,
RunnableMultiActionAgent,
)
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.agents.openai_functions_agent.base import (
OpenAIFunctionsAgent,
create_openai_functions_agent,
)
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel, LanguageModelLike
from langchain_core.messages import SystemMessage
from langchain_core.prompts import (
BasePromptTemplate,
ChatPromptTemplate,
PromptTemplate,
)
from langchain_core.tools import BaseTool
from langchain_core.utils.interactive_env import is_interactive_env
from langchain_experimental.agents.agent_toolkits.pandas.prompt import (
FUNCTIONS_WITH_DF,
FUNCTIONS_WITH_MULTI_DF,
MULTI_DF_PREFIX,
MULTI_DF_PREFIX_FUNCTIONS,
PREFIX,
PREFIX_FUNCTIONS,
SUFFIX_NO_DF,
SUFFIX_WITH_DF,
SUFFIX_WITH_MULTI_DF,
)
from langchain_experimental.tools.python.tool import PythonAstREPLTool
def _get_multi_prompt(
dfs: List[Any],
*,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> BasePromptTemplate:
if suffix is not None:
suffix_to_use = suffix
elif include_df_in_prompt:
suffix_to_use = SUFFIX_WITH_MULTI_DF
else:
suffix_to_use = SUFFIX_NO_DF
prefix = prefix if prefix is not None else MULTI_DF_PREFIX
template = "\n\n".join([prefix, "{tools}", FORMAT_INSTRUCTIONS, suffix_to_use])
prompt = PromptTemplate.from_template(template)
partial_prompt = prompt.partial()
if "dfs_head" in partial_prompt.input_variables:
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
partial_prompt = partial_prompt.partial(dfs_head=dfs_head)
if "num_dfs" in partial_prompt.input_variables:
partial_prompt = partial_prompt.partial(num_dfs=str(len(dfs)))
return partial_prompt
def _get_single_prompt(
df: Any,
*,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> BasePromptTemplate:
if suffix is not None:
suffix_to_use = suffix
elif include_df_in_prompt:
suffix_to_use = SUFFIX_WITH_DF
else:
suffix_to_use = SUFFIX_NO_DF
prefix = prefix if prefix is not None else PREFIX
template = "\n\n".join([prefix, "{tools}", FORMAT_INSTRUCTIONS, suffix_to_use])
prompt = PromptTemplate.from_template(template)
partial_prompt = prompt.partial()
if "df_head" in partial_prompt.input_variables:
df_head = str(df.head(number_of_head_rows).to_markdown())
partial_prompt = partial_prompt.partial(df_head=df_head)
return partial_prompt
def _get_prompt(df: Any, **kwargs: Any) -> BasePromptTemplate:
return (
_get_multi_prompt(df, **kwargs)
if isinstance(df, list)
else _get_single_prompt(df, **kwargs)
)
def _get_functions_single_prompt(
df: Any,
*,
prefix: Optional[str] = None,
suffix: str = "",
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> ChatPromptTemplate:
if include_df_in_prompt:
df_head = str(df.head(number_of_head_rows).to_markdown())
suffix = (suffix or FUNCTIONS_WITH_DF).format(df_head=df_head)
prefix = prefix if prefix is not None else PREFIX_FUNCTIONS
system_message = SystemMessage(content=prefix + suffix)
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
return prompt
def _get_functions_multi_prompt(
dfs: Any,
*,
prefix: str = "",
suffix: str = "",
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> ChatPromptTemplate:
if include_df_in_prompt:
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
suffix = (suffix or FUNCTIONS_WITH_MULTI_DF).format(dfs_head=dfs_head)
prefix = (prefix or MULTI_DF_PREFIX_FUNCTIONS).format(num_dfs=str(len(dfs)))
system_message = SystemMessage(content=prefix + suffix)
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
return prompt
def _get_functions_prompt(df: Any, **kwargs: Any) -> ChatPromptTemplate:
return (
_get_functions_multi_prompt(df, **kwargs)
if isinstance(df, list)
else _get_functions_single_prompt(df, **kwargs)
)
def create_pandas_dataframe_agent(
llm: LanguageModelLike,
df: Any,
agent_type: Union[
AgentType, Literal["openai-tools", "tool-calling"]
] = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
input_variables: Optional[List[str]] = None,
verbose: bool = False,
return_intermediate_steps: bool = False,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
extra_tools: Sequence[BaseTool] = (),
engine: Literal["pandas", "modin"] = "pandas",
allow_dangerous_code: bool = False,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a Pandas agent from an LLM and dataframe(s).
Security Notice:
This agent relies on access to a python repl tool which can execute
arbitrary code. This can be dangerous and requires a specially sandboxed
environment to be safely used. Failure to run this code in a properly
sandboxed environment can lead to arbitrary code execution vulnerabilities,
which can lead to data breaches, data loss, or other security incidents.
Do not use this code with untrusted inputs, with elevated permissions,
or without consulting your security team about proper sandboxing!
You must opt-in to use this functionality by setting allow_dangerous_code=True.
Args:
llm: Language model to use for the agent. If agent_type is "tool-calling" then
llm is expected to support tool calling.
df: Pandas dataframe or list of Pandas dataframes.
agent_type: One of "tool-calling", "openai-tools", "openai-functions", or
"zero-shot-react-description". Defaults to "zero-shot-react-description".
"tool-calling" is recommended over the legacy "openai-tools" and
"openai-functions" types.
callback_manager: DEPRECATED. Pass "callbacks" key into 'agent_executor_kwargs'
instead to pass constructor callbacks to AgentExecutor.
prefix: Prompt prefix string.
suffix: Prompt suffix string.
input_variables: DEPRECATED. Input variables automatically inferred from
constructed prompt.
verbose: AgentExecutor verbosity.
return_intermediate_steps: Passed to AgentExecutor init.
max_iterations: Passed to AgentExecutor init.
max_execution_time: Passed to AgentExecutor init.
early_stopping_method: Passed to AgentExecutor init.
agent_executor_kwargs: Arbitrary additional AgentExecutor args.
include_df_in_prompt: Whether to include the first number_of_head_rows in the
prompt. Must be None if suffix is not None.
number_of_head_rows: Number of initial rows to include in prompt if
include_df_in_prompt is True.
extra_tools: Additional tools to give to agent on top of a PythonAstREPLTool.
engine: One of "modin" or "pandas". Defaults to "pandas".
allow_dangerous_code: bool, default False
This agent relies on access to a python repl tool which can execute
arbitrary code. This can be dangerous and requires a specially sandboxed
environment to be safely used.
Failure to properly sandbox this class can lead to arbitrary code execution
vulnerabilities, which can lead to data breaches, data loss, or
other security incidents.
You must opt in to use this functionality by setting
allow_dangerous_code=True.
**kwargs: DEPRECATED. Not used, kept for backwards compatibility.
Returns:
An AgentExecutor with the specified agent_type agent and access to
a PythonAstREPLTool with the DataFrame(s) and any user-provided extra_tools.
Example:
.. code-block:: python
from langchain_openai import ChatOpenAI
from langchain_experimental.agents import create_pandas_dataframe_agent
import pandas as pd
df = pd.read_csv("titanic.csv")
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
agent_executor = create_pandas_dataframe_agent(
llm,
df,
agent_type="tool-calling",
verbose=True
)
"""
if not allow_dangerous_code:
raise ValueError(
"This agent relies on access to a python repl tool which can execute "
"arbitrary code. This can be dangerous and requires a specially sandboxed "
"environment to be safely used. Please read the security notice in the "
"doc-string of this function. You must opt-in to use this functionality "
"by setting allow_dangerous_code=True."
"For general security guidelines, please see: "
"https://python.langchain.com/v0.2/docs/security/"
)
try:
if engine == "modin":
import modin.pandas as pd
elif engine == "pandas":
import pandas as pd
else:
raise ValueError(
f"Unsupported engine {engine}. It must be one of 'modin' or 'pandas'."
)
except ImportError as e:
raise ImportError(
f"`{engine}` package not found, please install with `pip install {engine}`"
) from e
if is_interactive_env():
pd.set_option("display.max_columns", None)
for _df in df if isinstance(df, list) else [df]:
if not isinstance(_df, pd.DataFrame):
raise ValueError(f"Expected pandas DataFrame, got {type(_df)}")
if input_variables:
kwargs = kwargs or {}
kwargs["input_variables"] = input_variables
if kwargs:
warnings.warn(
f"Received additional kwargs {kwargs} which are no longer supported."
)
df_locals = {}
if isinstance(df, list):
for i, dataframe in enumerate(df):
df_locals[f"df{i + 1}"] = dataframe
else:
df_locals["df"] = df
tools = [PythonAstREPLTool(locals=df_locals)] + list(extra_tools)
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
if include_df_in_prompt is not None and suffix is not None:
raise ValueError(
"If suffix is specified, include_df_in_prompt should not be."
)
prompt = _get_prompt(
df,
prefix=prefix,
suffix=suffix,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
)
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent] = RunnableAgent(
runnable=create_react_agent(llm, tools, prompt), # type: ignore
input_keys_arg=["input"],
return_keys_arg=["output"],
)
elif agent_type in (AgentType.OPENAI_FUNCTIONS, "openai-tools", "tool-calling"):
prompt = _get_functions_prompt(
df,
prefix=prefix,
suffix=suffix,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
)
if agent_type == AgentType.OPENAI_FUNCTIONS:
runnable = create_openai_functions_agent(
cast(BaseLanguageModel, llm), tools, prompt
)
agent = RunnableAgent(
runnable=runnable,
input_keys_arg=["input"],
return_keys_arg=["output"],
)
else:
if agent_type == "openai-tools":
runnable = create_openai_tools_agent(
cast(BaseLanguageModel, llm), tools, prompt
)
else:
runnable = create_tool_calling_agent(
cast(BaseLanguageModel, llm), tools, prompt
)
agent = RunnableMultiActionAgent(
runnable=runnable,
input_keys_arg=["input"],
return_keys_arg=["output"],
)
else:
raise ValueError(
f"Agent type {agent_type} not supported at the moment. Must be one of "
"'tool-calling', 'openai-tools', 'openai-functions', or "
"'zero-shot-react-description'."
)
return AgentExecutor(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
return_intermediate_steps=return_intermediate_steps,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)

View File

@ -1,44 +0,0 @@
# flake8: noqa
PREFIX = """
You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
You should use the tools below to answer the question posed of you:"""
MULTI_DF_PREFIX = """
You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc. You
should use the tools below to answer the question posed of you:"""
SUFFIX_NO_DF = """
Begin!
Question: {input}
{agent_scratchpad}"""
SUFFIX_WITH_DF = """
This is the result of `print(df.head())`:
{df_head}
Begin!
Question: {input}
{agent_scratchpad}"""
SUFFIX_WITH_MULTI_DF = """
This is the result of `print(df.head())` for each dataframe:
{dfs_head}
Begin!
Question: {input}
{agent_scratchpad}"""
PREFIX_FUNCTIONS = """
You are working with a pandas dataframe in Python. The name of the dataframe is `df`."""
MULTI_DF_PREFIX_FUNCTIONS = """
You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc."""
FUNCTIONS_WITH_DF = """
This is the result of `print(df.head())`:
{df_head}"""
FUNCTIONS_WITH_MULTI_DF = """
This is the result of `print(df.head())` for each dataframe:
{dfs_head}"""

View File

@ -1,59 +0,0 @@
"""Python agent."""
from typing import Any, Dict, Optional
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.agents.types import AgentType
from langchain.chains.llm import LLMChain
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import SystemMessage
from langchain_experimental.agents.agent_toolkits.python.prompt import PREFIX
from langchain_experimental.tools.python.tool import PythonREPLTool
def create_python_agent(
llm: BaseLanguageModel,
tool: PythonREPLTool,
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
callback_manager: Optional[BaseCallbackManager] = None,
verbose: bool = False,
prefix: str = PREFIX,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a python agent from an LLM and tool."""
tools = [tool]
agent: BaseSingleActionAgent
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore[arg-type]
elif agent_type == AgentType.OPENAI_FUNCTIONS:
system_message = SystemMessage(content=prefix)
_prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
agent = OpenAIFunctionsAgent( # type: ignore[call-arg]
llm=llm,
prompt=_prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs, # type: ignore[arg-type]
)
else:
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
**(agent_executor_kwargs or {}),
)

View File

@ -1,9 +0,0 @@
# flake8: noqa
PREFIX = """You are an agent designed to write and execute python code to answer questions.
You have access to a python REPL, which you can use to execute python code.
If you get an error, debug your code and try again.
Only use the output of your code to answer the question.
You might know the answer without running any code, but you should still run the code to get the answer.
If it does not seem like you can write code to answer the question, just return "I don't know" as the answer.
"""

View File

@ -1,117 +0,0 @@
"""Agent for working with pandas objects."""
from typing import Any, Dict, List, Optional
from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_core.language_models import BaseLLM
from langchain_experimental.agents.agent_toolkits.spark.prompt import PREFIX, SUFFIX
from langchain_experimental.tools.python.tool import PythonAstREPLTool
def _validate_spark_df(df: Any) -> bool:
try:
from pyspark.sql import DataFrame as SparkLocalDataFrame
return isinstance(df, SparkLocalDataFrame)
except ImportError:
return False
def _validate_spark_connect_df(df: Any) -> bool:
try:
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
return isinstance(df, SparkConnectDataFrame)
except ImportError:
return False
def create_spark_dataframe_agent(
llm: BaseLLM,
df: Any,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
input_variables: Optional[List[str]] = None,
verbose: bool = False,
return_intermediate_steps: bool = False,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
allow_dangerous_code: bool = False,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a Spark agent from an LLM and dataframe.
Security Notice:
This agent relies on access to a python repl tool which can execute
arbitrary code. This can be dangerous and requires a specially sandboxed
environment to be safely used. Failure to run this code in a properly
sandboxed environment can lead to arbitrary code execution vulnerabilities,
which can lead to data breaches, data loss, or other security incidents.
Do not use this code with untrusted inputs, with elevated permissions,
or without consulting your security team about proper sandboxing!
You must opt in to use this functionality by setting allow_dangerous_code=True.
Args:
allow_dangerous_code: bool, default False
This agent relies on access to a python repl tool which can execute
arbitrary code. This can be dangerous and requires a specially sandboxed
environment to be safely used.
Failure to properly sandbox this class can lead to arbitrary code execution
vulnerabilities, which can lead to data breaches, data loss, or
other security incidents.
You must opt in to use this functionality by setting
allow_dangerous_code=True.
"""
if not allow_dangerous_code:
raise ValueError(
"This agent relies on access to a python repl tool which can execute "
"arbitrary code. This can be dangerous and requires a specially sandboxed "
"environment to be safely used. Please read the security notice in the "
"doc-string of this function. You must opt-in to use this functionality "
"by setting allow_dangerous_code=True."
"For general security guidelines, please see: "
"https://python.langchain.com/v0.2/docs/security/"
)
if not _validate_spark_df(df) and not _validate_spark_connect_df(df):
raise ImportError("Spark is not installed. run `pip install pyspark`.")
if input_variables is None:
input_variables = ["df", "input", "agent_scratchpad"]
tools = [PythonAstREPLTool(locals={"df": df})]
prompt = ZeroShotAgent.create_prompt(
tools, prefix=prefix, suffix=suffix, input_variables=input_variables
)
partial_prompt = prompt.partial(df=str(df.first()))
llm_chain = LLMChain(
llm=llm,
prompt=partial_prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent( # type: ignore[call-arg]
llm_chain=llm_chain,
allowed_tools=tool_names,
callback_manager=callback_manager,
**kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
return_intermediate_steps=return_intermediate_steps,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)

View File

@ -1,13 +0,0 @@
# flake8: noqa
PREFIX = """
You are working with a spark dataframe in Python. The name of the dataframe is `df`.
You should use the tools below to answer the question posed of you:"""
SUFFIX = """
This is the result of `print(df.first())`:
{df}
Begin!
Question: {input}
{agent_scratchpad}"""

View File

@ -1,128 +0,0 @@
"""Agent for working with xorbits objects."""
from typing import Any, Dict, List, Optional
from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_core.language_models import BaseLLM
from langchain_experimental.agents.agent_toolkits.xorbits.prompt import (
NP_PREFIX,
NP_SUFFIX,
PD_PREFIX,
PD_SUFFIX,
)
from langchain_experimental.tools.python.tool import PythonAstREPLTool
def create_xorbits_agent(
llm: BaseLLM,
data: Any,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = "",
suffix: str = "",
input_variables: Optional[List[str]] = None,
verbose: bool = False,
return_intermediate_steps: bool = False,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
allow_dangerous_code: bool = False,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a xorbits agent from an LLM and dataframe.
Security Notice:
This agent relies on access to a python repl tool which can execute
arbitrary code. This can be dangerous and requires a specially sandboxed
environment to be safely used. Failure to run this code in a properly
sandboxed environment can lead to arbitrary code execution vulnerabilities,
which can lead to data breaches, data loss, or other security incidents.
Do not use this code with untrusted inputs, with elevated permissions,
or without consulting your security team about proper sandboxing!
You must opt in to use this functionality by setting allow_dangerous_code=True.
Args:
allow_dangerous_code: bool, default False
This agent relies on access to a python repl tool which can execute
arbitrary code. This can be dangerous and requires a specially sandboxed
environment to be safely used.
Failure to properly sandbox this class can lead to arbitrary code execution
vulnerabilities, which can lead to data breaches, data loss, or
other security incidents.
You must opt in to use this functionality by setting
allow_dangerous_code=True.
"""
if not allow_dangerous_code:
raise ValueError(
"This agent relies on access to a python repl tool which can execute "
"arbitrary code. This can be dangerous and requires a specially sandboxed "
"environment to be safely used. Please read the security notice in the "
"doc-string of this function. You must opt-in to use this functionality "
"by setting allow_dangerous_code=True."
"For general security guidelines, please see: "
"https://python.langchain.com/v0.2/docs/security/"
)
try:
from xorbits import numpy as np
from xorbits import pandas as pd
except ImportError:
raise ImportError(
"Xorbits package not installed, please install with `pip install xorbits`"
)
if not isinstance(data, (pd.DataFrame, np.ndarray)):
raise ValueError(
f"Expected Xorbits DataFrame or ndarray object, got {type(data)}"
)
if input_variables is None:
input_variables = ["data", "input", "agent_scratchpad"]
tools = [PythonAstREPLTool(locals={"data": data})]
prompt, partial_input = None, None
if isinstance(data, pd.DataFrame):
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=PD_PREFIX if prefix == "" else prefix,
suffix=PD_SUFFIX if suffix == "" else suffix,
input_variables=input_variables,
)
partial_input = str(data.head())
else:
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=NP_PREFIX if prefix == "" else prefix,
suffix=NP_SUFFIX if suffix == "" else suffix,
input_variables=input_variables,
)
partial_input = str(data[: len(data) // 2])
partial_prompt = prompt.partial(data=partial_input)
llm_chain = LLMChain(
llm=llm,
prompt=partial_prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent( # type: ignore[call-arg]
llm_chain=llm_chain,
allowed_tools=tool_names,
callback_manager=callback_manager,
**kwargs, # type: ignore[arg-type]
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
return_intermediate_steps=return_intermediate_steps,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)

View File

@ -1,33 +0,0 @@
PD_PREFIX = """
You are working with Xorbits dataframe object in Python.
Before importing Numpy or Pandas in the current script,
remember to import the xorbits version of the library instead.
To import the xorbits version of Numpy, replace the original import statement
`import pandas as pd` with `import xorbits.pandas as pd`.
The name of the input is `data`.
You should use the tools below to answer the question posed of you:"""
PD_SUFFIX = """
This is the result of `print(data)`:
{data}
Begin!
Question: {input}
{agent_scratchpad}"""
NP_PREFIX = """
You are working with Xorbits ndarray object in Python.
Before importing Numpy in the current script,
remember to import the xorbits version of the library instead.
To import the xorbits version of Numpy, replace the original import statement
`import numpy as np` with `import xorbits.numpy as np`.
The name of the input is `data`.
You should use the tools below to answer the question posed of you:"""
NP_SUFFIX = """
This is the result of `print(data)`:
{data}
Begin!
Question: {input}
{agent_scratchpad}"""

View File

@ -1,18 +0,0 @@
"""**Autonomous agents** in the Langchain experimental package include
[AutoGPT](https://github.com/Significant-Gravitas/AutoGPT),
[BabyAGI](https://github.com/yoheinakajima/babyagi),
and [HuggingGPT](https://arxiv.org/abs/2303.17580) agents that
interact with language models autonomously.
These agents have specific functionalities like memory management,
task creation, execution chains, and response generation.
They differ from ordinary agents by their autonomous decision-making capabilities,
memory handling, and specialized functionalities for tasks and response.
"""
from langchain_experimental.autonomous_agents.autogpt.agent import AutoGPT
from langchain_experimental.autonomous_agents.baby_agi.baby_agi import BabyAGI
from langchain_experimental.autonomous_agents.hugginggpt.hugginggpt import HuggingGPT
__all__ = ["BabyAGI", "AutoGPT", "HuggingGPT"]

View File

@ -1,143 +0,0 @@
from __future__ import annotations
from typing import List, Optional
from langchain.chains.llm import LLMChain
from langchain.memory import ChatMessageHistory
from langchain.schema import (
BaseChatMessageHistory,
Document,
)
from langchain_community.tools.human.tool import HumanInputRun
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from langchain_core.vectorstores import VectorStoreRetriever
from pydantic import ValidationError
from langchain_experimental.autonomous_agents.autogpt.output_parser import (
AutoGPTOutputParser,
BaseAutoGPTOutputParser,
)
from langchain_experimental.autonomous_agents.autogpt.prompt import AutoGPTPrompt
from langchain_experimental.autonomous_agents.autogpt.prompt_generator import (
FINISH_NAME,
)
class AutoGPT:
"""Agent for interacting with AutoGPT."""
def __init__(
self,
ai_name: str,
memory: VectorStoreRetriever,
chain: LLMChain,
output_parser: BaseAutoGPTOutputParser,
tools: List[BaseTool],
feedback_tool: Optional[HumanInputRun] = None,
chat_history_memory: Optional[BaseChatMessageHistory] = None,
):
self.ai_name = ai_name
self.memory = memory
self.next_action_count = 0
self.chain = chain
self.output_parser = output_parser
self.tools = tools
self.feedback_tool = feedback_tool
self.chat_history_memory = chat_history_memory or ChatMessageHistory()
@classmethod
def from_llm_and_tools(
cls,
ai_name: str,
ai_role: str,
memory: VectorStoreRetriever,
tools: List[BaseTool],
llm: BaseChatModel,
human_in_the_loop: bool = False,
output_parser: Optional[BaseAutoGPTOutputParser] = None,
chat_history_memory: Optional[BaseChatMessageHistory] = None,
) -> AutoGPT:
prompt = AutoGPTPrompt( # type: ignore[call-arg, call-arg, call-arg, call-arg]
ai_name=ai_name,
ai_role=ai_role,
tools=tools,
input_variables=["memory", "messages", "goals", "user_input"],
token_counter=llm.get_num_tokens,
)
human_feedback_tool = HumanInputRun() if human_in_the_loop else None
chain = LLMChain(llm=llm, prompt=prompt)
return cls(
ai_name,
memory,
chain,
output_parser or AutoGPTOutputParser(),
tools,
feedback_tool=human_feedback_tool,
chat_history_memory=chat_history_memory,
)
def run(self, goals: List[str]) -> str:
user_input = (
"Determine which next command to use, "
"and respond using the format specified above:"
)
# Interaction Loop
loop_count = 0
while True:
# Discontinue if continuous limit is reached
loop_count += 1
# Send message to AI, get response
assistant_reply = self.chain.run(
goals=goals,
messages=self.chat_history_memory.messages,
memory=self.memory,
user_input=user_input,
)
# Print Assistant thoughts
print(assistant_reply) # noqa: T201
self.chat_history_memory.add_message(HumanMessage(content=user_input))
self.chat_history_memory.add_message(AIMessage(content=assistant_reply))
# Get command name and arguments
action = self.output_parser.parse(assistant_reply)
tools = {t.name: t for t in self.tools}
if action.name == FINISH_NAME:
return action.args["response"]
if action.name in tools:
tool = tools[action.name]
try:
observation = tool.run(action.args)
except ValidationError as e:
observation = (
f"Validation Error in args: {str(e)}, args: {action.args}"
)
except Exception as e:
observation = (
f"Error: {str(e)}, {type(e).__name__}, args: {action.args}"
)
result = f"Command {tool.name} returned: {observation}"
elif action.name == "ERROR":
result = f"Error: {action.args}. "
else:
result = (
f"Unknown command '{action.name}'. "
f"Please refer to the 'COMMANDS' list for available "
f"commands and only respond in the specified JSON format."
)
memory_to_add = (
f"Assistant Reply: {assistant_reply} " f"\nResult: {result} "
)
if self.feedback_tool is not None:
feedback = f"{self.feedback_tool.run('Input: ')}"
if feedback in {"q", "stop"}:
print("EXITING") # noqa: T201
return "EXITING"
memory_to_add += f"\n{feedback}"
self.memory.add_documents([Document(page_content=memory_to_add)])
self.chat_history_memory.add_message(SystemMessage(content=result))

View File

@ -1,32 +0,0 @@
from typing import Any, Dict, List
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.utils import get_prompt_input_key
from langchain_core.vectorstores import VectorStoreRetriever
from pydantic import Field
class AutoGPTMemory(BaseChatMemory):
"""Memory for AutoGPT."""
retriever: VectorStoreRetriever = Field(exclude=True)
"""VectorStoreRetriever object to connect to."""
@property
def memory_variables(self) -> List[str]:
return ["chat_history", "relevant_context"]
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
"""Get the input key for the prompt."""
if self.input_key is None:
return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
docs = self.retriever.invoke(query)
return {
"chat_history": self.chat_memory.messages[-10:],
"relevant_context": docs,
}

View File

@ -1,66 +0,0 @@
import json
import re
from abc import abstractmethod
from typing import Dict, NamedTuple
from langchain_core.output_parsers import BaseOutputParser
class AutoGPTAction(NamedTuple):
"""Action returned by AutoGPTOutputParser."""
name: str
args: Dict
class BaseAutoGPTOutputParser(BaseOutputParser):
"""Base Output parser for AutoGPT."""
@abstractmethod
def parse(self, text: str) -> AutoGPTAction:
"""Return AutoGPTAction"""
def preprocess_json_input(input_str: str) -> str:
"""Preprocesses a string to be parsed as json.
Replace single backslashes with double backslashes,
while leaving already escaped ones intact.
Args:
input_str: String to be preprocessed
Returns:
Preprocessed string
"""
corrected_str = re.sub(
r'(?<!\\)\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', r"\\\\", input_str
)
return corrected_str
class AutoGPTOutputParser(BaseAutoGPTOutputParser):
"""Output parser for AutoGPT."""
def parse(self, text: str) -> AutoGPTAction:
try:
parsed = json.loads(text, strict=False)
except json.JSONDecodeError:
preprocessed_text = preprocess_json_input(text)
try:
parsed = json.loads(preprocessed_text, strict=False)
except Exception:
return AutoGPTAction(
name="ERROR",
args={"error": f"Could not parse invalid json: {text}"},
)
try:
return AutoGPTAction(
name=parsed["command"]["name"],
args=parsed["command"]["args"],
)
except (KeyError, TypeError):
# If the command is null or incomplete, return an erroneous tool
return AutoGPTAction(
name="ERROR", args={"error": f"Incomplete command args: {parsed}"}
)

View File

@ -1,106 +0,0 @@
import time
from typing import Any, Callable, List, cast
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.prompts.chat import (
BaseChatPromptTemplate,
)
from langchain_core.tools import BaseTool
from langchain_core.vectorstores import VectorStoreRetriever
from pydantic import BaseModel
from langchain_experimental.autonomous_agents.autogpt.prompt_generator import get_prompt
# This class has a metaclass conflict: both `BaseChatPromptTemplate` and `BaseModel`
# define a metaclass to use, and the two metaclasses attempt to define
# the same functions but in mutually-incompatible ways.
# It isn't clear how to resolve this, and this code predates mypy
# beginning to perform that check.
#
# Mypy errors:
# ```
# Definition of "__private_attributes__" in base class "BaseModel" is
# incompatible with definition in base class "BaseModel" [misc]
# Definition of "__repr_name__" in base class "Representation" is
# incompatible with definition in base class "BaseModel" [misc]
# Definition of "__pretty__" in base class "Representation" is
# incompatible with definition in base class "BaseModel" [misc]
# Definition of "__repr_str__" in base class "Representation" is
# incompatible with definition in base class "BaseModel" [misc]
# Definition of "__rich_repr__" in base class "Representation" is
# incompatible with definition in base class "BaseModel" [misc]
# Metaclass conflict: the metaclass of a derived class must be
# a (non-strict) subclass of the metaclasses of all its bases [misc]
# ```
#
# TODO: look into refactoring this class in a way that avoids the mypy type errors
class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
"""Prompt for AutoGPT."""
ai_name: str
ai_role: str
tools: List[BaseTool]
token_counter: Callable[[str], int]
send_token_limit: int = 4196
def construct_full_prompt(self, goals: List[str]) -> str:
prompt_start = (
"Your decisions must always be made independently "
"without seeking user assistance.\n"
"Play to your strengths as an LLM and pursue simple "
"strategies with no legal complications.\n"
"If you have completed all your tasks, make sure to "
'use the "finish" command.'
)
# Construct full prompt
full_prompt = (
f"You are {self.ai_name}, {self.ai_role}\n{prompt_start}\n\nGOALS:\n\n"
)
for i, goal in enumerate(goals):
full_prompt += f"{i+1}. {goal}\n"
full_prompt += f"\n\n{get_prompt(self.tools)}"
return full_prompt
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
base_prompt = SystemMessage(content=self.construct_full_prompt(kwargs["goals"]))
time_prompt = SystemMessage(
content=f"The current time and date is {time.strftime('%c')}"
)
used_tokens = self.token_counter(
cast(str, base_prompt.content)
) + self.token_counter(cast(str, time_prompt.content))
memory: VectorStoreRetriever = kwargs["memory"]
previous_messages = kwargs["messages"]
relevant_docs = memory.invoke(str(previous_messages[-10:]))
relevant_memory = [d.page_content for d in relevant_docs]
relevant_memory_tokens = sum(
[self.token_counter(doc) for doc in relevant_memory]
)
while used_tokens + relevant_memory_tokens > 2500:
relevant_memory = relevant_memory[:-1]
relevant_memory_tokens = sum(
[self.token_counter(doc) for doc in relevant_memory]
)
content_format = (
f"This reminds you of these events "
f"from your past:\n{relevant_memory}\n\n"
)
memory_message = SystemMessage(content=content_format)
used_tokens += self.token_counter(cast(str, memory_message.content))
historical_messages: List[BaseMessage] = []
for message in previous_messages[-10:][::-1]:
message_tokens = self.token_counter(message.content)
if used_tokens + message_tokens > self.send_token_limit - 1000:
break
historical_messages = [message] + historical_messages
used_tokens += message_tokens
input_message = HumanMessage(content=kwargs["user_input"])
messages: List[BaseMessage] = [base_prompt, time_prompt, memory_message]
messages += historical_messages
messages.append(input_message)
return messages
def pretty_repr(self, html: bool = False) -> str:
raise NotImplementedError

View File

@ -1,186 +0,0 @@
import json
from typing import List
from langchain_core.tools import BaseTool
FINISH_NAME = "finish"
class PromptGenerator:
"""Generator of custom prompt strings.
Does this based on constraints, commands, resources, and performance evaluations.
"""
def __init__(self) -> None:
"""Initialize the PromptGenerator object.
Starts with empty lists of constraints, commands, resources,
and performance evaluations.
"""
self.constraints: List[str] = []
self.commands: List[BaseTool] = []
self.resources: List[str] = []
self.performance_evaluation: List[str] = []
self.response_format = {
"thoughts": {
"text": "thought",
"reasoning": "reasoning",
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
"criticism": "constructive self-criticism",
"speak": "thoughts summary to say to user",
},
"command": {"name": "command name", "args": {"arg name": "value"}},
}
def add_constraint(self, constraint: str) -> None:
"""
Add a constraint to the constraints list.
Args:
constraint (str): The constraint to be added.
"""
self.constraints.append(constraint)
def add_tool(self, tool: BaseTool) -> None:
self.commands.append(tool)
def _generate_command_string(self, tool: BaseTool) -> str:
output = f"{tool.name}: {tool.description}"
output += f", args json schema: {json.dumps(tool.args)}"
return output
def add_resource(self, resource: str) -> None:
"""
Add a resource to the resources list.
Args:
resource (str): The resource to be added.
"""
self.resources.append(resource)
def add_performance_evaluation(self, evaluation: str) -> None:
"""
Add a performance evaluation item to the performance_evaluation list.
Args:
evaluation (str): The evaluation item to be added.
"""
self.performance_evaluation.append(evaluation)
def _generate_numbered_list(self, items: list, item_type: str = "list") -> str:
"""
Generate a numbered list from given items based on the item_type.
Args:
items (list): A list of items to be numbered.
item_type (str, optional): The type of items in the list.
Defaults to 'list'.
Returns:
str: The formatted numbered list.
"""
if item_type == "command":
command_strings = [
f"{i + 1}. {self._generate_command_string(item)}"
for i, item in enumerate(items)
]
finish_description = (
"use this to signal that you have finished all your objectives"
)
finish_args = (
'"response": "final response to let '
'people know you have finished your objectives"'
)
finish_string = (
f"{len(items) + 1}. {FINISH_NAME}: "
f"{finish_description}, args: {finish_args}"
)
return "\n".join(command_strings + [finish_string])
else:
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
def generate_prompt_string(self) -> str:
"""Generate a prompt string.
Returns:
str: The generated prompt string.
"""
formatted_response_format = json.dumps(self.response_format, indent=4)
prompt_string = (
f"Constraints:\n{self._generate_numbered_list(self.constraints)}\n\n"
f"Commands:\n"
f"{self._generate_numbered_list(self.commands, item_type='command')}\n\n"
f"Resources:\n{self._generate_numbered_list(self.resources)}\n\n"
f"Performance Evaluation:\n"
f"{self._generate_numbered_list(self.performance_evaluation)}\n\n"
f"You should only respond in JSON format as described below "
f"\nResponse Format: \n{formatted_response_format} "
f"\nEnsure the response can be parsed by Python json.loads"
)
return prompt_string
def get_prompt(tools: List[BaseTool]) -> str:
"""Generates a prompt string.
It includes various constraints, commands, resources, and performance evaluations.
Returns:
str: The generated prompt string.
"""
# Initialize the PromptGenerator object
prompt_generator = PromptGenerator()
# Add constraints to the PromptGenerator object
prompt_generator.add_constraint(
"~4000 word limit for short term memory. "
"Your short term memory is short, "
"so immediately save important information to files."
)
prompt_generator.add_constraint(
"If you are unsure how you previously did something "
"or want to recall past events, "
"thinking about similar events will help you remember."
)
prompt_generator.add_constraint("No user assistance")
prompt_generator.add_constraint(
'Exclusively use the commands listed in double quotes e.g. "command name"'
)
# Add commands to the PromptGenerator object
for tool in tools:
prompt_generator.add_tool(tool)
# Add resources to the PromptGenerator object
prompt_generator.add_resource(
"Internet access for searches and information gathering."
)
prompt_generator.add_resource("Long Term memory management.")
prompt_generator.add_resource(
"GPT-3.5 powered Agents for delegation of simple tasks."
)
prompt_generator.add_resource("File output.")
# Add performance evaluations to the PromptGenerator object
prompt_generator.add_performance_evaluation(
"Continuously review and analyze your actions "
"to ensure you are performing to the best of your abilities."
)
prompt_generator.add_performance_evaluation(
"Constructively self-criticize your big-picture behavior constantly."
)
prompt_generator.add_performance_evaluation(
"Reflect on past decisions and strategies to refine your approach."
)
prompt_generator.add_performance_evaluation(
"Every command has a cost, so be smart and efficient. "
"Aim to complete tasks in the least number of steps."
)
# Generate the prompt string
prompt_string = prompt_generator.generate_prompt_string()
return prompt_string

View File

@ -1,17 +0,0 @@
from langchain_experimental.autonomous_agents.baby_agi.baby_agi import BabyAGI
from langchain_experimental.autonomous_agents.baby_agi.task_creation import (
TaskCreationChain,
)
from langchain_experimental.autonomous_agents.baby_agi.task_execution import (
TaskExecutionChain,
)
from langchain_experimental.autonomous_agents.baby_agi.task_prioritization import (
TaskPrioritizationChain,
)
__all__ = [
"BabyAGI",
"TaskPrioritizationChain",
"TaskExecutionChain",
"TaskCreationChain",
]

View File

@ -1,223 +0,0 @@
"""BabyAGI agent."""
from collections import deque
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.vectorstores import VectorStore
from pydantic import BaseModel, ConfigDict, Field
from langchain_experimental.autonomous_agents.baby_agi.task_creation import (
TaskCreationChain,
)
from langchain_experimental.autonomous_agents.baby_agi.task_execution import (
TaskExecutionChain,
)
from langchain_experimental.autonomous_agents.baby_agi.task_prioritization import (
TaskPrioritizationChain,
)
# This class has a metaclass conflict: both `Chain` and `BaseModel` define a metaclass
# to use, and the two metaclasses attempt to define the same functions but
# in mutually-incompatible ways. It isn't clear how to resolve this,
# and this code predates mypy beginning to perform that check.
#
# Mypy errors:
# ```
# Definition of "__repr_str__" in base class "Representation" is
# incompatible with definition in base class "BaseModel" [misc]
# Definition of "__repr_name__" in base class "Representation" is
# incompatible with definition in base class "BaseModel" [misc]
# Definition of "__rich_repr__" in base class "Representation" is
# incompatible with definition in base class "BaseModel" [misc]
# Definition of "__pretty__" in base class "Representation" is
# incompatible with definition in base class "BaseModel" [misc]
# Metaclass conflict: the metaclass of a derived class must be
# a (non-strict) subclass of the metaclasses of all its bases [misc]
# ```
#
# TODO: look into refactoring this class in a way that avoids the mypy type errors
class BabyAGI(Chain, BaseModel): # type: ignore[misc]
"""Controller model for the BabyAGI agent."""
task_list: deque = Field(default_factory=deque)
task_creation_chain: Chain = Field(...)
task_prioritization_chain: Chain = Field(...)
execution_chain: Chain = Field(...)
task_id_counter: int = Field(1)
vectorstore: VectorStore = Field(init=False)
max_iterations: Optional[int] = None
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def add_task(self, task: Dict) -> None:
self.task_list.append(task)
def print_task_list(self) -> None:
print("\033[95m\033[1m" + "\n*****TASK LIST*****\n" + "\033[0m\033[0m") # noqa: T201
for t in self.task_list:
print(str(t["task_id"]) + ": " + t["task_name"]) # noqa: T201
def print_next_task(self, task: Dict) -> None:
print("\033[92m\033[1m" + "\n*****NEXT TASK*****\n" + "\033[0m\033[0m") # noqa: T201
print(str(task["task_id"]) + ": " + task["task_name"]) # noqa: T201
def print_task_result(self, result: str) -> None:
print("\033[93m\033[1m" + "\n*****TASK RESULT*****\n" + "\033[0m\033[0m") # noqa: T201
print(result) # noqa: T201
@property
def input_keys(self) -> List[str]:
return ["objective"]
@property
def output_keys(self) -> List[str]:
return []
def get_next_task(
self, result: str, task_description: str, objective: str, **kwargs: Any
) -> List[Dict]:
"""Get the next task."""
task_names = [t["task_name"] for t in self.task_list]
incomplete_tasks = ", ".join(task_names)
response = self.task_creation_chain.run(
result=result,
task_description=task_description,
incomplete_tasks=incomplete_tasks,
objective=objective,
**kwargs,
)
new_tasks = response.split("\n")
return [
{"task_name": task_name} for task_name in new_tasks if task_name.strip()
]
def prioritize_tasks(
self, this_task_id: int, objective: str, **kwargs: Any
) -> List[Dict]:
"""Prioritize tasks."""
task_names = [t["task_name"] for t in list(self.task_list)]
next_task_id = int(this_task_id) + 1
response = self.task_prioritization_chain.run(
task_names=", ".join(task_names),
next_task_id=str(next_task_id),
objective=objective,
**kwargs,
)
new_tasks = response.split("\n")
prioritized_task_list = []
for task_string in new_tasks:
if not task_string.strip():
continue
task_parts = task_string.strip().split(".", 1)
if len(task_parts) == 2:
task_id = task_parts[0].strip()
task_name = task_parts[1].strip()
prioritized_task_list.append(
{"task_id": task_id, "task_name": task_name}
)
return prioritized_task_list
def _get_top_tasks(self, query: str, k: int) -> List[str]:
"""Get the top k tasks based on the query."""
results = self.vectorstore.similarity_search(query, k=k)
if not results:
return []
return [str(item.metadata["task"]) for item in results]
def execute_task(self, objective: str, task: str, k: int = 5, **kwargs: Any) -> str:
"""Execute a task."""
context = self._get_top_tasks(query=objective, k=k)
return self.execution_chain.run(
objective=objective, context="\n".join(context), task=task, **kwargs
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the agent."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
objective = inputs["objective"]
first_task = inputs.get("first_task", "Make a todo list")
self.add_task({"task_id": 1, "task_name": first_task})
num_iters = 0
while True:
if self.task_list:
self.print_task_list()
# Step 1: Pull the first task
task = self.task_list.popleft()
self.print_next_task(task)
# Step 2: Execute the task
result = self.execute_task(
objective, task["task_name"], callbacks=_run_manager.get_child()
)
this_task_id = int(task["task_id"])
self.print_task_result(result)
# Step 3: Store the result in Pinecone
result_id = f"result_{task['task_id']}_{num_iters}"
self.vectorstore.add_texts(
texts=[result],
metadatas=[{"task": task["task_name"]}],
ids=[result_id],
)
# Step 4: Create new tasks and reprioritize task list
new_tasks = self.get_next_task(
result,
task["task_name"],
objective,
callbacks=_run_manager.get_child(),
)
for new_task in new_tasks:
self.task_id_counter += 1
new_task.update({"task_id": self.task_id_counter})
self.add_task(new_task)
self.task_list = deque(
self.prioritize_tasks(
this_task_id, objective, callbacks=_run_manager.get_child()
)
)
num_iters += 1
if self.max_iterations is not None and num_iters == self.max_iterations:
print( # noqa: T201
"\033[91m\033[1m" + "\n*****TASK ENDING*****\n" + "\033[0m\033[0m"
)
break
return {}
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
vectorstore: VectorStore,
verbose: bool = False,
task_execution_chain: Optional[Chain] = None,
**kwargs: Any,
) -> "BabyAGI":
"""Initialize the BabyAGI Controller."""
task_creation_chain = TaskCreationChain.from_llm(llm, verbose=verbose)
task_prioritization_chain = TaskPrioritizationChain.from_llm(
llm, verbose=verbose
)
if task_execution_chain is None:
execution_chain: Chain = TaskExecutionChain.from_llm(llm, verbose=verbose)
else:
execution_chain = task_execution_chain
return cls( # type: ignore[call-arg, call-arg, call-arg, call-arg]
task_creation_chain=task_creation_chain,
task_prioritization_chain=task_prioritization_chain,
execution_chain=execution_chain,
vectorstore=vectorstore,
**kwargs,
)

View File

@ -1,31 +0,0 @@
from langchain.chains import LLMChain
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
class TaskCreationChain(LLMChain):
"""Chain generating tasks."""
@classmethod
def from_llm(cls, llm: BaseLanguageModel, verbose: bool = True) -> LLMChain:
"""Get the response parser."""
task_creation_template = (
"You are an task creation AI that uses the result of an execution agent"
" to create new tasks with the following objective: {objective},"
" The last completed task has the result: {result}."
" This result was based on this task description: {task_description}."
" These are incomplete tasks: {incomplete_tasks}."
" Based on the result, create new tasks to be completed"
" by the AI system that do not overlap with incomplete tasks."
" Return the tasks as an array."
)
prompt = PromptTemplate(
template=task_creation_template,
input_variables=[
"result",
"task_description",
"incomplete_tasks",
"objective",
],
)
return cls(prompt=prompt, llm=llm, verbose=verbose)

View File

@ -1,22 +0,0 @@
from langchain.chains import LLMChain
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
class TaskExecutionChain(LLMChain):
"""Chain to execute tasks."""
@classmethod
def from_llm(cls, llm: BaseLanguageModel, verbose: bool = True) -> LLMChain:
"""Get the response parser."""
execution_template = (
"You are an AI who performs one task based on the following objective: "
"{objective}."
"Take into account these previously completed tasks: {context}."
" Your task: {task}. Response:"
)
prompt = PromptTemplate(
template=execution_template,
input_variables=["objective", "context", "task"],
)
return cls(prompt=prompt, llm=llm, verbose=verbose)

View File

@ -1,25 +0,0 @@
from langchain.chains import LLMChain
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
class TaskPrioritizationChain(LLMChain):
"""Chain to prioritize tasks."""
@classmethod
def from_llm(cls, llm: BaseLanguageModel, verbose: bool = True) -> LLMChain:
"""Get the response parser."""
task_prioritization_template = (
"You are a task prioritization AI tasked with cleaning the formatting of "
"and reprioritizing the following tasks: {task_names}."
" Consider the ultimate objective of your team: {objective}."
" Do not remove any tasks. Return the result as a numbered list, like:"
" #. First task"
" #. Second task"
" Start the task list with number {next_task_id}."
)
prompt = PromptTemplate(
template=task_prioritization_template,
input_variables=["task_names", "next_task_id", "objective"],
)
return cls(prompt=prompt, llm=llm, verbose=verbose)

View File

@ -1,34 +0,0 @@
from typing import List
from langchain.base_language import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain_experimental.autonomous_agents.hugginggpt.repsonse_generator import (
load_response_generator,
)
from langchain_experimental.autonomous_agents.hugginggpt.task_executor import (
TaskExecutor,
)
from langchain_experimental.autonomous_agents.hugginggpt.task_planner import (
load_chat_planner,
)
class HuggingGPT:
"""Agent for interacting with HuggingGPT."""
def __init__(self, llm: BaseLanguageModel, tools: List[BaseTool]):
self.llm = llm
self.tools = tools
self.chat_planner = load_chat_planner(llm)
self.response_generator = load_response_generator(llm)
self.task_executor: TaskExecutor
def run(self, input: str) -> str:
plan = self.chat_planner.plan(inputs={"input": input, "hf_tools": self.tools})
self.task_executor = TaskExecutor(plan)
self.task_executor.run()
response = self.response_generator.generate(
{"task_execution": self.task_executor}
)
return response

View File

@ -1,46 +0,0 @@
from typing import Any, List, Optional
from langchain.base_language import BaseLanguageModel
from langchain.chains import LLMChain
from langchain_core.callbacks.manager import Callbacks
from langchain_core.prompts import PromptTemplate
class ResponseGenerationChain(LLMChain):
"""Chain to execute tasks."""
@classmethod
def from_llm(cls, llm: BaseLanguageModel, verbose: bool = True) -> LLMChain:
execution_template = (
"The AI assistant has parsed the user input into several tasks"
"and executed them. The results are as follows:\n"
"{task_execution}"
"\nPlease summarize the results and generate a response."
)
prompt = PromptTemplate(
template=execution_template,
input_variables=["task_execution"],
)
return cls(prompt=prompt, llm=llm, verbose=verbose)
class ResponseGenerator:
"""Generates a response based on the input."""
def __init__(self, llm_chain: LLMChain, stop: Optional[List] = None):
self.llm_chain = llm_chain
self.stop = stop
def generate(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Given input, decided what to do."""
llm_response = self.llm_chain.run(**inputs, stop=self.stop, callbacks=callbacks)
return llm_response
def load_response_generator(llm: BaseLanguageModel) -> ResponseGenerator:
"""Load the ResponseGenerator."""
llm_chain = ResponseGenerationChain.from_llm(llm)
return ResponseGenerator(
llm_chain=llm_chain,
)

View File

@ -1,151 +0,0 @@
import copy
import uuid
from typing import Dict, List
import numpy as np
from langchain_core.tools import BaseTool
from langchain_experimental.autonomous_agents.hugginggpt.task_planner import Plan
class Task:
"""Task to be executed."""
def __init__(self, task: str, id: int, dep: List[int], args: Dict, tool: BaseTool):
self.task = task
self.id = id
self.dep = dep
self.args = args
self.tool = tool
self.status = "pending"
self.message = ""
self.result = ""
def __str__(self) -> str:
return f"{self.task}({self.args})"
def save_product(self) -> None:
import cv2
if self.task == "video_generator":
# ndarray to video
product = np.array(self.product)
nframe, height, width, _ = product.shape
video_filename = uuid.uuid4().hex[:6] + ".mp4"
fps = 30 # Frames per second
fourcc = cv2.VideoWriter_fourcc(*"mp4v") # type: ignore
video_out = cv2.VideoWriter(video_filename, fourcc, fps, (width, height))
for frame in self.product:
video_out.write(frame)
video_out.release()
self.result = video_filename
elif self.task == "image_generator":
# PIL.Image to image
filename = uuid.uuid4().hex[:6] + ".png"
self.product.save(filename) # type: ignore
self.result = filename
def completed(self) -> bool:
return self.status == "completed"
def failed(self) -> bool:
return self.status == "failed"
def pending(self) -> bool:
return self.status == "pending"
def run(self) -> str:
from diffusers.utils import load_image
try:
new_args = copy.deepcopy(self.args)
for k, v in new_args.items():
if k == "image":
new_args["image"] = load_image(v)
if self.task in ["video_generator", "image_generator", "text_reader"]:
self.product = self.tool(**new_args)
else:
self.result = self.tool(**new_args)
except Exception as e:
self.status = "failed"
self.message = str(e)
return self.message
self.status = "completed"
self.save_product()
return self.result
class TaskExecutor:
"""Load tools and execute tasks."""
def __init__(self, plan: Plan):
self.plan = plan
self.tasks = []
self.id_task_map = {}
self.status = "pending"
for step in self.plan.steps:
task = Task(step.task, step.id, step.dep, step.args, step.tool)
self.tasks.append(task)
self.id_task_map[step.id] = task
def completed(self) -> bool:
return all(task.completed() for task in self.tasks)
def failed(self) -> bool:
return any(task.failed() for task in self.tasks)
def pending(self) -> bool:
return any(task.pending() for task in self.tasks)
def check_dependency(self, task: Task) -> bool:
for dep_id in task.dep:
if dep_id == -1:
continue
dep_task = self.id_task_map[dep_id]
if dep_task.failed() or dep_task.pending():
return False
return True
def update_args(self, task: Task) -> None:
for dep_id in task.dep:
if dep_id == -1:
continue
dep_task = self.id_task_map[dep_id]
for k, v in task.args.items():
if f"<resource-{dep_id}>" in v:
task.args[k] = task.args[k].replace(
f"<resource-{dep_id}>", dep_task.result
)
def run(self) -> str:
for task in self.tasks:
print(f"running {task}") # noqa: T201
if task.pending() and self.check_dependency(task):
self.update_args(task)
task.run()
if self.completed():
self.status = "completed"
elif self.failed():
self.status = "failed"
else:
self.status = "pending"
return self.status
def __str__(self) -> str:
result = ""
for task in self.tasks:
result += f"{task}\n"
result += f"status: {task.status}\n"
if task.failed():
result += f"message: {task.message}\n"
if task.completed():
result += f"result: {task.result}\n"
return result
def __repr__(self) -> str:
return self.__str__()
def describe(self) -> str:
return self.__str__()

View File

@ -1,174 +0,0 @@
import json
import re
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Union
from langchain.base_language import BaseLanguageModel
from langchain.chains import LLMChain
from langchain_core.callbacks.manager import Callbacks
from langchain_core.prompts.chat import (
AIMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_core.tools import BaseTool
from pydantic import BaseModel
DEMONSTRATIONS = [
{
"role": "user",
"content": "please show me a video and an image of (based on the text) 'a boy is running' and dub it", # noqa: E501
},
{
"role": "assistant",
"content": '[{{"task": "video_generator", "id": 0, "dep": [-1], "args": {{"prompt": "a boy is running" }}}}, {{"task": "text_reader", "id": 1, "dep": [-1], "args": {{"text": "a boy is running" }}}}, {{"task": "image_generator", "id": 2, "dep": [-1], "args": {{"prompt": "a boy is running" }}}}]', # noqa: E501
},
{
"role": "user",
"content": "Give you some pictures e1.jpg, e2.png, e3.jpg, help me count the number of sheep?", # noqa: E501
},
{
"role": "assistant",
"content": '[ {{"task": "image_qa", "id": 0, "dep": [-1], "args": {{"image": "e1.jpg", "question": "How many sheep in the picture"}}}}, {{"task": "image_qa", "id": 1, "dep": [-1], "args": {{"image": "e2.jpg", "question": "How many sheep in the picture"}}}}, {{"task": "image_qa", "id": 2, "dep": [-1], "args": {{"image": "e3.jpg", "question": "How many sheep in the picture"}}}}]', # noqa: E501
},
]
class TaskPlaningChain(LLMChain):
"""Chain to execute tasks."""
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
demos: List[Dict] = DEMONSTRATIONS,
verbose: bool = True,
) -> LLMChain:
"""Get the response parser."""
system_template = """#1 Task Planning Stage: The AI assistant can parse user input to several tasks: [{{"task": task, "id": task_id, "dep": dependency_task_id, "args": {{"input name": text may contain <resource-dep_id>}}}}]. The special tag "dep_id" refer to the one generated text/image/audio in the dependency task (Please consider whether the dependency task generates resources of this type.) and "dep_id" must be in "dep" list. The "dep" field denotes the ids of the previous prerequisite tasks which generate a new resource that the current task relies on. The task MUST be selected from the following tools (along with tool description, input name and output type): {tools}. There may be multiple tasks of the same type. Think step by step about all the tasks needed to resolve the user's request. Parse out as few tasks as possible while ensuring that the user request can be resolved. Pay attention to the dependencies and order among tasks. If the user input can't be parsed, you need to reply empty JSON [].""" # noqa: E501
human_template = """Now I input: {input}."""
system_message_prompt = SystemMessagePromptTemplate.from_template(
system_template
)
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
demo_messages: List[
Union[HumanMessagePromptTemplate, AIMessagePromptTemplate]
] = []
for demo in demos:
if demo["role"] == "user":
demo_messages.append(
HumanMessagePromptTemplate.from_template(demo["content"])
)
else:
demo_messages.append(
AIMessagePromptTemplate.from_template(demo["content"])
)
# demo_messages.append(message)
prompt = ChatPromptTemplate.from_messages(
[system_message_prompt, *demo_messages, human_message_prompt]
)
return cls(prompt=prompt, llm=llm, verbose=verbose)
class Step:
"""A step in the plan."""
def __init__(
self, task: str, id: int, dep: List[int], args: Dict[str, str], tool: BaseTool
):
self.task = task
self.id = id
self.dep = dep
self.args = args
self.tool = tool
class Plan:
"""A plan to execute."""
def __init__(self, steps: List[Step]):
self.steps = steps
def __str__(self) -> str:
return str([str(step) for step in self.steps])
def __repr__(self) -> str:
return str(self)
class BasePlanner(BaseModel):
"""Base class for a planner."""
@abstractmethod
def plan(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> Plan:
"""Given input, decide what to do."""
@abstractmethod
async def aplan(
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
) -> Plan:
"""Asynchronous Given input, decide what to do."""
class PlanningOutputParser(BaseModel):
"""Parses the output of the planning stage."""
def parse(self, text: str, hf_tools: List[BaseTool]) -> Plan:
"""Parse the output of the planning stage.
Args:
text: The output of the planning stage.
hf_tools: The tools available.
Returns:
The plan.
"""
steps = []
for v in json.loads(re.findall(r"\[.*\]", text)[0]):
choose_tool = None
for tool in hf_tools:
if tool.name == v["task"]:
choose_tool = tool
break
if choose_tool:
steps.append(Step(v["task"], v["id"], v["dep"], v["args"], tool))
return Plan(steps=steps)
class TaskPlanner(BasePlanner):
"""Planner for tasks."""
llm_chain: LLMChain
output_parser: PlanningOutputParser
stop: Optional[List] = None
def plan(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> Plan:
"""Given input, decided what to do."""
inputs["tools"] = [
f"{tool.name}: {tool.description}" for tool in inputs["hf_tools"]
]
llm_response = self.llm_chain.run(**inputs, stop=self.stop, callbacks=callbacks)
return self.output_parser.parse(llm_response, inputs["hf_tools"])
async def aplan(
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
) -> Plan:
"""Asynchronous Given input, decided what to do."""
inputs["hf_tools"] = [
f"{tool.name}: {tool.description}" for tool in inputs["hf_tools"]
]
llm_response = await self.llm_chain.arun(
**inputs, stop=self.stop, callbacks=callbacks
)
return self.output_parser.parse(llm_response, inputs["hf_tools"])
def load_chat_planner(llm: BaseLanguageModel) -> TaskPlanner:
"""Load the chat planner."""
llm_chain = TaskPlaningChain.from_llm(llm)
return TaskPlanner(llm_chain=llm_chain, output_parser=PlanningOutputParser())

View File

@ -1,27 +0,0 @@
"""**Chat Models** are a variation on language models.
While Chat Models use language models under the hood, the interface they expose
is a bit different. Rather than expose a "text in, text out" API, they expose
an interface where "chat messages" are the inputs and outputs.
**Class hierarchy:**
.. code-block::
BaseLanguageModel --> BaseChatModel --> <name> # Examples: ChatOpenAI, ChatGooglePalm
**Main helpers:**
.. code-block::
AIMessage, BaseMessage, HumanMessage
""" # noqa: E501
from langchain_experimental.chat_models.llm_wrapper import (
Llama2Chat,
Mixtral,
Orca,
Vicuna,
)
__all__ = ["Llama2Chat", "Orca", "Vicuna", "Mixtral"]

View File

@ -1,196 +0,0 @@
"""Generic Wrapper for chat LLMs, with sample implementations
for Llama-2-chat, Llama-2-instruct and Vicuna models.
"""
from typing import Any, List, Optional, cast
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatResult,
HumanMessage,
LLMResult,
SystemMessage,
)
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LLM, BaseChatModel
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" # noqa: E501
class ChatWrapper(BaseChatModel):
"""Wrapper for chat LLMs."""
llm: LLM
sys_beg: str
sys_end: str
ai_n_beg: str
ai_n_end: str
usr_n_beg: str
usr_n_end: str
usr_0_beg: Optional[str] = None
usr_0_end: Optional[str] = None
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
llm_result = await self.llm._agenerate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
def _to_chat_prompt(
self,
messages: List[BaseMessage],
) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
if not messages:
raise ValueError("at least one HumanMessage must be provided")
if not isinstance(messages[0], SystemMessage):
messages = [self.system_message] + messages
if not isinstance(messages[1], HumanMessage):
raise ValueError(
"messages list must start with a SystemMessage or UserMessage"
)
if not isinstance(messages[-1], HumanMessage):
raise ValueError("last message must be a HumanMessage")
prompt_parts = []
if self.usr_0_beg is None:
self.usr_0_beg = self.usr_n_beg
if self.usr_0_end is None:
self.usr_0_end = self.usr_n_end
prompt_parts.append(
self.sys_beg + cast(str, messages[0].content) + self.sys_end
)
prompt_parts.append(
self.usr_0_beg + cast(str, messages[1].content) + self.usr_0_end
)
for ai_message, human_message in zip(messages[2::2], messages[3::2]):
if not isinstance(ai_message, AIMessage) or not isinstance(
human_message, HumanMessage
):
raise ValueError(
"messages must be alternating human- and ai-messages, "
"optionally prepended by a system message"
)
prompt_parts.append(
self.ai_n_beg + cast(str, ai_message.content) + self.ai_n_end
)
prompt_parts.append(
self.usr_n_beg + cast(str, human_message.content) + self.usr_n_end
)
return "".join(prompt_parts)
@staticmethod
def _to_chat_result(llm_result: LLMResult) -> ChatResult:
chat_generations = []
for g in llm_result.generations[0]:
chat_generation = ChatGeneration(
message=AIMessage(content=g.text), generation_info=g.generation_info
)
chat_generations.append(chat_generation)
return ChatResult(
generations=chat_generations, llm_output=llm_result.llm_output
)
class Llama2Chat(ChatWrapper):
"""Wrapper for Llama-2-chat model."""
@property
def _llm_type(self) -> str:
return "llama-2-chat"
sys_beg: str = "<s>[INST] <<SYS>>\n"
sys_end: str = "\n<</SYS>>\n\n"
ai_n_beg: str = " "
ai_n_end: str = " </s>"
usr_n_beg: str = "<s>[INST] "
usr_n_end: str = " [/INST]"
usr_0_beg: str = ""
usr_0_end: str = " [/INST]"
class Mixtral(ChatWrapper):
"""See https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1#instruction-format"""
@property
def _llm_type(self) -> str:
return "mixtral"
sys_beg: str = "<s>[INST] "
sys_end: str = "\n"
ai_n_beg: str = " "
ai_n_end: str = " </s>"
usr_n_beg: str = " [INST] "
usr_n_end: str = " [/INST]"
usr_0_beg: str = ""
usr_0_end: str = " [/INST]"
class Orca(ChatWrapper):
"""Wrapper for Orca-style models."""
@property
def _llm_type(self) -> str:
return "orca-style"
sys_beg: str = "### System:\n"
sys_end: str = "\n\n"
ai_n_beg: str = "### Assistant:\n"
ai_n_end: str = "\n\n"
usr_n_beg: str = "### User:\n"
usr_n_end: str = "\n\n"
class Vicuna(ChatWrapper):
"""Wrapper for Vicuna-style models."""
@property
def _llm_type(self) -> str:
return "vicuna-style"
sys_beg: str = ""
sys_end: str = " "
ai_n_beg: str = "ASSISTANT: "
ai_n_end: str = " </s>"
usr_n_beg: str = "USER: "
usr_n_end: str = " "

View File

@ -1,52 +0,0 @@
"""
**Comprehend Moderation** is used to detect and handle `Personally Identifiable Information (PII)`,
`toxicity`, and `prompt safety` in text.
The Langchain experimental package includes the **AmazonComprehendModerationChain** class
for the comprehend moderation tasks. It is based on `Amazon Comprehend` service.
This class can be configured with specific moderation settings like PII labels, redaction,
toxicity thresholds, and prompt safety thresholds.
See more at https://aws.amazon.com/comprehend/
`Amazon Comprehend` service is used by several other classes:
- **ComprehendToxicity** class is used to check the toxicity of text prompts using
`AWS Comprehend service` and take actions based on the configuration
- **ComprehendPromptSafety** class is used to validate the safety of given prompt
text, raising an error if unsafe content is detected based on the specified threshold
- **ComprehendPII** class is designed to handle
`Personally Identifiable Information (PII)` moderation tasks,
detecting and managing PII entities in text inputs
""" # noqa: E501
from langchain_experimental.comprehend_moderation.amazon_comprehend_moderation import (
AmazonComprehendModerationChain,
)
from langchain_experimental.comprehend_moderation.base_moderation import BaseModeration
from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
BaseModerationCallbackHandler,
)
from langchain_experimental.comprehend_moderation.base_moderation_config import (
BaseModerationConfig,
ModerationPiiConfig,
ModerationPromptSafetyConfig,
ModerationToxicityConfig,
)
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
from langchain_experimental.comprehend_moderation.prompt_safety import (
ComprehendPromptSafety,
)
from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxicity
__all__ = [
"BaseModeration",
"ComprehendPII",
"ComprehendPromptSafety",
"ComprehendToxicity",
"BaseModerationConfig",
"ModerationPiiConfig",
"ModerationToxicityConfig",
"ModerationPromptSafetyConfig",
"BaseModerationCallbackHandler",
"AmazonComprehendModerationChain",
]

View File

@ -1,192 +0,0 @@
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from pydantic import model_validator
from langchain_experimental.comprehend_moderation.base_moderation import BaseModeration
from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
BaseModerationCallbackHandler,
)
from langchain_experimental.comprehend_moderation.base_moderation_config import (
BaseModerationConfig,
)
class AmazonComprehendModerationChain(Chain):
"""Moderation Chain, based on `Amazon Comprehend` service.
See more at https://aws.amazon.com/comprehend/
"""
output_key: str = "output" #: :meta private:
"""Key used to fetch/store the output in data containers. Defaults to `output`"""
input_key: str = "input" #: :meta private:
"""Key used to fetch/store the input in data containers. Defaults to `input`"""
moderation_config: BaseModerationConfig = BaseModerationConfig()
"""
Configuration settings for moderation,
defaults to BaseModerationConfig with default values
"""
client: Optional[Any] = None
"""boto3 client object for connection to Amazon Comprehend"""
region_name: Optional[str] = None
"""The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable
or region specified in ~/.aws/config in case it is not provided here.
"""
credentials_profile_name: Optional[str] = None
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
has either access keys or role information specified.
If not specified, the default credential profile or, if on an EC2 instance,
credentials from IMDS will be used.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
"""
moderation_callback: Optional[BaseModerationCallbackHandler] = None
"""Callback handler for moderation, this is different
from regular callbacks which can be used in addition to this."""
unique_id: Optional[str] = None
"""A unique id that can be used to identify or group a user or session"""
@model_validator(mode="before")
@classmethod
def create_client(cls, values: Dict[str, Any]) -> Any:
"""
Creates an Amazon Comprehend client.
Args:
values (Dict[str, Any]): A dictionary containing configuration values.
Returns:
Dict[str, Any]: A dictionary with the updated configuration values,
including the Amazon Comprehend client.
Raises:
ModuleNotFoundError: If the 'boto3' package is not installed.
ValueError: If there is an issue importing 'boto3' or loading
AWS credentials.
Example:
.. code-block:: python
config = {
"credentials_profile_name": "my-profile",
"region_name": "us-west-2"
}
updated_config = create_client(config)
comprehend_client = updated_config["client"]
"""
if values.get("client") is not None:
return values
try:
import boto3
if values.get("credentials_profile_name"):
session = boto3.Session(profile_name=values["credentials_profile_name"])
else:
# use default credentials
session = boto3.Session()
client_params = {}
if values.get("region_name"):
client_params["region_name"] = values["region_name"]
values["client"] = session.client("comprehend", **client_params)
return values
except ImportError:
raise ModuleNotFoundError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
f"profile name are valid. {e}"
) from e
@property
def output_keys(self) -> List[str]:
"""
Returns a list of output keys.
This method defines the output keys that will be used to access the output
values produced by the chain or function. It ensures that the specified keys
are available to access the outputs.
Returns:
List[str]: A list of output keys.
Note:
This method is considered private and may not be intended for direct
external use.
"""
return [self.output_key]
@property
def input_keys(self) -> List[str]:
"""
Returns a list of input keys expected by the prompt.
This method defines the input keys that the prompt expects in order to perform
its processing. It ensures that the specified keys are available for providing
input to the prompt.
Returns:
List[str]: A list of input keys.
Note:
This method is considered private and may not be intended for direct
external use.
"""
return [self.input_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""
Executes the moderation process on the input text and returns the processed
output.
This internal method performs the moderation process on the input text. It
converts the input prompt value to plain text, applies the specified filters,
and then converts the filtered output back to a suitable prompt value object.
Additionally, it provides the option to log information about the run using
the provided `run_manager`.
Args:
inputs: A dictionary containing input values
run_manager: A run manager to handle run-related events. Default is None
Returns:
Dict[str, str]: A dictionary containing the processed output of the
moderation process.
Raises:
ValueError: If there is an error during the moderation process
"""
if run_manager:
run_manager.on_text("Running AmazonComprehendModerationChain...\n")
moderation = BaseModeration(
client=self.client,
config=self.moderation_config,
moderation_callback=self.moderation_callback,
unique_id=self.unique_id,
run_manager=run_manager,
)
response = moderation.moderate(prompt=inputs[self.input_keys[0]])
return {self.output_key: response}

View File

@ -1,185 +0,0 @@
import uuid
from typing import Any, Callable, Optional, cast
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
from langchain_experimental.comprehend_moderation.prompt_safety import (
ComprehendPromptSafety,
)
from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxicity
class BaseModeration:
"""Base class for moderation."""
def __init__(
self,
client: Any,
config: Optional[Any] = None,
moderation_callback: Optional[Any] = None,
unique_id: Optional[str] = None,
run_manager: Optional[CallbackManagerForChainRun] = None,
):
self.client = client
self.config = config
self.moderation_callback = moderation_callback
self.unique_id = unique_id
self.chat_message_index = 0
self.run_manager = run_manager
self.chain_id = str(uuid.uuid4())
def _convert_prompt_to_text(self, prompt: Any) -> str:
input_text = str()
if isinstance(prompt, StringPromptValue):
input_text = prompt.text
elif isinstance(prompt, str):
input_text = prompt
elif isinstance(prompt, ChatPromptValue):
"""
We will just check the last message in the message Chain of a
ChatPromptTemplate. The typical chronology is
SystemMessage > HumanMessage > AIMessage and so on. However assuming
that with every chat the chain is invoked we will only check the last
message. This is assuming that all previous messages have been checked
already. Only HumanMessage and AIMessage will be checked. We can perhaps
loop through and take advantage of the additional_kwargs property in the
HumanMessage and AIMessage schema to mark messages that have been moderated.
However that means that this class could generate multiple text chunks
and moderate() logics would need to be updated. This also means some
complexity in re-constructing the prompt while keeping the messages in
sequence.
"""
message = prompt.messages[-1]
self.chat_message_index = len(prompt.messages) - 1
if isinstance(message, HumanMessage):
input_text = cast(str, message.content)
if isinstance(message, AIMessage):
input_text = cast(str, message.content)
else:
raise ValueError(
f"Invalid input type {type(input_text)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
return input_text
def _convert_text_to_prompt(self, prompt: Any, text: str) -> Any:
if isinstance(prompt, StringPromptValue):
return StringPromptValue(text=text)
elif isinstance(prompt, str):
return text
elif isinstance(prompt, ChatPromptValue):
# Copy the messages because we may need to mutate them.
# We don't want to mutate data we don't own.
messages = list(prompt.messages)
message = messages[self.chat_message_index]
if isinstance(message, HumanMessage):
messages[self.chat_message_index] = HumanMessage(
content=text,
example=message.example,
additional_kwargs=message.additional_kwargs,
)
if isinstance(message, AIMessage):
messages[self.chat_message_index] = AIMessage(
content=text,
example=message.example,
additional_kwargs=message.additional_kwargs,
)
return ChatPromptValue(messages=messages)
else:
raise ValueError(
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
def _moderation_class(self, moderation_class: Any) -> Callable:
return moderation_class(
client=self.client,
callback=self.moderation_callback,
unique_id=self.unique_id,
chain_id=self.chain_id,
).validate
def _log_message_for_verbose(self, message: str) -> None:
if self.run_manager:
self.run_manager.on_text(message)
def moderate(self, prompt: Any) -> str:
"""Moderate the input prompt."""
from langchain_experimental.comprehend_moderation.base_moderation_config import ( # noqa: E501
ModerationPiiConfig,
ModerationPromptSafetyConfig,
ModerationToxicityConfig,
)
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( # noqa: E501
ModerationPiiError,
ModerationPromptSafetyError,
ModerationToxicityError,
)
try:
# convert prompt to text
input_text = self._convert_prompt_to_text(prompt=prompt)
output_text = str()
# perform moderation
filter_functions = {
"pii": ComprehendPII,
"toxicity": ComprehendToxicity,
"prompt_safety": ComprehendPromptSafety,
}
filters = self.config.filters # type: ignore
for _filter in filters:
filter_name = (
"pii"
if isinstance(_filter, ModerationPiiConfig)
else (
"toxicity"
if isinstance(_filter, ModerationToxicityConfig)
else (
"prompt_safety"
if isinstance(_filter, ModerationPromptSafetyConfig)
else None
)
)
)
if filter_name in filter_functions:
self._log_message_for_verbose(
f"Running {filter_name} Validation...\n"
)
validation_fn = self._moderation_class(
moderation_class=filter_functions[filter_name]
)
input_text = input_text if not output_text else output_text
output_text = validation_fn(
prompt_value=input_text,
config=_filter.dict(),
)
# convert text to prompt and return
return self._convert_text_to_prompt(prompt=prompt, text=output_text)
except ModerationPiiError as e:
self._log_message_for_verbose(f"Found PII content..stopping..\n{str(e)}\n")
raise e
except ModerationToxicityError as e:
self._log_message_for_verbose(
f"Found Toxic content..stopping..\n{str(e)}\n"
)
raise e
except ModerationPromptSafetyError as e:
self._log_message_for_verbose(
f"Found Harmful intention..stopping..\n{str(e)}\n"
)
raise e
except Exception as e:
raise e

View File

@ -1,67 +0,0 @@
from typing import Any, Callable, Dict
class BaseModerationCallbackHandler:
"""Base class for moderation callback handlers."""
def __init__(self) -> None:
if (
self._is_method_unchanged(
BaseModerationCallbackHandler.on_after_pii, self.on_after_pii
)
and self._is_method_unchanged(
BaseModerationCallbackHandler.on_after_toxicity, self.on_after_toxicity
)
and self._is_method_unchanged(
BaseModerationCallbackHandler.on_after_prompt_safety,
self.on_after_prompt_safety,
)
):
raise NotImplementedError(
"Subclasses must override at least one of on_after_pii(), "
"on_after_toxicity(), or on_after_prompt_safety() functions."
)
def _is_method_unchanged(
self, base_method: Callable, derived_method: Callable
) -> bool:
return base_method.__qualname__ == derived_method.__qualname__
async def on_after_pii(
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
) -> None:
"""Run after PII validation is complete."""
pass
async def on_after_toxicity(
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
) -> None:
"""Run after Toxicity validation is complete."""
pass
async def on_after_prompt_safety(
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
) -> None:
"""Run after Prompt Safety validation is complete."""
pass
@property
def pii_callback(self) -> bool:
return (
self.on_after_pii.__func__ # type: ignore
is not BaseModerationCallbackHandler.on_after_pii
)
@property
def toxicity_callback(self) -> bool:
return (
self.on_after_toxicity.__func__ # type: ignore
is not BaseModerationCallbackHandler.on_after_toxicity
)
@property
def prompt_safety_callback(self) -> bool:
return (
self.on_after_prompt_safety.__func__ # type: ignore
is not BaseModerationCallbackHandler.on_after_prompt_safety
)

View File

@ -1,61 +0,0 @@
from typing import List, Union
from pydantic import BaseModel
class ModerationPiiConfig(BaseModel):
"""Configuration for PII moderation filter."""
threshold: float = 0.5
"""Threshold for PII confidence score, defaults to 0.5 i.e. 50%"""
labels: List[str] = []
"""
List of PII Universal Labels.
Defaults to `list[]`
"""
redact: bool = False
"""Whether to perform redaction of detected PII entities"""
mask_character: str = "*"
"""Redaction mask character in case redact=True, defaults to asterisk (*)"""
class ModerationToxicityConfig(BaseModel):
"""Configuration for Toxicity moderation filter."""
threshold: float = 0.5
"""Threshold for Toxic label confidence score, defaults to 0.5 i.e. 50%"""
labels: List[str] = []
"""List of toxic labels, defaults to `list[]`"""
class ModerationPromptSafetyConfig(BaseModel):
"""Configuration for Prompt Safety moderation filter."""
threshold: float = 0.5
"""
Threshold for Prompt Safety classification
confidence score, defaults to 0.5 i.e. 50%
"""
class BaseModerationConfig(BaseModel):
"""Base configuration settings for moderation."""
filters: List[
Union[
ModerationPiiConfig, ModerationToxicityConfig, ModerationPromptSafetyConfig
]
] = [
ModerationPiiConfig(),
ModerationToxicityConfig(),
ModerationPromptSafetyConfig(),
]
"""
Filters applied to the moderation chain, defaults to
`[ModerationPiiConfig(), ModerationToxicityConfig(),
ModerationPromptSafetyConfig()]`
"""

View File

@ -1,41 +0,0 @@
class ModerationPiiError(Exception):
"""Exception raised if PII entities are detected.
Attributes:
message -- explanation of the error
"""
def __init__(
self, message: str = "The prompt contains PII entities and cannot be processed"
):
self.message = message
super().__init__(self.message)
class ModerationToxicityError(Exception):
"""Exception raised if Toxic entities are detected.
Attributes:
message -- explanation of the error
"""
def __init__(
self, message: str = "The prompt contains toxic content and cannot be processed"
):
self.message = message
super().__init__(self.message)
class ModerationPromptSafetyError(Exception):
"""Exception raised if Unsafe prompts are detected.
Attributes:
message -- explanation of the error
"""
def __init__(
self,
message: str = ("The prompt is unsafe and cannot be processed"),
):
self.message = message
super().__init__(self.message)

View File

@ -1,166 +0,0 @@
import asyncio
from typing import Any, Dict, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationPiiError,
)
class ComprehendPII:
"""Class to handle Personally Identifiable Information (PII) moderation."""
def __init__(
self,
client: Any,
callback: Optional[Any] = None,
unique_id: Optional[str] = None,
chain_id: Optional[str] = None,
) -> None:
self.client = client
self.moderation_beacon = {
"moderation_chain_id": chain_id,
"moderation_type": "PII",
"moderation_status": "LABELS_NOT_FOUND",
}
self.callback = callback
self.unique_id = unique_id
def validate(self, prompt_value: str, config: Any = None) -> str:
redact = config.get("redact")
return (
self._detect_pii(prompt_value=prompt_value, config=config)
if redact
else self._contains_pii(prompt_value=prompt_value, config=config)
)
def _contains_pii(self, prompt_value: str, config: Any = None) -> str:
"""
Checks for Personally Identifiable Information (PII) labels above a
specified threshold. Uses Amazon Comprehend Contains PII Entities API. See -
https://docs.aws.amazon.com/comprehend/latest/APIReference/API_ContainsPiiEntities.html
Args:
prompt_value (str): The input text to be checked for PII labels.
config (Dict[str, Any]): Configuration for PII check and actions.
Returns:
str: the original prompt
Note:
- The provided client should be initialized with valid AWS credentials.
"""
pii_identified = self.client.contains_pii_entities(
Text=prompt_value, LanguageCode="en"
)
if self.callback and self.callback.pii_callback:
self.moderation_beacon["moderation_input"] = prompt_value
self.moderation_beacon["moderation_output"] = pii_identified
threshold = config.get("threshold")
pii_labels = config.get("labels")
pii_found = False
for entity in pii_identified["Labels"]:
if (entity["Score"] >= threshold and entity["Name"] in pii_labels) or (
entity["Score"] >= threshold and not pii_labels
):
pii_found = True
break
if self.callback and self.callback.pii_callback:
if pii_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
)
if pii_found:
raise ModerationPiiError
return prompt_value
def _detect_pii(self, prompt_value: str, config: Optional[Dict[str, Any]]) -> str:
"""
Detects and handles Personally Identifiable Information (PII) entities in the
given prompt text using Amazon Comprehend's detect_pii_entities API. The
function provides options to redact or stop processing based on the identified
PII entities and a provided configuration. Uses Amazon Comprehend Detect PII
Entities API.
Args:
prompt_value (str): The input text to be checked for PII entities.
config (Dict[str, Any]): A configuration specifying how to handle
PII entities.
Returns:
str: The processed prompt text with redacted PII entities or raised
exceptions.
Raises:
ValueError: If the prompt contains configured PII entities for
stopping processing.
Note:
- If PII is not found in the prompt, the original prompt is returned.
- The client should be initialized with valid AWS credentials.
"""
pii_identified = self.client.detect_pii_entities(
Text=prompt_value, LanguageCode="en"
)
if self.callback and self.callback.pii_callback:
self.moderation_beacon["moderation_input"] = prompt_value
self.moderation_beacon["moderation_output"] = pii_identified
if (pii_identified["Entities"]) == []:
if self.callback and self.callback.pii_callback:
asyncio.create_task(
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
)
return prompt_value
pii_found = False
if not config and pii_identified["Entities"]:
for entity in pii_identified["Entities"]:
if entity["Score"] >= 0.5:
pii_found = True
break
if self.callback and self.callback.pii_callback:
if pii_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
)
if pii_found:
raise ModerationPiiError
else:
threshold = config.get("threshold") # type: ignore
pii_labels = config.get("labels") # type: ignore
mask_marker = config.get("mask_character") # type: ignore
pii_found = False
for entity in pii_identified["Entities"]:
if (
pii_labels
and entity["Type"] in pii_labels
and entity["Score"] >= threshold
) or (not pii_labels and entity["Score"] >= threshold):
pii_found = True
char_offset_begin = entity["BeginOffset"]
char_offset_end = entity["EndOffset"]
mask_length = char_offset_end - char_offset_begin + 1
masked_part = mask_marker * mask_length
prompt_value = (
prompt_value[:char_offset_begin]
+ masked_part
+ prompt_value[char_offset_end + 1 :]
)
if self.callback and self.callback.pii_callback:
if pii_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
)
return prompt_value

View File

@ -1,89 +0,0 @@
import asyncio
from typing import Any, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationPromptSafetyError,
)
class ComprehendPromptSafety:
"""Class to handle prompt safety moderation."""
def __init__(
self,
client: Any,
callback: Optional[Any] = None,
unique_id: Optional[str] = None,
chain_id: Optional[str] = None,
) -> None:
self.client = client
self.moderation_beacon = {
"moderation_chain_id": chain_id,
"moderation_type": "PromptSafety",
"moderation_status": "LABELS_NOT_FOUND",
}
self.callback = callback
self.unique_id = unique_id
def _get_arn(self) -> str:
region_name = self.client.meta.region_name
service = "comprehend"
prompt_safety_endpoint = "document-classifier-endpoint/prompt-safety"
return f"arn:aws:{service}:{region_name}:aws:{prompt_safety_endpoint}"
def validate(self, prompt_value: str, config: Any = None) -> str:
"""
Check and validate the safety of the given prompt text.
Args:
prompt_value (str): The input text to be checked for unsafe text.
config (Dict[str, Any]): Configuration settings for prompt safety checks.
Raises:
ValueError: If unsafe prompt is found in the prompt text based
on the specified threshold.
Returns:
str: The input prompt_value.
Note:
This function checks the safety of the provided prompt text using
Comprehend's classify_document API and raises an error if unsafe
text is detected with a score above the specified threshold.
Example:
comprehend_client = boto3.client('comprehend')
prompt_text = "Please tell me your credit card information."
config = {"threshold": 0.7}
checked_prompt = check_prompt_safety(comprehend_client, prompt_text, config)
"""
threshold = config.get("threshold")
unsafe_prompt = False
endpoint_arn = self._get_arn()
response = self.client.classify_document(
Text=prompt_value, EndpointArn=endpoint_arn
)
if self.callback and self.callback.prompt_safety_callback:
self.moderation_beacon["moderation_input"] = prompt_value
self.moderation_beacon["moderation_output"] = response
for class_result in response["Classes"]:
if (
class_result["Score"] >= threshold
and class_result["Name"] == "UNSAFE_PROMPT"
):
unsafe_prompt = True
break
if self.callback and self.callback.intent_callback:
if unsafe_prompt:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_intent(self.moderation_beacon, self.unique_id)
)
if unsafe_prompt:
raise ModerationPromptSafetyError
return prompt_value

View File

@ -1,167 +0,0 @@
import asyncio
import importlib
from typing import Any, List, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationToxicityError,
)
class ComprehendToxicity:
"""Class to handle toxicity moderation."""
def __init__(
self,
client: Any,
callback: Optional[Any] = None,
unique_id: Optional[str] = None,
chain_id: Optional[str] = None,
) -> None:
self.client = client
self.moderation_beacon = {
"moderation_chain_id": chain_id,
"moderation_type": "Toxicity",
"moderation_status": "LABELS_NOT_FOUND",
}
self.callback = callback
self.unique_id = unique_id
def _toxicity_init_validate(self, max_size: int) -> Any:
"""
Validate and initialize toxicity processing configuration.
Args:
max_size (int): Maximum sentence size defined in the
configuration object.
Raises:
Exception: If the maximum sentence size exceeds the 5KB limit.
Note:
This function ensures that the NLTK punkt tokenizer is downloaded
if not already present.
Returns:
None
"""
if max_size > 1024 * 5:
raise Exception("The sentence length should not exceed 5KB.")
try:
nltk = importlib.import_module("nltk")
nltk.data.find("tokenizers/punkt")
return nltk
except ImportError:
raise ModuleNotFoundError(
"Could not import nltk python package. "
"Please install it with `pip install nltk`."
)
except LookupError:
nltk.download("punkt")
def _split_paragraph(
self, prompt_value: str, max_size: int = 1024 * 4
) -> List[List[str]]:
"""
Split a paragraph into chunks of sentences, respecting the maximum size limit.
Args:
paragraph (str): The input paragraph to be split into chunks.
max_size (int, optional): The maximum size limit in bytes for
each chunk. Defaults to 1024.
Returns:
List[List[str]]: A list of chunks, where each chunk is a list
of sentences.
Note:
This function validates the maximum sentence size based on service
limits using the 'toxicity_init_validate' function. It uses the NLTK
sentence tokenizer to split the paragraph into sentences.
Example:
paragraph = "This is a sample paragraph. It
contains multiple sentences. ..."
chunks = split_paragraph(paragraph, max_size=2048)
"""
# validate max. sentence size based on Service limits
nltk = self._toxicity_init_validate(max_size)
sentences = nltk.sent_tokenize(prompt_value)
chunks = list() # type: ignore
current_chunk = list() # type: ignore
current_size = 0
for sentence in sentences:
sentence_size = len(sentence.encode("utf-8"))
# If adding a new sentence exceeds max_size
# or current_chunk has 10 sentences, start a new chunk
if (current_size + sentence_size > max_size) or (len(current_chunk) >= 10):
if current_chunk: # Avoid appending empty chunks
chunks.append(current_chunk)
current_chunk = []
current_size = 0
current_chunk.append(sentence)
current_size += sentence_size
# Add any remaining sentences
if current_chunk:
chunks.append(current_chunk)
return chunks
def validate(self, prompt_value: str, config: Any = None) -> str:
"""
Check the toxicity of a given text prompt using AWS
Comprehend service and apply actions based on configuration.
Args:
prompt_value (str): The text content to be checked for toxicity.
config (Dict[str, Any]): Configuration for toxicity checks and actions.
Returns:
str: The original prompt_value if allowed or no toxicity found.
Raises:
ValueError: If the prompt contains toxic labels and cannot be
processed based on the configuration.
"""
chunks = self._split_paragraph(prompt_value=prompt_value)
for sentence_list in chunks:
segments = [{"Text": sentence} for sentence in sentence_list]
response = self.client.detect_toxic_content(
TextSegments=segments, LanguageCode="en"
)
if self.callback and self.callback.toxicity_callback:
self.moderation_beacon["moderation_input"] = segments # type: ignore
self.moderation_beacon["moderation_output"] = response
toxicity_found = False
threshold = config.get("threshold")
toxicity_labels = config.get("labels")
if not toxicity_labels:
for item in response["ResultList"]:
for label in item["Labels"]:
if label["Score"] >= threshold:
toxicity_found = True
break
else:
for item in response["ResultList"]:
for label in item["Labels"]:
if (
label["Name"] in toxicity_labels
and label["Score"] >= threshold
):
toxicity_found = True
break
if self.callback and self.callback.toxicity_callback:
if toxicity_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_toxicity(
self.moderation_beacon, self.unique_id
)
)
if toxicity_found:
raise ModerationToxicityError
return prompt_value

View File

@ -1,4 +0,0 @@
# Causal program-aided language (CPAL) chain
see https://github.com/langchain-ai/langchain/pull/6255

View File

@ -1,17 +0,0 @@
"""
**Causal program-aided language (CPAL)** is a concept implemented in LangChain as
a chain for causal modeling and narrative decomposition.
CPAL improves upon the program-aided language (**PAL**) by incorporating
causal structure to prevent hallucination in language models,
particularly when dealing with complex narratives and math
problems with nested dependencies.
CPAL involves translating causal narratives into a stack of operations,
setting hypothetical conditions for causal models, and decomposing
narratives into story elements.
It allows for the creation of causal chains that define the relationships
between different elements in a narrative, enabling the modeling and analysis
of causal relationships within a given context.
"""

View File

@ -1,303 +0,0 @@
"""
CPAL Chain and its subchains
"""
from __future__ import annotations
import json
from typing import Any, ClassVar, Dict, List, Optional, Type
import pydantic
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.output_parsers import PydanticOutputParser
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.prompts.prompt import PromptTemplate
from langchain_experimental.cpal.constants import Constant
from langchain_experimental.cpal.models import (
CausalModel,
InterventionModel,
NarrativeModel,
QueryModel,
StoryModel,
)
from langchain_experimental.cpal.templates.univariate.causal import (
template as causal_template,
)
from langchain_experimental.cpal.templates.univariate.intervention import (
template as intervention_template,
)
from langchain_experimental.cpal.templates.univariate.narrative import (
template as narrative_template,
)
from langchain_experimental.cpal.templates.univariate.query import (
template as query_template,
)
class _BaseStoryElementChain(Chain):
chain: LLMChain
input_key: str = Constant.narrative_input.value #: :meta private:
output_key: str = Constant.chain_answer.value #: :meta private:
pydantic_model: ClassVar[Optional[Type[pydantic.BaseModel]]] = (
None #: :meta private:
)
template: ClassVar[Optional[str]] = None #: :meta private:
@classmethod
def parser(cls) -> PydanticOutputParser:
"""Parse LLM output into a pydantic object."""
if cls.pydantic_model is None:
raise NotImplementedError(
f"pydantic_model not implemented for {cls.__name__}"
)
return PydanticOutputParser(pydantic_object=cls.pydantic_model)
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_univariate_prompt(
cls,
llm: BaseLanguageModel,
**kwargs: Any,
) -> Any:
return cls(
chain=LLMChain(
llm=llm,
prompt=PromptTemplate(
input_variables=[Constant.narrative_input.value],
template=kwargs.get("template", cls.template),
partial_variables={
"format_instructions": cls.parser().get_format_instructions()
},
),
),
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
completion = self.chain.run(inputs[self.input_key])
pydantic_data = self.__class__.parser().parse(completion)
return {
Constant.chain_data.value: pydantic_data,
Constant.chain_answer.value: None,
}
class NarrativeChain(_BaseStoryElementChain):
"""Decompose the narrative into its story elements.
- causal model
- query
- intervention
"""
pydantic_model: ClassVar[Type[pydantic.BaseModel]] = NarrativeModel
template: ClassVar[str] = narrative_template
class CausalChain(_BaseStoryElementChain):
"""Translate the causal narrative into a stack of operations."""
pydantic_model: ClassVar[Type[pydantic.BaseModel]] = CausalModel
template: ClassVar[str] = causal_template
class InterventionChain(_BaseStoryElementChain):
"""Set the hypothetical conditions for the causal model."""
pydantic_model: ClassVar[Type[pydantic.BaseModel]] = InterventionModel
template: ClassVar[str] = intervention_template
class QueryChain(_BaseStoryElementChain):
"""Query the outcome table using SQL.
*Security note*: This class implements an AI technique that generates SQL code.
If those SQL commands are executed, it's critical to ensure they use credentials
that are narrowly-scoped to only include the permissions this chain needs.
Failure to do so may result in data corruption or loss, since this chain may
attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this chain.
"""
pydantic_model: ClassVar[Type[pydantic.BaseModel]] = QueryModel
template: ClassVar[str] = query_template # TODO: incl. table schema
class CPALChain(_BaseStoryElementChain):
"""Causal program-aided language (CPAL) chain implementation.
*Security note*: The building blocks of this class include the implementation
of an AI technique that generates SQL code. If those SQL commands
are executed, it's critical to ensure they use credentials that
are narrowly-scoped to only include the permissions this chain needs.
Failure to do so may result in data corruption or loss, since this chain may
attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this chain.
"""
llm: BaseLanguageModel
narrative_chain: Optional[NarrativeChain] = None
causal_chain: Optional[CausalChain] = None
intervention_chain: Optional[InterventionChain] = None
query_chain: Optional[QueryChain] = None
_story: StoryModel = pydantic.PrivateAttr(default=None) # TODO: change name ?
@classmethod
def from_univariate_prompt(
cls,
llm: BaseLanguageModel,
**kwargs: Any,
) -> CPALChain:
"""instantiation depends on component chains
*Security note*: The building blocks of this class include the implementation
of an AI technique that generates SQL code. If those SQL commands
are executed, it's critical to ensure they use credentials that
are narrowly-scoped to only include the permissions this chain needs.
Failure to do so may result in data corruption or loss, since this chain may
attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this chain.
"""
return cls(
llm=llm,
chain=LLMChain(
llm=llm,
prompt=PromptTemplate(
input_variables=["question", "query_result"],
template=(
"Summarize this answer '{query_result}' to this "
"question '{question}'? "
),
),
),
narrative_chain=NarrativeChain.from_univariate_prompt(llm=llm),
causal_chain=CausalChain.from_univariate_prompt(llm=llm),
intervention_chain=InterventionChain.from_univariate_prompt(llm=llm),
query_chain=QueryChain.from_univariate_prompt(llm=llm),
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
**kwargs: Any,
) -> Dict[str, Any]:
# instantiate component chains
if self.narrative_chain is None:
self.narrative_chain = NarrativeChain.from_univariate_prompt(llm=self.llm)
if self.causal_chain is None:
self.causal_chain = CausalChain.from_univariate_prompt(llm=self.llm)
if self.intervention_chain is None:
self.intervention_chain = InterventionChain.from_univariate_prompt(
llm=self.llm
)
if self.query_chain is None:
self.query_chain = QueryChain.from_univariate_prompt(llm=self.llm)
# decompose narrative into three causal story elements
narrative = self.narrative_chain(inputs[Constant.narrative_input.value])[
Constant.chain_data.value
]
story = StoryModel(
causal_operations=self.causal_chain(narrative.story_plot)[
Constant.chain_data.value
],
intervention=self.intervention_chain(narrative.story_hypothetical)[
Constant.chain_data.value
],
query=self.query_chain(narrative.story_outcome_question)[
Constant.chain_data.value
],
)
self._story = story
def pretty_print_str(title: str, d: str) -> str:
return title + "\n" + d
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(
pretty_print_str("story outcome data", story._outcome_table.to_string()),
color="green",
end="\n\n",
verbose=self.verbose,
)
def pretty_print_dict(title: str, d: dict) -> str:
return title + "\n" + json.dumps(d, indent=4)
_run_manager.on_text(
pretty_print_dict("query data", story.query.dict()),
color="blue",
end="\n\n",
verbose=self.verbose,
)
if story.query._result_table.empty:
# prevent piping bad data into subsequent chains
raise ValueError(
(
"unanswerable, query and outcome are incoherent\n"
"\n"
"outcome:\n"
f"{story._outcome_table}\n"
"query:\n"
f"{story.query.dict()}"
)
)
else:
query_result = float(story.query._result_table.values[0][-1])
if False:
"""TODO: add this back in when demanded by composable chains"""
reporting_chain = self.chain
human_report = reporting_chain.run(
question=story.query.question, query_result=query_result
)
query_result = {
"query_result": query_result,
"human_report": human_report,
}
output = {
Constant.chain_data.value: story,
self.output_key: query_result,
**kwargs,
}
return output
def draw(self, **kwargs: Any) -> None:
"""
CPAL chain can draw its resulting DAG.
Usage in a jupyter notebook:
>>> from IPython.display import SVG
>>> cpal_chain.draw(path="graph.svg")
>>> SVG('graph.svg')
"""
self._story._networkx_wrapper.draw_graphviz(**kwargs)

View File

@ -1,9 +0,0 @@
from enum import Enum
class Constant(Enum):
"""Enum for constants used in the CPAL."""
narrative_input = "narrative_input"
chain_answer = "chain_answer" # natural language answer
chain_data = "chain_data" # pydantic instance

View File

@ -1,279 +0,0 @@
from __future__ import annotations # allows pydantic model to reference itself
import re
from typing import Any, List, Optional, Union
from langchain_community.graphs.networkx_graph import NetworkxEntityGraph
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
field_validator,
model_validator,
)
from langchain_experimental.cpal.constants import Constant
class NarrativeModel(BaseModel):
"""
Narrative input as three story elements.
"""
story_outcome_question: str
story_hypothetical: str
story_plot: str # causal stack of operations
@field_validator("*", mode="before")
def empty_str_to_none(cls, v: str) -> Union[str, None]:
"""Empty strings are not allowed"""
if v == "":
return None
return v
class EntityModel(BaseModel):
"""Entity in the story."""
name: str = Field(description="entity name")
code: str = Field(description="entity actions")
value: float = Field(description="entity initial value")
depends_on: List[str] = Field(default=[], description="ancestor entities")
# TODO: generalize to multivariate math
# TODO: acyclic graph
model_config = ConfigDict(
validate_assignment=True,
)
@field_validator("name")
def lower_case_name(cls, v: str) -> str:
v = v.lower()
return v
class CausalModel(BaseModel):
"""Casual data."""
attribute: str = Field(description="name of the attribute to be calculated")
entities: List[EntityModel] = Field(description="entities in the story")
# TODO: root validate each `entity.depends_on` using system's entity names
class EntitySettingModel(BaseModel):
"""Entity initial conditions.
Initial conditions for an entity
{"name": "bud", "attribute": "pet_count", "value": 12}
"""
name: str = Field(description="name of the entity")
attribute: str = Field(description="name of the attribute to be calculated")
value: float = Field(description="entity's attribute value (calculated)")
@field_validator("name")
def lower_case_transform(cls, v: str) -> str:
v = v.lower()
return v
class SystemSettingModel(BaseModel):
"""System initial conditions.
Initial global conditions for the system.
{"parameter": "interest_rate", "value": .05}
"""
parameter: str
value: float
class InterventionModel(BaseModel):
"""Intervention data of the story aka initial conditions.
>>> intervention.dict()
{
entity_settings: [
{"name": "bud", "attribute": "pet_count", "value": 12},
{"name": "pat", "attribute": "pet_count", "value": 0},
],
system_settings: None,
}
"""
entity_settings: List[EntitySettingModel]
system_settings: Optional[List[SystemSettingModel]] = None
@field_validator("system_settings")
def lower_case_name(cls, v: str) -> Union[str, None]:
if v is not None:
raise NotImplementedError("system_setting is not implemented yet")
return v
class QueryModel(BaseModel):
"""Query data of the story.
translate a question about the story outcome into a programmatic expression"""
question: str = Field( # type: ignore[literal-required]
alias=Constant.narrative_input.value
) # input # type: ignore[literal-required]
expression: str # output, part of llm completion
llm_error_msg: str # output, part of llm completion
_result_table: str = PrivateAttr() # result of the executed query
class ResultModel(BaseModel):
"""Result of the story query."""
question: str = Field( # type: ignore[literal-required]
alias=Constant.narrative_input.value
) # input # type: ignore[literal-required]
_result_table: str = PrivateAttr() # result of the executed query
class StoryModel(BaseModel):
"""Story data."""
causal_operations: Any = Field()
intervention: Any = Field()
query: Any = Field()
_outcome_table: Any = PrivateAttr(default=None)
_networkx_wrapper: Any = PrivateAttr(default=None)
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._compute()
# TODO: when langchain adopts pydantic.v2 replace w/ `__post_init__`
# misses hints github.com/pydantic/pydantic/issues/1729#issuecomment-1300576214
@model_validator(mode="before")
@classmethod
def check_intervention_is_valid(cls, values: dict) -> Any:
valid_names = [e.name for e in values["causal_operations"].entities]
for setting in values["intervention"].entity_settings:
if setting.name not in valid_names:
error_msg = f"""
Hypothetical question has an invalid entity name.
`{setting.name}` not in `{valid_names}`
"""
raise ValueError(error_msg)
return values
def _block_back_door_paths(self) -> None:
# stop intervention entities from depending on others
intervention_entities = [
entity_setting.name for entity_setting in self.intervention.entity_settings
]
for entity in self.causal_operations.entities:
if entity.name in intervention_entities:
entity.depends_on = []
entity.code = "pass"
def _set_initial_conditions(self) -> None:
for entity_setting in self.intervention.entity_settings:
for entity in self.causal_operations.entities:
if entity.name == entity_setting.name:
entity.value = entity_setting.value
def _make_graph(self) -> None:
self._networkx_wrapper = NetworkxEntityGraph()
for entity in self.causal_operations.entities:
for parent_name in entity.depends_on:
self._networkx_wrapper._graph.add_edge(
parent_name, entity.name, relation=entity.code
)
# TODO: is it correct to drop entities with no impact on the outcome (?)
self.causal_operations.entities = [
entity
for entity in self.causal_operations.entities
if entity.name in self._networkx_wrapper.get_topological_sort()
]
def _sort_entities(self) -> None:
# order the sequence of causal actions
sorted_nodes = self._networkx_wrapper.get_topological_sort()
self.causal_operations.entities.sort(key=lambda x: sorted_nodes.index(x.name))
def _forward_propagate(self) -> None:
try:
import pandas as pd
except ImportError as e:
raise ImportError(
"Unable to import pandas, please install with `pip install pandas`."
) from e
entity_scope = {
entity.name: entity for entity in self.causal_operations.entities
}
for entity in self.causal_operations.entities:
if entity.code == "pass":
continue
else:
# gist.github.com/dean0x7d/df5ce97e4a1a05be4d56d1378726ff92
exec(entity.code, globals(), entity_scope)
row_values = [entity.dict() for entity in entity_scope.values()]
self._outcome_table = pd.DataFrame(row_values)
def _run_query(self) -> None:
def humanize_sql_error_msg(error: str) -> str:
pattern = r"column\s+(.*?)\s+not found"
col_match = re.search(pattern, error)
if col_match:
return (
"SQL error: "
+ col_match.group(1)
+ " is not an attribute in your story!"
)
else:
return str(error)
if self.query.llm_error_msg == "":
try:
import duckdb
df = self._outcome_table # noqa
query_result = duckdb.sql(self.query.expression).df()
self.query._result_table = query_result
except duckdb.BinderException as e:
self.query._result_table = humanize_sql_error_msg(str(e))
except ImportError as e:
raise ImportError(
"Unable to import duckdb, please install with `pip install duckdb`."
) from e
except Exception as e:
self.query._result_table = str(e)
else:
msg = "LLM maybe failed to translate question to SQL query."
raise ValueError(
{
"question": self.query.question,
"llm_error_msg": self.query.llm_error_msg,
"msg": msg,
}
)
def _compute(self) -> Any:
self._block_back_door_paths()
self._set_initial_conditions()
self._make_graph()
self._sort_entities()
self._forward_propagate()
self._run_query()
def print_debug_report(self) -> None:
report = {
"outcome": self._outcome_table,
"query": self.query.dict(),
"result": self.query._result_table,
}
from pprint import pprint
pprint(report)

View File

@ -1,113 +0,0 @@
# ruff: noqa: E501
# fmt: off
template = (
"""
Transform the math story plot into a JSON object. Don't guess at any of the parts.
{format_instructions}
Story: Boris has seven times the number of pets as Marcia. Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy.
# JSON:
{{
"attribute": "pet_count",
"entities": [
{{
"name": "cindy",
"value": 0,
"depends_on": [],
"code": "pass"
}},
{{
"name": "marcia",
"value": 0,
"depends_on": ["cindy"],
"code": "marcia.value = cindy.value + 2"
}},
{{
"name": "boris",
"value": 0,
"depends_on": ["marcia"],
"code": "boris.value = marcia.value * 7"
}},
{{
"name": "jan",
"value": 0,
"depends_on": ["marcia"],
"code": "jan.value = marcia.value * 3"
}}
]
}}
Story: Boris gives 20 percent of his money to Marcia. Marcia gives 10
percent of her money to Cindy. Cindy gives 5 percent of her money to Jan.
# JSON:
{{
"attribute": "money",
"entities": [
{{
"name": "boris",
"value": 0,
"depends_on": [],
"code": "pass"
}},
{{
"name": "marcia",
"value": 0,
"depends_on": ["boris"],
"code": "
marcia.value = boris.value * 0.2
boris.value = boris.value * 0.8
"
}},
{{
"name": "cindy",
"value": 0,
"depends_on": ["marcia"],
"code": "
cindy.value = marcia.value * 0.1
marcia.value = marcia.value * 0.9
"
}},
{{
"name": "jan",
"value": 0,
"depends_on": ["cindy"],
"code": "
jan.value = cindy.value * 0.05
cindy.value = cindy.value * 0.9
"
}}
]
}}
Story: {narrative_input}
# JSON:
""".strip()
+ "\n"
)
# fmt: on

View File

@ -1,59 +0,0 @@
# ruff: noqa: E501
# fmt: off
template = (
"""
Transform the hypothetical whatif statement into JSON. Don't guess at any of the parts. Write NONE if you are unsure.
{format_instructions}
statement: if cindy's pet count was 4
# JSON:
{{
"entity_settings" : [
{{ "name": "cindy", "attribute": "pet_count", "value": "4" }}
]
}}
statement: Let's say boris has ten dollars and Bill has 20 dollars.
# JSON:
{{
"entity_settings" : [
{{ "name": "boris", "attribute": "dollars", "value": "10" }},
{{ "name": "bill", "attribute": "dollars", "value": "20" }}
]
}}
Statement: {narrative_input}
# JSON:
""".strip()
+ "\n\n\n"
)
# fmt: on

View File

@ -1,79 +0,0 @@
# ruff: noqa: E501
# fmt: off
template = (
"""
Split the given text into three parts: the question, the story_hypothetical, and the logic. Don't guess at any of the parts. Write NONE if you are unsure.
{format_instructions}
Q: Boris has seven times the number of pets as Marcia. Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy. If Cindy has four pets, how many total pets do the three have?
# JSON
{{
"story_outcome_question": "how many total pets do the three have?",
"story_hypothetical": "If Cindy has four pets",
"story_plot": "Boris has seven times the number of pets as Marcia. Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy."
}}
Q: boris gives ten percent of his money to marcia. marcia gives ten
percent of her money to andy. If boris has 100 dollars, how much money
will andy have?
# JSON
{{
"story_outcome_question": "how much money will andy have?",
"story_hypothetical": "If boris has 100 dollars"
"story_plot": "boris gives ten percent of his money to marcia. marcia gives ten percent of her money to andy."
}}
Q: boris gives ten percent of his candy to marcia. marcia gives ten
percent of her candy to andy. If boris has 100 pounds of candy and marcia has
200 pounds of candy, then how many pounds of candy will andy have?
# JSON
{{
"story_outcome_question": "how many pounds of candy will andy have?",
"story_hypothetical": "If boris has 100 pounds of candy and marcia has 200 pounds of candy"
"story_plot": "boris gives ten percent of his candy to marcia. marcia gives ten percent of her candy to andy."
}}
Q: {narrative_input}
# JSON
""".strip()
+ "\n\n\n"
)
# fmt: on

View File

@ -1,270 +0,0 @@
# ruff: noqa: E501
# fmt: off
template = (
"""
Transform the narrative_input into an SQL expression. If you are
unsure, then do not guess, instead add a llm_error_msg that explains why you are unsure.
{format_instructions}
narrative_input: how much money will boris have?
# JSON:
{{
"narrative_input": "how much money will boris have?",
"llm_error_msg": "",
"expression": "SELECT name, value FROM df WHERE name = 'boris'"
}}
narrative_input: How much money does ted have?
# JSON:
{{
"narrative_input": "How much money does ted have?",
"llm_error_msg": "",
"expression": "SELECT name, value FROM df WHERE name = 'ted'"
}}
narrative_input: what is the sum of pet count for all the people?
# JSON:
{{
"narrative_input": "what is the sum of pet count for all the people?",
"llm_error_msg": "",
"expression": "SELECT SUM(value) FROM df"
}}
narrative_input: what's the average of the pet counts for all the people?
# JSON:
{{
"narrative_input": "what's the average of the pet counts for all the people?",
"llm_error_msg": "",
"expression": "SELECT AVG(value) FROM df"
}}
narrative_input: what's the maximum of the pet counts for all the people?
# JSON:
{{
"narrative_input": "what's the maximum of the pet counts for all the people?",
"llm_error_msg": "",
"expression": "SELECT MAX(value) FROM df"
}}
narrative_input: what's the minimum of the pet counts for all the people?
# JSON:
{{
"narrative_input": "what's the minimum of the pet counts for all the people?",
"llm_error_msg": "",
"expression": "SELECT MIN(value) FROM df"
}}
narrative_input: what's the number of people with pet counts greater than 10?
# JSON:
{{
"narrative_input": "what's the number of people with pet counts greater than 10?",
"llm_error_msg": "",
"expression": "SELECT COUNT(*) FROM df WHERE value > 10"
}}
narrative_input: what's the pet count for boris?
# JSON:
{{
"narrative_input": "what's the pet count for boris?",
"llm_error_msg": "",
"expression": "SELECT name, value FROM df WHERE name = 'boris'"
}}
narrative_input: what's the pet count for cindy and marcia?
# JSON:
{{
"narrative_input": "what's the pet count for cindy and marcia?",
"llm_error_msg": "",
"expression": "SELECT name, value FROM df WHERE name IN ('cindy', 'marcia')"
}}
narrative_input: what's the total pet count for cindy and marcia?
# JSON:
{{
"narrative_input": "what's the total pet count for cindy and marcia?",
"llm_error_msg": "",
"expression": "SELECT SUM(value) FROM df WHERE name IN ('cindy', 'marcia')"
}}
narrative_input: what's the total pet count for TED?
# JSON:
{{
"narrative_input": "what's the total pet count for TED?",
"llm_error_msg": "",
"expression": "SELECT SUM(value) FROM df WHERE name = 'TED'"
}}
narrative_input: what's the total dollar count for TED and cindy?
# JSON:
{{
"narrative_input": "what's the total dollar count for TED and cindy?",
"llm_error_msg": "",
"expression": "SELECT SUM(value) FROM df WHERE name IN ('TED', 'cindy')"
}}
narrative_input: what's the total pet count for TED and cindy?
# JSON:
{{
"narrative_input": "what's the total pet count for TED and cindy?",
"llm_error_msg": "",
"expression": "SELECT SUM(value) FROM df WHERE name IN ('TED', 'cindy')"
}}
narrative_input: what's the best for TED and cindy?
# JSON:
{{
"narrative_input": "what's the best for TED and cindy?",
"llm_error_msg": "ambiguous narrative_input, not sure what 'best' means",
"expression": ""
}}
narrative_input: what's the value?
# JSON:
{{
"narrative_input": "what's the value?",
"llm_error_msg": "ambiguous narrative_input, not sure what entity is being asked about",
"expression": ""
}}
narrative_input: how many total pets do the three have?
# JSON:
{{
"narrative_input": "how many total pets do the three have?",
"llm_error_msg": "",
"expression": "SELECT SUM(value) FROM df"
}}
narrative_input: {narrative_input}
# JSON:
""".strip()
+ "\n\n\n"
)
# fmt: on

View File

@ -1,17 +0,0 @@
"""**Data anonymizer** contains both Anonymizers and Deanonymizers.
It uses the [Microsoft Presidio](https://microsoft.github.io/presidio/) library.
**Anonymizers** are used to replace a `Personally Identifiable Information (PII)`
entity text with some other
value by applying a certain operator (e.g. replace, mask, redact, encrypt).
**Deanonymizers** are used to revert the anonymization operation
(e.g. to decrypt an encrypted text).
"""
from langchain_experimental.data_anonymizer.presidio import (
PresidioAnonymizer,
PresidioReversibleAnonymizer,
)
__all__ = ["PresidioAnonymizer", "PresidioReversibleAnonymizer"]

View File

@ -1,61 +0,0 @@
from abc import ABC, abstractmethod
from typing import Callable, List, Optional
from langchain_experimental.data_anonymizer.deanonymizer_mapping import MappingDataType
from langchain_experimental.data_anonymizer.deanonymizer_matching_strategies import (
exact_matching_strategy,
)
DEFAULT_DEANONYMIZER_MATCHING_STRATEGY = exact_matching_strategy
class AnonymizerBase(ABC):
"""Base abstract class for anonymizers.
It is public and non-virtual because it allows
wrapping the behavior for all methods in a base class.
"""
def anonymize(
self,
text: str,
language: Optional[str] = None,
allow_list: Optional[List[str]] = None,
) -> str:
"""Anonymize text."""
return self._anonymize(text, language, allow_list)
@abstractmethod
def _anonymize(
self, text: str, language: Optional[str], allow_list: Optional[List[str]] = None
) -> str:
"""Abstract method to anonymize text"""
class ReversibleAnonymizerBase(AnonymizerBase):
"""
Base abstract class for reversible anonymizers.
"""
def deanonymize(
self,
text_to_deanonymize: str,
deanonymizer_matching_strategy: Callable[
[str, MappingDataType], str
] = DEFAULT_DEANONYMIZER_MATCHING_STRATEGY,
) -> str:
"""Deanonymize text"""
return self._deanonymize(text_to_deanonymize, deanonymizer_matching_strategy)
@abstractmethod
def _deanonymize(
self,
text_to_deanonymize: str,
deanonymizer_matching_strategy: Callable[[str, MappingDataType], str],
) -> str:
"""Abstract method to deanonymize text"""
@abstractmethod
def reset_deanonymizer_mapping(self) -> None:
"""Abstract method to reset deanonymizer mapping"""

View File

@ -1,144 +0,0 @@
import re
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List
if TYPE_CHECKING:
from presidio_analyzer import RecognizerResult
from presidio_anonymizer.entities import EngineResult
MappingDataType = Dict[str, Dict[str, str]]
def format_duplicated_operator(operator_name: str, count: int) -> str:
"""Format the operator name with the count."""
clean_operator_name = re.sub(r"[<>]", "", operator_name)
clean_operator_name = re.sub(r"_\d+$", "", clean_operator_name)
if operator_name.startswith("<") and operator_name.endswith(">"):
return f"<{clean_operator_name}_{count}>"
else:
return f"{clean_operator_name}_{count}"
@dataclass
class DeanonymizerMapping:
"""Deanonymizer mapping."""
mapping: MappingDataType = field(
default_factory=lambda: defaultdict(lambda: defaultdict(str))
)
@property
def data(self) -> MappingDataType:
"""Return the deanonymizer mapping."""
return {k: dict(v) for k, v in self.mapping.items()}
def update(self, new_mapping: MappingDataType) -> None:
"""Update the deanonymizer mapping with new values.
Duplicated values will not be added
If there are multiple entities of the same type, the mapping will
include a count to differentiate them. For example, if there are
two names in the input text, the mapping will include NAME_1 and NAME_2.
"""
seen_values = set()
for entity_type, values in new_mapping.items():
count = len(self.mapping[entity_type]) + 1
for key, value in values.items():
if (
value not in seen_values
and value not in self.mapping[entity_type].values()
):
new_key = (
format_duplicated_operator(key, count)
if key in self.mapping[entity_type]
else key
)
self.mapping[entity_type][new_key] = value
seen_values.add(value)
count += 1
def create_anonymizer_mapping(
original_text: str,
analyzer_results: List["RecognizerResult"],
anonymizer_results: "EngineResult",
is_reversed: bool = False,
) -> MappingDataType:
"""Create or update the mapping used to anonymize and/or
deanonymize a text.
This method exploits the results returned by the
analysis and anonymization processes.
If is_reversed is True, it constructs a mapping from each original
entity to its anonymized value.
If is_reversed is False, it constructs a mapping from each
anonymized entity back to its original text value.
If there are multiple entities of the same type, the mapping will
include a count to differentiate them. For example, if there are
two names in the input text, the mapping will include NAME_1 and NAME_2.
Example of mapping:
{
"PERSON": {
"<original>": "<anonymized>",
"John Doe": "Slim Shady"
},
"PHONE_NUMBER": {
"111-111-1111": "555-555-5555"
}
...
}
"""
# We are able to zip and loop through both lists because we expect
# them to return corresponding entities for each identified piece
# of analyzable data from our input.
# We sort them by their 'start' attribute because it allows us to
# match corresponding entities by their position in the input text.
analyzer_results.sort(key=lambda d: d.start)
anonymizer_results.items.sort(key=lambda d: d.start)
mapping: MappingDataType = defaultdict(dict)
count: dict = defaultdict(int)
for analyzed, anonymized in zip(analyzer_results, anonymizer_results.items):
original_value = original_text[analyzed.start : analyzed.end]
entity_type = anonymized.entity_type
if is_reversed:
cond = original_value in mapping[entity_type].values()
else:
cond = original_value in mapping[entity_type]
if cond:
continue
if (
anonymized.text in mapping[entity_type].values()
or anonymized.text in mapping[entity_type]
):
anonymized_value = format_duplicated_operator(
anonymized.text, count[entity_type] + 2
)
count[entity_type] += 1
else:
anonymized_value = anonymized.text
mapping_key, mapping_value = (
(anonymized_value, original_value)
if is_reversed
else (original_value, anonymized_value)
)
mapping[entity_type][mapping_key] = mapping_value
return mapping

View File

@ -1,185 +0,0 @@
import re
from typing import List
from langchain_experimental.data_anonymizer.deanonymizer_mapping import MappingDataType
def exact_matching_strategy(text: str, deanonymizer_mapping: MappingDataType) -> str:
"""Exact matching strategy for deanonymization.
It replaces all the anonymized entities with the original ones.
Args:
text: text to deanonymize
deanonymizer_mapping: mapping between anonymized entities and original ones"""
# Iterate over all the entities (PERSON, EMAIL_ADDRESS, etc.)
for entity_type in deanonymizer_mapping:
for anonymized, original in deanonymizer_mapping[entity_type].items():
text = text.replace(anonymized, original)
return text
def case_insensitive_matching_strategy(
text: str, deanonymizer_mapping: MappingDataType
) -> str:
"""Case insensitive matching strategy for deanonymization.
It replaces all the anonymized entities with the original ones
irrespective of their letter case.
Args:
text: text to deanonymize
deanonymizer_mapping: mapping between anonymized entities and original ones
Examples of matching:
keanu reeves -> Keanu Reeves
JOHN F. KENNEDY -> John F. Kennedy
"""
# Iterate over all the entities (PERSON, EMAIL_ADDRESS, etc.)
for entity_type in deanonymizer_mapping:
for anonymized, original in deanonymizer_mapping[entity_type].items():
# Use regular expressions for case-insensitive matching and replacing
text = re.sub(anonymized, original, text, flags=re.IGNORECASE)
return text
def fuzzy_matching_strategy(
text: str, deanonymizer_mapping: MappingDataType, max_l_dist: int = 3
) -> str:
"""Fuzzy matching strategy for deanonymization.
It uses fuzzy matching to find the position of the anonymized entity in the text.
It replaces all the anonymized entities with the original ones.
Args:
text: text to deanonymize
deanonymizer_mapping: mapping between anonymized entities and original ones
max_l_dist: maximum Levenshtein distance between the anonymized entity and the
text segment to consider it a match
Examples of matching:
Kaenu Reves -> Keanu Reeves
John F. Kennedy -> John Kennedy
"""
try:
from fuzzysearch import find_near_matches
except ImportError as e:
raise ImportError(
"Could not import fuzzysearch, please install with "
"`pip install fuzzysearch`."
) from e
for entity_type in deanonymizer_mapping:
for anonymized, original in deanonymizer_mapping[entity_type].items():
matches = find_near_matches(anonymized, text, max_l_dist=max_l_dist)
new_text = ""
last_end = 0
for m in matches:
# add the text that isn't part of a match
new_text += text[last_end : m.start]
# add the replacement text
new_text += original
last_end = m.end
# add the remaining text that wasn't part of a match
new_text += text[last_end:]
text = new_text
return text
def combined_exact_fuzzy_matching_strategy(
text: str, deanonymizer_mapping: MappingDataType, max_l_dist: int = 3
) -> str:
"""Combined exact and fuzzy matching strategy for deanonymization.
It is a RECOMMENDED STRATEGY.
Args:
text: text to deanonymize
deanonymizer_mapping: mapping between anonymized entities and original ones
max_l_dist: maximum Levenshtein distance between the anonymized entity and the
text segment to consider it a match
Examples of matching:
Kaenu Reves -> Keanu Reeves
John F. Kennedy -> John Kennedy
"""
text = exact_matching_strategy(text, deanonymizer_mapping)
text = fuzzy_matching_strategy(text, deanonymizer_mapping, max_l_dist)
return text
def ngram_fuzzy_matching_strategy(
text: str,
deanonymizer_mapping: MappingDataType,
fuzzy_threshold: int = 85,
use_variable_length: bool = True,
) -> str:
"""N-gram fuzzy matching strategy for deanonymization.
It replaces all the anonymized entities with the original ones.
It uses fuzzy matching to find the position of the anonymized entity in the text.
It generates n-grams of the same length as the anonymized entity from the text and
uses fuzzy matching to find the position of the anonymized entity in the text.
Args:
text: text to deanonymize
deanonymizer_mapping: mapping between anonymized entities and original ones
fuzzy_threshold: fuzzy matching threshold
use_variable_length: whether to use (n-1, n, n+1)-grams or just n-grams
"""
def generate_ngrams(words_list: List[str], n: int) -> list:
"""Generate n-grams from a list of words"""
return [
" ".join(words_list[i : i + n]) for i in range(len(words_list) - (n - 1))
]
try:
from fuzzywuzzy import fuzz
except ImportError as e:
raise ImportError(
"Could not import fuzzywuzzy, please install with "
"`pip install fuzzywuzzy`."
) from e
text_words = text.split()
replacements = []
matched_indices: List[int] = []
for entity_type in deanonymizer_mapping:
for anonymized, original in deanonymizer_mapping[entity_type].items():
anonymized_words = anonymized.split()
if use_variable_length:
gram_lengths = [
len(anonymized_words) - 1,
len(anonymized_words),
len(anonymized_words) + 1,
]
else:
gram_lengths = [len(anonymized_words)]
for n in gram_lengths:
if n > 0: # Take only positive values
segments = generate_ngrams(text_words, n)
for i, segment in enumerate(segments):
if (
fuzz.ratio(anonymized.lower(), segment.lower())
> fuzzy_threshold
and i not in matched_indices
):
replacements.append((i, n, original))
# Add the matched segment indices to the list
matched_indices.extend(range(i, i + n))
# Sort replacements by index in reverse order
replacements.sort(key=lambda x: x[0], reverse=True)
# Apply replacements in reverse order to not affect subsequent indices
for start, length, replacement in replacements:
text_words[start : start + length] = replacement.split()
return " ".join(text_words)

View File

@ -1,62 +0,0 @@
import string
from typing import Callable, Dict, Optional
def get_pseudoanonymizer_mapping(seed: Optional[int] = None) -> Dict[str, Callable]:
"""Get a mapping of entities to pseudo anonymize them."""
try:
from faker import Faker
except ImportError as e:
raise ImportError(
"Could not import faker, please install it with `pip install Faker`."
) from e
fake = Faker()
fake.seed_instance(seed)
# Listed entities supported by Microsoft Presidio (for now, global and US only)
# Source: https://microsoft.github.io/presidio/supported_entities/
return {
# Global entities
"PERSON": lambda _: fake.name(),
"EMAIL_ADDRESS": lambda _: fake.email(),
"PHONE_NUMBER": lambda _: fake.phone_number(),
"IBAN_CODE": lambda _: fake.iban(),
"CREDIT_CARD": lambda _: fake.credit_card_number(),
"CRYPTO": lambda _: "bc1"
+ "".join(
fake.random_choices(string.ascii_lowercase + string.digits, length=26)
),
"IP_ADDRESS": lambda _: fake.ipv4_public(),
"LOCATION": lambda _: fake.city(),
"DATE_TIME": lambda _: fake.date(),
"NRP": lambda _: str(fake.random_number(digits=8, fix_len=True)),
"MEDICAL_LICENSE": lambda _: fake.bothify(text="??######").upper(),
"URL": lambda _: fake.url(),
# US-specific entities
"US_BANK_NUMBER": lambda _: fake.bban(),
"US_DRIVER_LICENSE": lambda _: str(fake.random_number(digits=9, fix_len=True)),
"US_ITIN": lambda _: fake.bothify(text="9##-7#-####"),
"US_PASSPORT": lambda _: fake.bothify(text="#####??").upper(),
"US_SSN": lambda _: fake.ssn(),
# UK-specific entities
"UK_NHS": lambda _: str(fake.random_number(digits=10, fix_len=True)),
# Spain-specific entities
"ES_NIF": lambda _: fake.bothify(text="########?").upper(),
# Italy-specific entities
"IT_FISCAL_CODE": lambda _: fake.bothify(text="??????##?##?###?").upper(),
"IT_DRIVER_LICENSE": lambda _: fake.bothify(text="?A#######?").upper(),
"IT_VAT_CODE": lambda _: fake.bothify(text="IT???????????"),
"IT_PASSPORT": lambda _: str(fake.random_number(digits=9, fix_len=True)),
"IT_IDENTITY_CARD": lambda _: lambda _: str(
fake.random_number(digits=7, fix_len=True)
),
# Singapore-specific entities
"SG_NRIC_FIN": lambda _: fake.bothify(text="????####?").upper(),
# Australia-specific entities
"AU_ABN": lambda _: str(fake.random_number(digits=11, fix_len=True)),
"AU_ACN": lambda _: str(fake.random_number(digits=9, fix_len=True)),
"AU_TFN": lambda _: str(fake.random_number(digits=9, fix_len=True)),
"AU_MEDICARE": lambda _: str(fake.random_number(digits=10, fix_len=True)),
}

View File

@ -1,468 +0,0 @@
from __future__ import annotations
import json
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
import yaml
from langchain_experimental.data_anonymizer.base import (
DEFAULT_DEANONYMIZER_MATCHING_STRATEGY,
AnonymizerBase,
ReversibleAnonymizerBase,
)
from langchain_experimental.data_anonymizer.deanonymizer_mapping import (
DeanonymizerMapping,
MappingDataType,
create_anonymizer_mapping,
)
from langchain_experimental.data_anonymizer.deanonymizer_matching_strategies import (
exact_matching_strategy,
)
from langchain_experimental.data_anonymizer.faker_presidio_mapping import (
get_pseudoanonymizer_mapping,
)
if TYPE_CHECKING:
from presidio_analyzer import AnalyzerEngine, EntityRecognizer
from presidio_analyzer.nlp_engine import NlpEngineProvider
from presidio_anonymizer import AnonymizerEngine
from presidio_anonymizer.entities import ConflictResolutionStrategy, OperatorConfig
def _import_analyzer_engine() -> "AnalyzerEngine":
try:
from presidio_analyzer import AnalyzerEngine
except ImportError as e:
raise ImportError(
"Could not import presidio_analyzer, please install with "
"`pip install presidio-analyzer`. You will also need to download a "
"spaCy model to use the analyzer, e.g. "
"`python -m spacy download en_core_web_lg`."
) from e
return AnalyzerEngine
def _import_nlp_engine_provider() -> "NlpEngineProvider":
try:
from presidio_analyzer.nlp_engine import NlpEngineProvider
except ImportError as e:
raise ImportError(
"Could not import presidio_analyzer, please install with "
"`pip install presidio-analyzer`. You will also need to download a "
"spaCy model to use the analyzer, e.g. "
"`python -m spacy download en_core_web_lg`."
) from e
return NlpEngineProvider
def _import_anonymizer_engine() -> "AnonymizerEngine":
try:
from presidio_anonymizer import AnonymizerEngine
except ImportError as e:
raise ImportError(
"Could not import presidio_anonymizer, please install with "
"`pip install presidio-anonymizer`."
) from e
return AnonymizerEngine
def _import_operator_config() -> "OperatorConfig":
try:
from presidio_anonymizer.entities import OperatorConfig
except ImportError as e:
raise ImportError(
"Could not import presidio_anonymizer, please install with "
"`pip install presidio-anonymizer`."
) from e
return OperatorConfig
# Configuring Anonymizer for multiple languages
# Detailed description and examples can be found here:
# langchain/docs/extras/guides/privacy/multi_language_anonymization.ipynb
DEFAULT_LANGUAGES_CONFIG = {
# You can also use Stanza or transformers library.
# See https://microsoft.github.io/presidio/analyzer/customizing_nlp_models/
"nlp_engine_name": "spacy",
"models": [
{"lang_code": "en", "model_name": "en_core_web_lg"},
# {"lang_code": "de", "model_name": "de_core_news_md"},
# {"lang_code": "es", "model_name": "es_core_news_md"},
# ...
# List of available models: https://spacy.io/usage/models
],
}
class PresidioAnonymizerBase(AnonymizerBase):
"""Base Anonymizer using Microsoft Presidio.
See more: https://microsoft.github.io/presidio/
"""
def __init__(
self,
analyzed_fields: Optional[List[str]] = None,
operators: Optional[Dict[str, OperatorConfig]] = None,
languages_config: Optional[Dict] = None,
add_default_faker_operators: bool = True,
faker_seed: Optional[int] = None,
):
"""
Args:
analyzed_fields: List of fields to detect and then anonymize.
Defaults to all entities supported by Microsoft Presidio.
operators: Operators to use for anonymization.
Operators allow for custom anonymization of detected PII.
Learn more:
https://microsoft.github.io/presidio/tutorial/10_simple_anonymization/
languages_config: Configuration for the NLP engine.
First language in the list will be used as the main language
in self.anonymize(...) when no language is specified.
Learn more:
https://microsoft.github.io/presidio/analyzer/customizing_nlp_models/
faker_seed: Seed used to initialize faker.
Defaults to None, in which case faker will be seeded randomly
and provide random values.
"""
if languages_config is None:
languages_config = DEFAULT_LANGUAGES_CONFIG
OperatorConfig = _import_operator_config()
AnalyzerEngine = _import_analyzer_engine()
NlpEngineProvider = _import_nlp_engine_provider()
AnonymizerEngine = _import_anonymizer_engine()
self.analyzed_fields = (
analyzed_fields
if analyzed_fields is not None
else list(get_pseudoanonymizer_mapping().keys())
)
if add_default_faker_operators:
self.operators = {
field: OperatorConfig(
operator_name="custom", params={"lambda": faker_function}
)
for field, faker_function in get_pseudoanonymizer_mapping(
faker_seed
).items()
}
else:
self.operators = {}
if operators:
self.add_operators(operators)
provider = NlpEngineProvider(nlp_configuration=languages_config)
nlp_engine = provider.create_engine()
self.supported_languages = list(nlp_engine.nlp.keys())
self._analyzer = AnalyzerEngine(
supported_languages=self.supported_languages, nlp_engine=nlp_engine
)
self._anonymizer = AnonymizerEngine()
def add_recognizer(self, recognizer: EntityRecognizer) -> None:
"""Add a recognizer to the analyzer
Args:
recognizer: Recognizer to add to the analyzer.
"""
self._analyzer.registry.add_recognizer(recognizer)
self.analyzed_fields.extend(recognizer.supported_entities)
def add_operators(self, operators: Dict[str, OperatorConfig]) -> None:
"""Add operators to the anonymizer
Args:
operators: Operators to add to the anonymizer.
"""
self.operators.update(operators)
class PresidioAnonymizer(PresidioAnonymizerBase):
"""Anonymizer using Microsoft Presidio."""
def _anonymize(
self,
text: str,
language: Optional[str] = None,
allow_list: Optional[List[str]] = None,
conflict_resolution: Optional[ConflictResolutionStrategy] = None,
) -> str:
"""Anonymize text.
Each PII entity is replaced with a fake value.
Each time fake values will be different, as they are generated randomly.
PresidioAnonymizer has no built-in memory -
so it will not remember the effects of anonymizing previous texts.
>>> anonymizer = PresidioAnonymizer()
>>> anonymizer.anonymize("My name is John Doe. Hi John Doe!")
'My name is Noah Rhodes. Hi Noah Rhodes!'
>>> anonymizer.anonymize("My name is John Doe. Hi John Doe!")
'My name is Brett Russell. Hi Brett Russell!'
Args:
text: text to anonymize
language: language to use for analysis of PII
If None, the first (main) language in the list
of languages specified in the configuration will be used.
"""
if language is None:
language = self.supported_languages[0]
elif language not in self.supported_languages:
raise ValueError(
f"Language '{language}' is not supported. "
f"Supported languages are: {self.supported_languages}. "
"Change your language configuration file to add more languages."
)
# Check supported entities for given language
# e.g. IT_FISCAL_CODE is not supported for English in Presidio by default
# If you want to use it, you need to add a recognizer manually
supported_entities = []
for recognizer in self._analyzer.get_recognizers(language):
recognizer_dict = recognizer.to_dict()
supported_entities.extend(
[recognizer_dict["supported_entity"]]
if "supported_entity" in recognizer_dict
else recognizer_dict["supported_entities"]
)
entities_to_analyze = list(
set(supported_entities).intersection(set(self.analyzed_fields))
)
analyzer_results = self._analyzer.analyze(
text,
entities=entities_to_analyze,
language=language,
allow_list=allow_list,
)
filtered_analyzer_results = (
self._anonymizer._remove_conflicts_and_get_text_manipulation_data(
analyzer_results, conflict_resolution
)
)
anonymizer_results = self._anonymizer.anonymize(
text,
analyzer_results=analyzer_results,
operators=self.operators,
)
anonymizer_mapping = create_anonymizer_mapping(
text,
filtered_analyzer_results,
anonymizer_results,
)
return exact_matching_strategy(text, anonymizer_mapping)
class PresidioReversibleAnonymizer(PresidioAnonymizerBase, ReversibleAnonymizerBase):
"""Reversible Anonymizer using Microsoft Presidio."""
def __init__(
self,
analyzed_fields: Optional[List[str]] = None,
operators: Optional[Dict[str, OperatorConfig]] = None,
languages_config: Optional[Dict] = None,
add_default_faker_operators: bool = True,
faker_seed: Optional[int] = None,
):
if languages_config is None:
languages_config = DEFAULT_LANGUAGES_CONFIG
super().__init__(
analyzed_fields,
operators,
languages_config,
add_default_faker_operators,
faker_seed,
)
self._deanonymizer_mapping = DeanonymizerMapping()
@property
def deanonymizer_mapping(self) -> MappingDataType:
"""Return the deanonymizer mapping"""
return self._deanonymizer_mapping.data
@property
def anonymizer_mapping(self) -> MappingDataType:
"""Return the anonymizer mapping
This is just the reverse version of the deanonymizer mapping."""
return {
key: {v: k for k, v in inner_dict.items()}
for key, inner_dict in self.deanonymizer_mapping.items()
}
def _anonymize(
self,
text: str,
language: Optional[str] = None,
allow_list: Optional[List[str]] = None,
conflict_resolution: Optional[ConflictResolutionStrategy] = None,
) -> str:
"""Anonymize text.
Each PII entity is replaced with a fake value.
Each time fake values will be different, as they are generated randomly.
At the same time, we will create a mapping from each anonymized entity
back to its original text value.
Thanks to the built-in memory, all previously anonymised entities
will be remembered and replaced by the same fake values:
>>> anonymizer = PresidioReversibleAnonymizer()
>>> anonymizer.anonymize("My name is John Doe. Hi John Doe!")
'My name is Noah Rhodes. Hi Noah Rhodes!'
>>> anonymizer.anonymize("My name is John Doe. Hi John Doe!")
'My name is Noah Rhodes. Hi Noah Rhodes!'
Args:
text: text to anonymize
language: language to use for analysis of PII
If None, the first (main) language in the list
of languages specified in the configuration will be used.
"""
if language is None:
language = self.supported_languages[0]
if language not in self.supported_languages:
raise ValueError(
f"Language '{language}' is not supported. "
f"Supported languages are: {self.supported_languages}. "
"Change your language configuration file to add more languages."
)
# Check supported entities for given language
# e.g. IT_FISCAL_CODE is not supported for English in Presidio by default
# If you want to use it, you need to add a recognizer manually
supported_entities = []
for recognizer in self._analyzer.get_recognizers(language):
recognizer_dict = recognizer.to_dict()
supported_entities.extend(
[recognizer_dict["supported_entity"]]
if "supported_entity" in recognizer_dict
else recognizer_dict["supported_entities"]
)
entities_to_analyze = list(
set(supported_entities).intersection(set(self.analyzed_fields))
)
analyzer_results = self._analyzer.analyze(
text,
entities=entities_to_analyze,
language=language,
allow_list=allow_list,
)
filtered_analyzer_results = (
self._anonymizer._remove_conflicts_and_get_text_manipulation_data(
analyzer_results, conflict_resolution
)
)
anonymizer_results = self._anonymizer.anonymize(
text,
analyzer_results=analyzer_results,
operators=self.operators,
)
new_deanonymizer_mapping = create_anonymizer_mapping(
text,
filtered_analyzer_results,
anonymizer_results,
is_reversed=True,
)
self._deanonymizer_mapping.update(new_deanonymizer_mapping)
return exact_matching_strategy(text, self.anonymizer_mapping)
def _deanonymize(
self,
text_to_deanonymize: str,
deanonymizer_matching_strategy: Callable[
[str, MappingDataType], str
] = DEFAULT_DEANONYMIZER_MATCHING_STRATEGY,
) -> str:
"""Deanonymize text.
Each anonymized entity is replaced with its original value.
This method exploits the mapping created during the anonymization process.
Args:
text_to_deanonymize: text to deanonymize
deanonymizer_matching_strategy: function to use to match
anonymized entities with their original values and replace them.
"""
if not self._deanonymizer_mapping:
raise ValueError(
"Deanonymizer mapping is empty.",
"Please call anonymize() and anonymize some text first.",
)
text_to_deanonymize = deanonymizer_matching_strategy(
text_to_deanonymize, self.deanonymizer_mapping
)
return text_to_deanonymize
def reset_deanonymizer_mapping(self) -> None:
"""Reset the deanonymizer mapping"""
self._deanonymizer_mapping = DeanonymizerMapping()
def save_deanonymizer_mapping(self, file_path: Union[Path, str]) -> None:
"""Save the deanonymizer mapping to a JSON or YAML file.
Args:
file_path: Path to file to save the mapping to.
Example:
.. code-block:: python
anonymizer.save_deanonymizer_mapping(file_path="path/mapping.json")
"""
save_path = Path(file_path)
if save_path.suffix not in [".json", ".yaml"]:
raise ValueError(f"{save_path} must have an extension of .json or .yaml")
# Make sure parent directories exist
save_path.parent.mkdir(parents=True, exist_ok=True)
if save_path.suffix == ".json":
with open(save_path, "w") as f:
json.dump(self.deanonymizer_mapping, f, indent=2)
elif save_path.suffix.endswith((".yaml", ".yml")):
with open(save_path, "w") as f:
yaml.dump(self.deanonymizer_mapping, f, default_flow_style=False)
def load_deanonymizer_mapping(self, file_path: Union[Path, str]) -> None:
"""Load the deanonymizer mapping from a JSON or YAML file.
Args:
file_path: Path to file to load the mapping from.
Example:
.. code-block:: python
anonymizer.load_deanonymizer_mapping(file_path="path/mapping.json")
"""
load_path = Path(file_path)
if load_path.suffix not in [".json", ".yaml"]:
raise ValueError(f"{load_path} must have an extension of .json or .yaml")
if load_path.suffix == ".json":
with open(load_path, "r") as f:
loaded_mapping = json.load(f)
elif load_path.suffix.endswith((".yaml", ".yml")):
with open(load_path, "r") as f:
loaded_mapping = yaml.load(f, Loader=yaml.FullLoader)
self._deanonymizer_mapping.update(loaded_mapping)

View File

@ -1,7 +0,0 @@
"""**Fallacy Removal** Chain runs a self-review of logical fallacies
as determined by paper
[Robust and Explainable Identification of Logical Fallacies in Natural
Language Arguments](https://arxiv.org/pdf/2212.07425.pdf).
It is modeled after `Constitutional AI` and in the same format, but applying logical
fallacies as generalized rules to remove them in output.
"""

View File

@ -1,183 +0,0 @@
"""Chain for applying removals of logical fallacies."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.schema import BasePromptTemplate
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_experimental.fallacy_removal.fallacies import FALLACIES
from langchain_experimental.fallacy_removal.models import LogicalFallacy
from langchain_experimental.fallacy_removal.prompts import (
FALLACY_CRITIQUE_PROMPT,
FALLACY_REVISION_PROMPT,
)
class FallacyChain(Chain):
"""Chain for applying logical fallacy evaluations.
It is modeled after Constitutional AI and in same format, but
applying logical fallacies as generalized rules to remove in output.
Example:
.. code-block:: python
from langchain_community.llms import OpenAI
from langchain.chains import LLMChain
from langchain_experimental.fallacy import FallacyChain
from langchain_experimental.fallacy_removal.models import LogicalFallacy
llm = OpenAI()
qa_prompt = PromptTemplate(
template="Q: {question} A:",
input_variables=["question"],
)
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
fallacy_chain = FallacyChain.from_llm(
llm=llm,
chain=qa_chain,
logical_fallacies=[
LogicalFallacy(
fallacy_critique_request="Tell if this answer meets criteria.",
fallacy_revision_request=\
"Give an answer that meets better criteria.",
)
],
)
fallacy_chain.run(question="How do I know if the earth is round?")
"""
chain: LLMChain
logical_fallacies: List[LogicalFallacy]
fallacy_critique_chain: LLMChain
fallacy_revision_chain: LLMChain
return_intermediate_steps: bool = False
@classmethod
def get_fallacies(cls, names: Optional[List[str]] = None) -> List[LogicalFallacy]:
if names is None:
return list(FALLACIES.values())
else:
return [FALLACIES[name] for name in names]
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
chain: LLMChain,
fallacy_critique_prompt: BasePromptTemplate = FALLACY_CRITIQUE_PROMPT,
fallacy_revision_prompt: BasePromptTemplate = FALLACY_REVISION_PROMPT,
**kwargs: Any,
) -> "FallacyChain":
"""Create a chain from an LLM."""
fallacy_critique_chain = LLMChain(llm=llm, prompt=fallacy_critique_prompt)
fallacy_revision_chain = LLMChain(llm=llm, prompt=fallacy_revision_prompt)
return cls(
chain=chain,
fallacy_critique_chain=fallacy_critique_chain,
fallacy_revision_chain=fallacy_revision_chain,
**kwargs,
)
@property
def input_keys(self) -> List[str]:
"""Input keys."""
return self.chain.input_keys
@property
def output_keys(self) -> List[str]:
"""Output keys."""
if self.return_intermediate_steps:
return ["output", "fallacy_critiques_and_revisions", "initial_output"]
return ["output"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
response = self.chain.run(
**inputs,
callbacks=_run_manager.get_child("original"),
)
initial_response = response
input_prompt = self.chain.prompt.format(**inputs)
_run_manager.on_text(
text="Initial response: " + response + "\n\n",
verbose=self.verbose,
color="yellow",
)
fallacy_critiques_and_revisions = []
for logical_fallacy in self.logical_fallacies:
# Fallacy critique below
fallacy_raw_critique = self.fallacy_critique_chain.run(
input_prompt=input_prompt,
output_from_model=response,
fallacy_critique_request=logical_fallacy.fallacy_critique_request,
callbacks=_run_manager.get_child("fallacy_critique"),
)
fallacy_critique = self._parse_critique(
output_string=fallacy_raw_critique,
).strip()
# if fallacy critique contains "No fallacy critique needed" then done
if "no fallacy critique needed" in fallacy_critique.lower():
fallacy_critiques_and_revisions.append((fallacy_critique, ""))
continue
fallacy_revision = self.fallacy_revision_chain.run(
input_prompt=input_prompt,
output_from_model=response,
fallacy_critique_request=logical_fallacy.fallacy_critique_request,
fallacy_critique=fallacy_critique,
revision_request=logical_fallacy.fallacy_revision_request,
callbacks=_run_manager.get_child("fallacy_revision"),
).strip()
response = fallacy_revision
fallacy_critiques_and_revisions.append((fallacy_critique, fallacy_revision))
_run_manager.on_text(
text=f"Applying {logical_fallacy.name}..." + "\n\n",
verbose=self.verbose,
color="green",
)
_run_manager.on_text(
text="Logical Fallacy: " + fallacy_critique + "\n\n",
verbose=self.verbose,
color="blue",
)
_run_manager.on_text(
text="Updated response: " + fallacy_revision + "\n\n",
verbose=self.verbose,
color="yellow",
)
final_output: Dict[str, Any] = {"output": response}
if self.return_intermediate_steps:
final_output["initial_output"] = initial_response
final_output["fallacy_critiques_and_revisions"] = (
fallacy_critiques_and_revisions
)
return final_output
@staticmethod
def _parse_critique(output_string: str) -> str:
if "Fallacy Revision request:" not in output_string:
return output_string
output_string = output_string.split("Fallacy Revision request:")[0]
if "\n\n" in output_string:
output_string = output_string.split("\n\n")[0]
return output_string

View File

@ -1,209 +0,0 @@
"""Classification of logical fallacies in Natural Language Arguments \
from https://arxiv.org/pdf/2212.07425.pdf
"""
from typing import Dict
from langchain_experimental.fallacy_removal.models import LogicalFallacy
FALLACIES: Dict[str, LogicalFallacy] = {
"adhominem": LogicalFallacy(
name="adhominem",
fallacy_critique_request="Identify any feasible ways in which \
the assistants last response is attacking the character or \
personal traits of the person making an argument rather than \
addressing the actual argument and evidence.",
fallacy_revision_request="Please rewrite the assistant response\
to remove any attacking the character or personal traits of the\
person making an argument rather than addressing the actual\
argument and evidence.",
),
"adpopulum": LogicalFallacy(
name="adpopulum",
fallacy_critique_request="Identify ways in which the assistants\
last response may be asserting that something must be true or \
correct simply because many people believe it or do it, without \
actual facts or evidence to support the conclusion.",
fallacy_revision_request="Please rewrite the assistant response \
to remove any assertion that something must be true or correct \
simply because many people believe it or do it, without actual \
facts or evidence to support the conclusion.",
),
"appealtoemotion": LogicalFallacy(
name="appealtoemotion",
fallacy_critique_request="Identify all ways in which the \
assistants last response is an attempt to win support for an \
argument by exploiting or manipulating people's emotions rather \
than using facts and reason.",
fallacy_revision_request="Please rewrite the assistant response \
to remove any attempt to win support for an argument by \
exploiting or manipulating people's emotions rather than using \
facts and reason.",
),
"fallacyofextension": LogicalFallacy(
name="fallacyofextension",
fallacy_critique_request="Identify any ways in which the \
assitant's last response is making broad, sweeping generalizations\
and extending the implications of an argument far beyond what the \
initial premises support.",
fallacy_revision_request="Rewrite the assistant response to remove\
all broad, sweeping generalizations and extending the implications\
of an argument far beyond what the initial premises support.",
),
"intentionalfallacy": LogicalFallacy(
name="intentionalfallacy",
fallacy_critique_request="Identify any way in which the assistants\
last response may be falsely supporting a conclusion by claiming to\
understand an author or creator's subconscious intentions without \
clear evidence.",
fallacy_revision_request="Revise the assistants last response to \
remove any false support of a conclusion by claiming to understand\
an author or creator's subconscious intentions without clear \
evidence.",
),
"falsecausality": LogicalFallacy(
name="falsecausality",
fallacy_critique_request="Think carefully about whether the \
assistant's last response is jumping to conclusions about causation\
between events or circumstances without adequate evidence to infer \
a causal relationship.",
fallacy_revision_request="Please write a new version of the \
assistants response that removes jumping to conclusions about\
causation between events or circumstances without adequate \
evidence to infer a causal relationship.",
),
"falsedilemma": LogicalFallacy(
name="falsedilemma",
fallacy_critique_request="Identify any way in which the \
assistant's last response may be presenting only two possible options\
or sides to a situation when there are clearly other alternatives \
that have not been considered or addressed.",
fallacy_revision_request="Amend the assistants last response to \
remove any presentation of only two possible options or sides to a \
situation when there are clearly other alternatives that have not \
been considered or addressed.",
),
"hastygeneralization": LogicalFallacy(
name="hastygeneralization",
fallacy_critique_request="Identify any way in which the assistants\
last response is making a broad inference or generalization to \
situations, people, or circumstances that are not sufficiently \
similar based on a specific example or limited evidence.",
fallacy_revision_request="Please rewrite the assistant response to\
remove a broad inference or generalization to situations, people, \
or circumstances that are not sufficiently similar based on a \
specific example or limited evidence.",
),
"illogicalarrangement": LogicalFallacy(
name="illogicalarrangement",
fallacy_critique_request="Think carefully about any ways in which \
the assistant's last response is constructing an argument in a \
flawed, illogical way, so the premises do not connect to or lead\
to the conclusion properly.",
fallacy_revision_request="Please rewrite the assistants response\
so as to remove any construction of an argument that is flawed and\
illogical or if the premises do not connect to or lead to the \
conclusion properly.",
),
"fallacyofcredibility": LogicalFallacy(
name="fallacyofcredibility",
fallacy_critique_request="Discuss whether the assistant's last \
response was dismissing or attacking the credibility of the person\
making an argument rather than directly addressing the argument \
itself.",
fallacy_revision_request="Revise the assistants response so as \
that it refrains from dismissing or attacking the credibility of\
the person making an argument rather than directly addressing \
the argument itself.",
),
"circularreasoning": LogicalFallacy(
name="circularreasoning",
fallacy_critique_request="Discuss ways in which the assistants\
last response may be supporting a premise by simply repeating \
the premise as the conclusion without giving actual proof or \
evidence.",
fallacy_revision_request="Revise the assistants response if \
possible so that its not supporting a premise by simply \
repeating the premise as the conclusion without giving actual\
proof or evidence.",
),
"beggingthequestion": LogicalFallacy(
name="beggingthequestion",
fallacy_critique_request="Discuss ways in which the assistant's\
last response is restating the conclusion of an argument as a \
premise without providing actual support for the conclusion in \
the first place.",
fallacy_revision_request="Write a revision of the assistants \
response that refrains from restating the conclusion of an \
argument as a premise without providing actual support for the \
conclusion in the first place.",
),
"trickquestion": LogicalFallacy(
name="trickquestion",
fallacy_critique_request="Identify ways in which the \
assistants last response is asking a question that \
contains or assumes information that has not been proven or \
substantiated.",
fallacy_revision_request="Please write a new assistant \
response so that it does not ask a question that contains \
or assumes information that has not been proven or \
substantiated.",
),
"overapplier": LogicalFallacy(
name="overapplier",
fallacy_critique_request="Identify ways in which the assistants\
last response is applying a general rule or generalization to a \
specific case it was not meant to apply to.",
fallacy_revision_request="Please write a new response that does\
not apply a general rule or generalization to a specific case \
it was not meant to apply to.",
),
"equivocation": LogicalFallacy(
name="equivocation",
fallacy_critique_request="Read the assistants last response \
carefully and identify if it is using the same word or phrase \
in two different senses or contexts within an argument.",
fallacy_revision_request="Rewrite the assistant response so \
that it does not use the same word or phrase in two different \
senses or contexts within an argument.",
),
"amphiboly": LogicalFallacy(
name="amphiboly",
fallacy_critique_request="Critique the assistants last response\
to see if it is constructing sentences such that the grammar \
or structure is ambiguous, leading to multiple interpretations.",
fallacy_revision_request="Please rewrite the assistant response\
to remove any construction of sentences where the grammar or \
structure is ambiguous or leading to multiple interpretations.",
),
"accent": LogicalFallacy(
name="accent",
fallacy_critique_request="Discuss whether the assitant's response\
is misrepresenting an argument by shifting the emphasis of a word\
or phrase to give it a different meaning than intended.",
fallacy_revision_request="Please rewrite the AI model's response\
so that it is not misrepresenting an argument by shifting the \
emphasis of a word or phrase to give it a different meaning than\
intended.",
),
"composition": LogicalFallacy(
name="composition",
fallacy_critique_request="Discuss whether the assistant's \
response is erroneously inferring that something is true of \
the whole based on the fact that it is true of some part or \
parts.",
fallacy_revision_request="Please rewrite the assitant's response\
so that it is not erroneously inferring that something is true \
of the whole based on the fact that it is true of some part or \
parts.",
),
"division": LogicalFallacy(
name="division",
fallacy_critique_request="Discuss whether the assistant's last \
response is erroneously inferring that something is true of the \
parts based on the fact that it is true of the whole.",
fallacy_revision_request="Please rewrite the assitant's response\
so that it is not erroneously inferring that something is true \
of the parts based on the fact that it is true of the whole.",
),
}

View File

@ -1,11 +0,0 @@
"""Models for the Logical Fallacy Chain"""
from pydantic import BaseModel
class LogicalFallacy(BaseModel):
"""Logical fallacy."""
fallacy_critique_request: str
fallacy_revision_request: str
name: str = "Logical Fallacy"

View File

@ -1,135 +0,0 @@
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
fallacy_critique_example = PromptTemplate(
template="""Human: {input_prompt}
Model: {output_from_model}
Fallacy Critique Request: {fallacy_critique_request}
Fallacy Critique: {fallacy_critique}""",
input_variables=[
"input_prompt",
"output_from_model",
"fallacy_critique_request",
"fallacy_critique",
],
)
examples = [
{
"input_prompt": "If everyone says the Earth is round, \
how do I know that's correct?",
"output_from_model": "The earth is round because your \
teacher says it is",
"fallacy_critique_request": "Identify specific ways in\
which the models previous response had a logical fallacy. \
Also point out potential logical fallacies in the humans \
questions and responses. Examples of logical fallacies \
include but are not limited to ad hominem, ad populum, \
appeal to emotion and false causality.",
"fallacy_critique": "This statement contains the logical \
fallacy of Ad Verecundiam or Appeal to Authority. It is \
a fallacy because it asserts something to be true purely \
based on the authority of the source making the claim, \
without any actual evidence to support it. Fallacy \
Critique Needed",
"fallacy_revision_request": "Please rewrite the model \
response to remove all logical fallacies, and to \
politely point out any logical fallacies from the \
human.",
"fallacy_revision": "The earth is round based on \
evidence from observations of its curvature from high \
altitudes, photos from space showing its spherical shape, \
circumnavigation, and the fact that we see its rounded \
shadow on the moon during lunar eclipses.",
},
{
"input_prompt": "Should we invest more in our school \
music program? After all, studies show students \
involved in music perform better academically.",
"output_from_model": "I don't think we should invest \
more in the music program. Playing the piccolo won't \
teach someone better math skills.",
"fallacy_critique_request": "Identify specific ways \
in which the models previous response had a logical \
fallacy. Also point out potential logical fallacies \
in the humans questions and responses. Examples of \
logical fallacies include but are not limited to ad \
homimem, ad populum, appeal to emotion and false \
causality.",
"fallacy_critique": "This answer commits the division \
fallacy by rejecting the argument based on assuming \
capabilities true of the parts (playing an instrument \
like piccolo) also apply to the whole \
(the full music program). The answer focuses only on \
part of the music program rather than considering it \
as a whole. Fallacy Critique Needed.",
"fallacy_revision_request": "Please rewrite the model \
response to remove all logical fallacies, and to \
politely point out any logical fallacies from the human.",
"fallacy_revision": "While playing an instrument may \
teach discipline, more evidence is needed on whether \
music education courses improve critical thinking \
skills across subjects before determining if increased \
investment in the whole music program is warranted.",
},
]
FALLACY_CRITIQUE_PROMPT = FewShotPromptTemplate(
example_prompt=fallacy_critique_example,
examples=[
{k: v for k, v in e.items() if k != "fallacy_revision_request"}
for e in examples
],
prefix="Below is a conversation between a human and an \
AI assistant. If there is no material critique of the \
model output, append to the end of the Fallacy Critique: \
'No fallacy critique needed.' If there is material \
critique \
of the model output, append to the end of the Fallacy \
Critique: 'Fallacy Critique needed.'",
suffix="""Human: {input_prompt}
Model: {output_from_model}
Fallacy Critique Request: {fallacy_critique_request}
Fallacy Critique:""",
example_separator="\n === \n",
input_variables=["input_prompt", "output_from_model", "fallacy_critique_request"],
)
FALLACY_REVISION_PROMPT = FewShotPromptTemplate(
example_prompt=fallacy_critique_example,
examples=examples,
prefix="Below is a conversation between a human and \
an AI assistant.",
suffix="""Human: {input_prompt}
Model: {output_from_model}
Fallacy Critique Request: {fallacy_critique_request}
Fallacy Critique: {fallacy_critique}
If the fallacy critique does not identify anything worth \
changing, ignore the Fallacy Revision Request and do not \
make any revisions. Instead, return "No revisions needed".
If the fallacy critique does identify something worth \
changing, please revise the model response based on the \
Fallacy Revision Request.
Fallacy Revision Request: {fallacy_revision_request}
Fallacy Revision:""",
example_separator="\n === \n",
input_variables=[
"input_prompt",
"output_from_model",
"fallacy_critique_request",
"fallacy_critique",
"fallacy_revision_request",
],
)

View File

@ -1,6 +0,0 @@
"""**Generative Agent** primitives."""
from langchain_experimental.generative_agents.generative_agent import GenerativeAgent
from langchain_experimental.generative_agents.memory import GenerativeAgentMemory
__all__ = ["GenerativeAgent", "GenerativeAgentMemory"]

View File

@ -1,249 +0,0 @@
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from langchain.chains import LLMChain
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel, ConfigDict, Field
from langchain_experimental.generative_agents.memory import GenerativeAgentMemory
class GenerativeAgent(BaseModel):
"""Agent as a character with memory and innate characteristics."""
name: str
"""The character's name."""
age: Optional[int] = None
"""The optional age of the character."""
traits: str = "N/A"
"""Permanent traits to ascribe to the character."""
status: str
"""The traits of the character you wish not to change."""
memory: GenerativeAgentMemory
"""The memory object that combines relevance, recency, and 'importance'."""
llm: BaseLanguageModel
"""The underlying language model."""
verbose: bool = False
summary: str = "" #: :meta private:
"""Stateful self-summary generated via reflection on the character's memory."""
summary_refresh_seconds: int = 3600 #: :meta private:
"""How frequently to re-generate the summary."""
last_refreshed: datetime = Field(default_factory=datetime.now) # : :meta private:
"""The last time the character's summary was regenerated."""
daily_summaries: List[str] = Field(default_factory=list) # : :meta private:
"""Summary of the events in the plan that the agent took."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
# LLM-related methods
@staticmethod
def _parse_list(text: str) -> List[str]:
"""Parse a newline-separated string into a list of strings."""
lines = re.split(r"\n", text.strip())
return [re.sub(r"^\s*\d+\.\s*", "", line).strip() for line in lines]
def chain(self, prompt: PromptTemplate) -> LLMChain:
"""Create a chain with the same settings as the agent."""
return LLMChain(
llm=self.llm, prompt=prompt, verbose=self.verbose, memory=self.memory
)
def _get_entity_from_observation(self, observation: str) -> str:
prompt = PromptTemplate.from_template(
"What is the observed entity in the following observation? {observation}"
+ "\nEntity="
)
return self.chain(prompt).run(observation=observation).strip()
def _get_entity_action(self, observation: str, entity_name: str) -> str:
prompt = PromptTemplate.from_template(
"What is the {entity} doing in the following observation? {observation}"
+ "\nThe {entity} is"
)
return (
self.chain(prompt).run(entity=entity_name, observation=observation).strip()
)
def summarize_related_memories(self, observation: str) -> str:
"""Summarize memories that are most relevant to an observation."""
prompt = PromptTemplate.from_template(
"""
{q1}?
Context from memory:
{relevant_memories}
Relevant context:
"""
)
entity_name = self._get_entity_from_observation(observation)
entity_action = self._get_entity_action(observation, entity_name)
q1 = f"What is the relationship between {self.name} and {entity_name}"
q2 = f"{entity_name} is {entity_action}"
return self.chain(prompt=prompt).run(q1=q1, queries=[q1, q2]).strip()
def _generate_reaction(
self, observation: str, suffix: str, now: Optional[datetime] = None
) -> str:
"""React to a given observation or dialogue act."""
prompt = PromptTemplate.from_template(
"{agent_summary_description}"
+ "\nIt is {current_time}."
+ "\n{agent_name}'s status: {agent_status}"
+ "\nSummary of relevant context from {agent_name}'s memory:"
+ "\n{relevant_memories}"
+ "\nMost recent observations: {most_recent_memories}"
+ "\nObservation: {observation}"
+ "\n\n"
+ suffix
)
agent_summary_description = self.get_summary(now=now)
relevant_memories_str = self.summarize_related_memories(observation)
current_time_str = (
datetime.now().strftime("%B %d, %Y, %I:%M %p")
if now is None
else now.strftime("%B %d, %Y, %I:%M %p")
)
kwargs: Dict[str, Any] = dict(
agent_summary_description=agent_summary_description,
current_time=current_time_str,
relevant_memories=relevant_memories_str,
agent_name=self.name,
observation=observation,
agent_status=self.status,
)
consumed_tokens = self.llm.get_num_tokens(
prompt.format(most_recent_memories="", **kwargs)
)
kwargs[self.memory.most_recent_memories_token_key] = consumed_tokens
return self.chain(prompt=prompt).run(**kwargs).strip()
def _clean_response(self, text: str) -> str:
return re.sub(f"^{self.name} ", "", text.strip()).strip()
def generate_reaction(
self, observation: str, now: Optional[datetime] = None
) -> Tuple[bool, str]:
"""React to a given observation."""
call_to_action_template = (
"Should {agent_name} react to the observation, and if so,"
+ " what would be an appropriate reaction? Respond in one line."
+ ' If the action is to engage in dialogue, write:\nSAY: "what to say"'
+ "\notherwise, write:\nREACT: {agent_name}'s reaction (if anything)."
+ "\nEither do nothing, react, or say something but not both.\n\n"
)
full_result = self._generate_reaction(
observation, call_to_action_template, now=now
)
result = full_result.strip().split("\n")[0]
# AAA
self.memory.save_context(
{},
{
self.memory.add_memory_key: f"{self.name} observed "
f"{observation} and reacted by {result}",
self.memory.now_key: now,
},
)
if "REACT:" in result:
reaction = self._clean_response(result.split("REACT:")[-1])
return False, f"{self.name} {reaction}"
if "SAY:" in result:
said_value = self._clean_response(result.split("SAY:")[-1])
return True, f"{self.name} said {said_value}"
else:
return False, result
def generate_dialogue_response(
self, observation: str, now: Optional[datetime] = None
) -> Tuple[bool, str]:
"""React to a given observation."""
call_to_action_template = (
"What would {agent_name} say? To end the conversation, write:"
' GOODBYE: "what to say". Otherwise to continue the conversation,'
' write: SAY: "what to say next"\n\n'
)
full_result = self._generate_reaction(
observation, call_to_action_template, now=now
)
result = full_result.strip().split("\n")[0]
if "GOODBYE:" in result:
farewell = self._clean_response(result.split("GOODBYE:")[-1])
self.memory.save_context(
{},
{
self.memory.add_memory_key: f"{self.name} observed "
f"{observation} and said {farewell}",
self.memory.now_key: now,
},
)
return False, f"{self.name} said {farewell}"
if "SAY:" in result:
response_text = self._clean_response(result.split("SAY:")[-1])
self.memory.save_context(
{},
{
self.memory.add_memory_key: f"{self.name} observed "
f"{observation} and said {response_text}",
self.memory.now_key: now,
},
)
return True, f"{self.name} said {response_text}"
else:
return False, result
######################################################
# Agent stateful' summary methods. #
# Each dialog or response prompt includes a header #
# summarizing the agent's self-description. This is #
# updated periodically through probing its memories #
######################################################
def _compute_agent_summary(self) -> str:
""""""
prompt = PromptTemplate.from_template(
"How would you summarize {name}'s core characteristics given the"
+ " following statements:\n"
+ "{relevant_memories}"
+ "Do not embellish."
+ "\n\nSummary: "
)
# The agent seeks to think about their core characteristics.
return (
self.chain(prompt)
.run(name=self.name, queries=[f"{self.name}'s core characteristics"])
.strip()
)
def get_summary(
self, force_refresh: bool = False, now: Optional[datetime] = None
) -> str:
"""Return a descriptive summary of the agent."""
current_time = datetime.now() if now is None else now
since_refresh = (current_time - self.last_refreshed).seconds
if (
not self.summary
or since_refresh >= self.summary_refresh_seconds
or force_refresh
):
self.summary = self._compute_agent_summary()
self.last_refreshed = current_time
age = self.age if self.age is not None else "N/A"
return (
f"Name: {self.name} (age: {age})"
+ f"\nInnate traits: {self.traits}"
+ f"\n{self.summary}"
)
def get_full_header(
self, force_refresh: bool = False, now: Optional[datetime] = None
) -> str:
"""Return a full header of the agent's status, summary, and current time."""
now = datetime.now() if now is None else now
summary = self.get_summary(force_refresh=force_refresh, now=now)
current_time_str = now.strftime("%B %d, %Y, %I:%M %p")
return (
f"{summary}\nIt is {current_time_str}.\n{self.name}'s status: {self.status}"
)

View File

@ -1,296 +0,0 @@
import logging
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from langchain.chains import LLMChain
from langchain.retrievers import TimeWeightedVectorStoreRetriever
from langchain.schema import BaseMemory, Document
from langchain.utils import mock_now
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
logger = logging.getLogger(__name__)
class GenerativeAgentMemory(BaseMemory):
"""Memory for the generative agent."""
llm: BaseLanguageModel
"""The core language model."""
memory_retriever: TimeWeightedVectorStoreRetriever
"""The retriever to fetch related memories."""
verbose: bool = False
reflection_threshold: Optional[float] = None
"""When aggregate_importance exceeds reflection_threshold, stop to reflect."""
current_plan: List[str] = []
"""The current plan of the agent."""
# A weight of 0.15 makes this less important than it
# would be otherwise, relative to salience and time
importance_weight: float = 0.15
"""How much weight to assign the memory importance."""
aggregate_importance: float = 0.0 # : :meta private:
"""Track the sum of the 'importance' of recent memories.
Triggers reflection when it reaches reflection_threshold."""
max_tokens_limit: int = 1200 # : :meta private:
# input keys
queries_key: str = "queries"
most_recent_memories_token_key: str = "recent_memories_token"
add_memory_key: str = "add_memory"
# output keys
relevant_memories_key: str = "relevant_memories"
relevant_memories_simple_key: str = "relevant_memories_simple"
most_recent_memories_key: str = "most_recent_memories"
now_key: str = "now"
reflecting: bool = False
def chain(self, prompt: PromptTemplate) -> LLMChain:
return LLMChain(llm=self.llm, prompt=prompt, verbose=self.verbose)
@staticmethod
def _parse_list(text: str) -> List[str]:
"""Parse a newline-separated string into a list of strings."""
lines = re.split(r"\n", text.strip())
lines = [line for line in lines if line.strip()] # remove empty lines
return [re.sub(r"^\s*\d+\.\s*", "", line).strip() for line in lines]
def _get_topics_of_reflection(self, last_k: int = 50) -> List[str]:
"""Return the 3 most salient high-level questions about recent observations."""
prompt = PromptTemplate.from_template(
"{observations}\n\n"
"Given only the information above, what are the 3 most salient "
"high-level questions we can answer about the subjects in the statements?\n"
"Provide each question on a new line."
)
observations = self.memory_retriever.memory_stream[-last_k:]
observation_str = "\n".join(
[self._format_memory_detail(o) for o in observations]
)
result = self.chain(prompt).run(observations=observation_str)
return self._parse_list(result)
def _get_insights_on_topic(
self, topic: str, now: Optional[datetime] = None
) -> List[str]:
"""Generate 'insights' on a topic of reflection, based on pertinent memories."""
prompt = PromptTemplate.from_template(
"Statements relevant to: '{topic}'\n"
"---\n"
"{related_statements}\n"
"---\n"
"What 5 high-level novel insights can you infer from the above statements "
"that are relevant for answering the following question?\n"
"Do not include any insights that are not relevant to the question.\n"
"Do not repeat any insights that have already been made.\n\n"
"Question: {topic}\n\n"
"(example format: insight (because of 1, 5, 3))\n"
)
related_memories = self.fetch_memories(topic, now=now)
related_statements = "\n".join(
[
self._format_memory_detail(memory, prefix=f"{i+1}. ")
for i, memory in enumerate(related_memories)
]
)
result = self.chain(prompt).run(
topic=topic, related_statements=related_statements
)
# TODO: Parse the connections between memories and insights
return self._parse_list(result)
def pause_to_reflect(self, now: Optional[datetime] = None) -> List[str]:
"""Reflect on recent observations and generate 'insights'."""
if self.verbose:
logger.info("Character is reflecting")
new_insights = []
topics = self._get_topics_of_reflection()
for topic in topics:
insights = self._get_insights_on_topic(topic, now=now)
for insight in insights:
self.add_memory(insight, now=now)
new_insights.extend(insights)
return new_insights
def _score_memory_importance(self, memory_content: str) -> float:
"""Score the absolute importance of the given memory."""
prompt = PromptTemplate.from_template(
"On the scale of 1 to 10, where 1 is purely mundane"
+ " (e.g., brushing teeth, making bed) and 10 is"
+ " extremely poignant (e.g., a break up, college"
+ " acceptance), rate the likely poignancy of the"
+ " following piece of memory. Respond with a single integer."
+ "\nMemory: {memory_content}"
+ "\nRating: "
)
score = self.chain(prompt).run(memory_content=memory_content).strip()
if self.verbose:
logger.info(f"Importance score: {score}")
match = re.search(r"^\D*(\d+)", score)
if match:
return (float(match.group(1)) / 10) * self.importance_weight
else:
return 0.0
def _score_memories_importance(self, memory_content: str) -> List[float]:
"""Score the absolute importance of the given memory."""
prompt = PromptTemplate.from_template(
"On the scale of 1 to 10, where 1 is purely mundane"
+ " (e.g., brushing teeth, making bed) and 10 is"
+ " extremely poignant (e.g., a break up, college"
+ " acceptance), rate the likely poignancy of the"
+ " following piece of memory. Always answer with only a list of numbers."
+ " If just given one memory still respond in a list."
+ " Memories are separated by semi colans (;)"
+ "\nMemories: {memory_content}"
+ "\nRating: "
)
scores = self.chain(prompt).run(memory_content=memory_content).strip()
if self.verbose:
logger.info(f"Importance scores: {scores}")
# Split into list of strings and convert to floats
scores_list = [float(x) for x in scores.split(";")]
return scores_list
def add_memories(
self, memory_content: str, now: Optional[datetime] = None
) -> List[str]:
"""Add an observations or memories to the agent's memory."""
importance_scores = self._score_memories_importance(memory_content)
self.aggregate_importance += max(importance_scores)
memory_list = memory_content.split(";")
documents = []
for i in range(len(memory_list)):
documents.append(
Document(
page_content=memory_list[i],
metadata={"importance": importance_scores[i]},
)
)
result = self.memory_retriever.add_documents(documents, current_time=now)
# After an agent has processed a certain amount of memories (as measured by
# aggregate importance), it is time to reflect on recent events to add
# more synthesized memories to the agent's memory stream.
if (
self.reflection_threshold is not None
and self.aggregate_importance > self.reflection_threshold
and not self.reflecting
):
self.reflecting = True
self.pause_to_reflect(now=now)
# Hack to clear the importance from reflection
self.aggregate_importance = 0.0
self.reflecting = False
return result
def add_memory(
self, memory_content: str, now: Optional[datetime] = None
) -> List[str]:
"""Add an observation or memory to the agent's memory."""
importance_score = self._score_memory_importance(memory_content)
self.aggregate_importance += importance_score
document = Document(
page_content=memory_content, metadata={"importance": importance_score}
)
result = self.memory_retriever.add_documents([document], current_time=now)
# After an agent has processed a certain amount of memories (as measured by
# aggregate importance), it is time to reflect on recent events to add
# more synthesized memories to the agent's memory stream.
if (
self.reflection_threshold is not None
and self.aggregate_importance > self.reflection_threshold
and not self.reflecting
):
self.reflecting = True
self.pause_to_reflect(now=now)
# Hack to clear the importance from reflection
self.aggregate_importance = 0.0
self.reflecting = False
return result
def fetch_memories(
self, observation: str, now: Optional[datetime] = None
) -> List[Document]:
"""Fetch related memories."""
if now is not None:
with mock_now(now):
return self.memory_retriever.invoke(observation)
else:
return self.memory_retriever.invoke(observation)
def format_memories_detail(self, relevant_memories: List[Document]) -> str:
content = []
for mem in relevant_memories:
content.append(self._format_memory_detail(mem, prefix="- "))
return "\n".join([f"{mem}" for mem in content])
def _format_memory_detail(self, memory: Document, prefix: str = "") -> str:
created_time = memory.metadata["created_at"].strftime("%B %d, %Y, %I:%M %p")
return f"{prefix}[{created_time}] {memory.page_content.strip()}"
def format_memories_simple(self, relevant_memories: List[Document]) -> str:
return "; ".join([f"{mem.page_content}" for mem in relevant_memories])
def _get_memories_until_limit(self, consumed_tokens: int) -> str:
"""Reduce the number of tokens in the documents."""
result = []
for doc in self.memory_retriever.memory_stream[::-1]:
if consumed_tokens >= self.max_tokens_limit:
break
consumed_tokens += self.llm.get_num_tokens(doc.page_content)
if consumed_tokens < self.max_tokens_limit:
result.append(doc)
return self.format_memories_simple(result)
@property
def memory_variables(self) -> List[str]:
"""Input keys this memory class will load dynamically."""
return []
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return key-value pairs given the text input to the chain."""
queries = inputs.get(self.queries_key)
now = inputs.get(self.now_key)
if queries is not None:
relevant_memories = [
mem for query in queries for mem in self.fetch_memories(query, now=now)
]
return {
self.relevant_memories_key: self.format_memories_detail(
relevant_memories
),
self.relevant_memories_simple_key: self.format_memories_simple(
relevant_memories
),
}
most_recent_memories_token = inputs.get(self.most_recent_memories_token_key)
if most_recent_memories_token is not None:
return {
self.most_recent_memories_key: self._get_memories_until_limit(
most_recent_memories_token
)
}
return {}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
"""Save the context of this model run to memory."""
# TODO: fix the save memory key
mem = outputs.get(self.add_memory_key)
now = outputs.get(self.now_key)
if mem:
self.add_memory(mem, now=now)
def clear(self) -> None:
"""Clear memory contents."""
# TODO

View File

@ -1,13 +0,0 @@
"""**Graph Transformers** transform Documents into Graph Documents."""
from langchain_experimental.graph_transformers.diffbot import DiffbotGraphTransformer
from langchain_experimental.graph_transformers.gliner import GlinerGraphTransformer
from langchain_experimental.graph_transformers.llm import LLMGraphTransformer
from langchain_experimental.graph_transformers.relik import RelikGraphTransformer
__all__ = [
"DiffbotGraphTransformer",
"LLMGraphTransformer",
"RelikGraphTransformer",
"GlinerGraphTransformer",
]

View File

@ -1,365 +0,0 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import requests
from langchain.utils import get_from_env
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_core.documents import Document
class TypeOption(str, Enum):
FACTS = "facts"
ENTITIES = "entities"
SENTIMENT = "sentiment"
def format_property_key(s: str) -> str:
"""Formats a string to be used as a property key."""
words = s.split()
if not words:
return s
first_word = words[0].lower()
capitalized_words = [word.capitalize() for word in words[1:]]
return "".join([first_word] + capitalized_words)
class NodesList:
"""List of nodes with associated properties.
Attributes:
nodes (Dict[Tuple, Any]): Stores nodes as keys and their properties as values.
Each key is a tuple where the first element is the
node ID and the second is the node type.
"""
def __init__(self) -> None:
self.nodes: Dict[Tuple[Union[str, int], str], Any] = dict()
def add_node_property(
self, node: Tuple[Union[str, int], str], properties: Dict[str, Any]
) -> None:
"""
Adds or updates node properties.
If the node does not exist in the list, it's added along with its properties.
If the node already exists, its properties are updated with the new values.
Args:
node (Tuple): A tuple containing the node ID and node type.
properties (Dict): A dictionary of properties to add or update for the node.
"""
if node not in self.nodes:
self.nodes[node] = properties
else:
self.nodes[node].update(properties)
def return_node_list(self) -> List[Node]:
"""
Returns the nodes as a list of Node objects.
Each Node object will have its ID, type, and properties populated.
Returns:
List[Node]: A list of Node objects.
"""
nodes = [
Node(id=key[0], type=key[1], properties=self.nodes[key])
for key in self.nodes
]
return nodes
# Properties that should be treated as node properties instead of relationships
FACT_TO_PROPERTY_TYPE = [
"Date",
"Number",
"Job title",
"Cause of death",
"Organization type",
"Academic title",
]
schema_mapping = [
("HEADQUARTERS", "ORGANIZATION_LOCATIONS"),
("RESIDENCE", "PERSON_LOCATION"),
("ALL_PERSON_LOCATIONS", "PERSON_LOCATION"),
("CHILD", "HAS_CHILD"),
("PARENT", "HAS_PARENT"),
("CUSTOMERS", "HAS_CUSTOMER"),
("SKILLED_AT", "INTERESTED_IN"),
]
class SimplifiedSchema:
"""Simplified schema mapping.
Attributes:
schema (Dict): A dictionary containing the mapping to simplified schema types.
"""
def __init__(self) -> None:
"""Initializes the schema dictionary based on the predefined list."""
self.schema = dict()
for row in schema_mapping:
self.schema[row[0]] = row[1]
def get_type(self, type: str) -> str:
"""
Retrieves the simplified schema type for a given original type.
Args:
type (str): The original schema type to find the simplified type for.
Returns:
str: The simplified schema type if it exists;
otherwise, returns the original type.
"""
try:
return self.schema[type]
except KeyError:
return type
class DiffbotGraphTransformer:
"""Transform documents into graph documents using Diffbot NLP API.
A graph document transformation system takes a sequence of Documents and returns a
sequence of Graph Documents.
Example:
.. code-block:: python
from langchain_experimental.graph_transformers import DiffbotGraphTransformer
from langchain_core.documents import Document
diffbot_api_key = "DIFFBOT_API_KEY"
diffbot_nlp = DiffbotGraphTransformer(diffbot_api_key=diffbot_api_key)
document = Document(page_content="Mike Tunge is the CEO of Diffbot.")
graph_documents = diffbot_nlp.convert_to_graph_documents([document])
"""
def __init__(
self,
diffbot_api_key: Optional[str] = None,
fact_confidence_threshold: float = 0.7,
include_qualifiers: bool = True,
include_evidence: bool = True,
simplified_schema: bool = True,
extract_types: List[TypeOption] = [TypeOption.FACTS],
*,
include_confidence: bool = False,
) -> None:
"""
Initialize the graph transformer with various options.
Args:
diffbot_api_key (str):
The API key for Diffbot's NLP services.
fact_confidence_threshold (float):
Minimum confidence level for facts to be included.
include_qualifiers (bool):
Whether to include qualifiers in the relationships.
include_evidence (bool):
Whether to include evidence for the relationships.
simplified_schema (bool):
Whether to use a simplified schema for relationships.
extract_types (List[TypeOption]):
A list of data types to extract. Facts, entities, and
sentiment are supported. By default, the option is
set to facts. A fact represents a combination of
source and target nodes with a relationship type.
include_confidence (bool):
Whether to include confidence scores on nodes and rels
"""
self.diffbot_api_key = diffbot_api_key or get_from_env(
"diffbot_api_key", "DIFFBOT_API_KEY"
)
self.fact_threshold_confidence = fact_confidence_threshold
self.include_qualifiers = include_qualifiers
self.include_evidence = include_evidence
self.include_confidence = include_confidence
self.simplified_schema = None
if simplified_schema:
self.simplified_schema = SimplifiedSchema()
if not extract_types:
raise ValueError(
"`extract_types` cannot be an empty array. "
"Allowed values are 'facts', 'entities', or both."
)
self.extract_types = extract_types
def nlp_request(self, text: str) -> Dict[str, Any]:
"""
Make an API request to the Diffbot NLP endpoint.
Args:
text (str): The text to be processed.
Returns:
Dict[str, Any]: The JSON response from the API.
"""
# Relationship extraction only works for English
payload = {
"content": text,
"lang": "en",
}
FIELDS = ",".join(self.extract_types)
HOST = "nl.diffbot.com"
url = (
f"https://{HOST}/v1/?fields={FIELDS}&"
f"token={self.diffbot_api_key}&language=en"
)
result = requests.post(url, data=payload)
return result.json()
def process_response(
self, payload: Dict[str, Any], document: Document
) -> GraphDocument:
"""
Transform the Diffbot NLP response into a GraphDocument.
Args:
payload (Dict[str, Any]): The JSON response from Diffbot's NLP API.
document (Document): The original document.
Returns:
GraphDocument: The transformed document as a graph.
"""
# Return empty result if there are no facts
if ("facts" not in payload or not payload["facts"]) and (
"entities" not in payload or not payload["entities"]
):
return GraphDocument(nodes=[], relationships=[], source=document)
# Nodes are a custom class because we need to deduplicate
nodes_list = NodesList()
if "entities" in payload and payload["entities"]:
for record in payload["entities"]:
# Ignore if it doesn't have a type
if not record["allTypes"]:
continue
# Define source node
source_id = (
record["allUris"][0] if record["allUris"] else record["name"]
)
source_label = record["allTypes"][0]["name"].capitalize()
source_name = record["name"]
nodes_list.add_node_property(
(source_id, source_label), {"name": source_name}
)
if record.get("sentiment") is not None:
nodes_list.add_node_property(
(source_id, source_label),
{"sentiment": record.get("sentiment")},
)
if self.include_confidence:
nodes_list.add_node_property(
(source_id, source_label),
{"confidence": record.get("confidence")},
)
relationships = list()
# Relationships are a list because we don't deduplicate nor anything else
if "facts" in payload and payload["facts"]:
for record in payload["facts"]:
# Skip if the fact is below the threshold confidence
if record["confidence"] < self.fact_threshold_confidence:
continue
# TODO: It should probably be treated as a node property
if not record["value"]["allTypes"]:
continue
# Define source node
source_id = (
record["entity"]["allUris"][0]
if record["entity"]["allUris"]
else record["entity"]["name"]
)
source_label = record["entity"]["allTypes"][0]["name"].capitalize()
source_name = record["entity"]["name"]
source_node = Node(id=source_id, type=source_label)
nodes_list.add_node_property(
(source_id, source_label), {"name": source_name}
)
# Define target node
target_id = (
record["value"]["allUris"][0]
if record["value"]["allUris"]
else record["value"]["name"]
)
target_label = record["value"]["allTypes"][0]["name"].capitalize()
target_name = record["value"]["name"]
# Some facts are better suited as node properties
if target_label in FACT_TO_PROPERTY_TYPE:
nodes_list.add_node_property(
(source_id, source_label),
{format_property_key(record["property"]["name"]): target_name},
)
else: # Define relationship
# Define target node object
target_node = Node(id=target_id, type=target_label)
nodes_list.add_node_property(
(target_id, target_label), {"name": target_name}
)
# Define relationship type
rel_type = record["property"]["name"].replace(" ", "_").upper()
if self.simplified_schema:
rel_type = self.simplified_schema.get_type(rel_type)
# Relationship qualifiers/properties
rel_properties = dict()
relationship_evidence = [
el["passage"] for el in record["evidence"]
][0]
if self.include_evidence:
rel_properties.update({"evidence": relationship_evidence})
if self.include_confidence:
rel_properties.update({"confidence": record["confidence"]})
if self.include_qualifiers and record.get("qualifiers"):
for property in record["qualifiers"]:
prop_key = format_property_key(property["property"]["name"])
rel_properties[prop_key] = property["value"]["name"]
relationship = Relationship(
source=source_node,
target=target_node,
type=rel_type,
properties=rel_properties,
)
relationships.append(relationship)
return GraphDocument(
nodes=nodes_list.return_node_list(),
relationships=relationships,
source=document,
)
def convert_to_graph_documents(
self, documents: Sequence[Document]
) -> List[GraphDocument]:
"""Convert a sequence of documents into graph documents.
Args:
documents (Sequence[Document]): The original documents.
kwargs: Additional keyword arguments.
Returns:
Sequence[GraphDocument]: The transformed documents as graphs.
"""
results = []
for document in documents:
raw_results = self.nlp_request(document.page_content)
graph_document = self.process_response(raw_results, document)
results.append(graph_document)
return results

View File

@ -1,175 +0,0 @@
from typing import Any, Dict, List, Sequence, Union
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_core.documents import Document
DEFAULT_NODE_TYPE = "Node"
class GlinerGraphTransformer:
"""
A transformer class for converting documents into graph structures
using the GLiNER and GLiREL models.
This class leverages GLiNER for named entity recognition and GLiREL for
relationship extraction from text documents, converting them into a graph format.
The extracted entities and relationships are filtered based on specified
confidence thresholds and allowed types.
For more details on GLiNER and GLiREL, visit their respective repositories:
GLiNER: https://github.com/urchade/GLiNER
GLiREL: https://github.com/jackboyla/GLiREL/tree/main
Args:
allowed_nodes (List[str]): A list of allowed node types for entity extraction.
allowed_relationships (Union[List[str], Dict[str, Any]]): A list of allowed
relationship types or a dictionary with additional configuration for
relationship extraction.
gliner_model (str): The name of the pretrained GLiNER model to use.
Default is "urchade/gliner_mediumv2.1".
glirel_model (str): The name of the pretrained GLiREL model to use.
Default is "jackboyla/glirel_beta".
entity_confidence_threshold (float): The confidence threshold for
filtering extracted entities. Default is 0.1.
relationship_confidence_threshold (float): The confidence threshold for
filtering extracted relationships. Default is 0.1.
device (str): The device to use for model inference ('cpu' or 'cuda').
Default is "cpu".
ignore_self_loops (bool): Whether to ignore relationships where the
source and target nodes are the same. Default is True.
"""
def __init__(
self,
allowed_nodes: List[str],
allowed_relationships: Union[List[str], Dict[str, Any]],
gliner_model: str = "urchade/gliner_mediumv2.1",
glirel_model: str = "jackboyla/glirel_beta",
entity_confidence_threshold: float = 0.1,
relationship_confidence_threshold: float = 0.1,
device: str = "cpu",
ignore_self_loops: bool = True,
) -> None:
try:
import gliner_spacy # type: ignore # noqa: F401
except ImportError:
raise ImportError(
"Could not import relik python package. "
"Please install it with `pip install gliner-spacy`."
)
try:
import spacy # type: ignore
except ImportError:
raise ImportError(
"Could not import relik python package. "
"Please install it with `pip install spacy`."
)
try:
import glirel # type: ignore # noqa: F401
except ImportError:
raise ImportError(
"Could not import relik python package. "
"Please install it with `pip install gliner`."
)
gliner_config = {
"gliner_model": gliner_model,
"chunk_size": 250,
"labels": allowed_nodes,
"style": "ent",
"threshold": entity_confidence_threshold,
"map_location": device,
}
glirel_config = {"model": glirel_model, "device": device}
self.nlp = spacy.blank("en")
# Add the GliNER component to the pipeline
self.nlp.add_pipe("gliner_spacy", config=gliner_config)
# Add the GLiREL component to the pipeline
self.nlp.add_pipe("glirel", after="gliner_spacy", config=glirel_config)
self.allowed_relationships = (
{"glirel_labels": allowed_relationships}
if isinstance(allowed_relationships, list)
else allowed_relationships
)
self.relationship_confidence_threshold = relationship_confidence_threshold
self.ignore_self_loops = ignore_self_loops
def process_document(self, document: Document) -> GraphDocument:
# Extraction as SpaCy pipeline
docs = list(
self.nlp.pipe(
[(document.page_content, self.allowed_relationships)], as_tuples=True
)
)
# Convert nodes
nodes = []
for node in docs[0][0].ents:
nodes.append(
Node(
id=node.text,
type=node.label_,
)
)
# Convert relationships
relationships = []
relations = docs[0][0]._.relations
# Deduplicate based on label, head text, and tail text
# Use a list comprehension with max() function
deduplicated_rels = []
seen = set()
for item in relations:
key = (tuple(item["head_text"]), tuple(item["tail_text"]), item["label"])
if key not in seen:
seen.add(key)
# Find all items matching the current key
matching_items = [
rel
for rel in relations
if (tuple(rel["head_text"]), tuple(rel["tail_text"]), rel["label"])
== key
]
# Find the item with the maximum score
max_item = max(matching_items, key=lambda x: x["score"])
deduplicated_rels.append(max_item)
for rel in deduplicated_rels:
# Relationship confidence threshold
if rel["score"] < self.relationship_confidence_threshold:
continue
source_id = docs[0][0][rel["head_pos"][0] : rel["head_pos"][1]].text
target_id = docs[0][0][rel["tail_pos"][0] : rel["tail_pos"][1]].text
# Ignore self loops
if self.ignore_self_loops and source_id == target_id:
continue
source_node = [n for n in nodes if n.id == source_id][0]
target_node = [n for n in nodes if n.id == target_id][0]
relationships.append(
Relationship(
source=source_node,
target=target_node,
type=rel["label"].replace(" ", "_").upper(),
)
)
return GraphDocument(nodes=nodes, relationships=relationships, source=document)
def convert_to_graph_documents(
self, documents: Sequence[Document]
) -> List[GraphDocument]:
"""Convert a sequence of documents into graph documents.
Args:
documents (Sequence[Document]): The original documents.
kwargs: Additional keyword arguments.
Returns:
Sequence[GraphDocument]: The transformed documents as graphs.
"""
results = []
for document in documents:
graph_document = self.process_document(document)
results.append(graph_document)
return results

View File

@ -1,887 +0,0 @@
import asyncio
import json
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import SystemMessage
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
)
from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel, Field, create_model
examples = [
{
"text": (
"Adam is a software engineer in Microsoft since 2009, "
"and last year he got an award as the Best Talent"
),
"head": "Adam",
"head_type": "Person",
"relation": "WORKS_FOR",
"tail": "Microsoft",
"tail_type": "Company",
},
{
"text": (
"Adam is a software engineer in Microsoft since 2009, "
"and last year he got an award as the Best Talent"
),
"head": "Adam",
"head_type": "Person",
"relation": "HAS_AWARD",
"tail": "Best Talent",
"tail_type": "Award",
},
{
"text": (
"Microsoft is a tech company that provide "
"several products such as Microsoft Word"
),
"head": "Microsoft Word",
"head_type": "Product",
"relation": "PRODUCED_BY",
"tail": "Microsoft",
"tail_type": "Company",
},
{
"text": "Microsoft Word is a lightweight app that accessible offline",
"head": "Microsoft Word",
"head_type": "Product",
"relation": "HAS_CHARACTERISTIC",
"tail": "lightweight app",
"tail_type": "Characteristic",
},
{
"text": "Microsoft Word is a lightweight app that accessible offline",
"head": "Microsoft Word",
"head_type": "Product",
"relation": "HAS_CHARACTERISTIC",
"tail": "accessible offline",
"tail_type": "Characteristic",
},
]
system_prompt = (
"# Knowledge Graph Instructions for GPT-4\n"
"## 1. Overview\n"
"You are a top-tier algorithm designed for extracting information in structured "
"formats to build a knowledge graph.\n"
"Try to capture as much information from the text as possible without "
"sacrificing accuracy. Do not add any information that is not explicitly "
"mentioned in the text.\n"
"- **Nodes** represent entities and concepts.\n"
"- The aim is to achieve simplicity and clarity in the knowledge graph, making it\n"
"accessible for a vast audience.\n"
"## 2. Labeling Nodes\n"
"- **Consistency**: Ensure you use available types for node labels.\n"
"Ensure you use basic or elementary types for node labels.\n"
"- For example, when you identify an entity representing a person, "
"always label it as **'person'**. Avoid using more specific terms "
"like 'mathematician' or 'scientist'."
"- **Node IDs**: Never utilize integers as node IDs. Node IDs should be "
"names or human-readable identifiers found in the text.\n"
"- **Relationships** represent connections between entities or concepts.\n"
"Ensure consistency and generality in relationship types when constructing "
"knowledge graphs. Instead of using specific and momentary types "
"such as 'BECAME_PROFESSOR', use more general and timeless relationship types "
"like 'PROFESSOR'. Make sure to use general and timeless relationship types!\n"
"## 3. Coreference Resolution\n"
"- **Maintain Entity Consistency**: When extracting entities, it's vital to "
"ensure consistency.\n"
'If an entity, such as "John Doe", is mentioned multiple times in the text '
'but is referred to by different names or pronouns (e.g., "Joe", "he"),'
"always use the most complete identifier for that entity throughout the "
'knowledge graph. In this example, use "John Doe" as the entity ID.\n'
"Remember, the knowledge graph should be coherent and easily understandable, "
"so maintaining consistency in entity references is crucial.\n"
"## 4. Strict Compliance\n"
"Adhere to the rules strictly. Non-compliance will result in termination."
)
default_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
system_prompt,
),
(
"human",
(
"Tip: Make sure to answer in the correct format and do "
"not include any explanations. "
"Use the given format to extract information from the "
"following input: {input}"
),
),
]
)
def _get_additional_info(input_type: str) -> str:
# Check if the input_type is one of the allowed values
if input_type not in ["node", "relationship", "property"]:
raise ValueError("input_type must be 'node', 'relationship', or 'property'")
# Perform actions based on the input_type
if input_type == "node":
return (
"Ensure you use basic or elementary types for node labels.\n"
"For example, when you identify an entity representing a person, "
"always label it as **'Person'**. Avoid using more specific terms "
"like 'Mathematician' or 'Scientist'"
)
elif input_type == "relationship":
return (
"Instead of using specific and momentary types such as "
"'BECAME_PROFESSOR', use more general and timeless relationship types "
"like 'PROFESSOR'. However, do not sacrifice any accuracy for generality"
)
elif input_type == "property":
return ""
return ""
def optional_enum_field(
enum_values: Optional[List[str]] = None,
description: str = "",
input_type: str = "node",
llm_type: Optional[str] = None,
**field_kwargs: Any,
) -> Any:
"""Utility function to conditionally create a field with an enum constraint."""
# Only openai supports enum param
if enum_values and llm_type == "openai-chat":
return Field(
...,
enum=enum_values, # type: ignore[call-arg]
description=f"{description}. Available options are {enum_values}",
**field_kwargs,
)
elif enum_values:
return Field(
...,
description=f"{description}. Available options are {enum_values}",
**field_kwargs,
)
else:
additional_info = _get_additional_info(input_type)
return Field(..., description=description + additional_info, **field_kwargs)
class _Graph(BaseModel):
nodes: Optional[List]
relationships: Optional[List]
class UnstructuredRelation(BaseModel):
head: str = Field(
description=(
"extracted head entity like Microsoft, Apple, John. "
"Must use human-readable unique identifier."
)
)
head_type: str = Field(
description="type of the extracted head entity like Person, Company, etc"
)
relation: str = Field(description="relation between the head and the tail entities")
tail: str = Field(
description=(
"extracted tail entity like Microsoft, Apple, John. "
"Must use human-readable unique identifier."
)
)
tail_type: str = Field(
description="type of the extracted tail entity like Person, Company, etc"
)
def create_unstructured_prompt(
node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None
) -> ChatPromptTemplate:
node_labels_str = str(node_labels) if node_labels else ""
rel_types_str = str(rel_types) if rel_types else ""
base_string_parts = [
"You are a top-tier algorithm designed for extracting information in "
"structured formats to build a knowledge graph. Your task is to identify "
"the entities and relations requested with the user prompt from a given "
"text. You must generate the output in a JSON format containing a list "
'with JSON objects. Each object should have the keys: "head", '
'"head_type", "relation", "tail", and "tail_type". The "head" '
"key must contain the text of the extracted entity with one of the types "
"from the provided list in the user prompt.",
f'The "head_type" key must contain the type of the extracted head entity, '
f"which must be one of the types from {node_labels_str}."
if node_labels
else "",
f'The "relation" key must contain the type of relation between the "head" '
f'and the "tail", which must be one of the relations from {rel_types_str}.'
if rel_types
else "",
f'The "tail" key must represent the text of an extracted entity which is '
f'the tail of the relation, and the "tail_type" key must contain the type '
f"of the tail entity from {node_labels_str}."
if node_labels
else "",
"Attempt to extract as many entities and relations as you can. Maintain "
"Entity Consistency: When extracting entities, it's vital to ensure "
'consistency. If an entity, such as "John Doe", is mentioned multiple '
"times in the text but is referred to by different names or pronouns "
'(e.g., "Joe", "he"), always use the most complete identifier for '
"that entity. The knowledge graph should be coherent and easily "
"understandable, so maintaining consistency in entity references is "
"crucial.",
"IMPORTANT NOTES:\n- Don't add any explanation and text.",
]
system_prompt = "\n".join(filter(None, base_string_parts))
system_message = SystemMessage(content=system_prompt)
parser = JsonOutputParser(pydantic_object=UnstructuredRelation)
human_string_parts = [
"Based on the following example, extract entities and "
"relations from the provided text.\n\n",
"Use the following entity types, don't use other entity "
"that is not defined below:"
"# ENTITY TYPES:"
"{node_labels}"
if node_labels
else "",
"Use the following relation types, don't use other relation "
"that is not defined below:"
"# RELATION TYPES:"
"{rel_types}"
if rel_types
else "",
"Below are a number of examples of text and their extracted "
"entities and relationships."
"{examples}\n"
"For the following text, extract entities and relations as "
"in the provided example."
"{format_instructions}\nText: {input}",
]
human_prompt_string = "\n".join(filter(None, human_string_parts))
human_prompt = PromptTemplate(
template=human_prompt_string,
input_variables=["input"],
partial_variables={
"format_instructions": parser.get_format_instructions(),
"node_labels": node_labels,
"rel_types": rel_types,
"examples": examples,
},
)
human_message_prompt = HumanMessagePromptTemplate(prompt=human_prompt)
chat_prompt = ChatPromptTemplate.from_messages(
[system_message, human_message_prompt]
)
return chat_prompt
def create_simple_model(
node_labels: Optional[List[str]] = None,
rel_types: Optional[List[str]] = None,
node_properties: Union[bool, List[str]] = False,
llm_type: Optional[str] = None,
relationship_properties: Union[bool, List[str]] = False,
) -> Type[_Graph]:
"""
Create a simple graph model with optional constraints on node
and relationship types.
Args:
node_labels (Optional[List[str]]): Specifies the allowed node types.
Defaults to None, allowing all node types.
rel_types (Optional[List[str]]): Specifies the allowed relationship types.
Defaults to None, allowing all relationship types.
node_properties (Union[bool, List[str]]): Specifies if node properties should
be included. If a list is provided, only properties with keys in the list
will be included. If True, all properties are included. Defaults to False.
relationship_properties (Union[bool, List[str]]): Specifies if relationship
properties should be included. If a list is provided, only properties with
keys in the list will be included. If True, all properties are included.
Defaults to False.
llm_type (Optional[str]): The type of the language model. Defaults to None.
Only openai supports enum param: openai-chat.
Returns:
Type[_Graph]: A graph model with the specified constraints.
Raises:
ValueError: If 'id' is included in the node or relationship properties list.
"""
node_fields: Dict[str, Tuple[Any, Any]] = {
"id": (
str,
Field(..., description="Name or human-readable unique identifier."),
),
"type": (
str,
optional_enum_field(
node_labels,
description="The type or label of the node.",
input_type="node",
llm_type=llm_type,
),
),
}
if node_properties:
if isinstance(node_properties, list) and "id" in node_properties:
raise ValueError("The node property 'id' is reserved and cannot be used.")
# Map True to empty array
node_properties_mapped: List[str] = (
[] if node_properties is True else node_properties
)
class Property(BaseModel):
"""A single property consisting of key and value"""
key: str = optional_enum_field(
node_properties_mapped,
description="Property key.",
input_type="property",
llm_type=llm_type,
)
value: str = Field(..., description="value")
node_fields["properties"] = (
Optional[List[Property]],
Field(None, description="List of node properties"),
)
SimpleNode = create_model("SimpleNode", **node_fields) # type: ignore
relationship_fields: Dict[str, Tuple[Any, Any]] = {
"source_node_id": (
str,
Field(
...,
description="Name or human-readable unique identifier of source node",
),
),
"source_node_type": (
str,
optional_enum_field(
node_labels,
description="The type or label of the source node.",
input_type="node",
llm_type=llm_type,
),
),
"target_node_id": (
str,
Field(
...,
description="Name or human-readable unique identifier of target node",
),
),
"target_node_type": (
str,
optional_enum_field(
node_labels,
description="The type or label of the target node.",
input_type="node",
llm_type=llm_type,
),
),
"type": (
str,
optional_enum_field(
rel_types,
description="The type of the relationship.",
input_type="relationship",
llm_type=llm_type,
),
),
}
if relationship_properties:
if (
isinstance(relationship_properties, list)
and "id" in relationship_properties
):
raise ValueError(
"The relationship property 'id' is reserved and cannot be used."
)
# Map True to empty array
relationship_properties_mapped: List[str] = (
[] if relationship_properties is True else relationship_properties
)
class RelationshipProperty(BaseModel):
"""A single property consisting of key and value"""
key: str = optional_enum_field(
relationship_properties_mapped,
description="Property key.",
input_type="property",
llm_type=llm_type,
)
value: str = Field(..., description="value")
relationship_fields["properties"] = (
Optional[List[RelationshipProperty]],
Field(None, description="List of relationship properties"),
)
SimpleRelationship = create_model("SimpleRelationship", **relationship_fields) # type: ignore
class DynamicGraph(_Graph):
"""Represents a graph document consisting of nodes and relationships."""
nodes: Optional[List[SimpleNode]] = Field(description="List of nodes") # type: ignore
relationships: Optional[List[SimpleRelationship]] = Field( # type: ignore
description="List of relationships"
)
return DynamicGraph
def map_to_base_node(node: Any) -> Node:
"""Map the SimpleNode to the base Node."""
properties = {}
if hasattr(node, "properties") and node.properties:
for p in node.properties:
properties[format_property_key(p.key)] = p.value
return Node(id=node.id, type=node.type, properties=properties)
def map_to_base_relationship(rel: Any) -> Relationship:
"""Map the SimpleRelationship to the base Relationship."""
source = Node(id=rel.source_node_id, type=rel.source_node_type)
target = Node(id=rel.target_node_id, type=rel.target_node_type)
properties = {}
if hasattr(rel, "properties") and rel.properties:
for p in rel.properties:
properties[format_property_key(p.key)] = p.value
return Relationship(
source=source, target=target, type=rel.type, properties=properties
)
def _parse_and_clean_json(
argument_json: Dict[str, Any],
) -> Tuple[List[Node], List[Relationship]]:
nodes = []
for node in argument_json["nodes"]:
if not node.get("id"): # Id is mandatory, skip this node
continue
node_properties = {}
if "properties" in node and node["properties"]:
for p in node["properties"]:
node_properties[format_property_key(p["key"])] = p["value"]
nodes.append(
Node(
id=node["id"],
type=node.get("type", "Node"),
properties=node_properties,
)
)
relationships = []
for rel in argument_json["relationships"]:
# Mandatory props
if (
not rel.get("source_node_id")
or not rel.get("target_node_id")
or not rel.get("type")
):
continue
# Node type copying if needed from node list
if not rel.get("source_node_type"):
try:
rel["source_node_type"] = [
el.get("type")
for el in argument_json["nodes"]
if el["id"] == rel["source_node_id"]
][0]
except IndexError:
rel["source_node_type"] = None
if not rel.get("target_node_type"):
try:
rel["target_node_type"] = [
el.get("type")
for el in argument_json["nodes"]
if el["id"] == rel["target_node_id"]
][0]
except IndexError:
rel["target_node_type"] = None
rel_properties = {}
if "properties" in rel and rel["properties"]:
for p in rel["properties"]:
rel_properties[format_property_key(p["key"])] = p["value"]
source_node = Node(
id=rel["source_node_id"],
type=rel["source_node_type"],
)
target_node = Node(
id=rel["target_node_id"],
type=rel["target_node_type"],
)
relationships.append(
Relationship(
source=source_node,
target=target_node,
type=rel["type"],
properties=rel_properties,
)
)
return nodes, relationships
def _format_nodes(nodes: List[Node]) -> List[Node]:
return [
Node(
id=el.id.title() if isinstance(el.id, str) else el.id,
type=el.type.capitalize() # type: ignore[arg-type]
if el.type
else None, # handle empty strings # type: ignore[arg-type]
properties=el.properties,
)
for el in nodes
]
def _format_relationships(rels: List[Relationship]) -> List[Relationship]:
return [
Relationship(
source=_format_nodes([el.source])[0],
target=_format_nodes([el.target])[0],
type=el.type.replace(" ", "_").upper(),
properties=el.properties,
)
for el in rels
]
def format_property_key(s: str) -> str:
words = s.split()
if not words:
return s
first_word = words[0].lower()
capitalized_words = [word.capitalize() for word in words[1:]]
return "".join([first_word] + capitalized_words)
def _convert_to_graph_document(
raw_schema: Dict[Any, Any],
) -> Tuple[List[Node], List[Relationship]]:
# If there are validation errors
if not raw_schema["parsed"]:
try:
try: # OpenAI type response
argument_json = json.loads(
raw_schema["raw"].additional_kwargs["tool_calls"][0]["function"][
"arguments"
]
)
except Exception: # Google type response
try:
argument_json = json.loads(
raw_schema["raw"].additional_kwargs["function_call"][
"arguments"
]
)
except Exception: # Ollama type response
argument_json = raw_schema["raw"].tool_calls[0]["args"]
if isinstance(argument_json["nodes"], str):
argument_json["nodes"] = json.loads(argument_json["nodes"])
if isinstance(argument_json["relationships"], str):
argument_json["relationships"] = json.loads(
argument_json["relationships"]
)
nodes, relationships = _parse_and_clean_json(argument_json)
except Exception: # If we can't parse JSON
return ([], [])
else: # If there are no validation errors use parsed pydantic object
parsed_schema: _Graph = raw_schema["parsed"]
nodes = (
[map_to_base_node(node) for node in parsed_schema.nodes if node.id]
if parsed_schema.nodes
else []
)
relationships = (
[
map_to_base_relationship(rel)
for rel in parsed_schema.relationships
if rel.type and rel.source_node_id and rel.target_node_id
]
if parsed_schema.relationships
else []
)
# Title / Capitalize
return _format_nodes(nodes), _format_relationships(relationships)
class LLMGraphTransformer:
"""Transform documents into graph-based documents using a LLM.
It allows specifying constraints on the types of nodes and relationships to include
in the output graph. The class supports extracting properties for both nodes and
relationships.
Args:
llm (BaseLanguageModel): An instance of a language model supporting structured
output.
allowed_nodes (List[str], optional): Specifies which node types are
allowed in the graph. Defaults to an empty list, allowing all node types.
allowed_relationships (List[str], optional): Specifies which relationship types
are allowed in the graph. Defaults to an empty list, allowing all relationship
types.
prompt (Optional[ChatPromptTemplate], optional): The prompt to pass to
the LLM with additional instructions.
strict_mode (bool, optional): Determines whether the transformer should apply
filtering to strictly adhere to `allowed_nodes` and `allowed_relationships`.
Defaults to True.
node_properties (Union[bool, List[str]]): If True, the LLM can extract any
node properties from text. Alternatively, a list of valid properties can
be provided for the LLM to extract, restricting extraction to those specified.
relationship_properties (Union[bool, List[str]]): If True, the LLM can extract
any relationship properties from text. Alternatively, a list of valid
properties can be provided for the LLM to extract, restricting extraction to
those specified.
ignore_tool_usage (bool): Indicates whether the transformer should
bypass the use of structured output functionality of the language model.
If set to True, the transformer will not use the language model's native
function calling capabilities to handle structured output. Defaults to False.
Example:
.. code-block:: python
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
llm=ChatOpenAI(temperature=0)
transformer = LLMGraphTransformer(
llm=llm,
allowed_nodes=["Person", "Organization"])
doc = Document(page_content="Elon Musk is suing OpenAI")
graph_documents = transformer.convert_to_graph_documents([doc])
"""
def __init__(
self,
llm: BaseLanguageModel,
allowed_nodes: List[str] = [],
allowed_relationships: List[str] = [],
prompt: Optional[ChatPromptTemplate] = None,
strict_mode: bool = True,
node_properties: Union[bool, List[str]] = False,
relationship_properties: Union[bool, List[str]] = False,
ignore_tool_usage: bool = False,
) -> None:
self.allowed_nodes = allowed_nodes
self.allowed_relationships = allowed_relationships
self.strict_mode = strict_mode
self._function_call = not ignore_tool_usage
# Check if the LLM really supports structured output
if self._function_call:
try:
llm.with_structured_output(_Graph)
except NotImplementedError:
self._function_call = False
if not self._function_call:
if node_properties or relationship_properties:
raise ValueError(
"The 'node_properties' and 'relationship_properties' parameters "
"cannot be used in combination with a LLM that doesn't support "
"native function calling."
)
try:
import json_repair # type: ignore
self.json_repair = json_repair
except ImportError:
raise ImportError(
"Could not import json_repair python package. "
"Please install it with `pip install json-repair`."
)
prompt = prompt or create_unstructured_prompt(
allowed_nodes, allowed_relationships
)
self.chain = prompt | llm
else:
# Define chain
try:
llm_type = llm._llm_type # type: ignore
except AttributeError:
llm_type = None
schema = create_simple_model(
allowed_nodes,
allowed_relationships,
node_properties,
llm_type,
relationship_properties,
)
structured_llm = llm.with_structured_output(schema, include_raw=True)
prompt = prompt or default_prompt
self.chain = prompt | structured_llm
def process_response(
self, document: Document, config: Optional[RunnableConfig] = None
) -> GraphDocument:
"""
Processes a single document, transforming it into a graph document using
an LLM based on the model's schema and constraints.
"""
text = document.page_content
raw_schema = self.chain.invoke({"input": text}, config=config)
if self._function_call:
raw_schema = cast(Dict[Any, Any], raw_schema)
nodes, relationships = _convert_to_graph_document(raw_schema)
else:
nodes_set = set()
relationships = []
if not isinstance(raw_schema, str):
raw_schema = raw_schema.content
parsed_json = self.json_repair.loads(raw_schema)
if isinstance(parsed_json, dict):
parsed_json = [parsed_json]
for rel in parsed_json:
# Check if mandatory properties are there
if (
not rel.get("head")
or not rel.get("tail")
or not rel.get("relation")
):
continue
# Nodes need to be deduplicated using a set
# Use default Node label for nodes if missing
nodes_set.add((rel["head"], rel.get("head_type", "Node")))
nodes_set.add((rel["tail"], rel.get("tail_type", "Node")))
source_node = Node(id=rel["head"], type=rel.get("head_type", "Node"))
target_node = Node(id=rel["tail"], type=rel.get("tail_type", "Node"))
relationships.append(
Relationship(
source=source_node, target=target_node, type=rel["relation"]
)
)
# Create nodes list
nodes = [Node(id=el[0], type=el[1]) for el in list(nodes_set)]
# Strict mode filtering
if self.strict_mode and (self.allowed_nodes or self.allowed_relationships):
if self.allowed_nodes:
lower_allowed_nodes = [el.lower() for el in self.allowed_nodes]
nodes = [
node for node in nodes if node.type.lower() in lower_allowed_nodes
]
relationships = [
rel
for rel in relationships
if rel.source.type.lower() in lower_allowed_nodes
and rel.target.type.lower() in lower_allowed_nodes
]
if self.allowed_relationships:
relationships = [
rel
for rel in relationships
if rel.type.lower()
in [el.lower() for el in self.allowed_relationships]
]
return GraphDocument(nodes=nodes, relationships=relationships, source=document)
def convert_to_graph_documents(
self, documents: Sequence[Document], config: Optional[RunnableConfig] = None
) -> List[GraphDocument]:
"""Convert a sequence of documents into graph documents.
Args:
documents (Sequence[Document]): The original documents.
kwargs: Additional keyword arguments.
Returns:
Sequence[GraphDocument]: The transformed documents as graphs.
"""
return [self.process_response(document, config) for document in documents]
async def aprocess_response(
self, document: Document, config: Optional[RunnableConfig] = None
) -> GraphDocument:
"""
Asynchronously processes a single document, transforming it into a
graph document.
"""
text = document.page_content
raw_schema = await self.chain.ainvoke({"input": text}, config=config)
if self._function_call:
raw_schema = cast(Dict[Any, Any], raw_schema)
nodes, relationships = _convert_to_graph_document(raw_schema)
else:
nodes_set = set()
relationships = []
if not isinstance(raw_schema, str):
raw_schema = raw_schema.content
parsed_json = self.json_repair.loads(raw_schema)
if isinstance(parsed_json, dict):
parsed_json = [parsed_json]
for rel in parsed_json:
# Check if mandatory properties are there
if (
not rel.get("head")
or not rel.get("tail")
or not rel.get("relation")
):
continue
# Nodes need to be deduplicated using a set
# Use default Node label for nodes if missing
nodes_set.add((rel["head"], rel.get("head_type", "Node")))
nodes_set.add((rel["tail"], rel.get("tail_type", "Node")))
source_node = Node(id=rel["head"], type=rel.get("head_type", "Node"))
target_node = Node(id=rel["tail"], type=rel.get("tail_type", "Node"))
relationships.append(
Relationship(
source=source_node, target=target_node, type=rel["relation"]
)
)
# Create nodes list
nodes = [Node(id=el[0], type=el[1]) for el in list(nodes_set)]
if self.strict_mode and (self.allowed_nodes or self.allowed_relationships):
if self.allowed_nodes:
lower_allowed_nodes = [el.lower() for el in self.allowed_nodes]
nodes = [
node for node in nodes if node.type.lower() in lower_allowed_nodes
]
relationships = [
rel
for rel in relationships
if rel.source.type.lower() in lower_allowed_nodes
and rel.target.type.lower() in lower_allowed_nodes
]
if self.allowed_relationships:
relationships = [
rel
for rel in relationships
if rel.type.lower()
in [el.lower() for el in self.allowed_relationships]
]
return GraphDocument(nodes=nodes, relationships=relationships, source=document)
async def aconvert_to_graph_documents(
self, documents: Sequence[Document], config: Optional[RunnableConfig] = None
) -> List[GraphDocument]:
"""
Asynchronously convert a sequence of documents into graph documents.
"""
tasks = [
asyncio.create_task(self.aprocess_response(document, config))
for document in documents
]
results = await asyncio.gather(*tasks)
return results

View File

@ -1,115 +0,0 @@
import logging
from typing import Any, Dict, List, Sequence
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_core.documents import Document
DEFAULT_NODE_TYPE = "Node"
class RelikGraphTransformer:
"""
A transformer class for converting documents into graph structures
using the Relik library and models.
This class leverages relik models for extracting relationships
and nodes from text documents and converting them into a graph format.
The relationships are filtered based on a specified confidence threshold.
For more details on the Relik library, visit their GitHub repository:
https://github.com/SapienzaNLP/relik
Args:
model (str): The name of the pretrained Relik model to use.
Default is "relik-ie/relik-relation-extraction-small-wikipedia".
relationship_confidence_threshold (float): The confidence threshold for
filtering relationships. Default is 0.1.
model_config (Dict[str, any]): Additional configuration options for the
Relik model. Default is an empty dictionary.
ignore_self_loops (bool): Whether to ignore relationships where the
source and target nodes are the same. Default is True.
"""
def __init__(
self,
model: str = "relik-ie/relik-relation-extraction-small",
relationship_confidence_threshold: float = 0.1,
model_config: Dict[str, Any] = {},
ignore_self_loops: bool = True,
) -> None:
try:
import relik # type: ignore
# Remove default INFO logging
logging.getLogger("relik").setLevel(logging.WARNING)
except ImportError:
raise ImportError(
"Could not import relik python package. "
"Please install it with `pip install relik`."
)
self.relik_model = relik.Relik.from_pretrained(model, **model_config)
self.relationship_confidence_threshold = relationship_confidence_threshold
self.ignore_self_loops = ignore_self_loops
def process_document(self, document: Document) -> GraphDocument:
relik_out = self.relik_model(document.page_content)
nodes = []
# Extract nodes
for node in relik_out.spans:
nodes.append(
Node(
id=node.text,
type=DEFAULT_NODE_TYPE
if node.label.strip() == "--NME--"
else node.label.strip(),
)
)
relationships = []
# Extract relationships
for triple in relik_out.triplets:
# Ignore relationship if below confidence threshold
if triple.confidence < self.relationship_confidence_threshold:
continue
# Ignore self loops
if self.ignore_self_loops and triple.subject.text == triple.object.text:
continue
source_node = Node(
id=triple.subject.text,
type=DEFAULT_NODE_TYPE
if triple.subject.label.strip() == "--NME--"
else triple.subject.label.strip(),
)
target_node = Node(
id=triple.object.text,
type=DEFAULT_NODE_TYPE
if triple.object.label.strip() == "--NME--"
else triple.object.label.strip(),
)
relationship = Relationship(
source=source_node,
target=target_node,
type=triple.label.replace(" ", "_").upper(),
)
relationships.append(relationship)
return GraphDocument(nodes=nodes, relationships=relationships, source=document)
def convert_to_graph_documents(
self, documents: Sequence[Document]
) -> List[GraphDocument]:
"""Convert a sequence of documents into graph documents.
Args:
documents (Sequence[Document]): The original documents.
kwargs: Additional keyword arguments.
Returns:
Sequence[GraphDocument]: The transformed documents as graphs.
"""
results = []
for document in documents:
graph_document = self.process_document(document)
results.append(graph_document)
return results

View File

@ -1,2 +0,0 @@
"""**LLM bash** is a chain that uses LLM to interpret a prompt and
executes **bash** code."""

View File

@ -1,127 +0,0 @@
"""Chain that interprets a prompt and executes bash operations."""
from __future__ import annotations
import logging
import warnings
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.schema import BasePromptTemplate, OutputParserException
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from pydantic import ConfigDict, Field, model_validator
from langchain_experimental.llm_bash.bash import BashProcess
from langchain_experimental.llm_bash.prompt import PROMPT
logger = logging.getLogger(__name__)
class LLMBashChain(Chain):
"""Chain that interprets a prompt and executes bash operations.
Example:
.. code-block:: python
from langchain.chains import LLMBashChain
from langchain_community.llms import OpenAI
llm_bash = LLMBashChain.from_llm(OpenAI())
"""
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
prompt: BasePromptTemplate = PROMPT
"""[Deprecated]"""
bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private:
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMBashChain with an llm is deprecated. "
"Please instantiate with llm_chain or using the from_llm class method."
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@model_validator(mode="before")
@classmethod
def validate_prompt(cls, values: Dict) -> Any:
if values["llm_chain"].prompt.output_parser is None:
raise ValueError(
"The prompt used by llm_chain is expected to have an output_parser."
)
return values
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key], verbose=self.verbose)
t = self.llm_chain.predict(
question=inputs[self.input_key], callbacks=_run_manager.get_child()
)
_run_manager.on_text(t, color="green", verbose=self.verbose)
t = t.strip()
try:
parser = self.llm_chain.prompt.output_parser
command_list = parser.parse(t) # type: ignore[union-attr]
except OutputParserException as e:
_run_manager.on_chain_error(e, verbose=self.verbose)
raise e
if self.verbose:
_run_manager.on_text("\nCode: ", verbose=self.verbose)
_run_manager.on_text(
str(command_list), color="yellow", verbose=self.verbose
)
output = self.bash_process.run(command_list)
_run_manager.on_text("\nAnswer: ", verbose=self.verbose)
_run_manager.on_text(output, color="yellow", verbose=self.verbose)
return {self.output_key: output}
@property
def _chain_type(self) -> str:
return "llm_bash_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = PROMPT,
**kwargs: Any,
) -> LLMBashChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)

View File

@ -1,185 +0,0 @@
"""Wrapper around subprocess to run commands."""
from __future__ import annotations
import platform
import re
import subprocess
from typing import TYPE_CHECKING, List, Union
from uuid import uuid4
if TYPE_CHECKING:
import pexpect
class BashProcess:
"""Wrapper for starting subprocesses.
Uses the python built-in subprocesses.run()
Persistent processes are **not** available
on Windows systems, as pexpect makes use of
Unix pseudoterminals (ptys). MacOS and Linux
are okay.
Example:
.. code-block:: python
from langchain_community.utilities.bash import BashProcess
bash = BashProcess(
strip_newlines = False,
return_err_output = False,
persistent = False
)
bash.run('echo \'hello world\'')
"""
strip_newlines: bool = False
"""Whether or not to run .strip() on the output"""
return_err_output: bool = False
"""Whether or not to return the output of a failed
command, or just the error message and stacktrace"""
persistent: bool = False
"""Whether or not to spawn a persistent session
NOTE: Unavailable for Windows environments"""
def __init__(
self,
strip_newlines: bool = False,
return_err_output: bool = False,
persistent: bool = False,
):
"""
Initializes with default settings
"""
self.strip_newlines = strip_newlines
self.return_err_output = return_err_output
self.prompt = ""
self.process = None
if persistent:
self.prompt = str(uuid4())
self.process = self._initialize_persistent_process(self, self.prompt)
@staticmethod
def _lazy_import_pexpect() -> pexpect:
"""Import pexpect only when needed."""
if platform.system() == "Windows":
raise ValueError(
"Persistent bash processes are not yet supported on Windows."
)
try:
import pexpect
except ImportError:
raise ImportError(
"pexpect required for persistent bash processes."
" To install, run `pip install pexpect`."
)
return pexpect
@staticmethod
def _initialize_persistent_process(self: BashProcess, prompt: str) -> pexpect.spawn:
# Start bash in a clean environment
# Doesn't work on windows
"""
Initializes a persistent bash setting in a
clean environment.
NOTE: Unavailable on Windows
Args:
Prompt(str): the bash command to execute
"""
pexpect = self._lazy_import_pexpect()
process = pexpect.spawn(
"env", ["-i", "bash", "--norc", "--noprofile"], encoding="utf-8"
)
# Set the custom prompt
process.sendline("PS1=" + prompt)
process.expect_exact(prompt, timeout=10)
return process
def run(self, commands: Union[str, List[str]]) -> str:
"""
Run commands in either an existing persistent
subprocess or on in a new subprocess environment.
Args:
commands(List[str]): a list of commands to
execute in the session
"""
if isinstance(commands, str):
commands = [commands]
commands = ";".join(commands)
if self.process is not None:
return self._run_persistent(
commands,
)
else:
return self._run(commands)
def _run(self, command: str) -> str:
"""
Runs a command in a subprocess and returns
the output.
Args:
command: The command to run
"""
try:
output = subprocess.run(
command,
shell=True,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
).stdout.decode()
except subprocess.CalledProcessError as error:
if self.return_err_output:
return error.stdout.decode()
return str(error)
if self.strip_newlines:
output = output.strip()
return output
def process_output(self, output: str, command: str) -> str:
"""
Uses regex to remove the command from the output
Args:
output: a process' output string
command: the executed command
"""
pattern = re.escape(command) + r"\s*\n"
output = re.sub(pattern, "", output, count=1)
return output.strip()
def _run_persistent(self, command: str) -> str:
"""
Runs commands in a persistent environment
and returns the output.
Args:
command: the command to execute
"""
pexpect = self._lazy_import_pexpect()
if self.process is None:
raise ValueError("Process not initialized")
self.process.sendline(command)
# Clear the output with an empty string
self.process.expect(self.prompt, timeout=10)
self.process.sendline("")
try:
self.process.expect([self.prompt, pexpect.EOF], timeout=10)
except pexpect.TIMEOUT:
return f"Timeout error while executing command {command}"
if self.process.after == pexpect.EOF:
return f"Exited with error status: {self.process.exitstatus}"
output = self.process.before
output = self.process_output(output, command)
if self.strip_newlines:
return output.strip()
return output

View File

@ -1,67 +0,0 @@
# flake8: noqa
from __future__ import annotations
import re
from typing import List
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.exceptions import OutputParserException
_PROMPT_TEMPLATE = """If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put "#!/bin/bash" in your answer. Make sure to reason step by step, using this format:
Question: "copy the files in the directory named 'target' into a new directory at the same level as target called 'myNewDirectory'"
I need to take the following actions:
- List all files in the directory
- Create a new directory
- Copy the files from the first directory into the second directory
```bash
ls
mkdir myNewDirectory
cp -r target/* myNewDirectory
```
That is the format. Begin!
Question: {question}"""
class BashOutputParser(BaseOutputParser):
"""Parser for bash output."""
def parse(self, text: str) -> List[str]:
"""Parse the output of a bash command."""
if "```bash" in text:
return self.get_code_blocks(text)
else:
raise OutputParserException(
f"Failed to parse bash output. Got: {text}",
)
@staticmethod
def get_code_blocks(t: str) -> List[str]:
"""Get multiple code blocks from the LLM result."""
code_blocks: List[str] = []
# Bash markdown code blocks
pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL)
for match in pattern.finditer(t):
matched = match.group(1).strip()
if matched:
code_blocks.extend(
[line for line in matched.split("\n") if line.strip()]
)
return code_blocks
@property
def _type(self) -> str:
return "bash"
PROMPT = PromptTemplate(
input_variables=["question"],
template=_PROMPT_TEMPLATE,
output_parser=BashOutputParser(),
)

View File

@ -1,4 +0,0 @@
"""Chain that interprets a prompt and **executes python code to do math**.
Heavily borrowed from `llm_math`, uses the [SymPy](https://www.sympy.org/) package.
"""

View File

@ -1,250 +0,0 @@
"""Chain that interprets a prompt and executes python code to do symbolic math."""
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.prompts.base import BasePromptTemplate
from pydantic import ConfigDict
from langchain_experimental.llm_symbolic_math.prompt import PROMPT
class LLMSymbolicMathChain(Chain):
"""Chain that interprets a prompt and executes python code to do symbolic math.
It is based on the sympy library and can be used to evaluate
mathematical expressions.
See https://www.sympy.org/ for more information.
Example:
.. code-block:: python
from langchain.chains import LLMSymbolicMathChain
from langchain_community.llms import OpenAI
llm_symbolic_math = LLMSymbolicMathChain.from_llm(OpenAI())
"""
llm_chain: LLMChain
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
allow_dangerous_requests: bool # Assign no default.
"""Must be set by the user to allow dangerous requests or not.
We recommend a default of False to allow only pre-defined symbolic operations.
When set to True, the chain will allow any kind of input. This is
STRONGLY DISCOURAGED unless you fully trust the input (and believe that
the LLM itself cannot behave in a malicious way).
You should absolutely NOT be deploying this in a production environment
with allow_dangerous_requests=True. As this would allow a malicious actor
to execute arbitrary code on your system.
Use default=True at your own risk.
When set to False, the chain will only allow pre-defined symbolic operations.
If the some symbolic expressions are failing to evaluate, you can open a PR
to add them to extend the list of allowed operations.
"""
def __init__(self, **kwargs: Any) -> None:
if "allow_dangerous_requests" not in kwargs:
raise ValueError(
"LLMSymbolicMathChain requires allow_dangerous_requests to be set. "
"We recommend that you set `allow_dangerous_requests=False` to allow "
"only pre-defined symbolic operations. "
"If the some symbolic expressions are failing to evaluate, you can "
"open a PR to add them to extend the list of allowed operations. "
"Alternatively, you can set `allow_dangerous_requests=True` to allow "
"any kind of input but this is STRONGLY DISCOURAGED unless you "
"fully trust the input (and believe that the LLM itself cannot behave "
"in a malicious way)."
"You should absolutely NOT be deploying this in a production "
"environment with allow_dangerous_requests=True. As "
"this would allow a malicious actor to execute arbitrary code on "
"your system."
)
super().__init__(**kwargs)
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _evaluate_expression(self, expression: str) -> str:
try:
import sympy
except ImportError as e:
raise ImportError(
"Unable to import sympy, please install it with `pip install sympy`."
) from e
try:
if self.allow_dangerous_requests:
output = str(sympy.sympify(expression, evaluate=True))
else:
allowed_symbols = {
# Basic arithmetic and trigonometry
"sin": sympy.sin,
"cos": sympy.cos,
"tan": sympy.tan,
"cot": sympy.cot,
"sec": sympy.sec,
"csc": sympy.csc,
"asin": sympy.asin,
"acos": sympy.acos,
"atan": sympy.atan,
# Hyperbolic functions
"sinh": sympy.sinh,
"cosh": sympy.cosh,
"tanh": sympy.tanh,
"asinh": sympy.asinh,
"acosh": sympy.acosh,
"atanh": sympy.atanh,
# Exponentials and logarithms
"exp": sympy.exp,
"log": sympy.log,
"ln": sympy.log, # natural log sympy defaults to natural log
"log10": lambda x: sympy.log(x, 10), # log base 10 (use sympy.log)
# Powers and roots
"sqrt": sympy.sqrt,
"cbrt": lambda x: sympy.Pow(x, sympy.Rational(1, 3)),
# Combinatorics and other math functions
"factorial": sympy.factorial,
"binomial": sympy.binomial,
"gcd": sympy.gcd,
"lcm": sympy.lcm,
"abs": sympy.Abs,
"sign": sympy.sign,
"mod": sympy.Mod,
# Constants
"pi": sympy.pi,
"e": sympy.E,
"I": sympy.I,
"oo": sympy.oo,
"NaN": sympy.nan,
}
# Use parse_expr with strict settings
output = str(
sympy.parse_expr(
expression, local_dict=allowed_symbols, evaluate=True
)
)
except Exception as e:
raise ValueError(
f'LLMSymbolicMathChain._evaluate("{expression}") raised error: {e}.'
" Please try again with a valid numerical expression"
)
# Remove any leading and trailing brackets from the output
return re.sub(r"^\[|\]$", "", output)
def _process_llm_result(
self, llm_output: str, run_manager: CallbackManagerForChainRun
) -> Dict[str, str]:
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
async def _aprocess_llm_result(
self,
llm_output: str,
run_manager: AsyncCallbackManagerForChainRun,
) -> Dict[str, str]:
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key])
llm_output = self.llm_chain.predict(
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return self._process_llm_result(llm_output, _run_manager)
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
await _run_manager.on_text(inputs[self.input_key])
llm_output = await self.llm_chain.apredict(
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return await self._aprocess_llm_result(llm_output, _run_manager)
@property
def _chain_type(self) -> str:
return "llm_symbolic_math_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = PROMPT,
**kwargs: Any,
) -> LLMSymbolicMathChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)

View File

@ -1,51 +0,0 @@
# flake8: noqa
from langchain_core.prompts.prompt import PromptTemplate
_PROMPT_TEMPLATE = """Translate a math problem into a expression that can be executed using Python's SymPy library. Use the output of running this code to answer the question.
Question: ${{Question with math problem.}}
```text
${{single line sympy expression that solves the problem}}
```
...sympy.sympify(text, evaluate=True)...
```output
${{Output of running the code}}
```
Answer: ${{Answer}}
Begin.
Question: What is the limit of sin(x) / x as x goes to 0
```text
limit(sin(x)/x, x, 0)
```
...sympy.sympify("limit(sin(x)/x, x, 0)")...
```output
1
```
Answer: 1
Question: What is the integral of e^-x from 0 to infinity
```text
integrate(exp(-x), (x, 0, oo))
```
...sympy.sympify("integrate(exp(-x), (x, 0, oo))")...
```output
1
```
Question: What are the solutions to this equation x**2 - x?
```text
solveset(x**2 - x, x)
```
...sympy.sympify("solveset(x**2 - x, x)")...
```output
[0, 1]
```
Question: {question}
"""
PROMPT = PromptTemplate(
input_variables=["question"],
template=_PROMPT_TEMPLATE,
)

View File

@ -1,10 +0,0 @@
"""Experimental **LLM** classes provide
access to the large language model (**LLM**) APIs and services.
"""
from langchain_experimental.llms.jsonformer_decoder import JsonFormer
from langchain_experimental.llms.llamaapi import ChatLlamaAPI
from langchain_experimental.llms.lmformatenforcer_decoder import LMFormatEnforcer
from langchain_experimental.llms.rellm_decoder import RELLM
__all__ = ["RELLM", "JsonFormer", "ChatLlamaAPI", "LMFormatEnforcer"]

View File

@ -1,228 +0,0 @@
import json
from collections import defaultdict
from html.parser import HTMLParser
from typing import Any, DefaultDict, Dict, List, Optional, cast
from langchain.schema import (
ChatGeneration,
ChatResult,
)
from langchain_community.chat_models.anthropic import ChatAnthropic
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
SystemMessage,
)
from pydantic import model_validator
prompt = """In addition to responding, you can use tools. \
You have access to the following tools.
{tools}
In order to use a tool, you can use <tool></tool> to specify the name, \
and the <tool_input></tool_input> tags to specify the parameters. \
Each parameter should be passed in as <$param_name>$value</$param_name>, \
Where $param_name is the name of the specific parameter, and $value \
is the value for that parameter.
You will then get back a response in the form <observation></observation>
For example, if you have a tool called 'search' that accepts a single \
parameter 'query' that could run a google search, in order to search \
for the weather in SF you would respond:
<tool>search</tool><tool_input><query>weather in SF</query></tool_input>
<observation>64 degrees</observation>"""
class TagParser(HTMLParser):
"""Parser for the tool tags."""
def __init__(self) -> None:
"""A heavy-handed solution, but it's fast for prototyping.
Might be re-implemented later to restrict scope to the limited grammar, and
more efficiency.
Uses an HTML parser to parse a limited grammar that allows
for syntax of the form:
INPUT -> JUNK? VALUE*
JUNK -> JUNK_CHARACTER+
JUNK_CHARACTER -> whitespace | ,
VALUE -> <IDENTIFIER>DATA</IDENTIFIER> | OBJECT
OBJECT -> <IDENTIFIER>VALUE+</IDENTIFIER>
IDENTIFIER -> [a-Z][a-Z0-9_]*
DATA -> .*
Interprets the data to allow repetition of tags and recursion
to support representation of complex types.
^ Just another approximately wrong grammar specification.
"""
super().__init__()
self.parse_data: DefaultDict[str, List[Any]] = defaultdict(list)
self.stack: List[DefaultDict[str, List[str]]] = [self.parse_data]
self.success = True
self.depth = 0
self.data: Optional[str] = None
def handle_starttag(self, tag: str, attrs: Any) -> None:
"""Hook when a new tag is encountered."""
self.depth += 1
self.stack.append(defaultdict(list))
self.data = None
def handle_endtag(self, tag: str) -> None:
"""Hook when a tag is closed."""
self.depth -= 1
top_of_stack = dict(self.stack.pop(-1)) # Pop the dictionary we don't need it
# If a lead node
is_leaf = self.data is not None
# Annoying to type here, code is tested, hopefully OK
value = self.data if is_leaf else top_of_stack
# Difficult to type this correctly with mypy (maybe impossible?)
# Can be nested indefinitely, so requires self referencing type
self.stack[-1][tag].append(value) # type: ignore
# Reset the data so we if we encounter a sequence of end tags, we
# don't confuse an outer end tag for belonging to a leaf node.
self.data = None
def handle_data(self, data: str) -> None:
"""Hook when handling data."""
stripped_data = data.strip()
# The only data that's allowed is whitespace or a comma surrounded by whitespace
if self.depth == 0 and stripped_data not in (",", ""):
# If this is triggered the parse should be considered invalid.
self.success = False
if stripped_data: # ignore whitespace-only strings
self.data = stripped_data
def _destrip(tool_input: Any) -> Any:
if isinstance(tool_input, dict):
return {k: _destrip(v) for k, v in tool_input.items()}
elif isinstance(tool_input, list):
if isinstance(tool_input[0], str):
if len(tool_input) == 1:
return tool_input[0]
else:
raise ValueError
elif isinstance(tool_input[0], dict):
return [_destrip(v) for v in tool_input]
else:
raise ValueError
else:
raise ValueError
@deprecated(
since="0.0.54",
removal="1.0",
alternative_import="langchain_anthropic.experimental.ChatAnthropicTools",
)
class AnthropicFunctions(BaseChatModel):
"""Chat model for interacting with Anthropic functions."""
llm: BaseChatModel
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
values["llm"] = values.get("llm") or ChatAnthropic(**values)
return values
@property
def model(self) -> BaseChatModel:
"""For backwards compatibility."""
return self.llm
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
forced = False
function_call = ""
if "functions" in kwargs:
# get the function call method
if "function_call" in kwargs:
function_call = kwargs["function_call"]
del kwargs["function_call"]
else:
function_call = "auto"
# should function calling be used
if function_call != "none":
content = prompt.format(tools=json.dumps(kwargs["functions"], indent=2))
system = SystemMessage(content=content)
messages = [system] + messages
# is the function call a dictionary (forced function calling)
if isinstance(function_call, dict):
forced = True
function_call_name = function_call["name"]
messages.append(AIMessage(content=f"<tool>{function_call_name}</tool>"))
del kwargs["functions"]
if stop is None:
stop = ["</tool_input>"]
else:
stop.append("</tool_input>")
else:
if "function_call" in kwargs:
raise ValueError(
"if `function_call` provided, `functions` must also be"
)
response = self.model.invoke(
messages, stop=stop, callbacks=run_manager, **kwargs
)
completion = cast(str, response.content)
if forced:
tag_parser = TagParser()
if "<tool_input>" in completion:
tag_parser.feed(completion.strip() + "</tool_input>")
v1 = tag_parser.parse_data["tool_input"][0]
arguments = json.dumps(_destrip(v1))
else:
v1 = completion
arguments = ""
kwargs = {
"function_call": {
"name": function_call_name, # type: ignore[has-type]
"arguments": arguments,
}
}
message = AIMessage(content="", additional_kwargs=kwargs)
return ChatResult(generations=[ChatGeneration(message=message)])
elif "<tool>" in completion:
tag_parser = TagParser()
tag_parser.feed(completion.strip() + "</tool_input>")
msg = completion.split("<tool>")[0].strip()
v1 = tag_parser.parse_data["tool_input"][0]
kwargs = {
"function_call": {
"name": tag_parser.parse_data["tool"][0],
"arguments": json.dumps(_destrip(v1)),
}
}
message = AIMessage(content=msg, additional_kwargs=kwargs)
return ChatResult(generations=[ChatGeneration(message=message)])
else:
response.content = cast(str, response.content).strip()
return ChatResult(generations=[ChatGeneration(message=response)])
@property
def _llm_type(self) -> str:
return "anthropic_functions"

View File

@ -1,67 +0,0 @@
"""Experimental implementation of jsonformer wrapped LLM."""
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any, List, Optional, cast
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from pydantic import Field, model_validator
if TYPE_CHECKING:
import jsonformer
def import_jsonformer() -> jsonformer:
"""Lazily import of the jsonformer package."""
try:
import jsonformer
except ImportError:
raise ImportError(
"Could not import jsonformer python package. "
"Please install it with `pip install jsonformer`."
)
return jsonformer
class JsonFormer(HuggingFacePipeline):
"""Jsonformer wrapped LLM using HuggingFace Pipeline API.
This pipeline is experimental and not yet stable.
"""
json_schema: dict = Field(..., description="The JSON Schema to complete.")
max_new_tokens: int = Field(
default=200, description="Maximum number of new tokens to generate."
)
debug: bool = Field(default=False, description="Debug mode.")
@model_validator(mode="before")
@classmethod
def check_jsonformer_installation(cls, values: dict) -> Any:
import_jsonformer()
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
jsonformer = import_jsonformer()
from transformers import Text2TextGenerationPipeline
pipeline = cast(Text2TextGenerationPipeline, self.pipeline)
model = jsonformer.Jsonformer(
model=pipeline.model,
tokenizer=pipeline.tokenizer,
json_schema=self.json_schema,
prompt=prompt,
max_number_tokens=self.max_new_tokens,
debug=self.debug,
)
text = model()
return json.dumps(text)

View File

@ -1,126 +0,0 @@
import json
import logging
from typing import (
Any,
Dict,
List,
Mapping,
Optional,
Tuple,
)
from langchain.schema import (
ChatGeneration,
ChatResult,
)
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
logger = logging.getLogger(__name__)
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
# Fix for azure
# Also OpenAI returns None for tool invocations
content = _dict.get("content") or ""
if _dict.get("function_call"):
_dict["function_call"]["arguments"] = json.dumps(
_dict["function_call"]["arguments"]
)
additional_kwargs = {"function_call": dict(_dict["function_call"])}
else:
additional_kwargs = {}
return AIMessage(content=content, additional_kwargs=additional_kwargs)
elif role == "system":
return SystemMessage(content=_dict["content"])
elif role == "function":
return FunctionMessage(content=_dict["content"], name=_dict["name"])
else:
return ChatMessage(content=_dict["content"], role=role)
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
class ChatLlamaAPI(BaseChatModel):
"""Chat model using the Llama API."""
client: Any #: :meta private:
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
_params = {"messages": message_dicts}
final_params = {**params, **kwargs, **_params}
response = self.client.run(final_params).json()
return self._create_chat_result(response)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = dict(self._client_params)
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.get("finish_reason")),
)
generations.append(gen)
return ChatResult(generations=generations)
@property
def _client_params(self) -> Mapping[str, Any]:
"""Get the parameters used for the client."""
return {}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "llama-api"

View File

@ -1,83 +0,0 @@
"""Experimental implementation of lm-format-enforcer wrapped LLM."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, List, Optional
from langchain.schema import LLMResult
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from pydantic import Field
if TYPE_CHECKING:
import lmformatenforcer
def import_lmformatenforcer() -> lmformatenforcer:
"""Lazily import of the lmformatenforcer package."""
try:
import lmformatenforcer
except ImportError:
raise ImportError(
"Could not import lmformatenforcer python package. "
"Please install it with `pip install lm-format-enforcer`."
)
return lmformatenforcer
class LMFormatEnforcer(HuggingFacePipeline):
"""LMFormatEnforcer wrapped LLM using HuggingFace Pipeline API.
This pipeline is experimental and not yet stable.
"""
json_schema: Optional[dict] = Field(
description="The JSON Schema to complete.", default=None
)
regex: Optional[str] = Field(
description="The regular expression to complete.", default=None
)
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
lmformatenforcer = import_lmformatenforcer()
import lmformatenforcer.integrations.transformers as hf_integration
# We integrate lmformatenforcer by adding a prefix_allowed_tokens_fn.
# It has to be done on each call, because the prefix function is stateful.
if "prefix_allowed_tokens_fn" in self.pipeline._forward_params:
raise ValueError(
"prefix_allowed_tokens_fn param is forbidden with LMFormatEnforcer."
)
has_json_schema = self.json_schema is not None
has_regex = self.regex is not None
if has_json_schema == has_regex:
raise ValueError(
"You must specify exactly one of json_schema or a regex, but not both."
)
if has_json_schema:
parser = lmformatenforcer.JsonSchemaParser(self.json_schema)
else:
parser = lmformatenforcer.RegexParser(self.regex)
prefix_function = hf_integration.build_transformers_prefix_allowed_tokens_fn(
self.pipeline.tokenizer, parser
)
self.pipeline._forward_params["prefix_allowed_tokens_fn"] = prefix_function
result = super()._generate(
prompts,
stop=stop,
run_manager=run_manager,
**kwargs,
)
del self.pipeline._forward_params["prefix_allowed_tokens_fn"]
return result

View File

@ -1,462 +0,0 @@
import json
import uuid
from operator import itemgetter
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Type,
TypedDict,
TypeVar,
Union,
)
from langchain_community.chat_models.ollama import ChatOllama
from langchain_core._api import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import (
AIMessage,
BaseMessage,
ToolCall,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.json import JsonOutputParser
from langchain_core.output_parsers.pydantic import PydanticOutputParser
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.prompts import SystemMessagePromptTemplate
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_core.runnables.base import RunnableMap
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass
from pydantic import (
BaseModel,
)
DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools:
{tools}
You must always select one of the above tools and respond with only a JSON object matching the following schema:
{{
"tool": <name of the selected tool>,
"tool_input": <parameters for the selected tool, matching the tool's JSON schema>
}}
""" # noqa: E501
DEFAULT_RESPONSE_FUNCTION = {
"name": "__conversational_response",
"description": (
"Respond conversationally if no other tools should be called for a given query."
),
"parameters": {
"type": "object",
"properties": {
"response": {
"type": "string",
"description": "Conversational response to the user.",
},
},
"required": ["response"],
},
}
_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydantic = Union[Dict, _BM]
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and (
is_basemodel_subclass(obj) or BaseModel in obj.__bases__
)
def convert_to_ollama_tool(tool: Any) -> Dict:
"""Convert a tool to an Ollama tool."""
description = None
if _is_pydantic_class(tool):
schema = tool.model_construct().model_json_schema()
name = schema["title"]
elif isinstance(tool, BaseTool):
schema = tool.tool_call_schema.model_json_schema()
name = tool.get_name()
description = tool.description
elif is_basemodel_instance(tool):
schema = tool.get_input_schema().model_json_schema()
name = tool.get_name()
description = tool.description
elif isinstance(tool, dict) and "name" in tool and "parameters" in tool:
return tool.copy()
else:
raise ValueError(
f"""Cannot convert {tool} to an Ollama tool.
{tool} needs to be a Pydantic class, model, or a dict."""
)
definition = {"name": name, "parameters": schema}
if description:
definition["description"] = description
return definition
class _AllReturnType(TypedDict):
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
def parse_response(message: BaseMessage) -> str:
"""Extract `function_call` from `AIMessage`."""
if isinstance(message, AIMessage):
kwargs = message.additional_kwargs
tool_calls = message.tool_calls
if len(tool_calls) > 0:
tool_call = tool_calls[-1]
args = tool_call.get("args")
return json.dumps(args)
elif "function_call" in kwargs:
if "arguments" in kwargs["function_call"]:
return kwargs["function_call"]["arguments"]
raise ValueError(
f"`arguments` missing from `function_call` within AIMessage: {message}"
)
else:
raise ValueError("`tool_calls` missing from AIMessage: {message}")
raise ValueError(f"`message` is not an instance of `AIMessage`: {message}")
@deprecated( # type: ignore[arg-type]
since="0.0.64", removal="1.0", alternative_import="langchain_ollama.ChatOllama"
)
class OllamaFunctions(ChatOllama):
"""Function chat model that uses Ollama API."""
tool_system_prompt_template: str = DEFAULT_SYSTEM_TEMPLATE
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.bind(functions=tools, **kwargs)
def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
then the model output will be an object of that class. If a dict then
the model output will be a dict. With a Pydantic class the returned
attributes will be validated, whereas with a dict they will not be.
include_raw: If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
Returns:
A Runnable that takes any ChatModel input and returns as output:
If include_raw is True then a dict with keys:
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
If include_raw is False then just _DictOrPydantic is returned,
where _DictOrPydantic depends on the schema:
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
class.
If schema is a dict then _DictOrPydantic is a dict.
Example: Pydantic schema (include_raw=False):
.. code-block:: python
from langchain_experimental.llms import OllamaFunctions
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = OllamaFunctions(model="phi3", format="json", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> AnswerWithJustification(
# answer='They weigh the same',
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
# )
Example: Pydantic schema (include_raw=True):
.. code-block:: python
from langchain_experimental.llms import OllamaFunctions
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = OllamaFunctions(model="phi3", format="json", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
# 'parsing_error': None
# }
Example: dict schema (method="include_raw=False):
.. code-block:: python
from langchain_experimental.llms import OllamaFunctions, convert_to_ollama_tool
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
dict_schema = convert_to_ollama_tool(AnswerWithJustification)
llm = OllamaFunctions(model="phi3", format="json", temperature=0)
structured_llm = llm.with_structured_output(dict_schema)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'answer': 'They weigh the same',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
# }
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
llm = self.bind_tools(tools=[schema], format="json")
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticOutputParser( # type: ignore[type-var]
pydantic_object=schema # type: ignore[arg-type]
)
else:
output_parser = JsonOutputParser()
parser_chain = RunnableLambda(parse_response) | output_parser
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | parser_chain, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | parser_chain
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
functions = kwargs.get("functions", [])
if "functions" in kwargs:
del kwargs["functions"]
if "function_call" in kwargs:
functions = [
fn for fn in functions if fn["name"] == kwargs["function_call"]["name"]
]
if not functions:
raise ValueError(
"If `function_call` is specified, you must also pass a "
"matching function in `functions`."
)
del kwargs["function_call"]
functions = [convert_to_ollama_tool(fn) for fn in functions]
functions.append(DEFAULT_RESPONSE_FUNCTION)
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
self.tool_system_prompt_template
)
system_message = system_message_prompt_template.format(
tools=json.dumps(functions, indent=2)
)
response_message = super()._generate(
[system_message] + messages, stop=stop, run_manager=run_manager, **kwargs
)
chat_generation_content = response_message.generations[0].text
if not isinstance(chat_generation_content, str):
raise ValueError("OllamaFunctions does not support non-string output.")
try:
parsed_chat_result = json.loads(chat_generation_content)
except json.JSONDecodeError:
raise ValueError(
f"""'{self.model}' did not respond with valid JSON.
Please try again.
Response: {chat_generation_content}"""
)
called_tool_name = (
parsed_chat_result["tool"] if "tool" in parsed_chat_result else None
)
called_tool = next(
(fn for fn in functions if fn["name"] == called_tool_name), None
)
if (
called_tool is None
or called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]
):
if (
"tool_input" in parsed_chat_result
and "response" in parsed_chat_result["tool_input"]
):
response = parsed_chat_result["tool_input"]["response"]
elif "response" in parsed_chat_result:
response = parsed_chat_result["response"]
else:
raise ValueError(
f"Failed to parse a response from {self.model} output: "
f"{chat_generation_content}"
)
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(
content=response,
)
)
]
)
called_tool_arguments = (
parsed_chat_result["tool_input"]
if "tool_input" in parsed_chat_result
else {}
)
response_message_with_functions = AIMessage(
content="",
tool_calls=[
ToolCall(
name=called_tool_name,
args=called_tool_arguments if called_tool_arguments else {},
id=f"call_{str(uuid.uuid4()).replace('-', '')}",
)
],
)
return ChatResult(
generations=[ChatGeneration(message=response_message_with_functions)]
)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
functions = kwargs.get("functions", [])
if "functions" in kwargs:
del kwargs["functions"]
if "function_call" in kwargs:
functions = [
fn for fn in functions if fn["name"] == kwargs["function_call"]["name"]
]
if not functions:
raise ValueError(
"If `function_call` is specified, you must also pass a "
"matching function in `functions`."
)
del kwargs["function_call"]
elif not functions:
functions.append(DEFAULT_RESPONSE_FUNCTION)
if _is_pydantic_class(functions[0]):
functions = [convert_to_ollama_tool(fn) for fn in functions]
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
self.tool_system_prompt_template
)
system_message = system_message_prompt_template.format(
tools=json.dumps(functions, indent=2)
)
response_message = await super()._agenerate(
[system_message] + messages, stop=stop, run_manager=run_manager, **kwargs
)
chat_generation_content = response_message.generations[0].text
if not isinstance(chat_generation_content, str):
raise ValueError("OllamaFunctions does not support non-string output.")
try:
parsed_chat_result = json.loads(chat_generation_content)
except json.JSONDecodeError:
raise ValueError(
f"""'{self.model}' did not respond with valid JSON.
Please try again.
Response: {chat_generation_content}"""
)
called_tool_name = parsed_chat_result["tool"]
called_tool_arguments = parsed_chat_result["tool_input"]
called_tool = next(
(fn for fn in functions if fn["name"] == called_tool_name), None
)
if called_tool is None:
raise ValueError(
f"Failed to parse a function call from {self.model} output: "
f"{chat_generation_content}"
)
if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]:
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(
content=called_tool_arguments["response"],
)
)
]
)
response_message_with_functions = AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": called_tool_name,
"arguments": json.dumps(called_tool_arguments)
if called_tool_arguments
else "",
},
},
)
return ChatResult(
generations=[ChatGeneration(message=response_message_with_functions)]
)
@property
def _llm_type(self) -> str:
return "ollama_functions"

View File

@ -1,71 +0,0 @@
"""Experimental implementation of RELLM wrapped LLM."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, List, Optional, cast
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_community.llms.utils import enforce_stop_tokens
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from pydantic import Field, model_validator
if TYPE_CHECKING:
import rellm
from regex import Pattern as RegexPattern
else:
try:
from regex import Pattern as RegexPattern
except ImportError:
pass
def import_rellm() -> rellm:
"""Lazily import of the rellm package."""
try:
import rellm
except ImportError:
raise ImportError(
"Could not import rellm python package. "
"Please install it with `pip install rellm`."
)
return rellm
class RELLM(HuggingFacePipeline):
"""RELLM wrapped LLM using HuggingFace Pipeline API."""
regex: RegexPattern = Field(..., description="The structured format to complete.")
max_new_tokens: int = Field(
default=200, description="Maximum number of new tokens to generate."
)
@model_validator(mode="before")
@classmethod
def check_rellm_installation(cls, values: dict) -> Any:
import_rellm()
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
rellm = import_rellm()
from transformers import Text2TextGenerationPipeline
pipeline = cast(Text2TextGenerationPipeline, self.pipeline)
text = rellm.complete_re(
prompt,
self.regex,
tokenizer=pipeline.tokenizer,
model=pipeline.model,
max_new_tokens=self.max_new_tokens,
)
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text

View File

@ -1,12 +0,0 @@
"""**OpenCLIP Embeddings** model.
OpenCLIP is a multimodal model that can encode text and images into a shared space.
See this paper for more details: https://arxiv.org/abs/2103.00020
and [this repository](https://github.com/mlfoundations/open_clip) for details.
"""
from .open_clip import OpenCLIPEmbeddings
__all__ = ["OpenCLIPEmbeddings"]

Some files were not shown because too many files have changed in this diff Show More