community: Add ruff rule PGH003 (#30812)

See https://docs.astral.sh/ruff/rules/blanket-type-ignore/

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Christophe Bornet 2025-04-14 04:32:13 +02:00 committed by GitHub
parent f005988e31
commit ada740b5b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
148 changed files with 448 additions and 419 deletions

View File

@ -196,7 +196,7 @@ def create_sql_agent(
]
prompt = ChatPromptTemplate.from_messages(messages)
agent = RunnableAgent(
runnable=create_openai_functions_agent(llm, tools, prompt), # type: ignore
runnable=create_openai_functions_agent(llm, tools, prompt), # type: ignore[arg-type]
input_keys_arg=["input"],
return_keys_arg=["output"],
**kwargs,
@ -211,9 +211,9 @@ def create_sql_agent(
]
prompt = ChatPromptTemplate.from_messages(messages)
if agent_type == "openai-tools":
runnable = create_openai_tools_agent(llm, tools, prompt) # type: ignore
runnable = create_openai_tools_agent(llm, tools, prompt) # type: ignore[arg-type]
else:
runnable = create_tool_calling_agent(llm, tools, prompt) # type: ignore
runnable = create_tool_calling_agent(llm, tools, prompt) # type: ignore[arg-type]
agent = RunnableMultiActionAgent( # type: ignore[assignment]
runnable=runnable,
input_keys_arg=["input"],

View File

@ -135,7 +135,7 @@ def _get_assistants_tool(
Dict[str, Any]: A dictionary of tools that are converted into OpenAI tools.
"""
if _is_assistants_builtin_tool(tool):
return tool # type: ignore
return tool # type: ignore[return-value]
else:
return convert_to_openai_tool(tool)
@ -288,7 +288,7 @@ class OpenAIAssistantV2Runnable(OpenAIAssistantRunnable):
assistant = client.beta.assistants.create(
name=name,
instructions=instructions,
tools=[_get_assistants_tool(tool) for tool in tools], # type: ignore
tools=[_get_assistants_tool(tool) for tool in tools],
tool_resources=tool_resources, # type: ignore[arg-type]
model=model,
extra_body=extra_body,
@ -430,7 +430,7 @@ class OpenAIAssistantV2Runnable(OpenAIAssistantRunnable):
assistant = await async_client.beta.assistants.create(
name=name,
instructions=instructions,
tools=openai_tools, # type: ignore
tools=openai_tools,
tool_resources=tool_resources, # type: ignore[arg-type]
model=model,
)

View File

@ -238,7 +238,7 @@ class InMemoryCache(BaseCache):
Base = declarative_base()
class FullLLMCache(Base): # type: ignore
class FullLLMCache(Base): # type: ignore[misc,valid-type]
"""SQLite table for full LLM Cache (all generations)."""
__tablename__ = "full_llm_cache"
@ -261,7 +261,7 @@ class SQLAlchemyCache(BaseCache):
"""Look up based on prompt and llm_string."""
stmt = (
select(self.cache_schema.response)
.where(self.cache_schema.prompt == prompt) # type: ignore
.where(self.cache_schema.prompt == prompt)
.where(self.cache_schema.llm == llm_string)
.order_by(self.cache_schema.idx)
)
@ -1531,7 +1531,7 @@ class CassandraSemanticCache(BaseCache):
await self.table.aclear()
class FullMd5LLMCache(Base): # type: ignore
class FullMd5LLMCache(Base): # type: ignore[misc,valid-type]
"""SQLite table for full LLM Cache (all generations)."""
__tablename__ = "full_md5_llm_cache"
@ -1583,7 +1583,7 @@ class SQLAlchemyMd5Cache(BaseCache):
def _delete_previous(self, session: Session, prompt: str, llm_string: str) -> None:
stmt = (
delete(self.cache_schema)
.where(self.cache_schema.prompt_md5 == self.get_md5(prompt)) # type: ignore
.where(self.cache_schema.prompt_md5 == self.get_md5(prompt))
.where(self.cache_schema.llm == llm_string)
.where(self.cache_schema.prompt == prompt)
)
@ -1593,7 +1593,7 @@ class SQLAlchemyMd5Cache(BaseCache):
prompt_pd5 = self.get_md5(prompt)
stmt = (
select(self.cache_schema.response)
.where(self.cache_schema.prompt_md5 == prompt_pd5) # type: ignore
.where(self.cache_schema.prompt_md5 == prompt_pd5)
.where(self.cache_schema.llm == llm_string)
.where(self.cache_schema.prompt == prompt)
.order_by(self.cache_schema.idx)
@ -1796,7 +1796,7 @@ class _CachedAwaitable:
def __await__(self) -> Generator:
if self.result is _unset:
self.result = yield from self.awaitable.__await__()
return self.result # type: ignore
return self.result # type: ignore[return-value]
def _reawaitable(func: Callable) -> Callable:

View File

@ -584,7 +584,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
)
_custom_metrics = custom_metrics if custom_metrics else self.custom_metrics
self.__init__( # type: ignore
self.__init__( # type: ignore[misc]
task_type=_task_type,
workspace=_workspace,
project_name=_project_name,

View File

@ -580,7 +580,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.temp_dir.cleanup()
self.reset_callback_meta()
if reset:
self.__init__( # type: ignore
self.__init__( # type: ignore[misc]
job_type=job_type if job_type else self.job_type,
project=project if project else self.project,
entity=entity if entity else self.entity,

View File

@ -352,7 +352,7 @@ def create_structured_output_runnable(
class _OutputFormatter(BaseModel):
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501
output: output_schema # type: ignore
output: output_schema # type: ignore[valid-type]
function = _OutputFormatter
output_parser = output_parser or PydanticAttrOutputFunctionsParser(
@ -537,7 +537,7 @@ def create_structured_output_chain(
class _OutputFormatter(BaseModel):
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501
output: output_schema # type: ignore
output: output_schema # type: ignore[valid-type]
function = _OutputFormatter
output_parser = output_parser or PydanticAttrOutputFunctionsParser(

View File

@ -316,7 +316,7 @@ class GraphCypherQAChain(Chain):
MessagesPlaceholder(variable_name="function_response"),
]
)
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore[operator]
except (NotImplementedError, AttributeError):
raise ValueError("Provided LLM does not support native tools/functions")
else:
@ -404,15 +404,15 @@ class GraphCypherQAChain(Chain):
intermediate_steps.append({"context": context})
if self.use_function_response:
function_response = get_function_response(question, context)
final_result = self.qa_chain.invoke( # type: ignore
final_result = self.qa_chain.invoke( # type: ignore[assignment]
{"question": question, "function_response": function_response},
)
else:
result = self.qa_chain.invoke( # type: ignore
result = self.qa_chain.invoke(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key] # type: ignore
final_result = result[self.qa_chain.output_key] # type: ignore[union-attr]
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:

View File

@ -225,11 +225,11 @@ class MemgraphQAChain(Chain):
MessagesPlaceholder(variable_name="function_response"),
]
)
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore[operator]
except (NotImplementedError, AttributeError):
raise ValueError("Provided LLM does not support native tools/functions")
else:
qa_chain = use_qa_llm_kwargs["prompt"] | qa_llm | StrOutputParser() # type: ignore
qa_chain = use_qa_llm_kwargs["prompt"] | qa_llm | StrOutputParser()
prompt = use_cypher_llm_kwargs["prompt"]
llm_to_use = cypher_llm if cypher_llm is not None else llm
@ -300,11 +300,11 @@ class MemgraphQAChain(Chain):
intermediate_steps.append({"context": context})
if self.use_function_response:
function_response = get_function_response(question, context)
result = self.qa_chain.invoke( # type: ignore
result = self.qa_chain.invoke(
{"question": question, "function_response": function_response},
)
else:
result = self.qa_chain.invoke( # type: ignore
result = self.qa_chain.invoke(
{"question": question, "context": context},
callbacks=callbacks,
)

View File

@ -67,11 +67,11 @@ def extract_cypher(text: str) -> str:
def use_simple_prompt(llm: BaseLanguageModel) -> bool:
"""Decides whether to use the simple prompt"""
if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore
if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore[attr-defined]
return True
# Bedrock anthropic
if hasattr(llm, "model_id") and "anthropic" in llm.model_id: # type: ignore
if hasattr(llm, "model_id") and "anthropic" in llm.model_id:
return True
return False

View File

@ -313,8 +313,12 @@ class PebbloRetrievalQA(Chain):
)
@staticmethod
def _get_app_details( # type: ignore
app_name: str, owner: str, description: str, llm: BaseLanguageModel, **kwargs
def _get_app_details(
app_name: str,
owner: str,
description: str,
llm: BaseLanguageModel,
**kwargs: Any,
) -> App:
"""Fetch app details. Internal method.
Returns:

View File

@ -81,7 +81,7 @@ class CassandraChatMessageHistory(BaseChatMessageHistory):
)
@property
def messages(self) -> List[BaseMessage]: # type: ignore
def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve all session messages from DB"""
# The latest are returned, in chronological order
rows = self.table.get_partition(

View File

@ -35,7 +35,7 @@ class FileChatMessageHistory(BaseChatMessageHistory):
)
@property
def messages(self) -> List[BaseMessage]: # type: ignore
def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from the local file"""
items = json.loads(self.file_path.read_text(encoding=self.encoding))
messages = messages_from_dict(items)

View File

@ -334,7 +334,7 @@ class KafkaChatMessageHistory(BaseChatMessageHistory):
)
@property
def messages(self) -> List[BaseMessage]: # type: ignore
def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""
Retrieve the messages for the session, from Kafka topic continuously
from last consumed message. This method is stateful and maintains

View File

@ -60,7 +60,7 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
self.collection.create_index("SessionId")
@property
def messages(self) -> List[BaseMessage]: # type: ignore
def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from MongoDB"""
from pymongo import errors

View File

@ -65,7 +65,7 @@ class PostgresChatMessageHistory(BaseChatMessageHistory):
self.connection.commit()
@property
def messages(self) -> List[BaseMessage]: # type: ignore
def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from PostgreSQL"""
query = (
f"SELECT message FROM {self.table_name} WHERE session_id = %s ORDER BY id;"

View File

@ -215,7 +215,7 @@ class RocksetChatMessageHistory(BaseChatMessageHistory):
self._create_empty_doc()
@property
def messages(self) -> List[BaseMessage]: # type: ignore
def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Messages in this chat history."""
return messages_from_dict(
self._query(

View File

@ -212,7 +212,7 @@ class SingleStoreDBChatMessageHistory(BaseChatMessageHistory):
conn.close()
@property
def messages(self) -> List[BaseMessage]: # type: ignore
def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from SingleStoreDB"""
self._create_table_if_not_exists()
conn = self.connection_pool.connect()

View File

@ -47,7 +47,7 @@ try:
from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError:
# dummy for sqlalchemy < 2
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore[assignment,misc]
logger = logging.getLogger(__name__)
@ -242,7 +242,7 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
self._table_created = True
@property
def messages(self) -> List[BaseMessage]: # type: ignore
def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve all messages from db"""
with self._make_sync_session() as session:
result = (

View File

@ -51,7 +51,7 @@ class UpstashRedisChatMessageHistory(BaseChatMessageHistory):
return self.key_prefix + self.session_id
@property
def messages(self) -> List[BaseMessage]: # type: ignore
def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from Upstash Redis"""
_items = self.redis_client.lrange(self.key, 0, -1)
items = [json.loads(m) for m in _items[::-1]]

View File

@ -83,7 +83,7 @@ class XataChatMessageHistory(BaseChatMessageHistory):
raise Exception(f"Error adding message to Xata: {r.status_code} {r}")
@property
def messages(self) -> List[BaseMessage]: # type: ignore
def messages(self) -> List[BaseMessage]: # type: ignore[override]
r = self._client.data().query(
self._table_name,
payload={

View File

@ -87,7 +87,7 @@ class ZepChatMessageHistory(BaseChatMessageHistory):
self.session_id = session_id
@property
def messages(self) -> List[BaseMessage]: # type: ignore
def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve messages from Zep memory"""
zep_memory: Optional[Memory] = self._get_memory()
if not zep_memory:

View File

@ -134,7 +134,7 @@ class ZepCloudChatMessageHistory(BaseChatMessageHistory):
self.summary_instruction = summary_instruction
@property
def messages(self) -> List[BaseMessage]: # type: ignore
def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve messages from Zep memory"""
zep_memory: Optional[Memory] = self._get_memory()
if not zep_memory:

View File

@ -42,7 +42,7 @@ class ChatLiteLLMRouter(ChatLiteLLM):
def __init__(self, *, router: Any, **kwargs: Any) -> None:
"""Construct Chat LiteLLM Router."""
super().__init__(router=router, **kwargs) # type: ignore
super().__init__(router=router, **kwargs) # type: ignore[call-arg]
self.router = router
@property

View File

@ -815,4 +815,4 @@ def _convert_delta_to_message_chunk(
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role, id=id_)
else:
return default_class(content=content, id=id_) # type: ignore
return default_class(content=content, id=id_) # type: ignore[call-arg]

View File

@ -716,7 +716,7 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
first_tool_only=True, # type: ignore[list-item]
first_tool_only=True,
)
else:
output_parser = JsonOutputKeyToolsParser(

View File

@ -158,9 +158,9 @@ def _convert_delta_response_to_message_chunk(
Optional[str],
]:
"""Converts delta response to message chunk"""
_delta = response.choices[0].delta # type: ignore
role = _delta.get("role", "") # type: ignore
content = _delta.get("content", "") # type: ignore
_delta = response.choices[0].delta
role = _delta.get("role", "")
content = _delta.get("content", "")
additional_kwargs: Dict = {}
finish_reasons: Optional[str] = response.choices[0].finish_reason
@ -398,7 +398,7 @@ class ChatPremAI(BaseChatModel, BaseModel):
messages, template_id=kwargs["template_id"]
)
else:
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages)
if system_prompt is not None and system_prompt != "":
kwargs["system_prompt"] = system_prompt
@ -425,9 +425,9 @@ class ChatPremAI(BaseChatModel, BaseModel):
if "template_id" in kwargs:
system_prompt, messages_to_pass = _messages_to_prompt_dict(
messages, template_id=kwargs["template_id"]
) # type: ignore
)
else:
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages)
if stop is not None:
logger.warning("stop is not supported in langchain streaming")

View File

@ -218,7 +218,7 @@ class BlackboardLoader(WebBaseLoader):
loader = DirectoryLoader(
path=self.folder_path,
glob="*.pdf",
loader_cls=PyPDFLoader, # type: ignore
loader_cls=PyPDFLoader, # type: ignore[arg-type]
)
# Load the documents
documents = loader.load()

View File

@ -35,7 +35,7 @@ class _CloudBlob(Blob):
from cloudpathlib import AnyPath
if self.data is None and self.path:
return AnyPath(self.path).read_text(encoding=self.encoding) # type: ignore
return AnyPath(self.path).read_text(encoding=self.encoding)
elif isinstance(self.data, bytes):
return self.data.decode(self.encoding)
elif isinstance(self.data, str):
@ -52,7 +52,7 @@ class _CloudBlob(Blob):
elif isinstance(self.data, str):
return self.data.encode(self.encoding)
elif self.data is None and self.path:
return AnyPath(self.path).read_bytes() # type: ignore
return AnyPath(self.path).read_bytes()
else:
raise ValueError(f"Unable to get bytes for blob {self}")
@ -64,7 +64,7 @@ class _CloudBlob(Blob):
if isinstance(self.data, bytes):
yield BytesIO(self.data)
elif self.data is None and self.path:
return AnyPath(self.path).read_bytes() # type: ignore
return AnyPath(self.path).read_bytes()
else:
raise NotImplementedError(f"Unable to convert blob {self}")
@ -79,7 +79,7 @@ def _url_to_filename(url: str) -> str:
url_parsed = urlparse(url)
suffix = Path(url_parsed.path).suffix
if url_parsed.scheme in ["s3", "az", "gs"]:
with AnyPath(url).open("rb") as f: # type: ignore
with AnyPath(url).open("rb") as f:
temp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
while True:
buf = f.read()
@ -116,7 +116,7 @@ def _make_iterator(
iterator = _with_tqdm
else:
iterator = iter # type: ignore
iterator = iter # type: ignore[assignment]
return iterator
@ -220,7 +220,7 @@ class CloudBlobLoader(BlobLoader):
def _yield_paths(self) -> Iterable["AnyPath"]:
"""Yield paths that match the requested pattern."""
if self.path.is_file(): # type: ignore
if self.path.is_file():
yield self.path
return
@ -269,7 +269,7 @@ class CloudBlobLoader(BlobLoader):
Blob instance
"""
if mime_type is None and guess_type:
_mimetype = mimetypes.guess_type(path)[0] if guess_type else None # type: ignore
_mimetype = mimetypes.guess_type(path)[0] if guess_type else None
else:
_mimetype = mime_type

View File

@ -252,7 +252,7 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
files = self._fetch_files_recursive(service, folder_id)
# If file types filter is provided, we'll filter by the file type.
if file_types:
_files = [f for f in files if f["mimeType"] in file_types] # type: ignore
_files = [f for f in files if f["mimeType"] in file_types]
else:
_files = files
@ -261,14 +261,14 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
if file["trashed"] and not self.load_trashed_files:
continue
elif file["mimeType"] == "application/vnd.google-apps.document":
returns.append(self._load_document_from_id(file["id"])) # type: ignore
returns.append(self._load_document_from_id(file["id"])) # type: ignore[arg-type]
elif file["mimeType"] == "application/vnd.google-apps.spreadsheet":
returns.extend(self._load_sheet_from_id(file["id"])) # type: ignore
returns.extend(self._load_sheet_from_id(file["id"])) # type: ignore[arg-type]
elif (
file["mimeType"] == "application/pdf"
or self.file_loader_cls is not None
):
returns.extend(self._load_file_from_id(file["id"])) # type: ignore
returns.extend(self._load_file_from_id(file["id"])) # type: ignore[arg-type]
else:
pass
return returns

View File

@ -267,7 +267,7 @@ class DocAIParser(BaseBlobParser):
"""Initializes Long-Running Operations from their names."""
try:
from google.longrunning.operations_pb2 import (
GetOperationRequest, # type: ignore
GetOperationRequest,
)
except ImportError as exc:
raise ImportError(

View File

@ -59,9 +59,9 @@ class DocumentLoaderAsParser(BaseBlobParser):
"""
Use underlying DocumentLoader to lazily parse the blob.
"""
doc_loader = self.DocumentLoaderClass(
doc_loader = self.DocumentLoaderClass( # type: ignore[call-arg]
file_path=blob.path, **self.document_loader_kwargs
) # type: ignore
)
for document in doc_loader.lazy_load():
document.metadata.update(blob.metadata)
yield document

View File

@ -107,7 +107,7 @@ class RapidOCRBlobParser(BaseImageBlobParser):
"`rapidocr-onnxruntime` package not found, please install it with "
"`pip install rapidocr-onnxruntime`"
)
ocr_result, _ = self.ocr(np.array(img)) # type: ignore
ocr_result, _ = self.ocr(np.array(img)) # type: ignore[misc]
content = ""
if ocr_result:
content = ("\n".join([text[1] for text in ocr_result])).strip()

View File

@ -82,7 +82,7 @@ class TreeSitterSegmenter(CodeSegmenter):
)
for line_num in range(start_line + 1, end_line + 1):
simplified_lines[line_num] = None # type: ignore
simplified_lines[line_num] = None # type: ignore[call-overload]
processed_lines.update(lines)

View File

@ -6,7 +6,7 @@ import json
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional
import requests # type: ignore
import requests
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from pydantic import Field

View File

@ -78,7 +78,7 @@ class TrelloLoader(BaseLoader):
"""
try:
from trello import TrelloClient # type: ignore
from trello import TrelloClient
except ImportError as ex:
raise ImportError(
"Could not import trello python package. "
@ -124,7 +124,7 @@ class TrelloLoader(BaseLoader):
return board
def _card_to_doc(self, card: Card, list_dict: dict) -> Document:
from bs4 import BeautifulSoup # type: ignore
from bs4 import BeautifulSoup
text_content = ""
if self.include_card_name:

View File

@ -245,8 +245,8 @@ def get_elements_from_api(
from unstructured.partition.api import partition_multiple_via_api
_doc_elements = partition_multiple_via_api(
filenames=file_path, # type: ignore
files=file, # type: ignore
filenames=file_path,
files=file,
api_key=api_key,
api_url=api_url,
**unstructured_kwargs,

View File

@ -393,7 +393,7 @@ class WebBaseLoader(BaseLoader):
"https://python.langchain.com/api_reference/community/document_loaders/langchain_community.document_loaders.web_base.WebBaseLoader.html" # noqa: E501
),
)
def aload(self) -> List[Document]: # type: ignore
def aload(self) -> List[Document]: # type: ignore[override]
"""Load text from the urls in web_path async into Documents."""
results = self.scrape_all(self.web_paths)

View File

@ -439,7 +439,7 @@ class HypotheticalDocumentEmbedder:
)
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H
return H(*args, **kwargs) # type: ignore
return H(*args, **kwargs) # type: ignore[return-value]
@classmethod
def from_llm(cls, *args: Any, **kwargs: Any) -> Any:

View File

@ -68,7 +68,7 @@ class InfinityEmbeddingsLocal(BaseModel, Embeddings):
"""Validate that api key and python package exists in environment."""
try:
from infinity_emb import AsyncEmbeddingEngine # type: ignore
from infinity_emb import AsyncEmbeddingEngine
except ImportError:
raise ImportError(
"Please install the "

View File

@ -76,7 +76,7 @@ class JinaEmbeddings(BaseModel, Embeddings):
def _embed(self, input: Any) -> List[List[float]]:
# Call Jina AI Embedding API
resp = self.session.post( # type: ignore
resp = self.session.post(
JINA_API_URL, json={"input": input, "model": self.model_name}
).json()
if "data" not in resp:
@ -85,7 +85,7 @@ class JinaEmbeddings(BaseModel, Embeddings):
embeddings = resp["data"]
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])
# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]

View File

@ -309,7 +309,7 @@ class AsyncOpenAITextEmbedEmbeddingClient:
Raises:
Exception: If the response status is not 200.
"""
async with session.post(**kwargs) as response: # type: ignore
async with session.post(**kwargs) as response: # type: ignore[arg-type]
if response.status != 200:
raise Exception(
f"TextEmbed responded with an unexpected status message "

View File

@ -22,7 +22,7 @@ def ngram_overlap_score(source: List[str], example: List[str]) -> float:
https://aclanthology.org/P02-1040.pdf
"""
from nltk.translate.bleu_score import (
SmoothingFunction, # type: ignore
SmoothingFunction,
sentence_bleu,
)

View File

@ -54,7 +54,7 @@ try:
from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError:
# dummy for sqlalchemy < 2
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore[assignment,misc]
from langchain_community.indexes.base import RecordManager
@ -308,8 +308,8 @@ class SQLRecordManager(RecordManager):
[UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
updated_at=insert_stmt.excluded.updated_at,
group_id=insert_stmt.excluded.group_id,
),
)
elif self.dialect == "postgresql":
@ -322,8 +322,8 @@ class SQLRecordManager(RecordManager):
"uix_key_namespace", # Name of constraint
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
updated_at=insert_stmt.excluded.updated_at,
group_id=insert_stmt.excluded.group_id,
),
)
else:
@ -383,8 +383,8 @@ class SQLRecordManager(RecordManager):
[UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
updated_at=insert_stmt.excluded.updated_at,
group_id=insert_stmt.excluded.group_id,
),
)
elif self.dialect == "postgresql":
@ -397,8 +397,8 @@ class SQLRecordManager(RecordManager):
"uix_key_namespace", # Name of constraint
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
updated_at=insert_stmt.excluded.updated_at,
group_id=insert_stmt.excluded.group_id,
),
)
else:

