mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-02 19:34:04 +00:00
azure-dynamic-sessions: add Python REPL tool (#21264)
Adds a Python REPL that executes code in a code interpreter session using Azure Container Apps dynamic sessions. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
02701c277f
commit
c735849e76
1
libs/partners/azure-dynamic-sessions/.gitignore
vendored
Normal file
1
libs/partners/azure-dynamic-sessions/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
__pycache__
|
21
libs/partners/azure-dynamic-sessions/LICENSE
Normal file
21
libs/partners/azure-dynamic-sessions/LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 LangChain, Inc.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
59
libs/partners/azure-dynamic-sessions/Makefile
Normal file
59
libs/partners/azure-dynamic-sessions/Makefile
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||||
|
|
||||||
|
# Default target executed when no arguments are given to make.
|
||||||
|
all: help
|
||||||
|
|
||||||
|
# Define a variable for the test file path.
|
||||||
|
TEST_FILE ?= tests/unit_tests/
|
||||||
|
|
||||||
|
test:
|
||||||
|
poetry run pytest $(TEST_FILE)
|
||||||
|
|
||||||
|
tests:
|
||||||
|
poetry run pytest $(TEST_FILE)
|
||||||
|
|
||||||
|
|
||||||
|
######################
|
||||||
|
# LINTING AND FORMATTING
|
||||||
|
######################
|
||||||
|
|
||||||
|
# Define a variable for Python and notebook files.
|
||||||
|
PYTHON_FILES=.
|
||||||
|
MYPY_CACHE=.mypy_cache
|
||||||
|
lint format: PYTHON_FILES=.
|
||||||
|
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/azure --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||||
|
lint_package: PYTHON_FILES=langchain_azure_dynamic_sessions
|
||||||
|
lint_tests: PYTHON_FILES=tests
|
||||||
|
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||||
|
|
||||||
|
lint lint_diff lint_package lint_tests:
|
||||||
|
poetry run ruff .
|
||||||
|
poetry run ruff format $(PYTHON_FILES) --diff
|
||||||
|
poetry run ruff --select I $(PYTHON_FILES)
|
||||||
|
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||||
|
|
||||||
|
format format_diff:
|
||||||
|
poetry run ruff format $(PYTHON_FILES)
|
||||||
|
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||||
|
|
||||||
|
spell_check:
|
||||||
|
poetry run codespell --toml pyproject.toml
|
||||||
|
|
||||||
|
spell_fix:
|
||||||
|
poetry run codespell --toml pyproject.toml -w
|
||||||
|
|
||||||
|
check_imports: $(shell find langchain_azure_dynamic_sessions -name '*.py')
|
||||||
|
poetry run python ./scripts/check_imports.py $^
|
||||||
|
|
||||||
|
######################
|
||||||
|
# HELP
|
||||||
|
######################
|
||||||
|
|
||||||
|
help:
|
||||||
|
@echo '----'
|
||||||
|
@echo 'check_imports - check imports'
|
||||||
|
@echo 'format - run code formatters'
|
||||||
|
@echo 'lint - run linters'
|
||||||
|
@echo 'test - run unit tests'
|
||||||
|
@echo 'tests - run unit tests'
|
||||||
|
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
36
libs/partners/azure-dynamic-sessions/README.md
Normal file
36
libs/partners/azure-dynamic-sessions/README.md
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# langchain-azure-dynamic-sessions
|
||||||
|
|
||||||
|
This package contains the LangChain integration for Azure Container Apps dynamic sessions. You can use it to add a secure and scalable code interpreter to your agents.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -U langchain-azure-dynamic-sessions
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
You first need to create an Azure Container Apps session pool and obtain its management endpoint. Then you can use the `SessionsPythonREPLTool` tool to give your agent the ability to execute Python code.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
|
||||||
|
|
||||||
|
|
||||||
|
# get the management endpoint from the session pool in the Azure portal
|
||||||
|
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
|
||||||
|
|
||||||
|
prompt = hub.pull("hwchase17/react")
|
||||||
|
tools=[tool]
|
||||||
|
react_agent = create_react_agent(
|
||||||
|
llm=llm,
|
||||||
|
tools=tools,
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
react_agent_executor = AgentExecutor(agent=react_agent, tools=tools, verbose=True, handle_parsing_errors=True)
|
||||||
|
|
||||||
|
react_agent_executor.invoke({"input": "What is the current time in Vancouver, Canada?"})
|
||||||
|
```
|
||||||
|
|
||||||
|
By default, the tool uses `DefaultAzureCredential` to authenticate with Azure. If you're using a user-assigned managed identity, you must set the `AZURE_CLIENT_ID` environment variable to the ID of the managed identity.
|
||||||
|
|
169
libs/partners/azure-dynamic-sessions/docs/provider.ipynb
Normal file
169
libs/partners/azure-dynamic-sessions/docs/provider.ipynb
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Azure Container Apps dynamic sessions\n",
|
||||||
|
"\n",
|
||||||
|
"Azure Container Apps dynamic sessions provides a secure and scalable way to run a Python code interpreter in Hyper-V isolated sandboxes. This allows your agents to run potentially untrusted code in a secure environment. The code interpreter environment includes many popular Python packages, such as NumPy, pandas, and scikit-learn.\n",
|
||||||
|
"\n",
|
||||||
|
"## Pre-requisites\n",
|
||||||
|
"\n",
|
||||||
|
"By default, the `SessionsPythonREPLTool` tool uses `DefaultAzureCredential` to authenticate with Azure. Locally, it'll use your credentials from the Azure CLI or VS Code. Install the Azure CLI and log in with `az login` to authenticate.\n",
|
||||||
|
"\n",
|
||||||
|
"## Using the tool\n",
|
||||||
|
"\n",
|
||||||
|
"Set variables:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"import dotenv\n",
|
||||||
|
"dotenv.load_dotenv()\n",
|
||||||
|
"\n",
|
||||||
|
"POOL_MANAGEMENT_ENDPOINT = os.getenv(\"POOL_MANAGEMENT_ENDPOINT\")\n",
|
||||||
|
"AZURE_OPENAI_ENDPOINT = os.getenv(\"AZURE_OPENAI_ENDPOINT\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'{\\n \"result\": 42,\\n \"stdout\": \"\",\\n \"stderr\": \"\"\\n}'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain_azure_dynamic_sessions import SessionsPythonREPLTool\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)\n",
|
||||||
|
"tool.run(\"6 * 7\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Full agent example"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mI need to calculate the compound interest on the initial amount over 6 years.\n",
|
||||||
|
"Action: Python_REPL\n",
|
||||||
|
"Action Input: \n",
|
||||||
|
"```python\n",
|
||||||
|
"initial_amount = 500\n",
|
||||||
|
"interest_rate = 0.05\n",
|
||||||
|
"time_period = 6\n",
|
||||||
|
"final_amount = initial_amount * (1 + interest_rate)**time_period\n",
|
||||||
|
"final_amount\n",
|
||||||
|
"```\u001b[0m\u001b[36;1m\u001b[1;3m{\n",
|
||||||
|
" \"result\": 670.0478203125002,\n",
|
||||||
|
" \"stdout\": \"\",\n",
|
||||||
|
" \"stderr\": \"\"\n",
|
||||||
|
"}\u001b[0m\u001b[32;1m\u001b[1;3mThe final amount after 6 years will be $670.05\n",
|
||||||
|
"Final Answer: $670.05\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"{'input': 'If I put $500 in a bank account with a 5% interest rate, how much money will I have in the account after 6 years?',\n",
|
||||||
|
" 'output': '$670.05'}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"from azure.identity import DefaultAzureCredential\n",
|
||||||
|
"from langchain_azure_dynamic_sessions import SessionsPythonREPLTool\n",
|
||||||
|
"from langchain_openai import AzureChatOpenAI\n",
|
||||||
|
"from langchain import agents, hub\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"credential = DefaultAzureCredential()\n",
|
||||||
|
"os.environ[\"OPENAI_API_TYPE\"] = \"azure_ad\"\n",
|
||||||
|
"os.environ[\"OPENAI_API_KEY\"] = credential.get_token(\"https://cognitiveservices.azure.com/.default\").token\n",
|
||||||
|
"os.environ[\"AZURE_OPENAI_ENDPOINT\"] = AZURE_OPENAI_ENDPOINT\n",
|
||||||
|
"\n",
|
||||||
|
"llm = AzureChatOpenAI(\n",
|
||||||
|
" azure_deployment=\"gpt-35-turbo\",\n",
|
||||||
|
" openai_api_version=\"2023-09-15-preview\",\n",
|
||||||
|
" streaming=True,\n",
|
||||||
|
" temperature=0,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"repl = SessionsPythonREPLTool(\n",
|
||||||
|
" pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"tools = [repl]\n",
|
||||||
|
"react_agent = agents.create_react_agent(\n",
|
||||||
|
" llm=llm,\n",
|
||||||
|
" tools=tools,\n",
|
||||||
|
" prompt=hub.pull(\"hwchase17/react\"),\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"react_agent_executor = agents.AgentExecutor(agent=react_agent, tools=tools, verbose=True, handle_parsing_errors=True)\n",
|
||||||
|
"\n",
|
||||||
|
"react_agent_executor.invoke({\"input\": \"If I put $500 in a bank account with a 5% interest rate, how much money will I have in the account after 6 years?\"})"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.9"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 1
|
||||||
|
}
|
@ -0,0 +1,5 @@
|
|||||||
|
from langchain_azure_dynamic_sessions.tools.sessions import SessionsPythonREPLTool
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SessionsPythonREPLTool",
|
||||||
|
]
|
@ -0,0 +1,5 @@
|
|||||||
|
from langchain_azure_dynamic_sessions.tools.sessions import SessionsPythonREPLTool
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SessionsPythonREPLTool",
|
||||||
|
]
|
@ -0,0 +1,273 @@
|
|||||||
|
import importlib.metadata
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import urllib
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Any, BinaryIO, Callable, List, Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from azure.core.credentials import AccessToken
|
||||||
|
from azure.identity import DefaultAzureCredential
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
try:
|
||||||
|
_package_version = importlib.metadata.version("langchain-azure-dynamic-sessions")
|
||||||
|
except importlib.metadata.PackageNotFoundError:
|
||||||
|
_package_version = "0.0.0"
|
||||||
|
USER_AGENT = f"langchain-azure-dynamic-sessions/{_package_version} (Language=Python)"
|
||||||
|
|
||||||
|
|
||||||
|
def _access_token_provider_factory() -> Callable[[], Optional[str]]:
|
||||||
|
"""Factory function for creating an access token provider function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[], Optional[str]]: The access token provider function
|
||||||
|
"""
|
||||||
|
|
||||||
|
access_token: Optional[AccessToken] = None
|
||||||
|
|
||||||
|
def access_token_provider() -> Optional[str]:
|
||||||
|
nonlocal access_token
|
||||||
|
if access_token is None or datetime.fromtimestamp(
|
||||||
|
access_token.expires_on, timezone.utc
|
||||||
|
) < datetime.now(timezone.utc) + timedelta(minutes=5):
|
||||||
|
credential = DefaultAzureCredential()
|
||||||
|
access_token = credential.get_token("https://dynamicsessions.io/.default")
|
||||||
|
return access_token.token
|
||||||
|
|
||||||
|
return access_token_provider
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_input(query: str) -> str:
|
||||||
|
"""Sanitize input to the python REPL.
|
||||||
|
Remove whitespace, backtick & python (if llm mistakes python console as terminal)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The sanitized query
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Removes `, whitespace & python from start
|
||||||
|
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
|
||||||
|
# Removes whitespace & ` from end
|
||||||
|
query = re.sub(r"(\s|`)*$", "", query)
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RemoteFileMetadata:
|
||||||
|
"""Metadata for a file in the session."""
|
||||||
|
|
||||||
|
filename: str
|
||||||
|
"""The filename relative to `/mnt/data`."""
|
||||||
|
|
||||||
|
size_in_bytes: int
|
||||||
|
"""The size of the file in bytes."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def full_path(self) -> str:
|
||||||
|
"""Get the full path of the file."""
|
||||||
|
return f"/mnt/data/{self.filename}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_dict(data: dict) -> "RemoteFileMetadata":
|
||||||
|
"""Create a RemoteFileMetadata object from a dictionary."""
|
||||||
|
properties = data.get("properties", {})
|
||||||
|
return RemoteFileMetadata(
|
||||||
|
filename=properties.get("filename"),
|
||||||
|
size_in_bytes=properties.get("size"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionsPythonREPLTool(BaseTool):
|
||||||
|
"""A tool for running Python code in an Azure Container Apps dynamic sessions
|
||||||
|
code interpreter.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
|
||||||
|
tool = SessionsPythonREPLTool(pool_management_endpoint="...")
|
||||||
|
result = tool.run("6 * 7")
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "Python_REPL"
|
||||||
|
description: str = (
|
||||||
|
"A Python shell. Use this to execute python commands "
|
||||||
|
"when you need to perform calculations or computations. "
|
||||||
|
"Input should be a valid python command. "
|
||||||
|
"Returns a JSON object with the result, stdout, and stderr. "
|
||||||
|
)
|
||||||
|
|
||||||
|
sanitize_input: bool = True
|
||||||
|
"""Whether to sanitize input to the python REPL."""
|
||||||
|
|
||||||
|
pool_management_endpoint: str
|
||||||
|
"""The management endpoint of the session pool. Should end with a '/'."""
|
||||||
|
|
||||||
|
access_token_provider: Callable[
|
||||||
|
[], Optional[str]
|
||||||
|
] = _access_token_provider_factory()
|
||||||
|
"""A function that returns the access token to use for the session pool."""
|
||||||
|
|
||||||
|
session_id: str = str(uuid4())
|
||||||
|
"""The session ID to use for the code interpreter. Defaults to a random UUID."""
|
||||||
|
|
||||||
|
def _build_url(self, path: str) -> str:
|
||||||
|
pool_management_endpoint = self.pool_management_endpoint
|
||||||
|
if not pool_management_endpoint:
|
||||||
|
raise ValueError("pool_management_endpoint is not set")
|
||||||
|
if not pool_management_endpoint.endswith("/"):
|
||||||
|
pool_management_endpoint += "/"
|
||||||
|
encoded_session_id = urllib.parse.quote(self.session_id)
|
||||||
|
query = f"identifier={encoded_session_id}&api-version=2024-02-02-preview"
|
||||||
|
query_separator = "&" if "?" in pool_management_endpoint else "?"
|
||||||
|
full_url = pool_management_endpoint + path + query_separator + query
|
||||||
|
return full_url
|
||||||
|
|
||||||
|
def execute(self, python_code: str) -> Any:
|
||||||
|
"""Execute Python code in the session."""
|
||||||
|
|
||||||
|
if self.sanitize_input:
|
||||||
|
python_code = _sanitize_input(python_code)
|
||||||
|
|
||||||
|
access_token = self.access_token_provider()
|
||||||
|
api_url = self._build_url("code/execute")
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"User-Agent": USER_AGENT,
|
||||||
|
}
|
||||||
|
body = {
|
||||||
|
"properties": {
|
||||||
|
"codeInputType": "inline",
|
||||||
|
"executionType": "synchronous",
|
||||||
|
"code": python_code,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(api_url, headers=headers, json=body)
|
||||||
|
response.raise_for_status()
|
||||||
|
response_json = response.json()
|
||||||
|
properties = response_json.get("properties", {})
|
||||||
|
return properties
|
||||||
|
|
||||||
|
def _run(self, python_code: str) -> Any:
|
||||||
|
response = self.execute(python_code)
|
||||||
|
|
||||||
|
# if the result is an image, remove the base64 data
|
||||||
|
result = response.get("result")
|
||||||
|
if isinstance(result, dict):
|
||||||
|
if result.get("type") == "image" and "base64_data" in result:
|
||||||
|
result.pop("base64_data")
|
||||||
|
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"result": result,
|
||||||
|
"stdout": response.get("stdout"),
|
||||||
|
"stderr": response.get("stderr"),
|
||||||
|
},
|
||||||
|
indent=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
def upload_file(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
data: Optional[BinaryIO] = None,
|
||||||
|
remote_file_path: Optional[str] = None,
|
||||||
|
local_file_path: Optional[str] = None,
|
||||||
|
) -> RemoteFileMetadata:
|
||||||
|
"""Upload a file to the session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The data to upload.
|
||||||
|
remote_file_path: The path to upload the file to, relative to
|
||||||
|
`/mnt/data`. If local_file_path is provided, this is defaulted
|
||||||
|
to its filename.
|
||||||
|
local_file_path: The path to the local file to upload.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RemoteFileMetadata: The metadata for the uploaded file
|
||||||
|
"""
|
||||||
|
|
||||||
|
if data and local_file_path:
|
||||||
|
raise ValueError("data and local_file_path cannot be provided together")
|
||||||
|
|
||||||
|
if data:
|
||||||
|
file_data = data
|
||||||
|
elif local_file_path:
|
||||||
|
if not remote_file_path:
|
||||||
|
remote_file_path = os.path.basename(local_file_path)
|
||||||
|
file_data = open(local_file_path, "rb")
|
||||||
|
|
||||||
|
access_token = self.access_token_provider()
|
||||||
|
api_url = self._build_url("files/upload")
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"User-Agent": USER_AGENT,
|
||||||
|
}
|
||||||
|
files = [("file", (remote_file_path, file_data, "application/octet-stream"))]
|
||||||
|
|
||||||
|
response = requests.request(
|
||||||
|
"POST", api_url, headers=headers, data={}, files=files
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
return RemoteFileMetadata.from_dict(response_json["value"][0])
|
||||||
|
|
||||||
|
def download_file(
|
||||||
|
self, *, remote_file_path: str, local_file_path: Optional[str] = None
|
||||||
|
) -> BinaryIO:
|
||||||
|
"""Download a file from the session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_file_path: The path to download the file from,
|
||||||
|
relative to `/mnt/data`.
|
||||||
|
local_file_path: The path to save the downloaded file to.
|
||||||
|
If not provided, the file is returned as a BufferedReader.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BinaryIO: The data of the downloaded file.
|
||||||
|
"""
|
||||||
|
access_token = self.access_token_provider()
|
||||||
|
encoded_remote_file_path = urllib.parse.quote(remote_file_path)
|
||||||
|
api_url = self._build_url(f"files/content/{encoded_remote_file_path}")
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"User-Agent": USER_AGENT,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.get(api_url, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
if local_file_path:
|
||||||
|
with open(local_file_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
return BytesIO(response.content)
|
||||||
|
|
||||||
|
def list_files(self) -> List[RemoteFileMetadata]:
|
||||||
|
"""List the files in the session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[RemoteFileMetadata]: The metadata for the files in the session
|
||||||
|
"""
|
||||||
|
access_token = self.access_token_provider()
|
||||||
|
api_url = self._build_url("files")
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"User-Agent": USER_AGENT,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.get(api_url, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
return [RemoteFileMetadata.from_dict(entry) for entry in response_json["value"]]
|
2055
libs/partners/azure-dynamic-sessions/poetry.lock
generated
Normal file
2055
libs/partners/azure-dynamic-sessions/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
104
libs/partners/azure-dynamic-sessions/pyproject.toml
Normal file
104
libs/partners/azure-dynamic-sessions/pyproject.toml
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "langchain-azure-dynamic-sessions"
|
||||||
|
version = "0.1.0rc0"
|
||||||
|
description = "An integration package connecting Azure Container Apps dynamic sessions and LangChain"
|
||||||
|
authors = []
|
||||||
|
readme = "README.md"
|
||||||
|
repository = "https://github.com/langchain-ai/langchain"
|
||||||
|
license = "MIT"
|
||||||
|
|
||||||
|
[tool.poetry.urls]
|
||||||
|
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/azure-dynamic-sessions"
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = ">=3.8.1,<4.0"
|
||||||
|
langchain-core = "^0.1.52"
|
||||||
|
azure-identity = "^1.16.0"
|
||||||
|
requests = "^2.31.0"
|
||||||
|
|
||||||
|
[tool.poetry.group.test]
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
[tool.poetry.group.test.dependencies]
|
||||||
|
pytest = "^7.3.0"
|
||||||
|
freezegun = "^1.2.2"
|
||||||
|
pytest-mock = "^3.10.0"
|
||||||
|
syrupy = "^4.0.2"
|
||||||
|
pytest-watcher = "^0.3.4"
|
||||||
|
pytest-asyncio = "^0.21.1"
|
||||||
|
langchain-core = {path = "../../core", develop = true}
|
||||||
|
python-dotenv = "^1.0.1"
|
||||||
|
|
||||||
|
[tool.poetry.group.test_integration]
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
[tool.poetry.group.test_integration.dependencies]
|
||||||
|
pytest = "^7.3.0"
|
||||||
|
python-dotenv = "^1.0.1"
|
||||||
|
|
||||||
|
[tool.poetry.group.codespell]
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
[tool.poetry.group.codespell.dependencies]
|
||||||
|
codespell = "^2.2.0"
|
||||||
|
|
||||||
|
[tool.poetry.group.lint]
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
[tool.poetry.group.lint.dependencies]
|
||||||
|
ruff = "^0.1.5"
|
||||||
|
python-dotenv = "^1.0.1"
|
||||||
|
pytest = "^7.3.0"
|
||||||
|
|
||||||
|
[tool.poetry.group.typing.dependencies]
|
||||||
|
mypy = "^0.991"
|
||||||
|
langchain-core = {path = "../../core", develop = true}
|
||||||
|
types-requests = "^2.31.0.20240406"
|
||||||
|
|
||||||
|
[tool.poetry.group.dev]
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
langchain-core = {path = "../../core", develop = true}
|
||||||
|
ipykernel = "^6.29.4"
|
||||||
|
langchain-openai = {path = "../openai", develop = true}
|
||||||
|
langchainhub = "^0.1.15"
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
select = [
|
||||||
|
"E", # pycodestyle
|
||||||
|
"F", # pyflakes
|
||||||
|
"I", # isort
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
disallow_untyped_defs = "True"
|
||||||
|
|
||||||
|
[tool.coverage.run]
|
||||||
|
omit = [
|
||||||
|
"tests/*",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core>=1.0.0"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
# --strict-markers will raise errors on unknown marks.
|
||||||
|
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||||
|
#
|
||||||
|
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||||
|
# --strict-config any warnings encountered while parsing the `pytest`
|
||||||
|
# section of the configuration file raise errors.
|
||||||
|
#
|
||||||
|
# https://github.com/tophat/syrupy
|
||||||
|
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||||
|
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||||
|
# Registering custom markers.
|
||||||
|
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||||
|
markers = [
|
||||||
|
"requires: mark tests as requiring a specific library",
|
||||||
|
"asyncio: mark tests as requiring asyncio",
|
||||||
|
"compile: mark placeholder test used to compile integration tests without running them",
|
||||||
|
]
|
||||||
|
asyncio_mode = "auto"
|
@ -0,0 +1,17 @@
|
|||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from importlib.machinery import SourceFileLoader
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
files = sys.argv[1:]
|
||||||
|
has_failure = False
|
||||||
|
for file in files:
|
||||||
|
try:
|
||||||
|
SourceFileLoader("x", file).load_module()
|
||||||
|
except Exception:
|
||||||
|
has_faillure = True
|
||||||
|
print(file)
|
||||||
|
traceback.print_exc()
|
||||||
|
print()
|
||||||
|
|
||||||
|
sys.exit(1 if has_failure else 0)
|
27
libs/partners/azure-dynamic-sessions/scripts/check_pydantic.sh
Executable file
27
libs/partners/azure-dynamic-sessions/scripts/check_pydantic.sh
Executable file
@ -0,0 +1,27 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||||
|
# in tracked files within a Git repository.
|
||||||
|
#
|
||||||
|
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||||
|
|
||||||
|
# Check if a path argument is provided
|
||||||
|
if [ $# -ne 1 ]; then
|
||||||
|
echo "Usage: $0 /path/to/repository"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
repository_path="$1"
|
||||||
|
|
||||||
|
# Search for lines matching the pattern within the specified repository
|
||||||
|
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||||
|
|
||||||
|
# Check if any matching lines were found
|
||||||
|
if [ -n "$result" ]; then
|
||||||
|
echo "ERROR: The following lines need to be updated:"
|
||||||
|
echo "$result"
|
||||||
|
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||||
|
echo "For example, replace 'from pydantic import BaseModel'"
|
||||||
|
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||||
|
exit 1
|
||||||
|
fi
|
17
libs/partners/azure-dynamic-sessions/scripts/lint_imports.sh
Executable file
17
libs/partners/azure-dynamic-sessions/scripts/lint_imports.sh
Executable file
@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
# Initialize a variable to keep track of errors
|
||||||
|
errors=0
|
||||||
|
|
||||||
|
# make sure not importing from langchain or langchain_experimental
|
||||||
|
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||||
|
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||||
|
|
||||||
|
# Decide on an exit status based on the errors
|
||||||
|
if [ "$errors" -gt 0 ]; then
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
exit 0
|
||||||
|
fi
|
@ -0,0 +1 @@
|
|||||||
|
test file content
|
@ -0,0 +1,7 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.compile
|
||||||
|
def test_placeholder() -> None:
|
||||||
|
"""Used for compiling integration tests without running any real tests."""
|
||||||
|
pass
|
@ -0,0 +1,68 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
|
||||||
|
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
POOL_MANAGEMENT_ENDPOINT = os.getenv("AZURE_DYNAMIC_SESSIONS_POOL_MANAGEMENT_ENDPOINT")
|
||||||
|
TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "testdata.txt")
|
||||||
|
TEST_DATA_CONTENT = open(TEST_DATA_PATH, "rb").read()
|
||||||
|
|
||||||
|
|
||||||
|
def test_end_to_end() -> None:
|
||||||
|
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
|
||||||
|
result = tool.run("print('hello world')\n1 + 1")
|
||||||
|
assert json.loads(result) == {
|
||||||
|
"result": 2,
|
||||||
|
"stdout": "hello world\n",
|
||||||
|
"stderr": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
# upload file content
|
||||||
|
uploaded_file1_metadata = tool.upload_file(
|
||||||
|
remote_file_path="test1.txt", data=BytesIO(b"hello world!!!!!")
|
||||||
|
)
|
||||||
|
assert uploaded_file1_metadata.filename == "test1.txt"
|
||||||
|
assert uploaded_file1_metadata.size_in_bytes == 16
|
||||||
|
assert uploaded_file1_metadata.full_path == "/mnt/data/test1.txt"
|
||||||
|
downloaded_file1 = tool.download_file(remote_file_path="test1.txt")
|
||||||
|
assert downloaded_file1.read() == b"hello world!!!!!"
|
||||||
|
|
||||||
|
# upload file from buffer
|
||||||
|
with open(TEST_DATA_PATH, "rb") as f:
|
||||||
|
uploaded_file2_metadata = tool.upload_file(remote_file_path="test2.txt", data=f)
|
||||||
|
assert uploaded_file2_metadata.filename == "test2.txt"
|
||||||
|
downloaded_file2 = tool.download_file(remote_file_path="test2.txt")
|
||||||
|
assert downloaded_file2.read() == TEST_DATA_CONTENT
|
||||||
|
|
||||||
|
# upload file from disk, specifying remote file path
|
||||||
|
uploaded_file3_metadata = tool.upload_file(
|
||||||
|
remote_file_path="test3.txt", local_file_path=TEST_DATA_PATH
|
||||||
|
)
|
||||||
|
assert uploaded_file3_metadata.filename == "test3.txt"
|
||||||
|
downloaded_file3 = tool.download_file(remote_file_path="test3.txt")
|
||||||
|
assert downloaded_file3.read() == TEST_DATA_CONTENT
|
||||||
|
|
||||||
|
# upload file from disk, without specifying remote file path
|
||||||
|
uploaded_file4_metadata = tool.upload_file(local_file_path=TEST_DATA_PATH)
|
||||||
|
assert uploaded_file4_metadata.filename == os.path.basename(TEST_DATA_PATH)
|
||||||
|
downloaded_file4 = tool.download_file(
|
||||||
|
remote_file_path=uploaded_file4_metadata.filename
|
||||||
|
)
|
||||||
|
assert downloaded_file4.read() == TEST_DATA_CONTENT
|
||||||
|
|
||||||
|
# list files
|
||||||
|
remote_files_metadata = tool.list_files()
|
||||||
|
assert len(remote_files_metadata) == 4
|
||||||
|
remote_file_paths = [metadata.filename for metadata in remote_files_metadata]
|
||||||
|
expected_filenames = [
|
||||||
|
"test1.txt",
|
||||||
|
"test2.txt",
|
||||||
|
"test3.txt",
|
||||||
|
os.path.basename(TEST_DATA_PATH),
|
||||||
|
]
|
||||||
|
assert set(remote_file_paths) == set(expected_filenames)
|
@ -0,0 +1,9 @@
|
|||||||
|
from langchain_azure_dynamic_sessions import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"SessionsPythonREPLTool",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
@ -0,0 +1,208 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from unittest import mock
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
from azure.core.credentials import AccessToken
|
||||||
|
|
||||||
|
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
|
||||||
|
from langchain_azure_dynamic_sessions.tools.sessions import (
|
||||||
|
_access_token_provider_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
POOL_MANAGEMENT_ENDPOINT = "https://westus2.dynamicsessions.io/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/sessions-rg/sessionPools/my-pool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_access_token_provider_returns_token() -> None:
|
||||||
|
access_token_provider = _access_token_provider_factory()
|
||||||
|
with mock.patch(
|
||||||
|
"azure.identity.DefaultAzureCredential.get_token"
|
||||||
|
) as mock_get_token:
|
||||||
|
mock_get_token.return_value = AccessToken("token_value", 0)
|
||||||
|
access_token = access_token_provider()
|
||||||
|
assert access_token == "token_value"
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_access_token_provider_returns_cached_token() -> None:
|
||||||
|
access_token_provider = _access_token_provider_factory()
|
||||||
|
with mock.patch(
|
||||||
|
"azure.identity.DefaultAzureCredential.get_token"
|
||||||
|
) as mock_get_token:
|
||||||
|
mock_get_token.return_value = AccessToken(
|
||||||
|
"token_value", int(time.time() + 1000)
|
||||||
|
)
|
||||||
|
access_token = access_token_provider()
|
||||||
|
assert access_token == "token_value"
|
||||||
|
assert mock_get_token.call_count == 1
|
||||||
|
|
||||||
|
mock_get_token.return_value = AccessToken(
|
||||||
|
"new_token_value", int(time.time() + 1000)
|
||||||
|
)
|
||||||
|
access_token = access_token_provider()
|
||||||
|
assert access_token == "token_value"
|
||||||
|
assert mock_get_token.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_access_token_provider_refreshes_expiring_token() -> None:
|
||||||
|
access_token_provider = _access_token_provider_factory()
|
||||||
|
with mock.patch(
|
||||||
|
"azure.identity.DefaultAzureCredential.get_token"
|
||||||
|
) as mock_get_token:
|
||||||
|
mock_get_token.return_value = AccessToken("token_value", int(time.time() - 1))
|
||||||
|
access_token = access_token_provider()
|
||||||
|
assert access_token == "token_value"
|
||||||
|
assert mock_get_token.call_count == 1
|
||||||
|
|
||||||
|
mock_get_token.return_value = AccessToken(
|
||||||
|
"new_token_value", int(time.time() + 1000)
|
||||||
|
)
|
||||||
|
access_token = access_token_provider()
|
||||||
|
assert access_token == "new_token_value"
|
||||||
|
assert mock_get_token.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("requests.post")
|
||||||
|
@mock.patch("azure.identity.DefaultAzureCredential.get_token")
|
||||||
|
def test_code_execution_calls_api(
|
||||||
|
mock_get_token: mock.MagicMock, mock_post: mock.MagicMock
|
||||||
|
) -> None:
|
||||||
|
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
|
||||||
|
mock_post.return_value.json.return_value = {
|
||||||
|
"$id": "1",
|
||||||
|
"properties": {
|
||||||
|
"$id": "2",
|
||||||
|
"status": "Success",
|
||||||
|
"stdout": "hello world\n",
|
||||||
|
"stderr": "",
|
||||||
|
"result": "",
|
||||||
|
"executionTimeInMilliseconds": 33,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mock_get_token.return_value = AccessToken("token_value", int(time.time() + 1000))
|
||||||
|
|
||||||
|
result = tool.run("print('hello world')")
|
||||||
|
|
||||||
|
assert json.loads(result) == {
|
||||||
|
"result": "",
|
||||||
|
"stdout": "hello world\n",
|
||||||
|
"stderr": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
api_url = f"{POOL_MANAGEMENT_ENDPOINT}/code/execute"
|
||||||
|
headers = {
|
||||||
|
"Authorization": "Bearer token_value",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"User-Agent": mock.ANY,
|
||||||
|
}
|
||||||
|
body = {
|
||||||
|
"properties": {
|
||||||
|
"codeInputType": "inline",
|
||||||
|
"executionType": "synchronous",
|
||||||
|
"code": "print('hello world')",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_post.assert_called_once_with(mock.ANY, headers=headers, json=body)
|
||||||
|
|
||||||
|
called_headers = mock_post.call_args.kwargs["headers"]
|
||||||
|
assert re.match(
|
||||||
|
r"^langchain-azure-dynamic-sessions/\d+\.\d+\.\d+.* \(Language=Python\)",
|
||||||
|
called_headers["User-Agent"],
|
||||||
|
)
|
||||||
|
|
||||||
|
called_api_url = mock_post.call_args.args[0]
|
||||||
|
assert called_api_url.startswith(api_url)
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("requests.post")
|
||||||
|
@mock.patch("azure.identity.DefaultAzureCredential.get_token")
|
||||||
|
def test_uses_specified_session_id(
|
||||||
|
mock_get_token: mock.MagicMock, mock_post: mock.MagicMock
|
||||||
|
) -> None:
|
||||||
|
tool = SessionsPythonREPLTool(
|
||||||
|
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT,
|
||||||
|
session_id="00000000-0000-0000-0000-000000000003",
|
||||||
|
)
|
||||||
|
mock_post.return_value.json.return_value = {
|
||||||
|
"$id": "1",
|
||||||
|
"properties": {
|
||||||
|
"$id": "2",
|
||||||
|
"status": "Success",
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": "",
|
||||||
|
"result": "2",
|
||||||
|
"executionTimeInMilliseconds": 33,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mock_get_token.return_value = AccessToken("token_value", int(time.time() + 1000))
|
||||||
|
tool.run("1 + 1")
|
||||||
|
call_url = mock_post.call_args.args[0]
|
||||||
|
parsed_url = urlparse(call_url)
|
||||||
|
call_identifier = parse_qs(parsed_url.query)["identifier"][0]
|
||||||
|
assert call_identifier == "00000000-0000-0000-0000-000000000003"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitizes_input() -> None:
|
||||||
|
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
|
||||||
|
with mock.patch("requests.post") as mock_post:
|
||||||
|
mock_post.return_value.json.return_value = {
|
||||||
|
"$id": "1",
|
||||||
|
"properties": {
|
||||||
|
"$id": "2",
|
||||||
|
"status": "Success",
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": "",
|
||||||
|
"result": "",
|
||||||
|
"executionTimeInMilliseconds": 33,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tool.run("```python\nprint('hello world')\n```")
|
||||||
|
body = mock_post.call_args.kwargs["json"]
|
||||||
|
assert body["properties"]["code"] == "print('hello world')"
|
||||||
|
|
||||||
|
|
||||||
|
def test_does_not_sanitize_input() -> None:
|
||||||
|
tool = SessionsPythonREPLTool(
|
||||||
|
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT, sanitize_input=False
|
||||||
|
)
|
||||||
|
with mock.patch("requests.post") as mock_post:
|
||||||
|
mock_post.return_value.json.return_value = {
|
||||||
|
"$id": "1",
|
||||||
|
"properties": {
|
||||||
|
"$id": "2",
|
||||||
|
"status": "Success",
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": "",
|
||||||
|
"result": "",
|
||||||
|
"executionTimeInMilliseconds": 33,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tool.run("```python\nprint('hello world')\n```")
|
||||||
|
body = mock_post.call_args.kwargs["json"]
|
||||||
|
assert body["properties"]["code"] == "```python\nprint('hello world')\n```"
|
||||||
|
|
||||||
|
|
||||||
|
def test_uses_custom_access_token_provider() -> None:
|
||||||
|
def custom_access_token_provider() -> str:
|
||||||
|
return "custom_token"
|
||||||
|
|
||||||
|
tool = SessionsPythonREPLTool(
|
||||||
|
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT,
|
||||||
|
access_token_provider=custom_access_token_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock.patch("requests.post") as mock_post:
|
||||||
|
mock_post.return_value.json.return_value = {
|
||||||
|
"$id": "1",
|
||||||
|
"properties": {
|
||||||
|
"$id": "2",
|
||||||
|
"status": "Success",
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": "",
|
||||||
|
"result": "",
|
||||||
|
"executionTimeInMilliseconds": 33,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tool.run("print('hello world')")
|
||||||
|
headers = mock_post.call_args.kwargs["headers"]
|
||||||
|
assert headers["Authorization"] == "Bearer custom_token"
|
Loading…
Reference in New Issue
Block a user