langchain[patch]: Add bandit rules (#31818)

Integrate Bandit for security analysis, suppress warnings for specific issues, and address potential vulnerabilities such as hardcoded passwords and SQL injection risks. Adjust documentation and formatting for clarity.
This commit is contained in:
Mason Daugherty 2025-07-03 14:20:33 -04:00 committed by GitHub
parent df06041eb2
commit 6a5073b227
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 151 additions and 108 deletions

View File

@ -1,7 +1,7 @@
# Streamlit
> **[Streamlit](https://streamlit.io/) is a faster way to build and share data apps.**
> Streamlit turns data scripts into shareable web apps in minutes. All in pure Python. No frontend experience required.
> Streamlit turns data scripts into shareable web apps in minutes. All in pure Python. No front-end experience required.
> See more examples at [streamlit.io/generative-ai](https://streamlit.io/generative-ai).
[![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/langchain-ai/streamlit-agent?quickstart=1)

View File

@ -1,7 +1,7 @@
# Streamlit
>[Streamlit](https://streamlit.io/) is a faster way to build and share data apps.
>`Streamlit` turns data scripts into shareable web apps in minutes. All in pure Python. No frontend experience required.
>`Streamlit` turns data scripts into shareable web apps in minutes. All in pure Python. No front-end experience required.
>See more examples at [streamlit.io/generative-ai](https://streamlit.io/generative-ai).
## Installation and Setup

View File

@ -1307,7 +1307,10 @@ class AgentExecutor(Chain):
self, values: NextStepOutput
) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
if isinstance(values[-1], AgentFinish):
assert len(values) == 1
if len(values) != 1:
raise ValueError(
"Expected a single AgentFinish output, but got multiple values."
)
return values[-1]
else:
return [

View File

@ -173,6 +173,7 @@ def openapi_spec_to_openai_fn(
fn_args: dict,
headers: Optional[dict] = None,
params: Optional[dict] = None,
timeout: Optional[int] = 30,
**kwargs: Any,
) -> Any:
method = _name_to_call_map[name]["method"]
@ -192,7 +193,7 @@ def openapi_spec_to_openai_fn(
_kwargs["params"].update(params)
else:
_kwargs["params"] = params
return requests.request(method, url, **_kwargs)
return requests.request(method, url, **_kwargs, timeout=timeout)
return functions, default_call_api
@ -367,7 +368,7 @@ def get_openapi_chain(
break
except ImportError as e:
raise e
except Exception:
except Exception: # noqa: S110
pass
if isinstance(spec, str):
raise ValueError(f"Unable to parse spec from source {spec}")

View File

@ -26,16 +26,16 @@ NAMESPACE_UUID = uuid.UUID(int=1985)
def _sha1_hash_to_uuid(text: str) -> uuid.UUID:
"""Return a UUID derived from *text* using SHA1 (deterministic).
"""Return a UUID derived from *text* using SHA-1 (deterministic).
Deterministic and fast, **but not collisionresistant**.
Deterministic and fast, **but not collision-resistant**.
A malicious attacker could try to create two different texts that hash to the same
UUID. This may not necessarily be an issue in the context of caching embeddings,
but new applications should swap this out for a stronger hash function like
xxHash, BLAKE2 or SHA256, which are collision-resistant.
xxHash, BLAKE2 or SHA-256, which are collision-resistant.
"""
sha1_hex = hashlib.sha1(text.encode("utf-8")).hexdigest()
sha1_hex = hashlib.sha1(text.encode("utf-8"), usedforsecurity=False).hexdigest()
# Embed the hex string in `uuid5` to obtain a valid UUID.
return uuid.uuid5(NAMESPACE_UUID, sha1_hex)
@ -46,10 +46,10 @@ def _make_default_key_encoder(namespace: str, algorithm: str) -> Callable[[str],
Args:
namespace: Prefix that segregates keys from different embedding models.
algorithm:
* `sha1` - fast but not collisionresistant
* `blake2b` - cryptographically strong, faster than SHA1
* `sha256` - cryptographically strong, slower than SHA1
* `sha512` - cryptographically strong, slower than SHA1
* ``'sha1'`` - fast but not collision-resistant
* ``'blake2b'`` - cryptographically strong, faster than SHA-1
* ``'sha256'`` - cryptographically strong, slower than SHA-1
* ``'sha512'`` - cryptographically strong, slower than SHA-1
Returns:
A function that encodes a key using the specified algorithm.
@ -87,15 +87,15 @@ _warned_about_sha1: bool = False
def _warn_about_sha1_encoder() -> None:
"""Emit a onetime warning about SHA1 collision weaknesses."""
"""Emit a one-time warning about SHA-1 collision weaknesses."""
global _warned_about_sha1
if not _warned_about_sha1:
warnings.warn(
"Using default key encoder: SHA1 is *not* collisionresistant. "
"Using default key encoder: SHA-1 is *not* collision-resistant. "
"While acceptable for most cache scenarios, a motivated attacker "
"can craft two different payloads that map to the same cache key. "
"If that risk matters in your environment, supply a stronger "
"encoder (e.g. SHA256 or BLAKE2) via the `key_encoder` argument. "
"encoder (e.g. SHA-256 or BLAKE2) via the `key_encoder` argument. "
"If you change the key encoder, consider also creating a new cache, "
"to avoid (the potential for) collisions with existing keys.",
category=UserWarning,
@ -118,7 +118,6 @@ class CacheBackedEmbeddings(Embeddings):
embeddings too, pass in a query_embedding_store to constructor.
Examples:
.. code-block: python
from langchain.embeddings import CacheBackedEmbeddings
@ -154,7 +153,7 @@ class CacheBackedEmbeddings(Embeddings):
document_embedding_store: The store to use for caching document embeddings.
batch_size: The number of documents to embed between store updates.
query_embedding_store: The store to use for caching query embeddings.
If None, query embeddings are not cached.
If ``None``, query embeddings are not cached.
"""
super().__init__()
self.document_embedding_store = document_embedding_store
@ -236,7 +235,7 @@ class CacheBackedEmbeddings(Embeddings):
"""Embed query text.
By default, this method does not cache queries. To enable caching, set the
`cache_query` parameter to `True` when initializing the embedder.
``cache_query`` parameter to ``True`` when initializing the embedder.
Args:
text: The text to embed.
@ -259,7 +258,7 @@ class CacheBackedEmbeddings(Embeddings):
"""Embed query text.
By default, this method does not cache queries. To enable caching, set the
`cache_query` parameter to `True` when initializing the embedder.
``cache_query`` parameter to ``True`` when initializing the embedder.
Args:
text: The text to embed.
@ -305,7 +304,7 @@ class CacheBackedEmbeddings(Embeddings):
True to use the same cache as document embeddings.
False to not cache query embeddings.
key_encoder: Optional callable to encode keys. If not provided,
a default encoder using SHA1 will be used. SHA-1 is not
a default encoder using SHA-1 will be used. SHA-1 is not
collision-resistant, and a motivated attacker could craft two
different texts that hash to the same cache key.

View File

@ -4,7 +4,7 @@ import logging
from abc import ABC, abstractmethod
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel
@ -20,6 +20,9 @@ from langchain.memory.prompt import (
)
from langchain.memory.utils import get_prompt_input_key
if TYPE_CHECKING:
import sqlite3
logger = logging.getLogger(__name__)
@ -283,7 +286,7 @@ class RedisEntityStore(BaseEntityStore):
),
)
class SQLiteEntityStore(BaseEntityStore):
"""SQLite-backed Entity store"""
"""SQLite-backed Entity store with safe query construction."""
session_id: str = "default"
table_name: str = "memory_store"
@ -301,6 +304,7 @@ class SQLiteEntityStore(BaseEntityStore):
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
try:
import sqlite3
except ImportError:
@ -308,7 +312,13 @@ class SQLiteEntityStore(BaseEntityStore):
"Could not import sqlite3 python package. "
"Please install it with `pip install sqlite3`."
)
super().__init__(*args, **kwargs)
# Basic validation to prevent obviously malicious table/session names
if not table_name.isidentifier() or not session_id.isidentifier():
# Since we validate here, we can safely suppress the S608 bandit warning
raise ValueError(
"Table name and session ID must be valid Python identifiers."
)
self.conn = sqlite3.connect(db_file)
self.session_id = session_id
@ -319,62 +329,60 @@ class SQLiteEntityStore(BaseEntityStore):
def full_table_name(self) -> str:
return f"{self.table_name}_{self.session_id}"
def _execute_query(self, query: str, params: tuple = ()) -> "sqlite3.Cursor":
"""Executes a query with proper connection handling."""
with self.conn:
return self.conn.execute(query, params)
def _create_table_if_not_exists(self) -> None:
"""Creates the entity table if it doesn't exist, using safe quoting."""
# Use standard SQL double quotes for the table name identifier
create_table_query = f"""
CREATE TABLE IF NOT EXISTS {self.full_table_name} (
CREATE TABLE IF NOT EXISTS "{self.full_table_name}" (
key TEXT PRIMARY KEY,
value TEXT
)
"""
with self.conn:
self.conn.execute(create_table_query)
self._execute_query(create_table_query)
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
query = f"""
SELECT value
FROM {self.full_table_name}
WHERE key = ?
"""
cursor = self.conn.execute(query, (key,))
"""Retrieves a value, safely quoting the table name."""
# `?` placeholder is used for the value to prevent SQL injection
# noqa since we validate for malicious table/session names in `__init__`
query = f'SELECT value FROM "{self.full_table_name}" WHERE key = ?' # noqa: S608
cursor = self._execute_query(query, (key,))
result = cursor.fetchone()
if result is not None:
value = result[0]
return value
return default
return result[0] if result is not None else default
def set(self, key: str, value: Optional[str]) -> None:
"""Inserts or replaces a value, safely quoting the table name."""
if not value:
return self.delete(key)
query = f"""
INSERT OR REPLACE INTO {self.full_table_name} (key, value)
VALUES (?, ?)
"""
with self.conn:
self.conn.execute(query, (key, value))
# noqa since we validate for malicious table/session names in `__init__`
query = (
"INSERT OR REPLACE INTO " # noqa: S608
f'"{self.full_table_name}" (key, value) VALUES (?, ?)'
)
self._execute_query(query, (key, value))
def delete(self, key: str) -> None:
query = f"""
DELETE FROM {self.full_table_name}
WHERE key = ?
"""
with self.conn:
self.conn.execute(query, (key,))
"""Deletes a key-value pair, safely quoting the table name."""
# noqa since we validate for malicious table/session names in `__init__`
query = f'DELETE FROM "{self.full_table_name}" WHERE key = ?' # noqa: S608
self._execute_query(query, (key,))
def exists(self, key: str) -> bool:
query = f"""
SELECT 1
FROM {self.full_table_name}
WHERE key = ?
LIMIT 1
"""
cursor = self.conn.execute(query, (key,))
result = cursor.fetchone()
return result is not None
"""Checks for the existence of a key, safely quoting the table name."""
# noqa since we validate for malicious table/session names in `__init__`
query = f'SELECT 1 FROM "{self.full_table_name}" WHERE key = ? LIMIT 1' # noqa: S608
cursor = self._execute_query(query, (key,))
return cursor.fetchone() is not None
def clear(self) -> None:
# noqa since we validate for malicious table/session names in `__init__`
query = f"""
DELETE FROM {self.full_table_name}
"""
""" # noqa: S608
with self.conn:
self.conn.execute(query)

View File

@ -1,42 +1,43 @@
import random
from datetime import datetime, timedelta
from datetime import datetime
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.utils import comma_list
def _generate_random_datetime_strings(
pattern: str,
n: int = 3,
start_date: datetime = datetime(1, 1, 1),
end_date: datetime = datetime.now() + timedelta(days=3650),
) -> list[str]:
"""Generates n random datetime strings conforming to the
given pattern within the specified date range.
Pattern should be a string containing the desired format codes.
start_date and end_date should be datetime objects representing
the start and end of the date range.
"""
examples = []
delta = end_date - start_date
for i in range(n):
random_delta = random.uniform(0, delta.total_seconds())
dt = start_date + timedelta(seconds=random_delta)
date_string = dt.strftime(pattern)
examples.append(date_string)
return examples
class DatetimeOutputParser(BaseOutputParser[datetime]):
"""Parse the output of an LLM call to a datetime."""
format: str = "%Y-%m-%dT%H:%M:%S.%fZ"
"""The string value that used as the datetime format."""
"""The string value that is used as the datetime format.
Update this to match the desired datetime format for your application.
"""
def get_format_instructions(self) -> str:
examples = comma_list(_generate_random_datetime_strings(self.format))
"""Returns the format instructions for the given format."""
if self.format == "%Y-%m-%dT%H:%M:%S.%fZ":
examples = comma_list(
[
"2023-07-04T14:30:00.000000Z",
"1999-12-31T23:59:59.999999Z",
"2025-01-01T00:00:00.000000Z",
]
)
else:
try:
now = datetime.now()
examples = comma_list(
[
now.strftime(self.format),
(now.replace(year=now.year - 1)).strftime(self.format),
(now.replace(day=now.day - 1)).strftime(self.format),
]
)
except ValueError:
# Fallback if the format is very unusual
examples = f"e.g., a valid string in the format {self.format}"
return (
f"Write a datetime string that matches the "
f"following pattern: '{self.format}'.\n\n"
@ -45,6 +46,7 @@ class DatetimeOutputParser(BaseOutputParser[datetime]):
)
def parse(self, response: str) -> datetime:
"""Parse a string into a datetime object."""
try:
return datetime.strptime(response.strip(), self.format)
except ValueError as e:

View File

@ -37,8 +37,14 @@ class OpenAIFunctionsRouter(RunnableBindingBase[BaseMessage, Any]):
functions: Optional[list[OpenAIFunction]] = None,
):
if functions is not None:
assert len(functions) == len(runnables)
assert all(func["name"] in runnables for func in functions)
if len(functions) != len(runnables):
raise ValueError(
"The number of functions does not match the number of runnables."
)
if not all(func["name"] in runnables for func in functions):
raise ValueError(
"One or more function names are not found in runnables."
)
router = (
JsonOutputFunctionsParser(args_only=False)
| {"key": itemgetter("name"), "input": itemgetter("arguments")}

View File

@ -721,7 +721,7 @@ nouns = [
def random_name() -> str:
"""Generate a random name."""
adjective = random.choice(adjectives)
noun = random.choice(nouns)
number = random.randint(1, 100)
adjective = random.choice(adjectives) # noqa: S311
noun = random.choice(nouns) # noqa: S311
number = random.randint(1, 100) # noqa: S311
return f"{adjective}-{noun}-{number}"

View File

@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
[tool.ruff.lint]
select = ["E", "F", "I", "PGH003", "T201", "D", "UP"]
select = ["E", "F", "I", "PGH003", "T201", "D", "UP", "S"]
ignore = ["UP007", ]
pydocstyle = { convention = "google" }
@ -151,6 +151,12 @@ pydocstyle = { convention = "google" }
"tests/*" = ["D"]
"!langchain/indexes/vectorstore.py" = ["D"]
[tool.ruff.lint.extend-per-file-ignores]
"tests/**/*.py" = [
"S101", # Tests need assertions
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
]
[tool.coverage.run]
omit = ["tests/*"]

View File

@ -1,3 +1,12 @@
"""
Quickly verify that a list of Python files can be loaded by the Python interpreter
without raising any errors. Ran before running more expensive tests. Useful in
Makefiles.
If loading a file fails, the script prints the problematic filename and the detailed
error traceback.
"""
import random
import string
import sys
@ -10,7 +19,8 @@ if __name__ == "__main__":
for file in files:
try:
module_name = "".join(
random.choice(string.ascii_letters) for _ in range(20)
random.choice(string.ascii_letters) # noqa: S311
for _ in range(20)
)
SourceFileLoader(module_name, file).load_module()
except Exception:

View File

@ -201,5 +201,6 @@ def custom_openapi() -> dict[str, Any]:
# This lets us prevent the "servers" configuration from being overwritten in
# the auto-generated OpenAPI schema
app.openapi = custom_openapi
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=PORT)
uvicorn.run(app, host="127.0.0.1", port=PORT)

View File

@ -201,11 +201,11 @@ def test_sha512_encoder() -> None:
def test_sha1_warning_emitted_once() -> None:
"""Test that a warning is emitted when using SHA1 as the default key encoder."""
"""Test that a warning is emitted when using SHA-1 as the default key encoder."""
module = importlib.import_module(CacheBackedEmbeddings.__module__)
# Create a *temporary* MonkeyPatch object whose effects disappear
# automatically when the withblock exits.
# automatically when the with-block exits.
with pytest.MonkeyPatch.context() as mp:
# We're monkey patching the module to reset the `_warned_about_sha1` flag
# which may have been set while testing other parts of the codebase.
@ -219,7 +219,7 @@ def test_sha1_warning_emitted_once() -> None:
CacheBackedEmbeddings.from_bytes_store(emb, store) # triggers warning
CacheBackedEmbeddings.from_bytes_store(emb, store) # silent
sha1_msgs = [w for w in caught if "SHA1" in str(w.message)]
sha1_msgs = [w for w in caught if "SHA-1" in str(w.message)]
assert len(sha1_msgs) == 1

View File

@ -54,9 +54,9 @@ class NotSerializable:
def test_person(snapshot: Any) -> None:
p = Person(secret="hello")
p = Person(secret="parrot party") # noqa: S106
assert dumps(p, pretty=True) == snapshot
sp = SpecialPerson(another_secret="Wooo", secret="Hmm")
sp = SpecialPerson(another_secret="Wooo", secret="Hmm") # noqa: S106
assert dumps(sp, pretty=True) == snapshot
assert Person.lc_id() == ["tests", "unit_tests", "load", "test_dump", "Person"]
assert SpecialPerson.lc_id() == ["my", "special", "namespace", "SpecialPerson"]
@ -70,12 +70,12 @@ def test_typeerror() -> None:
def test_person_with_kwargs(snapshot: Any) -> None:
person = Person(secret="hello")
person = Person(secret="parrot party") # noqa: S106
assert dumps(person, separators=(",", ":")) == snapshot
def test_person_with_invalid_kwargs() -> None:
person = Person(secret="hello")
person = Person(secret="parrot party") # noqa: S106
with pytest.raises(TypeError):
dumps(person, invalid_kwarg="hello")
@ -115,7 +115,10 @@ class TestClass(Serializable):
def test_aliases_hidden() -> None:
test_class = TestClass(my_favorite_secret="hello", my_other_secret="world") # type: ignore[call-arg]
test_class = TestClass(
my_favorite_secret="hello", # noqa: S106 # type: ignore[call-arg]
my_other_secret="world", # noqa: S106
) # type: ignore[call-arg]
dumped = json.loads(dumps(test_class, pretty=True))
expected_dump = {
"lc": 1,
@ -139,7 +142,10 @@ def test_aliases_hidden() -> None:
dumped = json.loads(dumps(test_class, pretty=True))
# Check by alias
test_class = TestClass(my_favorite_secret_alias="hello", my_other_secret="world")
test_class = TestClass(
my_favorite_secret_alias="hello", # noqa: S106
my_other_secret="parrot party", # noqa: S106
)
dumped = json.loads(dumps(test_class, pretty=True))
expected_dump = {
"lc": 1,

View File

@ -6,4 +6,5 @@ import requests
def test_socket_disabled() -> None:
"""This test should fail."""
with pytest.raises(pytest_socket.SocketBlockedError):
requests.get("https://www.example.com")
# noqa since we don't need a timeout here as the request should fail immediately
requests.get("https://www.example.com") # noqa: S113