azure-dynamic-sessions: migrate to repo (#27468)

This commit is contained in:
Erick Friis 2024-10-18 12:30:48 -07:00 committed by GitHub
parent 30660786b3
commit a562c54f7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 2 additions and 3122 deletions

View File

@ -1 +0,0 @@
__pycache__

View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2023 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,61 +0,0 @@
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
test:
poetry run pytest $(TEST_FILE)
tests:
poetry run pytest $(TEST_FILE)
test_watch:
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/azure --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_azure_dynamic_sessions
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
check_imports: $(shell find langchain_azure_dynamic_sessions -name '*.py')
poetry run python ./scripts/check_imports.py $^
######################
# HELP
######################
help:
@echo '----'
@echo 'check_imports - check imports'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

View File

@ -1,36 +1,3 @@
# langchain-azure-dynamic-sessions
This package contains the LangChain integration for Azure Container Apps dynamic sessions. You can use it to add a secure and scalable code interpreter to your agents.
## Installation
```bash
pip install -U langchain-azure-dynamic-sessions
```
## Usage
You first need to create an Azure Container Apps session pool and obtain its management endpoint. Then you can use the `SessionsPythonREPLTool` tool to give your agent the ability to execute Python code.
```python
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
# get the management endpoint from the session pool in the Azure portal
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
prompt = hub.pull("hwchase17/react")
tools=[tool]
react_agent = create_react_agent(
llm=llm,
tools=tools,
prompt=prompt,
)
react_agent_executor = AgentExecutor(agent=react_agent, tools=tools, verbose=True, handle_parsing_errors=True)
react_agent_executor.invoke({"input": "What is the current time in Vancouver, Canada?"})
```
By default, the tool uses `DefaultAzureCredential` to authenticate with Azure. If you're using a user-assigned managed identity, you must set the `AZURE_CLIENT_ID` environment variable to the ID of the managed identity.
This package has moved!
https://github.com/langchain-ai/langchain-azure/tree/main/libs/azure-dynamic-sessions

View File

@ -1,7 +0,0 @@
"""This package provides tools for managing dynamic sessions in Azure."""
from langchain_azure_dynamic_sessions.tools.sessions import SessionsPythonREPLTool
__all__ = [
"SessionsPythonREPLTool",
]

View File

@ -1,7 +0,0 @@
"""This package provides tools for managing dynamic sessions in Azure."""
from langchain_azure_dynamic_sessions.tools.sessions import SessionsPythonREPLTool
__all__ = [
"SessionsPythonREPLTool",
]

View File

@ -1,314 +0,0 @@
"""This is the Azure Dynamic Sessions module.
This module provides the SessionsPythonREPLTool class for
managing dynamic sessions in Azure.
"""
import importlib.metadata
import json
import os
import re
import urllib
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from io import BytesIO
from typing import Any, BinaryIO, Callable, List, Literal, Optional, Tuple
from uuid import uuid4
import requests
from azure.core.credentials import AccessToken
from azure.identity import DefaultAzureCredential
from langchain_core.tools import BaseTool
try:
_package_version = importlib.metadata.version("langchain-azure-dynamic-sessions")
except importlib.metadata.PackageNotFoundError:
_package_version = "0.0.0"
USER_AGENT = f"langchain-azure-dynamic-sessions/{_package_version} (Language=Python)"
def _access_token_provider_factory() -> Callable[[], Optional[str]]:
"""Factory function for creating an access token provider function.
Returns:
Callable[[], Optional[str]]: The access token provider function
"""
access_token: Optional[AccessToken] = None
def access_token_provider() -> Optional[str]:
nonlocal access_token
if access_token is None or datetime.fromtimestamp(
access_token.expires_on, timezone.utc
) < datetime.now(timezone.utc) + timedelta(minutes=5):
credential = DefaultAzureCredential()
access_token = credential.get_token("https://dynamicsessions.io/.default")
return access_token.token
return access_token_provider
def _sanitize_input(query: str) -> str:
"""Sanitize input to the python REPL.
Remove whitespace, backtick & python (if llm mistakes python console as terminal)
Args:
query: The query to sanitize
Returns:
str: The sanitized query
"""
# Removes `, whitespace & python from start
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
# Removes whitespace & ` from end
query = re.sub(r"(\s|`)*$", "", query)
return query
@dataclass
class RemoteFileMetadata:
"""Metadata for a file in the session."""
filename: str
"""The filename relative to `/mnt/data`."""
size_in_bytes: int
"""The size of the file in bytes."""
@property
def full_path(self) -> str:
"""Get the full path of the file."""
return f"/mnt/data/{self.filename}"
@staticmethod
def from_dict(data: dict) -> "RemoteFileMetadata":
"""Create a RemoteFileMetadata object from a dictionary."""
properties = data.get("properties", {})
return RemoteFileMetadata(
filename=properties.get("filename"),
size_in_bytes=properties.get("size"),
)
class SessionsPythonREPLTool(BaseTool):
r"""Azure Dynamic Sessions tool.
Setup:
Install ``langchain-azure-dynamic-sessions`` and create a session pool, which you can do by following the instructions [here](https://learn.microsoft.com/en-us/azure/container-apps/sessions-code-interpreter?tabs=azure-cli#create-a-session-pool-with-azure-cli).
.. code-block:: bash
pip install -U langchain-azure-dynamic-sessions
.. code-block:: python
import getpass
POOL_MANAGEMENT_ENDPOINT = getpass.getpass("Enter the management endpoint of the session pool: ")
Instantiation:
.. code-block:: python
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
tool = SessionsPythonREPLTool(
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT
)
Invocation with args:
.. code-block:: python
tool.invoke("6 * 7")
.. code-block:: python
'{\\n "result": 42,\\n "stdout": "",\\n "stderr": ""\\n}'
Invocation with ToolCall:
.. code-block:: python
tool.invoke({"args": {"input":"6 * 7"}, "id": "1", "name": tool.name, "type": "tool_call"})
.. code-block:: python
'{\\n "result": 42,\\n "stdout": "",\\n "stderr": ""\\n}'
""" # noqa: E501
name: str = "Python_REPL"
description: str = (
"A Python shell. Use this to execute python commands "
"when you need to perform calculations or computations. "
"Input should be a valid python command. "
"Returns a JSON object with the result, stdout, and stderr. "
)
sanitize_input: bool = True
"""Whether to sanitize input to the python REPL."""
pool_management_endpoint: str
"""The management endpoint of the session pool. Should end with a '/'."""
access_token_provider: Callable[[], Optional[str]] = (
_access_token_provider_factory()
)
"""A function that returns the access token to use for the session pool."""
session_id: str = str(uuid4())
"""The session ID to use for the code interpreter. Defaults to a random UUID."""
response_format: Literal["content_and_artifact"] = "content_and_artifact"
def _build_url(self, path: str) -> str:
pool_management_endpoint = self.pool_management_endpoint
if not pool_management_endpoint:
raise ValueError("pool_management_endpoint is not set")
if not pool_management_endpoint.endswith("/"):
pool_management_endpoint += "/"
encoded_session_id = urllib.parse.quote(self.session_id)
query = f"identifier={encoded_session_id}&api-version=2024-02-02-preview"
query_separator = "&" if "?" in pool_management_endpoint else "?"
full_url = pool_management_endpoint + path + query_separator + query
return full_url
def execute(self, python_code: str) -> Any:
"""Execute Python code in the session."""
if self.sanitize_input:
python_code = _sanitize_input(python_code)
access_token = self.access_token_provider()
api_url = self._build_url("code/execute")
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
"User-Agent": USER_AGENT,
}
body = {
"properties": {
"codeInputType": "inline",
"executionType": "synchronous",
"code": python_code,
}
}
response = requests.post(api_url, headers=headers, json=body)
response.raise_for_status()
response_json = response.json()
properties = response_json.get("properties", {})
return properties
def _run(self, python_code: str, **kwargs: Any) -> Tuple[str, dict]:
response = self.execute(python_code)
# if the result is an image, remove the base64 data
result = deepcopy(response.get("result"))
if isinstance(result, dict):
if result.get("type") == "image" and "base64_data" in result:
result.pop("base64_data")
content = json.dumps(
{
"result": result,
"stdout": response.get("stdout"),
"stderr": response.get("stderr"),
},
indent=2,
)
return content, response
def upload_file(
self,
*,
data: Optional[BinaryIO] = None,
remote_file_path: Optional[str] = None,
local_file_path: Optional[str] = None,
) -> RemoteFileMetadata:
"""Upload a file to the session.
Args:
data: The data to upload.
remote_file_path: The path to upload the file to, relative to
`/mnt/data`. If local_file_path is provided, this is defaulted
to its filename.
local_file_path: The path to the local file to upload.
Returns:
RemoteFileMetadata: The metadata for the uploaded file
"""
if data and local_file_path:
raise ValueError("data and local_file_path cannot be provided together")
if data:
file_data = data
elif local_file_path:
if not remote_file_path:
remote_file_path = os.path.basename(local_file_path)
file_data = open(local_file_path, "rb")
access_token = self.access_token_provider()
api_url = self._build_url("files/upload")
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": USER_AGENT,
}
files = [("file", (remote_file_path, file_data, "application/octet-stream"))]
response = requests.request(
"POST", api_url, headers=headers, data={}, files=files
)
response.raise_for_status()
response_json = response.json()
return RemoteFileMetadata.from_dict(response_json["value"][0])
def download_file(
self, *, remote_file_path: str, local_file_path: Optional[str] = None
) -> BinaryIO:
"""Download a file from the session.
Args:
remote_file_path: The path to download the file from,
relative to `/mnt/data`.
local_file_path: The path to save the downloaded file to.
If not provided, the file is returned as a BufferedReader.
Returns:
BinaryIO: The data of the downloaded file.
"""
access_token = self.access_token_provider()
encoded_remote_file_path = urllib.parse.quote(remote_file_path)
api_url = self._build_url(f"files/content/{encoded_remote_file_path}")
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": USER_AGENT,
}
response = requests.get(api_url, headers=headers)
response.raise_for_status()
if local_file_path:
with open(local_file_path, "wb") as f:
f.write(response.content)
return BytesIO(response.content)
def list_files(self) -> List[RemoteFileMetadata]:
"""List the files in the session.
Returns:
list[RemoteFileMetadata]: The metadata for the files in the session
"""
access_token = self.access_token_provider()
api_url = self._build_url("files")
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": USER_AGENT,
}
response = requests.get(api_url, headers=headers)
response.raise_for_status()
response_json = response.json()
return [RemoteFileMetadata.from_dict(entry) for entry in response_json["value"]]

