community[patch]: Resolve more linting issues (#26115)

Resolve a bunch of errors caught with mypy
This commit is contained in:
Eugene Yurtsev
2024-09-05 15:59:30 -04:00
committed by GitHub
parent 6e1b0d0228
commit 0cc6584889
13 changed files with 43 additions and 47 deletions

View File

@@ -56,7 +56,7 @@ class AINetworkToolkit(BaseToolkit):
model_config = ConfigDict(
arbitrary_types_allowed=True,
validate_all=True,
validate_default=True,
)
def get_tools(self) -> List[BaseTool]:

View File

@@ -31,7 +31,8 @@ from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from pydantic import BaseModel, Field, SecretStr, model_validator, root_validator
from pydantic import BaseModel, Field, SecretStr, model_validator
from typing_extensions import Self
from langchain_community.adapters.openai import (
convert_dict_to_message,
@@ -150,7 +151,7 @@ class GPTRouter(BaseChatModel):
"""
client: Any = Field(default=None, exclude=True) #: :meta private:
models_priority_list: List[GPTRouterModel] = Field(min_items=1)
models_priority_list: List[GPTRouterModel] = Field(min_length=1)
gpt_router_api_base: str = Field(default=None)
"""WriteSonic GPTRouter custom endpoint"""
gpt_router_api_key: Optional[SecretStr] = None
@@ -186,8 +187,8 @@ class GPTRouter(BaseChatModel):
)
return values
@root_validator(pre=True, skip_on_failure=True)
def post_init(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def post_init(self) -> Self:
try:
from gpt_router.client import GPTRouterClient
@@ -198,12 +199,14 @@ class GPTRouter(BaseChatModel):
)
gpt_router_client = GPTRouterClient(
values["gpt_router_api_base"],
values["gpt_router_api_key"].get_secret_value(),
self.gpt_router_api_base,
self.gpt_router_api_key.get_secret_value()
if self.gpt_router_api_key
else None,
)
values["client"] = gpt_router_client
self.client = gpt_router_client
return values
return self
@property
def lc_secrets(self) -> Dict[str, str]:

View File

@@ -34,11 +34,11 @@ CHUNK_SIZE = 1024 * 1024 * 5
class _O365Settings(BaseSettings):
client_id: str = Field(..., env="O365_CLIENT_ID")
client_secret: SecretStr = Field(..., env="O365_CLIENT_SECRET")
client_id: str = Field(..., alias="O365_CLIENT_ID")
client_secret: SecretStr = Field(..., alias="O365_CLIENT_SECRET")
model_config = SettingsConfigDict(
case_sentive=False, env_file=".env", env_prefix=""
case_sensitive=False, env_file=".env", env_prefix=""
)

View File

@@ -11,19 +11,20 @@ from pydantic import (
FilePath,
SecretStr,
)
from pydantic_settings import BaseSettings
from pydantic_settings import BaseSettings, SettingsConfigDict
from langchain_community.document_loaders.base import BaseLoader
class _OneNoteGraphSettings(BaseSettings):
client_id: str = Field(..., env="MS_GRAPH_CLIENT_ID")
client_secret: SecretStr = Field(..., env="MS_GRAPH_CLIENT_SECRET")
client_id: str = Field(..., alias="MS_GRAPH_CLIENT_ID")
client_secret: SecretStr = Field(..., alias="MS_GRAPH_CLIENT_SECRET")
class Config:
case_sensitive = False
env_file = ".env"
env_prefix = ""
model_config = SettingsConfigDict(
case_sensitive=False,
env_file=".env",
env_prefix="",
)
class OneNoteLoader(BaseLoader, BaseModel):

View File

@@ -62,7 +62,7 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
session: Any #: :meta private:
model_name: str = Field(default="Baichuan-Text-Embedding", alias="model")
"""The model used to embed the documents."""
baichuan_api_key: Optional[SecretStr] = Field(
baichuan_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env(["BAICHUAN_API_KEY", "BAICHUAN_AUTH_TOKEN"]),
)

View File

@@ -194,6 +194,12 @@ class OCIGenAIBase(BaseModel, ABC):
if self.provider is not None:
provider = self.provider
else:
if self.model_id is None:
raise ValueError(
"model_id is required to derive the provider, "
"please provide the provider explicitly or specify "
"the model_id to derive the provider."
)
provider = self.model_id.split(".")[0].lower()
if provider not in provider_map:
@@ -267,6 +273,12 @@ class OCIGenAI(LLM, OCIGenAIBase):
if stop is not None:
_model_kwargs[self._provider.stop_sequence_key] = stop
if self.model_id is None:
raise ValueError(
"model_id is required to call the model, "
"please provide the model_id."
)
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
else:

View File

@@ -20,9 +20,6 @@ class BaseCassandraDatabaseTool(BaseModel):
db: CassandraDatabase = Field(exclude=True)
class Config(BaseTool.Config):
pass
class _QueryCassandraDatabaseToolInput(BaseModel):
query: str = Field(..., description="A detailed and correct CQL query.")

View File

@@ -30,13 +30,12 @@ class DetectorAPI(str, Enum):
class ZenGuardInput(BaseModel):
prompts: List[str] = Field(
...,
min_items=1,
min_length=1,
description="Prompt to check",
)
detectors: List[Detector] = Field(
...,
min_items=1,
min_length=1,
description="List of detectors by which you want to check the prompt",
)
in_parallel: bool = Field(

View File

@@ -139,7 +139,6 @@ from pydantic import (
Field,
PrivateAttr,
model_validator,
validator,
)
@@ -215,20 +214,6 @@ class SearxSearchWrapper(BaseModel):
k: int = 10
aiosession: Optional[Any] = None
@validator("unsecure")
def disable_ssl_warnings(cls, v: bool) -> bool:
"""Disable SSL warnings."""
if v:
# requests.urllib3.disable_warnings()
try:
import urllib3
urllib3.disable_warnings()
except ImportError as e:
print(e) # noqa: T201
return v
@model_validator(mode="before")
@classmethod
def validate_params(cls, values: Dict) -> Any:
@@ -254,7 +239,6 @@ class SearxSearchWrapper(BaseModel):
searx_host = "https://" + searx_host
elif searx_host.startswith("http://"):
values["unsecure"] = True
cls.disable_ssl_warnings(True)
values["searx_host"] = searx_host
return values

View File

@@ -51,9 +51,9 @@ class DocArrayIndex(VectorStore, ABC):
from docarray.typing import NdArray
class DocArrayDoc(BaseDoc):
text: Optional[str] = Field(default=None, required=False)
text: Optional[str] = Field(default=None)
embedding: Optional[NdArray] = Field(**embeddings_params)
metadata: Optional[dict] = Field(default=None, required=False)
metadata: Optional[dict] = Field(default=None)
return DocArrayDoc

View File

@@ -28,11 +28,11 @@ def amazon_retriever(
)
def test_create_client(amazon_retriever: AmazonKnowledgeBasesRetriever) -> None:
def test_create_client() -> None:
# Import error if boto3 is not installed
# Value error if credentials are not supplied.
with pytest.raises((ImportError, ValueError)):
amazon_retriever.create_client({})
AmazonKnowledgeBasesRetriever() # type: ignore
def test_standard_params(amazon_retriever: AmazonKnowledgeBasesRetriever) -> None:

View File

@@ -128,7 +128,7 @@ def test_init(asr: RivaASR) -> None:
"""Test that ASR accepts valid arguments."""
for key, expected_val in CONFIG.items():
if key == "url":
assert asr.url == AnyHttpUrl(expected_val)
assert asr.url == AnyHttpUrl(expected_val) # type: ignore
else:
assert getattr(asr, key, None) == expected_val

View File

@@ -58,7 +58,7 @@ def test_init(tts: RivaTTS) -> None:
"""Test that ASR accepts valid arguments."""
for key, expected_val in CONFIG.items():
if key == "url":
assert str(tts.url) == expected_val + "/"
assert str(tts.url) == expected_val + "/" # type: ignore
else:
assert getattr(tts, key, None) == expected_val