community[major], core[patch], langchain[patch], experimental[patch]: Create langchain-community (#14463)

Moved the following modules to new package langchain-community in a backwards compatible fashion:

```
mv langchain/langchain/adapters community/langchain_community
mv langchain/langchain/callbacks community/langchain_community/callbacks
mv langchain/langchain/chat_loaders community/langchain_community
mv langchain/langchain/chat_models community/langchain_community
mv langchain/langchain/document_loaders community/langchain_community
mv langchain/langchain/docstore community/langchain_community
mv langchain/langchain/document_transformers community/langchain_community
mv langchain/langchain/embeddings community/langchain_community
mv langchain/langchain/graphs community/langchain_community
mv langchain/langchain/llms community/langchain_community
mv langchain/langchain/memory/chat_message_histories community/langchain_community
mv langchain/langchain/retrievers community/langchain_community
mv langchain/langchain/storage community/langchain_community
mv langchain/langchain/tools community/langchain_community
mv langchain/langchain/utilities community/langchain_community
mv langchain/langchain/vectorstores community/langchain_community
mv langchain/langchain/agents/agent_toolkits community/langchain_community
mv langchain/langchain/cache.py community/langchain_community
mv langchain/langchain/adapters community/langchain_community
mv langchain/langchain/callbacks community/langchain_community/callbacks
mv langchain/langchain/chat_loaders community/langchain_community
mv langchain/langchain/chat_models community/langchain_community
mv langchain/langchain/document_loaders community/langchain_community
mv langchain/langchain/docstore community/langchain_community
mv langchain/langchain/document_transformers community/langchain_community
mv langchain/langchain/embeddings community/langchain_community
mv langchain/langchain/graphs community/langchain_community
mv langchain/langchain/llms community/langchain_community
mv langchain/langchain/memory/chat_message_histories community/langchain_community
mv langchain/langchain/retrievers community/langchain_community
mv langchain/langchain/storage community/langchain_community
mv langchain/langchain/tools community/langchain_community
mv langchain/langchain/utilities community/langchain_community
mv langchain/langchain/vectorstores community/langchain_community
mv langchain/langchain/agents/agent_toolkits community/langchain_community
mv langchain/langchain/cache.py community/langchain_community
```

Moved the following to core
```
mv langchain/langchain/utils/json_schema.py core/langchain_core/utils
mv langchain/langchain/utils/html.py core/langchain_core/utils
mv langchain/langchain/utils/strings.py core/langchain_core/utils
cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py
rm langchain/langchain/utils/env.py
```

See .scripts/community_split/script_integrations.sh for all changes
This commit is contained in:
Bagatur
2023-12-11 13:53:30 -08:00
committed by GitHub
parent c0f4b95aa9
commit ed58eeb9c5
2446 changed files with 171805 additions and 137118 deletions

View File

@@ -51,10 +51,12 @@ check_imports: langchain/**/*.py
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/langchain --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
@@ -62,7 +64,7 @@ lint lint_diff lint_package lint_tests:
poetry run ruff .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)

View File

@@ -1,399 +1,37 @@
from __future__ import annotations
import importlib
from typing import (
Any,
AsyncIterator,
Dict,
Iterable,
List,
Mapping,
Sequence,
Union,
overload,
from langchain_community.adapters.openai import (
Chat,
ChatCompletion,
ChatCompletionChunk,
ChatCompletions,
Choice,
ChoiceChunk,
Completions,
IndexableBaseModel,
_convert_message_chunk,
_convert_message_chunk_to_delta,
_has_assistant_message,
chat,
convert_dict_to_message,
convert_message_to_dict,
convert_messages_for_finetuning,
convert_openai_messages,
)
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.pydantic_v1 import BaseModel
from typing_extensions import Literal
async def aenumerate(
iterable: AsyncIterator[Any], start: int = 0
) -> AsyncIterator[tuple[int, Any]]:
"""Async version of enumerate function."""
i = start
async for x in iterable:
yield i, x
i += 1
class IndexableBaseModel(BaseModel):
"""Allows a BaseModel to return its fields by string variable indexing"""
def __getitem__(self, item: str) -> Any:
return getattr(self, item)
class Choice(IndexableBaseModel):
message: dict
class ChatCompletions(IndexableBaseModel):
choices: List[Choice]
class ChoiceChunk(IndexableBaseModel):
delta: dict
class ChatCompletionChunk(IndexableBaseModel):
choices: List[ChoiceChunk]
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
Args:
_dict: The dictionary.
Returns:
The LangChain message.
"""
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
# Fix for azure
# Also OpenAI returns None for tool invocations
content = _dict.get("content", "") or ""
additional_kwargs: Dict = {}
if _dict.get("function_call"):
additional_kwargs["function_call"] = dict(_dict["function_call"])
if _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = _dict["tool_calls"]
return AIMessage(content=content, additional_kwargs=additional_kwargs)
elif role == "system":
return SystemMessage(content=_dict["content"])
elif role == "function":
return FunctionMessage(content=_dict["content"], name=_dict["name"])
elif role == "tool":
return ToolMessage(content=_dict["content"], tool_call_id=_dict["tool_call_id"])
else:
return ChatMessage(content=_dict["content"], role=role)
def convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a dictionary.
Args:
message: The LangChain message.
Returns:
The dictionary.
"""
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
if "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
# If tool calls only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"content": message.content,
"tool_call_id": message.tool_call_id,
}
else:
raise TypeError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]:
"""Convert dictionaries representing OpenAI messages to LangChain format.
Args:
messages: List of dictionaries representing OpenAI messages
Returns:
List of LangChain BaseMessage objects.
"""
return [convert_dict_to_message(m) for m in messages]
def _convert_message_chunk(chunk: BaseMessageChunk, i: int) -> dict:
_dict: Dict[str, Any] = {}
if isinstance(chunk, AIMessageChunk):
if i == 0:
# Only shows up in the first chunk
_dict["role"] = "assistant"
if "function_call" in chunk.additional_kwargs:
_dict["function_call"] = chunk.additional_kwargs["function_call"]
# If the first chunk is a function call, the content is not empty string,
# not missing, but None.
if i == 0:
_dict["content"] = None
else:
_dict["content"] = chunk.content
else:
raise ValueError(f"Got unexpected streaming chunk type: {type(chunk)}")
# This only happens at the end of streams, and OpenAI returns as empty dict
if _dict == {"content": ""}:
_dict = {}
return _dict
def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]:
_dict = _convert_message_chunk(chunk, i)
return {"choices": [{"delta": _dict}]}
class ChatCompletion:
"""Chat completion."""
@overload
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[False] = False,
**kwargs: Any,
) -> dict:
...
@overload
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[True],
**kwargs: Any,
) -> Iterable:
...
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: bool = False,
**kwargs: Any,
) -> Union[dict, Iterable]:
models = importlib.import_module("langchain.chat_models")
model_cls = getattr(models, provider)
model_config = model_cls(**kwargs)
converted_messages = convert_openai_messages(messages)
if not stream:
result = model_config.invoke(converted_messages)
return {"choices": [{"message": convert_message_to_dict(result)}]}
else:
return (
_convert_message_chunk_to_delta(c, i)
for i, c in enumerate(model_config.stream(converted_messages))
)
@overload
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[False] = False,
**kwargs: Any,
) -> dict:
...
@overload
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[True],
**kwargs: Any,
) -> AsyncIterator:
...
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: bool = False,
**kwargs: Any,
) -> Union[dict, AsyncIterator]:
models = importlib.import_module("langchain.chat_models")
model_cls = getattr(models, provider)
model_config = model_cls(**kwargs)
converted_messages = convert_openai_messages(messages)
if not stream:
result = await model_config.ainvoke(converted_messages)
return {"choices": [{"message": convert_message_to_dict(result)}]}
else:
return (
_convert_message_chunk_to_delta(c, i)
async for i, c in aenumerate(model_config.astream(converted_messages))
)
def _has_assistant_message(session: ChatSession) -> bool:
"""Check if chat session has an assistant message."""
return any([isinstance(m, AIMessage) for m in session["messages"]])
def convert_messages_for_finetuning(
sessions: Iterable[ChatSession],
) -> List[List[dict]]:
"""Convert messages to a list of lists of dictionaries for fine-tuning.
Args:
sessions: The chat sessions.
Returns:
The list of lists of dictionaries.
"""
return [
[convert_message_to_dict(s) for s in session["messages"]]
for session in sessions
if _has_assistant_message(session)
]
class Completions:
"""Completion."""
@overload
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[False] = False,
**kwargs: Any,
) -> ChatCompletions:
...
@overload
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[True],
**kwargs: Any,
) -> Iterable:
...
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: bool = False,
**kwargs: Any,
) -> Union[ChatCompletions, Iterable]:
models = importlib.import_module("langchain.chat_models")
model_cls = getattr(models, provider)
model_config = model_cls(**kwargs)
converted_messages = convert_openai_messages(messages)
if not stream:
result = model_config.invoke(converted_messages)
return ChatCompletions(
choices=[Choice(message=convert_message_to_dict(result))]
)
else:
return (
ChatCompletionChunk(
choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))]
)
for i, c in enumerate(model_config.stream(converted_messages))
)
@overload
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[False] = False,
**kwargs: Any,
) -> ChatCompletions:
...
@overload
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[True],
**kwargs: Any,
) -> AsyncIterator:
...
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: bool = False,
**kwargs: Any,
) -> Union[ChatCompletions, AsyncIterator]:
models = importlib.import_module("langchain.chat_models")
model_cls = getattr(models, provider)
model_config = model_cls(**kwargs)
converted_messages = convert_openai_messages(messages)
if not stream:
result = await model_config.ainvoke(converted_messages)
return ChatCompletions(
choices=[Choice(message=convert_message_to_dict(result))]
)
else:
return (
ChatCompletionChunk(
choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))]
)
async for i, c in aenumerate(model_config.astream(converted_messages))
)
class Chat:
def __init__(self) -> None:
self.completions = Completions()
chat = Chat()
__all__ = [
"IndexableBaseModel",
"Choice",
"ChatCompletions",
"ChoiceChunk",
"ChatCompletionChunk",
"convert_dict_to_message",
"convert_message_to_dict",
"convert_openai_messages",
"_convert_message_chunk",
"_convert_message_chunk_to_delta",
"ChatCompletion",
"_has_assistant_message",
"convert_messages_for_finetuning",
"Completions",
"Chat",
"chat",
]

View File

@@ -32,6 +32,7 @@ from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.utils import AddableDict
from langchain_core.tools import BaseTool
from langchain_core.utils.input import get_color_mapping
from langchain.agents.agent_iterator import AgentExecutorIterator
@@ -47,7 +48,6 @@ from langchain.callbacks.manager import (
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.tools.base import BaseTool
from langchain.utilities.asyncio import asyncio_timeout
logger = logging.getLogger(__name__)

View File

@@ -1,53 +1,3 @@
from __future__ import annotations
from langchain_community.agent_toolkits.ainetwork.toolkit import AINetworkToolkit
from typing import TYPE_CHECKING, List, Literal, Optional
from langchain_core.pydantic_v1 import root_validator
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
from langchain.tools.ainetwork.app import AINAppOps
from langchain.tools.ainetwork.owner import AINOwnerOps
from langchain.tools.ainetwork.rule import AINRuleOps
from langchain.tools.ainetwork.transfer import AINTransfer
from langchain.tools.ainetwork.utils import authenticate
from langchain.tools.ainetwork.value import AINValueOps
if TYPE_CHECKING:
from ain.ain import Ain
class AINetworkToolkit(BaseToolkit):
"""Toolkit for interacting with AINetwork Blockchain.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by reading, creating, updating, deleting
data associated with this service.
See https://python.langchain.com/docs/security for more information.
"""
network: Optional[Literal["mainnet", "testnet"]] = "testnet"
interface: Optional[Ain] = None
@root_validator(pre=True)
def set_interface(cls, values: dict) -> dict:
if not values.get("interface"):
values["interface"] = authenticate(network=values.get("network", "testnet"))
return values
class Config:
"""Pydantic config."""
validate_all = True
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
AINAppOps(),
AINOwnerOps(),
AINRuleOps(),
AINTransfer(),
AINValueOps(),
]
__all__ = ["AINetworkToolkit"]

View File

@@ -1,32 +1,3 @@
from __future__ import annotations
from langchain_community.agent_toolkits.amadeus.toolkit import AmadeusToolkit
from typing import TYPE_CHECKING, List
from langchain_core.pydantic_v1 import Field
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
from langchain.tools.amadeus.closest_airport import AmadeusClosestAirport
from langchain.tools.amadeus.flight_search import AmadeusFlightSearch
from langchain.tools.amadeus.utils import authenticate
if TYPE_CHECKING:
from amadeus import Client
class AmadeusToolkit(BaseToolkit):
"""Toolkit for interacting with Amadeus which offers APIs for travel."""
client: Client = Field(default_factory=authenticate)
class Config:
"""Pydantic config."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
AmadeusClosestAirport(),
AmadeusFlightSearch(),
]
__all__ = ["AmadeusToolkit"]

View File

@@ -1,33 +1,5 @@
from __future__ import annotations
import sys
from typing import List
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools.azure_cognitive_services import (
AzureCogsFormRecognizerTool,
AzureCogsImageAnalysisTool,
AzureCogsSpeech2TextTool,
AzureCogsText2SpeechTool,
AzureCogsTextAnalyticsHealthTool,
from langchain_community.agent_toolkits.azure_cognitive_services import (
AzureCognitiveServicesToolkit,
)
from langchain.tools.base import BaseTool
class AzureCognitiveServicesToolkit(BaseToolkit):
"""Toolkit for Azure Cognitive Services."""
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
tools: List[BaseTool] = [
AzureCogsFormRecognizerTool(),
AzureCogsSpeech2TextTool(),
AzureCogsText2SpeechTool(),
AzureCogsTextAnalyticsHealthTool(),
]
# TODO: Remove check once azure-ai-vision supports MacOS.
if sys.platform.startswith("linux") or sys.platform.startswith("win"):
tools.append(AzureCogsImageAnalysisTool())
return tools
__all__ = ["AzureCognitiveServicesToolkit"]

View File

@@ -1,15 +1,3 @@
"""Toolkits for agents."""
from abc import ABC, abstractmethod
from typing import List
from langchain_community.agent_toolkits.base import BaseToolkit
from langchain_core.pydantic_v1 import BaseModel
from langchain.tools import BaseTool
class BaseToolkit(BaseModel, ABC):
"""Base Toolkit representing a collection of related tools."""
@abstractmethod
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
__all__ = ["BaseToolkit"]

View File

@@ -1,108 +1,3 @@
from typing import Dict, List
from langchain_community.agent_toolkits.clickup.toolkit import ClickupToolkit
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
from langchain.tools.clickup.prompt import (
CLICKUP_FOLDER_CREATE_PROMPT,
CLICKUP_GET_ALL_TEAMS_PROMPT,
CLICKUP_GET_FOLDERS_PROMPT,
CLICKUP_GET_LIST_PROMPT,
CLICKUP_GET_SPACES_PROMPT,
CLICKUP_GET_TASK_ATTRIBUTE_PROMPT,
CLICKUP_GET_TASK_PROMPT,
CLICKUP_LIST_CREATE_PROMPT,
CLICKUP_TASK_CREATE_PROMPT,
CLICKUP_UPDATE_TASK_ASSIGNEE_PROMPT,
CLICKUP_UPDATE_TASK_PROMPT,
)
from langchain.tools.clickup.tool import ClickupAction
from langchain.utilities.clickup import ClickupAPIWrapper
class ClickupToolkit(BaseToolkit):
"""Clickup Toolkit.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by reading, creating, updating, deleting
data associated with this service.
See https://python.langchain.com/docs/security for more information.
"""
tools: List[BaseTool] = []
@classmethod
def from_clickup_api_wrapper(
cls, clickup_api_wrapper: ClickupAPIWrapper
) -> "ClickupToolkit":
operations: List[Dict] = [
{
"mode": "get_task",
"name": "Get task",
"description": CLICKUP_GET_TASK_PROMPT,
},
{
"mode": "get_task_attribute",
"name": "Get task attribute",
"description": CLICKUP_GET_TASK_ATTRIBUTE_PROMPT,
},
{
"mode": "get_teams",
"name": "Get Teams",
"description": CLICKUP_GET_ALL_TEAMS_PROMPT,
},
{
"mode": "create_task",
"name": "Create Task",
"description": CLICKUP_TASK_CREATE_PROMPT,
},
{
"mode": "create_list",
"name": "Create List",
"description": CLICKUP_LIST_CREATE_PROMPT,
},
{
"mode": "create_folder",
"name": "Create Folder",
"description": CLICKUP_FOLDER_CREATE_PROMPT,
},
{
"mode": "get_list",
"name": "Get all lists in the space",
"description": CLICKUP_GET_LIST_PROMPT,
},
{
"mode": "get_folders",
"name": "Get all folders in the workspace",
"description": CLICKUP_GET_FOLDERS_PROMPT,
},
{
"mode": "get_spaces",
"name": "Get all spaces in the workspace",
"description": CLICKUP_GET_SPACES_PROMPT,
},
{
"mode": "update_task",
"name": "Update task",
"description": CLICKUP_UPDATE_TASK_PROMPT,
},
{
"mode": "update_task_assignees",
"name": "Update task assignees",
"description": CLICKUP_UPDATE_TASK_ASSIGNEE_PROMPT,
},
]
tools = [
ClickupAction(
name=action["name"],
description=action["description"],
mode=action["mode"],
api_wrapper=clickup_api_wrapper,
)
for action in operations
]
return cls(tools=tools)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return self.tools
__all__ = ["ClickupToolkit"]

View File

@@ -1,4 +1,4 @@
from typing import Any, List, Optional
from typing import Any, List, Optional # noqa: E501
from langchain_core.language_models import BaseLanguageModel
from langchain_core.memory import BaseMemory

View File

@@ -1,81 +1,6 @@
from __future__ import annotations
from langchain_community.agent_toolkits.file_management.toolkit import (
_FILE_TOOLS,
FileManagementToolkit,
)
from typing import List, Optional
from langchain_core.pydantic_v1 import root_validator
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
from langchain.tools.file_management.copy import CopyFileTool
from langchain.tools.file_management.delete import DeleteFileTool
from langchain.tools.file_management.file_search import FileSearchTool
from langchain.tools.file_management.list_dir import ListDirectoryTool
from langchain.tools.file_management.move import MoveFileTool
from langchain.tools.file_management.read import ReadFileTool
from langchain.tools.file_management.write import WriteFileTool
_FILE_TOOLS = {
# "Type[Runnable[Any, Any]]" has no attribute "__fields__" [attr-defined]
tool_cls.__fields__["name"].default: tool_cls # type: ignore[attr-defined]
for tool_cls in [
CopyFileTool,
DeleteFileTool,
FileSearchTool,
MoveFileTool,
ReadFileTool,
WriteFileTool,
ListDirectoryTool,
]
}
class FileManagementToolkit(BaseToolkit):
"""Toolkit for interacting with local files.
*Security Notice*: This toolkit provides methods to interact with local files.
If providing this toolkit to an agent on an LLM, ensure you scope
the agent's permissions to only include the necessary permissions
to perform the desired operations.
By **default** the agent will have access to all files within
the root dir and will be able to Copy, Delete, Move, Read, Write
and List files in that directory.
Consider the following:
- Limit access to particular directories using `root_dir`.
- Use filesystem permissions to restrict access and permissions to only
the files and directories required by the agent.
- Limit the tools available to the agent to only the file operations
necessary for the agent's intended use.
- Sandbox the agent by running it in a container.
See https://python.langchain.com/docs/security for more information.
"""
root_dir: Optional[str] = None
"""If specified, all file operations are made relative to root_dir."""
selected_tools: Optional[List[str]] = None
"""If provided, only provide the selected tools. Defaults to all."""
@root_validator
def validate_tools(cls, values: dict) -> dict:
selected_tools = values.get("selected_tools") or []
for tool_name in selected_tools:
if tool_name not in _FILE_TOOLS:
raise ValueError(
f"File Tool of name {tool_name} not supported."
f" Permitted tools: {list(_FILE_TOOLS)}"
)
return values
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
allowed_tools = self.selected_tools or _FILE_TOOLS.keys()
tools: List[BaseTool] = []
for tool in allowed_tools:
tool_cls = _FILE_TOOLS[tool]
tools.append(tool_cls(root_dir=self.root_dir)) # type: ignore
return tools
__all__ = ["FileManagementToolkit"]
__all__ = ["_FILE_TOOLS", "FileManagementToolkit"]

View File

@@ -1,288 +1,35 @@
"""GitHub Toolkit."""
from typing import Dict, List
from langchain_core.pydantic_v1 import BaseModel
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.pydantic_v1 import Field
from langchain.tools import BaseTool
from langchain.tools.github.prompt import (
COMMENT_ON_ISSUE_PROMPT,
CREATE_BRANCH_PROMPT,
CREATE_FILE_PROMPT,
CREATE_PULL_REQUEST_PROMPT,
CREATE_REVIEW_REQUEST_PROMPT,
DELETE_FILE_PROMPT,
GET_FILES_FROM_DIRECTORY_PROMPT,
GET_ISSUE_PROMPT,
GET_ISSUES_PROMPT,
GET_PR_PROMPT,
LIST_BRANCHES_IN_REPO_PROMPT,
LIST_PRS_PROMPT,
LIST_PULL_REQUEST_FILES,
OVERVIEW_EXISTING_FILES_BOT_BRANCH,
OVERVIEW_EXISTING_FILES_IN_MAIN,
READ_FILE_PROMPT,
SEARCH_CODE_PROMPT,
SEARCH_ISSUES_AND_PRS_PROMPT,
SET_ACTIVE_BRANCH_PROMPT,
UPDATE_FILE_PROMPT,
from langchain_community.agent_toolkits.github.toolkit import (
BranchName,
CommentOnIssue,
CreateFile,
CreatePR,
CreateReviewRequest,
DeleteFile,
DirectoryPath,
GetIssue,
GetPR,
GitHubToolkit,
NoInput,
ReadFile,
SearchCode,
SearchIssuesAndPRs,
UpdateFile,
)
from langchain.tools.github.tool import GitHubAction
from langchain.utilities.github import GitHubAPIWrapper
class NoInput(BaseModel):
no_input: str = Field("", description="No input required, e.g. `` (empty string).")
class GetIssue(BaseModel):
issue_number: int = Field(0, description="Issue number as an integer, e.g. `42`")
class CommentOnIssue(BaseModel):
input: str = Field(..., description="Follow the required formatting.")
class GetPR(BaseModel):
pr_number: int = Field(0, description="The PR number as an integer, e.g. `12`")
class CreatePR(BaseModel):
formatted_pr: str = Field(..., description="Follow the required formatting.")
class CreateFile(BaseModel):
formatted_file: str = Field(..., description="Follow the required formatting.")
class ReadFile(BaseModel):
formatted_filepath: str = Field(
...,
description=(
"The full file path of the file you would like to read where the "
"path must NOT start with a slash, e.g. `some_dir/my_file.py`."
),
)
class UpdateFile(BaseModel):
formatted_file_update: str = Field(
..., description="Strictly follow the provided rules."
)
class DeleteFile(BaseModel):
formatted_filepath: str = Field(
...,
description=(
"The full file path of the file you would like to delete"
" where the path must NOT start with a slash, e.g."
" `some_dir/my_file.py`. Only input a string,"
" not the param name."
),
)
class DirectoryPath(BaseModel):
input: str = Field(
"",
description=(
"The path of the directory, e.g. `some_dir/inner_dir`."
" Only input a string, do not include the parameter name."
),
)
class BranchName(BaseModel):
branch_name: str = Field(
..., description="The name of the branch, e.g. `my_branch`."
)
class SearchCode(BaseModel):
search_query: str = Field(
...,
description=(
"A keyword-focused natural language search"
"query for code, e.g. `MyFunctionName()`."
),
)
class CreateReviewRequest(BaseModel):
username: str = Field(
...,
description="GitHub username of the user being requested, e.g. `my_username`.",
)
class SearchIssuesAndPRs(BaseModel):
search_query: str = Field(
...,
description="Natural language search query, e.g. `My issue title or topic`.",
)
class GitHubToolkit(BaseToolkit):
"""GitHub Toolkit.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by creating, deleting, or updating,
reading underlying data.
For example, this toolkit can be used to create issues, pull requests,
and comments on GitHub.
See [Security](https://python.langchain.com/docs/security) for more information.
"""
tools: List[BaseTool] = []
@classmethod
def from_github_api_wrapper(
cls, github_api_wrapper: GitHubAPIWrapper
) -> "GitHubToolkit":
operations: List[Dict] = [
{
"mode": "get_issues",
"name": "Get Issues",
"description": GET_ISSUES_PROMPT,
"args_schema": NoInput,
},
{
"mode": "get_issue",
"name": "Get Issue",
"description": GET_ISSUE_PROMPT,
"args_schema": GetIssue,
},
{
"mode": "comment_on_issue",
"name": "Comment on Issue",
"description": COMMENT_ON_ISSUE_PROMPT,
"args_schema": CommentOnIssue,
},
{
"mode": "list_open_pull_requests",
"name": "List open pull requests (PRs)",
"description": LIST_PRS_PROMPT,
"args_schema": NoInput,
},
{
"mode": "get_pull_request",
"name": "Get Pull Request",
"description": GET_PR_PROMPT,
"args_schema": GetPR,
},
{
"mode": "list_pull_request_files",
"name": "Overview of files included in PR",
"description": LIST_PULL_REQUEST_FILES,
"args_schema": GetPR,
},
{
"mode": "create_pull_request",
"name": "Create Pull Request",
"description": CREATE_PULL_REQUEST_PROMPT,
"args_schema": CreatePR,
},
{
"mode": "list_pull_request_files",
"name": "List Pull Requests' Files",
"description": LIST_PULL_REQUEST_FILES,
"args_schema": GetPR,
},
{
"mode": "create_file",
"name": "Create File",
"description": CREATE_FILE_PROMPT,
"args_schema": CreateFile,
},
{
"mode": "read_file",
"name": "Read File",
"description": READ_FILE_PROMPT,
"args_schema": ReadFile,
},
{
"mode": "update_file",
"name": "Update File",
"description": UPDATE_FILE_PROMPT,
"args_schema": UpdateFile,
},
{
"mode": "delete_file",
"name": "Delete File",
"description": DELETE_FILE_PROMPT,
"args_schema": DeleteFile,
},
{
"mode": "list_files_in_main_branch",
"name": "Overview of existing files in Main branch",
"description": OVERVIEW_EXISTING_FILES_IN_MAIN,
"args_schema": NoInput,
},
{
"mode": "list_files_in_bot_branch",
"name": "Overview of files in current working branch",
"description": OVERVIEW_EXISTING_FILES_BOT_BRANCH,
"args_schema": NoInput,
},
{
"mode": "list_branches_in_repo",
"name": "List branches in this repository",
"description": LIST_BRANCHES_IN_REPO_PROMPT,
"args_schema": NoInput,
},
{
"mode": "set_active_branch",
"name": "Set active branch",
"description": SET_ACTIVE_BRANCH_PROMPT,
"args_schema": BranchName,
},
{
"mode": "create_branch",
"name": "Create a new branch",
"description": CREATE_BRANCH_PROMPT,
"args_schema": BranchName,
},
{
"mode": "get_files_from_directory",
"name": "Get files from a directory",
"description": GET_FILES_FROM_DIRECTORY_PROMPT,
"args_schema": DirectoryPath,
},
{
"mode": "search_issues_and_prs",
"name": "Search issues and pull requests",
"description": SEARCH_ISSUES_AND_PRS_PROMPT,
"args_schema": SearchIssuesAndPRs,
},
{
"mode": "search_code",
"name": "Search code",
"description": SEARCH_CODE_PROMPT,
"args_schema": SearchCode,
},
{
"mode": "create_review_request",
"name": "Create review request",
"description": CREATE_REVIEW_REQUEST_PROMPT,
"args_schema": CreateReviewRequest,
},
]
tools = [
GitHubAction(
name=action["name"],
description=action["description"],
mode=action["mode"],
api_wrapper=github_api_wrapper,
args_schema=action.get("args_schema", None),
)
for action in operations
]
return cls(tools=tools)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return self.tools
__all__ = [
"NoInput",
"GetIssue",
"CommentOnIssue",
"GetPR",
"CreatePR",
"CreateFile",
"ReadFile",
"UpdateFile",
"DeleteFile",
"DirectoryPath",
"BranchName",
"SearchCode",
"CreateReviewRequest",
"SearchIssuesAndPRs",
"GitHubToolkit",
]

View File

@@ -1,94 +1,3 @@
"""GitHub Toolkit."""
from typing import Dict, List
from langchain_community.agent_toolkits.gitlab.toolkit import GitLabToolkit
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
from langchain.tools.gitlab.prompt import (
COMMENT_ON_ISSUE_PROMPT,
CREATE_FILE_PROMPT,
CREATE_PULL_REQUEST_PROMPT,
DELETE_FILE_PROMPT,
GET_ISSUE_PROMPT,
GET_ISSUES_PROMPT,
READ_FILE_PROMPT,
UPDATE_FILE_PROMPT,
)
from langchain.tools.gitlab.tool import GitLabAction
from langchain.utilities.gitlab import GitLabAPIWrapper
class GitLabToolkit(BaseToolkit):
"""GitLab Toolkit.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by creating, deleting, or updating,
reading underlying data.
For example, this toolkit can be used to create issues, pull requests,
and comments on GitLab.
See https://python.langchain.com/docs/security for more information.
"""
tools: List[BaseTool] = []
@classmethod
def from_gitlab_api_wrapper(
cls, gitlab_api_wrapper: GitLabAPIWrapper
) -> "GitLabToolkit":
operations: List[Dict] = [
{
"mode": "get_issues",
"name": "Get Issues",
"description": GET_ISSUES_PROMPT,
},
{
"mode": "get_issue",
"name": "Get Issue",
"description": GET_ISSUE_PROMPT,
},
{
"mode": "comment_on_issue",
"name": "Comment on Issue",
"description": COMMENT_ON_ISSUE_PROMPT,
},
{
"mode": "create_pull_request",
"name": "Create Pull Request",
"description": CREATE_PULL_REQUEST_PROMPT,
},
{
"mode": "create_file",
"name": "Create File",
"description": CREATE_FILE_PROMPT,
},
{
"mode": "read_file",
"name": "Read File",
"description": READ_FILE_PROMPT,
},
{
"mode": "update_file",
"name": "Update File",
"description": UPDATE_FILE_PROMPT,
},
{
"mode": "delete_file",
"name": "Delete File",
"description": DELETE_FILE_PROMPT,
},
]
tools = [
GitLabAction(
name=action["name"],
description=action["description"],
mode=action["mode"],
api_wrapper=gitlab_api_wrapper,
)
for action in operations
]
return cls(tools=tools)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return self.tools
__all__ = ["GitLabToolkit"]

View File

@@ -1,58 +1,3 @@
from __future__ import annotations
from langchain_community.agent_toolkits.gmail.toolkit import SCOPES, GmailToolkit
from typing import TYPE_CHECKING, List
from langchain_core.pydantic_v1 import Field
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
from langchain.tools.gmail.create_draft import GmailCreateDraft
from langchain.tools.gmail.get_message import GmailGetMessage
from langchain.tools.gmail.get_thread import GmailGetThread
from langchain.tools.gmail.search import GmailSearch
from langchain.tools.gmail.send_message import GmailSendMessage
from langchain.tools.gmail.utils import build_resource_service
if TYPE_CHECKING:
# This is for linting and IDE typehints
from googleapiclient.discovery import Resource
else:
try:
# We do this so pydantic can resolve the types when instantiating
from googleapiclient.discovery import Resource
except ImportError:
pass
SCOPES = ["https://mail.google.com/"]
class GmailToolkit(BaseToolkit):
"""Toolkit for interacting with Gmail.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by reading, creating, updating, deleting
data associated with this service.
For example, this toolkit can be used to send emails on behalf of the
associated account.
See https://python.langchain.com/docs/security for more information.
"""
api_resource: Resource = Field(default_factory=build_resource_service)
class Config:
"""Pydantic config."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
GmailCreateDraft(api_resource=self.api_resource),
GmailSendMessage(api_resource=self.api_resource),
GmailSearch(api_resource=self.api_resource),
GmailGetMessage(api_resource=self.api_resource),
GmailGetThread(api_resource=self.api_resource),
]
__all__ = ["SCOPES", "GmailToolkit"]

View File

@@ -1,70 +1,3 @@
from typing import Dict, List
from langchain_community.agent_toolkits.jira.toolkit import JiraToolkit
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
from langchain.tools.jira.prompt import (
JIRA_CATCH_ALL_PROMPT,
JIRA_CONFLUENCE_PAGE_CREATE_PROMPT,
JIRA_GET_ALL_PROJECTS_PROMPT,
JIRA_ISSUE_CREATE_PROMPT,
JIRA_JQL_PROMPT,
)
from langchain.tools.jira.tool import JiraAction
from langchain.utilities.jira import JiraAPIWrapper
class JiraToolkit(BaseToolkit):
"""Jira Toolkit.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by creating, deleting, or updating,
reading underlying data.
See https://python.langchain.com/docs/security for more information.
"""
tools: List[BaseTool] = []
@classmethod
def from_jira_api_wrapper(cls, jira_api_wrapper: JiraAPIWrapper) -> "JiraToolkit":
operations: List[Dict] = [
{
"mode": "jql",
"name": "JQL Query",
"description": JIRA_JQL_PROMPT,
},
{
"mode": "get_projects",
"name": "Get Projects",
"description": JIRA_GET_ALL_PROJECTS_PROMPT,
},
{
"mode": "create_issue",
"name": "Create Issue",
"description": JIRA_ISSUE_CREATE_PROMPT,
},
{
"mode": "other",
"name": "Catch all Jira API call",
"description": JIRA_CATCH_ALL_PROMPT,
},
{
"mode": "create_page",
"name": "Create confluence page",
"description": JIRA_CONFLUENCE_PAGE_CREATE_PROMPT,
},
]
tools = [
JiraAction(
name=action["name"],
description=action["description"],
mode=action["mode"],
api_wrapper=jira_api_wrapper,
)
for action in operations
]
return cls(tools=tools)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return self.tools
__all__ = ["JiraToolkit"]

View File

@@ -1,49 +1,3 @@
"""Json agent."""
from typing import Any, Dict, List, Optional
from langchain_community.agent_toolkits.json.base import create_json_agent
from langchain_core.language_models import BaseLanguageModel
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
def create_json_agent(
llm: BaseLanguageModel,
toolkit: JsonToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = JSON_PREFIX,
suffix: str = JSON_SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a json agent from an LLM and tools."""
tools = toolkit.get_tools()
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
**(agent_executor_kwargs or {}),
)
__all__ = ["create_json_agent"]

View File

@@ -1,25 +1,3 @@
# flake8: noqa
from langchain_community.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
JSON_PREFIX = """You are an agent designed to interact with JSON.
Your goal is to return a final answer by interacting with the JSON.
You have access to the following tools which help you learn more about the JSON you are interacting with.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
Do not make up any information that is not contained in the JSON.
Your input to the tools should be in the form of `data["key"][0]` where `data` is the JSON blob you are interacting with, and the syntax used is Python.
You should only use keys that you know for a fact exist. You must validate that a key exists by seeing it previously when calling `json_spec_list_keys`.
If you have not seen a key in one of those responses, you cannot use it.
You should only add one key at a time to the path. You cannot add multiple keys at once.
If you encounter a "KeyError", go back to the previous key, look at the available keys, and try again.
If the question does not seem to be related to the JSON, just return "I don't know" as the answer.
Always begin your interaction with the `json_spec_list_keys` tool with input "data" to see what keys exist in the JSON.
Note that sometimes the value at a given path is large. In this case, you will get an error "Value is a large dictionary, should explore its keys directly".
In this case, you should ALWAYS follow up by using the `json_spec_list_keys` tool to see what keys exist at that path.
Do not simply refer the user to the JSON or a section of the JSON, as this is not a valid answer. Keep digging until you find the answer and explicitly return it.
"""
JSON_SUFFIX = """Begin!"
Question: {input}
Thought: I should look at the keys that exist in data to see what I have access to
{agent_scratchpad}"""
__all__ = ["JSON_PREFIX", "JSON_SUFFIX"]

View File

@@ -1,20 +1,3 @@
from __future__ import annotations
from langchain_community.agent_toolkits.json.toolkit import JsonToolkit
from typing import List
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
from langchain.tools.json.tool import JsonGetValueTool, JsonListKeysTool, JsonSpec
class JsonToolkit(BaseToolkit):
"""Toolkit for interacting with a JSON spec."""
spec: JsonSpec
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
JsonListKeysTool(spec=self.spec),
JsonGetValueTool(spec=self.spec),
]
__all__ = ["JsonToolkit"]

View File

@@ -1,33 +1,3 @@
"""MultiOn agent."""
from __future__ import annotations
from langchain_community.agent_toolkits.multion.toolkit import MultionToolkit
from typing import List
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
from langchain.tools.multion.close_session import MultionCloseSession
from langchain.tools.multion.create_session import MultionCreateSession
from langchain.tools.multion.update_session import MultionUpdateSession
class MultionToolkit(BaseToolkit):
"""Toolkit for interacting with the Browser Agent.
**Security Note**: This toolkit contains tools that interact with the
user's browser via the multion API which grants an agent
access to the user's browser.
Please review the documentation for the multion API to understand
the security implications of using this toolkit.
See https://python.langchain.com/docs/security for more information.
"""
class Config:
"""Pydantic config."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [MultionCreateSession(), MultionUpdateSession(), MultionCloseSession()]
__all__ = ["MultionToolkit"]

View File

@@ -1,57 +1,3 @@
from typing import Dict, List
from langchain_community.agent_toolkits.nasa.toolkit import NasaToolkit
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
from langchain.tools.nasa.prompt import (
NASA_CAPTIONS_PROMPT,
NASA_MANIFEST_PROMPT,
NASA_METADATA_PROMPT,
NASA_SEARCH_PROMPT,
)
from langchain.tools.nasa.tool import NasaAction
from langchain.utilities.nasa import NasaAPIWrapper
class NasaToolkit(BaseToolkit):
"""Nasa Toolkit."""
tools: List[BaseTool] = []
@classmethod
def from_nasa_api_wrapper(cls, nasa_api_wrapper: NasaAPIWrapper) -> "NasaToolkit":
operations: List[Dict] = [
{
"mode": "search_media",
"name": "Search NASA Image and Video Library media",
"description": NASA_SEARCH_PROMPT,
},
{
"mode": "get_media_metadata_manifest",
"name": "Get NASA Image and Video Library media metadata manifest",
"description": NASA_MANIFEST_PROMPT,
},
{
"mode": "get_media_metadata_location",
"name": "Get NASA Image and Video Library media metadata location",
"description": NASA_METADATA_PROMPT,
},
{
"mode": "get_video_captions_location",
"name": "Get NASA Image and Video Library video captions location",
"description": NASA_CAPTIONS_PROMPT,
},
]
tools = [
NasaAction(
name=action["name"],
description=action["description"],
mode=action["mode"],
api_wrapper=nasa_api_wrapper,
)
for action in operations
]
return cls(tools=tools)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return self.tools
__all__ = ["NasaToolkit"]

View File

@@ -1,55 +1,3 @@
"""Tool for interacting with a single API with natural language definition."""
from langchain_community.agent_toolkits.nla.tool import NLATool
from typing import Any, Optional
from langchain_core.language_models import BaseLanguageModel
from langchain.agents.tools import Tool
from langchain.chains.api.openapi.chain import OpenAPIEndpointChain
from langchain.tools.openapi.utils.api_models import APIOperation
from langchain.tools.openapi.utils.openapi_utils import OpenAPISpec
from langchain.utilities.requests import Requests
class NLATool(Tool):
"""Natural Language API Tool."""
@classmethod
def from_open_api_endpoint_chain(
cls, chain: OpenAPIEndpointChain, api_title: str
) -> "NLATool":
"""Convert an endpoint chain to an API endpoint tool."""
expanded_name = (
f'{api_title.replace(" ", "_")}.{chain.api_operation.operation_id}'
)
description = (
f"I'm an AI from {api_title}. Instruct what you want,"
" and I'll assist via an API with description:"
f" {chain.api_operation.description}"
)
return cls(name=expanded_name, func=chain.run, description=description)
@classmethod
def from_llm_and_method(
cls,
llm: BaseLanguageModel,
path: str,
method: str,
spec: OpenAPISpec,
requests: Optional[Requests] = None,
verbose: bool = False,
return_intermediate_steps: bool = False,
**kwargs: Any,
) -> "NLATool":
"""Instantiate the tool from the specified path and method."""
api_operation = APIOperation.from_openapi_spec(spec, path, method)
chain = OpenAPIEndpointChain.from_api_operation(
api_operation,
llm,
requests=requests,
verbose=verbose,
return_intermediate_steps=return_intermediate_steps,
**kwargs,
)
return cls.from_open_api_endpoint_chain(chain, spec.info.title)
__all__ = ["NLATool"]

View File

@@ -1,127 +1,3 @@
from __future__ import annotations
from langchain_community.agent_toolkits.nla.toolkit import NLAToolkit
from typing import Any, List, Optional, Sequence
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Field
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.agents.agent_toolkits.nla.tool import NLATool
from langchain.tools.base import BaseTool
from langchain.tools.openapi.utils.openapi_utils import OpenAPISpec
from langchain.tools.plugin import AIPlugin
from langchain.utilities.requests import Requests
class NLAToolkit(BaseToolkit):
"""Natural Language API Toolkit.
*Security Note*: This toolkit creates tools that enable making calls
to an Open API compliant API.
The tools created by this toolkit may be able to make GET, POST,
PATCH, PUT, DELETE requests to any of the exposed endpoints on
the API.
Control access to who can use this toolkit.
See https://python.langchain.com/docs/security for more information.
"""
nla_tools: Sequence[NLATool] = Field(...)
"""List of API Endpoint Tools."""
def get_tools(self) -> List[BaseTool]:
"""Get the tools for all the API operations."""
return list(self.nla_tools)
@staticmethod
def _get_http_operation_tools(
llm: BaseLanguageModel,
spec: OpenAPISpec,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> List[NLATool]:
"""Get the tools for all the API operations."""
if not spec.paths:
return []
http_operation_tools = []
for path in spec.paths:
for method in spec.get_methods_for_path(path):
endpoint_tool = NLATool.from_llm_and_method(
llm=llm,
path=path,
method=method,
spec=spec,
requests=requests,
verbose=verbose,
**kwargs,
)
http_operation_tools.append(endpoint_tool)
return http_operation_tools
@classmethod
def from_llm_and_spec(
cls,
llm: BaseLanguageModel,
spec: OpenAPISpec,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> NLAToolkit:
"""Instantiate the toolkit by creating tools for each operation."""
http_operation_tools = cls._get_http_operation_tools(
llm=llm, spec=spec, requests=requests, verbose=verbose, **kwargs
)
return cls(nla_tools=http_operation_tools)
@classmethod
def from_llm_and_url(
cls,
llm: BaseLanguageModel,
open_api_url: str,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> NLAToolkit:
"""Instantiate the toolkit from an OpenAPI Spec URL"""
spec = OpenAPISpec.from_url(open_api_url)
return cls.from_llm_and_spec(
llm=llm, spec=spec, requests=requests, verbose=verbose, **kwargs
)
@classmethod
def from_llm_and_ai_plugin(
cls,
llm: BaseLanguageModel,
ai_plugin: AIPlugin,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> NLAToolkit:
"""Instantiate the toolkit from an OpenAPI Spec URL"""
spec = OpenAPISpec.from_url(ai_plugin.api.url)
# TODO: Merge optional Auth information with the `requests` argument
return cls.from_llm_and_spec(
llm=llm,
spec=spec,
requests=requests,
verbose=verbose,
**kwargs,
)
@classmethod
def from_llm_and_ai_plugin_url(
cls,
llm: BaseLanguageModel,
ai_plugin_url: str,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> NLAToolkit:
"""Instantiate the toolkit from an OpenAPI Spec URL"""
plugin = AIPlugin.from_url(ai_plugin_url)
return cls.from_llm_and_ai_plugin(
llm=llm, ai_plugin=plugin, requests=requests, verbose=verbose, **kwargs
)
__all__ = ["NLAToolkit"]

View File

@@ -1,51 +1,3 @@
from __future__ import annotations
from langchain_community.agent_toolkits.office365.toolkit import O365Toolkit
from typing import TYPE_CHECKING, List
from langchain_core.pydantic_v1 import Field
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
from langchain.tools.office365.create_draft_message import O365CreateDraftMessage
from langchain.tools.office365.events_search import O365SearchEvents
from langchain.tools.office365.messages_search import O365SearchEmails
from langchain.tools.office365.send_event import O365SendEvent
from langchain.tools.office365.send_message import O365SendMessage
from langchain.tools.office365.utils import authenticate
if TYPE_CHECKING:
from O365 import Account
class O365Toolkit(BaseToolkit):
"""Toolkit for interacting with Office 365.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by reading, creating, updating, deleting
data associated with this service.
For example, this toolkit can be used search through emails and events,
send messages and event invites, and create draft messages.
Please make sure that the permissions given by this toolkit
are appropriate for your use case.
See https://python.langchain.com/docs/security for more information.
"""
account: Account = Field(default_factory=authenticate)
class Config:
"""Pydantic config."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
O365SearchEvents(),
O365CreateDraftMessage(),
O365SearchEmails(),
O365SendEvent(),
O365SendMessage(),
]
__all__ = ["O365Toolkit"]

View File

@@ -1,73 +1,3 @@
"""OpenAPI spec agent."""
from typing import Any, Dict, List, Optional
from langchain_community.agent_toolkits.openapi.base import create_openapi_agent
from langchain_core.language_models import BaseLanguageModel
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.openapi.prompt import (
OPENAPI_PREFIX,
OPENAPI_SUFFIX,
)
from langchain.agents.agent_toolkits.openapi.toolkit import OpenAPIToolkit
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
def create_openapi_agent(
llm: BaseLanguageModel,
toolkit: OpenAPIToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = OPENAPI_PREFIX,
suffix: str = OPENAPI_SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
verbose: bool = False,
return_intermediate_steps: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct an OpenAPI agent from an LLM and tools.
*Security Note*: When creating an OpenAPI agent, check the permissions
and capabilities of the underlying toolkit.
For example, if the default implementation of OpenAPIToolkit
uses the RequestsToolkit which contains tools to make arbitrary
network requests against any URL (e.g., GET, POST, PATCH, PUT, DELETE),
Control access to who can submit issue requests using this toolkit and
what network access it has.
See https://python.langchain.com/docs/security for more information.
"""
tools = toolkit.get_tools()
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
return_intermediate_steps=return_intermediate_steps,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)
__all__ = ["create_openapi_agent"]

View File

@@ -1,358 +1,29 @@
"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach."""
import json
import re
from functools import partial
from typing import Any, Callable, Dict, List, Optional
import yaml
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.openapi.planner_prompt import (
API_CONTROLLER_PROMPT,
API_CONTROLLER_TOOL_DESCRIPTION,
API_CONTROLLER_TOOL_NAME,
API_ORCHESTRATOR_PROMPT,
API_PLANNER_PROMPT,
API_PLANNER_TOOL_DESCRIPTION,
API_PLANNER_TOOL_NAME,
PARSING_DELETE_PROMPT,
PARSING_GET_PROMPT,
PARSING_PATCH_PROMPT,
PARSING_POST_PROMPT,
PARSING_PUT_PROMPT,
REQUESTS_DELETE_TOOL_DESCRIPTION,
REQUESTS_GET_TOOL_DESCRIPTION,
REQUESTS_PATCH_TOOL_DESCRIPTION,
REQUESTS_POST_TOOL_DESCRIPTION,
REQUESTS_PUT_TOOL_DESCRIPTION,
from langchain_community.agent_toolkits.openapi.planner import (
MAX_RESPONSE_LENGTH,
RequestsDeleteToolWithParsing,
RequestsGetToolWithParsing,
RequestsPatchToolWithParsing,
RequestsPostToolWithParsing,
RequestsPutToolWithParsing,
_create_api_controller_agent,
_create_api_controller_tool,
_create_api_planner_tool,
_get_default_llm_chain,
_get_default_llm_chain_factory,
create_openapi_agent,
)
from langchain.agents.agent_toolkits.openapi.spec import ReducedOpenAPISpec
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.tools import Tool
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.llms.openai import OpenAI
from langchain.memory import ReadOnlySharedMemory
from langchain.output_parsers.json import parse_json_markdown
from langchain.tools.base import BaseTool
from langchain.tools.requests.tool import BaseRequestsTool
from langchain.utilities.requests import RequestsWrapper
#
# Requests tools with LLM-instructed extraction of truncated responses.
#
# Of course, truncating so bluntly may lose a lot of valuable
# information in the response.
# However, the goal for now is to have only a single inference step.
MAX_RESPONSE_LENGTH = 5000
"""Maximum length of the response to be returned."""
def _get_default_llm_chain(prompt: BasePromptTemplate) -> LLMChain:
return LLMChain(
llm=OpenAI(),
prompt=prompt,
)
def _get_default_llm_chain_factory(
prompt: BasePromptTemplate,
) -> Callable[[], LLMChain]:
"""Returns a default LLMChain factory."""
return partial(_get_default_llm_chain, prompt)
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests GET tool with LLM-instructed extraction of truncated responses."""
name: str = "requests_get"
"""Tool name."""
description = REQUESTS_GET_TOOL_DESCRIPTION
"""Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
try:
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
data_params = data.get("params")
response = self.requests_wrapper.get(data["url"], params=data_params)
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests POST tool with LLM-instructed extraction of truncated responses."""
name: str = "requests_post"
"""Tool name."""
description = REQUESTS_POST_TOOL_DESCRIPTION
"""Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
try:
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
response = self.requests_wrapper.post(data["url"], data["data"])
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests PATCH tool with LLM-instructed extraction of truncated responses."""
name: str = "requests_patch"
"""Tool name."""
description = REQUESTS_PATCH_TOOL_DESCRIPTION
"""Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
try:
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
response = self.requests_wrapper.patch(data["url"], data["data"])
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests PUT tool with LLM-instructed extraction of truncated responses."""
name: str = "requests_put"
"""Tool name."""
description = REQUESTS_PUT_TOOL_DESCRIPTION
"""Tool description."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_PUT_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
try:
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
response = self.requests_wrapper.put(data["url"], data["data"])
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
"""A tool that sends a DELETE request and parses the response."""
name: str = "requests_delete"
"""The name of the tool."""
description = REQUESTS_DELETE_TOOL_DESCRIPTION
"""The description of the tool."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
"""The maximum length of the response."""
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_DELETE_PROMPT)
)
"""The LLM chain used to parse the response."""
def _run(self, text: str) -> str:
try:
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
response = self.requests_wrapper.delete(data["url"])
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
#
# Orchestrator, planner, controller.
#
def _create_api_planner_tool(
api_spec: ReducedOpenAPISpec, llm: BaseLanguageModel
) -> Tool:
endpoint_descriptions = [
f"{name} {description}" for name, description, _ in api_spec.endpoints
]
prompt = PromptTemplate(
template=API_PLANNER_PROMPT,
input_variables=["query"],
partial_variables={"endpoints": "- " + "- ".join(endpoint_descriptions)},
)
chain = LLMChain(llm=llm, prompt=prompt)
tool = Tool(
name=API_PLANNER_TOOL_NAME,
description=API_PLANNER_TOOL_DESCRIPTION,
func=chain.run,
)
return tool
def _create_api_controller_agent(
api_url: str,
api_docs: str,
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
) -> AgentExecutor:
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
tools: List[BaseTool] = [
RequestsGetToolWithParsing(
requests_wrapper=requests_wrapper, llm_chain=get_llm_chain
),
RequestsPostToolWithParsing(
requests_wrapper=requests_wrapper, llm_chain=post_llm_chain
),
]
prompt = PromptTemplate(
template=API_CONTROLLER_PROMPT,
input_variables=["input", "agent_scratchpad"],
partial_variables={
"api_url": api_url,
"api_docs": api_docs,
"tool_names": ", ".join([tool.name for tool in tools]),
"tool_descriptions": "\n".join(
[f"{tool.name}: {tool.description}" for tool in tools]
),
},
)
agent = ZeroShotAgent(
llm_chain=LLMChain(llm=llm, prompt=prompt),
allowed_tools=[tool.name for tool in tools],
)
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
def _create_api_controller_tool(
api_spec: ReducedOpenAPISpec,
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
) -> Tool:
"""Expose controller as a tool.
The tool is invoked with a plan from the planner, and dynamically
creates a controller agent with relevant documentation only to
constrain the context.
"""
base_url = api_spec.servers[0]["url"] # TODO: do better.
def _create_and_run_api_controller_agent(plan_str: str) -> str:
pattern = r"\b(GET|POST|PATCH|DELETE)\s+(/\S+)*"
matches = re.findall(pattern, plan_str)
endpoint_names = [
"{method} {route}".format(method=method, route=route.split("?")[0])
for method, route in matches
]
docs_str = ""
for endpoint_name in endpoint_names:
found_match = False
for name, _, docs in api_spec.endpoints:
regex_name = re.compile(re.sub("\{.*?\}", ".*", name))
if regex_name.match(endpoint_name):
found_match = True
docs_str += f"== Docs for {endpoint_name} == \n{yaml.dump(docs)}\n"
if not found_match:
raise ValueError(f"{endpoint_name} endpoint does not exist.")
agent = _create_api_controller_agent(base_url, docs_str, requests_wrapper, llm)
return agent.run(plan_str)
return Tool(
name=API_CONTROLLER_TOOL_NAME,
func=_create_and_run_api_controller_agent,
description=API_CONTROLLER_TOOL_DESCRIPTION,
)
def create_openapi_agent(
api_spec: ReducedOpenAPISpec,
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
shared_memory: Optional[ReadOnlySharedMemory] = None,
callback_manager: Optional[BaseCallbackManager] = None,
verbose: bool = True,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Instantiate OpenAI API planner and controller for a given spec.
Inject credentials via requests_wrapper.
We use a top-level "orchestrator" agent to invoke the planner and controller,
rather than a top-level planner
that invokes a controller with its plan. This is to keep the planner simple.
"""
tools = [
_create_api_planner_tool(api_spec, llm),
_create_api_controller_tool(api_spec, requests_wrapper, llm),
]
prompt = PromptTemplate(
template=API_ORCHESTRATOR_PROMPT,
input_variables=["input", "agent_scratchpad"],
partial_variables={
"tool_names": ", ".join([tool.name for tool in tools]),
"tool_descriptions": "\n".join(
[f"{tool.name}: {tool.description}" for tool in tools]
),
},
)
agent = ZeroShotAgent(
llm_chain=LLMChain(llm=llm, prompt=prompt, memory=shared_memory),
allowed_tools=[tool.name for tool in tools],
**kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
**(agent_executor_kwargs or {}),
)
__all__ = [
"MAX_RESPONSE_LENGTH",
"_get_default_llm_chain",
"_get_default_llm_chain_factory",
"RequestsGetToolWithParsing",
"RequestsPostToolWithParsing",
"RequestsPatchToolWithParsing",
"RequestsPutToolWithParsing",
"RequestsDeleteToolWithParsing",
"_create_api_planner_tool",
"_create_api_controller_agent",
"_create_api_controller_tool",
"create_openapi_agent",
]

View File

@@ -1,235 +1,39 @@
# flake8: noqa
from langchain_core.prompts.prompt import PromptTemplate
API_PLANNER_PROMPT = """You are a planner that plans a sequence of API calls to assist with user queries against an API.
You should:
1) evaluate whether the user query can be solved by the API documentated below. If no, say why.
2) if yes, generate a plan of API calls and say what they are doing step by step.
3) If the plan includes a DELETE call, you should always return an ask from the User for authorization first unless the User has specifically asked to delete something.
You should only use API endpoints documented below ("Endpoints you can use:").
You can only use the DELETE tool if the User has specifically asked to delete something. Otherwise, you should return a request authorization from the User first.
Some user queries can be resolved in a single API call, but some will require several API calls.
The plan will be passed to an API controller that can format it into web requests and return the responses.
----
Here are some examples:
Fake endpoints for examples:
GET /user to get information about the current user
GET /products/search search across products
POST /users/{{id}}/cart to add products to a user's cart
PATCH /users/{{id}}/cart to update a user's cart
PUT /users/{{id}}/coupon to apply idempotent coupon to a user's cart
DELETE /users/{{id}}/cart to delete a user's cart
User query: tell me a joke
Plan: Sorry, this API's domain is shopping, not comedy.
User query: I want to buy a couch
Plan: 1. GET /products with a query param to search for couches
2. GET /user to find the user's id
3. POST /users/{{id}}/cart to add a couch to the user's cart
User query: I want to add a lamp to my cart
Plan: 1. GET /products with a query param to search for lamps
2. GET /user to find the user's id
3. PATCH /users/{{id}}/cart to add a lamp to the user's cart
User query: I want to add a coupon to my cart
Plan: 1. GET /user to find the user's id
2. PUT /users/{{id}}/coupon to apply the coupon
User query: I want to delete my cart
Plan: 1. GET /user to find the user's id
2. DELETE required. Did user specify DELETE or previously authorize? Yes, proceed.
3. DELETE /users/{{id}}/cart to delete the user's cart
User query: I want to start a new cart
Plan: 1. GET /user to find the user's id
2. DELETE required. Did user specify DELETE or previously authorize? No, ask for authorization.
3. Are you sure you want to delete your cart?
----
Here are endpoints you can use. Do not reference any of the endpoints above.
{endpoints}
----
User query: {query}
Plan:"""
API_PLANNER_TOOL_NAME = "api_planner"
API_PLANNER_TOOL_DESCRIPTION = f"Can be used to generate the right API calls to assist with a user query, like {API_PLANNER_TOOL_NAME}(query). Should always be called before trying to call the API controller."
# Execution.
API_CONTROLLER_PROMPT = """You are an agent that gets a sequence of API calls and given their documentation, should execute them and return the final response.
If you cannot complete them and run into issues, you should explain the issue. If you're unable to resolve an API call, you can retry the API call. When interacting with API objects, you should extract ids for inputs to other API calls but ids and names for outputs returned to the User.
Here is documentation on the API:
Base url: {api_url}
Endpoints:
{api_docs}
Here are tools to execute requests against the API: {tool_descriptions}
Starting below, you should follow this format:
Plan: the plan of API calls to execute
Thought: you should always think about what to do
Action: the action to take, should be one of the tools [{tool_names}]
Action Input: the input to the action
Observation: the output of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I am finished executing the plan (or, I cannot finish executing the plan without knowing some other information.)
Final Answer: the final output from executing the plan or missing information I'd need to re-plan correctly.
Begin!
Plan: {input}
Thought:
{agent_scratchpad}
"""
API_CONTROLLER_TOOL_NAME = "api_controller"
API_CONTROLLER_TOOL_DESCRIPTION = f"Can be used to execute a plan of API calls, like {API_CONTROLLER_TOOL_NAME}(plan)."
# Orchestrate planning + execution.
# The goal is to have an agent at the top-level (e.g. so it can recover from errors and re-plan) while
# keeping planning (and specifically the planning prompt) simple.
API_ORCHESTRATOR_PROMPT = """You are an agent that assists with user queries against API, things like querying information or creating resources.
Some user queries can be resolved in a single API call, particularly if you can find appropriate params from the OpenAPI spec; though some require several API calls.
You should always plan your API calls first, and then execute the plan second.
If the plan includes a DELETE call, be sure to ask the User for authorization first unless the User has specifically asked to delete something.
You should never return information without executing the api_controller tool.
Here are the tools to plan and execute API requests: {tool_descriptions}
Starting below, you should follow this format:
User query: the query a User wants help with related to the API
Thought: you should always think about what to do
Action: the action to take, should be one of the tools [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I am finished executing a plan and have the information the user asked for or the data the user asked to create
Final Answer: the final output from executing the plan
Example:
User query: can you add some trendy stuff to my shopping cart.
Thought: I should plan API calls first.
Action: api_planner
Action Input: I need to find the right API calls to add trendy items to the users shopping cart
Observation: 1) GET /items with params 'trending' is 'True' to get trending item ids
2) GET /user to get user
3) POST /cart to post the trending items to the user's cart
Thought: I'm ready to execute the API calls.
Action: api_controller
Action Input: 1) GET /items params 'trending' is 'True' to get trending item ids
2) GET /user to get user
3) POST /cart to post the trending items to the user's cart
...
Begin!
User query: {input}
Thought: I should generate a plan to help with this query and then copy that plan exactly to the controller.
{agent_scratchpad}"""
REQUESTS_GET_TOOL_DESCRIPTION = """Use this to GET content from a website.
Input to the tool should be a json string with 3 keys: "url", "params" and "output_instructions".
The value of "url" should be a string.
The value of "params" should be a dict of the needed and available parameters from the OpenAPI spec related to the endpoint.
If parameters are not needed, or not available, leave it empty.
The value of "output_instructions" should be instructions on what information to extract from the response,
for example the id(s) for a resource(s) that the GET request fetches.
"""
PARSING_GET_PROMPT = PromptTemplate(
template="""Here is an API response:\n\n{response}\n\n====
Your task is to extract some information according to these instructions: {instructions}
When working with API objects, you should usually use ids over names.
If the response indicates an error, you should instead output a summary of the error.
Output:""",
input_variables=["response", "instructions"],
from langchain_community.agent_toolkits.openapi.planner_prompt import (
API_CONTROLLER_PROMPT,
API_CONTROLLER_TOOL_DESCRIPTION,
API_CONTROLLER_TOOL_NAME,
API_ORCHESTRATOR_PROMPT,
API_PLANNER_PROMPT,
API_PLANNER_TOOL_DESCRIPTION,
API_PLANNER_TOOL_NAME,
PARSING_DELETE_PROMPT,
PARSING_GET_PROMPT,
PARSING_PATCH_PROMPT,
PARSING_POST_PROMPT,
PARSING_PUT_PROMPT,
REQUESTS_DELETE_TOOL_DESCRIPTION,
REQUESTS_GET_TOOL_DESCRIPTION,
REQUESTS_PATCH_TOOL_DESCRIPTION,
REQUESTS_POST_TOOL_DESCRIPTION,
REQUESTS_PUT_TOOL_DESCRIPTION,
)
REQUESTS_POST_TOOL_DESCRIPTION = """Use this when you want to POST to a website.
Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions".
The value of "url" should be a string.
The value of "data" should be a dictionary of key-value pairs you want to POST to the url.
The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the POST request creates.
Always use double quotes for strings in the json string."""
PARSING_POST_PROMPT = PromptTemplate(
template="""Here is an API response:\n\n{response}\n\n====
Your task is to extract some information according to these instructions: {instructions}
When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
If the response indicates an error, you should instead output a summary of the error.
Output:""",
input_variables=["response", "instructions"],
)
REQUESTS_PATCH_TOOL_DESCRIPTION = """Use this when you want to PATCH content on a website.
Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions".
The value of "url" should be a string.
The value of "data" should be a dictionary of key-value pairs of the body params available in the OpenAPI spec you want to PATCH the content with at the url.
The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the PATCH request creates.
Always use double quotes for strings in the json string."""
PARSING_PATCH_PROMPT = PromptTemplate(
template="""Here is an API response:\n\n{response}\n\n====
Your task is to extract some information according to these instructions: {instructions}
When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
If the response indicates an error, you should instead output a summary of the error.
Output:""",
input_variables=["response", "instructions"],
)
REQUESTS_PUT_TOOL_DESCRIPTION = """Use this when you want to PUT to a website.
Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions".
The value of "url" should be a string.
The value of "data" should be a dictionary of key-value pairs you want to PUT to the url.
The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the PUT request creates.
Always use double quotes for strings in the json string."""
PARSING_PUT_PROMPT = PromptTemplate(
template="""Here is an API response:\n\n{response}\n\n====
Your task is to extract some information according to these instructions: {instructions}
When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
If the response indicates an error, you should instead output a summary of the error.
Output:""",
input_variables=["response", "instructions"],
)
REQUESTS_DELETE_TOOL_DESCRIPTION = """ONLY USE THIS TOOL WHEN THE USER HAS SPECIFICALLY REQUESTED TO DELETE CONTENT FROM A WEBSITE.
Input to the tool should be a json string with 2 keys: "url", and "output_instructions".
The value of "url" should be a string.
The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the DELETE request creates.
Always use double quotes for strings in the json string.
ONLY USE THIS TOOL IF THE USER HAS SPECIFICALLY REQUESTED TO DELETE SOMETHING."""
PARSING_DELETE_PROMPT = PromptTemplate(
template="""Here is an API response:\n\n{response}\n\n====
Your task is to extract some information according to these instructions: {instructions}
When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
If the response indicates an error, you should instead output a summary of the error.
Output:""",
input_variables=["response", "instructions"],
)
__all__ = [
"API_PLANNER_PROMPT",
"API_PLANNER_TOOL_NAME",
"API_PLANNER_TOOL_DESCRIPTION",
"API_CONTROLLER_PROMPT",
"API_CONTROLLER_TOOL_NAME",
"API_CONTROLLER_TOOL_DESCRIPTION",
"API_ORCHESTRATOR_PROMPT",
"REQUESTS_GET_TOOL_DESCRIPTION",
"PARSING_GET_PROMPT",
"REQUESTS_POST_TOOL_DESCRIPTION",
"PARSING_POST_PROMPT",
"REQUESTS_PATCH_TOOL_DESCRIPTION",
"PARSING_PATCH_PROMPT",
"REQUESTS_PUT_TOOL_DESCRIPTION",
"PARSING_PUT_PROMPT",
"REQUESTS_DELETE_TOOL_DESCRIPTION",
"PARSING_DELETE_PROMPT",
]

View File

@@ -1,29 +1,7 @@
# flake8: noqa
from langchain_community.agent_toolkits.openapi.prompt import (
DESCRIPTION,
OPENAPI_PREFIX,
OPENAPI_SUFFIX,
)
OPENAPI_PREFIX = """You are an agent designed to answer questions by making web requests to an API given the openapi spec.
If the question does not seem related to the API, return I don't know. Do not make up an answer.
Only use information provided by the tools to construct your response.
First, find the base URL needed to make the request.
Second, find the relevant paths needed to answer the question. Take note that, sometimes, you might need to make more than one request to more than one path to answer the question.
Third, find the required parameters needed to make the request. For GET requests, these are usually URL parameters and for POST requests, these are request body parameters.
Fourth, make the requests needed to answer the question. Ensure that you are sending the correct parameters to the request by checking which parameters are required. For parameters with a fixed set of values, please use the spec to look at which values are allowed.
Use the exact parameter names as listed in the spec, do not make up any names or abbreviate the names of parameters.
If you get a not found error, ensure that you are using a path that actually exists in the spec.
"""
OPENAPI_SUFFIX = """Begin!
Question: {input}
Thought: I should explore the spec to find the base url for the API.
{agent_scratchpad}"""
DESCRIPTION = """Can be used to answer questions about the openapi spec for the API. Always use this tool before trying to make a request.
Example inputs to this tool:
'What are the required query parameters for a GET request to the /bar endpoint?`
'What are the required parameters in the request body for a POST request to the /foo endpoint?'
Always give this tool a specific question."""
__all__ = ["OPENAPI_PREFIX", "OPENAPI_SUFFIX", "DESCRIPTION"]

View File

@@ -1,75 +1,6 @@
"""Quick and dirty representation for OpenAPI specs."""
from langchain_community.agent_toolkits.openapi.spec import (
ReducedOpenAPISpec,
reduce_openapi_spec,
)
from dataclasses import dataclass
from typing import List, Tuple
from langchain.utils.json_schema import dereference_refs
@dataclass(frozen=True)
class ReducedOpenAPISpec:
"""A reduced OpenAPI spec.
This is a quick and dirty representation for OpenAPI specs.
Attributes:
servers: The servers in the spec.
description: The description of the spec.
endpoints: The endpoints in the spec.
"""
servers: List[dict]
description: str
endpoints: List[Tuple[str, str, dict]]
def reduce_openapi_spec(spec: dict, dereference: bool = True) -> ReducedOpenAPISpec:
"""Simplify/distill/minify a spec somehow.
I want a smaller target for retrieval and (more importantly)
I want smaller results from retrieval.
I was hoping https://openapi.tools/ would have some useful bits
to this end, but doesn't seem so.
"""
# 1. Consider only get, post, patch, put, delete endpoints.
endpoints = [
(f"{operation_name.upper()} {route}", docs.get("description"), docs)
for route, operation in spec["paths"].items()
for operation_name, docs in operation.items()
if operation_name in ["get", "post", "patch", "put", "delete"]
]
# 2. Replace any refs so that complete docs are retrieved.
# Note: probably want to do this post-retrieval, it blows up the size of the spec.
if dereference:
endpoints = [
(name, description, dereference_refs(docs, full_schema=spec))
for name, description, docs in endpoints
]
# 3. Strip docs down to required request args + happy path response.
def reduce_endpoint_docs(docs: dict) -> dict:
out = {}
if docs.get("description"):
out["description"] = docs.get("description")
if docs.get("parameters"):
out["parameters"] = [
parameter
for parameter in docs.get("parameters", [])
if parameter.get("required")
]
if "200" in docs["responses"]:
out["responses"] = docs["responses"]["200"]
if docs.get("requestBody"):
out["requestBody"] = docs.get("requestBody")
return out
endpoints = [
(name, description, reduce_endpoint_docs(docs))
for name, description, docs in endpoints
]
return ReducedOpenAPISpec(
servers=spec["servers"],
description=spec["info"].get("description", ""),
endpoints=endpoints,
)
__all__ = ["ReducedOpenAPISpec", "reduce_openapi_spec"]

View File

@@ -1,91 +1,6 @@
"""Requests toolkit."""
from __future__ import annotations
from typing import Any, List
from langchain_core.language_models import BaseLanguageModel
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.agents.agent_toolkits.json.base import create_json_agent
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
from langchain.agents.agent_toolkits.openapi.prompt import DESCRIPTION
from langchain.agents.tools import Tool
from langchain.tools import BaseTool
from langchain.tools.json.tool import JsonSpec
from langchain.tools.requests.tool import (
RequestsDeleteTool,
RequestsGetTool,
RequestsPatchTool,
RequestsPostTool,
RequestsPutTool,
from langchain_community.agent_toolkits.openapi.toolkit import (
OpenAPIToolkit,
RequestsToolkit,
)
from langchain.utilities.requests import TextRequestsWrapper
class RequestsToolkit(BaseToolkit):
"""Toolkit for making REST requests.
*Security Note*: This toolkit contains tools to make GET, POST, PATCH, PUT,
and DELETE requests to an API.
Exercise care in who is allowed to use this toolkit. If exposing
to end users, consider that users will be able to make arbitrary
requests on behalf of the server hosting the code. For example,
users could ask the server to make a request to a private API
that is only accessible from the server.
Control access to who can submit issue requests using this toolkit and
what network access it has.
See https://python.langchain.com/docs/security for more information.
"""
requests_wrapper: TextRequestsWrapper
def get_tools(self) -> List[BaseTool]:
"""Return a list of tools."""
return [
RequestsGetTool(requests_wrapper=self.requests_wrapper),
RequestsPostTool(requests_wrapper=self.requests_wrapper),
RequestsPatchTool(requests_wrapper=self.requests_wrapper),
RequestsPutTool(requests_wrapper=self.requests_wrapper),
RequestsDeleteTool(requests_wrapper=self.requests_wrapper),
]
class OpenAPIToolkit(BaseToolkit):
"""Toolkit for interacting with an OpenAPI API.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by creating, deleting, or updating,
reading underlying data.
For example, this toolkit can be used to delete data exposed via
an OpenAPI compliant API.
"""
json_agent: AgentExecutor
requests_wrapper: TextRequestsWrapper
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
json_agent_tool = Tool(
name="json_explorer",
func=self.json_agent.run,
description=DESCRIPTION,
)
request_toolkit = RequestsToolkit(requests_wrapper=self.requests_wrapper)
return [*request_toolkit.get_tools(), json_agent_tool]
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
json_spec: JsonSpec,
requests_wrapper: TextRequestsWrapper,
**kwargs: Any,
) -> OpenAPIToolkit:
"""Create json agent from llm, then initialize."""
json_agent = create_json_agent(llm, JsonToolkit(spec=json_spec), **kwargs)
return cls(json_agent=json_agent, requests_wrapper=requests_wrapper)
__all__ = ["RequestsToolkit", "OpenAPIToolkit"]

View File

@@ -1,108 +1,5 @@
"""Playwright web browser toolkit."""
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional, Type, cast
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools.base import BaseTool
from langchain.tools.playwright.base import (
BaseBrowserTool,
lazy_import_playwright_browsers,
from langchain_community.agent_toolkits.playwright.toolkit import (
PlayWrightBrowserToolkit,
)
from langchain.tools.playwright.click import ClickTool
from langchain.tools.playwright.current_page import CurrentWebPageTool
from langchain.tools.playwright.extract_hyperlinks import ExtractHyperlinksTool
from langchain.tools.playwright.extract_text import ExtractTextTool
from langchain.tools.playwright.get_elements import GetElementsTool
from langchain.tools.playwright.navigate import NavigateTool
from langchain.tools.playwright.navigate_back import NavigateBackTool
if TYPE_CHECKING:
from playwright.async_api import Browser as AsyncBrowser
from playwright.sync_api import Browser as SyncBrowser
else:
try:
# We do this so pydantic can resolve the types when instantiating
from playwright.async_api import Browser as AsyncBrowser
from playwright.sync_api import Browser as SyncBrowser
except ImportError:
pass
class PlayWrightBrowserToolkit(BaseToolkit):
"""Toolkit for PlayWright browser tools.
**Security Note**: This toolkit provides code to control a web-browser.
Careful if exposing this toolkit to end-users. The tools in the toolkit
are capable of navigating to arbitrary webpages, clicking on arbitrary
elements, and extracting arbitrary text and hyperlinks from webpages.
Specifically, by default this toolkit allows navigating to:
- Any URL (including any internal network URLs)
- And local files
If exposing to end-users, consider limiting network access to the
server that hosts the agent; in addition, consider it is advised
to create a custom NavigationTool wht an args_schema that limits the URLs
that can be navigated to (e.g., only allow navigating to URLs that
start with a particular prefix).
Remember to scope permissions to the minimal permissions necessary for
the application. If the default tool selection is not appropriate for
the application, consider creating a custom toolkit with the appropriate
tools.
See https://python.langchain.com/docs/security for more information.
"""
sync_browser: Optional["SyncBrowser"] = None
async_browser: Optional["AsyncBrowser"] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator
def validate_imports_and_browser_provided(cls, values: dict) -> dict:
"""Check that the arguments are valid."""
lazy_import_playwright_browsers()
if values.get("async_browser") is None and values.get("sync_browser") is None:
raise ValueError("Either async_browser or sync_browser must be specified.")
return values
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
tool_classes: List[Type[BaseBrowserTool]] = [
ClickTool,
NavigateTool,
NavigateBackTool,
ExtractTextTool,
ExtractHyperlinksTool,
GetElementsTool,
CurrentWebPageTool,
]
tools = [
tool_cls.from_browser(
sync_browser=self.sync_browser, async_browser=self.async_browser
)
for tool_cls in tool_classes
]
return cast(List[BaseTool], tools)
@classmethod
def from_browser(
cls,
sync_browser: Optional[SyncBrowser] = None,
async_browser: Optional[AsyncBrowser] = None,
) -> PlayWrightBrowserToolkit:
"""Instantiate the toolkit."""
# This is to raise a better error than the forward ref ones Pydantic would have
lazy_import_playwright_browsers()
return cls(sync_browser=sync_browser, async_browser=async_browser)
__all__ = ["PlayWrightBrowserToolkit"]

View File

@@ -1,63 +1,3 @@
"""Power BI agent."""
from typing import Any, Dict, List, Optional
from langchain_community.agent_toolkits.powerbi.base import create_pbi_agent
from langchain_core.language_models import BaseLanguageModel
from langchain.agents import AgentExecutor
from langchain.agents.agent_toolkits.powerbi.prompt import (
POWERBI_PREFIX,
POWERBI_SUFFIX,
)
from langchain.agents.agent_toolkits.powerbi.toolkit import PowerBIToolkit
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.utilities.powerbi import PowerBIDataset
def create_pbi_agent(
llm: BaseLanguageModel,
toolkit: Optional[PowerBIToolkit] = None,
powerbi: Optional[PowerBIDataset] = None,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = POWERBI_PREFIX,
suffix: str = POWERBI_SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
examples: Optional[str] = None,
input_variables: Optional[List[str]] = None,
top_k: int = 10,
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a Power BI agent from an LLM and tools."""
if toolkit is None:
if powerbi is None:
raise ValueError("Must provide either a toolkit or powerbi dataset")
toolkit = PowerBIToolkit(powerbi=powerbi, llm=llm, examples=examples)
tools = toolkit.get_tools()
tables = powerbi.table_names if powerbi else toolkit.powerbi.table_names
agent = ZeroShotAgent(
llm_chain=LLMChain(
llm=llm,
prompt=ZeroShotAgent.create_prompt(
tools,
prefix=prefix.format(top_k=top_k).format(tables=tables),
suffix=suffix,
format_instructions=format_instructions,
input_variables=input_variables,
),
callback_manager=callback_manager, # type: ignore
verbose=verbose,
),
allowed_tools=[tool.name for tool in tools],
**kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
**(agent_executor_kwargs or {}),
)
__all__ = ["create_pbi_agent"]

View File

@@ -1,64 +1,3 @@
"""Power BI agent."""
from typing import Any, Dict, List, Optional
from langchain_community.agent_toolkits.powerbi.chat_base import create_pbi_chat_agent
from langchain.agents import AgentExecutor
from langchain.agents.agent import AgentOutputParser
from langchain.agents.agent_toolkits.powerbi.prompt import (
POWERBI_CHAT_PREFIX,
POWERBI_CHAT_SUFFIX,
)
from langchain.agents.agent_toolkits.powerbi.toolkit import PowerBIToolkit
from langchain.agents.conversational_chat.base import ConversationalChatAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.chat_models.base import BaseChatModel
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_memory import BaseChatMemory
from langchain.utilities.powerbi import PowerBIDataset
def create_pbi_chat_agent(
llm: BaseChatModel,
toolkit: Optional[PowerBIToolkit] = None,
powerbi: Optional[PowerBIDataset] = None,
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = POWERBI_CHAT_PREFIX,
suffix: str = POWERBI_CHAT_SUFFIX,
examples: Optional[str] = None,
input_variables: Optional[List[str]] = None,
memory: Optional[BaseChatMemory] = None,
top_k: int = 10,
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a Power BI agent from a Chat LLM and tools.
If you supply only a toolkit and no Power BI dataset, the same LLM is used for both.
"""
if toolkit is None:
if powerbi is None:
raise ValueError("Must provide either a toolkit or powerbi dataset")
toolkit = PowerBIToolkit(powerbi=powerbi, llm=llm, examples=examples)
tools = toolkit.get_tools()
tables = powerbi.table_names if powerbi else toolkit.powerbi.table_names
agent = ConversationalChatAgent.from_llm_and_tools(
llm=llm,
tools=tools,
system_message=prefix.format(top_k=top_k).format(tables=tables),
human_message=suffix,
input_variables=input_variables,
callback_manager=callback_manager,
output_parser=output_parser,
verbose=verbose,
**kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
memory=memory
or ConversationBufferMemory(memory_key="chat_history", return_messages=True),
verbose=verbose,
**(agent_executor_kwargs or {}),
)
__all__ = ["create_pbi_chat_agent"]

View File

@@ -1,38 +1,13 @@
# flake8: noqa
"""Prompts for PowerBI agent."""
from langchain_community.agent_toolkits.powerbi.prompt import (
POWERBI_CHAT_PREFIX,
POWERBI_CHAT_SUFFIX,
POWERBI_PREFIX,
POWERBI_SUFFIX,
)
POWERBI_PREFIX = """You are an agent designed to help users interact with a PowerBI Dataset.
Agent has access to a tool that can write a query based on the question and then run those against PowerBI, Microsofts business intelligence tool. The questions from the users should be interpreted as related to the dataset that is available and not general questions about the world. If the question does not seem related to the dataset, return "This does not appear to be part of this dataset." as the answer.
Given an input question, ask to run the questions against the dataset, then look at the results and return the answer, the answer should be a complete sentence that answers the question, if multiple rows are asked find a way to write that in a easily readable format for a human, also make sure to represent numbers in readable ways, like 1M instead of 1000000. Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
"""
POWERBI_SUFFIX = """Begin!
Question: {input}
Thought: I can first ask which tables I have, then how each table is defined and then ask the query tool the question I need, and finally create a nice sentence that answers the question.
{agent_scratchpad}"""
POWERBI_CHAT_PREFIX = """Assistant is a large language model built to help users interact with a PowerBI Dataset.
Assistant should try to create a correct and complete answer to the question from the user. If the user asks a question not related to the dataset it should return "This does not appear to be part of this dataset." as the answer. The user might make a mistake with the spelling of certain values, if you think that is the case, ask the user to confirm the spelling of the value and then run the query again. Unless the user specifies a specific number of examples they wish to obtain, and the results are too large, limit your query to at most {top_k} results, but make it clear when answering which field was used for the filtering. The user has access to these tables: {{tables}}.
The answer should be a complete sentence that answers the question, if multiple rows are asked find a way to write that in a easily readable format for a human, also make sure to represent numbers in readable ways, like 1M instead of 1000000.
"""
POWERBI_CHAT_SUFFIX = """TOOLS
------
Assistant can ask the user to use tools to look up information that may be helpful in answering the users original question. The tools the human can use are:
{{tools}}
{format_instructions}
USER'S INPUT
--------------------
Here is the user's input (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):
{{{{input}}}}
"""
__all__ = [
"POWERBI_PREFIX",
"POWERBI_SUFFIX",
"POWERBI_CHAT_PREFIX",
"POWERBI_CHAT_SUFFIX",
]

View File

@@ -1,102 +1,3 @@
"""Toolkit for interacting with a Power BI dataset."""
from typing import List, Optional, Union
from langchain_community.agent_toolkits.powerbi.toolkit import PowerBIToolkit
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_core.pydantic_v1 import Field
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.chat_models.base import BaseChatModel
from langchain.tools import BaseTool
from langchain.tools.powerbi.prompt import (
QUESTION_TO_QUERY_BASE,
SINGLE_QUESTION_TO_QUERY,
USER_INPUT,
)
from langchain.tools.powerbi.tool import (
InfoPowerBITool,
ListPowerBITool,
QueryPowerBITool,
)
from langchain.utilities.powerbi import PowerBIDataset
class PowerBIToolkit(BaseToolkit):
"""Toolkit for interacting with Power BI dataset.
*Security Note*: This toolkit interacts with an external service.
Control access to who can use this toolkit.
Make sure that the capabilities given by this toolkit to the calling
code are appropriately scoped to the application.
See https://python.langchain.com/docs/security for more information.
"""
powerbi: PowerBIDataset = Field(exclude=True)
llm: Union[BaseLanguageModel, BaseChatModel] = Field(exclude=True)
examples: Optional[str] = None
max_iterations: int = 5
callback_manager: Optional[BaseCallbackManager] = None
output_token_limit: Optional[int] = None
tiktoken_model_name: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
QueryPowerBITool(
llm_chain=self._get_chain(),
powerbi=self.powerbi,
examples=self.examples,
max_iterations=self.max_iterations,
output_token_limit=self.output_token_limit,
tiktoken_model_name=self.tiktoken_model_name,
),
InfoPowerBITool(powerbi=self.powerbi),
ListPowerBITool(powerbi=self.powerbi),
]
def _get_chain(self) -> LLMChain:
"""Construct the chain based on the callback manager and model type."""
if isinstance(self.llm, BaseLanguageModel):
return LLMChain(
llm=self.llm,
callback_manager=self.callback_manager
if self.callback_manager
else None,
prompt=PromptTemplate(
template=SINGLE_QUESTION_TO_QUERY,
input_variables=["tool_input", "tables", "schemas", "examples"],
),
)
system_prompt = SystemMessagePromptTemplate(
prompt=PromptTemplate(
template=QUESTION_TO_QUERY_BASE,
input_variables=["tables", "schemas", "examples"],
)
)
human_prompt = HumanMessagePromptTemplate(
prompt=PromptTemplate(
template=USER_INPUT,
input_variables=["tool_input"],
)
)
return LLMChain(
llm=self.llm,
callback_manager=self.callback_manager if self.callback_manager else None,
prompt=ChatPromptTemplate.from_messages([system_prompt, human_prompt]),
)
__all__ = ["PowerBIToolkit"]

View File

@@ -1,35 +1,3 @@
from __future__ import annotations
from langchain_community.agent_toolkits.slack.toolkit import SlackToolkit
from typing import TYPE_CHECKING, List
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.pydantic_v1 import Field
from langchain.tools import BaseTool
from langchain.tools.slack.get_channel import SlackGetChannel
from langchain.tools.slack.get_message import SlackGetMessage
from langchain.tools.slack.schedule_message import SlackScheduleMessage
from langchain.tools.slack.send_message import SlackSendMessage
from langchain.tools.slack.utils import login
if TYPE_CHECKING:
from slack_sdk import WebClient
class SlackToolkit(BaseToolkit):
"""Toolkit for interacting with Slack."""
client: WebClient = Field(default_factory=login)
class Config:
"""Pydantic config."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
SlackGetChannel(),
SlackGetMessage(),
SlackScheduleMessage(),
SlackSendMessage(),
]
__all__ = ["SlackToolkit"]

View File

@@ -1,60 +1,3 @@
"""Spark SQL agent."""
from typing import Any, Dict, List, Optional
from langchain_community.agent_toolkits.spark_sql.base import create_spark_sql_agent
from langchain_core.language_models import BaseLanguageModel
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUFFIX
from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.chains.llm import LLMChain
def create_spark_sql_agent(
llm: BaseLanguageModel,
toolkit: SparkSQLToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
prefix: str = SQL_PREFIX,
suffix: str = SQL_SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
top_k: int = 10,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a Spark SQL agent from an LLM and tools."""
tools = toolkit.get_tools()
prefix = prefix.format(top_k=top_k)
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
callbacks=callbacks,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
callbacks=callbacks,
verbose=verbose,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)
__all__ = ["create_spark_sql_agent"]

View File

@@ -1,21 +1,3 @@
# flake8: noqa
from langchain_community.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUFFIX
SQL_PREFIX = """You are an agent designed to interact with Spark SQL.
Given an input question, create a syntactically correct Spark SQL query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
"""
SQL_SUFFIX = """Begin!
Question: {input}
Thought: I should look at the tables in the database to see what I can query.
{agent_scratchpad}"""
__all__ = ["SQL_PREFIX", "SQL_SUFFIX"]

View File

@@ -1,36 +1,3 @@
"""Toolkit for interacting with Spark SQL."""
from typing import List
from langchain_community.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Field
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
from langchain.tools.spark_sql.tool import (
InfoSparkSQLTool,
ListSparkSQLTool,
QueryCheckerTool,
QuerySparkSQLTool,
)
from langchain.utilities.spark_sql import SparkSQL
class SparkSQLToolkit(BaseToolkit):
"""Toolkit for interacting with Spark SQL."""
db: SparkSQL = Field(exclude=True)
llm: BaseLanguageModel = Field(exclude=True)
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
QuerySparkSQLTool(db=self.db),
InfoSparkSQLTool(db=self.db),
ListSparkSQLTool(db=self.db),
QueryCheckerTool(db=self.db, llm=self.llm),
]
__all__ = ["SparkSQLToolkit"]

View File

@@ -1,96 +1,3 @@
"""SQL agent."""
from typing import Any, Dict, List, Optional, Sequence
from langchain_community.agent_toolkits.sql.base import create_sql_agent
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.agent_toolkits.sql.prompt import (
SQL_FUNCTIONS_SUFFIX,
SQL_PREFIX,
SQL_SUFFIX,
)
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.tools import BaseTool
def create_sql_agent(
llm: BaseLanguageModel,
toolkit: SQLDatabaseToolkit,
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = SQL_PREFIX,
suffix: Optional[str] = None,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
top_k: int = 10,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
extra_tools: Sequence[BaseTool] = (),
**kwargs: Any,
) -> AgentExecutor:
"""Construct an SQL agent from an LLM and tools."""
tools = toolkit.get_tools() + list(extra_tools)
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
agent: BaseSingleActionAgent
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix or SQL_SUFFIX,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
elif agent_type == AgentType.OPENAI_FUNCTIONS:
messages = [
SystemMessage(content=prefix),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
input_variables = ["input", "agent_scratchpad"]
_prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)
agent = OpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)
else:
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)
__all__ = ["create_sql_agent"]

View File

@@ -1,23 +1,7 @@
# flake8: noqa
from langchain_community.agent_toolkits.sql.prompt import (
SQL_FUNCTIONS_SUFFIX,
SQL_PREFIX,
SQL_SUFFIX,
)
SQL_PREFIX = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
"""
SQL_SUFFIX = """Begin!
Question: {input}
Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
{agent_scratchpad}"""
SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."""
__all__ = ["SQL_PREFIX", "SQL_SUFFIX", "SQL_FUNCTIONS_SUFFIX"]

View File

@@ -1,71 +1,3 @@
"""Toolkit for interacting with an SQL database."""
from typing import List
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Field
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
from langchain.tools.sql_database.tool import (
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
QuerySQLCheckerTool,
QuerySQLDataBaseTool,
)
from langchain.utilities.sql_database import SQLDatabase
class SQLDatabaseToolkit(BaseToolkit):
"""Toolkit for interacting with SQL databases."""
db: SQLDatabase = Field(exclude=True)
llm: BaseLanguageModel = Field(exclude=True)
@property
def dialect(self) -> str:
"""Return string representation of SQL dialect to use."""
return self.db.dialect
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
list_sql_database_tool = ListSQLDatabaseTool(db=self.db)
info_sql_database_tool_description = (
"Input to this tool is a comma-separated list of tables, output is the "
"schema and sample rows for those tables. "
"Be sure that the tables actually exist by calling "
f"{list_sql_database_tool.name} first! "
"Example Input: table1, table2, table3"
)
info_sql_database_tool = InfoSQLDatabaseTool(
db=self.db, description=info_sql_database_tool_description
)
query_sql_database_tool_description = (
"Input to this tool is a detailed and correct SQL query, output is a "
"result from the database. If the query is not correct, an error message "
"will be returned. If an error is returned, rewrite the query, check the "
"query, and try again. If you encounter an issue with Unknown column "
f"'xxxx' in 'field list', use {info_sql_database_tool.name} "
"to query the correct table fields."
)
query_sql_database_tool = QuerySQLDataBaseTool(
db=self.db, description=query_sql_database_tool_description
)
query_sql_checker_tool_description = (
"Use this tool to double check if your query is correct before executing "
"it. Always use this tool before executing a query with "
f"{query_sql_database_tool.name}!"
)
query_sql_checker_tool = QuerySQLCheckerTool(
db=self.db, llm=self.llm, description=query_sql_checker_tool_description
)
return [
query_sql_database_tool,
info_sql_database_tool,
list_sql_database_tool,
query_sql_checker_tool,
]
__all__ = ["SQLDatabaseToolkit"]

View File

@@ -1,48 +1,3 @@
"""Steam Toolkit."""
from typing import List
from langchain_community.agent_toolkits.steam.toolkit import SteamToolkit
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
from langchain.tools.steam.prompt import (
STEAM_GET_GAMES_DETAILS,
STEAM_GET_RECOMMENDED_GAMES,
)
from langchain.tools.steam.tool import SteamWebAPIQueryRun
from langchain.utilities.steam import SteamWebAPIWrapper
class SteamToolkit(BaseToolkit):
"""Steam Toolkit."""
tools: List[BaseTool] = []
@classmethod
def from_steam_api_wrapper(
cls, steam_api_wrapper: SteamWebAPIWrapper
) -> "SteamToolkit":
operations: List[dict] = [
{
"mode": "get_games_details",
"name": "Get Games Details",
"description": STEAM_GET_GAMES_DETAILS,
},
{
"mode": "get_recommended_games",
"name": "Get Recommended Games",
"description": STEAM_GET_RECOMMENDED_GAMES,
},
]
tools = [
SteamWebAPIQueryRun(
name=action["name"],
description=action["description"],
mode=action["mode"],
api_wrapper=steam_api_wrapper,
)
for action in operations
]
return cls(tools=tools)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return self.tools
__all__ = ["SteamToolkit"]

View File

@@ -1,60 +1,3 @@
"""[DEPRECATED] Zapier Toolkit."""
from typing import List
from langchain_community.agent_toolkits.zapier.toolkit import ZapierToolkit
from langchain_core._api import warn_deprecated
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
from langchain.tools.zapier.tool import ZapierNLARunAction
from langchain.utilities.zapier import ZapierNLAWrapper
class ZapierToolkit(BaseToolkit):
"""Zapier Toolkit."""
tools: List[BaseTool] = []
@classmethod
def from_zapier_nla_wrapper(
cls, zapier_nla_wrapper: ZapierNLAWrapper
) -> "ZapierToolkit":
"""Create a toolkit from a ZapierNLAWrapper."""
actions = zapier_nla_wrapper.list()
tools = [
ZapierNLARunAction(
action_id=action["id"],
zapier_description=action["description"],
params_schema=action["params"],
api_wrapper=zapier_nla_wrapper,
)
for action in actions
]
return cls(tools=tools)
@classmethod
async def async_from_zapier_nla_wrapper(
cls, zapier_nla_wrapper: ZapierNLAWrapper
) -> "ZapierToolkit":
"""Create a toolkit from a ZapierNLAWrapper."""
actions = await zapier_nla_wrapper.alist()
tools = [
ZapierNLARunAction(
action_id=action["id"],
zapier_description=action["description"],
params_schema=action["params"],
api_wrapper=zapier_nla_wrapper,
)
for action in actions
]
return cls(tools=tools)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
warn_deprecated(
since="0.0.319",
message=(
"This tool will be deprecated on 2023-11-17. See "
"https://nla.zapier.com/sunset/ for details"
),
)
return self.tools
__all__ = ["ZapierToolkit"]

View File

@@ -9,6 +9,7 @@ from langchain_core.prompts.chat import (
SystemMessagePromptTemplate,
)
from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool
from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.chat.output_parser import ChatOutputParser
@@ -21,7 +22,6 @@ from langchain.agents.chat.prompt import (
from langchain.agents.utils import validate_tools_single_input
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.tools.base import BaseTool
class ChatAgent(Agent):

View File

@@ -6,6 +6,7 @@ from typing import Any, List, Optional, Sequence
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool
from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.agent_types import AgentType
@@ -14,7 +15,6 @@ from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX,
from langchain.agents.utils import validate_tools_single_input
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.tools.base import BaseTool
class ConversationalAgent(Agent):

View File

@@ -15,6 +15,7 @@ from langchain_core.prompts.chat import (
SystemMessagePromptTemplate,
)
from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool
from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.conversational_chat.output_parser import ConvoOutputParser
@@ -26,7 +27,6 @@ from langchain.agents.conversational_chat.prompt import (
from langchain.agents.utils import validate_tools_single_input
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.tools.base import BaseTool
class ConversationalChatAgent(Agent):

View File

@@ -2,12 +2,12 @@
from typing import Any, Optional, Sequence
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_types import AgentType
from langchain.agents.loading import AGENT_TO_CLASS, load_agent
from langchain.callbacks.base import BaseCallbackManager
from langchain.tools.base import BaseTool
def initialize_agent(

View File

@@ -30,7 +30,7 @@ from langchain.utilities.requests import TextRequestsWrapper
from langchain.tools.arxiv.tool import ArxivQueryRun
from langchain.tools.golden_query.tool import GoldenQueryRun
from langchain.tools.pubmed.tool import PubmedQueryRun
from langchain.tools.base import BaseTool
from langchain_core.tools import BaseTool
from langchain.tools.bing_search.tool import BingSearchRun
from langchain.tools.ddg_search.tool import DuckDuckGoSearchRun
from langchain.tools.google_cloud.texttospeech import GoogleCloudTextToSpeechTool

View File

@@ -6,6 +6,7 @@ from typing import Any, Callable, List, NamedTuple, Optional, Sequence
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
from langchain.agents.agent_types import AgentType
@@ -15,7 +16,6 @@ from langchain.agents.tools import Tool
from langchain.agents.utils import validate_tools_single_input
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.tools.base import BaseTool
class ChainConfig(NamedTuple):

View File

@@ -9,9 +9,9 @@ from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.load import dumpd
from langchain_core.pydantic_v1 import Field
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.tools import BaseTool
from langchain.callbacks.manager import CallbackManager
from langchain.tools.base import BaseTool
from langchain.tools.render import format_tool_to_openai_tool
if TYPE_CHECKING:

View File

@@ -15,6 +15,7 @@ from langchain_core.prompts.chat import (
MessagesPlaceholder,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.tools import BaseTool
from langchain.agents import BaseSingleActionAgent
from langchain.agents.format_scratchpad.openai_functions import (
@@ -25,7 +26,6 @@ from langchain.agents.output_parsers.openai_functions import (
)
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.tools.base import BaseTool
from langchain.tools.render import format_tool_to_openai_function

View File

@@ -5,6 +5,7 @@ from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
from langchain.agents.agent_types import AgentType
@@ -14,7 +15,6 @@ from langchain.agents.react.wiki_prompt import WIKI_PROMPT
from langchain.agents.tools import Tool
from langchain.agents.utils import validate_tools_single_input
from langchain.docstore.base import Docstore
from langchain.tools.base import BaseTool
class ReActDocstoreAgent(Agent):

View File

@@ -4,6 +4,7 @@ from typing import Any, Sequence, Union
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
from langchain.agents.agent_types import AgentType
@@ -11,7 +12,6 @@ from langchain.agents.self_ask_with_search.output_parser import SelfAskOutputPar
from langchain.agents.self_ask_with_search.prompt import PROMPT
from langchain.agents.tools import Tool
from langchain.agents.utils import validate_tools_single_input
from langchain.tools.base import BaseTool
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
from langchain.utilities.searchapi import SearchApiAPIWrapper
from langchain.utilities.serpapi import SerpAPIWrapper

View File

@@ -1,11 +1,12 @@
"""Interface for tools."""
from typing import List, Optional
from langchain_core.tools import BaseTool, Tool, tool
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.tools.base import BaseTool, Tool, tool
class InvalidTool(BaseTool):

View File

@@ -1,6 +1,6 @@
from typing import Sequence
from langchain.tools.base import BaseTool
from langchain_core.tools import BaseTool
def validate_tools_single_input(class_name: str, tools: Sequence[BaseTool]) -> None:

View File

@@ -2,13 +2,13 @@ from typing import Any, List, Tuple, Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.prompts.chat import AIMessagePromptTemplate, ChatPromptTemplate
from langchain_core.tools import BaseTool
from langchain.agents.agent import BaseSingleActionAgent
from langchain.agents.output_parsers.xml import XMLAgentOutputParser
from langchain.agents.xml.prompt import agent_instructions
from langchain.callbacks.base import Callbacks
from langchain.chains.llm import LLMChain
from langchain.tools.base import BaseTool
class XMLAgent(BaseSingleActionAgent):

File diff suppressed because it is too large Load Diff

View File

@@ -7,42 +7,53 @@
BaseCallbackHandler --> <name>CallbackHandler # Example: AimCallbackHandler
"""
from langchain_core.tracers.langchain import LangChainTracer
from langchain.callbacks.aim_callback import AimCallbackHandler
from langchain.callbacks.argilla_callback import ArgillaCallbackHandler
from langchain.callbacks.arize_callback import ArizeCallbackHandler
from langchain.callbacks.arthur_callback import ArthurCallbackHandler
from langchain.callbacks.clearml_callback import ClearMLCallbackHandler
from langchain.callbacks.comet_ml_callback import CometCallbackHandler
from langchain.callbacks.context_callback import ContextCallbackHandler
from langchain.callbacks.file import FileCallbackHandler
from langchain.callbacks.flyte_callback import FlyteCallbackHandler
from langchain.callbacks.human import HumanApprovalCallbackHandler
from langchain.callbacks.infino_callback import InfinoCallbackHandler
from langchain.callbacks.labelstudio_callback import LabelStudioCallbackHandler
from langchain.callbacks.llmonitor_callback import LLMonitorCallbackHandler
from langchain.callbacks.manager import (
collect_runs,
from langchain_community.callbacks.aim_callback import AimCallbackHandler
from langchain_community.callbacks.argilla_callback import ArgillaCallbackHandler
from langchain_community.callbacks.arize_callback import ArizeCallbackHandler
from langchain_community.callbacks.arthur_callback import ArthurCallbackHandler
from langchain_community.callbacks.clearml_callback import ClearMLCallbackHandler
from langchain_community.callbacks.comet_ml_callback import CometCallbackHandler
from langchain_community.callbacks.context_callback import ContextCallbackHandler
from langchain_community.callbacks.flyte_callback import FlyteCallbackHandler
from langchain_community.callbacks.human import HumanApprovalCallbackHandler
from langchain_community.callbacks.infino_callback import InfinoCallbackHandler
from langchain_community.callbacks.labelstudio_callback import (
LabelStudioCallbackHandler,
)
from langchain_community.callbacks.llmonitor_callback import LLMonitorCallbackHandler
from langchain_community.callbacks.manager import (
get_openai_callback,
tracing_enabled,
tracing_v2_enabled,
wandb_tracing_enabled,
)
from langchain.callbacks.mlflow_callback import MlflowCallbackHandler
from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.promptlayer_callback import PromptLayerCallbackHandler
from langchain.callbacks.sagemaker_callback import SageMakerCallbackHandler
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain_community.callbacks.mlflow_callback import MlflowCallbackHandler
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
from langchain_community.callbacks.promptlayer_callback import (
PromptLayerCallbackHandler,
)
from langchain_community.callbacks.sagemaker_callback import SageMakerCallbackHandler
from langchain_community.callbacks.streamlit import (
LLMThoughtLabeler,
StreamlitCallbackHandler,
)
from langchain_community.callbacks.trubrics_callback import TrubricsCallbackHandler
from langchain_community.callbacks.wandb_callback import WandbCallbackHandler
from langchain_community.callbacks.whylabs_callback import WhyLabsCallbackHandler
from langchain_core.callbacks import (
StdOutCallbackHandler,
StreamingStdOutCallbackHandler,
)
from langchain_core.tracers.context import (
collect_runs,
tracing_enabled,
tracing_v2_enabled,
)
from langchain_core.tracers.langchain import LangChainTracer
from langchain.callbacks.file import FileCallbackHandler
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.streaming_stdout_final_only import (
FinalStreamingStdOutCallbackHandler,
)
from langchain.callbacks.streamlit import LLMThoughtLabeler, StreamlitCallbackHandler
from langchain.callbacks.trubrics_callback import TrubricsCallbackHandler
from langchain.callbacks.wandb_callback import WandbCallbackHandler
from langchain.callbacks.whylabs_callback import WhyLabsCallbackHandler
__all__ = [
"AimCallbackHandler",

View File

@@ -1,431 +1,7 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional
from langchain_community.callbacks.aim_callback import (
AimCallbackHandler,
BaseMetadataCallbackHandler,
import_aim,
)
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.outputs import LLMResult
from langchain.callbacks.base import BaseCallbackHandler
def import_aim() -> Any:
"""Import the aim python package and raise an error if it is not installed."""
try:
import aim
except ImportError:
raise ImportError(
"To use the Aim callback manager you need to have the"
" `aim` python package installed."
"Please install it with `pip install aim`"
)
return aim
class BaseMetadataCallbackHandler:
"""This class handles the metadata and associated function states for callbacks.
Attributes:
step (int): The current step.
starts (int): The number of times the start method has been called.
ends (int): The number of times the end method has been called.
errors (int): The number of times the error method has been called.
text_ctr (int): The number of times the text method has been called.
ignore_llm_ (bool): Whether to ignore llm callbacks.
ignore_chain_ (bool): Whether to ignore chain callbacks.
ignore_agent_ (bool): Whether to ignore agent callbacks.
ignore_retriever_ (bool): Whether to ignore retriever callbacks.
always_verbose_ (bool): Whether to always be verbose.
chain_starts (int): The number of times the chain start method has been called.
chain_ends (int): The number of times the chain end method has been called.
llm_starts (int): The number of times the llm start method has been called.
llm_ends (int): The number of times the llm end method has been called.
llm_streams (int): The number of times the text method has been called.
tool_starts (int): The number of times the tool start method has been called.
tool_ends (int): The number of times the tool end method has been called.
agent_ends (int): The number of times the agent end method has been called.
"""
def __init__(self) -> None:
self.step = 0
self.starts = 0
self.ends = 0
self.errors = 0
self.text_ctr = 0
self.ignore_llm_ = False
self.ignore_chain_ = False
self.ignore_agent_ = False
self.ignore_retriever_ = False
self.always_verbose_ = False
self.chain_starts = 0
self.chain_ends = 0
self.llm_starts = 0
self.llm_ends = 0
self.llm_streams = 0
self.tool_starts = 0
self.tool_ends = 0
self.agent_ends = 0
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return self.always_verbose_
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return self.ignore_llm_
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return self.ignore_chain_
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
@property
def ignore_retriever(self) -> bool:
"""Whether to ignore retriever callbacks."""
return self.ignore_retriever_
def get_custom_callback_meta(self) -> Dict[str, Any]:
return {
"step": self.step,
"starts": self.starts,
"ends": self.ends,
"errors": self.errors,
"text_ctr": self.text_ctr,
"chain_starts": self.chain_starts,
"chain_ends": self.chain_ends,
"llm_starts": self.llm_starts,
"llm_ends": self.llm_ends,
"llm_streams": self.llm_streams,
"tool_starts": self.tool_starts,
"tool_ends": self.tool_ends,
"agent_ends": self.agent_ends,
}
def reset_callback_meta(self) -> None:
"""Reset the callback metadata."""
self.step = 0
self.starts = 0
self.ends = 0
self.errors = 0
self.text_ctr = 0
self.ignore_llm_ = False
self.ignore_chain_ = False
self.ignore_agent_ = False
self.always_verbose_ = False
self.chain_starts = 0
self.chain_ends = 0
self.llm_starts = 0
self.llm_ends = 0
self.llm_streams = 0
self.tool_starts = 0
self.tool_ends = 0
self.agent_ends = 0
return None
class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""Callback Handler that logs to Aim.
Parameters:
repo (:obj:`str`, optional): Aim repository path or Repo object to which
Run object is bound. If skipped, default Repo is used.
experiment_name (:obj:`str`, optional): Sets Run's `experiment` property.
'default' if not specified. Can be used later to query runs/sequences.
system_tracking_interval (:obj:`int`, optional): Sets the tracking interval
in seconds for system usage metrics (CPU, Memory, etc.). Set to `None`
to disable system metrics tracking.
log_system_params (:obj:`bool`, optional): Enable/Disable logging of system
params such as installed packages, git info, environment variables, etc.
This handler will utilize the associated callback method called and formats
the input of each callback function with metadata regarding the state of LLM run
and then logs the response to Aim.
"""
def __init__(
self,
repo: Optional[str] = None,
experiment_name: Optional[str] = None,
system_tracking_interval: Optional[int] = 10,
log_system_params: bool = True,
) -> None:
"""Initialize callback handler."""
super().__init__()
aim = import_aim()
self.repo = repo
self.experiment_name = experiment_name
self.system_tracking_interval = system_tracking_interval
self.log_system_params = log_system_params
self._run = aim.Run(
repo=self.repo,
experiment=self.experiment_name,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
self._run_hash = self._run.hash
self.action_records: list = []
def setup(self, **kwargs: Any) -> None:
aim = import_aim()
if not self._run:
if self._run_hash:
self._run = aim.Run(
self._run_hash,
repo=self.repo,
system_tracking_interval=self.system_tracking_interval,
)
else:
self._run = aim.Run(
repo=self.repo,
experiment=self.experiment_name,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
self._run_hash = self._run.hash
if kwargs:
for key, value in kwargs.items():
self._run.set(key, value, strict=False)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
aim = import_aim()
self.step += 1
self.llm_starts += 1
self.starts += 1
resp = {"action": "on_llm_start"}
resp.update(self.get_custom_callback_meta())
prompts_res = deepcopy(prompts)
self._run.track(
[aim.Text(prompt) for prompt in prompts_res],
name="on_llm_start",
context=resp,
)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
aim = import_aim()
self.step += 1
self.llm_ends += 1
self.ends += 1
resp = {"action": "on_llm_end"}
resp.update(self.get_custom_callback_meta())
response_res = deepcopy(response)
generated = [
aim.Text(generation.text)
for generations in response_res.generations
for generation in generations
]
self._run.track(
generated,
name="on_llm_end",
context=resp,
)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.step += 1
self.llm_streams += 1
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.step += 1
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
aim = import_aim()
self.step += 1
self.chain_starts += 1
self.starts += 1
resp = {"action": "on_chain_start"}
resp.update(self.get_custom_callback_meta())
inputs_res = deepcopy(inputs)
self._run.track(
aim.Text(inputs_res["input"]), name="on_chain_start", context=resp
)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
aim = import_aim()
self.step += 1
self.chain_ends += 1
self.ends += 1
resp = {"action": "on_chain_end"}
resp.update(self.get_custom_callback_meta())
outputs_res = deepcopy(outputs)
self._run.track(
aim.Text(outputs_res["output"]), name="on_chain_end", context=resp
)
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.step += 1
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
aim = import_aim()
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = {"action": "on_tool_start"}
resp.update(self.get_custom_callback_meta())
self._run.track(aim.Text(input_str), name="on_tool_start", context=resp)
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
aim = import_aim()
self.step += 1
self.tool_ends += 1
self.ends += 1
resp = {"action": "on_tool_end"}
resp.update(self.get_custom_callback_meta())
self._run.track(aim.Text(output), name="on_tool_end", context=resp)
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.step += 1
self.errors += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
"""
self.step += 1
self.text_ctr += 1
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
aim = import_aim()
self.step += 1
self.agent_ends += 1
self.ends += 1
resp = {"action": "on_agent_finish"}
resp.update(self.get_custom_callback_meta())
finish_res = deepcopy(finish)
text = "OUTPUT:\n{}\n\nLOG:\n{}".format(
finish_res.return_values["output"], finish_res.log
)
self._run.track(aim.Text(text), name="on_agent_finish", context=resp)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
aim = import_aim()
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = {
"action": "on_agent_action",
"tool": action.tool,
}
resp.update(self.get_custom_callback_meta())
action_res = deepcopy(action)
text = "TOOL INPUT:\n{}\n\nLOG:\n{}".format(
action_res.tool_input, action_res.log
)
self._run.track(aim.Text(text), name="on_agent_action", context=resp)
def flush_tracker(
self,
repo: Optional[str] = None,
experiment_name: Optional[str] = None,
system_tracking_interval: Optional[int] = 10,
log_system_params: bool = True,
langchain_asset: Any = None,
reset: bool = True,
finish: bool = False,
) -> None:
"""Flush the tracker and reset the session.
Args:
repo (:obj:`str`, optional): Aim repository path or Repo object to which
Run object is bound. If skipped, default Repo is used.
experiment_name (:obj:`str`, optional): Sets Run's `experiment` property.
'default' if not specified. Can be used later to query runs/sequences.
system_tracking_interval (:obj:`int`, optional): Sets the tracking interval
in seconds for system usage metrics (CPU, Memory, etc.). Set to `None`
to disable system metrics tracking.
log_system_params (:obj:`bool`, optional): Enable/Disable logging of system
params such as installed packages, git info, environment variables, etc.
langchain_asset: The langchain asset to save.
reset: Whether to reset the session.
finish: Whether to finish the run.
Returns:
None
"""
if langchain_asset:
try:
for key, value in langchain_asset.dict().items():
self._run.set(key, value, strict=False)
except Exception:
pass
if finish or reset:
self._run.close()
self.reset_callback_meta()
if reset:
self.__init__( # type: ignore
repo=repo if repo else self.repo,
experiment_name=experiment_name
if experiment_name
else self.experiment_name,
system_tracking_interval=system_tracking_interval
if system_tracking_interval
else self.system_tracking_interval,
log_system_params=log_system_params
if log_system_params
else self.log_system_params,
)
__all__ = ["import_aim", "BaseMetadataCallbackHandler", "AimCallbackHandler"]

View File

@@ -1,353 +1,3 @@
import os
import warnings
from typing import Any, Dict, List, Optional
from langchain_community.callbacks.argilla_callback import ArgillaCallbackHandler
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.outputs import LLMResult
from packaging.version import parse
from langchain.callbacks.base import BaseCallbackHandler
class ArgillaCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs into Argilla.
Args:
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
exist in advance. If you need help on how to create a `FeedbackDataset` in
Argilla, please visit
https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html.
workspace_name: name of the workspace in Argilla where the specified
`FeedbackDataset` lives in. Defaults to `None`, which means that the
default workspace will be used.
api_url: URL of the Argilla Server that we want to use, and where the
`FeedbackDataset` lives in. Defaults to `None`, which means that either
`ARGILLA_API_URL` environment variable or the default will be used.
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
means that either `ARGILLA_API_KEY` environment variable or the default
will be used.
Raises:
ImportError: if the `argilla` package is not installed.
ConnectionError: if the connection to Argilla fails.
FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails.
Examples:
>>> from langchain.llms import OpenAI
>>> from langchain.callbacks import ArgillaCallbackHandler
>>> argilla_callback = ArgillaCallbackHandler(
... dataset_name="my-dataset",
... workspace_name="my-workspace",
... api_url="http://localhost:6900",
... api_key="argilla.apikey",
... )
>>> llm = OpenAI(
... temperature=0,
... callbacks=[argilla_callback],
... verbose=True,
... openai_api_key="API_KEY_HERE",
... )
>>> llm.generate([
... "What is the best NLP-annotation tool out there? (no bias at all)",
... ])
"Argilla, no doubt about it."
"""
REPO_URL: str = "https://github.com/argilla-io/argilla"
ISSUES_URL: str = f"{REPO_URL}/issues"
BLOG_URL: str = "https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html" # noqa: E501
DEFAULT_API_URL: str = "http://localhost:6900"
def __init__(
self,
dataset_name: str,
workspace_name: Optional[str] = None,
api_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> None:
"""Initializes the `ArgillaCallbackHandler`.
Args:
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
exist in advance. If you need help on how to create a `FeedbackDataset`
in Argilla, please visit
https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html.
workspace_name: name of the workspace in Argilla where the specified
`FeedbackDataset` lives in. Defaults to `None`, which means that the
default workspace will be used.
api_url: URL of the Argilla Server that we want to use, and where the
`FeedbackDataset` lives in. Defaults to `None`, which means that either
`ARGILLA_API_URL` environment variable or the default will be used.
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
means that either `ARGILLA_API_KEY` environment variable or the default
will be used.
Raises:
ImportError: if the `argilla` package is not installed.
ConnectionError: if the connection to Argilla fails.
FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails.
"""
super().__init__()
# Import Argilla (not via `import_argilla` to keep hints in IDEs)
try:
import argilla as rg # noqa: F401
self.ARGILLA_VERSION = rg.__version__
except ImportError:
raise ImportError(
"To use the Argilla callback manager you need to have the `argilla` "
"Python package installed. Please install it with `pip install argilla`"
)
# Check whether the Argilla version is compatible
if parse(self.ARGILLA_VERSION) < parse("1.8.0"):
raise ImportError(
f"The installed `argilla` version is {self.ARGILLA_VERSION} but "
"`ArgillaCallbackHandler` requires at least version 1.8.0. Please "
"upgrade `argilla` with `pip install --upgrade argilla`."
)
# Show a warning message if Argilla will assume the default values will be used
if api_url is None and os.getenv("ARGILLA_API_URL") is None:
warnings.warn(
(
"Since `api_url` is None, and the env var `ARGILLA_API_URL` is not"
f" set, it will default to `{self.DEFAULT_API_URL}`, which is the"
" default API URL in Argilla Quickstart."
),
)
api_url = self.DEFAULT_API_URL
if api_key is None and os.getenv("ARGILLA_API_KEY") is None:
self.DEFAULT_API_KEY = (
"admin.apikey"
if parse(self.ARGILLA_VERSION) < parse("1.11.0")
else "owner.apikey"
)
warnings.warn(
(
"Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not"
f" set, it will default to `{self.DEFAULT_API_KEY}`, which is the"
" default API key in Argilla Quickstart."
),
)
api_url = self.DEFAULT_API_URL
# Connect to Argilla with the provided credentials, if applicable
try:
rg.init(api_key=api_key, api_url=api_url)
except Exception as e:
raise ConnectionError(
f"Could not connect to Argilla with exception: '{e}'.\n"
"Please check your `api_key` and `api_url`, and make sure that "
"the Argilla server is up and running. If the problem persists "
f"please report it to {self.ISSUES_URL} as an `integration` issue."
) from e
# Set the Argilla variables
self.dataset_name = dataset_name
self.workspace_name = workspace_name or rg.get_workspace()
# Retrieve the `FeedbackDataset` from Argilla (without existing records)
try:
extra_args = {}
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
warnings.warn(
f"You have Argilla {self.ARGILLA_VERSION}, but Argilla 1.14.0 or"
" higher is recommended.",
UserWarning,
)
extra_args = {"with_records": False}
self.dataset = rg.FeedbackDataset.from_argilla(
name=self.dataset_name,
workspace=self.workspace_name,
**extra_args,
)
except Exception as e:
raise FileNotFoundError(
f"`FeedbackDataset` retrieval from Argilla failed with exception `{e}`."
f"\nPlease check that the dataset with name={self.dataset_name} in the"
f" workspace={self.workspace_name} exists in advance. If you need help"
" on how to create a `langchain`-compatible `FeedbackDataset` in"
f" Argilla, please visit {self.BLOG_URL}. If the problem persists"
f" please report it to {self.ISSUES_URL} as an `integration` issue."
) from e
supported_fields = ["prompt", "response"]
if supported_fields != [field.name for field in self.dataset.fields]:
raise ValueError(
f"`FeedbackDataset` with name={self.dataset_name} in the workspace="
f"{self.workspace_name} had fields that are not supported yet for the"
f"`langchain` integration. Supported fields are: {supported_fields},"
f" and the current `FeedbackDataset` fields are {[field.name for field in self.dataset.fields]}." # noqa: E501
" For more information on how to create a `langchain`-compatible"
f" `FeedbackDataset` in Argilla, please visit {self.BLOG_URL}."
)
self.prompts: Dict[str, List[str]] = {}
warnings.warn(
(
"The `ArgillaCallbackHandler` is currently in beta and is subject to"
" change based on updates to `langchain`. Please report any issues to"
f" {self.ISSUES_URL} as an `integration` issue."
),
)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Save the prompts in memory when an LLM starts."""
self.prompts.update({str(kwargs["parent_run_id"] or kwargs["run_id"]): prompts})
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing when a new token is generated."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Log records to Argilla when an LLM ends."""
# Do nothing if there's a parent_run_id, since we will log the records when
# the chain ends
if kwargs["parent_run_id"]:
return
# Creates the records and adds them to the `FeedbackDataset`
prompts = self.prompts[str(kwargs["run_id"])]
for prompt, generations in zip(prompts, response.generations):
self.dataset.add_records(
records=[
{
"fields": {
"prompt": prompt,
"response": generation.text.strip(),
},
}
for generation in generations
]
)
# Pop current run from `self.runs`
self.prompts.pop(str(kwargs["run_id"]))
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
# Push the records to Argilla
self.dataset.push_to_argilla()
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM outputs an error."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""If the key `input` is in `inputs`, then save it in `self.prompts` using
either the `parent_run_id` or the `run_id` as the key. This is done so that
we don't log the same input prompt twice, once when the LLM starts and once
when the chain starts.
"""
if "input" in inputs:
self.prompts.update(
{
str(kwargs["parent_run_id"] or kwargs["run_id"]): (
inputs["input"]
if isinstance(inputs["input"], list)
else [inputs["input"]]
)
}
)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""If either the `parent_run_id` or the `run_id` is in `self.prompts`, then
log the outputs to Argilla, and pop the run from `self.prompts`. The behavior
differs if the output is a list or not.
"""
if not any(
key in self.prompts
for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
):
return
prompts = self.prompts.get(str(kwargs["parent_run_id"])) or self.prompts.get(
str(kwargs["run_id"])
)
for chain_output_key, chain_output_val in outputs.items():
if isinstance(chain_output_val, list):
# Creates the records and adds them to the `FeedbackDataset`
self.dataset.add_records(
records=[
{
"fields": {
"prompt": prompt,
"response": output["text"].strip(),
},
}
for prompt, output in zip(
prompts, # type: ignore
chain_output_val,
)
]
)
else:
# Creates the records and adds them to the `FeedbackDataset`
self.dataset.add_records(
records=[
{
"fields": {
"prompt": " ".join(prompts), # type: ignore
"response": chain_output_val.strip(),
},
}
]
)
# Pop current run from `self.runs`
if str(kwargs["parent_run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["parent_run_id"]))
if str(kwargs["run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["run_id"]))
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
# Push the records to Argilla
self.dataset.push_to_argilla()
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM chain outputs an error."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing when tool starts."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing when agent takes a specific action."""
pass
def on_tool_end(
self,
output: str,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Do nothing when tool ends."""
pass
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when tool outputs an error."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Do nothing"""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Do nothing"""
pass
__all__ = ["ArgillaCallbackHandler"]

View File

@@ -1,213 +1,3 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
from langchain_community.callbacks.arize_callback import ArizeCallbackHandler
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.outputs import LLMResult
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import import_pandas
class ArizeCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs to Arize."""
def __init__(
self,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
SPACE_KEY: Optional[str] = None,
API_KEY: Optional[str] = None,
) -> None:
"""Initialize callback handler."""
super().__init__()
self.model_id = model_id
self.model_version = model_version
self.space_key = SPACE_KEY
self.api_key = API_KEY
self.prompt_records: List[str] = []
self.response_records: List[str] = []
self.prediction_ids: List[str] = []
self.pred_timestamps: List[int] = []
self.response_embeddings: List[float] = []
self.prompt_embeddings: List[float] = []
self.prompt_tokens = 0
self.completion_tokens = 0
self.total_tokens = 0
self.step = 0
from arize.pandas.embeddings import EmbeddingGenerator, UseCases
from arize.pandas.logger import Client
self.generator = EmbeddingGenerator.from_use_case(
use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
model_name="distilbert-base-uncased",
tokenizer_max_length=512,
batch_size=256,
)
self.arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY)
if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY":
raise ValueError("❌ CHANGE SPACE AND API KEYS")
else:
print("✅ Arize client setup done! Now you can start using Arize!")
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
for prompt in prompts:
self.prompt_records.append(prompt.replace("\n", ""))
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pd = import_pandas()
from arize.utils.types import (
EmbeddingColumnNames,
Environments,
ModelTypes,
Schema,
)
# Safe check if 'llm_output' and 'token_usage' exist
if response.llm_output and "token_usage" in response.llm_output:
self.prompt_tokens = response.llm_output["token_usage"].get(
"prompt_tokens", 0
)
self.total_tokens = response.llm_output["token_usage"].get(
"total_tokens", 0
)
self.completion_tokens = response.llm_output["token_usage"].get(
"completion_tokens", 0
)
else:
self.prompt_tokens = (
self.total_tokens
) = self.completion_tokens = 0 # assign default value
for generations in response.generations:
for generation in generations:
prompt = self.prompt_records[self.step]
self.step = self.step + 1
prompt_embedding = pd.Series(
self.generator.generate_embeddings(
text_col=pd.Series(prompt.replace("\n", " "))
).reset_index(drop=True)
)
# Assigning text to response_text instead of response
response_text = generation.text.replace("\n", " ")
response_embedding = pd.Series(
self.generator.generate_embeddings(
text_col=pd.Series(generation.text.replace("\n", " "))
).reset_index(drop=True)
)
pred_timestamp = datetime.now().timestamp()
# Define the columns and data
columns = [
"prediction_ts",
"response",
"prompt",
"response_vector",
"prompt_vector",
"prompt_token",
"completion_token",
"total_token",
]
data = [
[
pred_timestamp,
response_text,
prompt,
response_embedding[0],
prompt_embedding[0],
self.prompt_tokens,
self.total_tokens,
self.completion_tokens,
]
]
# Create the DataFrame
df = pd.DataFrame(data, columns=columns)
# Declare prompt and response columns
prompt_columns = EmbeddingColumnNames(
vector_column_name="prompt_vector", data_column_name="prompt"
)
response_columns = EmbeddingColumnNames(
vector_column_name="response_vector", data_column_name="response"
)
schema = Schema(
timestamp_column_name="prediction_ts",
tag_column_names=[
"prompt_token",
"completion_token",
"total_token",
],
prompt_column_names=prompt_columns,
response_column_names=response_columns,
)
response_from_arize = self.arize_client.log(
dataframe=df,
schema=schema,
model_id=self.model_id,
model_version=self.model_version,
model_type=ModelTypes.GENERATIVE_LLM,
environment=Environments.PRODUCTION,
)
if response_from_arize.status_code == 200:
print("✅ Successfully logged data to Arize!")
else:
print(f'❌ Logging failed "{response_from_arize.text}"')
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing."""
pass
def on_tool_end(
self,
output: str,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
pass
def on_text(self, text: str, **kwargs: Any) -> None:
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
pass
__all__ = ["ArizeCallbackHandler"]

View File

@@ -1,297 +1,19 @@
"""ArthurAI's Callback Handler."""
from __future__ import annotations
from langchain_community.callbacks.arthur_callback import (
COMPLETION_TOKENS,
DURATION,
FINISH_REASON,
PROMPT_TOKENS,
TOKEN_USAGE,
ArthurCallbackHandler,
_lazy_load_arthur,
)
import os
import uuid
from collections import defaultdict
from datetime import datetime
from time import time
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional
import numpy as np
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.outputs import LLMResult
from langchain.callbacks.base import BaseCallbackHandler
if TYPE_CHECKING:
import arthurai
from arthurai.core.models import ArthurModel
PROMPT_TOKENS = "prompt_tokens"
COMPLETION_TOKENS = "completion_tokens"
TOKEN_USAGE = "token_usage"
FINISH_REASON = "finish_reason"
DURATION = "duration"
def _lazy_load_arthur() -> arthurai:
"""Lazy load Arthur."""
try:
import arthurai
except ImportError as e:
raise ImportError(
"To use the ArthurCallbackHandler you need the"
" `arthurai` package. Please install it with"
" `pip install arthurai`.",
e,
)
return arthurai
class ArthurCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs to Arthur platform.
Arthur helps enterprise teams optimize model operations
and performance at scale. The Arthur API tracks model
performance, explainability, and fairness across tabular,
NLP, and CV models. Our API is model- and platform-agnostic,
and continuously scales with complex and dynamic enterprise needs.
To learn more about Arthur, visit our website at
https://www.arthur.ai/ or read the Arthur docs at
https://docs.arthur.ai/
"""
def __init__(
self,
arthur_model: ArthurModel,
) -> None:
"""Initialize callback handler."""
super().__init__()
arthurai = _lazy_load_arthur()
Stage = arthurai.common.constants.Stage
ValueType = arthurai.common.constants.ValueType
self.arthur_model = arthur_model
# save the attributes of this model to be used when preparing
# inferences to log to Arthur in on_llm_end()
self.attr_names = set([a.name for a in self.arthur_model.get_attributes()])
self.input_attr = [
x
for x in self.arthur_model.get_attributes()
if x.stage == Stage.ModelPipelineInput
and x.value_type == ValueType.Unstructured_Text
][0].name
self.output_attr = [
x
for x in self.arthur_model.get_attributes()
if x.stage == Stage.PredictedValue
and x.value_type == ValueType.Unstructured_Text
][0].name
self.token_likelihood_attr = None
if (
len(
[
x
for x in self.arthur_model.get_attributes()
if x.value_type == ValueType.TokenLikelihoods
]
)
> 0
):
self.token_likelihood_attr = [
x
for x in self.arthur_model.get_attributes()
if x.value_type == ValueType.TokenLikelihoods
][0].name
self.run_map: DefaultDict[str, Any] = defaultdict(dict)
@classmethod
def from_credentials(
cls,
model_id: str,
arthur_url: Optional[str] = "https://app.arthur.ai",
arthur_login: Optional[str] = None,
arthur_password: Optional[str] = None,
) -> ArthurCallbackHandler:
"""Initialize callback handler from Arthur credentials.
Args:
model_id (str): The ID of the arthur model to log to.
arthur_url (str, optional): The URL of the Arthur instance to log to.
Defaults to "https://app.arthur.ai".
arthur_login (str, optional): The login to use to connect to Arthur.
Defaults to None.
arthur_password (str, optional): The password to use to connect to
Arthur. Defaults to None.
Returns:
ArthurCallbackHandler: The initialized callback handler.
"""
arthurai = _lazy_load_arthur()
ArthurAI = arthurai.ArthurAI
ResponseClientError = arthurai.common.exceptions.ResponseClientError
# connect to Arthur
if arthur_login is None:
try:
arthur_api_key = os.environ["ARTHUR_API_KEY"]
except KeyError:
raise ValueError(
"No Arthur authentication provided. Either give"
" a login to the ArthurCallbackHandler"
" or set an ARTHUR_API_KEY as an environment variable."
)
arthur = ArthurAI(url=arthur_url, access_key=arthur_api_key)
else:
if arthur_password is None:
arthur = ArthurAI(url=arthur_url, login=arthur_login)
else:
arthur = ArthurAI(
url=arthur_url, login=arthur_login, password=arthur_password
)
# get model from Arthur by the provided model ID
try:
arthur_model = arthur.get_model(model_id)
except ResponseClientError:
raise ValueError(
f"Was unable to retrieve model with id {model_id} from Arthur."
" Make sure the ID corresponds to a model that is currently"
" registered with your Arthur account."
)
return cls(arthur_model)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""On LLM start, save the input prompts"""
run_id = kwargs["run_id"]
self.run_map[run_id]["input_texts"] = prompts
self.run_map[run_id]["start_time"] = time()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""On LLM end, send data to Arthur."""
try:
import pytz # type: ignore[import]
except ImportError as e:
raise ImportError(
"Could not import pytz. Please install it with 'pip install pytz'."
) from e
run_id = kwargs["run_id"]
# get the run params from this run ID,
# or raise an error if this run ID has no corresponding metadata in self.run_map
try:
run_map_data = self.run_map[run_id]
except KeyError as e:
raise KeyError(
"This function has been called with a run_id"
" that was never registered in on_llm_start()."
" Restart and try running the LLM again"
) from e
# mark the duration time between on_llm_start() and on_llm_end()
time_from_start_to_end = time() - run_map_data["start_time"]
# create inferences to log to Arthur
inferences = []
for i, generations in enumerate(response.generations):
for generation in generations:
inference = {
"partner_inference_id": str(uuid.uuid4()),
"inference_timestamp": datetime.now(tz=pytz.UTC),
self.input_attr: run_map_data["input_texts"][i],
self.output_attr: generation.text,
}
if generation.generation_info is not None:
# add finish reason to the inference
# if generation info contains a finish reason and
# if the ArthurModel was registered to monitor finish_reason
if (
FINISH_REASON in generation.generation_info
and FINISH_REASON in self.attr_names
):
inference[FINISH_REASON] = generation.generation_info[
FINISH_REASON
]
# add token likelihoods data to the inference if the ArthurModel
# was registered to monitor token likelihoods
logprobs_data = generation.generation_info["logprobs"]
if (
logprobs_data is not None
and self.token_likelihood_attr is not None
):
logprobs = logprobs_data["top_logprobs"]
likelihoods = [
{k: np.exp(v) for k, v in logprobs[i].items()}
for i in range(len(logprobs))
]
inference[self.token_likelihood_attr] = likelihoods
# add token usage counts to the inference if the
# ArthurModel was registered to monitor token usage
if (
isinstance(response.llm_output, dict)
and TOKEN_USAGE in response.llm_output
):
token_usage = response.llm_output[TOKEN_USAGE]
if (
PROMPT_TOKENS in token_usage
and PROMPT_TOKENS in self.attr_names
):
inference[PROMPT_TOKENS] = token_usage[PROMPT_TOKENS]
if (
COMPLETION_TOKENS in token_usage
and COMPLETION_TOKENS in self.attr_names
):
inference[COMPLETION_TOKENS] = token_usage[COMPLETION_TOKENS]
# add inference duration to the inference if the ArthurModel
# was registered to monitor inference duration
if DURATION in self.attr_names:
inference[DURATION] = time_from_start_to_end
inferences.append(inference)
# send inferences to arthur
self.arthur_model.send_inferences(inferences)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""On chain start, do nothing."""
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""On chain end, do nothing."""
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM outputs an error."""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""On new token, pass."""
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM chain outputs an error."""
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing when tool starts."""
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing when agent takes a specific action."""
def on_tool_end(
self,
output: str,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Do nothing when tool ends."""
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when tool outputs an error."""
def on_text(self, text: str, **kwargs: Any) -> None:
"""Do nothing"""
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Do nothing"""
__all__ = [
"PROMPT_TOKENS",
"COMPLETION_TOKENS",
"TOKEN_USAGE",
"FINISH_REASON",
"DURATION",
"_lazy_load_arthur",
"ArthurCallbackHandler",
]

View File

@@ -1,527 +1,6 @@
from __future__ import annotations
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.outputs import LLMResult
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
hash_string,
import_pandas,
import_spacy,
import_textstat,
load_json,
from langchain_community.callbacks.clearml_callback import (
ClearMLCallbackHandler,
import_clearml,
)
if TYPE_CHECKING:
import pandas as pd
def import_clearml() -> Any:
"""Import the clearml python package and raise an error if it is not installed."""
try:
import clearml # noqa: F401
except ImportError:
raise ImportError(
"To use the clearml callback manager you need to have the `clearml` python "
"package installed. Please install it with `pip install clearml`"
)
return clearml
class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""Callback Handler that logs to ClearML.
Parameters:
job_type (str): The type of clearml task such as "inference", "testing" or "qc"
project_name (str): The clearml project name
tags (list): Tags to add to the task
task_name (str): Name of the clearml task
visualize (bool): Whether to visualize the run.
complexity_metrics (bool): Whether to log complexity metrics
stream_logs (bool): Whether to stream callback actions to ClearML
This handler will utilize the associated callback method and formats
the input of each callback function with metadata regarding the state of LLM run,
and adds the response to the list of records for both the {method}_records and
action. It then logs the response to the ClearML console.
"""
def __init__(
self,
task_type: Optional[str] = "inference",
project_name: Optional[str] = "langchain_callback_demo",
tags: Optional[Sequence] = None,
task_name: Optional[str] = None,
visualize: bool = False,
complexity_metrics: bool = False,
stream_logs: bool = False,
) -> None:
"""Initialize callback handler."""
clearml = import_clearml()
spacy = import_spacy()
super().__init__()
self.task_type = task_type
self.project_name = project_name
self.tags = tags
self.task_name = task_name
self.visualize = visualize
self.complexity_metrics = complexity_metrics
self.stream_logs = stream_logs
self.temp_dir = tempfile.TemporaryDirectory()
# Check if ClearML task already exists (e.g. in pipeline)
if clearml.Task.current_task():
self.task = clearml.Task.current_task()
else:
self.task = clearml.Task.init( # type: ignore
task_type=self.task_type,
project_name=self.project_name,
tags=self.tags,
task_name=self.task_name,
output_uri=True,
)
self.logger = self.task.get_logger()
warning = (
"The clearml callback is currently in beta and is subject to change "
"based on updates to `langchain`. Please report any issues to "
"https://github.com/allegroai/clearml/issues with the tag `langchain`."
)
self.logger.report_text(warning, level=30, print_console=True)
self.callback_columns: list = []
self.action_records: list = []
self.complexity_metrics = complexity_metrics
self.visualize = visualize
self.nlp = spacy.load("en_core_web_sm")
def _init_resp(self) -> Dict:
return {k: None for k in self.callback_columns}
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
self.step += 1
self.llm_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_llm_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
for prompt in prompts:
prompt_resp = deepcopy(resp)
prompt_resp["prompts"] = prompt
self.on_llm_start_records.append(prompt_resp)
self.action_records.append(prompt_resp)
if self.stream_logs:
self.logger.report_text(prompt_resp)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.step += 1
self.llm_streams += 1
resp = self._init_resp()
resp.update({"action": "on_llm_new_token", "token": token})
resp.update(self.get_custom_callback_meta())
self.on_llm_token_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.step += 1
self.llm_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_llm_end"})
resp.update(flatten_dict(response.llm_output or {}))
resp.update(self.get_custom_callback_meta())
for generations in response.generations:
for generation in generations:
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
generation_resp.update(self.analyze_text(generation.text))
self.on_llm_end_records.append(generation_resp)
self.action_records.append(generation_resp)
if self.stream_logs:
self.logger.report_text(generation_resp)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.step += 1
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.step += 1
self.chain_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
chain_input = inputs.get("input", inputs.get("human_input"))
if isinstance(chain_input, str):
input_resp = deepcopy(resp)
input_resp["input"] = chain_input
self.on_chain_start_records.append(input_resp)
self.action_records.append(input_resp)
if self.stream_logs:
self.logger.report_text(input_resp)
elif isinstance(chain_input, list):
for inp in chain_input:
input_resp = deepcopy(resp)
input_resp.update(inp)
self.on_chain_start_records.append(input_resp)
self.action_records.append(input_resp)
if self.stream_logs:
self.logger.report_text(input_resp)
else:
raise ValueError("Unexpected data format provided!")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.step += 1
self.chain_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update(
{
"action": "on_chain_end",
"outputs": outputs.get("output", outputs.get("text")),
}
)
resp.update(self.get_custom_callback_meta())
self.on_chain_end_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.step += 1
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_tool_start", "input_str": input_str})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
self.on_tool_start_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.step += 1
self.tool_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_tool_end", "output": output})
resp.update(self.get_custom_callback_meta())
self.on_tool_end_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.step += 1
self.errors += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
"""
self.step += 1
self.text_ctr += 1
resp = self._init_resp()
resp.update({"action": "on_text", "text": text})
resp.update(self.get_custom_callback_meta())
self.on_text_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.step += 1
self.agent_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update(
{
"action": "on_agent_finish",
"output": finish.return_values["output"],
"log": finish.log,
}
)
resp.update(self.get_custom_callback_meta())
self.on_agent_finish_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update(
{
"action": "on_agent_action",
"tool": action.tool,
"tool_input": action.tool_input,
"log": action.log,
}
)
resp.update(self.get_custom_callback_meta())
self.on_agent_action_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def analyze_text(self, text: str) -> dict:
"""Analyze text using textstat and spacy.
Parameters:
text (str): The text to analyze.
Returns:
(dict): A dictionary containing the complexity metrics.
"""
resp = {}
textstat = import_textstat()
spacy = import_spacy()
if self.complexity_metrics:
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"smog_index": textstat.smog_index(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
"automated_readability_index": textstat.automated_readability_index(
text
),
"dale_chall_readability_score": textstat.dale_chall_readability_score(
text
),
"difficult_words": textstat.difficult_words(text),
"linsear_write_formula": textstat.linsear_write_formula(text),
"gunning_fog": textstat.gunning_fog(text),
"text_standard": textstat.text_standard(text),
"fernandez_huerta": textstat.fernandez_huerta(text),
"szigriszt_pazos": textstat.szigriszt_pazos(text),
"gutierrez_polini": textstat.gutierrez_polini(text),
"crawford": textstat.crawford(text),
"gulpease_index": textstat.gulpease_index(text),
"osman": textstat.osman(text),
}
resp.update(text_complexity_metrics)
if self.visualize and self.nlp and self.temp_dir.name is not None:
doc = self.nlp(text)
dep_out = spacy.displacy.render( # type: ignore
doc, style="dep", jupyter=False, page=True
)
dep_output_path = Path(
self.temp_dir.name, hash_string(f"dep-{text}") + ".html"
)
dep_output_path.open("w", encoding="utf-8").write(dep_out)
ent_out = spacy.displacy.render( # type: ignore
doc, style="ent", jupyter=False, page=True
)
ent_output_path = Path(
self.temp_dir.name, hash_string(f"ent-{text}") + ".html"
)
ent_output_path.open("w", encoding="utf-8").write(ent_out)
self.logger.report_media(
"Dependencies Plot", text, local_path=dep_output_path
)
self.logger.report_media("Entities Plot", text, local_path=ent_output_path)
return resp
@staticmethod
def _build_llm_df(
base_df: pd.DataFrame, base_df_fields: Sequence, rename_map: Mapping
) -> pd.DataFrame:
base_df_fields = [field for field in base_df_fields if field in base_df]
rename_map = {
map_entry_k: map_entry_v
for map_entry_k, map_entry_v in rename_map.items()
if map_entry_k in base_df_fields
}
llm_df = base_df[base_df_fields].dropna(axis=1)
if rename_map:
llm_df = llm_df.rename(rename_map, axis=1)
return llm_df
def _create_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session."""
pd = import_pandas()
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df(
base_df=on_llm_end_records_df,
base_df_fields=["step", "prompts"]
+ (["name"] if "name" in on_llm_end_records_df else ["id"]),
rename_map={"step": "prompt_step"},
)
complexity_metrics_columns = []
visualizations_columns: List = []
if self.complexity_metrics:
complexity_metrics_columns = [
"flesch_reading_ease",
"flesch_kincaid_grade",
"smog_index",
"coleman_liau_index",
"automated_readability_index",
"dale_chall_readability_score",
"difficult_words",
"linsear_write_formula",
"gunning_fog",
"text_standard",
"fernandez_huerta",
"szigriszt_pazos",
"gutierrez_polini",
"crawford",
"gulpease_index",
"osman",
]
llm_outputs_df = ClearMLCallbackHandler._build_llm_df(
on_llm_end_records_df,
[
"step",
"text",
"token_usage_total_tokens",
"token_usage_prompt_tokens",
"token_usage_completion_tokens",
]
+ complexity_metrics_columns
+ visualizations_columns,
{"step": "output_step", "text": "output"},
)
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
return session_analysis_df
def flush_tracker(
self,
name: Optional[str] = None,
langchain_asset: Any = None,
finish: bool = False,
) -> None:
"""Flush the tracker and setup the session.
Everything after this will be a new table.
Args:
name: Name of the performed session so far so it is identifiable
langchain_asset: The langchain asset to save.
finish: Whether to finish the run.
Returns:
None
"""
pd = import_pandas()
clearml = import_clearml()
# Log the action records
self.logger.report_table(
"Action Records", name, table_plot=pd.DataFrame(self.action_records)
)
# Session analysis
session_analysis_df = self._create_session_analysis_df()
self.logger.report_table(
"Session Analysis", name, table_plot=session_analysis_df
)
if self.stream_logs:
self.logger.report_text(
{
"action_records": pd.DataFrame(self.action_records),
"session_analysis": session_analysis_df,
}
)
if langchain_asset:
langchain_asset_path = Path(self.temp_dir.name, "model.json")
try:
langchain_asset.save(langchain_asset_path)
# Create output model and connect it to the task
output_model = clearml.OutputModel(
task=self.task, config_text=load_json(langchain_asset_path)
)
output_model.update_weights(
weights_filename=str(langchain_asset_path),
auto_delete_file=False,
target_filename=name,
)
except ValueError:
langchain_asset.save_agent(langchain_asset_path)
output_model = clearml.OutputModel(
task=self.task, config_text=load_json(langchain_asset_path)
)
output_model.update_weights(
weights_filename=str(langchain_asset_path),
auto_delete_file=False,
target_filename=name,
)
except NotImplementedError as e:
print("Could not save model.")
print(repr(e))
pass
# Cleanup after adding everything to ClearML
self.task.flush(wait_for_uploads=True)
self.temp_dir.cleanup()
self.temp_dir = tempfile.TemporaryDirectory()
self.reset_callback_meta()
if finish:
self.task.close()
__all__ = ["import_clearml", "ClearMLCallbackHandler"]

View File

@@ -1,645 +1,17 @@
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.outputs import Generation, LLMResult
import langchain
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
import_pandas,
import_spacy,
import_textstat,
from langchain_community.callbacks.comet_ml_callback import (
LANGCHAIN_MODEL_NAME,
CometCallbackHandler,
_fetch_text_complexity_metrics,
_get_experiment,
_summarize_metrics_for_generated_outputs,
import_comet_ml,
)
LANGCHAIN_MODEL_NAME = "langchain-model"
def import_comet_ml() -> Any:
"""Import comet_ml and raise an error if it is not installed."""
try:
import comet_ml # noqa: F401
except ImportError:
raise ImportError(
"To use the comet_ml callback manager you need to have the "
"`comet_ml` python package installed. Please install it with"
" `pip install comet_ml`"
)
return comet_ml
def _get_experiment(
workspace: Optional[str] = None, project_name: Optional[str] = None
) -> Any:
comet_ml = import_comet_ml()
experiment = comet_ml.Experiment( # type: ignore
workspace=workspace,
project_name=project_name,
)
return experiment
def _fetch_text_complexity_metrics(text: str) -> dict:
textstat = import_textstat()
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"smog_index": textstat.smog_index(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
"automated_readability_index": textstat.automated_readability_index(text),
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
"difficult_words": textstat.difficult_words(text),
"linsear_write_formula": textstat.linsear_write_formula(text),
"gunning_fog": textstat.gunning_fog(text),
"text_standard": textstat.text_standard(text),
"fernandez_huerta": textstat.fernandez_huerta(text),
"szigriszt_pazos": textstat.szigriszt_pazos(text),
"gutierrez_polini": textstat.gutierrez_polini(text),
"crawford": textstat.crawford(text),
"gulpease_index": textstat.gulpease_index(text),
"osman": textstat.osman(text),
}
return text_complexity_metrics
def _summarize_metrics_for_generated_outputs(metrics: Sequence) -> dict:
pd = import_pandas()
metrics_df = pd.DataFrame(metrics)
metrics_summary = metrics_df.describe()
return metrics_summary.to_dict()
class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""Callback Handler that logs to Comet.
Parameters:
job_type (str): The type of comet_ml task such as "inference",
"testing" or "qc"
project_name (str): The comet_ml project name
tags (list): Tags to add to the task
task_name (str): Name of the comet_ml task
visualize (bool): Whether to visualize the run.
complexity_metrics (bool): Whether to log complexity metrics
stream_logs (bool): Whether to stream callback actions to Comet
This handler will utilize the associated callback method and formats
the input of each callback function with metadata regarding the state of LLM run,
and adds the response to the list of records for both the {method}_records and
action. It then logs the response to Comet.
"""
def __init__(
self,
task_type: Optional[str] = "inference",
workspace: Optional[str] = None,
project_name: Optional[str] = None,
tags: Optional[Sequence] = None,
name: Optional[str] = None,
visualizations: Optional[List[str]] = None,
complexity_metrics: bool = False,
custom_metrics: Optional[Callable] = None,
stream_logs: bool = True,
) -> None:
"""Initialize callback handler."""
self.comet_ml = import_comet_ml()
super().__init__()
self.task_type = task_type
self.workspace = workspace
self.project_name = project_name
self.tags = tags
self.visualizations = visualizations
self.complexity_metrics = complexity_metrics
self.custom_metrics = custom_metrics
self.stream_logs = stream_logs
self.temp_dir = tempfile.TemporaryDirectory()
self.experiment = _get_experiment(workspace, project_name)
self.experiment.log_other("Created from", "langchain")
if tags:
self.experiment.add_tags(tags)
self.name = name
if self.name:
self.experiment.set_name(self.name)
warning = (
"The comet_ml callback is currently in beta and is subject to change "
"based on updates to `langchain`. Please report any issues to "
"https://github.com/comet-ml/issue-tracking/issues with the tag "
"`langchain`."
)
self.comet_ml.LOGGER.warning(warning)
self.callback_columns: list = []
self.action_records: list = []
self.complexity_metrics = complexity_metrics
if self.visualizations:
spacy = import_spacy()
self.nlp = spacy.load("en_core_web_sm")
else:
self.nlp = None
def _init_resp(self) -> Dict:
return {k: None for k in self.callback_columns}
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
self.step += 1
self.llm_starts += 1
self.starts += 1
metadata = self._init_resp()
metadata.update({"action": "on_llm_start"})
metadata.update(flatten_dict(serialized))
metadata.update(self.get_custom_callback_meta())
for prompt in prompts:
prompt_resp = deepcopy(metadata)
prompt_resp["prompts"] = prompt
self.on_llm_start_records.append(prompt_resp)
self.action_records.append(prompt_resp)
if self.stream_logs:
self._log_stream(prompt, metadata, self.step)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.step += 1
self.llm_streams += 1
resp = self._init_resp()
resp.update({"action": "on_llm_new_token", "token": token})
resp.update(self.get_custom_callback_meta())
self.action_records.append(resp)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.step += 1
self.llm_ends += 1
self.ends += 1
metadata = self._init_resp()
metadata.update({"action": "on_llm_end"})
metadata.update(flatten_dict(response.llm_output or {}))
metadata.update(self.get_custom_callback_meta())
output_complexity_metrics = []
output_custom_metrics = []
for prompt_idx, generations in enumerate(response.generations):
for gen_idx, generation in enumerate(generations):
text = generation.text
generation_resp = deepcopy(metadata)
generation_resp.update(flatten_dict(generation.dict()))
complexity_metrics = self._get_complexity_metrics(text)
if complexity_metrics:
output_complexity_metrics.append(complexity_metrics)
generation_resp.update(complexity_metrics)
custom_metrics = self._get_custom_metrics(
generation, prompt_idx, gen_idx
)
if custom_metrics:
output_custom_metrics.append(custom_metrics)
generation_resp.update(custom_metrics)
if self.stream_logs:
self._log_stream(text, metadata, self.step)
self.action_records.append(generation_resp)
self.on_llm_end_records.append(generation_resp)
self._log_text_metrics(output_complexity_metrics, step=self.step)
self._log_text_metrics(output_custom_metrics, step=self.step)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.step += 1
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.step += 1
self.chain_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
for chain_input_key, chain_input_val in inputs.items():
if isinstance(chain_input_val, str):
input_resp = deepcopy(resp)
if self.stream_logs:
self._log_stream(chain_input_val, resp, self.step)
input_resp.update({chain_input_key: chain_input_val})
self.action_records.append(input_resp)
else:
self.comet_ml.LOGGER.warning(
f"Unexpected data format provided! "
f"Input Value for {chain_input_key} will not be logged"
)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.step += 1
self.chain_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_chain_end"})
resp.update(self.get_custom_callback_meta())
for chain_output_key, chain_output_val in outputs.items():
if isinstance(chain_output_val, str):
output_resp = deepcopy(resp)
if self.stream_logs:
self._log_stream(chain_output_val, resp, self.step)
output_resp.update({chain_output_key: chain_output_val})
self.action_records.append(output_resp)
else:
self.comet_ml.LOGGER.warning(
f"Unexpected data format provided! "
f"Output Value for {chain_output_key} will not be logged"
)
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.step += 1
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_tool_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
if self.stream_logs:
self._log_stream(input_str, resp, self.step)
resp.update({"input_str": input_str})
self.action_records.append(resp)
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.step += 1
self.tool_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_tool_end"})
resp.update(self.get_custom_callback_meta())
if self.stream_logs:
self._log_stream(output, resp, self.step)
resp.update({"output": output})
self.action_records.append(resp)
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.step += 1
self.errors += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
"""
self.step += 1
self.text_ctr += 1
resp = self._init_resp()
resp.update({"action": "on_text"})
resp.update(self.get_custom_callback_meta())
if self.stream_logs:
self._log_stream(text, resp, self.step)
resp.update({"text": text})
self.action_records.append(resp)
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.step += 1
self.agent_ends += 1
self.ends += 1
resp = self._init_resp()
output = finish.return_values["output"]
log = finish.log
resp.update({"action": "on_agent_finish", "log": log})
resp.update(self.get_custom_callback_meta())
if self.stream_logs:
self._log_stream(output, resp, self.step)
resp.update({"output": output})
self.action_records.append(resp)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.step += 1
self.tool_starts += 1
self.starts += 1
tool = action.tool
tool_input = str(action.tool_input)
log = action.log
resp = self._init_resp()
resp.update({"action": "on_agent_action", "log": log, "tool": tool})
resp.update(self.get_custom_callback_meta())
if self.stream_logs:
self._log_stream(tool_input, resp, self.step)
resp.update({"tool_input": tool_input})
self.action_records.append(resp)
def _get_complexity_metrics(self, text: str) -> dict:
"""Compute text complexity metrics using textstat.
Parameters:
text (str): The text to analyze.
Returns:
(dict): A dictionary containing the complexity metrics.
"""
resp = {}
if self.complexity_metrics:
text_complexity_metrics = _fetch_text_complexity_metrics(text)
resp.update(text_complexity_metrics)
return resp
def _get_custom_metrics(
self, generation: Generation, prompt_idx: int, gen_idx: int
) -> dict:
"""Compute Custom Metrics for an LLM Generated Output
Args:
generation (LLMResult): Output generation from an LLM
prompt_idx (int): List index of the input prompt
gen_idx (int): List index of the generated output
Returns:
dict: A dictionary containing the custom metrics.
"""
resp = {}
if self.custom_metrics:
custom_metrics = self.custom_metrics(generation, prompt_idx, gen_idx)
resp.update(custom_metrics)
return resp
def flush_tracker(
self,
langchain_asset: Any = None,
task_type: Optional[str] = "inference",
workspace: Optional[str] = None,
project_name: Optional[str] = "comet-langchain-demo",
tags: Optional[Sequence] = None,
name: Optional[str] = None,
visualizations: Optional[List[str]] = None,
complexity_metrics: bool = False,
custom_metrics: Optional[Callable] = None,
finish: bool = False,
reset: bool = False,
) -> None:
"""Flush the tracker and setup the session.
Everything after this will be a new table.
Args:
name: Name of the performed session so far so it is identifiable
langchain_asset: The langchain asset to save.
finish: Whether to finish the run.
Returns:
None
"""
self._log_session(langchain_asset)
if langchain_asset:
try:
self._log_model(langchain_asset)
except Exception:
self.comet_ml.LOGGER.error(
"Failed to export agent or LLM to Comet",
exc_info=True,
extra={"show_traceback": True},
)
if finish:
self.experiment.end()
if reset:
self._reset(
task_type,
workspace,
project_name,
tags,
name,
visualizations,
complexity_metrics,
custom_metrics,
)
def _log_stream(self, prompt: str, metadata: dict, step: int) -> None:
self.experiment.log_text(prompt, metadata=metadata, step=step)
def _log_model(self, langchain_asset: Any) -> None:
model_parameters = self._get_llm_parameters(langchain_asset)
self.experiment.log_parameters(model_parameters, prefix="model")
langchain_asset_path = Path(self.temp_dir.name, "model.json")
model_name = self.name if self.name else LANGCHAIN_MODEL_NAME
try:
if hasattr(langchain_asset, "save"):
langchain_asset.save(langchain_asset_path)
self.experiment.log_model(model_name, str(langchain_asset_path))
except (ValueError, AttributeError, NotImplementedError) as e:
if hasattr(langchain_asset, "save_agent"):
langchain_asset.save_agent(langchain_asset_path)
self.experiment.log_model(model_name, str(langchain_asset_path))
else:
self.comet_ml.LOGGER.error(
f"{e}"
" Could not save Langchain Asset "
f"for {langchain_asset.__class__.__name__}"
)
def _log_session(self, langchain_asset: Optional[Any] = None) -> None:
try:
llm_session_df = self._create_session_analysis_dataframe(langchain_asset)
# Log the cleaned dataframe as a table
self.experiment.log_table("langchain-llm-session.csv", llm_session_df)
except Exception:
self.comet_ml.LOGGER.warning(
"Failed to log session data to Comet",
exc_info=True,
extra={"show_traceback": True},
)
try:
metadata = {"langchain_version": str(langchain.__version__)}
# Log the langchain low-level records as a JSON file directly
self.experiment.log_asset_data(
self.action_records, "langchain-action_records.json", metadata=metadata
)
except Exception:
self.comet_ml.LOGGER.warning(
"Failed to log session data to Comet",
exc_info=True,
extra={"show_traceback": True},
)
try:
self._log_visualizations(llm_session_df)
except Exception:
self.comet_ml.LOGGER.warning(
"Failed to log visualizations to Comet",
exc_info=True,
extra={"show_traceback": True},
)
def _log_text_metrics(self, metrics: Sequence[dict], step: int) -> None:
if not metrics:
return
metrics_summary = _summarize_metrics_for_generated_outputs(metrics)
for key, value in metrics_summary.items():
self.experiment.log_metrics(value, prefix=key, step=step)
def _log_visualizations(self, session_df: Any) -> None:
if not (self.visualizations and self.nlp):
return
spacy = import_spacy()
prompts = session_df["prompts"].tolist()
outputs = session_df["text"].tolist()
for idx, (prompt, output) in enumerate(zip(prompts, outputs)):
doc = self.nlp(output)
sentence_spans = list(doc.sents)
for visualization in self.visualizations:
try:
html = spacy.displacy.render(
sentence_spans,
style=visualization,
options={"compact": True},
jupyter=False,
page=True,
)
self.experiment.log_asset_data(
html,
name=f"langchain-viz-{visualization}-{idx}.html",
metadata={"prompt": prompt},
step=idx,
)
except Exception as e:
self.comet_ml.LOGGER.warning(
e, exc_info=True, extra={"show_traceback": True}
)
return
def _reset(
self,
task_type: Optional[str] = None,
workspace: Optional[str] = None,
project_name: Optional[str] = None,
tags: Optional[Sequence] = None,
name: Optional[str] = None,
visualizations: Optional[List[str]] = None,
complexity_metrics: bool = False,
custom_metrics: Optional[Callable] = None,
) -> None:
_task_type = task_type if task_type else self.task_type
_workspace = workspace if workspace else self.workspace
_project_name = project_name if project_name else self.project_name
_tags = tags if tags else self.tags
_name = name if name else self.name
_visualizations = visualizations if visualizations else self.visualizations
_complexity_metrics = (
complexity_metrics if complexity_metrics else self.complexity_metrics
)
_custom_metrics = custom_metrics if custom_metrics else self.custom_metrics
self.__init__( # type: ignore
task_type=_task_type,
workspace=_workspace,
project_name=_project_name,
tags=_tags,
name=_name,
visualizations=_visualizations,
complexity_metrics=_complexity_metrics,
custom_metrics=_custom_metrics,
)
self.reset_callback_meta()
self.temp_dir = tempfile.TemporaryDirectory()
def _create_session_analysis_dataframe(self, langchain_asset: Any = None) -> dict:
pd = import_pandas()
llm_parameters = self._get_llm_parameters(langchain_asset)
num_generations_per_prompt = llm_parameters.get("n", 1)
llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
# Repeat each input row based on the number of outputs generated per prompt
llm_start_records_df = llm_start_records_df.loc[
llm_start_records_df.index.repeat(num_generations_per_prompt)
].reset_index(drop=True)
llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
llm_session_df = pd.merge(
llm_start_records_df,
llm_end_records_df,
left_index=True,
right_index=True,
suffixes=["_llm_start", "_llm_end"],
)
return llm_session_df
def _get_llm_parameters(self, langchain_asset: Any = None) -> dict:
if not langchain_asset:
return {}
try:
if hasattr(langchain_asset, "agent"):
llm_parameters = langchain_asset.agent.llm_chain.llm.dict()
elif hasattr(langchain_asset, "llm_chain"):
llm_parameters = langchain_asset.llm_chain.llm.dict()
elif hasattr(langchain_asset, "llm"):
llm_parameters = langchain_asset.llm.dict()
else:
llm_parameters = langchain_asset.dict()
except Exception:
return {}
return llm_parameters
__all__ = [
"LANGCHAIN_MODEL_NAME",
"import_comet_ml",
"_get_experiment",
"_fetch_text_complexity_metrics",
"_summarize_metrics_for_generated_outputs",
"CometCallbackHandler",
]

View File

@@ -1,183 +1,3 @@
# flake8: noqa
import os
import warnings
from typing import Any, Dict, List, Optional, Union
from langchain_community.callbacks.confident_callback import DeepEvalCallbackHandler
from langchain.callbacks.base import BaseCallbackHandler
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.outputs import LLMResult
class DeepEvalCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs into deepeval.
Args:
implementation_name: name of the `implementation` in deepeval
metrics: A list of metrics
Raises:
ImportError: if the `deepeval` package is not installed.
Examples:
>>> from langchain.llms import OpenAI
>>> from langchain.callbacks import DeepEvalCallbackHandler
>>> from deepeval.metrics import AnswerRelevancy
>>> metric = AnswerRelevancy(minimum_score=0.3)
>>> deepeval_callback = DeepEvalCallbackHandler(
... implementation_name="exampleImplementation",
... metrics=[metric],
... )
>>> llm = OpenAI(
... temperature=0,
... callbacks=[deepeval_callback],
... verbose=True,
... openai_api_key="API_KEY_HERE",
... )
>>> llm.generate([
... "What is the best evaluation tool out there? (no bias at all)",
... ])
"Deepeval, no doubt about it."
"""
REPO_URL: str = "https://github.com/confident-ai/deepeval"
ISSUES_URL: str = f"{REPO_URL}/issues"
BLOG_URL: str = "https://docs.confident-ai.com" # noqa: E501
def __init__(
self,
metrics: List[Any],
implementation_name: Optional[str] = None,
) -> None:
"""Initializes the `deepevalCallbackHandler`.
Args:
implementation_name: Name of the implementation you want.
metrics: What metrics do you want to track?
Raises:
ImportError: if the `deepeval` package is not installed.
ConnectionError: if the connection to deepeval fails.
"""
super().__init__()
# Import deepeval (not via `import_deepeval` to keep hints in IDEs)
try:
import deepeval # ignore: F401,I001
except ImportError:
raise ImportError(
"""To use the deepeval callback manager you need to have the
`deepeval` Python package installed. Please install it with
`pip install deepeval`"""
)
if os.path.exists(".deepeval"):
warnings.warn(
"""You are currently not logging anything to the dashboard, we
recommend using `deepeval login`."""
)
# Set the deepeval variables
self.implementation_name = implementation_name
self.metrics = metrics
warnings.warn(
(
"The `DeepEvalCallbackHandler` is currently in beta and is subject to"
" change based on updates to `langchain`. Please report any issues to"
f" {self.ISSUES_URL} as an `integration` issue."
),
)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Store the prompts"""
self.prompts = prompts
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing when a new token is generated."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Log records to deepeval when an LLM ends."""
from deepeval.metrics.answer_relevancy import AnswerRelevancy
from deepeval.metrics.bias_classifier import UnBiasedMetric
from deepeval.metrics.metric import Metric
from deepeval.metrics.toxic_classifier import NonToxicMetric
for metric in self.metrics:
for i, generation in enumerate(response.generations):
# Here, we only measure the first generation's output
output = generation[0].text
query = self.prompts[i]
if isinstance(metric, AnswerRelevancy):
result = metric.measure(
output=output,
query=query,
)
print(f"Answer Relevancy: {result}")
elif isinstance(metric, UnBiasedMetric):
score = metric.measure(output)
print(f"Bias Score: {score}")
elif isinstance(metric, NonToxicMetric):
score = metric.measure(output)
print(f"Toxic Score: {score}")
else:
raise ValueError(
f"""Metric {metric.__name__} is not supported by deepeval
callbacks."""
)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM outputs an error."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Do nothing when chain starts"""
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Do nothing when chain ends."""
pass
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM chain outputs an error."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing when tool starts."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing when agent takes a specific action."""
pass
def on_tool_end(
self,
output: str,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Do nothing when tool ends."""
pass
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when tool outputs an error."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Do nothing"""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Do nothing"""
pass
__all__ = ["DeepEvalCallbackHandler"]

View File

@@ -1,195 +1,6 @@
"""Callback handler for Context AI"""
import os
from typing import Any, Dict, List
from uuid import UUID
from langchain_community.callbacks.context_callback import (
ContextCallbackHandler,
import_context,
)
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult
from langchain.callbacks.base import BaseCallbackHandler
def import_context() -> Any:
"""Import the `getcontext` package."""
try:
import getcontext # noqa: F401
from getcontext.generated.models import (
Conversation,
Message,
MessageRole,
Rating,
)
from getcontext.token import Credential # noqa: F401
except ImportError:
raise ImportError(
"To use the context callback manager you need to have the "
"`getcontext` python package installed (version >=0.3.0). "
"Please install it with `pip install --upgrade python-context`"
)
return getcontext, Credential, Conversation, Message, MessageRole, Rating
class ContextCallbackHandler(BaseCallbackHandler):
"""Callback Handler that records transcripts to the Context service.
(https://context.ai).
Keyword Args:
token (optional): The token with which to authenticate requests to Context.
Visit https://with.context.ai/settings to generate a token.
If not provided, the value of the `CONTEXT_TOKEN` environment
variable will be used.
Raises:
ImportError: if the `context-python` package is not installed.
Chat Example:
>>> from langchain.llms import ChatOpenAI
>>> from langchain.callbacks import ContextCallbackHandler
>>> context_callback = ContextCallbackHandler(
... token="<CONTEXT_TOKEN_HERE>",
... )
>>> chat = ChatOpenAI(
... temperature=0,
... headers={"user_id": "123"},
... callbacks=[context_callback],
... openai_api_key="API_KEY_HERE",
... )
>>> messages = [
... SystemMessage(content="You translate English to French."),
... HumanMessage(content="I love programming with LangChain."),
... ]
>>> chat(messages)
Chain Example:
>>> from langchain.chains import LLMChain
>>> from langchain.chat_models import ChatOpenAI
>>> from langchain.callbacks import ContextCallbackHandler
>>> context_callback = ContextCallbackHandler(
... token="<CONTEXT_TOKEN_HERE>",
... )
>>> human_message_prompt = HumanMessagePromptTemplate(
... prompt=PromptTemplate(
... template="What is a good name for a company that makes {product}?",
... input_variables=["product"],
... ),
... )
>>> chat_prompt_template = ChatPromptTemplate.from_messages(
... [human_message_prompt]
... )
>>> callback = ContextCallbackHandler(token)
>>> # Note: the same callback object must be shared between the
... LLM and the chain.
>>> chat = ChatOpenAI(temperature=0.9, callbacks=[callback])
>>> chain = LLMChain(
... llm=chat,
... prompt=chat_prompt_template,
... callbacks=[callback]
... )
>>> chain.run("colorful socks")
"""
def __init__(self, token: str = "", verbose: bool = False, **kwargs: Any) -> None:
(
self.context,
self.credential,
self.conversation_model,
self.message_model,
self.message_role_model,
self.rating_model,
) = import_context()
token = token or os.environ.get("CONTEXT_TOKEN") or ""
self.client = self.context.ContextAPI(credential=self.credential(token))
self.chain_run_id = None
self.llm_model = None
self.messages: List[Any] = []
self.metadata: Dict[str, str] = {}
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
**kwargs: Any,
) -> Any:
"""Run when the chat model is started."""
llm_model = kwargs.get("invocation_params", {}).get("model", None)
if llm_model is not None:
self.metadata["model"] = llm_model
if len(messages) == 0:
return
for message in messages[0]:
role = self.message_role_model.SYSTEM
if message.type == "human":
role = self.message_role_model.USER
elif message.type == "system":
role = self.message_role_model.SYSTEM
elif message.type == "ai":
role = self.message_role_model.ASSISTANT
self.messages.append(
self.message_model(
message=message.content,
role=role,
)
)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends."""
if len(response.generations) == 0 or len(response.generations[0]) == 0:
return
if not self.chain_run_id:
generation = response.generations[0][0]
self.messages.append(
self.message_model(
message=generation.text,
role=self.message_role_model.ASSISTANT,
)
)
self._log_conversation()
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts."""
self.chain_run_id = kwargs.get("run_id", None)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends."""
self.messages.append(
self.message_model(
message=outputs["text"],
role=self.message_role_model.ASSISTANT,
)
)
self._log_conversation()
self.chain_run_id = None
def _log_conversation(self) -> None:
"""Log the conversation to the context API."""
if len(self.messages) == 0:
return
self.client.log.conversation_upsert(
body={
"conversation": self.conversation_model(
messages=self.messages,
metadata=self.metadata,
)
}
)
self.messages = []
self.metadata = {}
__all__ = ["import_context", "ContextCallbackHandler"]

View File

@@ -1,371 +1,7 @@
"""FlyteKit callback handler."""
from __future__ import annotations
import logging
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.outputs import LLMResult
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
import_pandas,
import_spacy,
import_textstat,
from langchain_community.callbacks.flyte_callback import (
FlyteCallbackHandler,
analyze_text,
import_flytekit,
)
if TYPE_CHECKING:
import flytekit
from flytekitplugins.deck import renderer
logger = logging.getLogger(__name__)
def import_flytekit() -> Tuple[flytekit, renderer]:
"""Import flytekit and flytekitplugins-deck-standard."""
try:
import flytekit # noqa: F401
from flytekitplugins.deck import renderer # noqa: F401
except ImportError:
raise ImportError(
"To use the flyte callback manager you need"
"to have the `flytekit` and `flytekitplugins-deck-standard`"
"packages installed. Please install them with `pip install flytekit`"
"and `pip install flytekitplugins-deck-standard`."
)
return flytekit, renderer
def analyze_text(
text: str,
nlp: Any = None,
textstat: Any = None,
) -> dict:
"""Analyze text using textstat and spacy.
Parameters:
text (str): The text to analyze.
nlp (spacy.lang): The spacy language model to use for visualization.
Returns:
(dict): A dictionary containing the complexity metrics and visualization
files serialized to HTML string.
"""
resp: Dict[str, Any] = {}
if textstat is not None:
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"smog_index": textstat.smog_index(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
"automated_readability_index": textstat.automated_readability_index(text),
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
"difficult_words": textstat.difficult_words(text),
"linsear_write_formula": textstat.linsear_write_formula(text),
"gunning_fog": textstat.gunning_fog(text),
"fernandez_huerta": textstat.fernandez_huerta(text),
"szigriszt_pazos": textstat.szigriszt_pazos(text),
"gutierrez_polini": textstat.gutierrez_polini(text),
"crawford": textstat.crawford(text),
"gulpease_index": textstat.gulpease_index(text),
"osman": textstat.osman(text),
}
resp.update({"text_complexity_metrics": text_complexity_metrics})
resp.update(text_complexity_metrics)
if nlp is not None:
spacy = import_spacy()
doc = nlp(text)
dep_out = spacy.displacy.render( # type: ignore
doc, style="dep", jupyter=False, page=True
)
ent_out = spacy.displacy.render( # type: ignore
doc, style="ent", jupyter=False, page=True
)
text_visualizations = {
"dependency_tree": dep_out,
"entities": ent_out,
}
resp.update(text_visualizations)
return resp
class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""This callback handler that is used within a Flyte task."""
def __init__(self) -> None:
"""Initialize callback handler."""
flytekit, renderer = import_flytekit()
self.pandas = import_pandas()
self.textstat = None
try:
self.textstat = import_textstat()
except ImportError:
logger.warning(
"Textstat library is not installed. \
It may result in the inability to log \
certain metrics that can be captured with Textstat."
)
spacy = None
try:
spacy = import_spacy()
except ImportError:
logger.warning(
"Spacy library is not installed. \
It may result in the inability to log \
certain metrics that can be captured with Spacy."
)
super().__init__()
self.nlp = None
if spacy:
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
logger.warning(
"FlyteCallbackHandler uses spacy's en_core_web_sm model"
" for certain metrics. To download,"
" run the following command in your terminal:"
" `python -m spacy download en_core_web_sm`"
)
self.table_renderer = renderer.TableRenderer
self.markdown_renderer = renderer.MarkdownRenderer
self.deck = flytekit.Deck(
"LangChain Metrics",
self.markdown_renderer().to_html("## LangChain Metrics"),
)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
self.step += 1
self.llm_starts += 1
self.starts += 1
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
prompt_responses = []
for prompt in prompts:
prompt_responses.append(prompt)
resp.update({"prompts": prompt_responses})
self.deck.append(self.markdown_renderer().to_html("### LLM Start"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.step += 1
self.llm_ends += 1
self.ends += 1
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_end"})
resp.update(flatten_dict(response.llm_output or {}))
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### LLM End"))
self.deck.append(self.table_renderer().to_html(self.pandas.DataFrame([resp])))
for generations in response.generations:
for generation in generations:
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
if self.nlp or self.textstat:
generation_resp.update(
analyze_text(
generation.text, nlp=self.nlp, textstat=self.textstat
)
)
complexity_metrics: Dict[str, float] = generation_resp.pop(
"text_complexity_metrics"
) # type: ignore # noqa: E501
self.deck.append(
self.markdown_renderer().to_html("#### Text Complexity Metrics")
)
self.deck.append(
self.table_renderer().to_html(
self.pandas.DataFrame([complexity_metrics])
)
+ "\n"
)
dependency_tree = generation_resp["dependency_tree"]
self.deck.append(
self.markdown_renderer().to_html("#### Dependency Tree")
)
self.deck.append(dependency_tree)
entities = generation_resp["entities"]
self.deck.append(self.markdown_renderer().to_html("#### Entities"))
self.deck.append(entities)
else:
self.deck.append(
self.markdown_renderer().to_html("#### Generated Response")
)
self.deck.append(self.markdown_renderer().to_html(generation.text))
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.step += 1
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.step += 1
self.chain_starts += 1
self.starts += 1
resp: Dict[str, Any] = {}
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
input_resp = deepcopy(resp)
input_resp["inputs"] = chain_input
self.deck.append(self.markdown_renderer().to_html("### Chain Start"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([input_resp])) + "\n"
)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.step += 1
self.chain_ends += 1
self.ends += 1
resp: Dict[str, Any] = {}
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
resp.update({"action": "on_chain_end", "outputs": chain_output})
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### Chain End"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.step += 1
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp: Dict[str, Any] = {}
resp.update({"action": "on_tool_start", "input_str": input_str})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### Tool Start"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.step += 1
self.tool_ends += 1
self.ends += 1
resp: Dict[str, Any] = {}
resp.update({"action": "on_tool_end", "output": output})
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### Tool End"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.step += 1
self.errors += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
"""
self.step += 1
self.text_ctr += 1
resp: Dict[str, Any] = {}
resp.update({"action": "on_text", "text": text})
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### On Text"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.step += 1
self.agent_ends += 1
self.ends += 1
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_finish",
"output": finish.return_values["output"],
"log": finish.log,
}
)
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### Agent Finish"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_action",
"tool": action.tool,
"tool_input": action.tool_input,
"log": action.log,
}
)
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### Agent Action"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
__all__ = ["import_flytekit", "analyze_text", "FlyteCallbackHandler"]

View File

@@ -1,88 +1,15 @@
from typing import Any, Awaitable, Callable, Dict, Optional
from uuid import UUID
from langchain_community.callbacks.human import (
AsyncHumanApprovalCallbackHandler,
HumanApprovalCallbackHandler,
HumanRejectedException,
_default_approve,
_default_true,
)
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
def _default_approve(_input: str) -> bool:
msg = (
"Do you approve of the following input? "
"Anything except 'Y'/'Yes' (case-insensitive) will be treated as a no."
)
msg += "\n\n" + _input + "\n"
resp = input(msg)
return resp.lower() in ("yes", "y")
async def _adefault_approve(_input: str) -> bool:
msg = (
"Do you approve of the following input? "
"Anything except 'Y'/'Yes' (case-insensitive) will be treated as a no."
)
msg += "\n\n" + _input + "\n"
resp = input(msg)
return resp.lower() in ("yes", "y")
def _default_true(_: Dict[str, Any]) -> bool:
return True
class HumanRejectedException(Exception):
"""Exception to raise when a person manually review and rejects a value."""
class HumanApprovalCallbackHandler(BaseCallbackHandler):
"""Callback for manually validating values."""
raise_error: bool = True
def __init__(
self,
approve: Callable[[Any], bool] = _default_approve,
should_check: Callable[[Dict[str, Any]], bool] = _default_true,
):
self._approve = approve
self._should_check = should_check
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if self._should_check(serialized) and not self._approve(input_str):
raise HumanRejectedException(
f"Inputs {input_str} to tool {serialized} were rejected."
)
class AsyncHumanApprovalCallbackHandler(AsyncCallbackHandler):
"""Asynchronous callback for manually validating values."""
raise_error: bool = True
def __init__(
self,
approve: Callable[[Any], Awaitable[bool]] = _adefault_approve,
should_check: Callable[[Dict[str, Any]], bool] = _default_true,
):
self._approve = approve
self._should_check = should_check
async def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if self._should_check(serialized) and not await self._approve(input_str):
raise HumanRejectedException(
f"Inputs {input_str} to tool {serialized} were rejected."
)
__all__ = [
"_default_approve",
"_default_true",
"HumanRejectedException",
"HumanApprovalCallbackHandler",
"AsyncHumanApprovalCallbackHandler",
]

View File

@@ -1,267 +1,13 @@
import time
from typing import Any, Dict, List, Optional, cast
from langchain_community.callbacks.infino_callback import (
InfinoCallbackHandler,
get_num_tokens,
import_infino,
import_tiktoken,
)
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult
from langchain.callbacks.base import BaseCallbackHandler
def import_infino() -> Any:
"""Import the infino client."""
try:
from infinopy import InfinoClient
except ImportError:
raise ImportError(
"To use the Infino callbacks manager you need to have the"
" `infinopy` python package installed."
"Please install it with `pip install infinopy`"
)
return InfinoClient()
def import_tiktoken() -> Any:
"""Import tiktoken for counting tokens for OpenAI models."""
try:
import tiktoken
except ImportError:
raise ImportError(
"To use the ChatOpenAI model with Infino callback manager, you need to "
"have the `tiktoken` python package installed."
"Please install it with `pip install tiktoken`"
)
return tiktoken
def get_num_tokens(string: str, openai_model_name: str) -> int:
"""Calculate num tokens for OpenAI with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/main
/examples/How_to_count_tokens_with_tiktoken.ipynb
"""
tiktoken = import_tiktoken()
encoding = tiktoken.encoding_for_model(openai_model_name)
num_tokens = len(encoding.encode(string))
return num_tokens
class InfinoCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs to Infino."""
def __init__(
self,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
verbose: bool = False,
) -> None:
# Set Infino client
self.client = import_infino()
self.model_id = model_id
self.model_version = model_version
self.verbose = verbose
self.is_chat_openai_model = False
self.chat_openai_model_name = "gpt-3.5-turbo"
def _send_to_infino(
self,
key: str,
value: Any,
is_ts: bool = True,
) -> None:
"""Send the key-value to Infino.
Parameters:
key (str): the key to send to Infino.
value (Any): the value to send to Infino.
is_ts (bool): if True, the value is part of a time series, else it
is sent as a log message.
"""
payload = {
"date": int(time.time()),
key: value,
"labels": {
"model_id": self.model_id,
"model_version": self.model_version,
},
}
if self.verbose:
print(f"Tracking {key} with Infino: {payload}")
# Append to Infino time series only if is_ts is True, otherwise
# append to Infino log.
if is_ts:
self.client.append_ts(payload)
else:
self.client.append_log(payload)
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
**kwargs: Any,
) -> None:
"""Log the prompts to Infino, and set start time and error flag."""
for prompt in prompts:
self._send_to_infino("prompt", prompt, is_ts=False)
# Set the error flag to indicate no error (this will get overridden
# in on_llm_error if an error occurs).
self.error = 0
# Set the start time (so that we can calculate the request
# duration in on_llm_end).
self.start_time = time.time()
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing when a new token is generated."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Log the latency, error, token usage, and response to Infino."""
# Calculate and track the request latency.
self.end_time = time.time()
duration = self.end_time - self.start_time
self._send_to_infino("latency", duration)
# Track success or error flag.
self._send_to_infino("error", self.error)
# Track prompt response.
for generations in response.generations:
for generation in generations:
self._send_to_infino("prompt_response", generation.text, is_ts=False)
# Track token usage (for non-chat models).
if (response.llm_output is not None) and isinstance(response.llm_output, Dict):
token_usage = response.llm_output["token_usage"]
if token_usage is not None:
prompt_tokens = token_usage["prompt_tokens"]
total_tokens = token_usage["total_tokens"]
completion_tokens = token_usage["completion_tokens"]
self._send_to_infino("prompt_tokens", prompt_tokens)
self._send_to_infino("total_tokens", total_tokens)
self._send_to_infino("completion_tokens", completion_tokens)
# Track completion token usage (for openai chat models).
if self.is_chat_openai_model:
messages = " ".join(
generation.message.content # type: ignore[attr-defined]
for generation in generations
)
completion_tokens = get_num_tokens(
messages, openai_model_name=self.chat_openai_model_name
)
self._send_to_infino("completion_tokens", completion_tokens)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Set the error flag."""
self.error = 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Do nothing when LLM chain starts."""
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Do nothing when LLM chain ends."""
pass
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Need to log the error."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing when tool starts."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing when agent takes a specific action."""
pass
def on_tool_end(
self,
output: str,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Do nothing when tool ends."""
pass
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when tool outputs an error."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any,
) -> None:
"""Run when LLM starts running."""
# Currently, for chat models, we only support input prompts for ChatOpenAI.
# Check if this model is a ChatOpenAI model.
values = serialized.get("id")
if values:
for value in values:
if value == "ChatOpenAI":
self.is_chat_openai_model = True
break
# Track prompt tokens for ChatOpenAI model.
if self.is_chat_openai_model:
invocation_params = kwargs.get("invocation_params")
if invocation_params:
model_name = invocation_params.get("model_name")
if model_name:
self.chat_openai_model_name = model_name
prompt_tokens = 0
for message_list in messages:
message_string = " ".join(
cast(str, msg.content) for msg in message_list
)
num_tokens = get_num_tokens(
message_string,
openai_model_name=self.chat_openai_model_name,
)
prompt_tokens += num_tokens
self._send_to_infino("prompt_tokens", prompt_tokens)
if self.verbose:
print(
f"on_chat_model_start: is_chat_openai_model= \
{self.is_chat_openai_model}, \
chat_openai_model_name={self.chat_openai_model_name}"
)
# Send the prompt to infino
prompt = " ".join(
cast(str, msg.content) for sublist in messages for msg in sublist
)
self._send_to_infino("prompt", prompt, is_ts=False)
# Set the error flag to indicate no error (this will get overridden
# in on_llm_error if an error occurs).
self.error = 0
# Set the start time (so that we can calculate the request
# duration in on_llm_end).
self.start_time = time.time()
__all__ = [
"import_infino",
"import_tiktoken",
"get_num_tokens",
"InfinoCallbackHandler",
]

View File

@@ -1,391 +1,7 @@
import os
import warnings
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
from uuid import UUID
from langchain_community.callbacks.labelstudio_callback import (
LabelStudioCallbackHandler,
LabelStudioMode,
get_default_label_configs,
)
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage, ChatMessage
from langchain_core.outputs import Generation, LLMResult
from langchain.callbacks.base import BaseCallbackHandler
class LabelStudioMode(Enum):
"""Label Studio mode enumerator."""
PROMPT = "prompt"
CHAT = "chat"
def get_default_label_configs(
mode: Union[str, LabelStudioMode]
) -> Tuple[str, LabelStudioMode]:
"""Get default Label Studio configs for the given mode.
Parameters:
mode: Label Studio mode ("prompt" or "chat")
Returns: Tuple of Label Studio config and mode
"""
_default_label_configs = {
LabelStudioMode.PROMPT.value: """
<View>
<Style>
.prompt-box {
background-color: white;
border-radius: 10px;
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
padding: 20px;
}
</Style>
<View className="root">
<View className="prompt-box">
<Text name="prompt" value="$prompt"/>
</View>
<TextArea name="response" toName="prompt"
maxSubmissions="1" editable="true"
required="true"/>
</View>
<Header value="Rate the response:"/>
<Rating name="rating" toName="prompt"/>
</View>""",
LabelStudioMode.CHAT.value: """
<View>
<View className="root">
<Paragraphs name="dialogue"
value="$prompt"
layout="dialogue"
textKey="content"
nameKey="role"
granularity="sentence"/>
<Header value="Final response:"/>
<TextArea name="response" toName="dialogue"
maxSubmissions="1" editable="true"
required="true"/>
</View>
<Header value="Rate the response:"/>
<Rating name="rating" toName="dialogue"/>
</View>""",
}
if isinstance(mode, str):
mode = LabelStudioMode(mode)
return _default_label_configs[mode.value], mode
class LabelStudioCallbackHandler(BaseCallbackHandler):
"""Label Studio callback handler.
Provides the ability to send predictions to Label Studio
for human evaluation, feedback and annotation.
Parameters:
api_key: Label Studio API key
url: Label Studio URL
project_id: Label Studio project ID
project_name: Label Studio project name
project_config: Label Studio project config (XML)
mode: Label Studio mode ("prompt" or "chat")
Examples:
>>> from langchain.llms import OpenAI
>>> from langchain.callbacks import LabelStudioCallbackHandler
>>> handler = LabelStudioCallbackHandler(
... api_key='<your_key_here>',
... url='http://localhost:8080',
... project_name='LangChain-%Y-%m-%d',
... mode='prompt'
... )
>>> llm = OpenAI(callbacks=[handler])
>>> llm.predict('Tell me a story about a dog.')
"""
DEFAULT_PROJECT_NAME: str = "LangChain-%Y-%m-%d"
def __init__(
self,
api_key: Optional[str] = None,
url: Optional[str] = None,
project_id: Optional[int] = None,
project_name: str = DEFAULT_PROJECT_NAME,
project_config: Optional[str] = None,
mode: Union[str, LabelStudioMode] = LabelStudioMode.PROMPT,
):
super().__init__()
# Import LabelStudio SDK
try:
import label_studio_sdk as ls
except ImportError:
raise ImportError(
f"You're using {self.__class__.__name__} in your code,"
f" but you don't have the LabelStudio SDK "
f"Python package installed or upgraded to the latest version. "
f"Please run `pip install -U label-studio-sdk`"
f" before using this callback."
)
# Check if Label Studio API key is provided
if not api_key:
if os.getenv("LABEL_STUDIO_API_KEY"):
api_key = str(os.getenv("LABEL_STUDIO_API_KEY"))
else:
raise ValueError(
f"You're using {self.__class__.__name__} in your code,"
f" Label Studio API key is not provided. "
f"Please provide Label Studio API key: "
f"go to the Label Studio instance, navigate to "
f"Account & Settings -> Access Token and copy the key. "
f"Use the key as a parameter for the callback: "
f"{self.__class__.__name__}"
f"(label_studio_api_key='<your_key_here>', ...) or "
f"set the environment variable LABEL_STUDIO_API_KEY=<your_key_here>"
)
self.api_key = api_key
if not url:
if os.getenv("LABEL_STUDIO_URL"):
url = os.getenv("LABEL_STUDIO_URL")
else:
warnings.warn(
f"Label Studio URL is not provided, "
f"using default URL: {ls.LABEL_STUDIO_DEFAULT_URL}"
f"If you want to provide your own URL, use the parameter: "
f"{self.__class__.__name__}"
f"(label_studio_url='<your_url_here>', ...) "
f"or set the environment variable LABEL_STUDIO_URL=<your_url_here>"
)
url = ls.LABEL_STUDIO_DEFAULT_URL
self.url = url
# Maps run_id to prompts
self.payload: Dict[str, Dict] = {}
self.ls_client = ls.Client(url=self.url, api_key=self.api_key)
self.project_name = project_name
if project_config:
self.project_config = project_config
self.mode = None
else:
self.project_config, self.mode = get_default_label_configs(mode)
self.project_id = project_id or os.getenv("LABEL_STUDIO_PROJECT_ID")
if self.project_id is not None:
self.ls_project = self.ls_client.get_project(int(self.project_id))
else:
project_title = datetime.today().strftime(self.project_name)
existing_projects = self.ls_client.get_projects(title=project_title)
if existing_projects:
self.ls_project = existing_projects[0]
self.project_id = self.ls_project.id
else:
self.ls_project = self.ls_client.create_project(
title=project_title, label_config=self.project_config
)
self.project_id = self.ls_project.id
self.parsed_label_config = self.ls_project.parsed_label_config
# Find the first TextArea tag
# "from_name", "to_name", "value" will be used to create predictions
self.from_name, self.to_name, self.value, self.input_type = (
None,
None,
None,
None,
)
for tag_name, tag_info in self.parsed_label_config.items():
if tag_info["type"] == "TextArea":
self.from_name = tag_name
self.to_name = tag_info["to_name"][0]
self.value = tag_info["inputs"][0]["value"]
self.input_type = tag_info["inputs"][0]["type"]
break
if not self.from_name:
error_message = (
f'Label Studio project "{self.project_name}" '
f"does not have a TextArea tag. "
f"Please add a TextArea tag to the project."
)
if self.mode == LabelStudioMode.PROMPT:
error_message += (
"\nHINT: go to project Settings -> "
"Labeling Interface -> Browse Templates"
' and select "Generative AI -> '
'Supervised Language Model Fine-tuning" template.'
)
else:
error_message += (
"\nHINT: go to project Settings -> "
"Labeling Interface -> Browse Templates"
" and check available templates under "
'"Generative AI" section.'
)
raise ValueError(error_message)
def add_prompts_generations(
self, run_id: str, generations: List[List[Generation]]
) -> None:
# Create tasks in Label Studio
tasks = []
prompts = self.payload[run_id]["prompts"]
model_version = (
self.payload[run_id]["kwargs"]
.get("invocation_params", {})
.get("model_name")
)
for prompt, generation in zip(prompts, generations):
tasks.append(
{
"data": {
self.value: prompt,
"run_id": run_id,
},
"predictions": [
{
"result": [
{
"from_name": self.from_name,
"to_name": self.to_name,
"type": "textarea",
"value": {"text": [g.text for g in generation]},
}
],
"model_version": model_version,
}
],
}
)
self.ls_project.import_tasks(tasks)
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
**kwargs: Any,
) -> None:
"""Save the prompts in memory when an LLM starts."""
if self.input_type != "Text":
raise ValueError(
f'\nLabel Studio project "{self.project_name}" '
f"has an input type <{self.input_type}>. "
f'To make it work with the mode="chat", '
f"the input type should be <Text>.\n"
f"Read more here https://labelstud.io/tags/text"
)
run_id = str(kwargs["run_id"])
self.payload[run_id] = {"prompts": prompts, "kwargs": kwargs}
def _get_message_role(self, message: BaseMessage) -> str:
"""Get the role of the message."""
if isinstance(message, ChatMessage):
return message.role
else:
return message.__class__.__name__
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Save the prompts in memory when an LLM starts."""
if self.input_type != "Paragraphs":
raise ValueError(
f'\nLabel Studio project "{self.project_name}" '
f"has an input type <{self.input_type}>. "
f'To make it work with the mode="chat", '
f"the input type should be <Paragraphs>.\n"
f"Read more here https://labelstud.io/tags/paragraphs"
)
prompts = []
for message_list in messages:
dialog = []
for message in message_list:
dialog.append(
{
"role": self._get_message_role(message),
"content": message.content,
}
)
prompts.append(dialog)
self.payload[str(run_id)] = {
"prompts": prompts,
"tags": tags,
"metadata": metadata,
"run_id": run_id,
"parent_run_id": parent_run_id,
"kwargs": kwargs,
}
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing when a new token is generated."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Create a new Label Studio task for each prompt and generation."""
run_id = str(kwargs["run_id"])
# Submit results to Label Studio
self.add_prompts_generations(run_id, response.generations)
# Pop current run from `self.runs`
self.payload.pop(run_id)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM outputs an error."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM chain outputs an error."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing when tool starts."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing when agent takes a specific action."""
pass
def on_tool_end(
self,
output: str,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Do nothing when tool ends."""
pass
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when tool outputs an error."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Do nothing"""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Do nothing"""
pass
__all__ = ["LabelStudioMode", "get_default_label_configs", "LabelStudioCallbackHandler"]

View File

@@ -1,681 +1,35 @@
import importlib.metadata
import logging
import os
import traceback
import warnings
from contextvars import ContextVar
from typing import Any, Dict, List, Union, cast
from uuid import UUID
from langchain_community.callbacks.llmonitor_callback import (
DEFAULT_API_URL,
PARAMS_TO_CAPTURE,
LLMonitorCallbackHandler,
UserContextManager,
_get_user_id,
_get_user_props,
_parse_input,
_parse_lc_message,
_parse_lc_messages,
_parse_lc_role,
_parse_output,
_serialize,
identify,
user_ctx,
user_props_ctx,
)
import requests
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult
from packaging.version import parse
from langchain.callbacks.base import BaseCallbackHandler
logger = logging.getLogger(__name__)
DEFAULT_API_URL = "https://app.llmonitor.com"
user_ctx = ContextVar[Union[str, None]]("user_ctx", default=None)
user_props_ctx = ContextVar[Union[str, None]]("user_props_ctx", default=None)
PARAMS_TO_CAPTURE = [
"temperature",
"top_p",
"top_k",
"stop",
"presence_penalty",
"frequence_penalty",
"seed",
"function_call",
"functions",
"tools",
"tool_choice",
"response_format",
"max_tokens",
"logit_bias",
__all__ = [
"DEFAULT_API_URL",
"user_ctx",
"user_props_ctx",
"PARAMS_TO_CAPTURE",
"UserContextManager",
"identify",
"_serialize",
"_parse_input",
"_parse_output",
"_parse_lc_role",
"_get_user_id",
"_get_user_props",
"_parse_lc_message",
"_parse_lc_messages",
"LLMonitorCallbackHandler",
]
class UserContextManager:
"""Context manager for LLMonitor user context."""
def __init__(self, user_id: str, user_props: Any = None) -> None:
user_ctx.set(user_id)
user_props_ctx.set(user_props)
def __enter__(self) -> Any:
pass
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> Any:
user_ctx.set(None)
user_props_ctx.set(None)
def identify(user_id: str, user_props: Any = None) -> UserContextManager:
"""Builds an LLMonitor UserContextManager
Parameters:
- `user_id`: The user id.
- `user_props`: The user properties.
Returns:
A context manager that sets the user context.
"""
return UserContextManager(user_id, user_props)
def _serialize(obj: Any) -> Union[Dict[str, Any], List[Any], Any]:
if hasattr(obj, "to_json"):
return obj.to_json()
if isinstance(obj, dict):
return {key: _serialize(value) for key, value in obj.items()}
if isinstance(obj, list):
return [_serialize(element) for element in obj]
return obj
def _parse_input(raw_input: Any) -> Any:
if not raw_input:
return None
# if it's an array of 1, just parse the first element
if isinstance(raw_input, list) and len(raw_input) == 1:
return _parse_input(raw_input[0])
if not isinstance(raw_input, dict):
return _serialize(raw_input)
input_value = raw_input.get("input")
inputs_value = raw_input.get("inputs")
question_value = raw_input.get("question")
query_value = raw_input.get("query")
if input_value:
return input_value
if inputs_value:
return inputs_value
if question_value:
return question_value
if query_value:
return query_value
return _serialize(raw_input)
def _parse_output(raw_output: dict) -> Any:
if not raw_output:
return None
if not isinstance(raw_output, dict):
return _serialize(raw_output)
text_value = raw_output.get("text")
output_value = raw_output.get("output")
output_text_value = raw_output.get("output_text")
answer_value = raw_output.get("answer")
result_value = raw_output.get("result")
if text_value:
return text_value
if answer_value:
return answer_value
if output_value:
return output_value
if output_text_value:
return output_text_value
if result_value:
return result_value
return _serialize(raw_output)
def _parse_lc_role(
role: str,
) -> str:
if role == "human":
return "user"
else:
return role
def _get_user_id(metadata: Any) -> Any:
if user_ctx.get() is not None:
return user_ctx.get()
metadata = metadata or {}
user_id = metadata.get("user_id")
if user_id is None:
user_id = metadata.get("userId") # legacy, to delete in the future
return user_id
def _get_user_props(metadata: Any) -> Any:
if user_props_ctx.get() is not None:
return user_props_ctx.get()
metadata = metadata or {}
return metadata.get("user_props", None)
def _parse_lc_message(message: BaseMessage) -> Dict[str, Any]:
keys = ["function_call", "tool_calls", "tool_call_id", "name"]
parsed = {"text": message.content, "role": _parse_lc_role(message.type)}
parsed.update(
{
key: cast(Any, message.additional_kwargs.get(key))
for key in keys
if message.additional_kwargs.get(key) is not None
}
)
return parsed
def _parse_lc_messages(messages: Union[List[BaseMessage], Any]) -> List[Dict[str, Any]]:
return [_parse_lc_message(message) for message in messages]
class LLMonitorCallbackHandler(BaseCallbackHandler):
"""Callback Handler for LLMonitor`.
#### Parameters:
- `app_id`: The app id of the app you want to report to. Defaults to
`None`, which means that `LLMONITOR_APP_ID` will be used.
- `api_url`: The url of the LLMonitor API. Defaults to `None`,
which means that either `LLMONITOR_API_URL` environment variable
or `https://app.llmonitor.com` will be used.
#### Raises:
- `ValueError`: if `app_id` is not provided either as an
argument or as an environment variable.
- `ConnectionError`: if the connection to the API fails.
#### Example:
```python
from langchain.llms import OpenAI
from langchain.callbacks import LLMonitorCallbackHandler
llmonitor_callback = LLMonitorCallbackHandler()
llm = OpenAI(callbacks=[llmonitor_callback],
metadata={"userId": "user-123"})
llm.predict("Hello, how are you?")
```
"""
__api_url: str
__app_id: str
__verbose: bool
__llmonitor_version: str
__has_valid_config: bool
def __init__(
self,
app_id: Union[str, None] = None,
api_url: Union[str, None] = None,
verbose: bool = False,
) -> None:
super().__init__()
self.__has_valid_config = True
try:
import llmonitor
self.__llmonitor_version = importlib.metadata.version("llmonitor")
self.__track_event = llmonitor.track_event
except ImportError:
logger.warning(
"""[LLMonitor] To use the LLMonitor callback handler you need to
have the `llmonitor` Python package installed. Please install it
with `pip install llmonitor`"""
)
self.__has_valid_config = False
return
if parse(self.__llmonitor_version) < parse("0.0.32"):
logger.warning(
f"""[LLMonitor] The installed `llmonitor` version is
{self.__llmonitor_version}
but `LLMonitorCallbackHandler` requires at least version 0.0.32
upgrade `llmonitor` with `pip install --upgrade llmonitor`"""
)
self.__has_valid_config = False
self.__has_valid_config = True
self.__api_url = api_url or os.getenv("LLMONITOR_API_URL") or DEFAULT_API_URL
self.__verbose = verbose or bool(os.getenv("LLMONITOR_VERBOSE"))
_app_id = app_id or os.getenv("LLMONITOR_APP_ID")
if _app_id is None:
logger.warning(
"""[LLMonitor] app_id must be provided either as an argument or
as an environment variable"""
)
self.__has_valid_config = False
else:
self.__app_id = _app_id
if self.__has_valid_config is False:
return None
try:
res = requests.get(f"{self.__api_url}/api/app/{self.__app_id}")
if not res.ok:
raise ConnectionError()
except Exception:
logger.warning(
f"""[LLMonitor] Could not connect to the LLMonitor API at
{self.__api_url}"""
)
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
tags: Union[List[str], None] = None,
metadata: Union[Dict[str, Any], None] = None,
**kwargs: Any,
) -> None:
if self.__has_valid_config is False:
return
try:
user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
params = kwargs.get("invocation_params", {})
params.update(
serialized.get("kwargs", {})
) # Sometimes, for example with ChatAnthropic, `invocation_params` is empty
name = (
params.get("model")
or params.get("model_name")
or params.get("model_id")
)
if not name and "anthropic" in params.get("_type"):
name = "claude-2"
extra = {
param: params.get(param)
for param in PARAMS_TO_CAPTURE
if params.get(param) is not None
}
input = _parse_input(prompts)
self.__track_event(
"llm",
"start",
user_id=user_id,
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
name=name,
input=input,
tags=tags,
extra=extra,
metadata=metadata,
user_props=user_props,
app_id=self.__app_id,
)
except Exception as e:
warnings.warn(f"[LLMonitor] An error occurred in on_llm_start: {e}")
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
tags: Union[List[str], None] = None,
metadata: Union[Dict[str, Any], None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
params = kwargs.get("invocation_params", {})
params.update(
serialized.get("kwargs", {})
) # Sometimes, for example with ChatAnthropic, `invocation_params` is empty
name = (
params.get("model")
or params.get("model_name")
or params.get("model_id")
)
if not name and "anthropic" in params.get("_type"):
name = "claude-2"
extra = {
param: params.get(param)
for param in PARAMS_TO_CAPTURE
if params.get(param) is not None
}
input = _parse_lc_messages(messages[0])
self.__track_event(
"llm",
"start",
user_id=user_id,
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
name=name,
input=input,
tags=tags,
extra=extra,
metadata=metadata,
user_props=user_props,
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_chat_model_start: {e}")
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> None:
if self.__has_valid_config is False:
return
try:
token_usage = (response.llm_output or {}).get("token_usage", {})
parsed_output: Any = [
_parse_lc_message(generation.message)
if hasattr(generation, "message")
else generation.text
for generation in response.generations[0]
]
# if it's an array of 1, just parse the first element
if len(parsed_output) == 1:
parsed_output = parsed_output[0]
self.__track_event(
"llm",
"end",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
output=parsed_output,
token_usage={
"prompt": token_usage.get("prompt_tokens"),
"completion": token_usage.get("completion_tokens"),
},
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_llm_end: {e}")
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
tags: Union[List[str], None] = None,
metadata: Union[Dict[str, Any], None] = None,
**kwargs: Any,
) -> None:
if self.__has_valid_config is False:
return
try:
user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
name = serialized.get("name")
self.__track_event(
"tool",
"start",
user_id=user_id,
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
name=name,
input=input_str,
tags=tags,
metadata=metadata,
user_props=user_props,
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_tool_start: {e}")
def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
tags: Union[List[str], None] = None,
**kwargs: Any,
) -> None:
if self.__has_valid_config is False:
return
try:
self.__track_event(
"tool",
"end",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
output=output,
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_tool_end: {e}")
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
tags: Union[List[str], None] = None,
metadata: Union[Dict[str, Any], None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
name = serialized.get("id", [None, None, None, None])[3]
type = "chain"
metadata = metadata or {}
agentName = metadata.get("agent_name")
if agentName is None:
agentName = metadata.get("agentName")
if name == "AgentExecutor" or name == "PlanAndExecute":
type = "agent"
if agentName is not None:
type = "agent"
name = agentName
if parent_run_id is not None:
type = "chain"
user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
input = _parse_input(inputs)
self.__track_event(
type,
"start",
user_id=user_id,
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
name=name,
input=input,
tags=tags,
metadata=metadata,
user_props=user_props,
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_chain_start: {e}")
def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
output = _parse_output(outputs)
self.__track_event(
"chain",
"end",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
output=output,
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_chain_end: {e}")
def on_agent_action(
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
name = action.tool
input = _parse_input(action.tool_input)
self.__track_event(
"tool",
"start",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
name=name,
input=input,
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_agent_action: {e}")
def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
output = _parse_output(finish.return_values)
self.__track_event(
"agent",
"end",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
output=output,
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_agent_finish: {e}")
def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
self.__track_event(
"chain",
"error",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
error={"message": str(error), "stack": traceback.format_exc()},
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_chain_error: {e}")
def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
self.__track_event(
"tool",
"error",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
error={"message": str(error), "stack": traceback.format_exc()},
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_tool_error: {e}")
def on_llm_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
self.__track_event(
"llm",
"error",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
error={"message": str(error), "stack": traceback.format_exc()},
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_llm_error: {e}")
__all__ = ["LLMonitorCallbackHandler", "identify"]

View File

@@ -1,13 +1,9 @@
from __future__ import annotations
import logging
from contextlib import contextmanager
from contextvars import ContextVar
from typing import (
Generator,
Optional,
from langchain_community.callbacks.manager import (
get_openai_callback,
wandb_tracing_enabled,
)
from langchain_core.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainGroup,
@@ -34,71 +30,11 @@ from langchain_core.callbacks.manager import (
)
from langchain_core.tracers.context import (
collect_runs,
register_configure_hook,
tracing_enabled,
tracing_v2_enabled,
)
from langchain_core.utils.env import env_var_is_set
from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.tracers.wandb import WandbTracer
logger = logging.getLogger(__name__)
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
"openai_callback", default=None
)
wandb_tracing_callback_var: ContextVar[Optional[WandbTracer]] = ContextVar( # noqa: E501
"tracing_wandb_callback", default=None
)
register_configure_hook(openai_callback_var, True)
register_configure_hook(
wandb_tracing_callback_var, True, WandbTracer, "LANGCHAIN_WANDB_TRACING"
)
@contextmanager
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
"""Get the OpenAI callback handler in a context manager.
which conveniently exposes token and cost information.
Returns:
OpenAICallbackHandler: The OpenAI callback handler.
Example:
>>> with get_openai_callback() as cb:
... # Use the OpenAI callback handler
"""
cb = OpenAICallbackHandler()
openai_callback_var.set(cb)
yield cb
openai_callback_var.set(None)
@contextmanager
def wandb_tracing_enabled(
session_name: str = "default",
) -> Generator[None, None, None]:
"""Get the WandbTracer in a context manager.
Args:
session_name (str, optional): The name of the session.
Defaults to "default".
Returns:
None
Example:
>>> with wandb_tracing_enabled() as session:
... # Use the WandbTracer session
"""
cb = WandbTracer()
wandb_tracing_callback_var.set(cb)
yield None
wandb_tracing_callback_var.set(None)
__all__ = [
"BaseRunManager",
"RunManager",
@@ -126,4 +62,6 @@ __all__ = [
"ahandle_event",
"Callbacks",
"env_var_is_set",
"get_openai_callback",
"wandb_tracing_enabled",
]

View File

@@ -1,660 +1,15 @@
import os
import random
import string
import tempfile
import traceback
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.outputs import LLMResult
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
hash_string,
import_pandas,
import_spacy,
import_textstat,
from langchain_community.callbacks.mlflow_callback import (
MlflowCallbackHandler,
MlflowLogger,
analyze_text,
construct_html_from_prompt_and_generation,
import_mlflow,
)
from langchain.utils import get_from_dict_or_env
def import_mlflow() -> Any:
"""Import the mlflow python package and raise an error if it is not installed."""
try:
import mlflow
except ImportError:
raise ImportError(
"To use the mlflow callback manager you need to have the `mlflow` python "
"package installed. Please install it with `pip install mlflow>=2.3.0`"
)
return mlflow
def analyze_text(
text: str,
nlp: Any = None,
) -> dict:
"""Analyze text using textstat and spacy.
Parameters:
text (str): The text to analyze.
nlp (spacy.lang): The spacy language model to use for visualization.
Returns:
(dict): A dictionary containing the complexity metrics and visualization
files serialized to HTML string.
"""
resp: Dict[str, Any] = {}
textstat = import_textstat()
spacy = import_spacy()
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"smog_index": textstat.smog_index(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
"automated_readability_index": textstat.automated_readability_index(text),
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
"difficult_words": textstat.difficult_words(text),
"linsear_write_formula": textstat.linsear_write_formula(text),
"gunning_fog": textstat.gunning_fog(text),
# "text_standard": textstat.text_standard(text),
"fernandez_huerta": textstat.fernandez_huerta(text),
"szigriszt_pazos": textstat.szigriszt_pazos(text),
"gutierrez_polini": textstat.gutierrez_polini(text),
"crawford": textstat.crawford(text),
"gulpease_index": textstat.gulpease_index(text),
"osman": textstat.osman(text),
}
resp.update({"text_complexity_metrics": text_complexity_metrics})
resp.update(text_complexity_metrics)
if nlp is not None:
doc = nlp(text)
dep_out = spacy.displacy.render( # type: ignore
doc, style="dep", jupyter=False, page=True
)
ent_out = spacy.displacy.render( # type: ignore
doc, style="ent", jupyter=False, page=True
)
text_visualizations = {
"dependency_tree": dep_out,
"entities": ent_out,
}
resp.update(text_visualizations)
return resp
def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> Any:
"""Construct an html element from a prompt and a generation.
Parameters:
prompt (str): The prompt.
generation (str): The generation.
Returns:
(str): The html string."""
formatted_prompt = prompt.replace("\n", "<br>")
formatted_generation = generation.replace("\n", "<br>")
return f"""
<p style="color:black;">{formatted_prompt}:</p>
<blockquote>
<p style="color:green;">
{formatted_generation}
</p>
</blockquote>
"""
class MlflowLogger:
"""Callback Handler that logs metrics and artifacts to mlflow server.
Parameters:
name (str): Name of the run.
experiment (str): Name of the experiment.
tags (dict): Tags to be attached for the run.
tracking_uri (str): MLflow tracking server uri.
This handler implements the helper functions to initialize,
log metrics and artifacts to the mlflow server.
"""
def __init__(self, **kwargs: Any):
self.mlflow = import_mlflow()
if "DATABRICKS_RUNTIME_VERSION" in os.environ:
self.mlflow.set_tracking_uri("databricks")
self.mlf_expid = self.mlflow.tracking.fluent._get_experiment_id()
self.mlf_exp = self.mlflow.get_experiment(self.mlf_expid)
else:
tracking_uri = get_from_dict_or_env(
kwargs, "tracking_uri", "MLFLOW_TRACKING_URI", ""
)
self.mlflow.set_tracking_uri(tracking_uri)
# User can set other env variables described here
# > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server
experiment_name = get_from_dict_or_env(
kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME"
)
self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name)
if self.mlf_exp is not None:
self.mlf_expid = self.mlf_exp.experiment_id
else:
self.mlf_expid = self.mlflow.create_experiment(experiment_name)
self.start_run(kwargs["run_name"], kwargs["run_tags"])
def start_run(self, name: str, tags: Dict[str, str]) -> None:
"""To start a new run, auto generates the random suffix for name"""
if name.endswith("-%"):
rname = "".join(random.choices(string.ascii_uppercase + string.digits, k=7))
name = name.replace("%", rname)
self.run = self.mlflow.MlflowClient().create_run(
self.mlf_expid, run_name=name, tags=tags
)
def finish_run(self) -> None:
"""To finish the run."""
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.end_run()
def metric(self, key: str, value: float) -> None:
"""To log metric to mlflow server."""
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.log_metric(key, value)
def metrics(
self, data: Union[Dict[str, float], Dict[str, int]], step: Optional[int] = 0
) -> None:
"""To log all metrics in the input dict."""
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.log_metrics(data)
def jsonf(self, data: Dict[str, Any], filename: str) -> None:
"""To log the input data as json file artifact."""
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.log_dict(data, f"{filename}.json")
def table(self, name: str, dataframe) -> None: # type: ignore
"""To log the input pandas dataframe as a html table"""
self.html(dataframe.to_html(), f"table_{name}")
def html(self, html: str, filename: str) -> None:
"""To log the input html string as html file artifact."""
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.log_text(html, f"{filename}.html")
def text(self, text: str, filename: str) -> None:
"""To log the input text as text file artifact."""
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.log_text(text, f"{filename}.txt")
def artifact(self, path: str) -> None:
"""To upload the file from given path as artifact."""
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.log_artifact(path)
def langchain_artifact(self, chain: Any) -> None:
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.langchain.log_model(chain, "langchain-model")
class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""Callback Handler that logs metrics and artifacts to mlflow server.
Parameters:
name (str): Name of the run.
experiment (str): Name of the experiment.
tags (dict): Tags to be attached for the run.
tracking_uri (str): MLflow tracking server uri.
This handler will utilize the associated callback method called and formats
the input of each callback function with metadata regarding the state of LLM run,
and adds the response to the list of records for both the {method}_records and
action. It then logs the response to mlflow server.
"""
def __init__(
self,
name: Optional[str] = "langchainrun-%",
experiment: Optional[str] = "langchain",
tags: Optional[Dict] = None,
tracking_uri: Optional[str] = None,
) -> None:
"""Initialize callback handler."""
import_pandas()
import_textstat()
import_mlflow()
spacy = import_spacy()
super().__init__()
self.name = name
self.experiment = experiment
self.tags = tags or {}
self.tracking_uri = tracking_uri
self.temp_dir = tempfile.TemporaryDirectory()
self.mlflg = MlflowLogger(
tracking_uri=self.tracking_uri,
experiment_name=self.experiment,
run_name=self.name,
run_tags=self.tags,
)
self.action_records: list = []
self.nlp = spacy.load("en_core_web_sm")
self.metrics = {
"step": 0,
"starts": 0,
"ends": 0,
"errors": 0,
"text_ctr": 0,
"chain_starts": 0,
"chain_ends": 0,
"llm_starts": 0,
"llm_ends": 0,
"llm_streams": 0,
"tool_starts": 0,
"tool_ends": 0,
"agent_ends": 0,
}
self.records: Dict[str, Any] = {
"on_llm_start_records": [],
"on_llm_token_records": [],
"on_llm_end_records": [],
"on_chain_start_records": [],
"on_chain_end_records": [],
"on_tool_start_records": [],
"on_tool_end_records": [],
"on_text_records": [],
"on_agent_finish_records": [],
"on_agent_action_records": [],
"action_records": [],
}
def _reset(self) -> None:
for k, v in self.metrics.items():
self.metrics[k] = 0
for k, v in self.records.items():
self.records[k] = []
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
self.metrics["step"] += 1
self.metrics["llm_starts"] += 1
self.metrics["starts"] += 1
llm_starts = self.metrics["llm_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_start"})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
for idx, prompt in enumerate(prompts):
prompt_resp = deepcopy(resp)
prompt_resp["prompt"] = prompt
self.records["on_llm_start_records"].append(prompt_resp)
self.records["action_records"].append(prompt_resp)
self.mlflg.jsonf(prompt_resp, f"llm_start_{llm_starts}_prompt_{idx}")
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.metrics["step"] += 1
self.metrics["llm_streams"] += 1
llm_streams = self.metrics["llm_streams"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_new_token", "token": token})
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_llm_token_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"llm_new_tokens_{llm_streams}")
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.metrics["step"] += 1
self.metrics["llm_ends"] += 1
self.metrics["ends"] += 1
llm_ends = self.metrics["llm_ends"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_end"})
resp.update(flatten_dict(response.llm_output or {}))
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
for generations in response.generations:
for idx, generation in enumerate(generations):
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
generation_resp.update(
analyze_text(
generation.text,
nlp=self.nlp,
)
)
complexity_metrics: Dict[str, float] = generation_resp.pop(
"text_complexity_metrics"
) # type: ignore # noqa: E501
self.mlflg.metrics(
complexity_metrics,
step=self.metrics["step"],
)
self.records["on_llm_end_records"].append(generation_resp)
self.records["action_records"].append(generation_resp)
self.mlflg.jsonf(resp, f"llm_end_{llm_ends}_generation_{idx}")
dependency_tree = generation_resp["dependency_tree"]
entities = generation_resp["entities"]
self.mlflg.html(dependency_tree, "dep-" + hash_string(generation.text))
self.mlflg.html(entities, "ent-" + hash_string(generation.text))
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.metrics["step"] += 1
self.metrics["chain_starts"] += 1
self.metrics["starts"] += 1
chain_starts = self.metrics["chain_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
input_resp = deepcopy(resp)
input_resp["inputs"] = chain_input
self.records["on_chain_start_records"].append(input_resp)
self.records["action_records"].append(input_resp)
self.mlflg.jsonf(input_resp, f"chain_start_{chain_starts}")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.metrics["step"] += 1
self.metrics["chain_ends"] += 1
self.metrics["ends"] += 1
chain_ends = self.metrics["chain_ends"]
resp: Dict[str, Any] = {}
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
resp.update({"action": "on_chain_end", "outputs": chain_output})
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_chain_end_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"chain_end_{chain_ends}")
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.metrics["step"] += 1
self.metrics["tool_starts"] += 1
self.metrics["starts"] += 1
tool_starts = self.metrics["tool_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_tool_start", "input_str": input_str})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_tool_start_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"tool_start_{tool_starts}")
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.metrics["step"] += 1
self.metrics["tool_ends"] += 1
self.metrics["ends"] += 1
tool_ends = self.metrics["tool_ends"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_tool_end", "output": output})
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_tool_end_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"tool_end_{tool_ends}")
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
"""
self.metrics["step"] += 1
self.metrics["text_ctr"] += 1
text_ctr = self.metrics["text_ctr"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_text", "text": text})
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_text_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"on_text_{text_ctr}")
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.metrics["step"] += 1
self.metrics["agent_ends"] += 1
self.metrics["ends"] += 1
agent_ends = self.metrics["agent_ends"]
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_finish",
"output": finish.return_values["output"],
"log": finish.log,
}
)
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_agent_finish_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"agent_finish_{agent_ends}")
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.metrics["step"] += 1
self.metrics["tool_starts"] += 1
self.metrics["starts"] += 1
tool_starts = self.metrics["tool_starts"]
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_action",
"tool": action.tool,
"tool_input": action.tool_input,
"log": action.log,
}
)
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_agent_action_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"agent_action_{tool_starts}")
def _create_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session."""
pd = import_pandas()
on_llm_start_records_df = pd.DataFrame(self.records["on_llm_start_records"])
on_llm_end_records_df = pd.DataFrame(self.records["on_llm_end_records"])
llm_input_columns = ["step", "prompt"]
if "name" in on_llm_start_records_df.columns:
llm_input_columns.append("name")
elif "id" in on_llm_start_records_df.columns:
# id is llm class's full import path. For example:
# ["langchain", "llms", "openai", "AzureOpenAI"]
on_llm_start_records_df["name"] = on_llm_start_records_df["id"].apply(
lambda id_: id_[-1]
)
llm_input_columns.append("name")
llm_input_prompts_df = (
on_llm_start_records_df[llm_input_columns]
.dropna(axis=1)
.rename({"step": "prompt_step"}, axis=1)
)
complexity_metrics_columns = []
visualizations_columns = []
complexity_metrics_columns = [
"flesch_reading_ease",
"flesch_kincaid_grade",
"smog_index",
"coleman_liau_index",
"automated_readability_index",
"dale_chall_readability_score",
"difficult_words",
"linsear_write_formula",
"gunning_fog",
# "text_standard",
"fernandez_huerta",
"szigriszt_pazos",
"gutierrez_polini",
"crawford",
"gulpease_index",
"osman",
]
visualizations_columns = ["dependency_tree", "entities"]
llm_outputs_df = (
on_llm_end_records_df[
[
"step",
"text",
"token_usage_total_tokens",
"token_usage_prompt_tokens",
"token_usage_completion_tokens",
]
+ complexity_metrics_columns
+ visualizations_columns
]
.dropna(axis=1)
.rename({"step": "output_step", "text": "output"}, axis=1)
)
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
session_analysis_df["chat_html"] = session_analysis_df[
["prompt", "output"]
].apply(
lambda row: construct_html_from_prompt_and_generation(
row["prompt"], row["output"]
),
axis=1,
)
return session_analysis_df
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:
pd = import_pandas()
self.mlflg.table("action_records", pd.DataFrame(self.records["action_records"]))
session_analysis_df = self._create_session_analysis_df()
chat_html = session_analysis_df.pop("chat_html")
chat_html = chat_html.replace("\n", "", regex=True)
self.mlflg.table("session_analysis", pd.DataFrame(session_analysis_df))
self.mlflg.html("".join(chat_html.tolist()), "chat_html")
if langchain_asset:
# To avoid circular import error
# mlflow only supports LLMChain asset
if "langchain.chains.llm.LLMChain" in str(type(langchain_asset)):
self.mlflg.langchain_artifact(langchain_asset)
else:
langchain_asset_path = str(Path(self.temp_dir.name, "model.json"))
try:
langchain_asset.save(langchain_asset_path)
self.mlflg.artifact(langchain_asset_path)
except ValueError:
try:
langchain_asset.save_agent(langchain_asset_path)
self.mlflg.artifact(langchain_asset_path)
except AttributeError:
print("Could not save model.")
traceback.print_exc()
pass
except NotImplementedError:
print("Could not save model.")
traceback.print_exc()
pass
except NotImplementedError:
print("Could not save model.")
traceback.print_exc()
pass
if finish:
self.mlflg.finish_run()
self._reset()
__all__ = [
"import_mlflow",
"analyze_text",
"construct_html_from_prompt_and_generation",
"MlflowLogger",
"MlflowCallbackHandler",
]

View File

@@ -1,209 +1,13 @@
"""Callback Handler that prints to std out."""
from typing import Any, Dict, List
from langchain_community.callbacks.openai_info import (
MODEL_COST_PER_1K_TOKENS,
OpenAICallbackHandler,
get_openai_token_cost_for_model,
standardize_model_name,
)
from langchain_core.outputs import LLMResult
from langchain.callbacks.base import BaseCallbackHandler
MODEL_COST_PER_1K_TOKENS = {
# GPT-4 input
"gpt-4": 0.03,
"gpt-4-0314": 0.03,
"gpt-4-0613": 0.03,
"gpt-4-32k": 0.06,
"gpt-4-32k-0314": 0.06,
"gpt-4-32k-0613": 0.06,
"gpt-4-vision-preview": 0.01,
"gpt-4-1106-preview": 0.01,
# GPT-4 output
"gpt-4-completion": 0.06,
"gpt-4-0314-completion": 0.06,
"gpt-4-0613-completion": 0.06,
"gpt-4-32k-completion": 0.12,
"gpt-4-32k-0314-completion": 0.12,
"gpt-4-32k-0613-completion": 0.12,
"gpt-4-vision-preview-completion": 0.03,
"gpt-4-1106-preview-completion": 0.03,
# GPT-3.5 input
"gpt-3.5-turbo": 0.0015,
"gpt-3.5-turbo-0301": 0.0015,
"gpt-3.5-turbo-0613": 0.0015,
"gpt-3.5-turbo-1106": 0.001,
"gpt-3.5-turbo-instruct": 0.0015,
"gpt-3.5-turbo-16k": 0.003,
"gpt-3.5-turbo-16k-0613": 0.003,
# GPT-3.5 output
"gpt-3.5-turbo-completion": 0.002,
"gpt-3.5-turbo-0301-completion": 0.002,
"gpt-3.5-turbo-0613-completion": 0.002,
"gpt-3.5-turbo-1106-completion": 0.002,
"gpt-3.5-turbo-instruct-completion": 0.002,
"gpt-3.5-turbo-16k-completion": 0.004,
"gpt-3.5-turbo-16k-0613-completion": 0.004,
# Azure GPT-35 input
"gpt-35-turbo": 0.0015, # Azure OpenAI version of ChatGPT
"gpt-35-turbo-0301": 0.0015, # Azure OpenAI version of ChatGPT
"gpt-35-turbo-0613": 0.0015,
"gpt-35-turbo-instruct": 0.0015,
"gpt-35-turbo-16k": 0.003,
"gpt-35-turbo-16k-0613": 0.003,
# Azure GPT-35 output
"gpt-35-turbo-completion": 0.002, # Azure OpenAI version of ChatGPT
"gpt-35-turbo-0301-completion": 0.002, # Azure OpenAI version of ChatGPT
"gpt-35-turbo-0613-completion": 0.002,
"gpt-35-turbo-instruct-completion": 0.002,
"gpt-35-turbo-16k-completion": 0.004,
"gpt-35-turbo-16k-0613-completion": 0.004,
# Others
"text-ada-001": 0.0004,
"ada": 0.0004,
"text-babbage-001": 0.0005,
"babbage": 0.0005,
"text-curie-001": 0.002,
"curie": 0.002,
"text-davinci-003": 0.02,
"text-davinci-002": 0.02,
"code-davinci-002": 0.02,
# Fine Tuned input
"babbage-002-finetuned": 0.0016,
"davinci-002-finetuned": 0.012,
"gpt-3.5-turbo-0613-finetuned": 0.012,
# Fine Tuned output
"babbage-002-finetuned-completion": 0.0016,
"davinci-002-finetuned-completion": 0.012,
"gpt-3.5-turbo-0613-finetuned-completion": 0.016,
# Azure Fine Tuned input
"babbage-002-azure-finetuned": 0.0004,
"davinci-002-azure-finetuned": 0.002,
"gpt-35-turbo-0613-azure-finetuned": 0.0015,
# Azure Fine Tuned output
"babbage-002-azure-finetuned-completion": 0.0004,
"davinci-002-azure-finetuned-completion": 0.002,
"gpt-35-turbo-0613-azure-finetuned-completion": 0.002,
# Legacy fine-tuned models
"ada-finetuned-legacy": 0.0016,
"babbage-finetuned-legacy": 0.0024,
"curie-finetuned-legacy": 0.012,
"davinci-finetuned-legacy": 0.12,
}
def standardize_model_name(
model_name: str,
is_completion: bool = False,
) -> str:
"""
Standardize the model name to a format that can be used in the OpenAI API.
Args:
model_name: Model name to standardize.
is_completion: Whether the model is used for completion or not.
Defaults to False.
Returns:
Standardized model name.
"""
model_name = model_name.lower()
if ".ft-" in model_name:
model_name = model_name.split(".ft-")[0] + "-azure-finetuned"
if ":ft-" in model_name:
model_name = model_name.split(":")[0] + "-finetuned-legacy"
if "ft:" in model_name:
model_name = model_name.split(":")[1] + "-finetuned"
if is_completion and (
model_name.startswith("gpt-4")
or model_name.startswith("gpt-3.5")
or model_name.startswith("gpt-35")
or ("finetuned" in model_name and "legacy" not in model_name)
):
return model_name + "-completion"
else:
return model_name
def get_openai_token_cost_for_model(
model_name: str, num_tokens: int, is_completion: bool = False
) -> float:
"""
Get the cost in USD for a given model and number of tokens.
Args:
model_name: Name of the model
num_tokens: Number of tokens.
is_completion: Whether the model is used for completion or not.
Defaults to False.
Returns:
Cost in USD.
"""
model_name = standardize_model_name(model_name, is_completion=is_completion)
if model_name not in MODEL_COST_PER_1K_TOKENS:
raise ValueError(
f"Unknown model: {model_name}. Please provide a valid OpenAI model name."
"Known models are: " + ", ".join(MODEL_COST_PER_1K_TOKENS.keys())
)
return MODEL_COST_PER_1K_TOKENS[model_name] * (num_tokens / 1000)
class OpenAICallbackHandler(BaseCallbackHandler):
"""Callback Handler that tracks OpenAI info."""
total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
successful_requests: int = 0
total_cost: float = 0.0
def __repr__(self) -> str:
return (
f"Tokens Used: {self.total_tokens}\n"
f"\tPrompt Tokens: {self.prompt_tokens}\n"
f"\tCompletion Tokens: {self.completion_tokens}\n"
f"Successful Requests: {self.successful_requests}\n"
f"Total Cost (USD): ${self.total_cost}"
)
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Print out the token."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
if response.llm_output is None:
return None
self.successful_requests += 1
if "token_usage" not in response.llm_output:
return None
token_usage = response.llm_output["token_usage"]
completion_tokens = token_usage.get("completion_tokens", 0)
prompt_tokens = token_usage.get("prompt_tokens", 0)
model_name = standardize_model_name(response.llm_output.get("model_name", ""))
if model_name in MODEL_COST_PER_1K_TOKENS:
completion_cost = get_openai_token_cost_for_model(
model_name, completion_tokens, is_completion=True
)
prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens)
self.total_cost += prompt_cost + completion_cost
self.total_tokens += token_usage.get("total_tokens", 0)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
def __copy__(self) -> "OpenAICallbackHandler":
"""Return a copy of the callback handler."""
return self
def __deepcopy__(self, memo: Any) -> "OpenAICallbackHandler":
"""Return a deep copy of the callback handler."""
return self
__all__ = [
"MODEL_COST_PER_1K_TOKENS",
"standardize_model_name",
"get_openai_token_cost_for_model",
"OpenAICallbackHandler",
]

View File

@@ -1,163 +1,6 @@
"""Callback handler for promptlayer."""
from __future__ import annotations
import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from uuid import UUID
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import (
ChatGeneration,
LLMResult,
from langchain_community.callbacks.promptlayer_callback import (
PromptLayerCallbackHandler,
_lazy_import_promptlayer,
)
from langchain.callbacks.base import BaseCallbackHandler
if TYPE_CHECKING:
import promptlayer
def _lazy_import_promptlayer() -> promptlayer:
"""Lazy import promptlayer to avoid circular imports."""
try:
import promptlayer
except ImportError:
raise ImportError(
"The PromptLayerCallbackHandler requires the promptlayer package. "
" Please install it with `pip install promptlayer`."
)
return promptlayer
class PromptLayerCallbackHandler(BaseCallbackHandler):
"""Callback handler for promptlayer."""
def __init__(
self,
pl_id_callback: Optional[Callable[..., Any]] = None,
pl_tags: Optional[List[str]] = None,
) -> None:
"""Initialize the PromptLayerCallbackHandler."""
_lazy_import_promptlayer()
self.pl_id_callback = pl_id_callback
self.pl_tags = pl_tags or []
self.runs: Dict[UUID, Dict[str, Any]] = {}
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> Any:
self.runs[run_id] = {
"messages": [self._create_message_dicts(m)[0] for m in messages],
"invocation_params": kwargs.get("invocation_params", {}),
"name": ".".join(serialized["id"]),
"request_start_time": datetime.datetime.now().timestamp(),
"tags": tags,
}
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> Any:
self.runs[run_id] = {
"prompts": prompts,
"invocation_params": kwargs.get("invocation_params", {}),
"name": ".".join(serialized["id"]),
"request_start_time": datetime.datetime.now().timestamp(),
"tags": tags,
}
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
from promptlayer.utils import get_api_key, promptlayer_api_request
run_info = self.runs.get(run_id, {})
if not run_info:
return
run_info["request_end_time"] = datetime.datetime.now().timestamp()
for i in range(len(response.generations)):
generation = response.generations[i][0]
resp = {
"text": generation.text,
"llm_output": response.llm_output,
}
model_params = run_info.get("invocation_params", {})
is_chat_model = run_info.get("messages", None) is not None
model_input = (
run_info.get("messages", [])[i]
if is_chat_model
else [run_info.get("prompts", [])[i]]
)
model_response = (
[self._convert_message_to_dict(generation.message)]
if is_chat_model and isinstance(generation, ChatGeneration)
else resp
)
pl_request_id = promptlayer_api_request(
run_info.get("name"),
"langchain",
model_input,
model_params,
self.pl_tags,
model_response,
run_info.get("request_start_time"),
run_info.get("request_end_time"),
get_api_key(),
return_pl_id=bool(self.pl_id_callback is not None),
metadata={
"_langchain_run_id": str(run_id),
"_langchain_parent_run_id": str(parent_run_id),
"_langchain_tags": str(run_info.get("tags", [])),
},
)
if self.pl_id_callback:
self.pl_id_callback(pl_request_id)
def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
if isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
def _create_message_dicts(
self, messages: List[BaseMessage]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params: Dict[str, Any] = {}
message_dicts = [self._convert_message_to_dict(m) for m in messages]
return message_dicts, params
__all__ = ["_lazy_import_promptlayer", "PromptLayerCallbackHandler"]

View File

@@ -1,276 +1,6 @@
import json
import os
import shutil
import tempfile
from copy import deepcopy
from typing import Any, Dict, List, Optional
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.outputs import LLMResult
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
flatten_dict,
from langchain_community.callbacks.sagemaker_callback import (
SageMakerCallbackHandler,
save_json,
)
def save_json(data: dict, file_path: str) -> None:
"""Save dict to local file path.
Parameters:
data (dict): The dictionary to be saved.
file_path (str): Local file path.
"""
with open(file_path, "w") as outfile:
json.dump(data, outfile)
class SageMakerCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs prompt artifacts and metrics to SageMaker Experiments.
Parameters:
run (sagemaker.experiments.run.Run): Run object where the experiment is logged.
"""
def __init__(self, run: Any) -> None:
"""Initialize callback handler."""
super().__init__()
self.run = run
self.metrics = {
"step": 0,
"starts": 0,
"ends": 0,
"errors": 0,
"text_ctr": 0,
"chain_starts": 0,
"chain_ends": 0,
"llm_starts": 0,
"llm_ends": 0,
"llm_streams": 0,
"tool_starts": 0,
"tool_ends": 0,
"agent_ends": 0,
}
# Create a temporary directory
self.temp_dir = tempfile.mkdtemp()
def _reset(self) -> None:
for k, v in self.metrics.items():
self.metrics[k] = 0
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
self.metrics["step"] += 1
self.metrics["llm_starts"] += 1
self.metrics["starts"] += 1
llm_starts = self.metrics["llm_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_start"})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
for idx, prompt in enumerate(prompts):
prompt_resp = deepcopy(resp)
prompt_resp["prompt"] = prompt
self.jsonf(
prompt_resp,
self.temp_dir,
f"llm_start_{llm_starts}_prompt_{idx}",
)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.metrics["step"] += 1
self.metrics["llm_streams"] += 1
llm_streams = self.metrics["llm_streams"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_new_token", "token": token})
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"llm_new_tokens_{llm_streams}")
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.metrics["step"] += 1
self.metrics["llm_ends"] += 1
self.metrics["ends"] += 1
llm_ends = self.metrics["llm_ends"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_end"})
resp.update(flatten_dict(response.llm_output or {}))
resp.update(self.metrics)
for generations in response.generations:
for idx, generation in enumerate(generations):
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
self.jsonf(
resp,
self.temp_dir,
f"llm_end_{llm_ends}_generation_{idx}",
)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.metrics["step"] += 1
self.metrics["chain_starts"] += 1
self.metrics["starts"] += 1
chain_starts = self.metrics["chain_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
input_resp = deepcopy(resp)
input_resp["inputs"] = chain_input
self.jsonf(input_resp, self.temp_dir, f"chain_start_{chain_starts}")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.metrics["step"] += 1
self.metrics["chain_ends"] += 1
self.metrics["ends"] += 1
chain_ends = self.metrics["chain_ends"]
resp: Dict[str, Any] = {}
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
resp.update({"action": "on_chain_end", "outputs": chain_output})
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}")
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.metrics["step"] += 1
self.metrics["tool_starts"] += 1
self.metrics["starts"] += 1
tool_starts = self.metrics["tool_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_tool_start", "input_str": input_str})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"tool_start_{tool_starts}")
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.metrics["step"] += 1
self.metrics["tool_ends"] += 1
self.metrics["ends"] += 1
tool_ends = self.metrics["tool_ends"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_tool_end", "output": output})
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}")
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
"""
self.metrics["step"] += 1
self.metrics["text_ctr"] += 1
text_ctr = self.metrics["text_ctr"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_text", "text": text})
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"on_text_{text_ctr}")
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.metrics["step"] += 1
self.metrics["agent_ends"] += 1
self.metrics["ends"] += 1
agent_ends = self.metrics["agent_ends"]
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_finish",
"output": finish.return_values["output"],
"log": finish.log,
}
)
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"agent_finish_{agent_ends}")
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.metrics["step"] += 1
self.metrics["tool_starts"] += 1
self.metrics["starts"] += 1
tool_starts = self.metrics["tool_starts"]
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_action",
"tool": action.tool,
"tool_input": action.tool_input,
"log": action.log,
}
)
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"agent_action_{tool_starts}")
def jsonf(
self,
data: Dict[str, Any],
data_dir: str,
filename: str,
is_output: Optional[bool] = True,
) -> None:
"""To log the input data as json file artifact."""
file_path = os.path.join(data_dir, f"{filename}.json")
save_json(data, file_path)
self.run.log_file(file_path, name=filename, is_output=is_output)
def flush_tracker(self) -> None:
"""Reset the steps and delete the temporary local directory."""
self._reset()
shutil.rmtree(self.temp_dir)
__all__ = ["save_json", "SageMakerCallbackHandler"]

View File

@@ -1,156 +1,7 @@
from __future__ import annotations
from langchain_community.callbacks.streamlit.mutable_expander import (
ChildRecord,
ChildType,
MutableExpander,
)
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator
from streamlit.type_util import SupportsStr
class ChildType(Enum):
"""The enumerator of the child type."""
MARKDOWN = "MARKDOWN"
EXCEPTION = "EXCEPTION"
class ChildRecord(NamedTuple):
"""The child record as a NamedTuple."""
type: ChildType
kwargs: Dict[str, Any]
dg: DeltaGenerator
class MutableExpander:
"""A Streamlit expander that can be renamed and dynamically expanded/collapsed."""
def __init__(self, parent_container: DeltaGenerator, label: str, expanded: bool):
"""Create a new MutableExpander.
Parameters
----------
parent_container
The `st.container` that the expander will be created inside.
The expander transparently deletes and recreates its underlying
`st.expander` instance when its label changes, and it uses
`parent_container` to ensure it recreates this underlying expander in the
same location onscreen.
label
The expander's initial label.
expanded
The expander's initial `expanded` value.
"""
self._label = label
self._expanded = expanded
self._parent_cursor = parent_container.empty()
self._container = self._parent_cursor.expander(label, expanded)
self._child_records: List[ChildRecord] = []
@property
def label(self) -> str:
"""The expander's label string."""
return self._label
@property
def expanded(self) -> bool:
"""True if the expander was created with `expanded=True`."""
return self._expanded
def clear(self) -> None:
"""Remove the container and its contents entirely. A cleared container can't
be reused.
"""
self._container = self._parent_cursor.empty()
self._child_records.clear()
def append_copy(self, other: MutableExpander) -> None:
"""Append a copy of another MutableExpander's children to this
MutableExpander.
"""
other_records = other._child_records.copy()
for record in other_records:
self._create_child(record.type, record.kwargs)
def update(
self, *, new_label: Optional[str] = None, new_expanded: Optional[bool] = None
) -> None:
"""Change the expander's label and expanded state"""
if new_label is None:
new_label = self._label
if new_expanded is None:
new_expanded = self._expanded
if self._label == new_label and self._expanded == new_expanded:
# No change!
return
self._label = new_label
self._expanded = new_expanded
self._container = self._parent_cursor.expander(new_label, new_expanded)
prev_records = self._child_records
self._child_records = []
# Replay all children into the new container
for record in prev_records:
self._create_child(record.type, record.kwargs)
def markdown(
self,
body: SupportsStr,
unsafe_allow_html: bool = False,
*,
help: Optional[str] = None,
index: Optional[int] = None,
) -> int:
"""Add a Markdown element to the container and return its index."""
kwargs = {"body": body, "unsafe_allow_html": unsafe_allow_html, "help": help}
new_dg = self._get_dg(index).markdown(**kwargs) # type: ignore[arg-type]
record = ChildRecord(ChildType.MARKDOWN, kwargs, new_dg)
return self._add_record(record, index)
def exception(
self, exception: BaseException, *, index: Optional[int] = None
) -> int:
"""Add an Exception element to the container and return its index."""
kwargs = {"exception": exception}
new_dg = self._get_dg(index).exception(**kwargs)
record = ChildRecord(ChildType.EXCEPTION, kwargs, new_dg)
return self._add_record(record, index)
def _create_child(self, type: ChildType, kwargs: Dict[str, Any]) -> None:
"""Create a new child with the given params"""
if type == ChildType.MARKDOWN:
self.markdown(**kwargs)
elif type == ChildType.EXCEPTION:
self.exception(**kwargs)
else:
raise RuntimeError(f"Unexpected child type {type}")
def _add_record(self, record: ChildRecord, index: Optional[int]) -> int:
"""Add a ChildRecord to self._children. If `index` is specified, replace
the existing record at that index. Otherwise, append the record to the
end of the list.
Return the index of the added record.
"""
if index is not None:
# Replace existing child
self._child_records[index] = record
return index
# Append new child
self._child_records.append(record)
return len(self._child_records) - 1
def _get_dg(self, index: Optional[int]) -> DeltaGenerator:
if index is not None:
# Existing index: reuse child's DeltaGenerator
assert 0 <= index < len(self._child_records), f"Bad index: {index}"
return self._child_records[index].dg
# No index: use container's DeltaGenerator
return self._container
__all__ = ["ChildType", "ChildRecord", "MutableExpander"]

View File

@@ -1,414 +1,25 @@
"""Callback Handler that prints to streamlit."""
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.outputs import LLMResult
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.streamlit.mutable_expander import MutableExpander
if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator
def _convert_newlines(text: str) -> str:
"""Convert newline characters to markdown newline sequences
(space, space, newline).
"""
return text.replace("\n", " \n")
CHECKMARK_EMOJI = ""
THINKING_EMOJI = ":thinking_face:"
HISTORY_EMOJI = ":books:"
EXCEPTION_EMOJI = "⚠️"
class LLMThoughtState(Enum):
"""Enumerator of the LLMThought state."""
# The LLM is thinking about what to do next. We don't know which tool we'll run.
THINKING = "THINKING"
# The LLM has decided to run a tool. We don't have results from the tool yet.
RUNNING_TOOL = "RUNNING_TOOL"
# We have results from the tool.
COMPLETE = "COMPLETE"
class ToolRecord(NamedTuple):
"""The tool record as a NamedTuple."""
name: str
input_str: str
class LLMThoughtLabeler:
"""
Generates markdown labels for LLMThought containers. Pass a custom
subclass of this to StreamlitCallbackHandler to override its default
labeling logic.
"""
def get_initial_label(self) -> str:
"""Return the markdown label for a new LLMThought that doesn't have
an associated tool yet.
"""
return f"{THINKING_EMOJI} **Thinking...**"
def get_tool_label(self, tool: ToolRecord, is_complete: bool) -> str:
"""Return the label for an LLMThought that has an associated
tool.
Parameters
----------
tool
The tool's ToolRecord
is_complete
True if the thought is complete; False if the thought
is still receiving input.
Returns
-------
The markdown label for the thought's container.
"""
input = tool.input_str
name = tool.name
emoji = CHECKMARK_EMOJI if is_complete else THINKING_EMOJI
if name == "_Exception":
emoji = EXCEPTION_EMOJI
name = "Parsing error"
idx = min([60, len(input)])
input = input[0:idx]
if len(tool.input_str) > idx:
input = input + "..."
input = input.replace("\n", " ")
label = f"{emoji} **{name}:** {input}"
return label
def get_history_label(self) -> str:
"""Return a markdown label for the special 'history' container
that contains overflow thoughts.
"""
return f"{HISTORY_EMOJI} **History**"
def get_final_agent_thought_label(self) -> str:
"""Return the markdown label for the agent's final thought -
the "Now I have the answer" thought, that doesn't involve
a tool.
"""
return f"{CHECKMARK_EMOJI} **Complete!**"
class LLMThought:
"""A thought in the LLM's thought stream."""
def __init__(
self,
parent_container: DeltaGenerator,
labeler: LLMThoughtLabeler,
expanded: bool,
collapse_on_complete: bool,
):
"""Initialize the LLMThought.
Args:
parent_container: The container we're writing into.
labeler: The labeler to use for this thought.
expanded: Whether the thought should be expanded by default.
collapse_on_complete: Whether the thought should be collapsed.
"""
self._container = MutableExpander(
parent_container=parent_container,
label=labeler.get_initial_label(),
expanded=expanded,
)
self._state = LLMThoughtState.THINKING
self._llm_token_stream = ""
self._llm_token_writer_idx: Optional[int] = None
self._last_tool: Optional[ToolRecord] = None
self._collapse_on_complete = collapse_on_complete
self._labeler = labeler
@property
def container(self) -> MutableExpander:
"""The container we're writing into."""
return self._container
@property
def last_tool(self) -> Optional[ToolRecord]:
"""The last tool executed by this thought"""
return self._last_tool
def _reset_llm_token_stream(self) -> None:
self._llm_token_stream = ""
self._llm_token_writer_idx = None
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str]) -> None:
self._reset_llm_token_stream()
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
# This is only called when the LLM is initialized with `streaming=True`
self._llm_token_stream += _convert_newlines(token)
self._llm_token_writer_idx = self._container.markdown(
self._llm_token_stream, index=self._llm_token_writer_idx
)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
# `response` is the concatenation of all the tokens received by the LLM.
# If we're receiving streaming tokens from `on_llm_new_token`, this response
# data is redundant
self._reset_llm_token_stream()
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self._container.markdown("**LLM encountered an error...**")
self._container.exception(error)
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
# Called with the name of the tool we're about to run (in `serialized[name]`),
# and its input. We change our container's label to be the tool name.
self._state = LLMThoughtState.RUNNING_TOOL
tool_name = serialized["name"]
self._last_tool = ToolRecord(name=tool_name, input_str=input_str)
self._container.update(
new_label=self._labeler.get_tool_label(self._last_tool, is_complete=False)
)
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
self._container.markdown(f"**{output}**")
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self._container.markdown("**Tool encountered an error...**")
self._container.exception(error)
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
# Called when we're about to kick off a new tool. The `action` data
# tells us the tool we're about to use, and the input we'll give it.
# We don't output anything here, because we'll receive this same data
# when `on_tool_start` is called immediately after.
pass
def complete(self, final_label: Optional[str] = None) -> None:
"""Finish the thought."""
if final_label is None and self._state == LLMThoughtState.RUNNING_TOOL:
assert (
self._last_tool is not None
), "_last_tool should never be null when _state == RUNNING_TOOL"
final_label = self._labeler.get_tool_label(
self._last_tool, is_complete=True
)
self._state = LLMThoughtState.COMPLETE
if self._collapse_on_complete:
self._container.update(new_label=final_label, new_expanded=False)
else:
self._container.update(new_label=final_label)
def clear(self) -> None:
"""Remove the thought from the screen. A cleared thought can't be reused."""
self._container.clear()
class StreamlitCallbackHandler(BaseCallbackHandler):
"""A callback handler that writes to a Streamlit app."""
def __init__(
self,
parent_container: DeltaGenerator,
*,
max_thought_containers: int = 4,
expand_new_thoughts: bool = True,
collapse_completed_thoughts: bool = True,
thought_labeler: Optional[LLMThoughtLabeler] = None,
):
"""Create a StreamlitCallbackHandler instance.
Parameters
----------
parent_container
The `st.container` that will contain all the Streamlit elements that the
Handler creates.
max_thought_containers
The max number of completed LLM thought containers to show at once. When
this threshold is reached, a new thought will cause the oldest thoughts to
be collapsed into a "History" expander. Defaults to 4.
expand_new_thoughts
Each LLM "thought" gets its own `st.expander`. This param controls whether
that expander is expanded by default. Defaults to True.
collapse_completed_thoughts
If True, LLM thought expanders will be collapsed when completed.
Defaults to True.
thought_labeler
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
will use the default thought labeling logic. Defaults to None.
"""
self._parent_container = parent_container
self._history_parent = parent_container.container()
self._history_container: Optional[MutableExpander] = None
self._current_thought: Optional[LLMThought] = None
self._completed_thoughts: List[LLMThought] = []
self._max_thought_containers = max(max_thought_containers, 1)
self._expand_new_thoughts = expand_new_thoughts
self._collapse_completed_thoughts = collapse_completed_thoughts
self._thought_labeler = thought_labeler or LLMThoughtLabeler()
def _require_current_thought(self) -> LLMThought:
"""Return our current LLMThought. Raise an error if we have no current
thought.
"""
if self._current_thought is None:
raise RuntimeError("Current LLMThought is unexpectedly None!")
return self._current_thought
def _get_last_completed_thought(self) -> Optional[LLMThought]:
"""Return our most recent completed LLMThought, or None if we don't have one."""
if len(self._completed_thoughts) > 0:
return self._completed_thoughts[len(self._completed_thoughts) - 1]
return None
@property
def _num_thought_containers(self) -> int:
"""The number of 'thought containers' we're currently showing: the
number of completed thought containers, the history container (if it exists),
and the current thought container (if it exists).
"""
count = len(self._completed_thoughts)
if self._history_container is not None:
count += 1
if self._current_thought is not None:
count += 1
return count
def _complete_current_thought(self, final_label: Optional[str] = None) -> None:
"""Complete the current thought, optionally assigning it a new label.
Add it to our _completed_thoughts list.
"""
thought = self._require_current_thought()
thought.complete(final_label)
self._completed_thoughts.append(thought)
self._current_thought = None
def _prune_old_thought_containers(self) -> None:
"""If we have too many thoughts onscreen, move older thoughts to the
'history container.'
"""
while (
self._num_thought_containers > self._max_thought_containers
and len(self._completed_thoughts) > 0
):
# Create our history container if it doesn't exist, and if
# max_thought_containers is > 1. (if max_thought_containers is 1, we don't
# have room to show history.)
if self._history_container is None and self._max_thought_containers > 1:
self._history_container = MutableExpander(
self._history_parent,
label=self._thought_labeler.get_history_label(),
expanded=False,
)
oldest_thought = self._completed_thoughts.pop(0)
if self._history_container is not None:
self._history_container.markdown(oldest_thought.container.label)
self._history_container.append_copy(oldest_thought.container)
oldest_thought.clear()
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
if self._current_thought is None:
self._current_thought = LLMThought(
parent_container=self._parent_container,
expanded=self._expand_new_thoughts,
collapse_on_complete=self._collapse_completed_thoughts,
labeler=self._thought_labeler,
)
self._current_thought.on_llm_start(serialized, prompts)
# We don't prune_old_thought_containers here, because our container won't
# be visible until it has a child.
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self._require_current_thought().on_llm_new_token(token, **kwargs)
self._prune_old_thought_containers()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self._require_current_thought().on_llm_end(response, **kwargs)
self._prune_old_thought_containers()
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self._require_current_thought().on_llm_error(error, **kwargs)
self._prune_old_thought_containers()
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
self._require_current_thought().on_tool_start(serialized, input_str, **kwargs)
self._prune_old_thought_containers()
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
self._require_current_thought().on_tool_end(
output, color, observation_prefix, llm_prefix, **kwargs
)
self._complete_current_thought()
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self._require_current_thought().on_tool_error(error, **kwargs)
self._prune_old_thought_containers()
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Any,
) -> None:
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
self._require_current_thought().on_agent_action(action, color, **kwargs)
self._prune_old_thought_containers()
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
if self._current_thought is not None:
self._current_thought.complete(
self._thought_labeler.get_final_agent_thought_label()
)
self._current_thought = None
from langchain_community.callbacks.streamlit.streamlit_callback_handler import (
CHECKMARK_EMOJI,
EXCEPTION_EMOJI,
HISTORY_EMOJI,
THINKING_EMOJI,
LLMThought,
LLMThoughtLabeler,
LLMThoughtState,
StreamlitCallbackHandler,
ToolRecord,
_convert_newlines,
)
__all__ = [
"_convert_newlines",
"CHECKMARK_EMOJI",
"THINKING_EMOJI",
"HISTORY_EMOJI",
"EXCEPTION_EMOJI",
"LLMThoughtState",
"ToolRecord",
"LLMThoughtLabeler",
"LLMThought",
"StreamlitCallbackHandler",
]

View File

@@ -1,138 +1,7 @@
from types import ModuleType, SimpleNamespace
from typing import TYPE_CHECKING, Any, Callable, Dict
from langchain_community.callbacks.tracers.comet import (
CometTracer,
_get_run_type,
import_comet_llm_api,
)
from langchain.callbacks.tracers.base import BaseTracer
if TYPE_CHECKING:
from uuid import UUID
from comet_llm import Span
from comet_llm.chains.chain import Chain
from langchain.callbacks.tracers.schemas import Run
def _get_run_type(run: "Run") -> str:
if isinstance(run.run_type, str):
return run.run_type
elif hasattr(run.run_type, "value"):
return run.run_type.value
else:
return str(run.run_type)
def import_comet_llm_api() -> SimpleNamespace:
"""Import comet_llm api and raise an error if it is not installed."""
try:
from comet_llm import (
experiment_info, # noqa: F401
flush, # noqa: F401
)
from comet_llm.chains import api as chain_api # noqa: F401
from comet_llm.chains import (
chain, # noqa: F401
span, # noqa: F401
)
except ImportError:
raise ImportError(
"To use the CometTracer you need to have the "
"`comet_llm>=2.0.0` python package installed. Please install it with"
" `pip install -U comet_llm`"
)
return SimpleNamespace(
chain=chain,
span=span,
chain_api=chain_api,
experiment_info=experiment_info,
flush=flush,
)
class CometTracer(BaseTracer):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._span_map: Dict["UUID", "Span"] = {}
self._chains_map: Dict["UUID", "Chain"] = {}
self._initialize_comet_modules()
def _initialize_comet_modules(self) -> None:
comet_llm_api = import_comet_llm_api()
self._chain: ModuleType = comet_llm_api.chain
self._span: ModuleType = comet_llm_api.span
self._chain_api: ModuleType = comet_llm_api.chain_api
self._experiment_info: ModuleType = comet_llm_api.experiment_info
self._flush: Callable[[], None] = comet_llm_api.flush
def _persist_run(self, run: "Run") -> None:
chain_ = self._chains_map[run.id]
chain_.set_outputs(outputs=run.outputs)
self._chain_api.log_chain(chain_)
def _process_start_trace(self, run: "Run") -> None:
if not run.parent_run_id:
# This is the first run, which maps to a chain
chain_: "Chain" = self._chain.Chain(
inputs=run.inputs,
metadata=None,
experiment_info=self._experiment_info.get(),
)
self._chains_map[run.id] = chain_
else:
span: "Span" = self._span.Span(
inputs=run.inputs,
category=_get_run_type(run),
metadata=run.extra,
name=run.name,
)
span.__api__start__(self._chains_map[run.parent_run_id])
self._chains_map[run.id] = self._chains_map[run.parent_run_id]
self._span_map[run.id] = span
def _process_end_trace(self, run: "Run") -> None:
if not run.parent_run_id:
pass
# Langchain will call _persist_run for us
else:
span = self._span_map[run.id]
span.set_outputs(outputs=run.outputs)
span.__api__end__()
def flush(self) -> None:
self._flush()
def _on_llm_start(self, run: "Run") -> None:
"""Process the LLM Run upon start."""
self._process_start_trace(run)
def _on_llm_end(self, run: "Run") -> None:
"""Process the LLM Run."""
self._process_end_trace(run)
def _on_llm_error(self, run: "Run") -> None:
"""Process the LLM Run upon error."""
self._process_end_trace(run)
def _on_chain_start(self, run: "Run") -> None:
"""Process the Chain Run upon start."""
self._process_start_trace(run)
def _on_chain_end(self, run: "Run") -> None:
"""Process the Chain Run."""
self._process_end_trace(run)
def _on_chain_error(self, run: "Run") -> None:
"""Process the Chain Run upon error."""
self._process_end_trace(run)
def _on_tool_start(self, run: "Run") -> None:
"""Process the Tool Run upon start."""
self._process_start_trace(run)
def _on_tool_end(self, run: "Run") -> None:
"""Process the Tool Run."""
self._process_end_trace(run)
def _on_tool_error(self, run: "Run") -> None:
"""Process the Tool Run upon error."""
self._process_end_trace(run)
__all__ = ["_get_run_type", "import_comet_llm_api", "CometTracer"]

View File

@@ -1,514 +1,15 @@
"""A Tracer Implementation that records activity to Weights & Biases."""
from __future__ import annotations
import json
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Sequence,
Tuple,
TypedDict,
Union,
from langchain_community.callbacks.tracers.wandb import (
PRINT_WARNINGS,
RunProcessor,
WandbRunArgs,
WandbTracer,
_serialize_io,
)
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
if TYPE_CHECKING:
from wandb import Settings as WBSettings
from wandb.sdk.data_types.trace_tree import Span
from wandb.sdk.lib.paths import StrPath
from wandb.wandb_run import Run as WBRun
PRINT_WARNINGS = True
def _serialize_io(run_inputs: Optional[dict]) -> dict:
if not run_inputs:
return {}
from google.protobuf.json_format import MessageToJson
from google.protobuf.message import Message
serialized_inputs = {}
for key, value in run_inputs.items():
if isinstance(value, Message):
serialized_inputs[key] = MessageToJson(value)
elif key == "input_documents":
serialized_inputs.update(
{f"input_document_{i}": doc.json() for i, doc in enumerate(value)}
)
else:
serialized_inputs[key] = value
return serialized_inputs
class RunProcessor:
"""Handles the conversion of a LangChain Runs into a WBTraceTree."""
def __init__(self, wandb_module: Any, trace_module: Any):
self.wandb = wandb_module
self.trace_tree = trace_module
def process_span(self, run: Run) -> Optional["Span"]:
"""Converts a LangChain Run into a W&B Trace Span.
:param run: The LangChain Run to convert.
:return: The converted W&B Trace Span.
"""
try:
span = self._convert_lc_run_to_wb_span(run)
return span
except Exception as e:
if PRINT_WARNINGS:
self.wandb.termwarn(
f"Skipping trace saving - unable to safely convert LangChain Run "
f"into W&B Trace due to: {e}"
)
return None
def _convert_run_to_wb_span(self, run: Run) -> "Span":
"""Base utility to create a span from a run.
:param run: The run to convert.
:return: The converted Span.
"""
attributes = {**run.extra} if run.extra else {}
attributes["execution_order"] = run.execution_order
return self.trace_tree.Span(
span_id=str(run.id) if run.id is not None else None,
name=run.name,
start_time_ms=int(run.start_time.timestamp() * 1000),
end_time_ms=int(run.end_time.timestamp() * 1000)
if run.end_time is not None
else None,
status_code=self.trace_tree.StatusCode.SUCCESS
if run.error is None
else self.trace_tree.StatusCode.ERROR,
status_message=run.error,
attributes=attributes,
)
def _convert_llm_run_to_wb_span(self, run: Run) -> "Span":
"""Converts a LangChain LLM Run into a W&B Trace Span.
:param run: The LangChain LLM Run to convert.
:return: The converted W&B Trace Span.
"""
base_span = self._convert_run_to_wb_span(run)
if base_span.attributes is None:
base_span.attributes = {}
base_span.attributes["llm_output"] = (run.outputs or {}).get("llm_output", {})
base_span.results = [
self.trace_tree.Result(
inputs={"prompt": prompt},
outputs={
f"gen_{g_i}": gen["text"]
for g_i, gen in enumerate(run.outputs["generations"][ndx])
}
if (
run.outputs is not None
and len(run.outputs["generations"]) > ndx
and len(run.outputs["generations"][ndx]) > 0
)
else None,
)
for ndx, prompt in enumerate(run.inputs["prompts"] or [])
]
base_span.span_kind = self.trace_tree.SpanKind.LLM
return base_span
def _convert_chain_run_to_wb_span(self, run: Run) -> "Span":
"""Converts a LangChain Chain Run into a W&B Trace Span.
:param run: The LangChain Chain Run to convert.
:return: The converted W&B Trace Span.
"""
base_span = self._convert_run_to_wb_span(run)
base_span.results = [
self.trace_tree.Result(
inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs)
)
]
base_span.child_spans = [
self._convert_lc_run_to_wb_span(child_run) for child_run in run.child_runs
]
base_span.span_kind = (
self.trace_tree.SpanKind.AGENT
if "agent" in run.name.lower()
else self.trace_tree.SpanKind.CHAIN
)
return base_span
def _convert_tool_run_to_wb_span(self, run: Run) -> "Span":
"""Converts a LangChain Tool Run into a W&B Trace Span.
:param run: The LangChain Tool Run to convert.
:return: The converted W&B Trace Span.
"""
base_span = self._convert_run_to_wb_span(run)
base_span.results = [
self.trace_tree.Result(
inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs)
)
]
base_span.child_spans = [
self._convert_lc_run_to_wb_span(child_run) for child_run in run.child_runs
]
base_span.span_kind = self.trace_tree.SpanKind.TOOL
return base_span
def _convert_lc_run_to_wb_span(self, run: Run) -> "Span":
"""Utility to convert any generic LangChain Run into a W&B Trace Span.
:param run: The LangChain Run to convert.
:return: The converted W&B Trace Span.
"""
if run.run_type == "llm":
return self._convert_llm_run_to_wb_span(run)
elif run.run_type == "chain":
return self._convert_chain_run_to_wb_span(run)
elif run.run_type == "tool":
return self._convert_tool_run_to_wb_span(run)
else:
return self._convert_run_to_wb_span(run)
def process_model(self, run: Run) -> Optional[Dict[str, Any]]:
"""Utility to process a run for wandb model_dict serialization.
:param run: The run to process.
:return: The convert model_dict to pass to WBTraceTree.
"""
try:
data = json.loads(run.json())
processed = self.flatten_run(data)
keep_keys = (
"id",
"name",
"serialized",
"inputs",
"outputs",
"parent_run_id",
"execution_order",
)
processed = self.truncate_run_iterative(processed, keep_keys=keep_keys)
exact_keys, partial_keys = ("lc", "type"), ("api_key",)
processed = self.modify_serialized_iterative(
processed, exact_keys=exact_keys, partial_keys=partial_keys
)
output = self.build_tree(processed)
return output
except Exception as e:
if PRINT_WARNINGS:
self.wandb.termwarn(f"WARNING: Failed to serialize model: {e}")
return None
def flatten_run(self, run: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Utility to flatten a nest run object into a list of runs.
:param run: The base run to flatten.
:return: The flattened list of runs.
"""
def flatten(child_runs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Utility to recursively flatten a list of child runs in a run.
:param child_runs: The list of child runs to flatten.
:return: The flattened list of runs.
"""
if child_runs is None:
return []
result = []
for item in child_runs:
child_runs = item.pop("child_runs", [])
result.append(item)
result.extend(flatten(child_runs))
return result
return flatten([run])
def truncate_run_iterative(
self, runs: List[Dict[str, Any]], keep_keys: Tuple[str, ...] = ()
) -> List[Dict[str, Any]]:
"""Utility to truncate a list of runs dictionaries to only keep the specified
keys in each run.
:param runs: The list of runs to truncate.
:param keep_keys: The keys to keep in each run.
:return: The truncated list of runs.
"""
def truncate_single(run: Dict[str, Any]) -> Dict[str, Any]:
"""Utility to truncate a single run dictionary to only keep the specified
keys.
:param run: The run dictionary to truncate.
:return: The truncated run dictionary
"""
new_dict = {}
for key in run:
if key in keep_keys:
new_dict[key] = run.get(key)
return new_dict
return list(map(truncate_single, runs))
def modify_serialized_iterative(
self,
runs: List[Dict[str, Any]],
exact_keys: Tuple[str, ...] = (),
partial_keys: Tuple[str, ...] = (),
) -> List[Dict[str, Any]]:
"""Utility to modify the serialized field of a list of runs dictionaries.
removes any keys that match the exact_keys and any keys that contain any of the
partial_keys.
recursively moves the dictionaries under the kwargs key to the top level.
changes the "id" field to a string "_kind" field that tells WBTraceTree how to
visualize the run. promotes the "serialized" field to the top level.
:param runs: The list of runs to modify.
:param exact_keys: A tuple of keys to remove from the serialized field.
:param partial_keys: A tuple of partial keys to remove from the serialized
field.
:return: The modified list of runs.
"""
def remove_exact_and_partial_keys(obj: Dict[str, Any]) -> Dict[str, Any]:
"""Recursively removes exact and partial keys from a dictionary.
:param obj: The dictionary to remove keys from.
:return: The modified dictionary.
"""
if isinstance(obj, dict):
obj = {
k: v
for k, v in obj.items()
if k not in exact_keys
and not any(partial in k for partial in partial_keys)
}
for k, v in obj.items():
obj[k] = remove_exact_and_partial_keys(v)
elif isinstance(obj, list):
obj = [remove_exact_and_partial_keys(x) for x in obj]
return obj
def handle_id_and_kwargs(
obj: Dict[str, Any], root: bool = False
) -> Dict[str, Any]:
"""Recursively handles the id and kwargs fields of a dictionary.
changes the id field to a string "_kind" field that tells WBTraceTree how
to visualize the run. recursively moves the dictionaries under the kwargs
key to the top level.
:param obj: a run dictionary with id and kwargs fields.
:param root: whether this is the root dictionary or the serialized
dictionary.
:return: The modified dictionary.
"""
if isinstance(obj, dict):
if ("id" in obj or "name" in obj) and not root:
_kind = obj.get("id")
if not _kind:
_kind = [obj.get("name")]
obj["_kind"] = _kind[-1]
obj.pop("id", None)
obj.pop("name", None)
if "kwargs" in obj:
kwargs = obj.pop("kwargs")
for k, v in kwargs.items():
obj[k] = v
for k, v in obj.items():
obj[k] = handle_id_and_kwargs(v)
elif isinstance(obj, list):
obj = [handle_id_and_kwargs(x) for x in obj]
return obj
def transform_serialized(serialized: Dict[str, Any]) -> Dict[str, Any]:
"""Transforms the serialized field of a run dictionary to be compatible
with WBTraceTree.
:param serialized: The serialized field of a run dictionary.
:return: The transformed serialized field.
"""
serialized = handle_id_and_kwargs(serialized, root=True)
serialized = remove_exact_and_partial_keys(serialized)
return serialized
def transform_run(run: Dict[str, Any]) -> Dict[str, Any]:
"""Transforms a run dictionary to be compatible with WBTraceTree.
:param run: The run dictionary to transform.
:return: The transformed run dictionary.
"""
transformed_dict = transform_serialized(run)
serialized = transformed_dict.pop("serialized")
for k, v in serialized.items():
transformed_dict[k] = v
_kind = transformed_dict.get("_kind", None)
name = transformed_dict.pop("name", None)
exec_ord = transformed_dict.pop("execution_order", None)
if not name:
name = _kind
output_dict = {
f"{exec_ord}_{name}": transformed_dict,
}
return output_dict
return list(map(transform_run, runs))
def build_tree(self, runs: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Builds a nested dictionary from a list of runs.
:param runs: The list of runs to build the tree from.
:return: The nested dictionary representing the langchain Run in a tree
structure compatible with WBTraceTree.
"""
id_to_data = {}
child_to_parent = {}
for entity in runs:
for key, data in entity.items():
id_val = data.pop("id", None)
parent_run_id = data.pop("parent_run_id", None)
id_to_data[id_val] = {key: data}
if parent_run_id:
child_to_parent[id_val] = parent_run_id
for child_id, parent_id in child_to_parent.items():
parent_dict = id_to_data[parent_id]
parent_dict[next(iter(parent_dict))][
next(iter(id_to_data[child_id]))
] = id_to_data[child_id][next(iter(id_to_data[child_id]))]
root_dict = next(
data for id_val, data in id_to_data.items() if id_val not in child_to_parent
)
return root_dict
class WandbRunArgs(TypedDict):
"""Arguments for the WandbTracer."""
job_type: Optional[str]
dir: Optional[StrPath]
config: Union[Dict, str, None]
project: Optional[str]
entity: Optional[str]
reinit: Optional[bool]
tags: Optional[Sequence]
group: Optional[str]
name: Optional[str]
notes: Optional[str]
magic: Optional[Union[dict, str, bool]]
config_exclude_keys: Optional[List[str]]
config_include_keys: Optional[List[str]]
anonymous: Optional[str]
mode: Optional[str]
allow_val_change: Optional[bool]
resume: Optional[Union[bool, str]]
force: Optional[bool]
tensorboard: Optional[bool]
sync_tensorboard: Optional[bool]
monitor_gym: Optional[bool]
save_code: Optional[bool]
id: Optional[str]
settings: Union[WBSettings, Dict[str, Any], None]
class WandbTracer(BaseTracer):
"""Callback Handler that logs to Weights and Biases.
This handler will log the model architecture and run traces to Weights and Biases.
This will ensure that all LangChain activity is logged to W&B.
"""
_run: Optional[WBRun] = None
_run_args: Optional[WandbRunArgs] = None
def __init__(self, run_args: Optional[WandbRunArgs] = None, **kwargs: Any) -> None:
"""Initializes the WandbTracer.
Parameters:
run_args: (dict, optional) Arguments to pass to `wandb.init()`. If not
provided, `wandb.init()` will be called with no arguments. Please
refer to the `wandb.init` for more details.
To use W&B to monitor all LangChain activity, add this tracer like any other
LangChain callback:
```
from wandb.integration.langchain import WandbTracer
tracer = WandbTracer()
chain = LLMChain(llm, callbacks=[tracer])
# ...end of notebook / script:
tracer.finish()
```
"""
super().__init__(**kwargs)
try:
import wandb
from wandb.sdk.data_types import trace_tree
except ImportError as e:
raise ImportError(
"Could not import wandb python package."
"Please install it with `pip install -U wandb`."
) from e
self._wandb = wandb
self._trace_tree = trace_tree
self._run_args = run_args
self._ensure_run(should_print_url=(wandb.run is None))
self.run_processor = RunProcessor(self._wandb, self._trace_tree)
def finish(self) -> None:
"""Waits for all asynchronous processes to finish and data to upload.
Proxy for `wandb.finish()`.
"""
self._wandb.finish()
def _log_trace_from_run(self, run: Run) -> None:
"""Logs a LangChain Run to W*B as a W&B Trace."""
self._ensure_run()
root_span = self.run_processor.process_span(run)
model_dict = self.run_processor.process_model(run)
if root_span is None:
return
model_trace = self._trace_tree.WBTraceTree(
root_span=root_span,
model_dict=model_dict,
)
if self._wandb.run is not None:
self._wandb.run.log({"langchain_trace": model_trace})
def _ensure_run(self, should_print_url: bool = False) -> None:
"""Ensures an active W&B run exists.
If not, will start a new run with the provided run_args.
"""
if self._wandb.run is None:
run_args = self._run_args or {} # type: ignore
run_args: dict = {**run_args} # type: ignore
if "settings" not in run_args: # type: ignore
run_args["settings"] = {"silent": True} # type: ignore
self._wandb.init(**run_args)
if self._wandb.run is not None:
if should_print_url:
run_url = self._wandb.run.settings.run_url
self._wandb.termlog(
f"Streaming LangChain activity to W&B at {run_url}\n"
"`WandbTracer` is currently in beta.\n"
"Please report any issues to "
"https://github.com/wandb/wandb/issues with the tag "
"`langchain`."
)
self._wandb.run._label(repo="langchain")
def _persist_run(self, run: "Run") -> None:
"""Persist a run."""
self._log_trace_from_run(run)
__all__ = [
"PRINT_WARNINGS",
"_serialize_io",
"RunProcessor",
"WandbRunArgs",
"WandbTracer",
]

View File

@@ -1,126 +1,6 @@
import os
from typing import Any, Dict, List, Optional
from uuid import UUID
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
from langchain_community.callbacks.trubrics_callback import (
TrubricsCallbackHandler,
_convert_message_to_dict,
)
from langchain_core.outputs import LLMResult
from langchain.callbacks.base import BaseCallbackHandler
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
else:
raise TypeError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
class TrubricsCallbackHandler(BaseCallbackHandler):
"""
Callback handler for Trubrics.
Args:
project: a trubrics project, default project is "default"
email: a trubrics account email, can equally be set in env variables
password: a trubrics account password, can equally be set in env variables
**kwargs: all other kwargs are parsed and set to trubrics prompt variables,
or added to the `metadata` dict
"""
def __init__(
self,
project: str = "default",
email: Optional[str] = None,
password: Optional[str] = None,
**kwargs: Any,
) -> None:
super().__init__()
try:
from trubrics import Trubrics
except ImportError:
raise ImportError(
"The TrubricsCallbackHandler requires installation of "
"the trubrics package. "
"Please install it with `pip install trubrics`."
)
self.trubrics = Trubrics(
project=project,
email=email or os.environ["TRUBRICS_EMAIL"],
password=password or os.environ["TRUBRICS_PASSWORD"],
)
self.config_model: dict = {}
self.prompt: Optional[str] = None
self.messages: Optional[list] = None
self.trubrics_kwargs: Optional[dict] = kwargs if kwargs else None
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
self.prompt = prompts[0]
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any,
) -> None:
self.messages = [_convert_message_to_dict(message) for message in messages[0]]
self.prompt = self.messages[-1]["content"]
def on_llm_end(self, response: LLMResult, run_id: UUID, **kwargs: Any) -> None:
tags = ["langchain"]
user_id = None
session_id = None
metadata: dict = {"langchain_run_id": run_id}
if self.messages:
metadata["messages"] = self.messages
if self.trubrics_kwargs:
if self.trubrics_kwargs.get("tags"):
tags.append(*self.trubrics_kwargs.pop("tags"))
user_id = self.trubrics_kwargs.pop("user_id", None)
session_id = self.trubrics_kwargs.pop("session_id", None)
metadata.update(self.trubrics_kwargs)
for generation in response.generations:
self.trubrics.log_prompt(
config_model={
"model": response.llm_output.get("model_name")
if response.llm_output
else "NA"
},
prompt=self.prompt,
generation=generation[0].text,
user_id=user_id,
session_id=session_id,
tags=tags,
metadata=metadata,
)
__all__ = ["_convert_message_to_dict", "TrubricsCallbackHandler"]

View File

@@ -1,258 +1,21 @@
import hashlib
from pathlib import Path
from typing import Any, Dict, Iterable, Tuple, Union
from langchain_community.callbacks.utils import (
BaseMetadataCallbackHandler,
_flatten_dict,
flatten_dict,
hash_string,
import_pandas,
import_spacy,
import_textstat,
load_json,
)
def import_spacy() -> Any:
"""Import the spacy python package and raise an error if it is not installed."""
try:
import spacy
except ImportError:
raise ImportError(
"This callback manager requires the `spacy` python "
"package installed. Please install it with `pip install spacy`"
)
return spacy
def import_pandas() -> Any:
"""Import the pandas python package and raise an error if it is not installed."""
try:
import pandas
except ImportError:
raise ImportError(
"This callback manager requires the `pandas` python "
"package installed. Please install it with `pip install pandas`"
)
return pandas
def import_textstat() -> Any:
"""Import the textstat python package and raise an error if it is not installed."""
try:
import textstat
except ImportError:
raise ImportError(
"This callback manager requires the `textstat` python "
"package installed. Please install it with `pip install textstat`"
)
return textstat
def _flatten_dict(
nested_dict: Dict[str, Any], parent_key: str = "", sep: str = "_"
) -> Iterable[Tuple[str, Any]]:
"""
Generator that yields flattened items from a nested dictionary for a flat dict.
Parameters:
nested_dict (dict): The nested dictionary to flatten.
parent_key (str): The prefix to prepend to the keys of the flattened dict.
sep (str): The separator to use between the parent key and the key of the
flattened dictionary.
Yields:
(str, any): A key-value pair from the flattened dictionary.
"""
for key, value in nested_dict.items():
new_key = parent_key + sep + key if parent_key else key
if isinstance(value, dict):
yield from _flatten_dict(value, new_key, sep)
else:
yield new_key, value
def flatten_dict(
nested_dict: Dict[str, Any], parent_key: str = "", sep: str = "_"
) -> Dict[str, Any]:
"""Flattens a nested dictionary into a flat dictionary.
Parameters:
nested_dict (dict): The nested dictionary to flatten.
parent_key (str): The prefix to prepend to the keys of the flattened dict.
sep (str): The separator to use between the parent key and the key of the
flattened dictionary.
Returns:
(dict): A flat dictionary.
"""
flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
return flat_dict
def hash_string(s: str) -> str:
"""Hash a string using sha1.
Parameters:
s (str): The string to hash.
Returns:
(str): The hashed string.
"""
return hashlib.sha1(s.encode("utf-8")).hexdigest()
def load_json(json_path: Union[str, Path]) -> str:
"""Load json file to a string.
Parameters:
json_path (str): The path to the json file.
Returns:
(str): The string representation of the json file.
"""
with open(json_path, "r") as f:
data = f.read()
return data
class BaseMetadataCallbackHandler:
"""This class handles the metadata and associated function states for callbacks.
Attributes:
step (int): The current step.
starts (int): The number of times the start method has been called.
ends (int): The number of times the end method has been called.
errors (int): The number of times the error method has been called.
text_ctr (int): The number of times the text method has been called.
ignore_llm_ (bool): Whether to ignore llm callbacks.
ignore_chain_ (bool): Whether to ignore chain callbacks.
ignore_agent_ (bool): Whether to ignore agent callbacks.
ignore_retriever_ (bool): Whether to ignore retriever callbacks.
always_verbose_ (bool): Whether to always be verbose.
chain_starts (int): The number of times the chain start method has been called.
chain_ends (int): The number of times the chain end method has been called.
llm_starts (int): The number of times the llm start method has been called.
llm_ends (int): The number of times the llm end method has been called.
llm_streams (int): The number of times the text method has been called.
tool_starts (int): The number of times the tool start method has been called.
tool_ends (int): The number of times the tool end method has been called.
agent_ends (int): The number of times the agent end method has been called.
on_llm_start_records (list): A list of records of the on_llm_start method.
on_llm_token_records (list): A list of records of the on_llm_token method.
on_llm_end_records (list): A list of records of the on_llm_end method.
on_chain_start_records (list): A list of records of the on_chain_start method.
on_chain_end_records (list): A list of records of the on_chain_end method.
on_tool_start_records (list): A list of records of the on_tool_start method.
on_tool_end_records (list): A list of records of the on_tool_end method.
on_agent_finish_records (list): A list of records of the on_agent_end method.
"""
def __init__(self) -> None:
self.step = 0
self.starts = 0
self.ends = 0
self.errors = 0
self.text_ctr = 0
self.ignore_llm_ = False
self.ignore_chain_ = False
self.ignore_agent_ = False
self.ignore_retriever_ = False
self.always_verbose_ = False
self.chain_starts = 0
self.chain_ends = 0
self.llm_starts = 0
self.llm_ends = 0
self.llm_streams = 0
self.tool_starts = 0
self.tool_ends = 0
self.agent_ends = 0
self.on_llm_start_records: list = []
self.on_llm_token_records: list = []
self.on_llm_end_records: list = []
self.on_chain_start_records: list = []
self.on_chain_end_records: list = []
self.on_tool_start_records: list = []
self.on_tool_end_records: list = []
self.on_text_records: list = []
self.on_agent_finish_records: list = []
self.on_agent_action_records: list = []
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return self.always_verbose_
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return self.ignore_llm_
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return self.ignore_chain_
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
def get_custom_callback_meta(self) -> Dict[str, Any]:
return {
"step": self.step,
"starts": self.starts,
"ends": self.ends,
"errors": self.errors,
"text_ctr": self.text_ctr,
"chain_starts": self.chain_starts,
"chain_ends": self.chain_ends,
"llm_starts": self.llm_starts,
"llm_ends": self.llm_ends,
"llm_streams": self.llm_streams,
"tool_starts": self.tool_starts,
"tool_ends": self.tool_ends,
"agent_ends": self.agent_ends,
}
def reset_callback_meta(self) -> None:
"""Reset the callback metadata."""
self.step = 0
self.starts = 0
self.ends = 0
self.errors = 0
self.text_ctr = 0
self.ignore_llm_ = False
self.ignore_chain_ = False
self.ignore_agent_ = False
self.always_verbose_ = False
self.chain_starts = 0
self.chain_ends = 0
self.llm_starts = 0
self.llm_ends = 0
self.llm_streams = 0
self.tool_starts = 0
self.tool_ends = 0
self.agent_ends = 0
self.on_llm_start_records = []
self.on_llm_token_records = []
self.on_llm_end_records = []
self.on_chain_start_records = []
self.on_chain_end_records = []
self.on_tool_start_records = []
self.on_tool_end_records = []
self.on_text_records = []
self.on_agent_finish_records = []
self.on_agent_action_records = []
return None
__all__ = [
"import_spacy",
"import_pandas",
"import_textstat",
"_flatten_dict",
"flatten_dict",
"hash_string",
"load_json",
"BaseMetadataCallbackHandler",
]

View File

@@ -1,587 +1,15 @@
import json
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.outputs import LLMResult
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
hash_string,
import_pandas,
import_spacy,
import_textstat,
from langchain_community.callbacks.wandb_callback import (
WandbCallbackHandler,
analyze_text,
construct_html_from_prompt_and_generation,
import_wandb,
load_json_to_dict,
)
def import_wandb() -> Any:
"""Import the wandb python package and raise an error if it is not installed."""
try:
import wandb # noqa: F401
except ImportError:
raise ImportError(
"To use the wandb callback manager you need to have the `wandb` python "
"package installed. Please install it with `pip install wandb`"
)
return wandb
def load_json_to_dict(json_path: Union[str, Path]) -> dict:
"""Load json file to a dictionary.
Parameters:
json_path (str): The path to the json file.
Returns:
(dict): The dictionary representation of the json file.
"""
with open(json_path, "r") as f:
data = json.load(f)
return data
def analyze_text(
text: str,
complexity_metrics: bool = True,
visualize: bool = True,
nlp: Any = None,
output_dir: Optional[Union[str, Path]] = None,
) -> dict:
"""Analyze text using textstat and spacy.
Parameters:
text (str): The text to analyze.
complexity_metrics (bool): Whether to compute complexity metrics.
visualize (bool): Whether to visualize the text.
nlp (spacy.lang): The spacy language model to use for visualization.
output_dir (str): The directory to save the visualization files to.
Returns:
(dict): A dictionary containing the complexity metrics and visualization
files serialized in a wandb.Html element.
"""
resp = {}
textstat = import_textstat()
wandb = import_wandb()
spacy = import_spacy()
if complexity_metrics:
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"smog_index": textstat.smog_index(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
"automated_readability_index": textstat.automated_readability_index(text),
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
"difficult_words": textstat.difficult_words(text),
"linsear_write_formula": textstat.linsear_write_formula(text),
"gunning_fog": textstat.gunning_fog(text),
"text_standard": textstat.text_standard(text),
"fernandez_huerta": textstat.fernandez_huerta(text),
"szigriszt_pazos": textstat.szigriszt_pazos(text),
"gutierrez_polini": textstat.gutierrez_polini(text),
"crawford": textstat.crawford(text),
"gulpease_index": textstat.gulpease_index(text),
"osman": textstat.osman(text),
}
resp.update(text_complexity_metrics)
if visualize and nlp and output_dir is not None:
doc = nlp(text)
dep_out = spacy.displacy.render( # type: ignore
doc, style="dep", jupyter=False, page=True
)
dep_output_path = Path(output_dir, hash_string(f"dep-{text}") + ".html")
dep_output_path.open("w", encoding="utf-8").write(dep_out)
ent_out = spacy.displacy.render( # type: ignore
doc, style="ent", jupyter=False, page=True
)
ent_output_path = Path(output_dir, hash_string(f"ent-{text}") + ".html")
ent_output_path.open("w", encoding="utf-8").write(ent_out)
text_visualizations = {
"dependency_tree": wandb.Html(str(dep_output_path)),
"entities": wandb.Html(str(ent_output_path)),
}
resp.update(text_visualizations)
return resp
def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> Any:
"""Construct an html element from a prompt and a generation.
Parameters:
prompt (str): The prompt.
generation (str): The generation.
Returns:
(wandb.Html): The html element."""
wandb = import_wandb()
formatted_prompt = prompt.replace("\n", "<br>")
formatted_generation = generation.replace("\n", "<br>")
return wandb.Html(
f"""
<p style="color:black;">{formatted_prompt}:</p>
<blockquote>
<p style="color:green;">
{formatted_generation}
</p>
</blockquote>
""",
inject=False,
)
class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""Callback Handler that logs to Weights and Biases.
Parameters:
job_type (str): The type of job.
project (str): The project to log to.
entity (str): The entity to log to.
tags (list): The tags to log.
group (str): The group to log to.
name (str): The name of the run.
notes (str): The notes to log.
visualize (bool): Whether to visualize the run.
complexity_metrics (bool): Whether to log complexity metrics.
stream_logs (bool): Whether to stream callback actions to W&B
This handler will utilize the associated callback method called and formats
the input of each callback function with metadata regarding the state of LLM run,
and adds the response to the list of records for both the {method}_records and
action. It then logs the response using the run.log() method to Weights and Biases.
"""
def __init__(
self,
job_type: Optional[str] = None,
project: Optional[str] = "langchain_callback_demo",
entity: Optional[str] = None,
tags: Optional[Sequence] = None,
group: Optional[str] = None,
name: Optional[str] = None,
notes: Optional[str] = None,
visualize: bool = False,
complexity_metrics: bool = False,
stream_logs: bool = False,
) -> None:
"""Initialize callback handler."""
wandb = import_wandb()
import_pandas()
import_textstat()
spacy = import_spacy()
super().__init__()
self.job_type = job_type
self.project = project
self.entity = entity
self.tags = tags
self.group = group
self.name = name
self.notes = notes
self.visualize = visualize
self.complexity_metrics = complexity_metrics
self.stream_logs = stream_logs
self.temp_dir = tempfile.TemporaryDirectory()
self.run: wandb.sdk.wandb_run.Run = wandb.init( # type: ignore
job_type=self.job_type,
project=self.project,
entity=self.entity,
tags=self.tags,
group=self.group,
name=self.name,
notes=self.notes,
)
warning = (
"DEPRECATION: The `WandbCallbackHandler` will soon be deprecated in favor "
"of the `WandbTracer`. Please update your code to use the `WandbTracer` "
"instead."
)
wandb.termwarn(
warning,
repeat=False,
)
self.callback_columns: list = []
self.action_records: list = []
self.complexity_metrics = complexity_metrics
self.visualize = visualize
self.nlp = spacy.load("en_core_web_sm")
def _init_resp(self) -> Dict:
return {k: None for k in self.callback_columns}
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
self.step += 1
self.llm_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_llm_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
for prompt in prompts:
prompt_resp = deepcopy(resp)
prompt_resp["prompts"] = prompt
self.on_llm_start_records.append(prompt_resp)
self.action_records.append(prompt_resp)
if self.stream_logs:
self.run.log(prompt_resp)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.step += 1
self.llm_streams += 1
resp = self._init_resp()
resp.update({"action": "on_llm_new_token", "token": token})
resp.update(self.get_custom_callback_meta())
self.on_llm_token_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.run.log(resp)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.step += 1
self.llm_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_llm_end"})
resp.update(flatten_dict(response.llm_output or {}))
resp.update(self.get_custom_callback_meta())
for generations in response.generations:
for generation in generations:
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
generation_resp.update(
analyze_text(
generation.text,
complexity_metrics=self.complexity_metrics,
visualize=self.visualize,
nlp=self.nlp,
output_dir=self.temp_dir.name,
)
)
self.on_llm_end_records.append(generation_resp)
self.action_records.append(generation_resp)
if self.stream_logs:
self.run.log(generation_resp)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.step += 1
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.step += 1
self.chain_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
chain_input = inputs["input"]
if isinstance(chain_input, str):
input_resp = deepcopy(resp)
input_resp["input"] = chain_input
self.on_chain_start_records.append(input_resp)
self.action_records.append(input_resp)
if self.stream_logs:
self.run.log(input_resp)
elif isinstance(chain_input, list):
for inp in chain_input:
input_resp = deepcopy(resp)
input_resp.update(inp)
self.on_chain_start_records.append(input_resp)
self.action_records.append(input_resp)
if self.stream_logs:
self.run.log(input_resp)
else:
raise ValueError("Unexpected data format provided!")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.step += 1
self.chain_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_chain_end", "outputs": outputs["output"]})
resp.update(self.get_custom_callback_meta())
self.on_chain_end_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.run.log(resp)
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.step += 1
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_tool_start", "input_str": input_str})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
self.on_tool_start_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.run.log(resp)
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.step += 1
self.tool_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_tool_end", "output": output})
resp.update(self.get_custom_callback_meta())
self.on_tool_end_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.run.log(resp)
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.step += 1
self.errors += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
"""
self.step += 1
self.text_ctr += 1
resp = self._init_resp()
resp.update({"action": "on_text", "text": text})
resp.update(self.get_custom_callback_meta())
self.on_text_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.run.log(resp)
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.step += 1
self.agent_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update(
{
"action": "on_agent_finish",
"output": finish.return_values["output"],
"log": finish.log,
}
)
resp.update(self.get_custom_callback_meta())
self.on_agent_finish_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.run.log(resp)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update(
{
"action": "on_agent_action",
"tool": action.tool,
"tool_input": action.tool_input,
"log": action.log,
}
)
resp.update(self.get_custom_callback_meta())
self.on_agent_action_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.run.log(resp)
def _create_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session."""
pd = import_pandas()
on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
llm_input_prompts_df = (
on_llm_start_records_df[["step", "prompts", "name"]]
.dropna(axis=1)
.rename({"step": "prompt_step"}, axis=1)
)
complexity_metrics_columns = []
visualizations_columns = []
if self.complexity_metrics:
complexity_metrics_columns = [
"flesch_reading_ease",
"flesch_kincaid_grade",
"smog_index",
"coleman_liau_index",
"automated_readability_index",
"dale_chall_readability_score",
"difficult_words",
"linsear_write_formula",
"gunning_fog",
"text_standard",
"fernandez_huerta",
"szigriszt_pazos",
"gutierrez_polini",
"crawford",
"gulpease_index",
"osman",
]
if self.visualize:
visualizations_columns = ["dependency_tree", "entities"]
llm_outputs_df = (
on_llm_end_records_df[
[
"step",
"text",
"token_usage_total_tokens",
"token_usage_prompt_tokens",
"token_usage_completion_tokens",
]
+ complexity_metrics_columns
+ visualizations_columns
]
.dropna(axis=1)
.rename({"step": "output_step", "text": "output"}, axis=1)
)
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
session_analysis_df["chat_html"] = session_analysis_df[
["prompts", "output"]
].apply(
lambda row: construct_html_from_prompt_and_generation(
row["prompts"], row["output"]
),
axis=1,
)
return session_analysis_df
def flush_tracker(
self,
langchain_asset: Any = None,
reset: bool = True,
finish: bool = False,
job_type: Optional[str] = None,
project: Optional[str] = None,
entity: Optional[str] = None,
tags: Optional[Sequence] = None,
group: Optional[str] = None,
name: Optional[str] = None,
notes: Optional[str] = None,
visualize: Optional[bool] = None,
complexity_metrics: Optional[bool] = None,
) -> None:
"""Flush the tracker and reset the session.
Args:
langchain_asset: The langchain asset to save.
reset: Whether to reset the session.
finish: Whether to finish the run.
job_type: The job type.
project: The project.
entity: The entity.
tags: The tags.
group: The group.
name: The name.
notes: The notes.
visualize: Whether to visualize.
complexity_metrics: Whether to compute complexity metrics.
Returns:
None
"""
pd = import_pandas()
wandb = import_wandb()
action_records_table = wandb.Table(dataframe=pd.DataFrame(self.action_records))
session_analysis_table = wandb.Table(
dataframe=self._create_session_analysis_df()
)
self.run.log(
{
"action_records": action_records_table,
"session_analysis": session_analysis_table,
}
)
if langchain_asset:
langchain_asset_path = Path(self.temp_dir.name, "model.json")
model_artifact = wandb.Artifact(name="model", type="model")
model_artifact.add(action_records_table, name="action_records")
model_artifact.add(session_analysis_table, name="session_analysis")
try:
langchain_asset.save(langchain_asset_path)
model_artifact.add_file(str(langchain_asset_path))
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
except ValueError:
langchain_asset.save_agent(langchain_asset_path)
model_artifact.add_file(str(langchain_asset_path))
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
except NotImplementedError as e:
print("Could not save model.")
print(repr(e))
pass
self.run.log_artifact(model_artifact)
if finish or reset:
self.run.finish()
self.temp_dir.cleanup()
self.reset_callback_meta()
if reset:
self.__init__( # type: ignore
job_type=job_type if job_type else self.job_type,
project=project if project else self.project,
entity=entity if entity else self.entity,
tags=tags if tags else self.tags,
group=group if group else self.group,
name=name if name else self.name,
notes=notes if notes else self.notes,
visualize=visualize if visualize else self.visualize,
complexity_metrics=complexity_metrics
if complexity_metrics
else self.complexity_metrics,
)
__all__ = [
"import_wandb",
"load_json_to_dict",
"analyze_text",
"construct_html_from_prompt_and_generation",
"WandbCallbackHandler",
]

View File

@@ -1,192 +1,7 @@
from __future__ import annotations
from langchain_community.callbacks.whylabs_callback import (
WhyLabsCallbackHandler,
diagnostic_logger,
import_langkit,
)
import logging
from typing import TYPE_CHECKING, Any, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.utils import get_from_env
if TYPE_CHECKING:
from whylogs.api.logger.logger import Logger
diagnostic_logger = logging.getLogger(__name__)
def import_langkit(
sentiment: bool = False,
toxicity: bool = False,
themes: bool = False,
) -> Any:
"""Import the langkit python package and raise an error if it is not installed.
Args:
sentiment: Whether to import the langkit.sentiment module. Defaults to False.
toxicity: Whether to import the langkit.toxicity module. Defaults to False.
themes: Whether to import the langkit.themes module. Defaults to False.
Returns:
The imported langkit module.
"""
try:
import langkit # noqa: F401
import langkit.regexes # noqa: F401
import langkit.textstat # noqa: F401
if sentiment:
import langkit.sentiment # noqa: F401
if toxicity:
import langkit.toxicity # noqa: F401
if themes:
import langkit.themes # noqa: F401
except ImportError:
raise ImportError(
"To use the whylabs callback manager you need to have the `langkit` python "
"package installed. Please install it with `pip install langkit`."
)
return langkit
class WhyLabsCallbackHandler(BaseCallbackHandler):
"""
Callback Handler for logging to WhyLabs. This callback handler utilizes
`langkit` to extract features from the prompts & responses when interacting with
an LLM. These features can be used to guardrail, evaluate, and observe interactions
over time to detect issues relating to hallucinations, prompt engineering,
or output validation. LangKit is an LLM monitoring toolkit developed by WhyLabs.
Here are some examples of what can be monitored with LangKit:
* Text Quality
- readability score
- complexity and grade scores
* Text Relevance
- Similarity scores between prompt/responses
- Similarity scores against user-defined themes
- Topic classification
* Security and Privacy
- patterns - count of strings matching a user-defined regex pattern group
- jailbreaks - similarity scores with respect to known jailbreak attempts
- prompt injection - similarity scores with respect to known prompt attacks
- refusals - similarity scores with respect to known LLM refusal responses
* Sentiment and Toxicity
- sentiment analysis
- toxicity analysis
For more information, see https://docs.whylabs.ai/docs/language-model-monitoring
or check out the LangKit repo here: https://github.com/whylabs/langkit
---
Args:
api_key (Optional[str]): WhyLabs API key. Optional because the preferred
way to specify the API key is with environment variable
WHYLABS_API_KEY.
org_id (Optional[str]): WhyLabs organization id to write profiles to.
Optional because the preferred way to specify the organization id is
with environment variable WHYLABS_DEFAULT_ORG_ID.
dataset_id (Optional[str]): WhyLabs dataset id to write profiles to.
Optional because the preferred way to specify the dataset id is
with environment variable WHYLABS_DEFAULT_DATASET_ID.
sentiment (bool): Whether to enable sentiment analysis. Defaults to False.
toxicity (bool): Whether to enable toxicity analysis. Defaults to False.
themes (bool): Whether to enable theme analysis. Defaults to False.
"""
def __init__(self, logger: Logger, handler: Any):
"""Initiate the rolling logger."""
super().__init__()
if hasattr(handler, "init"):
handler.init(self)
if hasattr(handler, "_get_callbacks"):
self._callbacks = handler._get_callbacks()
else:
self._callbacks = dict()
diagnostic_logger.warning("initialized handler without callbacks.")
self._logger = logger
def flush(self) -> None:
"""Explicitly write current profile if using a rolling logger."""
if self._logger and hasattr(self._logger, "_do_rollover"):
self._logger._do_rollover()
diagnostic_logger.info("Flushing WhyLabs logger, writing profile...")
def close(self) -> None:
"""Close any loggers to allow writing out of any profiles before exiting."""
if self._logger and hasattr(self._logger, "close"):
self._logger.close()
diagnostic_logger.info("Closing WhyLabs logger, see you next time!")
def __enter__(self) -> WhyLabsCallbackHandler:
return self
def __exit__(
self, exception_type: Any, exception_value: Any, traceback: Any
) -> None:
self.close()
@classmethod
def from_params(
cls,
*,
api_key: Optional[str] = None,
org_id: Optional[str] = None,
dataset_id: Optional[str] = None,
sentiment: bool = False,
toxicity: bool = False,
themes: bool = False,
logger: Optional[Logger] = None,
) -> WhyLabsCallbackHandler:
"""Instantiate whylogs Logger from params.
Args:
api_key (Optional[str]): WhyLabs API key. Optional because the preferred
way to specify the API key is with environment variable
WHYLABS_API_KEY.
org_id (Optional[str]): WhyLabs organization id to write profiles to.
If not set must be specified in environment variable
WHYLABS_DEFAULT_ORG_ID.
dataset_id (Optional[str]): The model or dataset this callback is gathering
telemetry for. If not set must be specified in environment variable
WHYLABS_DEFAULT_DATASET_ID.
sentiment (bool): If True will initialize a model to perform
sentiment analysis compound score. Defaults to False and will not gather
this metric.
toxicity (bool): If True will initialize a model to score
toxicity. Defaults to False and will not gather this metric.
themes (bool): If True will initialize a model to calculate
distance to configured themes. Defaults to None and will not gather this
metric.
logger (Optional[Logger]): If specified will bind the configured logger as
the telemetry gathering agent. Defaults to LangKit schema with periodic
WhyLabs writer.
"""
# langkit library will import necessary whylogs libraries
import_langkit(sentiment=sentiment, toxicity=toxicity, themes=themes)
import whylogs as why
from langkit.callback_handler import get_callback_instance
from whylogs.api.writer.whylabs import WhyLabsWriter
from whylogs.experimental.core.udf_schema import udf_schema
if logger is None:
api_key = api_key or get_from_env("api_key", "WHYLABS_API_KEY")
org_id = org_id or get_from_env("org_id", "WHYLABS_DEFAULT_ORG_ID")
dataset_id = dataset_id or get_from_env(
"dataset_id", "WHYLABS_DEFAULT_DATASET_ID"
)
whylabs_writer = WhyLabsWriter(
api_key=api_key, org_id=org_id, dataset_id=dataset_id
)
whylabs_logger = why.logger(
mode="rolling", interval=5, when="M", schema=udf_schema()
)
whylabs_logger.append_writer(writer=whylabs_writer)
else:
diagnostic_logger.info("Using passed in whylogs logger {logger}")
whylabs_logger = logger
callback_handler_cls = get_callback_instance(logger=whylabs_logger, impl=cls)
diagnostic_logger.info(
"Started whylogs Logger with WhyLabsWriter and initialized LangKit. 📝"
)
return callback_handler_cls
__all__ = ["diagnostic_logger", "import_langkit", "WhyLabsCallbackHandler"]

View File

@@ -2,12 +2,11 @@ from abc import ABC, abstractmethod
from typing import Callable, List, Tuple
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import BaseLLM
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
class BasePromptSelector(BaseModel, ABC):
"""Base class for prompt selectors."""

View File

@@ -1,16 +1,3 @@
from abc import ABC, abstractmethod
from typing import Iterator, List
from langchain_community.chat_loaders.base import BaseChatLoader
from langchain_core.chat_sessions import ChatSession
class BaseChatLoader(ABC):
"""Base class for chat loaders."""
@abstractmethod
def lazy_load(self) -> Iterator[ChatSession]:
"""Lazy load the chat sessions."""
def load(self) -> List[ChatSession]:
"""Eagerly load the chat sessions into memory."""
return list(self.lazy_load())
__all__ = ["BaseChatLoader"]

View File

@@ -1,79 +1,6 @@
import json
import logging
from pathlib import Path
from typing import Iterator, Union
from langchain_community.chat_loaders.facebook_messenger import (
FolderFacebookMessengerChatLoader,
SingleFileFacebookMessengerChatLoader,
)
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import HumanMessage
from langchain.chat_loaders.base import BaseChatLoader
logger = logging.getLogger(__file__)
class SingleFileFacebookMessengerChatLoader(BaseChatLoader):
"""Load `Facebook Messenger` chat data from a single file.
Args:
path (Union[Path, str]): The path to the chat file.
Attributes:
path (Path): The path to the chat file.
"""
def __init__(self, path: Union[Path, str]) -> None:
super().__init__()
self.file_path = path if isinstance(path, Path) else Path(path)
def lazy_load(self) -> Iterator[ChatSession]:
"""Lazy loads the chat data from the file.
Yields:
ChatSession: A chat session containing the loaded messages.
"""
with open(self.file_path) as f:
data = json.load(f)
sorted_data = sorted(data["messages"], key=lambda x: x["timestamp_ms"])
messages = []
for m in sorted_data:
messages.append(
HumanMessage(
content=m["content"], additional_kwargs={"sender": m["sender_name"]}
)
)
yield ChatSession(messages=messages)
class FolderFacebookMessengerChatLoader(BaseChatLoader):
"""Load `Facebook Messenger` chat data from a folder.
Args:
path (Union[str, Path]): The path to the directory
containing the chat files.
Attributes:
path (Path): The path to the directory containing the chat files.
"""
def __init__(self, path: Union[str, Path]) -> None:
super().__init__()
self.directory_path = Path(path) if isinstance(path, str) else path
def lazy_load(self) -> Iterator[ChatSession]:
"""Lazy loads the chat data from the folder.
Yields:
ChatSession: A chat session containing the loaded messages.
"""
inbox_path = self.directory_path / "inbox"
for _dir in inbox_path.iterdir():
if _dir.is_dir():
for _file in _dir.iterdir():
if _file.suffix.lower() == ".json":
file_loader = SingleFileFacebookMessengerChatLoader(path=_file)
for result in file_loader.lazy_load():
yield result
__all__ = ["SingleFileFacebookMessengerChatLoader", "FolderFacebookMessengerChatLoader"]

View File

@@ -1,112 +1,7 @@
import base64
import re
from typing import Any, Iterator
from langchain_community.chat_loaders.gmail import (
GMailLoader,
_extract_email_content,
_get_message_data,
)
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import HumanMessage
from langchain.chat_loaders.base import BaseChatLoader
def _extract_email_content(msg: Any) -> HumanMessage:
from_email = None
for values in msg["payload"]["headers"]:
name = values["name"]
if name == "From":
from_email = values["value"]
if from_email is None:
raise ValueError
for part in msg["payload"]["parts"]:
if part["mimeType"] == "text/plain":
data = part["body"]["data"]
data = base64.urlsafe_b64decode(data).decode("utf-8")
# Regular expression to split the email body at the first
# occurrence of a line that starts with "On ... wrote:"
pattern = re.compile(r"\r\nOn .+(\r\n)*wrote:\r\n")
# Split the email body and extract the first part
newest_response = re.split(pattern, data)[0]
message = HumanMessage(
content=newest_response, additional_kwargs={"sender": from_email}
)
return message
raise ValueError
def _get_message_data(service: Any, message: Any) -> ChatSession:
msg = service.users().messages().get(userId="me", id=message["id"]).execute()
message_content = _extract_email_content(msg)
in_reply_to = None
email_data = msg["payload"]["headers"]
for values in email_data:
name = values["name"]
if name == "In-Reply-To":
in_reply_to = values["value"]
if in_reply_to is None:
raise ValueError
thread_id = msg["threadId"]
thread = service.users().threads().get(userId="me", id=thread_id).execute()
messages = thread["messages"]
response_email = None
for message in messages:
email_data = message["payload"]["headers"]
for values in email_data:
if values["name"] == "Message-ID":
message_id = values["value"]
if message_id == in_reply_to:
response_email = message
if response_email is None:
raise ValueError
starter_content = _extract_email_content(response_email)
return ChatSession(messages=[starter_content, message_content])
class GMailLoader(BaseChatLoader):
"""Load data from `GMail`.
There are many ways you could want to load data from GMail.
This loader is currently fairly opinionated in how to do so.
The way it does it is it first looks for all messages that you have sent.
It then looks for messages where you are responding to a previous email.
It then fetches that previous email, and creates a training example
of that email, followed by your email.
Note that there are clear limitations here. For example,
all examples created are only looking at the previous email for context.
To use:
- Set up a Google Developer Account:
Go to the Google Developer Console, create a project,
and enable the Gmail API for that project.
This will give you a credentials.json file that you'll need later.
"""
def __init__(self, creds: Any, n: int = 100, raise_error: bool = False) -> None:
super().__init__()
self.creds = creds
self.n = n
self.raise_error = raise_error
def lazy_load(self) -> Iterator[ChatSession]:
from googleapiclient.discovery import build
service = build("gmail", "v1", credentials=self.creds)
results = (
service.users()
.messages()
.list(userId="me", labelIds=["SENT"], maxResults=self.n)
.execute()
)
messages = results.get("messages", [])
for message in messages:
try:
yield _get_message_data(service, message)
except Exception as e:
# TODO: handle errors better
if self.raise_error:
raise e
else:
pass
__all__ = ["_extract_email_content", "_get_message_data", "GMailLoader"]

View File

@@ -1,161 +1,3 @@
from __future__ import annotations
from langchain_community.chat_loaders.imessage import IMessageChatLoader
from pathlib import Path
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import HumanMessage
from langchain.chat_loaders.base import BaseChatLoader
if TYPE_CHECKING:
import sqlite3
class IMessageChatLoader(BaseChatLoader):
"""Load chat sessions from the `iMessage` chat.db SQLite file.
It only works on macOS when you have iMessage enabled and have the chat.db file.
The chat.db file is likely located at ~/Library/Messages/chat.db. However, your
terminal may not have permission to access this file. To resolve this, you can
copy the file to a different location, change the permissions of the file, or
grant full disk access for your terminal emulator
in System Settings > Security and Privacy > Full Disk Access.
"""
def __init__(self, path: Optional[Union[str, Path]] = None):
"""
Initialize the IMessageChatLoader.
Args:
path (str or Path, optional): Path to the chat.db SQLite file.
Defaults to None, in which case the default path
~/Library/Messages/chat.db will be used.
"""
if path is None:
path = Path.home() / "Library" / "Messages" / "chat.db"
self.db_path = path if isinstance(path, Path) else Path(path)
if not self.db_path.exists():
raise FileNotFoundError(f"File {self.db_path} not found")
try:
import sqlite3 # noqa: F401
except ImportError as e:
raise ImportError(
"The sqlite3 module is required to load iMessage chats.\n"
"Please install it with `pip install pysqlite3`"
) from e
def _parse_attributedBody(self, attributedBody: bytes) -> str:
"""
Parse the attributedBody field of the message table
for the text content of the message.
The attributedBody field is a binary blob that contains
the message content after the byte string b"NSString":
5 bytes 1-3 bytes `len` bytes
... | b"NSString" | preamble | `len` | contents | ...
The 5 preamble bytes are always b"\x01\x94\x84\x01+"
The size of `len` is either 1 byte or 3 bytes:
- If the first byte in `len` is b"\x81" then `len` is 3 bytes long.
So the message length is the 2 bytes after, in little Endian.
- Otherwise, the size of `len` is 1 byte, and the message length is
that byte.
Args:
attributedBody (bytes): attributedBody field of the message table.
Return:
str: Text content of the message.
"""
content = attributedBody.split(b"NSString")[1][5:]
length, start = content[0], 1
if content[0] == 129:
length, start = int.from_bytes(content[1:3], "little"), 3
return content[start : start + length].decode("utf-8", errors="ignore")
def _load_single_chat_session(
self, cursor: "sqlite3.Cursor", chat_id: int
) -> ChatSession:
"""
Load a single chat session from the iMessage chat.db.
Args:
cursor: SQLite cursor object.
chat_id (int): ID of the chat session to load.
Returns:
ChatSession: Loaded chat session.
"""
results: List[HumanMessage] = []
query = """
SELECT message.date, handle.id, message.text, message.attributedBody
FROM message
JOIN chat_message_join ON message.ROWID = chat_message_join.message_id
JOIN handle ON message.handle_id = handle.ROWID
WHERE chat_message_join.chat_id = ?
ORDER BY message.date ASC;
"""
cursor.execute(query, (chat_id,))
messages = cursor.fetchall()
for date, sender, text, attributedBody in messages:
if text:
content = text
elif attributedBody:
content = self._parse_attributedBody(attributedBody)
else: # Skip messages with no content
continue
results.append(
HumanMessage(
role=sender,
content=content,
additional_kwargs={
"message_time": date,
"sender": sender,
},
)
)
return ChatSession(messages=results)
def lazy_load(self) -> Iterator[ChatSession]:
"""
Lazy load the chat sessions from the iMessage chat.db
and yield them in the required format.
Yields:
ChatSession: Loaded chat session.
"""
import sqlite3
try:
conn = sqlite3.connect(self.db_path)
except sqlite3.OperationalError as e:
raise ValueError(
f"Could not open iMessage DB file {self.db_path}.\n"
"Make sure your terminal emulator has disk access to this file.\n"
" You can either copy the DB file to an accessible location"
" or grant full disk access for your terminal emulator."
" You can grant full disk access for your terminal emulator"
" in System Settings > Security and Privacy > Full Disk Access."
) from e
cursor = conn.cursor()
# Fetch the list of chat IDs sorted by time (most recent first)
query = """SELECT chat_id
FROM message
JOIN chat_message_join ON message.ROWID = chat_message_join.message_id
GROUP BY chat_id
ORDER BY MAX(date) DESC;"""
cursor.execute(query)
chat_ids = [row[0] for row in cursor.fetchall()]
for chat_id in chat_ids:
yield self._load_single_chat_session(cursor, chat_id)
conn.close()
__all__ = ["IMessageChatLoader"]

View File

@@ -1,159 +1,6 @@
from __future__ import annotations
from langchain_community.chat_loaders.langsmith import (
LangSmithDatasetChatLoader,
LangSmithRunChatLoader,
)
import logging
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union, cast
from langchain_core.chat_sessions import ChatSession
from langchain_core.load import load
from langchain.chat_loaders.base import BaseChatLoader
if TYPE_CHECKING:
from langsmith.client import Client
from langsmith.schemas import Run
logger = logging.getLogger(__name__)
class LangSmithRunChatLoader(BaseChatLoader):
"""
Load chat sessions from a list of LangSmith "llm" runs.
Attributes:
runs (Iterable[Union[str, Run]]): The list of LLM run IDs or run objects.
client (Client): Instance of LangSmith client for fetching data.
"""
def __init__(
self, runs: Iterable[Union[str, Run]], client: Optional["Client"] = None
):
"""
Initialize a new LangSmithRunChatLoader instance.
:param runs: List of LLM run IDs or run objects.
:param client: An instance of LangSmith client, if not provided,
a new client instance will be created.
"""
from langsmith.client import Client
self.runs = runs
self.client = client or Client()
def _load_single_chat_session(self, llm_run: "Run") -> ChatSession:
"""
Convert an individual LangSmith LLM run to a ChatSession.
:param llm_run: The LLM run object.
:return: A chat session representing the run's data.
"""
chat_session = LangSmithRunChatLoader._get_messages_from_llm_run(llm_run)
functions = LangSmithRunChatLoader._get_functions_from_llm_run(llm_run)
if functions:
chat_session["functions"] = functions
return chat_session
@staticmethod
def _get_messages_from_llm_run(llm_run: "Run") -> ChatSession:
"""
Extract messages from a LangSmith LLM run.
:param llm_run: The LLM run object.
:return: ChatSession with the extracted messages.
"""
if llm_run.run_type != "llm":
raise ValueError(f"Expected run of type llm. Got: {llm_run.run_type}")
if "messages" not in llm_run.inputs:
raise ValueError(f"Run has no 'messages' inputs. Got {llm_run.inputs}")
if not llm_run.outputs:
raise ValueError("Cannot convert pending run")
messages = load(llm_run.inputs)["messages"]
message_chunk = load(llm_run.outputs)["generations"][0]["message"]
return ChatSession(messages=messages + [message_chunk])
@staticmethod
def _get_functions_from_llm_run(llm_run: "Run") -> Optional[List[Dict]]:
"""
Extract functions from a LangSmith LLM run if they exist.
:param llm_run: The LLM run object.
:return: Functions from the run or None.
"""
if llm_run.run_type != "llm":
raise ValueError(f"Expected run of type llm. Got: {llm_run.run_type}")
return (llm_run.extra or {}).get("invocation_params", {}).get("functions")
def lazy_load(self) -> Iterator[ChatSession]:
"""
Lazy load the chat sessions from the iterable of run IDs.
This method fetches the runs and converts them to chat sessions on-the-fly,
yielding one session at a time.
:return: Iterator of chat sessions containing messages.
"""
from langsmith.schemas import Run
for run_obj in self.runs:
try:
if hasattr(run_obj, "id"):
run = run_obj
else:
run = self.client.read_run(run_obj)
session = self._load_single_chat_session(cast(Run, run))
yield session
except ValueError as e:
logger.warning(f"Could not load run {run_obj}: {repr(e)}")
continue
class LangSmithDatasetChatLoader(BaseChatLoader):
"""
Load chat sessions from a LangSmith dataset with the "chat" data type.
Attributes:
dataset_name (str): The name of the LangSmith dataset.
client (Client): Instance of LangSmith client for fetching data.
"""
def __init__(self, *, dataset_name: str, client: Optional["Client"] = None):
"""
Initialize a new LangSmithChatDatasetLoader instance.
:param dataset_name: The name of the LangSmith dataset.
:param client: An instance of LangSmith client; if not provided,
a new client instance will be created.
"""
try:
from langsmith.client import Client
except ImportError as e:
raise ImportError(
"The LangSmith client is required to load LangSmith datasets.\n"
"Please install it with `pip install langsmith`"
) from e
self.dataset_name = dataset_name
self.client = client or Client()
def lazy_load(self) -> Iterator[ChatSession]:
"""
Lazy load the chat sessions from the specified LangSmith dataset.
This method fetches the chat data from the dataset and
converts each data point to chat sessions on-the-fly,
yielding one session at a time.
:return: Iterator of chat sessions containing messages.
"""
from langchain.adapters import openai as oai_adapter # noqa: E402
data = self.client.read_dataset_openai_finetuning(
dataset_name=self.dataset_name
)
for data_point in data:
yield ChatSession(
messages=[
oai_adapter.convert_dict_to_message(m)
for m in data_point.get("messages", [])
],
functions=data_point.get("functions"),
)
__all__ = ["LangSmithRunChatLoader", "LangSmithDatasetChatLoader"]

View File

@@ -1,86 +1,3 @@
import json
import logging
import re
import zipfile
from pathlib import Path
from typing import Dict, Iterator, List, Union
from langchain_community.chat_loaders.slack import SlackChatLoader
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import AIMessage, HumanMessage
from langchain.chat_loaders.base import BaseChatLoader
logger = logging.getLogger(__name__)
class SlackChatLoader(BaseChatLoader):
"""Load `Slack` conversations from a dump zip file."""
def __init__(
self,
path: Union[str, Path],
):
"""
Initialize the chat loader with the path to the exported Slack dump zip file.
:param path: Path to the exported Slack dump zip file.
"""
self.zip_path = path if isinstance(path, Path) else Path(path)
if not self.zip_path.exists():
raise FileNotFoundError(f"File {self.zip_path} not found")
def _load_single_chat_session(self, messages: List[Dict]) -> ChatSession:
results: List[Union[AIMessage, HumanMessage]] = []
previous_sender = None
for message in messages:
if not isinstance(message, dict):
continue
text = message.get("text", "")
timestamp = message.get("ts", "")
sender = message.get("user", "")
if not sender:
continue
skip_pattern = re.compile(
r"<@U\d+> has joined the channel", flags=re.IGNORECASE
)
if skip_pattern.match(text):
continue
if sender == previous_sender:
results[-1].content += "\n\n" + text
results[-1].additional_kwargs["events"].append(
{"message_time": timestamp}
)
else:
results.append(
HumanMessage(
role=sender,
content=text,
additional_kwargs={
"sender": sender,
"events": [{"message_time": timestamp}],
},
)
)
previous_sender = sender
return ChatSession(messages=results)
def _read_json(self, zip_file: zipfile.ZipFile, file_path: str) -> List[dict]:
"""Read JSON data from a zip subfile."""
with zip_file.open(file_path, "r") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError(f"Expected list of dictionaries, got {type(data)}")
return data
def lazy_load(self) -> Iterator[ChatSession]:
"""
Lazy load the chat sessions from the Slack dump file and yield them
in the required format.
:return: Iterator of chat sessions containing messages.
"""
with zipfile.ZipFile(str(self.zip_path), "r") as zip_file:
for file_path in zip_file.namelist():
if file_path.endswith(".json"):
messages = self._read_json(zip_file, file_path)
yield self._load_single_chat_session(messages)
__all__ = ["SlackChatLoader"]

View File

@@ -1,151 +1,3 @@
import json
import logging
import os
import tempfile
import zipfile
from pathlib import Path
from typing import Iterator, List, Union
from langchain_community.chat_loaders.telegram import TelegramChatLoader
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain.chat_loaders.base import BaseChatLoader
logger = logging.getLogger(__name__)
class TelegramChatLoader(BaseChatLoader):
"""Load `telegram` conversations to LangChain chat messages.
To export, use the Telegram Desktop app from
https://desktop.telegram.org/, select a conversation, click the three dots
in the top right corner, and select "Export chat history". Then select
"Machine-readable JSON" (preferred) to export. Note: the 'lite' versions of
the desktop app (like "Telegram for MacOS") do not support exporting chat
history.
"""
def __init__(
self,
path: Union[str, Path],
):
"""Initialize the TelegramChatLoader.
Args:
path (Union[str, Path]): Path to the exported Telegram chat zip,
directory, json, or HTML file.
"""
self.path = path if isinstance(path, str) else str(path)
def _load_single_chat_session_html(self, file_path: str) -> ChatSession:
"""Load a single chat session from an HTML file.
Args:
file_path (str): Path to the HTML file.
Returns:
ChatSession: The loaded chat session.
"""
try:
from bs4 import BeautifulSoup
except ImportError:
raise ImportError(
"Please install the 'beautifulsoup4' package to load"
" Telegram HTML files. You can do this by running"
"'pip install beautifulsoup4' in your terminal."
)
with open(file_path, "r", encoding="utf-8") as file:
soup = BeautifulSoup(file, "html.parser")
results: List[Union[HumanMessage, AIMessage]] = []
previous_sender = None
for message in soup.select(".message.default"):
timestamp = message.select_one(".pull_right.date.details")["title"]
from_name_element = message.select_one(".from_name")
if from_name_element is None and previous_sender is None:
logger.debug("from_name not found in message")
continue
elif from_name_element is None:
from_name = previous_sender
else:
from_name = from_name_element.text.strip()
text = message.select_one(".text").text.strip()
results.append(
HumanMessage(
content=text,
additional_kwargs={
"sender": from_name,
"events": [{"message_time": timestamp}],
},
)
)
previous_sender = from_name
return ChatSession(messages=results)
def _load_single_chat_session_json(self, file_path: str) -> ChatSession:
"""Load a single chat session from a JSON file.
Args:
file_path (str): Path to the JSON file.
Returns:
ChatSession: The loaded chat session.
"""
with open(file_path, "r", encoding="utf-8") as file:
data = json.load(file)
messages = data.get("messages", [])
results: List[BaseMessage] = []
for message in messages:
text = message.get("text", "")
timestamp = message.get("date", "")
from_name = message.get("from", "")
results.append(
HumanMessage(
content=text,
additional_kwargs={
"sender": from_name,
"events": [{"message_time": timestamp}],
},
)
)
return ChatSession(messages=results)
def _iterate_files(self, path: str) -> Iterator[str]:
"""Iterate over files in a directory or zip file.
Args:
path (str): Path to the directory or zip file.
Yields:
str: Path to each file.
"""
if os.path.isfile(path) and path.endswith((".html", ".json")):
yield path
elif os.path.isdir(path):
for root, _, files in os.walk(path):
for file in files:
if file.endswith((".html", ".json")):
yield os.path.join(root, file)
elif zipfile.is_zipfile(path):
with zipfile.ZipFile(path) as zip_file:
for file in zip_file.namelist():
if file.endswith((".html", ".json")):
with tempfile.TemporaryDirectory() as temp_dir:
yield zip_file.extract(file, path=temp_dir)
def lazy_load(self) -> Iterator[ChatSession]:
"""Lazy load the messages from the chat file and yield them
in as chat sessions.
Yields:
ChatSession: The loaded chat session.
"""
for file_path in self._iterate_files(self.path):
if file_path.endswith(".html"):
yield self._load_single_chat_session_html(file_path)
elif file_path.endswith(".json"):
yield self._load_single_chat_session_json(file_path)
__all__ = ["TelegramChatLoader"]

View File

@@ -1,95 +1,13 @@
"""Utilities for chat loaders."""
from copy import deepcopy
from typing import Iterable, Iterator, List
from langchain_community.chat_loaders.utils import (
map_ai_messages,
map_ai_messages_in_session,
merge_chat_runs,
merge_chat_runs_in_session,
)
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import AIMessage, BaseMessage
def merge_chat_runs_in_session(
chat_session: ChatSession, delimiter: str = "\n\n"
) -> ChatSession:
"""Merge chat runs together in a chat session.
A chat run is a sequence of messages from the same sender.
Args:
chat_session: A chat session.
Returns:
A chat session with merged chat runs.
"""
messages: List[BaseMessage] = []
for message in chat_session["messages"]:
if not isinstance(message.content, str):
raise ValueError(
"Chat Loaders only support messages with content type string, "
f"got {message.content}"
)
if not messages:
messages.append(deepcopy(message))
elif (
isinstance(message, type(messages[-1]))
and messages[-1].additional_kwargs.get("sender") is not None
and messages[-1].additional_kwargs["sender"]
== message.additional_kwargs.get("sender")
):
if not isinstance(messages[-1].content, str):
raise ValueError(
"Chat Loaders only support messages with content type string, "
f"got {messages[-1].content}"
)
messages[-1].content = (
messages[-1].content + delimiter + message.content
).strip()
messages[-1].additional_kwargs.get("events", []).extend(
message.additional_kwargs.get("events") or []
)
else:
messages.append(deepcopy(message))
return ChatSession(messages=messages)
def merge_chat_runs(chat_sessions: Iterable[ChatSession]) -> Iterator[ChatSession]:
"""Merge chat runs together.
A chat run is a sequence of messages from the same sender.
Args:
chat_sessions: A list of chat sessions.
Returns:
A list of chat sessions with merged chat runs.
"""
for chat_session in chat_sessions:
yield merge_chat_runs_in_session(chat_session)
def map_ai_messages_in_session(chat_sessions: ChatSession, sender: str) -> ChatSession:
"""Convert messages from the specified 'sender' to AI messages.
This is useful for fine-tuning the AI to adapt to your voice.
"""
messages = []
num_converted = 0
for message in chat_sessions["messages"]:
if message.additional_kwargs.get("sender") == sender:
message = AIMessage(
content=message.content,
additional_kwargs=message.additional_kwargs.copy(),
example=getattr(message, "example", None),
)
num_converted += 1
messages.append(message)
return ChatSession(messages=messages)
def map_ai_messages(
chat_sessions: Iterable[ChatSession], sender: str
) -> Iterator[ChatSession]:
"""Convert messages from the specified 'sender' to AI messages.
This is useful for fine-tuning the AI to adapt to your voice.
"""
for chat_session in chat_sessions:
yield map_ai_messages_in_session(chat_session, sender)
__all__ = [
"merge_chat_runs_in_session",
"merge_chat_runs",
"map_ai_messages_in_session",
"map_ai_messages",
]

View File

@@ -1,119 +1,3 @@
import logging
import os
import re
import zipfile
from typing import Iterator, List, Union
from langchain_community.chat_loaders.whatsapp import WhatsAppChatLoader
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import AIMessage, HumanMessage
from langchain.chat_loaders.base import BaseChatLoader
logger = logging.getLogger(__name__)
class WhatsAppChatLoader(BaseChatLoader):
"""Load `WhatsApp` conversations from a dump zip file or directory."""
def __init__(self, path: str):
"""Initialize the WhatsAppChatLoader.
Args:
path (str): Path to the exported WhatsApp chat
zip directory, folder, or file.
To generate the dump, open the chat, click the three dots in the top
right corner, and select "More". Then select "Export chat" and
choose "Without media".
"""
self.path = path
ignore_lines = [
"This message was deleted",
"<Media omitted>",
"image omitted",
"Messages and calls are end-to-end encrypted. No one outside of this chat,"
" not even WhatsApp, can read or listen to them.",
]
self._ignore_lines = re.compile(
r"(" + "|".join([r"\u200E*" + line for line in ignore_lines]) + r")",
flags=re.IGNORECASE,
)
self._message_line_regex = re.compile(
r"\u200E*\[?(\d{1,2}/\d{1,2}/\d{2,4}, \d{1,2}:\d{2}:\d{2} (?:AM|PM))\]?[ \u200E]*([^:]+): (.+)", # noqa
flags=re.IGNORECASE,
)
def _load_single_chat_session(self, file_path: str) -> ChatSession:
"""Load a single chat session from a file.
Args:
file_path (str): Path to the chat file.
Returns:
ChatSession: The loaded chat session.
"""
with open(file_path, "r", encoding="utf-8") as file:
txt = file.read()
# Split messages by newlines, but keep multi-line messages grouped
chat_lines: List[str] = []
current_message = ""
for line in txt.split("\n"):
if self._message_line_regex.match(line):
if current_message:
chat_lines.append(current_message)
current_message = line
else:
current_message += " " + line.strip()
if current_message:
chat_lines.append(current_message)
results: List[Union[HumanMessage, AIMessage]] = []
for line in chat_lines:
result = self._message_line_regex.match(line.strip())
if result:
timestamp, sender, text = result.groups()
if not self._ignore_lines.match(text.strip()):
results.append(
HumanMessage(
role=sender,
content=text,
additional_kwargs={
"sender": sender,
"events": [{"message_time": timestamp}],
},
)
)
else:
logger.debug(f"Could not parse line: {line}")
return ChatSession(messages=results)
def _iterate_files(self, path: str) -> Iterator[str]:
"""Iterate over the files in a directory or zip file.
Args:
path (str): Path to the directory or zip file.
Yields:
str: The path to each file.
"""
if os.path.isfile(path):
yield path
elif os.path.isdir(path):
for root, _, files in os.walk(path):
for file in files:
if file.endswith(".txt"):
yield os.path.join(root, file)
elif zipfile.is_zipfile(path):
with zipfile.ZipFile(path) as zip_file:
for file in zip_file.namelist():
if file.endswith(".txt"):
yield zip_file.extract(file)
def lazy_load(self) -> Iterator[ChatSession]:
"""Lazy load the messages from the chat file and yield
them as chat sessions.
Yields:
Iterator[ChatSession]: The loaded chat sessions.
"""
yield self._load_single_chat_session(self.path)
__all__ = ["WhatsAppChatLoader"]

View File

@@ -1,226 +1,11 @@
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
from langchain_community.chat_models.anthropic import (
ChatAnthropic,
_convert_one_message_to_text,
convert_messages_to_prompt_anthropic,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.prompt_values import PromptValue
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain.llms.anthropic import _AnthropicCommon
def _convert_one_message_to_text(
message: BaseMessage,
human_prompt: str,
ai_prompt: str,
) -> str:
content = cast(str, message.content)
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {content}"
elif isinstance(message, HumanMessage):
message_text = f"{human_prompt} {content}"
elif isinstance(message, AIMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemMessage):
message_text = content
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def convert_messages_to_prompt_anthropic(
messages: List[BaseMessage],
*,
human_prompt: str = "\n\nHuman:",
ai_prompt: str = "\n\nAssistant:",
) -> str:
"""Format a list of messages into a full prompt for the Anthropic model
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
human_prompt (str, optional): Human prompt tag. Defaults to "\n\nHuman:".
ai_prompt (str, optional): AI prompt tag. Defaults to "\n\nAssistant:".
Returns:
str: Combined string with necessary human_prompt and ai_prompt tags.
"""
messages = messages.copy() # don't mutate the original list
if not isinstance(messages[-1], AIMessage):
messages.append(AIMessage(content=""))
text = "".join(
_convert_one_message_to_text(message, human_prompt, ai_prompt)
for message in messages
)
# trim off the trailing ' ' that might come from the "Assistant: "
return text.rstrip()
class ChatAnthropic(BaseChatModel, _AnthropicCommon):
"""`Anthropic` chat large language models.
To use, you should have the ``anthropic`` python package installed, and the
environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass
it as a named parameter to the constructor.
Example:
.. code-block:: python
import anthropic
from langchain.chat_models import ChatAnthropic
model = ChatAnthropic(model="<model_name>", anthropic_api_key="my-api-key")
"""
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
arbitrary_types_allowed = True
@property
def lc_secrets(self) -> Dict[str, str]:
return {"anthropic_api_key": "ANTHROPIC_API_KEY"}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "anthropic-chat"
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "anthropic"]
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
"""Format a list of messages into a full prompt for the Anthropic model
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
Returns:
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
"""
prompt_params = {}
if self.HUMAN_PROMPT:
prompt_params["human_prompt"] = self.HUMAN_PROMPT
if self.AI_PROMPT:
prompt_params["ai_prompt"] = self.AI_PROMPT
return convert_messages_to_prompt_anthropic(messages=messages, **prompt_params)
def convert_prompt(self, prompt: PromptValue) -> str:
return self._convert_messages_to_prompt(prompt.to_messages())
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop:
params["stop_sequences"] = stop
stream_resp = self.client.completions.create(**params, stream=True)
for data in stream_resp:
delta = data.completion
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
run_manager.on_llm_new_token(delta)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop:
params["stop_sequences"] = stop
stream_resp = await self.async_client.completions.create(**params, stream=True)
async for data in stream_resp:
delta = data.completion
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
await run_manager.on_llm_new_token(delta)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
prompt = self._convert_messages_to_prompt(
messages,
)
params: Dict[str, Any] = {
"prompt": prompt,
**self._default_params,
**kwargs,
}
if stop:
params["stop_sequences"] = stop
response = self.client.completions.create(**params)
completion = response.completion
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
prompt = self._convert_messages_to_prompt(
messages,
)
params: Dict[str, Any] = {
"prompt": prompt,
**self._default_params,
**kwargs,
}
if stop:
params["stop_sequences"] = stop
response = await self.async_client.completions.create(**params)
completion = response.completion
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""
if not self.count_tokens:
raise NameError("Please ensure the anthropic package is loaded")
return self.count_tokens(text)
__all__ = [
"_convert_one_message_to_text",
"convert_messages_to_prompt_anthropic",
"ChatAnthropic",
]

View File

@@ -1,221 +1,7 @@
"""Anyscale Endpoints chat wrapper. Relies heavily on ChatOpenAI."""
from __future__ import annotations
import logging
import os
import sys
from typing import TYPE_CHECKING, Dict, Optional, Set
import requests
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str
from langchain.adapters.openai import convert_message_to_dict
from langchain.chat_models.openai import (
ChatOpenAI,
_import_tiktoken,
from langchain_community.chat_models.anyscale import (
DEFAULT_API_BASE,
DEFAULT_MODEL,
ChatAnyscale,
)
from langchain.utils import get_from_dict_or_env
from langchain.utils.openai import is_openai_v1
if TYPE_CHECKING:
import tiktoken
logger = logging.getLogger(__name__)
DEFAULT_API_BASE = "https://api.endpoints.anyscale.com/v1"
DEFAULT_MODEL = "meta-llama/Llama-2-7b-chat-hf"
class ChatAnyscale(ChatOpenAI):
"""`Anyscale` Chat large language models.
See https://www.anyscale.com/ for information about Anyscale.
To use, you should have the ``openai`` python package installed, and the
environment variable ``ANYSCALE_API_KEY`` set with your API key.
Alternatively, you can use the anyscale_api_key keyword argument.
Any parameters that are valid to be passed to the `openai.create` call can be passed
in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain.chat_models import ChatAnyscale
chat = ChatAnyscale(model_name="meta-llama/Llama-2-7b-chat-hf")
"""
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "anyscale-chat"
@property
def lc_secrets(self) -> Dict[str, str]:
return {"anyscale_api_key": "ANYSCALE_API_KEY"}
@classmethod
def is_lc_serializable(cls) -> bool:
return False
anyscale_api_key: SecretStr
"""AnyScale Endpoints API keys."""
model_name: str = Field(default=DEFAULT_MODEL, alias="model")
"""Model name to use."""
anyscale_api_base: str = Field(default=DEFAULT_API_BASE)
"""Base URL path for API requests,
leave blank if not using a proxy or service emulator."""
anyscale_proxy: Optional[str] = None
"""To support explicit proxy for Anyscale."""
available_models: Optional[Set[str]] = None
"""Available models from Anyscale API."""
@staticmethod
def get_available_models(
anyscale_api_key: Optional[str] = None,
anyscale_api_base: str = DEFAULT_API_BASE,
) -> Set[str]:
"""Get available models from Anyscale API."""
try:
anyscale_api_key = anyscale_api_key or os.environ["ANYSCALE_API_KEY"]
except KeyError as e:
raise ValueError(
"Anyscale API key must be passed as keyword argument or "
"set in environment variable ANYSCALE_API_KEY.",
) from e
models_url = f"{anyscale_api_base}/models"
models_response = requests.get(
models_url,
headers={
"Authorization": f"Bearer {anyscale_api_key}",
},
)
if models_response.status_code != 200:
raise ValueError(
f"Error getting models from {models_url}: "
f"{models_response.status_code}",
)
return {model["id"] for model in models_response.json()["data"]}
@root_validator(pre=True)
def validate_environment_override(cls, values: dict) -> dict:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env(
values,
"anyscale_api_key",
"ANYSCALE_API_KEY",
)
values["anyscale_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"anyscale_api_key",
"ANYSCALE_API_KEY",
)
)
values["openai_api_base"] = get_from_dict_or_env(
values,
"anyscale_api_base",
"ANYSCALE_API_BASE",
default=DEFAULT_API_BASE,
)
values["openai_proxy"] = get_from_dict_or_env(
values,
"anyscale_proxy",
"ANYSCALE_PROXY",
default="",
)
try:
import openai
except ImportError as e:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`.",
) from e
try:
if is_openai_v1():
client_params = {
"api_key": values["openai_api_key"],
"base_url": values["openai_api_base"],
# To do: future support
# "organization": values["openai_organization"],
# "timeout": values["request_timeout"],
# "max_retries": values["max_retries"],
# "default_headers": values["default_headers"],
# "default_query": values["default_query"],
# "http_client": values["http_client"],
}
values["client"] = openai.OpenAI(**client_params).chat.completions
else:
values["client"] = openai.ChatCompletion
except AttributeError as exc:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`.",
) from exc
if "model_name" not in values.keys():
values["model_name"] = DEFAULT_MODEL
model_name = values["model_name"]
available_models = cls.get_available_models(
values["openai_api_key"],
values["openai_api_base"],
)
if model_name not in available_models:
raise ValueError(
f"Model name {model_name} not found in available models: "
f"{available_models}.",
)
values["available_models"] = available_models
return values
def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:
tiktoken_ = _import_tiktoken()
if self.tiktoken_model_name is not None:
model = self.tiktoken_model_name
else:
model = self.model_name
# Returns the number of tokens used by a list of messages.
try:
encoding = tiktoken_.encoding_for_model("gpt-3.5-turbo-0301")
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)
return model, encoding
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
"""Calculate num tokens with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if sys.version_info[1] <= 7:
return super().get_num_tokens_from_messages(messages)
model, encoding = self._get_encoding_model()
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# Cast str(value) in case the message value is not a string
# This occurs with function messages
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
return num_tokens
__all__ = ["DEFAULT_API_BASE", "DEFAULT_MODEL", "ChatAnyscale"]

View File

@@ -1,271 +1,3 @@
"""Azure OpenAI chat wrapper."""
from __future__ import annotations
from langchain_community.chat_models.azure_openai import AzureChatOpenAI
import logging
import os
import warnings
from typing import Any, Callable, Dict, List, Union
from langchain_core.outputs import ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain.chat_models.openai import ChatOpenAI
from langchain.utils import get_from_dict_or_env
from langchain.utils.openai import is_openai_v1
logger = logging.getLogger(__name__)
class AzureChatOpenAI(ChatOpenAI):
"""`Azure OpenAI` Chat Completion API.
To use this class you
must have a deployed model on Azure OpenAI. Use `deployment_name` in the
constructor to refer to the "Model deployment name" in the Azure portal.
In addition, you should have the ``openai`` python package installed, and the
following environment variables set or passed in constructor in lower case:
- ``AZURE_OPENAI_API_KEY``
- ``AZURE_OPENAI_API_ENDPOINT``
- ``AZURE_OPENAI_AD_TOKEN``
- ``OPENAI_API_VERSION``
- ``OPENAI_PROXY``
For example, if you have `gpt-35-turbo` deployed, with the deployment name
`35-turbo-dev`, the constructor should look like:
.. code-block:: python
AzureChatOpenAI(
azure_deployment="35-turbo-dev",
openai_api_version="2023-05-15",
)
Be aware the API version may change.
You can also specify the version of the model using ``model_version`` constructor
parameter, as Azure OpenAI doesn't return model version with the response.
Default is empty. When you specify the version, it will be appended to the
model name in the response. Setting correct version will help you to calculate the
cost properly. Model version is not validated, so make sure you set it correctly
to get the correct cost.
Any parameters that are valid to be passed to the openai.create call can be passed
in, even if not explicitly saved on this class.
"""
azure_endpoint: Union[str, None] = None
"""Your Azure endpoint, including the resource.
Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided.
Example: `https://example-resource.azure.openai.com/`
"""
deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment")
"""A model deployment.
If given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints.
"""
openai_api_version: str = Field(default="", alias="api_version")
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
openai_api_key: Union[str, None] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided."""
azure_ad_token: Union[str, None] = None
"""Your Azure Active Directory token.
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
For more:
https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id.
""" # noqa: E501
azure_ad_token_provider: Union[Callable[[], str], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every request.
"""
model_version: str = ""
"""Legacy, for openai<1.0.0 support."""
openai_api_type: str = ""
"""Legacy, for openai<1.0.0 support."""
validate_base_url: bool = True
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
infer if it is a base_url or azure_endpoint and update accordingly.
"""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "azure_openai"]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
# Check OPENAI_KEY for backwards compatibility.
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
# other forms of azure credentials.
values["openai_api_key"] = (
values["openai_api_key"]
or os.getenv("AZURE_OPENAI_API_KEY")
or os.getenv("OPENAI_API_KEY")
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
"OPENAI_API_BASE"
)
values["openai_api_version"] = values["openai_api_version"] or os.getenv(
"OPENAI_API_VERSION"
)
# Check OPENAI_ORGANIZATION for backwards compatibility.
values["openai_organization"] = (
values["openai_organization"]
or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION")
)
values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
"AZURE_OPENAI_ENDPOINT"
)
values["azure_ad_token"] = values["azure_ad_token"] or os.getenv(
"AZURE_OPENAI_AD_TOKEN"
)
values["openai_api_type"] = get_from_dict_or_env(
values, "openai_api_type", "OPENAI_API_TYPE", default="azure"
)
values["openai_proxy"] = get_from_dict_or_env(
values, "openai_proxy", "OPENAI_PROXY", default=""
)
try:
import openai
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if is_openai_v1():
# For backwards compatibility. Before openai v1, no distinction was made
# between azure_endpoint and base_url (openai_api_base).
openai_api_base = values["openai_api_base"]
if openai_api_base and values["validate_base_url"]:
if "/openai" not in openai_api_base:
values["openai_api_base"] = (
values["openai_api_base"].rstrip("/") + "/openai"
)
warnings.warn(
"As of openai>=1.0.0, Azure endpoints should be specified via "
f"the `azure_endpoint` param not `openai_api_base` "
f"(or alias `base_url`). Updating `openai_api_base` from "
f"{openai_api_base} to {values['openai_api_base']}."
)
if values["deployment_name"]:
warnings.warn(
"As of openai>=1.0.0, if `deployment_name` (or alias "
"`azure_deployment`) is specified then "
"`openai_api_base` (or alias `base_url`) should not be. "
"Instead use `deployment_name` (or alias `azure_deployment`) "
"and `azure_endpoint`."
)
if values["deployment_name"] not in values["openai_api_base"]:
warnings.warn(
"As of openai>=1.0.0, if `openai_api_base` "
"(or alias `base_url`) is specified it is expected to be "
"of the form "
"https://example-resource.azure.openai.com/openai/deployments/example-deployment. " # noqa: E501
f"Updating {openai_api_base} to "
f"{values['openai_api_base']}."
)
values["openai_api_base"] += (
"/deployments/" + values["deployment_name"]
)
values["deployment_name"] = None
client_params = {
"api_version": values["openai_api_version"],
"azure_endpoint": values["azure_endpoint"],
"azure_deployment": values["deployment_name"],
"api_key": values["openai_api_key"],
"azure_ad_token": values["azure_ad_token"],
"azure_ad_token_provider": values["azure_ad_token_provider"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
"default_query": values["default_query"],
"http_client": values["http_client"],
}
values["client"] = openai.AzureOpenAI(**client_params).chat.completions
values["async_client"] = openai.AsyncAzureOpenAI(
**client_params
).chat.completions
else:
values["client"] = openai.ChatCompletion
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
if is_openai_v1():
return super()._default_params
else:
return {
**super()._default_params,
"engine": self.deployment_name,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**self._default_params}
@property
def _client_params(self) -> Dict[str, Any]:
"""Get the config params used for the openai client."""
if is_openai_v1():
return super()._client_params
else:
return {
**super()._client_params,
"api_type": self.openai_api_type,
"api_version": self.openai_api_version,
}
@property
def _llm_type(self) -> str:
return "azure-openai-chat"
@property
def lc_attributes(self) -> Dict[str, Any]:
return {
"openai_api_type": self.openai_api_type,
"openai_api_version": self.openai_api_version,
}
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
if not isinstance(response, dict):
response = response.dict()
for res in response["choices"]:
if res.get("finish_reason", None) == "content_filter":
raise ValueError(
"Azure has not provided the response due to a content filter "
"being triggered"
)
chat_result = super()._create_chat_result(response)
if "model" in response:
model = response["model"]
if self.model_version:
model = f"{model}-{self.model_version}"
if chat_result.llm_output is not None and isinstance(
chat_result.llm_output, dict
):
chat_result.llm_output["model_name"] = model
return chat_result
__all__ = ["AzureChatOpenAI"]

View File

@@ -1,167 +1,6 @@
import json
from typing import Any, Dict, List, Optional, cast
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
from langchain_community.chat_models.azureml_endpoint import (
AzureMLChatOnlineEndpoint,
LlamaContentFormatter,
)
from langchain_core.pydantic_v1 import SecretStr, validator
from langchain_core.utils import convert_to_secret_str
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.llms.azureml_endpoint import AzureMLEndpointClient, ContentFormatterBase
from langchain.utils import get_from_dict_or_env
class LlamaContentFormatter(ContentFormatterBase):
"""Content formatter for `LLaMA`."""
SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
@staticmethod
def _convert_message_to_dict(message: BaseMessage) -> Dict:
"""Converts message to a dict according to role"""
content = cast(str, message.content)
if isinstance(message, HumanMessage):
return {
"role": "user",
"content": ContentFormatterBase.escape_special_characters(content),
}
elif isinstance(message, AIMessage):
return {
"role": "assistant",
"content": ContentFormatterBase.escape_special_characters(content),
}
elif isinstance(message, SystemMessage):
return {
"role": "system",
"content": ContentFormatterBase.escape_special_characters(content),
}
elif (
isinstance(message, ChatMessage)
and message.role in LlamaContentFormatter.SUPPORTED_ROLES
):
return {
"role": message.role,
"content": ContentFormatterBase.escape_special_characters(content),
}
else:
supported = ",".join(
[role for role in LlamaContentFormatter.SUPPORTED_ROLES]
)
raise ValueError(
f"""Received unsupported role.
Supported roles for the LLaMa Foundation Model: {supported}"""
)
def _format_request_payload(
self, messages: List[BaseMessage], model_kwargs: Dict
) -> bytes:
chat_messages = [
LlamaContentFormatter._convert_message_to_dict(message)
for message in messages
]
prompt = json.dumps(
{"input_data": {"input_string": chat_messages, "parameters": model_kwargs}}
)
return self.format_request_payload(prompt=prompt, model_kwargs=model_kwargs)
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
"""Formats the request according to the chosen api"""
return str.encode(prompt)
def format_response_payload(self, output: bytes) -> str:
"""Formats response"""
return json.loads(output)["output"]
class AzureMLChatOnlineEndpoint(SimpleChatModel):
"""`AzureML` Chat models API.
Example:
.. code-block:: python
azure_chat = AzureMLChatOnlineEndpoint(
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
endpoint_api_key="my-api-key",
content_formatter=content_formatter,
)
"""
endpoint_url: str = ""
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_URL`."""
endpoint_api_key: SecretStr = convert_to_secret_str("")
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_API_KEY`."""
http_client: Any = None #: :meta private:
content_formatter: Any = None
"""The content formatter that provides an input and output
transform function to handle formats between the LLM and
the endpoint"""
model_kwargs: Optional[dict] = None
"""Keyword arguments to pass to the model."""
@validator("http_client", always=True, allow_reuse=True)
@classmethod
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
"""Validate that api key and python package exist in environment."""
values["endpoint_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY")
)
endpoint_url = get_from_dict_or_env(
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
)
http_client = AzureMLEndpointClient(
endpoint_url, values["endpoint_api_key"].get_secret_value()
)
return http_client
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{"model_kwargs": _model_kwargs},
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "azureml_chat_endpoint"
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to an AzureML Managed Online endpoint.
Args:
messages: The messages in the conversation with the chat model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = azureml_model("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
request_payload = self.content_formatter._format_request_payload(
messages, _model_kwargs
)
response_payload = self.http_client.call(request_payload, **kwargs)
generated_text = self.content_formatter.format_response_payload(
response_payload
)
return generated_text
__all__ = ["LlamaContentFormatter", "AzureMLChatOnlineEndpoint"]

View File

@@ -1,296 +1,17 @@
import hashlib
import json
import logging
import time
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
import requests
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
convert_to_secret_str,
get_pydantic_field_names,
from langchain_community.chat_models.baichuan import (
DEFAULT_API_BASE,
ChatBaichuan,
_convert_delta_to_message_chunk,
_convert_dict_to_message,
_convert_message_to_dict,
_signature,
)
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel, generate_from_stream
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
DEFAULT_API_BASE = "https://api.baichuan-ai.com/v1"
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
else:
raise TypeError(f"Got unknown type {message}")
return message_dict
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict.get("content", "") or "")
else:
return ChatMessage(content=_dict["content"], role=role)
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)
# signature generation
def _signature(secret_key: SecretStr, payload: Dict[str, Any], timestamp: int) -> str:
input_str = secret_key.get_secret_value() + json.dumps(payload) + str(timestamp)
md5 = hashlib.md5()
md5.update(input_str.encode("utf-8"))
return md5.hexdigest()
class ChatBaichuan(BaseChatModel):
"""Baichuan chat models API by Baichuan Intelligent Technology.
For more information, see https://platform.baichuan-ai.com/docs/api
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {
"baichuan_api_key": "BAICHUAN_API_KEY",
"baichuan_secret_key": "BAICHUAN_SECRET_KEY",
}
@property
def lc_serializable(self) -> bool:
return True
baichuan_api_base: str = Field(default=DEFAULT_API_BASE)
"""Baichuan custom endpoints"""
baichuan_api_key: Optional[SecretStr] = None
"""Baichuan API Key"""
baichuan_secret_key: Optional[SecretStr] = None
"""Baichuan Secret Key"""
streaming: bool = False
"""Whether to stream the results or not."""
request_timeout: int = 60
"""request timeout for chat http requests"""
model = "Baichuan2-53B"
"""model name of Baichuan, default is `Baichuan2-53B`."""
temperature: float = 0.3
"""What sampling temperature to use."""
top_k: int = 5
"""What search sampling control to use."""
top_p: float = 0.85
"""What probability mass to use."""
with_search_enhance: bool = False
"""Whether to use search enhance, default is False."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for API call not explicitly specified."""
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
values["baichuan_api_base"] = get_from_dict_or_env(
values,
"baichuan_api_base",
"BAICHUAN_API_BASE",
DEFAULT_API_BASE,
)
values["baichuan_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"baichuan_api_key",
"BAICHUAN_API_KEY",
)
)
values["baichuan_secret_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"baichuan_secret_key",
"BAICHUAN_SECRET_KEY",
)
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Baichuan API."""
normal_params = {
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"with_search_enhance": self.with_search_enhance,
}
return {**normal_params, **self.model_kwargs}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
res = self._chat(messages, **kwargs)
response = res.json()
if response.get("code") != 0:
raise ValueError(f"Error from Baichuan api response: {response}")
return self._create_chat_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
res = self._chat(messages, **kwargs)
default_chunk_class = AIMessageChunk
for chunk in res.iter_lines():
response = json.loads(chunk)
if response.get("code") != 0:
raise ValueError(f"Error from Baichuan api response: {response}")
data = response.get("data")
for m in data.get("messages"):
chunk = _convert_delta_to_message_chunk(m, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(chunk.content)
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
if self.baichuan_secret_key is None:
raise ValueError("Baichuan secret key is not set.")
parameters = {**self._default_params, **kwargs}
model = parameters.pop("model")
headers = parameters.pop("headers", {})
payload = {
"model": model,
"messages": [_convert_message_to_dict(m) for m in messages],
"parameters": parameters,
}
timestamp = int(time.time())
url = self.baichuan_api_base
if self.streaming:
url = f"{url}/stream"
url = f"{url}/chat"
api_key = ""
if self.baichuan_api_key:
api_key = self.baichuan_api_key.get_secret_value()
res = requests.post(
url=url,
timeout=self.request_timeout,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
"X-BC-Timestamp": str(timestamp),
"X-BC-Signature": _signature(
secret_key=self.baichuan_secret_key,
payload=payload,
timestamp=timestamp,
),
"X-BC-Sign-Algo": "MD5",
**headers,
},
json=payload,
stream=self.streaming,
)
return res
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for m in response["data"]["messages"]:
message = _convert_dict_to_message(m)
gen = ChatGeneration(message=message)
generations.append(gen)
token_usage = response["usage"]
llm_output = {"token_usage": token_usage, "model": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
@property
def _llm_type(self) -> str:
return "baichuan-chat"
__all__ = [
"DEFAULT_API_BASE",
"_convert_message_to_dict",
"_convert_dict_to_message",
"_convert_delta_to_message_chunk",
"_signature",
"ChatBaichuan",
]

View File

@@ -1,346 +1,7 @@
from __future__ import annotations
import logging
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, cast
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
from langchain_community.chat_models.baidu_qianfan_endpoint import (
QianfanChatEndpoint,
_convert_dict_to_message,
convert_message_to_dict,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
def convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a message to a dictionary that can be passed to the API."""
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
else:
raise TypeError(f"Got unknown type {message}")
return message_dict
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
content = _dict.get("result", "") or ""
if _dict.get("function_call"):
additional_kwargs = {"function_call": dict(_dict["function_call"])}
if "thoughts" in additional_kwargs["function_call"]:
# align to api sample, which affects the llm function_call output
additional_kwargs["function_call"].pop("thoughts")
else:
additional_kwargs = {}
return AIMessage(
content=content,
additional_kwargs={**_dict.get("body", {}), **additional_kwargs},
)
class QianfanChatEndpoint(BaseChatModel):
"""Baidu Qianfan chat models.
To use, you should have the ``qianfan`` python package installed, and
the environment variable ``qianfan_ak`` and ``qianfan_sk`` set with your
API key and Secret Key.
ak, sk are required parameters
which you could get from https://cloud.baidu.com/product/wenxinworkshop
Example:
.. code-block:: python
from langchain.chat_models import QianfanChatEndpoint
qianfan_chat = QianfanChatEndpoint(model="ERNIE-Bot",
endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
client: Any
qianfan_ak: Optional[SecretStr] = None
qianfan_sk: Optional[SecretStr] = None
streaming: Optional[bool] = False
"""Whether to stream the results or not."""
request_timeout: Optional[int] = 60
"""request timeout for chat http requests"""
top_p: Optional[float] = 0.8
temperature: Optional[float] = 0.95
penalty_score: Optional[float] = 1
"""Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo.
In the case of other model, passing these params will not affect the result.
"""
model: str = "ERNIE-Bot-turbo"
"""Model name.
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
preset models are mapping to an endpoint.
`model` will be ignored if `endpoint` is set.
Default is ERNIE-Bot-turbo.
"""
endpoint: Optional[str] = None
"""Endpoint of the Qianfan LLM, required if custom model used."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
values["qianfan_ak"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"qianfan_ak",
"QIANFAN_AK",
)
)
values["qianfan_sk"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"qianfan_sk",
"QIANFAN_SK",
)
)
params = {
"ak": values["qianfan_ak"].get_secret_value(),
"sk": values["qianfan_sk"].get_secret_value(),
"model": values["model"],
"stream": values["streaming"],
}
if values["endpoint"] is not None and values["endpoint"] != "":
params["endpoint"] = values["endpoint"]
try:
import qianfan
values["client"] = qianfan.ChatCompletion(**params)
except ImportError:
raise ValueError(
"qianfan package not found, please install it with "
"`pip install qianfan`"
)
return values
@property
def _identifying_params(self) -> Dict[str, Any]:
return {
**{"endpoint": self.endpoint, "model": self.model},
**super()._identifying_params,
}
@property
def _llm_type(self) -> str:
"""Return type of chat_model."""
return "baidu-qianfan-chat"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Qianfan API."""
normal_params = {
"model": self.model,
"endpoint": self.endpoint,
"stream": self.streaming,
"request_timeout": self.request_timeout,
"top_p": self.top_p,
"temperature": self.temperature,
"penalty_score": self.penalty_score,
}
return {**normal_params, **self.model_kwargs}
def _convert_prompt_msg_params(
self,
messages: List[BaseMessage],
**kwargs: Any,
) -> Dict[str, Any]:
"""
Converts a list of messages into a dictionary containing the message content
and default parameters.
Args:
messages (List[BaseMessage]): The list of messages.
**kwargs (Any): Optional arguments to add additional parameters to the
resulting dictionary.
Returns:
Dict[str, Any]: A dictionary containing the message content and default
parameters.
"""
messages_dict: Dict[str, Any] = {
"messages": [
convert_message_to_dict(m)
for m in messages
if not isinstance(m, SystemMessage)
]
}
for i in [i for i, m in enumerate(messages) if isinstance(m, SystemMessage)]:
if "system" not in messages_dict:
messages_dict["system"] = ""
messages_dict["system"] += cast(str, messages[i].content) + "\n"
return {
**messages_dict,
**self._default_params,
**kwargs,
}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Call out to an qianfan models endpoint for each generation with a prompt.
Args:
messages: The messages to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = qianfan_model("Tell me a joke.")
"""
if self.streaming:
completion = ""
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
lc_msg = AIMessage(content=completion, additional_kwargs={})
gen = ChatGeneration(
message=lc_msg,
generation_info=dict(finish_reason="stop"),
)
return ChatResult(
generations=[gen],
llm_output={"token_usage": {}, "model_name": self.model},
)
params = self._convert_prompt_msg_params(messages, **kwargs)
response_payload = self.client.do(**params)
lc_msg = _convert_dict_to_message(response_payload)
gen = ChatGeneration(
message=lc_msg,
generation_info={
"finish_reason": "stop",
**response_payload.get("body", {}),
},
)
token_usage = response_payload.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=[gen], llm_output=llm_output)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
completion = ""
token_usage = {}
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
completion += chunk.text
lc_msg = AIMessage(content=completion, additional_kwargs={})
gen = ChatGeneration(
message=lc_msg,
generation_info=dict(finish_reason="stop"),
)
return ChatResult(
generations=[gen],
llm_output={"token_usage": {}, "model_name": self.model},
)
params = self._convert_prompt_msg_params(messages, **kwargs)
response_payload = await self.client.ado(**params)
lc_msg = _convert_dict_to_message(response_payload)
generations = []
gen = ChatGeneration(
message=lc_msg,
generation_info={
"finish_reason": "stop",
**response_payload.get("body", {}),
},
)
generations.append(gen)
token_usage = response_payload.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._convert_prompt_msg_params(messages, **kwargs)
for res in self.client.do(**params):
if res:
msg = _convert_dict_to_message(res)
chunk = ChatGenerationChunk(
text=res["result"],
message=AIMessageChunk(
content=msg.content,
role="assistant",
additional_kwargs=msg.additional_kwargs,
),
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._convert_prompt_msg_params(messages, **kwargs)
async for res in await self.client.ado(**params):
if res:
msg = _convert_dict_to_message(res)
chunk = ChatGenerationChunk(
text=res["result"],
message=AIMessageChunk(
content=msg.content,
role="assistant",
additional_kwargs=msg.additional_kwargs,
),
)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
__all__ = ["convert_message_to_dict", "_convert_dict_to_message", "QianfanChatEndpoint"]

View File

@@ -1,129 +1,3 @@
from typing import Any, Dict, Iterator, List, Optional
from langchain_community.chat_models.bedrock import BedrockChat, ChatPromptAdapter
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Extra
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
from langchain.chat_models.base import BaseChatModel
from langchain.chat_models.meta import convert_messages_to_prompt_llama
from langchain.llms.bedrock import BedrockBase
from langchain.utilities.anthropic import (
get_num_tokens_anthropic,
get_token_ids_anthropic,
)
class ChatPromptAdapter:
"""Adapter class to prepare the inputs from Langchain to prompt format
that Chat model expects.
"""
@classmethod
def convert_messages_to_prompt(
cls, provider: str, messages: List[BaseMessage]
) -> str:
if provider == "anthropic":
prompt = convert_messages_to_prompt_anthropic(messages=messages)
elif provider == "meta":
prompt = convert_messages_to_prompt_llama(messages=messages)
else:
raise NotImplementedError(
f"Provider {provider} model does not support chat."
)
return prompt
class BedrockChat(BaseChatModel, BedrockBase):
"""A chat model that uses the Bedrock API."""
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "amazon_bedrock_chat"
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "bedrock"]
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
if self.region_name:
attributes["region_name"] = self.region_name
return attributes
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
provider = self._get_provider()
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
)
for chunk in self._prepare_input_and_invoke_stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
delta = chunk.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
completion = ""
if self.streaming:
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
else:
provider = self._get_provider()
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
)
params: Dict[str, Any] = {**kwargs}
if stop:
params["stop_sequences"] = stop
completion = self._prepare_input_and_invoke(
prompt=prompt, stop=stop, run_manager=run_manager, **params
)
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])
def get_num_tokens(self, text: str) -> int:
if self._model_is_anthropic:
return get_num_tokens_anthropic(text)
else:
return super().get_num_tokens(text)
def get_token_ids(self, text: str) -> List[int]:
if self._model_is_anthropic:
return get_token_ids_anthropic(text)
else:
return super().get_token_ids(text)
__all__ = ["ChatPromptAdapter", "BedrockChat"]

Some files were not shown because too many files have changed in this diff Show More