docs: more standardization (#33124)

This commit is contained in:
Mason Daugherty
2025-09-25 20:46:20 -04:00
committed by GitHub
parent a5137b0a3e
commit 986302322f
125 changed files with 889 additions and 869 deletions

View File

@@ -7,10 +7,9 @@ import os
import re
import ssl
import uuid
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import AbstractAsyncContextManager
from operator import itemgetter
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
@@ -77,6 +76,10 @@ from pydantic import (
)
from typing_extensions import Self
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import AbstractAsyncContextManager
logger = logging.getLogger(__name__)
# Mistral enforces a specific pattern for tool call IDs
@@ -139,7 +142,7 @@ def _convert_mistral_chat_message_to_message(
if role != "assistant":
msg = f"Expected role to be 'assistant', got {role}"
raise ValueError(msg)
content = cast(str, _message["content"])
content = cast("str", _message["content"])
additional_kwargs: dict = {}
tool_calls = []
@@ -149,7 +152,7 @@ def _convert_mistral_chat_message_to_message(
for raw_tool_call in raw_tool_calls:
try:
parsed: dict = cast(
dict, parse_tool_call(raw_tool_call, return_id=True)
"dict", parse_tool_call(raw_tool_call, return_id=True)
)
if not parsed["id"]:
parsed["id"] = uuid.uuid4().hex[:]
@@ -516,7 +519,7 @@ class ChatMistralAI(BaseChatModel):
else:
api_key_str = self.mistral_api_key
# todo: handle retries
# TODO: handle retries
base_url_str = (
self.endpoint
or os.environ.get("MISTRAL_BASE_URL")
@@ -534,7 +537,7 @@ class ChatMistralAI(BaseChatModel):
timeout=self.timeout,
verify=global_ssl_context,
)
# todo: handle retries and max_concurrency
# TODO: handle retries and max_concurrency
if not self.async_client:
self.async_client = httpx.AsyncClient(
base_url=base_url_str,
@@ -639,7 +642,7 @@ class ChatMistralAI(BaseChatModel):
gen_chunk = ChatGenerationChunk(message=new_chunk)
if run_manager:
run_manager.on_llm_new_token(
token=cast(str, new_chunk.content), chunk=gen_chunk
token=cast("str", new_chunk.content), chunk=gen_chunk
)
yield gen_chunk
@@ -665,7 +668,7 @@ class ChatMistralAI(BaseChatModel):
gen_chunk = ChatGenerationChunk(message=new_chunk)
if run_manager:
await run_manager.on_llm_new_token(
token=cast(str, new_chunk.content), chunk=gen_chunk
token=cast("str", new_chunk.content), chunk=gen_chunk
)
yield gen_chunk

View File

@@ -155,7 +155,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
def validate_environment(self) -> Self:
"""Validate configuration."""
api_key_str = self.mistral_api_key.get_secret_value()
# todo: handle retries
# TODO: handle retries
if not self.client:
self.client = httpx.Client(
base_url=self.endpoint,
@@ -166,7 +166,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
},
timeout=self.timeout,
)
# todo: handle retries and max_concurrency
# TODO: handle retries and max_concurrency
if not self.async_client:
self.async_client = httpx.AsyncClient(
base_url=self.endpoint,
@@ -255,8 +255,8 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
for response in batch_responses
for embedding_obj in response.json()["data"]
]
except Exception as e:
logger.error(f"An error occurred with MistralAI: {e}")
except Exception:
logger.exception("An error occurred with MistralAI")
raise
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
@@ -287,8 +287,8 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
for response in batch_responses
for embedding_obj in response.json()["data"]
]
except Exception as e:
logger.error(f"An error occurred with MistralAI: {e}")
except Exception:
logger.exception("An error occurred with MistralAI")
raise
def embed_query(self, text: str) -> list[float]:

View File

@@ -50,52 +50,8 @@ target-version = "py39"
docstring-code-format = true
[tool.ruff.lint]
select = [
"A", # flake8-builtins
"B", # flake8-bugbear
"ASYNC", # flake8-async
"C4", # flake8-comprehensions
"COM", # flake8-commas
"D", # pydocstyle
"E", # pycodestyle error
"EM", # flake8-errmsg
"F", # pyflakes
"FA", # flake8-future-annotations
"FBT", # flake8-boolean-trap
"FLY", # flake8-flynt
"I", # isort
"ICN", # flake8-import-conventions
"INT", # flake8-gettext
"ISC", # isort-comprehensions
"PGH", # pygrep-hooks
"PIE", # flake8-pie
"PERF", # flake8-perf
"PYI", # flake8-pyi
"Q", # flake8-quotes
"RET", # flake8-return
"RSE", # flake8-rst-docstrings
"RUF", # ruff
"S", # flake8-bandit
"SLF", # flake8-self
"SLOT", # flake8-slots
"SIM", # flake8-simplify
"T10", # flake8-debugger
"T20", # flake8-print
"TID", # flake8-tidy-imports
"UP", # pyupgrade
"W", # pycodestyle warning
"YTT", # flake8-2020
]
select = ["ALL"]
ignore = [
"D100", # pydocstyle: Missing docstring in public module
"D101", # pydocstyle: Missing docstring in public class
"D102", # pydocstyle: Missing docstring in public method
"D103", # pydocstyle: Missing docstring in public function
"D104", # pydocstyle: Missing docstring in public package
"D105", # pydocstyle: Missing docstring in magic method
"D107", # pydocstyle: Missing docstring in __init__
"D203", # Messes with the formatter
"D407", # pydocstyle: Missing-dashed-underline-after-section
"COM812", # Messes with the formatter
"ISC001", # Messes with the formatter
"PERF203", # Rarely useful
@@ -104,11 +60,29 @@ ignore = [
"SLF001", # Private member access
"UP007", # pyupgrade: non-pep604-annotation-union
"UP045", # pyupgrade: non-pep604-annotation-optional
"TD",
"PLR0912",
"C901",
"FIX",
# TODO
"TC002",
"ANN401",
"ARG001",
"ARG002",
"PT011",
"PLC0415",
"PLR2004",
"BLE001",
"D100",
"D102",
"D104",
]
unfixable = ["B028"] # People should intentionally tune the stacklevel
[tool.ruff.lint.pydocstyle]
convention = "google"
ignore-var-parameters = true # ignore missing documentation for *args and **kwargs parameters
[tool.coverage.run]
omit = ["tests/*"]
@@ -125,4 +99,9 @@ asyncio_mode = "auto"
"tests/**/*.py" = [
"S101", # Tests need assertions
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
"PLR2004",
"D",
]
"scripts/*.py" = [
"INP001", # Not a package
]

View File

@@ -320,6 +320,7 @@ def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None:
# Measure start time
t0 = time.time()
logger = logging.getLogger(__name__)
try:
# Try to get a response
@@ -327,7 +328,7 @@ def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None:
# If successful, validate the response
elapsed_time = time.time() - t0
logging.info(f"Request succeeded in {elapsed_time:.2f} seconds")
logger.info("Request succeeded in %.2f seconds", elapsed_time)
# Check that we got a valid response
assert response.content
assert isinstance(response.content, str)
@@ -335,9 +336,9 @@ def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None:
except ReadTimeout:
elapsed_time = time.time() - t0
logging.info(f"Request timed out after {elapsed_time:.2f} seconds")
logger.info("Request timed out after %.2f seconds", elapsed_time)
assert elapsed_time >= 3.0
pytest.skip("Test timed out as expected with short timeout")
except Exception as e:
logging.error(f"Unexpected exception: {e}")
except Exception:
logger.exception("Unexpected exception")
raise

View File

@@ -43,11 +43,11 @@ def test_mistralai_initialization() -> None:
ChatMistralAI(model="test", mistral_api_key="test"), # type: ignore[call-arg, call-arg]
ChatMistralAI(model="test", api_key="test"), # type: ignore[call-arg, arg-type]
]:
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"
assert cast("SecretStr", model.mistral_api_key).get_secret_value() == "test"
@pytest.mark.parametrize(
"model,expected_url",
("model", "expected_url"),
[
(ChatMistralAI(model="test"), "https://api.mistral.ai/v1"), # type: ignore[call-arg, arg-type]
(ChatMistralAI(model="test", endpoint="baz"), "baz"), # type: ignore[call-arg, arg-type]

View File

@@ -14,4 +14,4 @@ def test_mistral_init() -> None:
MistralAIEmbeddings(model="mistral-embed", api_key="test"), # type: ignore[arg-type]
]:
assert model.model == "mistral-embed"
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"
assert cast("SecretStr", model.mistral_api_key).get_secret_value() == "test"