File diff suppressed because it is too large Load Diff

View File

@ -1,115 +0,0 @@
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "langchain-azure-dynamic-sessions"
version = "0.2.0"
description = "An integration package connecting Azure Container Apps dynamic sessions and LangChain"
authors = []
readme = "README.md"
repository = "https://github.com/langchain-ai/langchain"
license = "MIT"
[tool.mypy]
disallow_untyped_defs = "True"
[tool.poetry.urls]
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/azure-dynamic-sessions"
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-azure-dynamic-sessions%3D%3D0%22&expanded=true"
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
langchain-core = "^0.3.0"
azure-identity = "^1.16.0"
requests = "^2.31.0"
[tool.ruff.lint]
select = ["E", "F", "I", "D"]
[tool.coverage.run]
omit = ["tests/*"]
[tool.pytest.ini_options]
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
markers = [
"requires: mark tests as requiring a specific library",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"
[tool.poetry.group.test]
optional = true
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.dev]
optional = true
[tool.ruff.lint.pydocstyle]
convention = "google"
[tool.ruff.lint.per-file-ignores]
"tests/**" = ["D"]
[tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
freezegun = "^1.2.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
python-dotenv = "^1.0.1"
# TODO: hack to fix 3.9 builds
cffi = [
{ version = "<1.17.1", python = "<3.10" },
{ version = "*", python = ">=3.10" },
]
[tool.poetry.group.test_integration.dependencies]
pytest = "^7.3.0"
python-dotenv = "^1.0.1"
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.lint.dependencies]
ruff = "^0.5"
python-dotenv = "^1.0.1"
pytest = "^7.3.0"
# TODO: hack to fix 3.9 builds
cffi = [
{ version = "<1.17.1", python = "<3.10" },
{ version = "*", python = ">=3.10" },
]
[tool.poetry.group.dev.dependencies]
ipykernel = "^6.29.4"
langchainhub = "^0.1.15"
[tool.poetry.group.typing.dependencies]
mypy = "^1.10"
types-requests = "^2.31.0.20240406"
[tool.poetry.group.test.dependencies.langchain-core]
path = "../../core"
develop = true
[tool.poetry.group.dev.dependencies.langchain-core]
path = "../../core"
develop = true
[tool.poetry.group.dev.dependencies.langchain-openai]
path = "../openai"
develop = true
[tool.poetry.group.typing.dependencies.langchain-core]
path = "../../core"
develop = true

