mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
azure-dynamic-sessions: migrate to repo (#27468)
This commit is contained in:
parent
30660786b3
commit
a562c54f7d
@ -1 +0,0 @@
|
||||
__pycache__
|
@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@ -1,61 +0,0 @@
|
||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
test:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
test_watch:
|
||||
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/azure --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_azure_dynamic_sessions
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
||||
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_azure_dynamic_sessions -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
@ -1,36 +1,3 @@
|
||||
# langchain-azure-dynamic-sessions
|
||||
|
||||
This package contains the LangChain integration for Azure Container Apps dynamic sessions. You can use it to add a secure and scalable code interpreter to your agents.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install -U langchain-azure-dynamic-sessions
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
You first need to create an Azure Container Apps session pool and obtain its management endpoint. Then you can use the `SessionsPythonREPLTool` tool to give your agent the ability to execute Python code.
|
||||
|
||||
```python
|
||||
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
|
||||
|
||||
|
||||
# get the management endpoint from the session pool in the Azure portal
|
||||
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
|
||||
|
||||
prompt = hub.pull("hwchase17/react")
|
||||
tools=[tool]
|
||||
react_agent = create_react_agent(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
react_agent_executor = AgentExecutor(agent=react_agent, tools=tools, verbose=True, handle_parsing_errors=True)
|
||||
|
||||
react_agent_executor.invoke({"input": "What is the current time in Vancouver, Canada?"})
|
||||
```
|
||||
|
||||
By default, the tool uses `DefaultAzureCredential` to authenticate with Azure. If you're using a user-assigned managed identity, you must set the `AZURE_CLIENT_ID` environment variable to the ID of the managed identity.
|
||||
This package has moved!
|
||||
|
||||
https://github.com/langchain-ai/langchain-azure/tree/main/libs/azure-dynamic-sessions
|
||||
|
@ -1,7 +0,0 @@
|
||||
"""This package provides tools for managing dynamic sessions in Azure."""
|
||||
|
||||
from langchain_azure_dynamic_sessions.tools.sessions import SessionsPythonREPLTool
|
||||
|
||||
__all__ = [
|
||||
"SessionsPythonREPLTool",
|
||||
]
|
@ -1,7 +0,0 @@
|
||||
"""This package provides tools for managing dynamic sessions in Azure."""
|
||||
|
||||
from langchain_azure_dynamic_sessions.tools.sessions import SessionsPythonREPLTool
|
||||
|
||||
__all__ = [
|
||||
"SessionsPythonREPLTool",
|
||||
]
|
@ -1,314 +0,0 @@
|
||||
"""This is the Azure Dynamic Sessions module.
|
||||
|
||||
This module provides the SessionsPythonREPLTool class for
|
||||
managing dynamic sessions in Azure.
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import urllib
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from io import BytesIO
|
||||
from typing import Any, BinaryIO, Callable, List, Literal, Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
from azure.core.credentials import AccessToken
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
try:
|
||||
_package_version = importlib.metadata.version("langchain-azure-dynamic-sessions")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
_package_version = "0.0.0"
|
||||
USER_AGENT = f"langchain-azure-dynamic-sessions/{_package_version} (Language=Python)"
|
||||
|
||||
|
||||
def _access_token_provider_factory() -> Callable[[], Optional[str]]:
|
||||
"""Factory function for creating an access token provider function.
|
||||
|
||||
Returns:
|
||||
Callable[[], Optional[str]]: The access token provider function
|
||||
"""
|
||||
access_token: Optional[AccessToken] = None
|
||||
|
||||
def access_token_provider() -> Optional[str]:
|
||||
nonlocal access_token
|
||||
if access_token is None or datetime.fromtimestamp(
|
||||
access_token.expires_on, timezone.utc
|
||||
) < datetime.now(timezone.utc) + timedelta(minutes=5):
|
||||
credential = DefaultAzureCredential()
|
||||
access_token = credential.get_token("https://dynamicsessions.io/.default")
|
||||
return access_token.token
|
||||
|
||||
return access_token_provider
|
||||
|
||||
|
||||
def _sanitize_input(query: str) -> str:
|
||||
"""Sanitize input to the python REPL.
|
||||
|
||||
Remove whitespace, backtick & python (if llm mistakes python console as terminal)
|
||||
|
||||
Args:
|
||||
query: The query to sanitize
|
||||
|
||||
Returns:
|
||||
str: The sanitized query
|
||||
"""
|
||||
# Removes `, whitespace & python from start
|
||||
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
|
||||
# Removes whitespace & ` from end
|
||||
query = re.sub(r"(\s|`)*$", "", query)
|
||||
return query
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemoteFileMetadata:
|
||||
"""Metadata for a file in the session."""
|
||||
|
||||
filename: str
|
||||
"""The filename relative to `/mnt/data`."""
|
||||
|
||||
size_in_bytes: int
|
||||
"""The size of the file in bytes."""
|
||||
|
||||
@property
|
||||
def full_path(self) -> str:
|
||||
"""Get the full path of the file."""
|
||||
return f"/mnt/data/{self.filename}"
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: dict) -> "RemoteFileMetadata":
|
||||
"""Create a RemoteFileMetadata object from a dictionary."""
|
||||
properties = data.get("properties", {})
|
||||
return RemoteFileMetadata(
|
||||
filename=properties.get("filename"),
|
||||
size_in_bytes=properties.get("size"),
|
||||
)
|
||||
|
||||
|
||||
class SessionsPythonREPLTool(BaseTool):
|
||||
r"""Azure Dynamic Sessions tool.
|
||||
|
||||
Setup:
|
||||
Install ``langchain-azure-dynamic-sessions`` and create a session pool, which you can do by following the instructions [here](https://learn.microsoft.com/en-us/azure/container-apps/sessions-code-interpreter?tabs=azure-cli#create-a-session-pool-with-azure-cli).
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langchain-azure-dynamic-sessions
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import getpass
|
||||
|
||||
POOL_MANAGEMENT_ENDPOINT = getpass.getpass("Enter the management endpoint of the session pool: ")
|
||||
|
||||
Instantiation:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
|
||||
|
||||
tool = SessionsPythonREPLTool(
|
||||
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT
|
||||
)
|
||||
|
||||
|
||||
Invocation with args:
|
||||
.. code-block:: python
|
||||
|
||||
tool.invoke("6 * 7")
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
'{\\n "result": 42,\\n "stdout": "",\\n "stderr": ""\\n}'
|
||||
|
||||
Invocation with ToolCall:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
tool.invoke({"args": {"input":"6 * 7"}, "id": "1", "name": tool.name, "type": "tool_call"})
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
'{\\n "result": 42,\\n "stdout": "",\\n "stderr": ""\\n}'
|
||||
""" # noqa: E501
|
||||
|
||||
name: str = "Python_REPL"
|
||||
description: str = (
|
||||
"A Python shell. Use this to execute python commands "
|
||||
"when you need to perform calculations or computations. "
|
||||
"Input should be a valid python command. "
|
||||
"Returns a JSON object with the result, stdout, and stderr. "
|
||||
)
|
||||
|
||||
sanitize_input: bool = True
|
||||
"""Whether to sanitize input to the python REPL."""
|
||||
|
||||
pool_management_endpoint: str
|
||||
"""The management endpoint of the session pool. Should end with a '/'."""
|
||||
|
||||
access_token_provider: Callable[[], Optional[str]] = (
|
||||
_access_token_provider_factory()
|
||||
)
|
||||
"""A function that returns the access token to use for the session pool."""
|
||||
|
||||
session_id: str = str(uuid4())
|
||||
"""The session ID to use for the code interpreter. Defaults to a random UUID."""
|
||||
|
||||
response_format: Literal["content_and_artifact"] = "content_and_artifact"
|
||||
|
||||
def _build_url(self, path: str) -> str:
|
||||
pool_management_endpoint = self.pool_management_endpoint
|
||||
if not pool_management_endpoint:
|
||||
raise ValueError("pool_management_endpoint is not set")
|
||||
if not pool_management_endpoint.endswith("/"):
|
||||
pool_management_endpoint += "/"
|
||||
encoded_session_id = urllib.parse.quote(self.session_id)
|
||||
query = f"identifier={encoded_session_id}&api-version=2024-02-02-preview"
|
||||
query_separator = "&" if "?" in pool_management_endpoint else "?"
|
||||
full_url = pool_management_endpoint + path + query_separator + query
|
||||
return full_url
|
||||
|
||||
def execute(self, python_code: str) -> Any:
|
||||
"""Execute Python code in the session."""
|
||||
if self.sanitize_input:
|
||||
python_code = _sanitize_input(python_code)
|
||||
|
||||
access_token = self.access_token_provider()
|
||||
api_url = self._build_url("code/execute")
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": USER_AGENT,
|
||||
}
|
||||
body = {
|
||||
"properties": {
|
||||
"codeInputType": "inline",
|
||||
"executionType": "synchronous",
|
||||
"code": python_code,
|
||||
}
|
||||
}
|
||||
|
||||
response = requests.post(api_url, headers=headers, json=body)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
properties = response_json.get("properties", {})
|
||||
return properties
|
||||
|
||||
def _run(self, python_code: str, **kwargs: Any) -> Tuple[str, dict]:
|
||||
response = self.execute(python_code)
|
||||
|
||||
# if the result is an image, remove the base64 data
|
||||
result = deepcopy(response.get("result"))
|
||||
if isinstance(result, dict):
|
||||
if result.get("type") == "image" and "base64_data" in result:
|
||||
result.pop("base64_data")
|
||||
|
||||
content = json.dumps(
|
||||
{
|
||||
"result": result,
|
||||
"stdout": response.get("stdout"),
|
||||
"stderr": response.get("stderr"),
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
return content, response
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
*,
|
||||
data: Optional[BinaryIO] = None,
|
||||
remote_file_path: Optional[str] = None,
|
||||
local_file_path: Optional[str] = None,
|
||||
) -> RemoteFileMetadata:
|
||||
"""Upload a file to the session.
|
||||
|
||||
Args:
|
||||
data: The data to upload.
|
||||
remote_file_path: The path to upload the file to, relative to
|
||||
`/mnt/data`. If local_file_path is provided, this is defaulted
|
||||
to its filename.
|
||||
local_file_path: The path to the local file to upload.
|
||||
|
||||
Returns:
|
||||
RemoteFileMetadata: The metadata for the uploaded file
|
||||
"""
|
||||
if data and local_file_path:
|
||||
raise ValueError("data and local_file_path cannot be provided together")
|
||||
|
||||
if data:
|
||||
file_data = data
|
||||
elif local_file_path:
|
||||
if not remote_file_path:
|
||||
remote_file_path = os.path.basename(local_file_path)
|
||||
file_data = open(local_file_path, "rb")
|
||||
|
||||
access_token = self.access_token_provider()
|
||||
api_url = self._build_url("files/upload")
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"User-Agent": USER_AGENT,
|
||||
}
|
||||
files = [("file", (remote_file_path, file_data, "application/octet-stream"))]
|
||||
|
||||
response = requests.request(
|
||||
"POST", api_url, headers=headers, data={}, files=files
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
return RemoteFileMetadata.from_dict(response_json["value"][0])
|
||||
|
||||
def download_file(
|
||||
self, *, remote_file_path: str, local_file_path: Optional[str] = None
|
||||
) -> BinaryIO:
|
||||
"""Download a file from the session.
|
||||
|
||||
Args:
|
||||
remote_file_path: The path to download the file from,
|
||||
relative to `/mnt/data`.
|
||||
local_file_path: The path to save the downloaded file to.
|
||||
If not provided, the file is returned as a BufferedReader.
|
||||
|
||||
Returns:
|
||||
BinaryIO: The data of the downloaded file.
|
||||
"""
|
||||
access_token = self.access_token_provider()
|
||||
encoded_remote_file_path = urllib.parse.quote(remote_file_path)
|
||||
api_url = self._build_url(f"files/content/{encoded_remote_file_path}")
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"User-Agent": USER_AGENT,
|
||||
}
|
||||
|
||||
response = requests.get(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
if local_file_path:
|
||||
with open(local_file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
return BytesIO(response.content)
|
||||
|
||||
def list_files(self) -> List[RemoteFileMetadata]:
|
||||
"""List the files in the session.
|
||||
|
||||
Returns:
|
||||
list[RemoteFileMetadata]: The metadata for the files in the session
|
||||
"""
|
||||
access_token = self.access_token_provider()
|
||||
api_url = self._build_url("files")
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"User-Agent": USER_AGENT,
|
||||
}
|
||||
|
||||
response = requests.get(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
return [RemoteFileMetadata.from_dict(entry) for entry in response_json["value"]]
|
2232
libs/partners/azure-dynamic-sessions/poetry.lock
generated
2232
libs/partners/azure-dynamic-sessions/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,115 +0,0 @@
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
name = "langchain-azure-dynamic-sessions"
|
||||
version = "0.2.0"
|
||||
description = "An integration package connecting Azure Container Apps dynamic sessions and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
license = "MIT"
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.poetry.urls]
|
||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/azure-dynamic-sessions"
|
||||
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-azure-dynamic-sessions%3D%3D0%22&expanded=true"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9,<4.0"
|
||||
langchain-core = "^0.3.0"
|
||||
azure-identity = "^1.16.0"
|
||||
requests = "^2.31.0"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "D"]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/**" = ["D"]
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
python-dotenv = "^1.0.1"
|
||||
# TODO: hack to fix 3.9 builds
|
||||
cffi = [
|
||||
{ version = "<1.17.1", python = "<3.10" },
|
||||
{ version = "*", python = ">=3.10" },
|
||||
]
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
python-dotenv = "^1.0.1"
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.5"
|
||||
python-dotenv = "^1.0.1"
|
||||
pytest = "^7.3.0"
|
||||
# TODO: hack to fix 3.9 builds
|
||||
cffi = [
|
||||
{ version = "<1.17.1", python = "<3.10" },
|
||||
{ version = "*", python = ">=3.10" },
|
||||
]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ipykernel = "^6.29.4"
|
||||
langchainhub = "^0.1.15"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^1.10"
|
||||
types-requests = "^2.31.0.20240406"
|
||||
|
||||
[tool.poetry.group.test.dependencies.langchain-core]
|
||||
path = "../../core"
|
||||
develop = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies.langchain-core]
|
||||
path = "../../core"
|
||||
develop = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies.langchain-openai]
|
||||
path = "../openai"
|
||||
develop = true
|
||||
|
||||
[tool.poetry.group.typing.dependencies.langchain-core]
|
||||
path = "../../core"
|
||||
develop = true
|
@ -1,19 +0,0 @@
|
||||
"""This module checks for specific import statements in the codebase."""
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_failure = True
|
||||
print(file)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
@ -1 +0,0 @@
|
||||
test file content
|
@ -1,7 +0,0 @@
|
||||
import pytest # type: ignore[import-not-found]
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
@ -1,68 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import dotenv # type: ignore[import-not-found]
|
||||
|
||||
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
POOL_MANAGEMENT_ENDPOINT = os.getenv("AZURE_DYNAMIC_SESSIONS_POOL_MANAGEMENT_ENDPOINT")
|
||||
TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "testdata.txt")
|
||||
TEST_DATA_CONTENT = open(TEST_DATA_PATH, "rb").read()
|
||||
|
||||
|
||||
def test_end_to_end() -> None:
|
||||
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT) # type: ignore[arg-type]
|
||||
result = tool.run("print('hello world')\n1 + 1")
|
||||
assert json.loads(result) == {
|
||||
"result": 2,
|
||||
"stdout": "hello world\n",
|
||||
"stderr": "",
|
||||
}
|
||||
|
||||
# upload file content
|
||||
uploaded_file1_metadata = tool.upload_file(
|
||||
remote_file_path="test1.txt", data=BytesIO(b"hello world!!!!!")
|
||||
)
|
||||
assert uploaded_file1_metadata.filename == "test1.txt"
|
||||
assert uploaded_file1_metadata.size_in_bytes == 16
|
||||
assert uploaded_file1_metadata.full_path == "/mnt/data/test1.txt"
|
||||
downloaded_file1 = tool.download_file(remote_file_path="test1.txt")
|
||||
assert downloaded_file1.read() == b"hello world!!!!!"
|
||||
|
||||
# upload file from buffer
|
||||
with open(TEST_DATA_PATH, "rb") as f:
|
||||
uploaded_file2_metadata = tool.upload_file(remote_file_path="test2.txt", data=f)
|
||||
assert uploaded_file2_metadata.filename == "test2.txt"
|
||||
downloaded_file2 = tool.download_file(remote_file_path="test2.txt")
|
||||
assert downloaded_file2.read() == TEST_DATA_CONTENT
|
||||
|
||||
# upload file from disk, specifying remote file path
|
||||
uploaded_file3_metadata = tool.upload_file(
|
||||
remote_file_path="test3.txt", local_file_path=TEST_DATA_PATH
|
||||
)
|
||||
assert uploaded_file3_metadata.filename == "test3.txt"
|
||||
downloaded_file3 = tool.download_file(remote_file_path="test3.txt")
|
||||
assert downloaded_file3.read() == TEST_DATA_CONTENT
|
||||
|
||||
# upload file from disk, without specifying remote file path
|
||||
uploaded_file4_metadata = tool.upload_file(local_file_path=TEST_DATA_PATH)
|
||||
assert uploaded_file4_metadata.filename == os.path.basename(TEST_DATA_PATH)
|
||||
downloaded_file4 = tool.download_file(
|
||||
remote_file_path=uploaded_file4_metadata.filename
|
||||
)
|
||||
assert downloaded_file4.read() == TEST_DATA_CONTENT
|
||||
|
||||
# list files
|
||||
remote_files_metadata = tool.list_files()
|
||||
assert len(remote_files_metadata) == 4
|
||||
remote_file_paths = [metadata.filename for metadata in remote_files_metadata]
|
||||
expected_filenames = [
|
||||
"test1.txt",
|
||||
"test2.txt",
|
||||
"test3.txt",
|
||||
os.path.basename(TEST_DATA_PATH),
|
||||
]
|
||||
assert set(remote_file_paths) == set(expected_filenames)
|
@ -1,9 +0,0 @@
|
||||
from langchain_azure_dynamic_sessions import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"SessionsPythonREPLTool",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
@ -1,208 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from unittest import mock
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from azure.core.credentials import AccessToken
|
||||
|
||||
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
|
||||
from langchain_azure_dynamic_sessions.tools.sessions import (
|
||||
_access_token_provider_factory,
|
||||
)
|
||||
|
||||
POOL_MANAGEMENT_ENDPOINT = "https://westus2.dynamicsessions.io/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/sessions-rg/sessionPools/my-pool"
|
||||
|
||||
|
||||
def test_default_access_token_provider_returns_token() -> None:
|
||||
access_token_provider = _access_token_provider_factory()
|
||||
with mock.patch(
|
||||
"azure.identity.DefaultAzureCredential.get_token"
|
||||
) as mock_get_token:
|
||||
mock_get_token.return_value = AccessToken("token_value", 0)
|
||||
access_token = access_token_provider()
|
||||
assert access_token == "token_value"
|
||||
|
||||
|
||||
def test_default_access_token_provider_returns_cached_token() -> None:
|
||||
access_token_provider = _access_token_provider_factory()
|
||||
with mock.patch(
|
||||
"azure.identity.DefaultAzureCredential.get_token"
|
||||
) as mock_get_token:
|
||||
mock_get_token.return_value = AccessToken(
|
||||
"token_value", int(time.time() + 1000)
|
||||
)
|
||||
access_token = access_token_provider()
|
||||
assert access_token == "token_value"
|
||||
assert mock_get_token.call_count == 1
|
||||
|
||||
mock_get_token.return_value = AccessToken(
|
||||
"new_token_value", int(time.time() + 1000)
|
||||
)
|
||||
access_token = access_token_provider()
|
||||
assert access_token == "token_value"
|
||||
assert mock_get_token.call_count == 1
|
||||
|
||||
|
||||
def test_default_access_token_provider_refreshes_expiring_token() -> None:
|
||||
access_token_provider = _access_token_provider_factory()
|
||||
with mock.patch(
|
||||
"azure.identity.DefaultAzureCredential.get_token"
|
||||
) as mock_get_token:
|
||||
mock_get_token.return_value = AccessToken("token_value", int(time.time() - 1))
|
||||
access_token = access_token_provider()
|
||||
assert access_token == "token_value"
|
||||
assert mock_get_token.call_count == 1
|
||||
|
||||
mock_get_token.return_value = AccessToken(
|
||||
"new_token_value", int(time.time() + 1000)
|
||||
)
|
||||
access_token = access_token_provider()
|
||||
assert access_token == "new_token_value"
|
||||
assert mock_get_token.call_count == 2
|
||||
|
||||
|
||||
@mock.patch("requests.post")
|
||||
@mock.patch("azure.identity.DefaultAzureCredential.get_token")
|
||||
def test_code_execution_calls_api(
|
||||
mock_get_token: mock.MagicMock, mock_post: mock.MagicMock
|
||||
) -> None:
|
||||
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
|
||||
mock_post.return_value.json.return_value = {
|
||||
"$id": "1",
|
||||
"properties": {
|
||||
"$id": "2",
|
||||
"status": "Success",
|
||||
"stdout": "hello world\n",
|
||||
"stderr": "",
|
||||
"result": "",
|
||||
"executionTimeInMilliseconds": 33,
|
||||
},
|
||||
}
|
||||
mock_get_token.return_value = AccessToken("token_value", int(time.time() + 1000))
|
||||
|
||||
result = tool.run("print('hello world')")
|
||||
|
||||
assert json.loads(result) == {
|
||||
"result": "",
|
||||
"stdout": "hello world\n",
|
||||
"stderr": "",
|
||||
}
|
||||
|
||||
api_url = f"{POOL_MANAGEMENT_ENDPOINT}/code/execute"
|
||||
headers = {
|
||||
"Authorization": "Bearer token_value",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": mock.ANY,
|
||||
}
|
||||
body = {
|
||||
"properties": {
|
||||
"codeInputType": "inline",
|
||||
"executionType": "synchronous",
|
||||
"code": "print('hello world')",
|
||||
}
|
||||
}
|
||||
mock_post.assert_called_once_with(mock.ANY, headers=headers, json=body)
|
||||
|
||||
called_headers = mock_post.call_args.kwargs["headers"]
|
||||
assert re.match(
|
||||
r"^langchain-azure-dynamic-sessions/\d+\.\d+\.\d+.* \(Language=Python\)",
|
||||
called_headers["User-Agent"],
|
||||
)
|
||||
|
||||
called_api_url = mock_post.call_args.args[0]
|
||||
assert called_api_url.startswith(api_url)
|
||||
|
||||
|
||||
@mock.patch("requests.post")
|
||||
@mock.patch("azure.identity.DefaultAzureCredential.get_token")
|
||||
def test_uses_specified_session_id(
|
||||
mock_get_token: mock.MagicMock, mock_post: mock.MagicMock
|
||||
) -> None:
|
||||
tool = SessionsPythonREPLTool(
|
||||
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT,
|
||||
session_id="00000000-0000-0000-0000-000000000003",
|
||||
)
|
||||
mock_post.return_value.json.return_value = {
|
||||
"$id": "1",
|
||||
"properties": {
|
||||
"$id": "2",
|
||||
"status": "Success",
|
||||
"stdout": "",
|
||||
"stderr": "",
|
||||
"result": "2",
|
||||
"executionTimeInMilliseconds": 33,
|
||||
},
|
||||
}
|
||||
mock_get_token.return_value = AccessToken("token_value", int(time.time() + 1000))
|
||||
tool.run("1 + 1")
|
||||
call_url = mock_post.call_args.args[0]
|
||||
parsed_url = urlparse(call_url)
|
||||
call_identifier = parse_qs(parsed_url.query)["identifier"][0]
|
||||
assert call_identifier == "00000000-0000-0000-0000-000000000003"
|
||||
|
||||
|
||||
def test_sanitizes_input() -> None:
|
||||
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
|
||||
with mock.patch("requests.post") as mock_post:
|
||||
mock_post.return_value.json.return_value = {
|
||||
"$id": "1",
|
||||
"properties": {
|
||||
"$id": "2",
|
||||
"status": "Success",
|
||||
"stdout": "",
|
||||
"stderr": "",
|
||||
"result": "",
|
||||
"executionTimeInMilliseconds": 33,
|
||||
},
|
||||
}
|
||||
tool.run("```python\nprint('hello world')\n```")
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
assert body["properties"]["code"] == "print('hello world')"
|
||||
|
||||
|
||||
def test_does_not_sanitize_input() -> None:
|
||||
tool = SessionsPythonREPLTool(
|
||||
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT, sanitize_input=False
|
||||
)
|
||||
with mock.patch("requests.post") as mock_post:
|
||||
mock_post.return_value.json.return_value = {
|
||||
"$id": "1",
|
||||
"properties": {
|
||||
"$id": "2",
|
||||
"status": "Success",
|
||||
"stdout": "",
|
||||
"stderr": "",
|
||||
"result": "",
|
||||
"executionTimeInMilliseconds": 33,
|
||||
},
|
||||
}
|
||||
tool.run("```python\nprint('hello world')\n```")
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
assert body["properties"]["code"] == "```python\nprint('hello world')\n```"
|
||||
|
||||
|
||||
def test_uses_custom_access_token_provider() -> None:
|
||||
def custom_access_token_provider() -> str:
|
||||
return "custom_token"
|
||||
|
||||
tool = SessionsPythonREPLTool(
|
||||
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT,
|
||||
access_token_provider=custom_access_token_provider,
|
||||
)
|
||||
|
||||
with mock.patch("requests.post") as mock_post:
|
||||
mock_post.return_value.json.return_value = {
|
||||
"$id": "1",
|
||||
"properties": {
|
||||
"$id": "2",
|
||||
"status": "Success",
|
||||
"stdout": "",
|
||||
"stderr": "",
|
||||
"result": "",
|
||||
"executionTimeInMilliseconds": 33,
|
||||
},
|
||||
}
|
||||
tool.run("print('hello world')")
|
||||
headers = mock_post.call_args.kwargs["headers"]
|
||||
assert headers["Authorization"] == "Bearer custom_token"
|
Loading…
Reference in New Issue
Block a user