View File

@ -486,10 +486,10 @@ class AzureMLBaseEndpoint(BaseModel):
timeout = values.get("timeout", DEFAULT_TIMEOUT)
http_client = AzureMLEndpointClient(
endpoint_url, # type: ignore
endpoint_key.get_secret_value(), # type: ignore
deployment_name, # type: ignore
timeout, # type: ignore
endpoint_url, # type: ignore[arg-type]
endpoint_key.get_secret_value(), # type: ignore[union-attr]
deployment_name, # type: ignore[arg-type]
timeout,
)
return http_client

View File

@ -201,7 +201,7 @@ class Beam(LLM): # type: ignore[override, override, override, override]
def _deploy(self) -> str:
"""Call to Beam."""
try:
import beam # type: ignore
import beam
if beam.__path__ == "":
raise ImportError

View File

@ -181,7 +181,7 @@ class _BaseGigaChat(Serializable):
def get_num_tokens(self, text: str) -> int:
"""Count approximate number of tokens"""
if self.use_api_for_tokens:
return self.tokens_count([text])[0].tokens # type: ignore
return self.tokens_count([text])[0].tokens
else:
return round(len(text) / 4.6)

View File

@ -142,7 +142,7 @@ class HuggingFaceHub(LLM):
if "error" in response:
raise ValueError(f"Error raised by inference API: {response['error']}")
response_key = VALID_TASKS_DICT[self.task] # type: ignore
response_key = VALID_TASKS_DICT[self.task] # type: ignore[index]
if isinstance(response, list):
text = response[0][response_key]
else:

View File

@ -172,7 +172,7 @@ class IpexLLM(LLM):
if not low_bit_model:
if load_in_low_bit is not None:
load_function_name = "from_pretrained"
load_kwargs["load_in_low_bit"] = load_in_low_bit # type: ignore
load_kwargs["load_in_low_bit"] = load_in_low_bit # type: ignore[assignment]
else:
load_function_name = "from_pretrained"
load_kwargs["load_in_4bit"] = load_in_4bit

View File

@ -246,7 +246,7 @@ class BaseOpenAI(BaseLLM):
http_client: Union[Any, None] = None
"""Optional httpx.Client."""
def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore[misc]
"""Initialize the OpenAI object."""
model_name = data.get("model_name", "")
if (

View File

@ -47,7 +47,7 @@ class YiLLM(LLM):
def _post(self, request: Any) -> Any:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.yi_api_key.get_secret_value()}", # type: ignore
"Authorization": f"Bearer {self.yi_api_key.get_secret_value()}", # type: ignore[union-attr]
}
urls = []

View File

@ -161,11 +161,11 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
_result = super().parse_result(result)
if self.args_only:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore[union-attr]
else:
fn_name = _result["name"]
_args = _result["arguments"]
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore[index]
return pydantic_args

View File

@ -53,9 +53,9 @@ class DeepLakeTranslator(Visitor):
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
if isinstance(func, Operator):
value = OPERATOR_TO_TQL[func.value] # type: ignore
value = OPERATOR_TO_TQL[func.value] # type: ignore[index]
elif isinstance(func, Comparator):
value = COMPARATOR_TO_TQL[func.value] # type: ignore
value = COMPARATOR_TO_TQL[func.value] # type: ignore[index]
return f"{value}"
def visit_operation(self, operation: Operation) -> str:

View File

@ -42,9 +42,9 @@ class TimescaleVectorTranslator(Visitor):
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
if isinstance(func, Operator):
value = self.OPERATOR_MAP[func.value] # type: ignore
value = self.OPERATOR_MAP[func.value] # type: ignore[index]
elif isinstance(func, Comparator):
value = self.COMPARATOR_MAP[func.value] # type: ignore
value = self.COMPARATOR_MAP[func.value] # type: ignore[index]
return f"{value}"
def visit_operation(self, operation: Operation) -> client.Predicates:

View File

@ -41,7 +41,7 @@ try:
from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError:
# dummy for sqlalchemy < 2
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore[assignment,misc]
Base = declarative_base()
@ -255,7 +255,7 @@ class SQLStore(BaseStore[str, bytes]):
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
with self._make_sync_session() as session:
for v in session.query(LangchainKeyValueStores).filter( # type: ignore
for v in session.query(LangchainKeyValueStores).filter(
LangchainKeyValueStores.namespace == self.namespace
):
if str(v.key).startswith(prefix or ""):

View File

@ -178,7 +178,7 @@ def execute_function(
statement=parametrized_statement.statement,
warehouse_id=warehouse_id,
parameters=parametrized_statement.parameters,
**execute_statement_args, # type: ignore
**execute_statement_args,
)
if response.status and job_pending(response.status.state) and response.statement_id:
statement_id = response.statement_id
@ -197,7 +197,7 @@ def execute_function(
f"status after {wait} seconds."
)
time.sleep(wait)
response = ws.statement_execution.get_statement(statement_id) # type: ignore
response = ws.statement_execution.get_statement(statement_id)
if response.status is None or not job_pending(response.status.state):
break
wait_time += wait
@ -228,7 +228,7 @@ def execute_function(
if is_scalar(function):
value = None
if data_array and len(data_array) > 0 and len(data_array[0]) > 0:
value = str(data_array[0][0]) # type: ignore
value = str(data_array[0][0])
return FunctionExecutionResult(
format="SCALAR", value=value, truncated=truncated
)

View File

@ -51,8 +51,8 @@ def _uc_type_to_pydantic_type(uc_type_json: Union[str, Dict[str, Any]]) -> Type:
if tpe == "array":
element_type = _uc_type_to_pydantic_type(uc_type_json["elementType"])
if uc_type_json["containsNull"]:
element_type = Optional[element_type] # type: ignore
return List[element_type] # type: ignore
element_type = Optional[element_type] # type: ignore[assignment]
return List[element_type] # type: ignore[valid-type]
elif tpe == "map":
key_type = uc_type_json["keyType"]
assert key_type == "string", TypeError(
@ -60,14 +60,14 @@ def _uc_type_to_pydantic_type(uc_type_json: Union[str, Dict[str, Any]]) -> Type:
)
value_type = _uc_type_to_pydantic_type(uc_type_json["valueType"])
if uc_type_json["valueContainsNull"]:
value_type: Type = Optional[value_type] # type: ignore
return Dict[str, value_type] # type: ignore
value_type: Type = Optional[value_type] # type: ignore[no-redef]
return Dict[str, value_type] # type: ignore[valid-type]
elif tpe == "struct":
fields = {}
for field in uc_type_json["fields"]:
field_type = _uc_type_to_pydantic_type(field["type"])
if field.get("nullable"):
field_type = Optional[field_type] # type: ignore
field_type = Optional[field_type] # type: ignore[assignment]
comment = (
uc_type_json["metadata"].get("comment")
if "metadata" in uc_type_json
@ -76,7 +76,7 @@ def _uc_type_to_pydantic_type(uc_type_json: Union[str, Dict[str, Any]]) -> Type:
fields[field["name"]] = (field_type, Field(..., description=comment))
uc_type_json_str = json.dumps(uc_type_json, sort_keys=True)
type_hash = md5(uc_type_json_str.encode()).hexdigest()[:8]
return create_model(f"Struct_{type_hash}", **fields) # type: ignore
return create_model(f"Struct_{type_hash}", **fields) # type: ignore[call-overload]
else:
raise TypeError(f"Unknown type {uc_type_json}. Try upgrading this package.")
@ -94,7 +94,7 @@ def _generate_args_schema(function: "FunctionInfo") -> Type[BaseModel]:
description = p.comment
default: Any = ...
if p.parameter_default:
pydantic_type = Optional[pydantic_type] # type: ignore
pydantic_type = Optional[pydantic_type] # type: ignore[assignment]
default = None
# TODO: Convert default value string to the correct type.
# We might need to use statement execution API
@ -108,9 +108,9 @@ def _generate_args_schema(function: "FunctionInfo") -> Type[BaseModel]:
pydantic_type,
Field(default=default, description=description),
)
return create_model(
return create_model( # type: ignore[call-overload]
f"{function.catalog_name}__{function.schema_name}__{function.name}__params",
**fields, # type: ignore
**fields,
)

View File

@ -217,10 +217,10 @@ class NucliaUnderstandingAPI(BaseTool): # type: ignore[override, override]
logger.info(f"No matching id for {uuid}")
else:
self._results[matching_id]["status"] = "done"
data = MessageToJson(
data = MessageToJson( # type: ignore[call-arg]
pb,
preserving_proto_field_name=True,
including_default_value_fields=True, # type: ignore
including_default_value_fields=True,
)
self._results[matching_id]["data"] = data

View File

@ -161,9 +161,7 @@ class ZapierNLARunAction(BaseTool): # type: ignore[override]
)
ZapierNLARunAction.__doc__ = (
ZapierNLAWrapper.run.__doc__ + ZapierNLARunAction.__doc__ # type: ignore
)
ZapierNLARunAction.__doc__ = ZapierNLAWrapper.run.__doc__ + ZapierNLARunAction.__doc__ # type: ignore[operator]
# other useful actions
@ -210,5 +208,5 @@ class ZapierNLAListActions(BaseTool): # type: ignore[override]
ZapierNLAListActions.__doc__ = (
ZapierNLAWrapper.list.__doc__ + ZapierNLAListActions.__doc__ # type: ignore
ZapierNLAWrapper.list.__doc__ + ZapierNLAListActions.__doc__ # type: ignore[operator]
)

View File

@ -50,7 +50,7 @@ class BingSearchAPIWrapper(BaseModel):
response = requests.get(
self.bing_search_url,
headers=headers,
params=params, # type: ignore
params=params,
)
response.raise_for_status()
search_results = response.json()

View File

@ -82,7 +82,7 @@ class GoogleScholarAPIWrapper(BaseModel):
# 0 is the first page of results, 20 is the 2nd page of results,
# 40 is the 3rd page of results, etc.
results = (
self.google_scholar_engine( # type: ignore
self.google_scholar_engine(
{
"q": query,
"start": page,
@ -106,7 +106,7 @@ class GoogleScholarAPIWrapper(BaseModel):
): # From the last page we would only need top_k_results%20 results
# if k is not divisible by 20.
results = (
self.google_scholar_engine( # type: ignore
self.google_scholar_engine(
{
"q": query,
"start": page,

View File

@ -49,7 +49,6 @@ class MetaphorSearchAPIWrapper(BaseModel):
"useAutoprompt": use_autoprompt,
}
response = requests.post(
# type: ignore
f"{METAPHOR_API_URL}/search",
headers=headers,
json=params,

View File

@ -53,7 +53,7 @@ if TYPE_CHECKING:
try:
from openapi_pydantic import OpenAPI
except ImportError:
OpenAPI = object # type: ignore
OpenAPI = object
class OpenAPISpec(OpenAPI):

View File

@ -134,7 +134,7 @@ class NutritionAIAPI(BaseModel):
return requests.get(
self.nutritionai_api_url,
headers=self.auth_.headers,
params=params, # type: ignore
params=params,
)
def _api_call_results(self, search_term: str) -> dict:

View File

@ -395,7 +395,7 @@ class SQLDatabase:
try:
# get the sample rows
with self._engine.connect() as connection:
sample_rows_result = connection.execute(command) # type: ignore
sample_rows_result = connection.execute(command)
# shorten values in the sample rows
sample_rows = list(
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)

View File

@ -123,10 +123,10 @@ class SteamWebAPIWrapper(BaseModel):
except ImportError:
raise ImportError("steamspypi library is not installed.")
users_games = self.get_users_games(steam_id)
result = {} # type: ignore
result: dict[str, int] = {}
most_popular_genre = ""
most_popular_genre_count = 0
for game in users_games["games"]: # type: ignore
for game in users_games["games"]: # type: ignore[call-overload]
appid = game["appid"]
data_request = {"request": "appdetails", "appid": appid}
genreStore = steamspypi.download(data_request)
@ -148,7 +148,7 @@ class SteamWebAPIWrapper(BaseModel):
sorted_data = sorted(
data.values(), key=lambda x: x.get("average_forever", 0), reverse=True
)
owned_games = [game["appid"] for game in users_games["games"]] # type: ignore
owned_games = [game["appid"] for game in users_games["games"]] # type: ignore[call-overload]
remaining_games = [
game for game in sorted_data if game["appid"] not in owned_games
]

View File

@ -58,7 +58,6 @@ class TavilySearchAPIWrapper(BaseModel):
"include_images": include_images,
}
response = requests.post(
# type: ignore
f"{TAVILY_API_URL}/search",
json=params,
)

View File

@ -240,7 +240,6 @@ class YouSearchAPIWrapper(BaseModel):
if self.endpoint_type == "snippet":
self.endpoint_type = "search"
response = requests.get(
# type: ignore
f"{YOU_API_URL}/{self.endpoint_type}",
params=params,
headers=headers,

View File

@ -71,4 +71,4 @@ def cosine_similarity_top_k(
top_k_idxs = top_k_idxs[np.argsort(score_array.ravel()[top_k_idxs])][::-1]
ret_idxs = np.unravel_index(top_k_idxs, score_array.shape)
scores = score_array.ravel()[top_k_idxs].tolist()
return list(zip(*ret_idxs)), scores # type: ignore
return list(zip(*ret_idxs)), scores # type: ignore[return-value]

View File

@ -407,7 +407,7 @@ class AzureCosmosDBVectorSearch(VectorStore):
for t, m, embedding in zip(texts, metadatas, embeddings)
]
# insert the documents in Cosmos DB
insert_result = self._collection.insert_many(to_insert) # type: ignore
insert_result = self._collection.insert_many(to_insert)
return insert_result.inserted_ids
@classmethod

View File

@ -1571,7 +1571,7 @@ class AzureSearch(VectorStore):
azure_search.add_embeddings(text_embeddings, metadatas, **kwargs)
return azure_search
def as_retriever(self, **kwargs: Any) -> AzureSearchVectorStoreRetriever: # type: ignore
def as_retriever(self, **kwargs: Any) -> AzureSearchVectorStoreRetriever: # type: ignore[override]
"""Return AzureSearchVectorStoreRetriever initialized from this VectorStore.
Args:
@ -1781,7 +1781,7 @@ async def _areorder_results_with_maximal_marginal_relevance(
# Function can return -1 index
if x == -1:
break
ret.append((documents[x], scores[x])) # type: ignore
ret.append((documents[x], scores[x]))
return ret
@ -1816,7 +1816,7 @@ def _reorder_results_with_maximal_marginal_relevance(
# Function can return -1 index
if x == -1:
break
ret.append((documents[x], scores[x])) # type: ignore
ret.append((documents[x], scores[x]))
return ret

View File

@ -656,7 +656,7 @@ class BigQueryVectorSearch(VectorStore):
Returns:
List of Documents most similar to the query vector, with similarity scores.
"""
emb = self.embedding_model.embed_query(query) # type: ignore
emb = self.embedding_model.embed_query(query)
return self.similarity_search_with_score_by_vector(
emb, k, filter, brute_force, fraction_lists_to_search, **kwargs
)
@ -738,9 +738,7 @@ class BigQueryVectorSearch(VectorStore):
Returns:
List of Documents selected by maximal marginal relevance.
"""
query_embedding = self.embedding_model.embed_query( # type: ignore
query
)
query_embedding = self.embedding_model.embed_query(query)
doc_tuples = self._search_with_score_and_embeddings_by_vector(
query_embedding, fetch_k, filter, brute_force, fraction_lists_to_search
)

View File

@ -183,7 +183,7 @@ class Clarifai(VectorStore):
try:
from clarifai.client.search import Search
from clarifai_grpc.grpc.api import resources_pb2
from google.protobuf import json_format # type: ignore
from google.protobuf import json_format
except ImportError as e:
raise ImportError(
"Could not import clarifai python package. "

View File

@ -275,7 +275,7 @@ class DeepLake(VectorStore):
metadata=metadatas,
embedding_data=texts,
embedding_tensor="embedding",
embedding_function=self._embedding_function.embed_documents, # type: ignore
embedding_function=self._embedding_function.embed_documents, # type: ignore[union-attr]
return_ids=True,
**kwargs,
)
@ -464,8 +464,8 @@ class DeepLake(VectorStore):
if use_maximal_marginal_relevance:
lambda_mult = kwargs.get("lambda_mult", 0.5)
indices = maximal_marginal_relevance( # type: ignore
embedding, # type: ignore
indices = maximal_marginal_relevance(
embedding, # type: ignore[arg-type]
embeddings,
k=min(k, len(texts)),
lambda_mult=lambda_mult,
@ -829,7 +829,7 @@ class DeepLake(VectorStore):
use_maximal_marginal_relevance=True,
lambda_mult=lambda_mult,
exec_option=exec_option,
embedding_function=embedding_function, # type: ignore
embedding_function=embedding_function, # type: ignore[arg-type]
**kwargs,
)

View File

@ -103,7 +103,7 @@ class DocArrayIndex(VectorStore, ABC):
Lower score represents more similarity.
"""
query_embedding = self.embedding.embed_query(query)
query_doc = self.doc_cls(embedding=query_embedding) # type: ignore
query_doc = self.doc_cls(embedding=query_embedding)
docs, scores = self.doc_index.find(query_doc, search_field="embedding", limit=k)
result = [
@ -152,7 +152,7 @@ class DocArrayIndex(VectorStore, ABC):
List of Documents most similar to the query vector.
"""
query_doc = self.doc_cls(embedding=embedding) # type: ignore
query_doc = self.doc_cls(embedding=embedding)
docs = self.doc_index.find(
query_doc, search_field="embedding", limit=k
).documents
@ -187,7 +187,7 @@ class DocArrayIndex(VectorStore, ABC):
List of Documents selected by maximal marginal relevance.
"""
query_embedding = self.embedding.embed_query(query)
query_doc = self.doc_cls(embedding=query_embedding) # type: ignore
query_doc = self.doc_cls(embedding=query_embedding)
docs = self.doc_index.find(
query_doc, search_field="embedding", limit=fetch_k

View File

@ -71,7 +71,7 @@ class DocArrayHnswSearch(DocArrayIndex):
num_threads=num_threads,
**kwargs,
)
doc_index = HnswDocumentIndex[doc_cls](work_dir=work_dir) # type: ignore
doc_index = HnswDocumentIndex[doc_cls](work_dir=work_dir)
return cls(doc_index, embedding)
@classmethod

View File

@ -41,7 +41,7 @@ class DocArrayInMemorySearch(DocArrayIndex):
from docarray.index import InMemoryExactNNIndex
doc_cls = cls._get_doc_cls(space=metric, **kwargs)
doc_index = InMemoryExactNNIndex[doc_cls]() # type: ignore
doc_index = InMemoryExactNNIndex[doc_cls]()
return cls(doc_index, embedding)
@classmethod

View File

@ -265,7 +265,7 @@ class DocumentDBVectorSearch(VectorStore):
for t, m, embedding in zip(texts, metadatas, embeddings)
]
# insert the documents in DocumentDB
insert_result = self._collection.insert_many(to_insert) # type: ignore
insert_result = self._collection.insert_many(to_insert)
return insert_result.inserted_ids
@classmethod

View File

@ -220,7 +220,7 @@ class DuckDB(VectorStore):
except ImportError:
warnings.warn("You may need to `pip install pandas` to use this method.")
embedding = self._embedding.embed_query(query) # type: ignore
embedding = self._embedding.embed_query(query)
list_cosine_similarity = self.duckdb.FunctionExpression(
"list_cosine_similarity",
self.duckdb.ColumnExpression(self._vector_key),
@ -265,7 +265,7 @@ class DuckDB(VectorStore):
A list of Documents most similar to the query.
"""
embedding = self._embedding.embed_query(query) # type: ignore
embedding = self._embedding.embed_query(query)
list_cosine_similarity = self.duckdb.FunctionExpression(
"list_cosine_similarity",
self.duckdb.ColumnExpression(self._vector_key),

View File

@ -254,7 +254,7 @@ repeated float %s = 1;
texts_l = list(texts)
if last_vector:
texts_l.pop()
embeds = self._embedding.embed_documents(texts_l) # type: ignore
embeds = self._embedding.embed_documents(texts_l) # type: ignore[union-attr]
if last_vector:
embeds.append(last_vector)
if not metadatas:
@ -288,7 +288,7 @@ repeated float %s = 1;
Returns:
List[Tuple[Document, float]]
"""
embed = self._embedding.embed_query(query) # type: ignore
embed = self._embedding.embed_query(query) # type: ignore[union-attr]
documents = self.similarity_search_with_score_by_vector(embedding=embed, k=k)
return documents

View File

@ -201,7 +201,7 @@ class LanceDB(VectorStore):
"""
docs = []
ids = ids or [str(uuid.uuid4()) for _ in texts]
embeddings = self._embedding.embed_documents(list(texts)) # type: ignore
embeddings = self._embedding.embed_documents(list(texts)) # type: ignore[union-attr]
for idx, text in enumerate(texts):
embedding = embeddings[idx]
metadata = metadatas[idx] if metadatas else {"id": ids[idx]}
@ -490,7 +490,7 @@ class LanceDB(VectorStore):
embedding = self._embedding.embed_query(query)
_query = (embedding, query)
else:
_query = query # type: ignore
_query = query # type: ignore[assignment]
res = self._query(_query, k, filter=filter, name=name, **kwargs)
return self.results_to_docs(res, score=score)

View File

@ -58,9 +58,9 @@ def get_embedding_store(
embedding_type = None
if distance_strategy == DistanceStrategy.HAMMING:
embedding_type = sqlalchemy.INTEGER # type: ignore
embedding_type = sqlalchemy.INTEGER
else:
embedding_type = sqlalchemy.REAL # type: ignore
embedding_type = sqlalchemy.REAL # type: ignore[assignment]
DynamicBase = declarative_base(class_registry=dict()) # type: Any
@ -74,7 +74,7 @@ def get_embedding_store(
cmetadata = sqlalchemy.Column(JSON, nullable=True)
# custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
embedding = sqlalchemy.Column(sqlalchemy.ARRAY(embedding_type)) # type: ignore
embedding = sqlalchemy.Column(sqlalchemy.ARRAY(embedding_type)) # type: ignore[arg-type,var-annotated]
return EmbeddingStore

View File

@ -397,7 +397,7 @@ class MomentoVectorIndex(VectorStore):
)
selected = [response.hits[i].metadata for i in mmr_selected]
return [
Document(page_content=metadata.pop(self.text_field, ""), metadata=metadata) # type: ignore
Document(page_content=metadata.pop(self.text_field, ""), metadata=metadata)
for metadata in selected
]
@ -484,6 +484,6 @@ class MomentoVectorIndex(VectorStore):
configuration=VectorIndexConfigurations.Default.latest(),
credential_provider=CredentialProvider.from_string(api_key),
)
vector_db = cls(embedding=embedding, client=client, **kwargs) # type: ignore
vector_db = cls(embedding=embedding, client=client, **kwargs) # type: ignore[call-arg]
vector_db.add_texts(texts=texts, metadatas=metadatas, **kwargs)
return vector_db

View File

@ -183,7 +183,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
for t, m, embedding in zip(texts, metadatas, embeddings)
]
# insert the documents in MongoDB Atlas
insert_result = self._collection.insert_many(to_insert) # type: ignore
insert_result = self._collection.insert_many(to_insert)
return insert_result.inserted_ids
def _similarity_search_with_score(

View File

@ -857,7 +857,7 @@ class OracleVS(VectorStore):
)
documents.append((document, distance, current_embedding))
return documents # type: ignore
return documents
@_handle_exceptions
def max_marginal_relevance_search_with_score_by_vector(

View File

@ -49,7 +49,7 @@ class CollectionStore(BaseModel):
@classmethod
def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]:
return session.query(cls).filter(cls.name == name).first() # type: ignore
return session.query(cls).filter(cls.name == name).first()
@classmethod
def get_or_create(
@ -88,7 +88,7 @@ class EmbeddingStore(BaseModel):
)
collection = relationship(CollectionStore, back_populates="embeddings")
embedding = sqlalchemy.Column(sqlalchemy.ARRAY(sqlalchemy.REAL)) # type: ignore
embedding = sqlalchemy.Column(sqlalchemy.ARRAY(sqlalchemy.REAL)) # type: ignore[var-annotated]
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata = sqlalchemy.Column(JSON, nullable=True)

View File

@ -33,7 +33,7 @@ try:
from sqlalchemy import SQLColumnExpression
except ImportError:
# for sqlalchemy < 2
SQLColumnExpression = Any # type: ignore
SQLColumnExpression = Any # type: ignore[assignment,misc]
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
@ -126,7 +126,7 @@ def _get_embedding_collection_store(
def get_by_name(
cls, session: Session, name: str
) -> Optional["CollectionStore"]:
return session.query(cls).filter(cls.name == name).first() # type: ignore
return session.query(cls).filter(cls.name == name).first()
@classmethod
def get_or_create(
@ -956,7 +956,7 @@ class PGVector(VectorStore):
results: List[Any] = (
session.query(
self.EmbeddingStore,
self.distance_strategy(embedding).label("distance"), # type: ignore
self.distance_strategy(embedding).label("distance"),
)
.filter(*filter_by)
.order_by(sqlalchemy.asc("distance"))

View File

@ -419,7 +419,7 @@ class Redis(VectorStore):
# type check for metadata
if metadatas:
if isinstance(metadatas, list) and len(metadatas) != len(texts): # type: ignore
if isinstance(metadatas, list) and len(metadatas) != len(texts):
raise ValueError("Number of metadatas must match number of texts")
if not (isinstance(metadatas, list) and isinstance(metadatas[0], dict)):
raise ValueError("Metadatas must be a list of dicts")
@ -427,7 +427,7 @@ class Redis(VectorStore):
generated_schema = _generate_field_schema(metadatas[0])
if index_schema:
# read in the schema solely to compare to the generated schema
user_schema = read_schema(index_schema) # type: ignore
user_schema = read_schema(index_schema)
# the very rare case where a super user decides to pass the index
# schema and a document loader is used that has metadata which
@ -722,7 +722,7 @@ class Redis(VectorStore):
# type check for metadata
if metadatas:
if isinstance(metadatas, list) and len(metadatas) != len(texts): # type: ignore
if isinstance(metadatas, list) and len(metadatas) != len(texts): # type: ignore[arg-type]
raise ValueError("Number of metadatas must match number of texts")
if not (isinstance(metadatas, list) and isinstance(metadatas[0], dict)):
raise ValueError("Metadatas must be a list of dicts")
@ -850,7 +850,7 @@ class Redis(VectorStore):
# Perform vector search
# ignore type because redis-py is wrong about bytes
try:
results = self.client.ft(self.index_name).search(redis_query, params_dict) # type: ignore
results = self.client.ft(self.index_name).search(redis_query, params_dict)
except redis.exceptions.ResponseError as e:
# split error message and see if it starts with "Syntax"
if str(e).split(" ")[0] == "Syntax":
@ -966,7 +966,7 @@ class Redis(VectorStore):
# Perform vector search
# ignore type because redis-py is wrong about bytes
try:
results = self.client.ft(self.index_name).search(redis_query, params_dict) # type: ignore
results = self.client.ft(self.index_name).search(redis_query, params_dict)
except redis.exceptions.ResponseError as e:
# split error message and see if it starts with "Syntax"
if str(e).split(" ")[0] == "Syntax":
@ -1206,7 +1206,7 @@ class Redis(VectorStore):
# read in schema (yaml file or dict) and
# pass to the Pydantic validators
if index_schema:
schema_values = read_schema(index_schema) # type: ignore
schema_values = read_schema(index_schema)
schema = RedisModel(**schema_values)
# ensure user did not exclude the content field
@ -1242,7 +1242,7 @@ class Redis(VectorStore):
def _create_index_if_not_exist(self, dim: int = 1536) -> None:
try:
from redis.commands.search.indexDefinition import ( # type: ignore
from redis.commands.search.indexDefinition import (
IndexDefinition,
IndexType,
)

View File

@ -140,7 +140,7 @@ class RedisTag(RedisFilterField):
elif isinstance(other, str):
other = [other]
self._set_value(other, self.SUPPORTED_VAL_TYPES, operator) # type: ignore
self._set_value(other, self.SUPPORTED_VAL_TYPES, operator) # type: ignore[arg-type]
@check_operator_misuse
def __eq__(
@ -240,7 +240,7 @@ class RedisNum(RedisFilterField):
>>> from langchain_community.vectorstores.redis import RedisNum
>>> filter = RedisNum("zipcode") == 90210
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore[arg-type]
return RedisFilterExpression(str(self))
@check_operator_misuse
@ -254,7 +254,7 @@ class RedisNum(RedisFilterField):
>>> from langchain_community.vectorstores.redis import RedisNum
>>> filter = RedisNum("zipcode") != 90210
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore[arg-type]
return RedisFilterExpression(str(self))
def __gt__(self, other: Union[int, float]) -> "RedisFilterExpression":
@ -267,7 +267,7 @@ class RedisNum(RedisFilterField):
>>> from langchain_community.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") > 18
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GT) # type: ignore
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GT) # type: ignore[arg-type]
return RedisFilterExpression(str(self))
def __lt__(self, other: Union[int, float]) -> "RedisFilterExpression":
@ -280,7 +280,7 @@ class RedisNum(RedisFilterField):
>>> from langchain_community.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") < 18
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LT) # type: ignore
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LT) # type: ignore[arg-type]
return RedisFilterExpression(str(self))
def __ge__(self, other: Union[int, float]) -> "RedisFilterExpression":
@ -293,7 +293,7 @@ class RedisNum(RedisFilterField):
>>> from langchain_community.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") >= 18
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GE) # type: ignore
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GE) # type: ignore[arg-type]
return RedisFilterExpression(str(self))
def __le__(self, other: Union[int, float]) -> "RedisFilterExpression":
@ -306,7 +306,7 @@ class RedisNum(RedisFilterField):
>>> from langchain_community.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") <= 18
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LE) # type: ignore
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LE) # type: ignore[arg-type]
return RedisFilterExpression(str(self))
@ -336,7 +336,7 @@ class RedisText(RedisFilterField):
>>> from langchain_community.vectorstores.redis import RedisText
>>> filter = RedisText("job") == "engineer"
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore[arg-type]
return RedisFilterExpression(str(self))
@check_operator_misuse
@ -350,7 +350,7 @@ class RedisText(RedisFilterField):
>>> from langchain_community.vectorstores.redis import RedisText
>>> filter = RedisText("job") != "engineer"
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore[arg-type]
return RedisFilterExpression(str(self))
def __mod__(self, other: str) -> "RedisFilterExpression":
@ -366,7 +366,7 @@ class RedisText(RedisFilterField):
>>> filter = RedisText("job") % "engineer|doctor" # contains either term
>>> filter = RedisText("job") % "engineer doctor" # contains both terms
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LIKE) # type: ignore
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LIKE) # type: ignore[arg-type]
return RedisFilterExpression(str(self))
def __str__(self) -> str:

View File

@ -14,7 +14,7 @@ from typing_extensions import TYPE_CHECKING, Literal
from langchain_community.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP
if TYPE_CHECKING:
from redis.commands.search.field import ( # type: ignore
from redis.commands.search.field import (
NumericField,
TagField,
TextField,
@ -47,13 +47,13 @@ class TextFieldSchema(RedisField):
sortable: Optional[bool] = False
def as_field(self) -> TextField:
from redis.commands.search.field import TextField # type: ignore
from redis.commands.search.field import TextField
return TextField(
self.name,
weight=self.weight,
no_stem=self.no_stem,
phonetic_matcher=self.phonetic_matcher, # type: ignore
phonetic_matcher=self.phonetic_matcher,
sortable=self.sortable,
no_index=self.no_index,
)
@ -68,7 +68,7 @@ class TagFieldSchema(RedisField):
sortable: Optional[bool] = False
def as_field(self) -> TagField:
from redis.commands.search.field import TagField # type: ignore
from redis.commands.search.field import TagField
return TagField(
self.name,
@ -86,7 +86,7 @@ class NumericFieldSchema(RedisField):
sortable: Optional[bool] = False
def as_field(self) -> NumericField:
from redis.commands.search.field import NumericField # type: ignore
from redis.commands.search.field import NumericField
return NumericField(self.name, sortable=self.sortable, no_index=self.no_index)
@ -131,7 +131,7 @@ class FlatVectorField(RedisVectorField): # type: ignore[override]
block_size: Optional[int] = None
def as_field(self) -> VectorField:
from redis.commands.search.field import VectorField # type: ignore
from redis.commands.search.field import VectorField
field_data = super()._fields()
if self.block_size is not None:
@ -149,7 +149,7 @@ class HNSWVectorField(RedisVectorField): # type: ignore[override]
epsilon: float = Field(default=0.01)
def as_field(self) -> VectorField:
from redis.commands.search.field import VectorField # type: ignore
from redis.commands.search.field import VectorField
field_data = super()._fields()
field_data.update(
@ -193,9 +193,9 @@ class RedisModel(BaseModel):
# ignore types as pydantic is handling type validation and conversion
if vector_field["algorithm"] == "FLAT":
self.vector.append(FlatVectorField(**vector_field)) # type: ignore
self.vector.append(FlatVectorField(**vector_field))
elif vector_field["algorithm"] == "HNSW":
self.vector.append(HNSWVectorField(**vector_field)) # type: ignore
self.vector.append(HNSWVectorField(**vector_field))
else:
raise ValueError(
f"algorithm must be either FLAT or HNSW. Got "

View File

@ -257,7 +257,7 @@ class SupabaseVectorStore(VectorStore):
match_result = [
(
Document(
metadata=search.get("metadata", {}), # type: ignore
metadata=search.get("metadata", {}),
page_content=search.get("content", ""),
),
search.get("similarity", 0.0),
@ -302,7 +302,7 @@ class SupabaseVectorStore(VectorStore):
match_result = [
(
Document(
metadata=search.get("metadata", {}), # type: ignore
metadata=search.get("metadata", {}),
page_content=search.get("content", ""),
),
search.get("similarity", 0.0),
@ -351,7 +351,7 @@ class SupabaseVectorStore(VectorStore):
"id": ids[idx],
"content": documents[idx].page_content,
"embedding": embedding,
"metadata": documents[idx].metadata, # type: ignore
"metadata": documents[idx].metadata,
**kwargs,
}
for idx, embedding in enumerate(vectors)
@ -360,7 +360,7 @@ class SupabaseVectorStore(VectorStore):
for i in range(0, len(rows), chunk_size):
chunk = rows[i : i + chunk_size]
result = client.from_(table_name).upsert(chunk).execute() # type: ignore
result = client.from_(table_name).upsert(chunk).execute()
if len(result.data) == 0:
raise Exception("Error inserting: No rows added")

View File

@ -153,7 +153,7 @@ class UpstashVectorStore(VectorStore):
self._namespace = namespace
@property
def embeddings(self) -> Optional[Union[Embeddings, bool]]: # type: ignore
def embeddings(self) -> Optional[Union[Embeddings, bool]]: # type: ignore[override]
"""Access the query embedding object if available."""
return self._embeddings
@ -730,7 +730,7 @@ class UpstashVectorStore(VectorStore):
)
selected = [results[i].metadata for i in mmr_selected]
return [
Document(page_content=metadata.pop((self._text_key)), metadata=metadata) # type: ignore
Document(page_content=metadata.pop((self._text_key)), metadata=metadata)
for metadata in selected
]
@ -798,7 +798,7 @@ class UpstashVectorStore(VectorStore):
)
selected = [results[i].metadata for i in mmr_selected]
return [
Document(page_content=metadata.pop((self._text_key)), metadata=metadata) # type: ignore
Document(page_content=metadata.pop((self._text_key)), metadata=metadata)
for metadata in selected
]

View File

@ -467,7 +467,7 @@ class Vectara(VectorStore):
}
if config.lambda_val > 0:
body["query"][0]["corpusKey"][0]["lexicalInterpolationConfig"] = { # type: ignore
body["query"][0]["corpusKey"][0]["lexicalInterpolationConfig"] = { # type: ignore[index]
"lambda": config.lambda_val
}
@ -495,7 +495,7 @@ class Vectara(VectorStore):
}
]
if chat:
body["query"][0]["summary"][0]["chat"] = { # type: ignore
body["query"][0]["summary"][0]["chat"] = { # type: ignore[index]
"store": True,
"conversationId": chat_conv_id,
}

View File

@ -65,7 +65,7 @@ lint = [
]
dev = ["jupyter<2.0.0,>=1.0.0", "setuptools<68.0.0,>=67.6.1", "langchain-core"]
typing = [
"mypy<2.0,>=1.12",
"mypy<2.0,>=1.15",
"types-pyyaml<7.0.0.0,>=6.0.12.2",
"types-requests<3.0.0.0,>=2.28.11.5",
"types-toml<1.0.0.0,>=0.10.8.1",
@ -103,7 +103,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin,cann"
[tool.ruff.lint]
select = ["E", "F", "I", "T201"]
select = ["E", "F", "I", "PGH003", "T201"]
[tool.coverage.run]
omit = ["tests/*"]

View File

@ -27,7 +27,7 @@ def init_gptcache_map(cache_obj: Any) -> None:
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=cache_path),
)
init_gptcache_map._i = i + 1 # type: ignore
init_gptcache_map._i = i + 1 # type: ignore[attr-defined]
def init_gptcache_map_with_llm(cache_obj: Any, llm: str) -> None:

View File

@ -37,7 +37,7 @@ def test_messages(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> Non
Message(content="message2", role="human", metadata={"key2": "value2"}),
],
)
zep_chat.zep_client.memory.get_memory.return_value = mock_memory # type: ignore
zep_chat.zep_client.memory.get_memory.return_value = mock_memory
result = zep_chat.messages
@ -52,25 +52,25 @@ def test_add_user_message(
mocker: MockerFixture, zep_chat: ZepChatMessageHistory
) -> None:
zep_chat.add_user_message("test message")
zep_chat.zep_client.memory.add_memory.assert_called_once() # type: ignore
zep_chat.zep_client.memory.add_memory.assert_called_once()
@pytest.mark.requires("zep_python")
def test_add_ai_message(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
zep_chat.add_ai_message("test message")
zep_chat.zep_client.memory.add_memory.assert_called_once() # type: ignore
zep_chat.zep_client.memory.add_memory.assert_called_once()
@pytest.mark.requires("zep_python")
def test_append(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
zep_chat.add_message(AIMessage(content="test message"))
zep_chat.zep_client.memory.add_memory.assert_called_once() # type: ignore
zep_chat.zep_client.memory.add_memory.assert_called_once()
@pytest.mark.requires("zep_python")
def test_search(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
zep_chat.search("test query")
zep_chat.zep_client.memory.search_memory.assert_called_once_with( # type: ignore
zep_chat.zep_client.memory.search_memory.assert_called_once_with(
"test_session", mocker.ANY, limit=None
)
@ -78,6 +78,4 @@ def test_search(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
@pytest.mark.requires("zep_python")
def test_clear(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
zep_chat.clear()
zep_chat.zep_client.memory.delete_memory.assert_called_once_with( # type: ignore
"test_session"
)
zep_chat.zep_client.memory.delete_memory.assert_called_once_with("test_session")

View File

@ -67,7 +67,7 @@ class AnswerWithJustification(BaseModel):
def test_chat_minimax_with_structured_output() -> None:
"""Test MiniMaxChat with structured output."""
llm = MiniMaxChat() # type: ignore
llm = MiniMaxChat() # type: ignore[call-arg]
structured_llm = llm.with_structured_output(AnswerWithJustification)
response = structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
@ -77,7 +77,7 @@ def test_chat_minimax_with_structured_output() -> None:
def test_chat_tongyi_with_structured_output_include_raw() -> None:
"""Test MiniMaxChat with structured output."""
llm = MiniMaxChat() # type: ignore
llm = MiniMaxChat() # type: ignore[call-arg]
structured_llm = llm.with_structured_output(
AnswerWithJustification, include_raw=True
)

View File

@ -170,7 +170,7 @@ class GenerateUsername(BaseModel):
def test_tool_use() -> None:
llm = ChatTongyi(model="qwen-turbo", temperature=0) # type: ignore
llm = ChatTongyi(model="qwen-turbo", temperature=0) # type: ignore[call-arg]
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
msgs: List = [
HumanMessage(content="Sally has green hair, what would her username be?")
@ -187,7 +187,7 @@ def test_tool_use() -> None:
tool_msg = ToolMessage(
content="sally_green_hair",
tool_call_id=ai_msg.tool_calls[0]["id"], # type: ignore
tool_call_id=ai_msg.tool_calls[0]["id"],
name=ai_msg.tool_calls[0]["name"],
)
msgs.extend([ai_msg, tool_msg])
@ -201,7 +201,7 @@ def test_tool_use() -> None:
gathered = message
first = False
else:
gathered = gathered + message # type: ignore
gathered = gathered + message # type: ignore[assignment]
assert isinstance(gathered, AIMessageChunk)
streaming_tool_msg = ToolMessage(
@ -215,7 +215,7 @@ def test_tool_use() -> None:
def test_manual_tool_call_msg() -> None:
"""Test passing in manually construct tool call message."""
llm = ChatTongyi(model="qwen-turbo", temperature=0) # type: ignore
llm = ChatTongyi(model="qwen-turbo", temperature=0) # type: ignore[call-arg]
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
msgs: List = [
HumanMessage(content="Sally has green hair, what would her username be?"),
@ -246,7 +246,7 @@ class AnswerWithJustification(BaseModel):
def test_chat_tongyi_with_structured_output() -> None:
"""Test ChatTongyi with structured output."""
llm = ChatTongyi() # type: ignore
llm = ChatTongyi() # type: ignore[call-arg]
structured_llm = llm.with_structured_output(AnswerWithJustification)
response = structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
@ -256,7 +256,7 @@ def test_chat_tongyi_with_structured_output() -> None:
def test_chat_tongyi_with_structured_output_include_raw() -> None:
"""Test ChatTongyi with structured output."""
llm = ChatTongyi() # type: ignore
llm = ChatTongyi() # type: ignore[call-arg]
structured_llm = llm.with_structured_output(
AnswerWithJustification, include_raw=True
)

View File

@ -75,7 +75,7 @@ async def test_vertexai_agenerate(model_name: str) -> None:
message = HumanMessage(content="Hello")
response = await model.agenerate([[message]])
assert isinstance(response, LLMResult)
assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore
assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore[union-attr]
sync_response = model.generate([[message]])
assert response.generations[0][0] == sync_response.generations[0][0]

View File

@ -1,4 +1,6 @@
from typing import Dict
from __future__ import annotations
from typing import TYPE_CHECKING, Dict
from unittest.mock import MagicMock, patch
import pytest
@ -6,6 +8,9 @@ from langchain_core.documents import Document
from langchain_community.document_loaders.quip import QuipLoader
if TYPE_CHECKING:
from collections.abc import Iterator
try:
from quip_api.quip import QuipClient # noqa: F401
@ -15,7 +20,7 @@ except ImportError:
@pytest.fixture
def mock_quip(): # type: ignore
def mock_quip() -> Iterator[MagicMock]:
# mock quip_client
with patch("quip_api.quip.QuipClient") as mock_quip:
yield mock_quip

View File

@ -105,7 +105,7 @@ def test_unstructured_api_file_loader_io_multiple_files() -> None:
files = [stack.enter_context(open(file_path, "rb")) for file_path in file_paths]
loader = UnstructuredAPIFileIOLoader(
file=files, # type: ignore
file=files,
api_key="FAKE_API_KEY",
strategy="fast",
mode="elements",

Some files were not shown because too many files have changed in this diff Show More