View File

@ -1,19 +0,0 @@
"""This module checks for specific import statements in the codebase."""
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_failure = True
print(file)
traceback.print_exc()
print()
sys.exit(1 if has_failure else 0)

View File

@ -1,17 +0,0 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain or langchain_experimental
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

View File

@ -1,7 +0,0 @@
import pytest # type: ignore[import-not-found]
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

View File

@ -1,68 +0,0 @@
import json
import os
from io import BytesIO
import dotenv # type: ignore[import-not-found]
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
dotenv.load_dotenv()
POOL_MANAGEMENT_ENDPOINT = os.getenv("AZURE_DYNAMIC_SESSIONS_POOL_MANAGEMENT_ENDPOINT")
TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "testdata.txt")
TEST_DATA_CONTENT = open(TEST_DATA_PATH, "rb").read()
def test_end_to_end() -> None:
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT) # type: ignore[arg-type]
result = tool.run("print('hello world')\n1 + 1")
assert json.loads(result) == {
"result": 2,
"stdout": "hello world\n",
"stderr": "",
}
# upload file content
uploaded_file1_metadata = tool.upload_file(
remote_file_path="test1.txt", data=BytesIO(b"hello world!!!!!")
)
assert uploaded_file1_metadata.filename == "test1.txt"
assert uploaded_file1_metadata.size_in_bytes == 16
assert uploaded_file1_metadata.full_path == "/mnt/data/test1.txt"
downloaded_file1 = tool.download_file(remote_file_path="test1.txt")
assert downloaded_file1.read() == b"hello world!!!!!"
# upload file from buffer
with open(TEST_DATA_PATH, "rb") as f:
uploaded_file2_metadata = tool.upload_file(remote_file_path="test2.txt", data=f)
assert uploaded_file2_metadata.filename == "test2.txt"
downloaded_file2 = tool.download_file(remote_file_path="test2.txt")
assert downloaded_file2.read() == TEST_DATA_CONTENT
# upload file from disk, specifying remote file path
uploaded_file3_metadata = tool.upload_file(
remote_file_path="test3.txt", local_file_path=TEST_DATA_PATH
)
assert uploaded_file3_metadata.filename == "test3.txt"
downloaded_file3 = tool.download_file(remote_file_path="test3.txt")
assert downloaded_file3.read() == TEST_DATA_CONTENT
# upload file from disk, without specifying remote file path
uploaded_file4_metadata = tool.upload_file(local_file_path=TEST_DATA_PATH)
assert uploaded_file4_metadata.filename == os.path.basename(TEST_DATA_PATH)
downloaded_file4 = tool.download_file(
remote_file_path=uploaded_file4_metadata.filename
)
assert downloaded_file4.read() == TEST_DATA_CONTENT
# list files
remote_files_metadata = tool.list_files()
assert len(remote_files_metadata) == 4
remote_file_paths = [metadata.filename for metadata in remote_files_metadata]
expected_filenames = [
"test1.txt",
"test2.txt",
"test3.txt",
os.path.basename(TEST_DATA_PATH),
]
assert set(remote_file_paths) == set(expected_filenames)

