mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-27 14:26:48 +00:00
community[patch]: Resolve more linting issues (#26115)
Resolve a bunch of errors caught with mypy
This commit is contained in:
@@ -56,7 +56,7 @@ class AINetworkToolkit(BaseToolkit):
|
|||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
arbitrary_types_allowed=True,
|
arbitrary_types_allowed=True,
|
||||||
validate_all=True,
|
validate_default=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_tools(self) -> List[BaseTool]:
|
def get_tools(self) -> List[BaseTool]:
|
||||||
|
@@ -31,7 +31,8 @@ from langchain_core.language_models.llms import create_base_retry_decorator
|
|||||||
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
|
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
from pydantic import BaseModel, Field, SecretStr, model_validator, root_validator
|
from pydantic import BaseModel, Field, SecretStr, model_validator
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from langchain_community.adapters.openai import (
|
from langchain_community.adapters.openai import (
|
||||||
convert_dict_to_message,
|
convert_dict_to_message,
|
||||||
@@ -150,7 +151,7 @@ class GPTRouter(BaseChatModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
models_priority_list: List[GPTRouterModel] = Field(min_items=1)
|
models_priority_list: List[GPTRouterModel] = Field(min_length=1)
|
||||||
gpt_router_api_base: str = Field(default=None)
|
gpt_router_api_base: str = Field(default=None)
|
||||||
"""WriteSonic GPTRouter custom endpoint"""
|
"""WriteSonic GPTRouter custom endpoint"""
|
||||||
gpt_router_api_key: Optional[SecretStr] = None
|
gpt_router_api_key: Optional[SecretStr] = None
|
||||||
@@ -186,8 +187,8 @@ class GPTRouter(BaseChatModel):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@root_validator(pre=True, skip_on_failure=True)
|
@model_validator(mode="after")
|
||||||
def post_init(cls, values: Dict) -> Dict:
|
def post_init(self) -> Self:
|
||||||
try:
|
try:
|
||||||
from gpt_router.client import GPTRouterClient
|
from gpt_router.client import GPTRouterClient
|
||||||
|
|
||||||
@@ -198,12 +199,14 @@ class GPTRouter(BaseChatModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
gpt_router_client = GPTRouterClient(
|
gpt_router_client = GPTRouterClient(
|
||||||
values["gpt_router_api_base"],
|
self.gpt_router_api_base,
|
||||||
values["gpt_router_api_key"].get_secret_value(),
|
self.gpt_router_api_key.get_secret_value()
|
||||||
|
if self.gpt_router_api_key
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
values["client"] = gpt_router_client
|
self.client = gpt_router_client
|
||||||
|
|
||||||
return values
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lc_secrets(self) -> Dict[str, str]:
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
@@ -34,11 +34,11 @@ CHUNK_SIZE = 1024 * 1024 * 5
|
|||||||
|
|
||||||
|
|
||||||
class _O365Settings(BaseSettings):
|
class _O365Settings(BaseSettings):
|
||||||
client_id: str = Field(..., env="O365_CLIENT_ID")
|
client_id: str = Field(..., alias="O365_CLIENT_ID")
|
||||||
client_secret: SecretStr = Field(..., env="O365_CLIENT_SECRET")
|
client_secret: SecretStr = Field(..., alias="O365_CLIENT_SECRET")
|
||||||
|
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
case_sentive=False, env_file=".env", env_prefix=""
|
case_sensitive=False, env_file=".env", env_prefix=""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -11,19 +11,20 @@ from pydantic import (
|
|||||||
FilePath,
|
FilePath,
|
||||||
SecretStr,
|
SecretStr,
|
||||||
)
|
)
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
from langchain_community.document_loaders.base import BaseLoader
|
from langchain_community.document_loaders.base import BaseLoader
|
||||||
|
|
||||||
|
|
||||||
class _OneNoteGraphSettings(BaseSettings):
|
class _OneNoteGraphSettings(BaseSettings):
|
||||||
client_id: str = Field(..., env="MS_GRAPH_CLIENT_ID")
|
client_id: str = Field(..., alias="MS_GRAPH_CLIENT_ID")
|
||||||
client_secret: SecretStr = Field(..., env="MS_GRAPH_CLIENT_SECRET")
|
client_secret: SecretStr = Field(..., alias="MS_GRAPH_CLIENT_SECRET")
|
||||||
|
|
||||||
class Config:
|
model_config = SettingsConfigDict(
|
||||||
case_sensitive = False
|
case_sensitive=False,
|
||||||
env_file = ".env"
|
env_file=".env",
|
||||||
env_prefix = ""
|
env_prefix="",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OneNoteLoader(BaseLoader, BaseModel):
|
class OneNoteLoader(BaseLoader, BaseModel):
|
||||||
|
@@ -62,7 +62,7 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
|
|||||||
session: Any #: :meta private:
|
session: Any #: :meta private:
|
||||||
model_name: str = Field(default="Baichuan-Text-Embedding", alias="model")
|
model_name: str = Field(default="Baichuan-Text-Embedding", alias="model")
|
||||||
"""The model used to embed the documents."""
|
"""The model used to embed the documents."""
|
||||||
baichuan_api_key: Optional[SecretStr] = Field(
|
baichuan_api_key: SecretStr = Field(
|
||||||
alias="api_key",
|
alias="api_key",
|
||||||
default_factory=secret_from_env(["BAICHUAN_API_KEY", "BAICHUAN_AUTH_TOKEN"]),
|
default_factory=secret_from_env(["BAICHUAN_API_KEY", "BAICHUAN_AUTH_TOKEN"]),
|
||||||
)
|
)
|
||||||
|
@@ -194,6 +194,12 @@ class OCIGenAIBase(BaseModel, ABC):
|
|||||||
if self.provider is not None:
|
if self.provider is not None:
|
||||||
provider = self.provider
|
provider = self.provider
|
||||||
else:
|
else:
|
||||||
|
if self.model_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"model_id is required to derive the provider, "
|
||||||
|
"please provide the provider explicitly or specify "
|
||||||
|
"the model_id to derive the provider."
|
||||||
|
)
|
||||||
provider = self.model_id.split(".")[0].lower()
|
provider = self.model_id.split(".")[0].lower()
|
||||||
|
|
||||||
if provider not in provider_map:
|
if provider not in provider_map:
|
||||||
@@ -267,6 +273,12 @@ class OCIGenAI(LLM, OCIGenAIBase):
|
|||||||
if stop is not None:
|
if stop is not None:
|
||||||
_model_kwargs[self._provider.stop_sequence_key] = stop
|
_model_kwargs[self._provider.stop_sequence_key] = stop
|
||||||
|
|
||||||
|
if self.model_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"model_id is required to call the model, "
|
||||||
|
"please provide the model_id."
|
||||||
|
)
|
||||||
|
|
||||||
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
|
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
|
||||||
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
|
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
|
||||||
else:
|
else:
|
||||||
|
@@ -20,9 +20,6 @@ class BaseCassandraDatabaseTool(BaseModel):
|
|||||||
|
|
||||||
db: CassandraDatabase = Field(exclude=True)
|
db: CassandraDatabase = Field(exclude=True)
|
||||||
|
|
||||||
class Config(BaseTool.Config):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class _QueryCassandraDatabaseToolInput(BaseModel):
|
class _QueryCassandraDatabaseToolInput(BaseModel):
|
||||||
query: str = Field(..., description="A detailed and correct CQL query.")
|
query: str = Field(..., description="A detailed and correct CQL query.")
|
||||||
|
@@ -30,13 +30,12 @@ class DetectorAPI(str, Enum):
|
|||||||
class ZenGuardInput(BaseModel):
|
class ZenGuardInput(BaseModel):
|
||||||
prompts: List[str] = Field(
|
prompts: List[str] = Field(
|
||||||
...,
|
...,
|
||||||
min_items=1,
|
|
||||||
min_length=1,
|
min_length=1,
|
||||||
description="Prompt to check",
|
description="Prompt to check",
|
||||||
)
|
)
|
||||||
detectors: List[Detector] = Field(
|
detectors: List[Detector] = Field(
|
||||||
...,
|
...,
|
||||||
min_items=1,
|
min_length=1,
|
||||||
description="List of detectors by which you want to check the prompt",
|
description="List of detectors by which you want to check the prompt",
|
||||||
)
|
)
|
||||||
in_parallel: bool = Field(
|
in_parallel: bool = Field(
|
||||||
|
@@ -139,7 +139,6 @@ from pydantic import (
|
|||||||
Field,
|
Field,
|
||||||
PrivateAttr,
|
PrivateAttr,
|
||||||
model_validator,
|
model_validator,
|
||||||
validator,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -215,20 +214,6 @@ class SearxSearchWrapper(BaseModel):
|
|||||||
k: int = 10
|
k: int = 10
|
||||||
aiosession: Optional[Any] = None
|
aiosession: Optional[Any] = None
|
||||||
|
|
||||||
@validator("unsecure")
|
|
||||||
def disable_ssl_warnings(cls, v: bool) -> bool:
|
|
||||||
"""Disable SSL warnings."""
|
|
||||||
if v:
|
|
||||||
# requests.urllib3.disable_warnings()
|
|
||||||
try:
|
|
||||||
import urllib3
|
|
||||||
|
|
||||||
urllib3.disable_warnings()
|
|
||||||
except ImportError as e:
|
|
||||||
print(e) # noqa: T201
|
|
||||||
|
|
||||||
return v
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_params(cls, values: Dict) -> Any:
|
def validate_params(cls, values: Dict) -> Any:
|
||||||
@@ -254,7 +239,6 @@ class SearxSearchWrapper(BaseModel):
|
|||||||
searx_host = "https://" + searx_host
|
searx_host = "https://" + searx_host
|
||||||
elif searx_host.startswith("http://"):
|
elif searx_host.startswith("http://"):
|
||||||
values["unsecure"] = True
|
values["unsecure"] = True
|
||||||
cls.disable_ssl_warnings(True)
|
|
||||||
values["searx_host"] = searx_host
|
values["searx_host"] = searx_host
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
@@ -51,9 +51,9 @@ class DocArrayIndex(VectorStore, ABC):
|
|||||||
from docarray.typing import NdArray
|
from docarray.typing import NdArray
|
||||||
|
|
||||||
class DocArrayDoc(BaseDoc):
|
class DocArrayDoc(BaseDoc):
|
||||||
text: Optional[str] = Field(default=None, required=False)
|
text: Optional[str] = Field(default=None)
|
||||||
embedding: Optional[NdArray] = Field(**embeddings_params)
|
embedding: Optional[NdArray] = Field(**embeddings_params)
|
||||||
metadata: Optional[dict] = Field(default=None, required=False)
|
metadata: Optional[dict] = Field(default=None)
|
||||||
|
|
||||||
return DocArrayDoc
|
return DocArrayDoc
|
||||||
|
|
||||||
|
@@ -28,11 +28,11 @@ def amazon_retriever(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_client(amazon_retriever: AmazonKnowledgeBasesRetriever) -> None:
|
def test_create_client() -> None:
|
||||||
# Import error if boto3 is not installed
|
# Import error if boto3 is not installed
|
||||||
# Value error if credentials are not supplied.
|
# Value error if credentials are not supplied.
|
||||||
with pytest.raises((ImportError, ValueError)):
|
with pytest.raises((ImportError, ValueError)):
|
||||||
amazon_retriever.create_client({})
|
AmazonKnowledgeBasesRetriever() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def test_standard_params(amazon_retriever: AmazonKnowledgeBasesRetriever) -> None:
|
def test_standard_params(amazon_retriever: AmazonKnowledgeBasesRetriever) -> None:
|
||||||
|
@@ -128,7 +128,7 @@ def test_init(asr: RivaASR) -> None:
|
|||||||
"""Test that ASR accepts valid arguments."""
|
"""Test that ASR accepts valid arguments."""
|
||||||
for key, expected_val in CONFIG.items():
|
for key, expected_val in CONFIG.items():
|
||||||
if key == "url":
|
if key == "url":
|
||||||
assert asr.url == AnyHttpUrl(expected_val)
|
assert asr.url == AnyHttpUrl(expected_val) # type: ignore
|
||||||
else:
|
else:
|
||||||
assert getattr(asr, key, None) == expected_val
|
assert getattr(asr, key, None) == expected_val
|
||||||
|
|
||||||
|
@@ -58,7 +58,7 @@ def test_init(tts: RivaTTS) -> None:
|
|||||||
"""Test that ASR accepts valid arguments."""
|
"""Test that ASR accepts valid arguments."""
|
||||||
for key, expected_val in CONFIG.items():
|
for key, expected_val in CONFIG.items():
|
||||||
if key == "url":
|
if key == "url":
|
||||||
assert str(tts.url) == expected_val + "/"
|
assert str(tts.url) == expected_val + "/" # type: ignore
|
||||||
else:
|
else:
|
||||||
assert getattr(tts, key, None) == expected_val
|
assert getattr(tts, key, None) == expected_val
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user