mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
docs: more standardization (#33124)
This commit is contained in:
@@ -7,10 +7,9 @@ import os
|
||||
import re
|
||||
import ssl
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
@@ -77,6 +76,10 @@ from pydantic import (
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Mistral enforces a specific pattern for tool call IDs
|
||||
@@ -139,7 +142,7 @@ def _convert_mistral_chat_message_to_message(
|
||||
if role != "assistant":
|
||||
msg = f"Expected role to be 'assistant', got {role}"
|
||||
raise ValueError(msg)
|
||||
content = cast(str, _message["content"])
|
||||
content = cast("str", _message["content"])
|
||||
|
||||
additional_kwargs: dict = {}
|
||||
tool_calls = []
|
||||
@@ -149,7 +152,7 @@ def _convert_mistral_chat_message_to_message(
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
try:
|
||||
parsed: dict = cast(
|
||||
dict, parse_tool_call(raw_tool_call, return_id=True)
|
||||
"dict", parse_tool_call(raw_tool_call, return_id=True)
|
||||
)
|
||||
if not parsed["id"]:
|
||||
parsed["id"] = uuid.uuid4().hex[:]
|
||||
@@ -516,7 +519,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
else:
|
||||
api_key_str = self.mistral_api_key
|
||||
|
||||
# todo: handle retries
|
||||
# TODO: handle retries
|
||||
base_url_str = (
|
||||
self.endpoint
|
||||
or os.environ.get("MISTRAL_BASE_URL")
|
||||
@@ -534,7 +537,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
timeout=self.timeout,
|
||||
verify=global_ssl_context,
|
||||
)
|
||||
# todo: handle retries and max_concurrency
|
||||
# TODO: handle retries and max_concurrency
|
||||
if not self.async_client:
|
||||
self.async_client = httpx.AsyncClient(
|
||||
base_url=base_url_str,
|
||||
@@ -639,7 +642,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
gen_chunk = ChatGenerationChunk(message=new_chunk)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
token=cast(str, new_chunk.content), chunk=gen_chunk
|
||||
token=cast("str", new_chunk.content), chunk=gen_chunk
|
||||
)
|
||||
yield gen_chunk
|
||||
|
||||
@@ -665,7 +668,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
gen_chunk = ChatGenerationChunk(message=new_chunk)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
token=cast(str, new_chunk.content), chunk=gen_chunk
|
||||
token=cast("str", new_chunk.content), chunk=gen_chunk
|
||||
)
|
||||
yield gen_chunk
|
||||
|
||||
|
||||
@@ -155,7 +155,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate configuration."""
|
||||
api_key_str = self.mistral_api_key.get_secret_value()
|
||||
# todo: handle retries
|
||||
# TODO: handle retries
|
||||
if not self.client:
|
||||
self.client = httpx.Client(
|
||||
base_url=self.endpoint,
|
||||
@@ -166,7 +166,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
# todo: handle retries and max_concurrency
|
||||
# TODO: handle retries and max_concurrency
|
||||
if not self.async_client:
|
||||
self.async_client = httpx.AsyncClient(
|
||||
base_url=self.endpoint,
|
||||
@@ -255,8 +255,8 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
for response in batch_responses
|
||||
for embedding_obj in response.json()["data"]
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred with MistralAI: {e}")
|
||||
except Exception:
|
||||
logger.exception("An error occurred with MistralAI")
|
||||
raise
|
||||
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
@@ -287,8 +287,8 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
for response in batch_responses
|
||||
for embedding_obj in response.json()["data"]
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred with MistralAI: {e}")
|
||||
except Exception:
|
||||
logger.exception("An error occurred with MistralAI")
|
||||
raise
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
|
||||
@@ -50,52 +50,8 @@ target-version = "py39"
|
||||
docstring-code-format = true
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"A", # flake8-builtins
|
||||
"B", # flake8-bugbear
|
||||
"ASYNC", # flake8-async
|
||||
"C4", # flake8-comprehensions
|
||||
"COM", # flake8-commas
|
||||
"D", # pydocstyle
|
||||
"E", # pycodestyle error
|
||||
"EM", # flake8-errmsg
|
||||
"F", # pyflakes
|
||||
"FA", # flake8-future-annotations
|
||||
"FBT", # flake8-boolean-trap
|
||||
"FLY", # flake8-flynt
|
||||
"I", # isort
|
||||
"ICN", # flake8-import-conventions
|
||||
"INT", # flake8-gettext
|
||||
"ISC", # isort-comprehensions
|
||||
"PGH", # pygrep-hooks
|
||||
"PIE", # flake8-pie
|
||||
"PERF", # flake8-perf
|
||||
"PYI", # flake8-pyi
|
||||
"Q", # flake8-quotes
|
||||
"RET", # flake8-return
|
||||
"RSE", # flake8-rst-docstrings
|
||||
"RUF", # ruff
|
||||
"S", # flake8-bandit
|
||||
"SLF", # flake8-self
|
||||
"SLOT", # flake8-slots
|
||||
"SIM", # flake8-simplify
|
||||
"T10", # flake8-debugger
|
||||
"T20", # flake8-print
|
||||
"TID", # flake8-tidy-imports
|
||||
"UP", # pyupgrade
|
||||
"W", # pycodestyle warning
|
||||
"YTT", # flake8-2020
|
||||
]
|
||||
select = ["ALL"]
|
||||
ignore = [
|
||||
"D100", # pydocstyle: Missing docstring in public module
|
||||
"D101", # pydocstyle: Missing docstring in public class
|
||||
"D102", # pydocstyle: Missing docstring in public method
|
||||
"D103", # pydocstyle: Missing docstring in public function
|
||||
"D104", # pydocstyle: Missing docstring in public package
|
||||
"D105", # pydocstyle: Missing docstring in magic method
|
||||
"D107", # pydocstyle: Missing docstring in __init__
|
||||
"D203", # Messes with the formatter
|
||||
"D407", # pydocstyle: Missing-dashed-underline-after-section
|
||||
"COM812", # Messes with the formatter
|
||||
"ISC001", # Messes with the formatter
|
||||
"PERF203", # Rarely useful
|
||||
@@ -104,11 +60,29 @@ ignore = [
|
||||
"SLF001", # Private member access
|
||||
"UP007", # pyupgrade: non-pep604-annotation-union
|
||||
"UP045", # pyupgrade: non-pep604-annotation-optional
|
||||
"TD",
|
||||
"PLR0912",
|
||||
"C901",
|
||||
"FIX",
|
||||
|
||||
# TODO
|
||||
"TC002",
|
||||
"ANN401",
|
||||
"ARG001",
|
||||
"ARG002",
|
||||
"PT011",
|
||||
"PLC0415",
|
||||
"PLR2004",
|
||||
"BLE001",
|
||||
"D100",
|
||||
"D102",
|
||||
"D104",
|
||||
]
|
||||
unfixable = ["B028"] # People should intentionally tune the stacklevel
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
ignore-var-parameters = true # ignore missing documentation for *args and **kwargs parameters
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
@@ -125,4 +99,9 @@ asyncio_mode = "auto"
|
||||
"tests/**/*.py" = [
|
||||
"S101", # Tests need assertions
|
||||
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
|
||||
"PLR2004",
|
||||
"D",
|
||||
]
|
||||
"scripts/*.py" = [
|
||||
"INP001", # Not a package
|
||||
]
|
||||
|
||||
@@ -320,6 +320,7 @@ def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None:
|
||||
|
||||
# Measure start time
|
||||
t0 = time.time()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
# Try to get a response
|
||||
@@ -327,7 +328,7 @@ def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None:
|
||||
|
||||
# If successful, validate the response
|
||||
elapsed_time = time.time() - t0
|
||||
logging.info(f"Request succeeded in {elapsed_time:.2f} seconds")
|
||||
logger.info("Request succeeded in %.2f seconds", elapsed_time)
|
||||
# Check that we got a valid response
|
||||
assert response.content
|
||||
assert isinstance(response.content, str)
|
||||
@@ -335,9 +336,9 @@ def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None:
|
||||
|
||||
except ReadTimeout:
|
||||
elapsed_time = time.time() - t0
|
||||
logging.info(f"Request timed out after {elapsed_time:.2f} seconds")
|
||||
logger.info("Request timed out after %.2f seconds", elapsed_time)
|
||||
assert elapsed_time >= 3.0
|
||||
pytest.skip("Test timed out as expected with short timeout")
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected exception: {e}")
|
||||
except Exception:
|
||||
logger.exception("Unexpected exception")
|
||||
raise
|
||||
|
||||
@@ -43,11 +43,11 @@ def test_mistralai_initialization() -> None:
|
||||
ChatMistralAI(model="test", mistral_api_key="test"), # type: ignore[call-arg, call-arg]
|
||||
ChatMistralAI(model="test", api_key="test"), # type: ignore[call-arg, arg-type]
|
||||
]:
|
||||
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"
|
||||
assert cast("SecretStr", model.mistral_api_key).get_secret_value() == "test"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,expected_url",
|
||||
("model", "expected_url"),
|
||||
[
|
||||
(ChatMistralAI(model="test"), "https://api.mistral.ai/v1"), # type: ignore[call-arg, arg-type]
|
||||
(ChatMistralAI(model="test", endpoint="baz"), "baz"), # type: ignore[call-arg, arg-type]
|
||||
|
||||
@@ -14,4 +14,4 @@ def test_mistral_init() -> None:
|
||||
MistralAIEmbeddings(model="mistral-embed", api_key="test"), # type: ignore[arg-type]
|
||||
]:
|
||||
assert model.model == "mistral-embed"
|
||||
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"
|
||||
assert cast("SecretStr", model.mistral_api_key).get_secret_value() == "test"
|
||||
|
||||
Reference in New Issue
Block a user