View File

@ -1,9 +0,0 @@
from langchain_azure_dynamic_sessions import __all__
EXPECTED_ALL = [
"SessionsPythonREPLTool",
]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)

View File

@ -1,208 +0,0 @@
import json
import re
import time
from unittest import mock
from urllib.parse import parse_qs, urlparse
from azure.core.credentials import AccessToken
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
from langchain_azure_dynamic_sessions.tools.sessions import (
_access_token_provider_factory,
)
POOL_MANAGEMENT_ENDPOINT = "https://westus2.dynamicsessions.io/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/sessions-rg/sessionPools/my-pool"
def test_default_access_token_provider_returns_token() -> None:
access_token_provider = _access_token_provider_factory()
with mock.patch(
"azure.identity.DefaultAzureCredential.get_token"
) as mock_get_token:
mock_get_token.return_value = AccessToken("token_value", 0)
access_token = access_token_provider()
assert access_token == "token_value"
def test_default_access_token_provider_returns_cached_token() -> None:
access_token_provider = _access_token_provider_factory()
with mock.patch(
"azure.identity.DefaultAzureCredential.get_token"
) as mock_get_token:
mock_get_token.return_value = AccessToken(
"token_value", int(time.time() + 1000)
)
access_token = access_token_provider()
assert access_token == "token_value"
assert mock_get_token.call_count == 1
mock_get_token.return_value = AccessToken(
"new_token_value", int(time.time() + 1000)
)
access_token = access_token_provider()
assert access_token == "token_value"
assert mock_get_token.call_count == 1
def test_default_access_token_provider_refreshes_expiring_token() -> None:
access_token_provider = _access_token_provider_factory()
with mock.patch(
"azure.identity.DefaultAzureCredential.get_token"
) as mock_get_token:
mock_get_token.return_value = AccessToken("token_value", int(time.time() - 1))
access_token = access_token_provider()
assert access_token == "token_value"
assert mock_get_token.call_count == 1
mock_get_token.return_value = AccessToken(
"new_token_value", int(time.time() + 1000)
)
access_token = access_token_provider()
assert access_token == "new_token_value"
assert mock_get_token.call_count == 2
@mock.patch("requests.post")
@mock.patch("azure.identity.DefaultAzureCredential.get_token")
def test_code_execution_calls_api(
mock_get_token: mock.MagicMock, mock_post: mock.MagicMock
) -> None:
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
mock_post.return_value.json.return_value = {
"$id": "1",
"properties": {
"$id": "2",
"status": "Success",
"stdout": "hello world\n",
"stderr": "",
"result": "",
"executionTimeInMilliseconds": 33,
},
}
mock_get_token.return_value = AccessToken("token_value", int(time.time() + 1000))
result = tool.run("print('hello world')")
assert json.loads(result) == {
"result": "",
"stdout": "hello world\n",
"stderr": "",
}
api_url = f"{POOL_MANAGEMENT_ENDPOINT}/code/execute"
headers = {
"Authorization": "Bearer token_value",
"Content-Type": "application/json",
"User-Agent": mock.ANY,
}
body = {
"properties": {
"codeInputType": "inline",
"executionType": "synchronous",
"code": "print('hello world')",
}
}
mock_post.assert_called_once_with(mock.ANY, headers=headers, json=body)
called_headers = mock_post.call_args.kwargs["headers"]
assert re.match(
r"^langchain-azure-dynamic-sessions/\d+\.\d+\.\d+.* \(Language=Python\)",
called_headers["User-Agent"],
)
called_api_url = mock_post.call_args.args[0]
assert called_api_url.startswith(api_url)
@mock.patch("requests.post")
@mock.patch("azure.identity.DefaultAzureCredential.get_token")
def test_uses_specified_session_id(
mock_get_token: mock.MagicMock, mock_post: mock.MagicMock
) -> None:
tool = SessionsPythonREPLTool(
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT,
session_id="00000000-0000-0000-0000-000000000003",
)
mock_post.return_value.json.return_value = {
"$id": "1",
"properties": {
"$id": "2",
"status": "Success",
"stdout": "",
"stderr": "",
"result": "2",
"executionTimeInMilliseconds": 33,
},
}
mock_get_token.return_value = AccessToken("token_value", int(time.time() + 1000))
tool.run("1 + 1")
call_url = mock_post.call_args.args[0]
parsed_url = urlparse(call_url)
call_identifier = parse_qs(parsed_url.query)["identifier"][0]
assert call_identifier == "00000000-0000-0000-0000-000000000003"
def test_sanitizes_input() -> None:
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
with mock.patch("requests.post") as mock_post:
mock_post.return_value.json.return_value = {
"$id": "1",
"properties": {
"$id": "2",
"status": "Success",
"stdout": "",
"stderr": "",
"result": "",
"executionTimeInMilliseconds": 33,
},
}
tool.run("```python\nprint('hello world')\n```")
body = mock_post.call_args.kwargs["json"]
assert body["properties"]["code"] == "print('hello world')"
def test_does_not_sanitize_input() -> None:
tool = SessionsPythonREPLTool(
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT, sanitize_input=False
)
with mock.patch("requests.post") as mock_post:
mock_post.return_value.json.return_value = {
"$id": "1",
"properties": {
"$id": "2",
"status": "Success",
"stdout": "",
"stderr": "",
"result": "",
"executionTimeInMilliseconds": 33,
},
}
tool.run("```python\nprint('hello world')\n```")
body = mock_post.call_args.kwargs["json"]
assert body["properties"]["code"] == "```python\nprint('hello world')\n```"
def test_uses_custom_access_token_provider() -> None:
def custom_access_token_provider() -> str:
return "custom_token"
tool = SessionsPythonREPLTool(
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT,
access_token_provider=custom_access_token_provider,
)
with mock.patch("requests.post") as mock_post:
mock_post.return_value.json.return_value = {
"$id": "1",
"properties": {
"$id": "2",
"status": "Success",
"stdout": "",
"stderr": "",
"result": "",
"executionTimeInMilliseconds": 33,
},
}
tool.run("print('hello world')")
headers = mock_post.call_args.kwargs["headers"]
assert headers["Authorization"] == "Bearer custom_token"