community: Add ruff rule PGH003 (#30812)

See https://docs.astral.sh/ruff/rules/blanket-type-ignore/

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Christophe Bornet 2025-04-14 04:32:13 +02:00 committed by GitHub
parent f005988e31
commit ada740b5b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
148 changed files with 448 additions and 419 deletions

View File

@ -196,7 +196,7 @@ def create_sql_agent(
] ]
prompt = ChatPromptTemplate.from_messages(messages) prompt = ChatPromptTemplate.from_messages(messages)
agent = RunnableAgent( agent = RunnableAgent(
runnable=create_openai_functions_agent(llm, tools, prompt), # type: ignore runnable=create_openai_functions_agent(llm, tools, prompt), # type: ignore[arg-type]
input_keys_arg=["input"], input_keys_arg=["input"],
return_keys_arg=["output"], return_keys_arg=["output"],
**kwargs, **kwargs,
@ -211,9 +211,9 @@ def create_sql_agent(
] ]
prompt = ChatPromptTemplate.from_messages(messages) prompt = ChatPromptTemplate.from_messages(messages)
if agent_type == "openai-tools": if agent_type == "openai-tools":
runnable = create_openai_tools_agent(llm, tools, prompt) # type: ignore runnable = create_openai_tools_agent(llm, tools, prompt) # type: ignore[arg-type]
else: else:
runnable = create_tool_calling_agent(llm, tools, prompt) # type: ignore runnable = create_tool_calling_agent(llm, tools, prompt) # type: ignore[arg-type]
agent = RunnableMultiActionAgent( # type: ignore[assignment] agent = RunnableMultiActionAgent( # type: ignore[assignment]
runnable=runnable, runnable=runnable,
input_keys_arg=["input"], input_keys_arg=["input"],

View File

@ -135,7 +135,7 @@ def _get_assistants_tool(
Dict[str, Any]: A dictionary of tools that are converted into OpenAI tools. Dict[str, Any]: A dictionary of tools that are converted into OpenAI tools.
""" """
if _is_assistants_builtin_tool(tool): if _is_assistants_builtin_tool(tool):
return tool # type: ignore return tool # type: ignore[return-value]
else: else:
return convert_to_openai_tool(tool) return convert_to_openai_tool(tool)
@ -288,7 +288,7 @@ class OpenAIAssistantV2Runnable(OpenAIAssistantRunnable):
assistant = client.beta.assistants.create( assistant = client.beta.assistants.create(
name=name, name=name,
instructions=instructions, instructions=instructions,
tools=[_get_assistants_tool(tool) for tool in tools], # type: ignore tools=[_get_assistants_tool(tool) for tool in tools],
tool_resources=tool_resources, # type: ignore[arg-type] tool_resources=tool_resources, # type: ignore[arg-type]
model=model, model=model,
extra_body=extra_body, extra_body=extra_body,
@ -430,7 +430,7 @@ class OpenAIAssistantV2Runnable(OpenAIAssistantRunnable):
assistant = await async_client.beta.assistants.create( assistant = await async_client.beta.assistants.create(
name=name, name=name,
instructions=instructions, instructions=instructions,
tools=openai_tools, # type: ignore tools=openai_tools,
tool_resources=tool_resources, # type: ignore[arg-type] tool_resources=tool_resources, # type: ignore[arg-type]
model=model, model=model,
) )

View File

@ -238,7 +238,7 @@ class InMemoryCache(BaseCache):
Base = declarative_base() Base = declarative_base()
class FullLLMCache(Base): # type: ignore class FullLLMCache(Base): # type: ignore[misc,valid-type]
"""SQLite table for full LLM Cache (all generations).""" """SQLite table for full LLM Cache (all generations)."""
__tablename__ = "full_llm_cache" __tablename__ = "full_llm_cache"
@ -261,7 +261,7 @@ class SQLAlchemyCache(BaseCache):
"""Look up based on prompt and llm_string.""" """Look up based on prompt and llm_string."""
stmt = ( stmt = (
select(self.cache_schema.response) select(self.cache_schema.response)
.where(self.cache_schema.prompt == prompt) # type: ignore .where(self.cache_schema.prompt == prompt)
.where(self.cache_schema.llm == llm_string) .where(self.cache_schema.llm == llm_string)
.order_by(self.cache_schema.idx) .order_by(self.cache_schema.idx)
) )
@ -1531,7 +1531,7 @@ class CassandraSemanticCache(BaseCache):
await self.table.aclear() await self.table.aclear()
class FullMd5LLMCache(Base): # type: ignore class FullMd5LLMCache(Base): # type: ignore[misc,valid-type]
"""SQLite table for full LLM Cache (all generations).""" """SQLite table for full LLM Cache (all generations)."""
__tablename__ = "full_md5_llm_cache" __tablename__ = "full_md5_llm_cache"
@ -1583,7 +1583,7 @@ class SQLAlchemyMd5Cache(BaseCache):
def _delete_previous(self, session: Session, prompt: str, llm_string: str) -> None: def _delete_previous(self, session: Session, prompt: str, llm_string: str) -> None:
stmt = ( stmt = (
delete(self.cache_schema) delete(self.cache_schema)
.where(self.cache_schema.prompt_md5 == self.get_md5(prompt)) # type: ignore .where(self.cache_schema.prompt_md5 == self.get_md5(prompt))
.where(self.cache_schema.llm == llm_string) .where(self.cache_schema.llm == llm_string)
.where(self.cache_schema.prompt == prompt) .where(self.cache_schema.prompt == prompt)
) )
@ -1593,7 +1593,7 @@ class SQLAlchemyMd5Cache(BaseCache):
prompt_pd5 = self.get_md5(prompt) prompt_pd5 = self.get_md5(prompt)
stmt = ( stmt = (
select(self.cache_schema.response) select(self.cache_schema.response)
.where(self.cache_schema.prompt_md5 == prompt_pd5) # type: ignore .where(self.cache_schema.prompt_md5 == prompt_pd5)
.where(self.cache_schema.llm == llm_string) .where(self.cache_schema.llm == llm_string)
.where(self.cache_schema.prompt == prompt) .where(self.cache_schema.prompt == prompt)
.order_by(self.cache_schema.idx) .order_by(self.cache_schema.idx)
@ -1796,7 +1796,7 @@ class _CachedAwaitable:
def __await__(self) -> Generator: def __await__(self) -> Generator:
if self.result is _unset: if self.result is _unset:
self.result = yield from self.awaitable.__await__() self.result = yield from self.awaitable.__await__()
return self.result # type: ignore return self.result # type: ignore[return-value]
def _reawaitable(func: Callable) -> Callable: def _reawaitable(func: Callable) -> Callable:

View File

@ -584,7 +584,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) )
_custom_metrics = custom_metrics if custom_metrics else self.custom_metrics _custom_metrics = custom_metrics if custom_metrics else self.custom_metrics
self.__init__( # type: ignore self.__init__( # type: ignore[misc]
task_type=_task_type, task_type=_task_type,
workspace=_workspace, workspace=_workspace,
project_name=_project_name, project_name=_project_name,

View File

@ -580,7 +580,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.temp_dir.cleanup() self.temp_dir.cleanup()
self.reset_callback_meta() self.reset_callback_meta()
if reset: if reset:
self.__init__( # type: ignore self.__init__( # type: ignore[misc]
job_type=job_type if job_type else self.job_type, job_type=job_type if job_type else self.job_type,
project=project if project else self.project, project=project if project else self.project,
entity=entity if entity else self.entity, entity=entity if entity else self.entity,

View File

@ -352,7 +352,7 @@ def create_structured_output_runnable(
class _OutputFormatter(BaseModel): class _OutputFormatter(BaseModel):
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501 """Output formatter. Should always be used to format your response to the user.""" # noqa: E501
output: output_schema # type: ignore output: output_schema # type: ignore[valid-type]
function = _OutputFormatter function = _OutputFormatter
output_parser = output_parser or PydanticAttrOutputFunctionsParser( output_parser = output_parser or PydanticAttrOutputFunctionsParser(
@ -537,7 +537,7 @@ def create_structured_output_chain(
class _OutputFormatter(BaseModel): class _OutputFormatter(BaseModel):
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501 """Output formatter. Should always be used to format your response to the user.""" # noqa: E501
output: output_schema # type: ignore output: output_schema # type: ignore[valid-type]
function = _OutputFormatter function = _OutputFormatter
output_parser = output_parser or PydanticAttrOutputFunctionsParser( output_parser = output_parser or PydanticAttrOutputFunctionsParser(

View File

@ -316,7 +316,7 @@ class GraphCypherQAChain(Chain):
MessagesPlaceholder(variable_name="function_response"), MessagesPlaceholder(variable_name="function_response"),
] ]
) )
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore[operator]
except (NotImplementedError, AttributeError): except (NotImplementedError, AttributeError):
raise ValueError("Provided LLM does not support native tools/functions") raise ValueError("Provided LLM does not support native tools/functions")
else: else:
@ -404,15 +404,15 @@ class GraphCypherQAChain(Chain):
intermediate_steps.append({"context": context}) intermediate_steps.append({"context": context})
if self.use_function_response: if self.use_function_response:
function_response = get_function_response(question, context) function_response = get_function_response(question, context)
final_result = self.qa_chain.invoke( # type: ignore final_result = self.qa_chain.invoke( # type: ignore[assignment]
{"question": question, "function_response": function_response}, {"question": question, "function_response": function_response},
) )
else: else:
result = self.qa_chain.invoke( # type: ignore result = self.qa_chain.invoke(
{"question": question, "context": context}, {"question": question, "context": context},
callbacks=callbacks, callbacks=callbacks,
) )
final_result = result[self.qa_chain.output_key] # type: ignore final_result = result[self.qa_chain.output_key] # type: ignore[union-attr]
chain_result: Dict[str, Any] = {self.output_key: final_result} chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps: if self.return_intermediate_steps:

View File

@ -225,11 +225,11 @@ class MemgraphQAChain(Chain):
MessagesPlaceholder(variable_name="function_response"), MessagesPlaceholder(variable_name="function_response"),
] ]
) )
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore[operator]
except (NotImplementedError, AttributeError): except (NotImplementedError, AttributeError):
raise ValueError("Provided LLM does not support native tools/functions") raise ValueError("Provided LLM does not support native tools/functions")
else: else:
qa_chain = use_qa_llm_kwargs["prompt"] | qa_llm | StrOutputParser() # type: ignore qa_chain = use_qa_llm_kwargs["prompt"] | qa_llm | StrOutputParser()
prompt = use_cypher_llm_kwargs["prompt"] prompt = use_cypher_llm_kwargs["prompt"]
llm_to_use = cypher_llm if cypher_llm is not None else llm llm_to_use = cypher_llm if cypher_llm is not None else llm
@ -300,11 +300,11 @@ class MemgraphQAChain(Chain):
intermediate_steps.append({"context": context}) intermediate_steps.append({"context": context})
if self.use_function_response: if self.use_function_response:
function_response = get_function_response(question, context) function_response = get_function_response(question, context)
result = self.qa_chain.invoke( # type: ignore result = self.qa_chain.invoke(
{"question": question, "function_response": function_response}, {"question": question, "function_response": function_response},
) )
else: else:
result = self.qa_chain.invoke( # type: ignore result = self.qa_chain.invoke(
{"question": question, "context": context}, {"question": question, "context": context},
callbacks=callbacks, callbacks=callbacks,
) )

View File

@ -67,11 +67,11 @@ def extract_cypher(text: str) -> str:
def use_simple_prompt(llm: BaseLanguageModel) -> bool: def use_simple_prompt(llm: BaseLanguageModel) -> bool:
"""Decides whether to use the simple prompt""" """Decides whether to use the simple prompt"""
if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore[attr-defined]
return True return True
# Bedrock anthropic # Bedrock anthropic
if hasattr(llm, "model_id") and "anthropic" in llm.model_id: # type: ignore if hasattr(llm, "model_id") and "anthropic" in llm.model_id:
return True return True
return False return False

View File

@ -313,8 +313,12 @@ class PebbloRetrievalQA(Chain):
) )
@staticmethod @staticmethod
def _get_app_details( # type: ignore def _get_app_details(
app_name: str, owner: str, description: str, llm: BaseLanguageModel, **kwargs app_name: str,
owner: str,
description: str,
llm: BaseLanguageModel,
**kwargs: Any,
) -> App: ) -> App:
"""Fetch app details. Internal method. """Fetch app details. Internal method.
Returns: Returns:

View File

@ -81,7 +81,7 @@ class CassandraChatMessageHistory(BaseChatMessageHistory):
) )
@property @property
def messages(self) -> List[BaseMessage]: # type: ignore def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve all session messages from DB""" """Retrieve all session messages from DB"""
# The latest are returned, in chronological order # The latest are returned, in chronological order
rows = self.table.get_partition( rows = self.table.get_partition(

View File

@ -35,7 +35,7 @@ class FileChatMessageHistory(BaseChatMessageHistory):
) )
@property @property
def messages(self) -> List[BaseMessage]: # type: ignore def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from the local file""" """Retrieve the messages from the local file"""
items = json.loads(self.file_path.read_text(encoding=self.encoding)) items = json.loads(self.file_path.read_text(encoding=self.encoding))
messages = messages_from_dict(items) messages = messages_from_dict(items)

View File

@ -334,7 +334,7 @@ class KafkaChatMessageHistory(BaseChatMessageHistory):
) )
@property @property
def messages(self) -> List[BaseMessage]: # type: ignore def messages(self) -> List[BaseMessage]: # type: ignore[override]
""" """
Retrieve the messages for the session, from Kafka topic continuously Retrieve the messages for the session, from Kafka topic continuously
from last consumed message. This method is stateful and maintains from last consumed message. This method is stateful and maintains

View File

@ -60,7 +60,7 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
self.collection.create_index("SessionId") self.collection.create_index("SessionId")
@property @property
def messages(self) -> List[BaseMessage]: # type: ignore def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from MongoDB""" """Retrieve the messages from MongoDB"""
from pymongo import errors from pymongo import errors

View File

@ -65,7 +65,7 @@ class PostgresChatMessageHistory(BaseChatMessageHistory):
self.connection.commit() self.connection.commit()
@property @property
def messages(self) -> List[BaseMessage]: # type: ignore def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from PostgreSQL""" """Retrieve the messages from PostgreSQL"""
query = ( query = (
f"SELECT message FROM {self.table_name} WHERE session_id = %s ORDER BY id;" f"SELECT message FROM {self.table_name} WHERE session_id = %s ORDER BY id;"

View File

@ -215,7 +215,7 @@ class RocksetChatMessageHistory(BaseChatMessageHistory):
self._create_empty_doc() self._create_empty_doc()
@property @property
def messages(self) -> List[BaseMessage]: # type: ignore def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Messages in this chat history.""" """Messages in this chat history."""
return messages_from_dict( return messages_from_dict(
self._query( self._query(

View File

@ -212,7 +212,7 @@ class SingleStoreDBChatMessageHistory(BaseChatMessageHistory):
conn.close() conn.close()
@property @property
def messages(self) -> List[BaseMessage]: # type: ignore def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from SingleStoreDB""" """Retrieve the messages from SingleStoreDB"""
self._create_table_if_not_exists() self._create_table_if_not_exists()
conn = self.connection_pool.connect() conn = self.connection_pool.connect()

View File

@ -47,7 +47,7 @@ try:
from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError: except ImportError:
# dummy for sqlalchemy < 2 # dummy for sqlalchemy < 2
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore[assignment,misc]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -242,7 +242,7 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
self._table_created = True self._table_created = True
@property @property
def messages(self) -> List[BaseMessage]: # type: ignore def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve all messages from db""" """Retrieve all messages from db"""
with self._make_sync_session() as session: with self._make_sync_session() as session:
result = ( result = (

View File

@ -51,7 +51,7 @@ class UpstashRedisChatMessageHistory(BaseChatMessageHistory):
return self.key_prefix + self.session_id return self.key_prefix + self.session_id
@property @property
def messages(self) -> List[BaseMessage]: # type: ignore def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from Upstash Redis""" """Retrieve the messages from Upstash Redis"""
_items = self.redis_client.lrange(self.key, 0, -1) _items = self.redis_client.lrange(self.key, 0, -1)
items = [json.loads(m) for m in _items[::-1]] items = [json.loads(m) for m in _items[::-1]]

View File

@ -83,7 +83,7 @@ class XataChatMessageHistory(BaseChatMessageHistory):
raise Exception(f"Error adding message to Xata: {r.status_code} {r}") raise Exception(f"Error adding message to Xata: {r.status_code} {r}")
@property @property
def messages(self) -> List[BaseMessage]: # type: ignore def messages(self) -> List[BaseMessage]: # type: ignore[override]
r = self._client.data().query( r = self._client.data().query(
self._table_name, self._table_name,
payload={ payload={

View File

@ -87,7 +87,7 @@ class ZepChatMessageHistory(BaseChatMessageHistory):
self.session_id = session_id self.session_id = session_id
@property @property
def messages(self) -> List[BaseMessage]: # type: ignore def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve messages from Zep memory""" """Retrieve messages from Zep memory"""
zep_memory: Optional[Memory] = self._get_memory() zep_memory: Optional[Memory] = self._get_memory()
if not zep_memory: if not zep_memory:

View File

@ -134,7 +134,7 @@ class ZepCloudChatMessageHistory(BaseChatMessageHistory):
self.summary_instruction = summary_instruction self.summary_instruction = summary_instruction
@property @property
def messages(self) -> List[BaseMessage]: # type: ignore def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve messages from Zep memory""" """Retrieve messages from Zep memory"""
zep_memory: Optional[Memory] = self._get_memory() zep_memory: Optional[Memory] = self._get_memory()
if not zep_memory: if not zep_memory:

View File

@ -42,7 +42,7 @@ class ChatLiteLLMRouter(ChatLiteLLM):
def __init__(self, *, router: Any, **kwargs: Any) -> None: def __init__(self, *, router: Any, **kwargs: Any) -> None:
"""Construct Chat LiteLLM Router.""" """Construct Chat LiteLLM Router."""
super().__init__(router=router, **kwargs) # type: ignore super().__init__(router=router, **kwargs) # type: ignore[call-arg]
self.router = router self.router = router
@property @property

View File

@ -815,4 +815,4 @@ def _convert_delta_to_message_chunk(
elif role or default_class == ChatMessageChunk: elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role, id=id_) return ChatMessageChunk(content=content, role=role, id=id_)
else: else:
return default_class(content=content, id=id_) # type: ignore return default_class(content=content, id=id_) # type: ignore[call-arg]

View File

@ -716,7 +716,7 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
if is_pydantic_schema: if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser( output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item] tools=[schema], # type: ignore[list-item]
first_tool_only=True, # type: ignore[list-item] first_tool_only=True,
) )
else: else:
output_parser = JsonOutputKeyToolsParser( output_parser = JsonOutputKeyToolsParser(

View File

@ -158,9 +158,9 @@ def _convert_delta_response_to_message_chunk(
Optional[str], Optional[str],
]: ]:
"""Converts delta response to message chunk""" """Converts delta response to message chunk"""
_delta = response.choices[0].delta # type: ignore _delta = response.choices[0].delta
role = _delta.get("role", "") # type: ignore role = _delta.get("role", "")
content = _delta.get("content", "") # type: ignore content = _delta.get("content", "")
additional_kwargs: Dict = {} additional_kwargs: Dict = {}
finish_reasons: Optional[str] = response.choices[0].finish_reason finish_reasons: Optional[str] = response.choices[0].finish_reason
@ -398,7 +398,7 @@ class ChatPremAI(BaseChatModel, BaseModel):
messages, template_id=kwargs["template_id"] messages, template_id=kwargs["template_id"]
) )
else: else:
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore system_prompt, messages_to_pass = _messages_to_prompt_dict(messages)
if system_prompt is not None and system_prompt != "": if system_prompt is not None and system_prompt != "":
kwargs["system_prompt"] = system_prompt kwargs["system_prompt"] = system_prompt
@ -425,9 +425,9 @@ class ChatPremAI(BaseChatModel, BaseModel):
if "template_id" in kwargs: if "template_id" in kwargs:
system_prompt, messages_to_pass = _messages_to_prompt_dict( system_prompt, messages_to_pass = _messages_to_prompt_dict(
messages, template_id=kwargs["template_id"] messages, template_id=kwargs["template_id"]
) # type: ignore )
else: else:
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore system_prompt, messages_to_pass = _messages_to_prompt_dict(messages)
if stop is not None: if stop is not None:
logger.warning("stop is not supported in langchain streaming") logger.warning("stop is not supported in langchain streaming")

View File

@ -218,7 +218,7 @@ class BlackboardLoader(WebBaseLoader):
loader = DirectoryLoader( loader = DirectoryLoader(
path=self.folder_path, path=self.folder_path,
glob="*.pdf", glob="*.pdf",
loader_cls=PyPDFLoader, # type: ignore loader_cls=PyPDFLoader, # type: ignore[arg-type]
) )
# Load the documents # Load the documents
documents = loader.load() documents = loader.load()

View File

@ -35,7 +35,7 @@ class _CloudBlob(Blob):
from cloudpathlib import AnyPath from cloudpathlib import AnyPath
if self.data is None and self.path: if self.data is None and self.path:
return AnyPath(self.path).read_text(encoding=self.encoding) # type: ignore return AnyPath(self.path).read_text(encoding=self.encoding)
elif isinstance(self.data, bytes): elif isinstance(self.data, bytes):
return self.data.decode(self.encoding) return self.data.decode(self.encoding)
elif isinstance(self.data, str): elif isinstance(self.data, str):
@ -52,7 +52,7 @@ class _CloudBlob(Blob):
elif isinstance(self.data, str): elif isinstance(self.data, str):
return self.data.encode(self.encoding) return self.data.encode(self.encoding)
elif self.data is None and self.path: elif self.data is None and self.path:
return AnyPath(self.path).read_bytes() # type: ignore return AnyPath(self.path).read_bytes()
else: else:
raise ValueError(f"Unable to get bytes for blob {self}") raise ValueError(f"Unable to get bytes for blob {self}")
@ -64,7 +64,7 @@ class _CloudBlob(Blob):
if isinstance(self.data, bytes): if isinstance(self.data, bytes):
yield BytesIO(self.data) yield BytesIO(self.data)
elif self.data is None and self.path: elif self.data is None and self.path:
return AnyPath(self.path).read_bytes() # type: ignore return AnyPath(self.path).read_bytes()
else: else:
raise NotImplementedError(f"Unable to convert blob {self}") raise NotImplementedError(f"Unable to convert blob {self}")
@ -79,7 +79,7 @@ def _url_to_filename(url: str) -> str:
url_parsed = urlparse(url) url_parsed = urlparse(url)
suffix = Path(url_parsed.path).suffix suffix = Path(url_parsed.path).suffix
if url_parsed.scheme in ["s3", "az", "gs"]: if url_parsed.scheme in ["s3", "az", "gs"]:
with AnyPath(url).open("rb") as f: # type: ignore with AnyPath(url).open("rb") as f:
temp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) temp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
while True: while True:
buf = f.read() buf = f.read()
@ -116,7 +116,7 @@ def _make_iterator(
iterator = _with_tqdm iterator = _with_tqdm
else: else:
iterator = iter # type: ignore iterator = iter # type: ignore[assignment]
return iterator return iterator
@ -220,7 +220,7 @@ class CloudBlobLoader(BlobLoader):
def _yield_paths(self) -> Iterable["AnyPath"]: def _yield_paths(self) -> Iterable["AnyPath"]:
"""Yield paths that match the requested pattern.""" """Yield paths that match the requested pattern."""
if self.path.is_file(): # type: ignore if self.path.is_file():
yield self.path yield self.path
return return
@ -269,7 +269,7 @@ class CloudBlobLoader(BlobLoader):
Blob instance Blob instance
""" """
if mime_type is None and guess_type: if mime_type is None and guess_type:
_mimetype = mimetypes.guess_type(path)[0] if guess_type else None # type: ignore _mimetype = mimetypes.guess_type(path)[0] if guess_type else None
else: else:
_mimetype = mime_type _mimetype = mime_type

View File

@ -252,7 +252,7 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
files = self._fetch_files_recursive(service, folder_id) files = self._fetch_files_recursive(service, folder_id)
# If file types filter is provided, we'll filter by the file type. # If file types filter is provided, we'll filter by the file type.
if file_types: if file_types:
_files = [f for f in files if f["mimeType"] in file_types] # type: ignore _files = [f for f in files if f["mimeType"] in file_types]
else: else:
_files = files _files = files
@ -261,14 +261,14 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
if file["trashed"] and not self.load_trashed_files: if file["trashed"] and not self.load_trashed_files:
continue continue
elif file["mimeType"] == "application/vnd.google-apps.document": elif file["mimeType"] == "application/vnd.google-apps.document":
returns.append(self._load_document_from_id(file["id"])) # type: ignore returns.append(self._load_document_from_id(file["id"])) # type: ignore[arg-type]
elif file["mimeType"] == "application/vnd.google-apps.spreadsheet": elif file["mimeType"] == "application/vnd.google-apps.spreadsheet":
returns.extend(self._load_sheet_from_id(file["id"])) # type: ignore returns.extend(self._load_sheet_from_id(file["id"])) # type: ignore[arg-type]
elif ( elif (
file["mimeType"] == "application/pdf" file["mimeType"] == "application/pdf"
or self.file_loader_cls is not None or self.file_loader_cls is not None
): ):
returns.extend(self._load_file_from_id(file["id"])) # type: ignore returns.extend(self._load_file_from_id(file["id"])) # type: ignore[arg-type]
else: else:
pass pass
return returns return returns

View File

@ -267,7 +267,7 @@ class DocAIParser(BaseBlobParser):
"""Initializes Long-Running Operations from their names.""" """Initializes Long-Running Operations from their names."""
try: try:
from google.longrunning.operations_pb2 import ( from google.longrunning.operations_pb2 import (
GetOperationRequest, # type: ignore GetOperationRequest,
) )
except ImportError as exc: except ImportError as exc:
raise ImportError( raise ImportError(

View File

@ -59,9 +59,9 @@ class DocumentLoaderAsParser(BaseBlobParser):
""" """
Use underlying DocumentLoader to lazily parse the blob. Use underlying DocumentLoader to lazily parse the blob.
""" """
doc_loader = self.DocumentLoaderClass( doc_loader = self.DocumentLoaderClass( # type: ignore[call-arg]
file_path=blob.path, **self.document_loader_kwargs file_path=blob.path, **self.document_loader_kwargs
) # type: ignore )
for document in doc_loader.lazy_load(): for document in doc_loader.lazy_load():
document.metadata.update(blob.metadata) document.metadata.update(blob.metadata)
yield document yield document

View File

@ -107,7 +107,7 @@ class RapidOCRBlobParser(BaseImageBlobParser):
"`rapidocr-onnxruntime` package not found, please install it with " "`rapidocr-onnxruntime` package not found, please install it with "
"`pip install rapidocr-onnxruntime`" "`pip install rapidocr-onnxruntime`"
) )
ocr_result, _ = self.ocr(np.array(img)) # type: ignore ocr_result, _ = self.ocr(np.array(img)) # type: ignore[misc]
content = "" content = ""
if ocr_result: if ocr_result:
content = ("\n".join([text[1] for text in ocr_result])).strip() content = ("\n".join([text[1] for text in ocr_result])).strip()

View File

@ -82,7 +82,7 @@ class TreeSitterSegmenter(CodeSegmenter):
) )
for line_num in range(start_line + 1, end_line + 1): for line_num in range(start_line + 1, end_line + 1):
simplified_lines[line_num] = None # type: ignore simplified_lines[line_num] = None # type: ignore[call-overload]
processed_lines.update(lines) processed_lines.update(lines)

View File

@ -6,7 +6,7 @@ import json
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional from typing import Any, Dict, Iterator, List, Optional
import requests # type: ignore import requests
from langchain_core.document_loaders import BaseLoader from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document from langchain_core.documents import Document
from pydantic import Field from pydantic import Field

View File

@ -78,7 +78,7 @@ class TrelloLoader(BaseLoader):
""" """
try: try:
from trello import TrelloClient # type: ignore from trello import TrelloClient
except ImportError as ex: except ImportError as ex:
raise ImportError( raise ImportError(
"Could not import trello python package. " "Could not import trello python package. "
@ -124,7 +124,7 @@ class TrelloLoader(BaseLoader):
return board return board
def _card_to_doc(self, card: Card, list_dict: dict) -> Document: def _card_to_doc(self, card: Card, list_dict: dict) -> Document:
from bs4 import BeautifulSoup # type: ignore from bs4 import BeautifulSoup
text_content = "" text_content = ""
if self.include_card_name: if self.include_card_name:

View File

@ -245,8 +245,8 @@ def get_elements_from_api(
from unstructured.partition.api import partition_multiple_via_api from unstructured.partition.api import partition_multiple_via_api
_doc_elements = partition_multiple_via_api( _doc_elements = partition_multiple_via_api(
filenames=file_path, # type: ignore filenames=file_path,
files=file, # type: ignore files=file,
api_key=api_key, api_key=api_key,
api_url=api_url, api_url=api_url,
**unstructured_kwargs, **unstructured_kwargs,

View File

@ -393,7 +393,7 @@ class WebBaseLoader(BaseLoader):
"https://python.langchain.com/api_reference/community/document_loaders/langchain_community.document_loaders.web_base.WebBaseLoader.html" # noqa: E501 "https://python.langchain.com/api_reference/community/document_loaders/langchain_community.document_loaders.web_base.WebBaseLoader.html" # noqa: E501
), ),
) )
def aload(self) -> List[Document]: # type: ignore def aload(self) -> List[Document]: # type: ignore[override]
"""Load text from the urls in web_path async into Documents.""" """Load text from the urls in web_path async into Documents."""
results = self.scrape_all(self.web_paths) results = self.scrape_all(self.web_paths)

View File

@ -439,7 +439,7 @@ class HypotheticalDocumentEmbedder:
) )
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H
return H(*args, **kwargs) # type: ignore return H(*args, **kwargs) # type: ignore[return-value]
@classmethod @classmethod
def from_llm(cls, *args: Any, **kwargs: Any) -> Any: def from_llm(cls, *args: Any, **kwargs: Any) -> Any:

View File

@ -68,7 +68,7 @@ class InfinityEmbeddingsLocal(BaseModel, Embeddings):
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
try: try:
from infinity_emb import AsyncEmbeddingEngine # type: ignore from infinity_emb import AsyncEmbeddingEngine
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Please install the " "Please install the "

View File

@ -76,7 +76,7 @@ class JinaEmbeddings(BaseModel, Embeddings):
def _embed(self, input: Any) -> List[List[float]]: def _embed(self, input: Any) -> List[List[float]]:
# Call Jina AI Embedding API # Call Jina AI Embedding API
resp = self.session.post( # type: ignore resp = self.session.post(
JINA_API_URL, json={"input": input, "model": self.model_name} JINA_API_URL, json={"input": input, "model": self.model_name}
).json() ).json()
if "data" not in resp: if "data" not in resp:
@ -85,7 +85,7 @@ class JinaEmbeddings(BaseModel, Embeddings):
embeddings = resp["data"] embeddings = resp["data"]
# Sort resulting embeddings by index # Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])
# Return just the embeddings # Return just the embeddings
return [result["embedding"] for result in sorted_embeddings] return [result["embedding"] for result in sorted_embeddings]

View File

@ -309,7 +309,7 @@ class AsyncOpenAITextEmbedEmbeddingClient:
Raises: Raises:
Exception: If the response status is not 200. Exception: If the response status is not 200.
""" """
async with session.post(**kwargs) as response: # type: ignore async with session.post(**kwargs) as response: # type: ignore[arg-type]
if response.status != 200: if response.status != 200:
raise Exception( raise Exception(
f"TextEmbed responded with an unexpected status message " f"TextEmbed responded with an unexpected status message "

View File

@ -22,7 +22,7 @@ def ngram_overlap_score(source: List[str], example: List[str]) -> float:
https://aclanthology.org/P02-1040.pdf https://aclanthology.org/P02-1040.pdf
""" """
from nltk.translate.bleu_score import ( from nltk.translate.bleu_score import (
SmoothingFunction, # type: ignore SmoothingFunction,
sentence_bleu, sentence_bleu,
) )

View File

@ -54,7 +54,7 @@ try:
from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError: except ImportError:
# dummy for sqlalchemy < 2 # dummy for sqlalchemy < 2
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore[assignment,misc]
from langchain_community.indexes.base import RecordManager from langchain_community.indexes.base import RecordManager
@ -308,8 +308,8 @@ class SQLRecordManager(RecordManager):
[UpsertionRecord.key, UpsertionRecord.namespace], [UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict( set_=dict(
# attr-defined type ignore # attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore updated_at=insert_stmt.excluded.updated_at,
group_id=insert_stmt.excluded.group_id, # type: ignore group_id=insert_stmt.excluded.group_id,
), ),
) )
elif self.dialect == "postgresql": elif self.dialect == "postgresql":
@ -322,8 +322,8 @@ class SQLRecordManager(RecordManager):
"uix_key_namespace", # Name of constraint "uix_key_namespace", # Name of constraint
set_=dict( set_=dict(
# attr-defined type ignore # attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore updated_at=insert_stmt.excluded.updated_at,
group_id=insert_stmt.excluded.group_id, # type: ignore group_id=insert_stmt.excluded.group_id,
), ),
) )
else: else:
@ -383,8 +383,8 @@ class SQLRecordManager(RecordManager):
[UpsertionRecord.key, UpsertionRecord.namespace], [UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict( set_=dict(
# attr-defined type ignore # attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore updated_at=insert_stmt.excluded.updated_at,
group_id=insert_stmt.excluded.group_id, # type: ignore group_id=insert_stmt.excluded.group_id,
), ),
) )
elif self.dialect == "postgresql": elif self.dialect == "postgresql":
@ -397,8 +397,8 @@ class SQLRecordManager(RecordManager):
"uix_key_namespace", # Name of constraint "uix_key_namespace", # Name of constraint
set_=dict( set_=dict(
# attr-defined type ignore # attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore updated_at=insert_stmt.excluded.updated_at,
group_id=insert_stmt.excluded.group_id, # type: ignore group_id=insert_stmt.excluded.group_id,
), ),
) )
else: else:

View File

@ -486,10 +486,10 @@ class AzureMLBaseEndpoint(BaseModel):
timeout = values.get("timeout", DEFAULT_TIMEOUT) timeout = values.get("timeout", DEFAULT_TIMEOUT)
http_client = AzureMLEndpointClient( http_client = AzureMLEndpointClient(
endpoint_url, # type: ignore endpoint_url, # type: ignore[arg-type]
endpoint_key.get_secret_value(), # type: ignore endpoint_key.get_secret_value(), # type: ignore[union-attr]
deployment_name, # type: ignore deployment_name, # type: ignore[arg-type]
timeout, # type: ignore timeout,
) )
return http_client return http_client

View File

@ -201,7 +201,7 @@ class Beam(LLM): # type: ignore[override, override, override, override]
def _deploy(self) -> str: def _deploy(self) -> str:
"""Call to Beam.""" """Call to Beam."""
try: try:
import beam # type: ignore import beam
if beam.__path__ == "": if beam.__path__ == "":
raise ImportError raise ImportError

View File

@ -181,7 +181,7 @@ class _BaseGigaChat(Serializable):
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
"""Count approximate number of tokens""" """Count approximate number of tokens"""
if self.use_api_for_tokens: if self.use_api_for_tokens:
return self.tokens_count([text])[0].tokens # type: ignore return self.tokens_count([text])[0].tokens
else: else:
return round(len(text) / 4.6) return round(len(text) / 4.6)

View File

@ -142,7 +142,7 @@ class HuggingFaceHub(LLM):
if "error" in response: if "error" in response:
raise ValueError(f"Error raised by inference API: {response['error']}") raise ValueError(f"Error raised by inference API: {response['error']}")
response_key = VALID_TASKS_DICT[self.task] # type: ignore response_key = VALID_TASKS_DICT[self.task] # type: ignore[index]
if isinstance(response, list): if isinstance(response, list):
text = response[0][response_key] text = response[0][response_key]
else: else:

View File

@ -172,7 +172,7 @@ class IpexLLM(LLM):
if not low_bit_model: if not low_bit_model:
if load_in_low_bit is not None: if load_in_low_bit is not None:
load_function_name = "from_pretrained" load_function_name = "from_pretrained"
load_kwargs["load_in_low_bit"] = load_in_low_bit # type: ignore load_kwargs["load_in_low_bit"] = load_in_low_bit # type: ignore[assignment]
else: else:
load_function_name = "from_pretrained" load_function_name = "from_pretrained"
load_kwargs["load_in_4bit"] = load_in_4bit load_kwargs["load_in_4bit"] = load_in_4bit

View File

@ -246,7 +246,7 @@ class BaseOpenAI(BaseLLM):
http_client: Union[Any, None] = None http_client: Union[Any, None] = None
"""Optional httpx.Client.""" """Optional httpx.Client."""
def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore[misc]
"""Initialize the OpenAI object.""" """Initialize the OpenAI object."""
model_name = data.get("model_name", "") model_name = data.get("model_name", "")
if ( if (

View File

@ -47,7 +47,7 @@ class YiLLM(LLM):
def _post(self, request: Any) -> Any: def _post(self, request: Any) -> Any:
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {self.yi_api_key.get_secret_value()}", # type: ignore "Authorization": f"Bearer {self.yi_api_key.get_secret_value()}", # type: ignore[union-attr]
} }
urls = [] urls = []

View File

@ -161,11 +161,11 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
_result = super().parse_result(result) _result = super().parse_result(result)
if self.args_only: if self.args_only:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore[union-attr]
else: else:
fn_name = _result["name"] fn_name = _result["name"]
_args = _result["arguments"] _args = _result["arguments"]
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore[index]
return pydantic_args return pydantic_args

View File

@ -53,9 +53,9 @@ class DeepLakeTranslator(Visitor):
def _format_func(self, func: Union[Operator, Comparator]) -> str: def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func) self._validate_func(func)
if isinstance(func, Operator): if isinstance(func, Operator):
value = OPERATOR_TO_TQL[func.value] # type: ignore value = OPERATOR_TO_TQL[func.value] # type: ignore[index]
elif isinstance(func, Comparator): elif isinstance(func, Comparator):
value = COMPARATOR_TO_TQL[func.value] # type: ignore value = COMPARATOR_TO_TQL[func.value] # type: ignore[index]
return f"{value}" return f"{value}"
def visit_operation(self, operation: Operation) -> str: def visit_operation(self, operation: Operation) -> str:

View File

@ -42,9 +42,9 @@ class TimescaleVectorTranslator(Visitor):
def _format_func(self, func: Union[Operator, Comparator]) -> str: def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func) self._validate_func(func)
if isinstance(func, Operator): if isinstance(func, Operator):
value = self.OPERATOR_MAP[func.value] # type: ignore value = self.OPERATOR_MAP[func.value] # type: ignore[index]
elif isinstance(func, Comparator): elif isinstance(func, Comparator):
value = self.COMPARATOR_MAP[func.value] # type: ignore value = self.COMPARATOR_MAP[func.value] # type: ignore[index]
return f"{value}" return f"{value}"
def visit_operation(self, operation: Operation) -> client.Predicates: def visit_operation(self, operation: Operation) -> client.Predicates:

View File

@ -41,7 +41,7 @@ try:
from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError: except ImportError:
# dummy for sqlalchemy < 2 # dummy for sqlalchemy < 2
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore[assignment,misc]
Base = declarative_base() Base = declarative_base()
@ -255,7 +255,7 @@ class SQLStore(BaseStore[str, bytes]):
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
with self._make_sync_session() as session: with self._make_sync_session() as session:
for v in session.query(LangchainKeyValueStores).filter( # type: ignore for v in session.query(LangchainKeyValueStores).filter(
LangchainKeyValueStores.namespace == self.namespace LangchainKeyValueStores.namespace == self.namespace
): ):
if str(v.key).startswith(prefix or ""): if str(v.key).startswith(prefix or ""):

View File

@ -178,7 +178,7 @@ def execute_function(
statement=parametrized_statement.statement, statement=parametrized_statement.statement,
warehouse_id=warehouse_id, warehouse_id=warehouse_id,
parameters=parametrized_statement.parameters, parameters=parametrized_statement.parameters,
**execute_statement_args, # type: ignore **execute_statement_args,
) )
if response.status and job_pending(response.status.state) and response.statement_id: if response.status and job_pending(response.status.state) and response.statement_id:
statement_id = response.statement_id statement_id = response.statement_id
@ -197,7 +197,7 @@ def execute_function(
f"status after {wait} seconds." f"status after {wait} seconds."
) )
time.sleep(wait) time.sleep(wait)
response = ws.statement_execution.get_statement(statement_id) # type: ignore response = ws.statement_execution.get_statement(statement_id)
if response.status is None or not job_pending(response.status.state): if response.status is None or not job_pending(response.status.state):
break break
wait_time += wait wait_time += wait
@ -228,7 +228,7 @@ def execute_function(
if is_scalar(function): if is_scalar(function):
value = None value = None
if data_array and len(data_array) > 0 and len(data_array[0]) > 0: if data_array and len(data_array) > 0 and len(data_array[0]) > 0:
value = str(data_array[0][0]) # type: ignore value = str(data_array[0][0])
return FunctionExecutionResult( return FunctionExecutionResult(
format="SCALAR", value=value, truncated=truncated format="SCALAR", value=value, truncated=truncated
) )

View File

@ -51,8 +51,8 @@ def _uc_type_to_pydantic_type(uc_type_json: Union[str, Dict[str, Any]]) -> Type:
if tpe == "array": if tpe == "array":
element_type = _uc_type_to_pydantic_type(uc_type_json["elementType"]) element_type = _uc_type_to_pydantic_type(uc_type_json["elementType"])
if uc_type_json["containsNull"]: if uc_type_json["containsNull"]:
element_type = Optional[element_type] # type: ignore element_type = Optional[element_type] # type: ignore[assignment]
return List[element_type] # type: ignore return List[element_type] # type: ignore[valid-type]
elif tpe == "map": elif tpe == "map":
key_type = uc_type_json["keyType"] key_type = uc_type_json["keyType"]
assert key_type == "string", TypeError( assert key_type == "string", TypeError(
@ -60,14 +60,14 @@ def _uc_type_to_pydantic_type(uc_type_json: Union[str, Dict[str, Any]]) -> Type:
) )
value_type = _uc_type_to_pydantic_type(uc_type_json["valueType"]) value_type = _uc_type_to_pydantic_type(uc_type_json["valueType"])
if uc_type_json["valueContainsNull"]: if uc_type_json["valueContainsNull"]:
value_type: Type = Optional[value_type] # type: ignore value_type: Type = Optional[value_type] # type: ignore[no-redef]
return Dict[str, value_type] # type: ignore return Dict[str, value_type] # type: ignore[valid-type]
elif tpe == "struct": elif tpe == "struct":
fields = {} fields = {}
for field in uc_type_json["fields"]: for field in uc_type_json["fields"]:
field_type = _uc_type_to_pydantic_type(field["type"]) field_type = _uc_type_to_pydantic_type(field["type"])
if field.get("nullable"): if field.get("nullable"):
field_type = Optional[field_type] # type: ignore field_type = Optional[field_type] # type: ignore[assignment]
comment = ( comment = (
uc_type_json["metadata"].get("comment") uc_type_json["metadata"].get("comment")
if "metadata" in uc_type_json if "metadata" in uc_type_json
@ -76,7 +76,7 @@ def _uc_type_to_pydantic_type(uc_type_json: Union[str, Dict[str, Any]]) -> Type:
fields[field["name"]] = (field_type, Field(..., description=comment)) fields[field["name"]] = (field_type, Field(..., description=comment))
uc_type_json_str = json.dumps(uc_type_json, sort_keys=True) uc_type_json_str = json.dumps(uc_type_json, sort_keys=True)
type_hash = md5(uc_type_json_str.encode()).hexdigest()[:8] type_hash = md5(uc_type_json_str.encode()).hexdigest()[:8]
return create_model(f"Struct_{type_hash}", **fields) # type: ignore return create_model(f"Struct_{type_hash}", **fields) # type: ignore[call-overload]
else: else:
raise TypeError(f"Unknown type {uc_type_json}. Try upgrading this package.") raise TypeError(f"Unknown type {uc_type_json}. Try upgrading this package.")
@ -94,7 +94,7 @@ def _generate_args_schema(function: "FunctionInfo") -> Type[BaseModel]:
description = p.comment description = p.comment
default: Any = ... default: Any = ...
if p.parameter_default: if p.parameter_default:
pydantic_type = Optional[pydantic_type] # type: ignore pydantic_type = Optional[pydantic_type] # type: ignore[assignment]
default = None default = None
# TODO: Convert default value string to the correct type. # TODO: Convert default value string to the correct type.
# We might need to use statement execution API # We might need to use statement execution API
@ -108,9 +108,9 @@ def _generate_args_schema(function: "FunctionInfo") -> Type[BaseModel]:
pydantic_type, pydantic_type,
Field(default=default, description=description), Field(default=default, description=description),
) )
return create_model( return create_model( # type: ignore[call-overload]
f"{function.catalog_name}__{function.schema_name}__{function.name}__params", f"{function.catalog_name}__{function.schema_name}__{function.name}__params",
**fields, # type: ignore **fields,
) )

View File

@ -217,10 +217,10 @@ class NucliaUnderstandingAPI(BaseTool): # type: ignore[override, override]
logger.info(f"No matching id for {uuid}") logger.info(f"No matching id for {uuid}")
else: else:
self._results[matching_id]["status"] = "done" self._results[matching_id]["status"] = "done"
data = MessageToJson( data = MessageToJson( # type: ignore[call-arg]
pb, pb,
preserving_proto_field_name=True, preserving_proto_field_name=True,
including_default_value_fields=True, # type: ignore including_default_value_fields=True,
) )
self._results[matching_id]["data"] = data self._results[matching_id]["data"] = data

View File

@ -161,9 +161,7 @@ class ZapierNLARunAction(BaseTool): # type: ignore[override]
) )
ZapierNLARunAction.__doc__ = ( ZapierNLARunAction.__doc__ = ZapierNLAWrapper.run.__doc__ + ZapierNLARunAction.__doc__ # type: ignore[operator]
ZapierNLAWrapper.run.__doc__ + ZapierNLARunAction.__doc__ # type: ignore
)
# other useful actions # other useful actions
@ -210,5 +208,5 @@ class ZapierNLAListActions(BaseTool): # type: ignore[override]
ZapierNLAListActions.__doc__ = ( ZapierNLAListActions.__doc__ = (
ZapierNLAWrapper.list.__doc__ + ZapierNLAListActions.__doc__ # type: ignore ZapierNLAWrapper.list.__doc__ + ZapierNLAListActions.__doc__ # type: ignore[operator]
) )

View File

@ -50,7 +50,7 @@ class BingSearchAPIWrapper(BaseModel):
response = requests.get( response = requests.get(
self.bing_search_url, self.bing_search_url,
headers=headers, headers=headers,
params=params, # type: ignore params=params,
) )
response.raise_for_status() response.raise_for_status()
search_results = response.json() search_results = response.json()

View File

@ -82,7 +82,7 @@ class GoogleScholarAPIWrapper(BaseModel):
# 0 is the first page of results, 20 is the 2nd page of results, # 0 is the first page of results, 20 is the 2nd page of results,
# 40 is the 3rd page of results, etc. # 40 is the 3rd page of results, etc.
results = ( results = (
self.google_scholar_engine( # type: ignore self.google_scholar_engine(
{ {
"q": query, "q": query,
"start": page, "start": page,
@ -106,7 +106,7 @@ class GoogleScholarAPIWrapper(BaseModel):
): # From the last page we would only need top_k_results%20 results ): # From the last page we would only need top_k_results%20 results
# if k is not divisible by 20. # if k is not divisible by 20.
results = ( results = (
self.google_scholar_engine( # type: ignore self.google_scholar_engine(
{ {
"q": query, "q": query,
"start": page, "start": page,

View File

@ -49,7 +49,6 @@ class MetaphorSearchAPIWrapper(BaseModel):
"useAutoprompt": use_autoprompt, "useAutoprompt": use_autoprompt,
} }
response = requests.post( response = requests.post(
# type: ignore
f"{METAPHOR_API_URL}/search", f"{METAPHOR_API_URL}/search",
headers=headers, headers=headers,
json=params, json=params,

View File

@ -53,7 +53,7 @@ if TYPE_CHECKING:
try: try:
from openapi_pydantic import OpenAPI from openapi_pydantic import OpenAPI
except ImportError: except ImportError:
OpenAPI = object # type: ignore OpenAPI = object
class OpenAPISpec(OpenAPI): class OpenAPISpec(OpenAPI):

View File

@ -134,7 +134,7 @@ class NutritionAIAPI(BaseModel):
return requests.get( return requests.get(
self.nutritionai_api_url, self.nutritionai_api_url,
headers=self.auth_.headers, headers=self.auth_.headers,
params=params, # type: ignore params=params,
) )
def _api_call_results(self, search_term: str) -> dict: def _api_call_results(self, search_term: str) -> dict:

View File

@ -395,7 +395,7 @@ class SQLDatabase:
try: try:
# get the sample rows # get the sample rows
with self._engine.connect() as connection: with self._engine.connect() as connection:
sample_rows_result = connection.execute(command) # type: ignore sample_rows_result = connection.execute(command)
# shorten values in the sample rows # shorten values in the sample rows
sample_rows = list( sample_rows = list(
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result) map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)

View File

@ -123,10 +123,10 @@ class SteamWebAPIWrapper(BaseModel):
except ImportError: except ImportError:
raise ImportError("steamspypi library is not installed.") raise ImportError("steamspypi library is not installed.")
users_games = self.get_users_games(steam_id) users_games = self.get_users_games(steam_id)
result = {} # type: ignore result: dict[str, int] = {}
most_popular_genre = "" most_popular_genre = ""
most_popular_genre_count = 0 most_popular_genre_count = 0
for game in users_games["games"]: # type: ignore for game in users_games["games"]: # type: ignore[call-overload]
appid = game["appid"] appid = game["appid"]
data_request = {"request": "appdetails", "appid": appid} data_request = {"request": "appdetails", "appid": appid}
genreStore = steamspypi.download(data_request) genreStore = steamspypi.download(data_request)
@ -148,7 +148,7 @@ class SteamWebAPIWrapper(BaseModel):
sorted_data = sorted( sorted_data = sorted(
data.values(), key=lambda x: x.get("average_forever", 0), reverse=True data.values(), key=lambda x: x.get("average_forever", 0), reverse=True
) )
owned_games = [game["appid"] for game in users_games["games"]] # type: ignore owned_games = [game["appid"] for game in users_games["games"]] # type: ignore[call-overload]
remaining_games = [ remaining_games = [
game for game in sorted_data if game["appid"] not in owned_games game for game in sorted_data if game["appid"] not in owned_games
] ]

View File

@ -58,7 +58,6 @@ class TavilySearchAPIWrapper(BaseModel):
"include_images": include_images, "include_images": include_images,
} }
response = requests.post( response = requests.post(
# type: ignore
f"{TAVILY_API_URL}/search", f"{TAVILY_API_URL}/search",
json=params, json=params,
) )

View File

@ -240,7 +240,6 @@ class YouSearchAPIWrapper(BaseModel):
if self.endpoint_type == "snippet": if self.endpoint_type == "snippet":
self.endpoint_type = "search" self.endpoint_type = "search"
response = requests.get( response = requests.get(
# type: ignore
f"{YOU_API_URL}/{self.endpoint_type}", f"{YOU_API_URL}/{self.endpoint_type}",
params=params, params=params,
headers=headers, headers=headers,

View File

@ -71,4 +71,4 @@ def cosine_similarity_top_k(
top_k_idxs = top_k_idxs[np.argsort(score_array.ravel()[top_k_idxs])][::-1] top_k_idxs = top_k_idxs[np.argsort(score_array.ravel()[top_k_idxs])][::-1]
ret_idxs = np.unravel_index(top_k_idxs, score_array.shape) ret_idxs = np.unravel_index(top_k_idxs, score_array.shape)
scores = score_array.ravel()[top_k_idxs].tolist() scores = score_array.ravel()[top_k_idxs].tolist()
return list(zip(*ret_idxs)), scores # type: ignore return list(zip(*ret_idxs)), scores # type: ignore[return-value]

View File

@ -407,7 +407,7 @@ class AzureCosmosDBVectorSearch(VectorStore):
for t, m, embedding in zip(texts, metadatas, embeddings) for t, m, embedding in zip(texts, metadatas, embeddings)
] ]
# insert the documents in Cosmos DB # insert the documents in Cosmos DB
insert_result = self._collection.insert_many(to_insert) # type: ignore insert_result = self._collection.insert_many(to_insert)
return insert_result.inserted_ids return insert_result.inserted_ids
@classmethod @classmethod

View File

@ -1571,7 +1571,7 @@ class AzureSearch(VectorStore):
azure_search.add_embeddings(text_embeddings, metadatas, **kwargs) azure_search.add_embeddings(text_embeddings, metadatas, **kwargs)
return azure_search return azure_search
def as_retriever(self, **kwargs: Any) -> AzureSearchVectorStoreRetriever: # type: ignore def as_retriever(self, **kwargs: Any) -> AzureSearchVectorStoreRetriever: # type: ignore[override]
"""Return AzureSearchVectorStoreRetriever initialized from this VectorStore. """Return AzureSearchVectorStoreRetriever initialized from this VectorStore.
Args: Args:
@ -1781,7 +1781,7 @@ async def _areorder_results_with_maximal_marginal_relevance(
# Function can return -1 index # Function can return -1 index
if x == -1: if x == -1:
break break
ret.append((documents[x], scores[x])) # type: ignore ret.append((documents[x], scores[x]))
return ret return ret
@ -1816,7 +1816,7 @@ def _reorder_results_with_maximal_marginal_relevance(
# Function can return -1 index # Function can return -1 index
if x == -1: if x == -1:
break break
ret.append((documents[x], scores[x])) # type: ignore ret.append((documents[x], scores[x]))
return ret return ret

View File

@ -656,7 +656,7 @@ class BigQueryVectorSearch(VectorStore):
Returns: Returns:
List of Documents most similar to the query vector, with similarity scores. List of Documents most similar to the query vector, with similarity scores.
""" """
emb = self.embedding_model.embed_query(query) # type: ignore emb = self.embedding_model.embed_query(query)
return self.similarity_search_with_score_by_vector( return self.similarity_search_with_score_by_vector(
emb, k, filter, brute_force, fraction_lists_to_search, **kwargs emb, k, filter, brute_force, fraction_lists_to_search, **kwargs
) )
@ -738,9 +738,7 @@ class BigQueryVectorSearch(VectorStore):
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
query_embedding = self.embedding_model.embed_query( # type: ignore query_embedding = self.embedding_model.embed_query(query)
query
)
doc_tuples = self._search_with_score_and_embeddings_by_vector( doc_tuples = self._search_with_score_and_embeddings_by_vector(
query_embedding, fetch_k, filter, brute_force, fraction_lists_to_search query_embedding, fetch_k, filter, brute_force, fraction_lists_to_search
) )

View File

@ -183,7 +183,7 @@ class Clarifai(VectorStore):
try: try:
from clarifai.client.search import Search from clarifai.client.search import Search
from clarifai_grpc.grpc.api import resources_pb2 from clarifai_grpc.grpc.api import resources_pb2
from google.protobuf import json_format # type: ignore from google.protobuf import json_format
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"Could not import clarifai python package. " "Could not import clarifai python package. "

View File

@ -275,7 +275,7 @@ class DeepLake(VectorStore):
metadata=metadatas, metadata=metadatas,
embedding_data=texts, embedding_data=texts,
embedding_tensor="embedding", embedding_tensor="embedding",
embedding_function=self._embedding_function.embed_documents, # type: ignore embedding_function=self._embedding_function.embed_documents, # type: ignore[union-attr]
return_ids=True, return_ids=True,
**kwargs, **kwargs,
) )
@ -464,8 +464,8 @@ class DeepLake(VectorStore):
if use_maximal_marginal_relevance: if use_maximal_marginal_relevance:
lambda_mult = kwargs.get("lambda_mult", 0.5) lambda_mult = kwargs.get("lambda_mult", 0.5)
indices = maximal_marginal_relevance( # type: ignore indices = maximal_marginal_relevance(
embedding, # type: ignore embedding, # type: ignore[arg-type]
embeddings, embeddings,
k=min(k, len(texts)), k=min(k, len(texts)),
lambda_mult=lambda_mult, lambda_mult=lambda_mult,
@ -829,7 +829,7 @@ class DeepLake(VectorStore):
use_maximal_marginal_relevance=True, use_maximal_marginal_relevance=True,
lambda_mult=lambda_mult, lambda_mult=lambda_mult,
exec_option=exec_option, exec_option=exec_option,
embedding_function=embedding_function, # type: ignore embedding_function=embedding_function, # type: ignore[arg-type]
**kwargs, **kwargs,
) )

View File

@ -103,7 +103,7 @@ class DocArrayIndex(VectorStore, ABC):
Lower score represents more similarity. Lower score represents more similarity.
""" """
query_embedding = self.embedding.embed_query(query) query_embedding = self.embedding.embed_query(query)
query_doc = self.doc_cls(embedding=query_embedding) # type: ignore query_doc = self.doc_cls(embedding=query_embedding)
docs, scores = self.doc_index.find(query_doc, search_field="embedding", limit=k) docs, scores = self.doc_index.find(query_doc, search_field="embedding", limit=k)
result = [ result = [
@ -152,7 +152,7 @@ class DocArrayIndex(VectorStore, ABC):
List of Documents most similar to the query vector. List of Documents most similar to the query vector.
""" """
query_doc = self.doc_cls(embedding=embedding) # type: ignore query_doc = self.doc_cls(embedding=embedding)
docs = self.doc_index.find( docs = self.doc_index.find(
query_doc, search_field="embedding", limit=k query_doc, search_field="embedding", limit=k
).documents ).documents
@ -187,7 +187,7 @@ class DocArrayIndex(VectorStore, ABC):
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
query_embedding = self.embedding.embed_query(query) query_embedding = self.embedding.embed_query(query)
query_doc = self.doc_cls(embedding=query_embedding) # type: ignore query_doc = self.doc_cls(embedding=query_embedding)
docs = self.doc_index.find( docs = self.doc_index.find(
query_doc, search_field="embedding", limit=fetch_k query_doc, search_field="embedding", limit=fetch_k

View File

@ -71,7 +71,7 @@ class DocArrayHnswSearch(DocArrayIndex):
num_threads=num_threads, num_threads=num_threads,
**kwargs, **kwargs,
) )
doc_index = HnswDocumentIndex[doc_cls](work_dir=work_dir) # type: ignore doc_index = HnswDocumentIndex[doc_cls](work_dir=work_dir)
return cls(doc_index, embedding) return cls(doc_index, embedding)
@classmethod @classmethod

View File

@ -41,7 +41,7 @@ class DocArrayInMemorySearch(DocArrayIndex):
from docarray.index import InMemoryExactNNIndex from docarray.index import InMemoryExactNNIndex
doc_cls = cls._get_doc_cls(space=metric, **kwargs) doc_cls = cls._get_doc_cls(space=metric, **kwargs)
doc_index = InMemoryExactNNIndex[doc_cls]() # type: ignore doc_index = InMemoryExactNNIndex[doc_cls]()
return cls(doc_index, embedding) return cls(doc_index, embedding)
@classmethod @classmethod

View File

@ -265,7 +265,7 @@ class DocumentDBVectorSearch(VectorStore):
for t, m, embedding in zip(texts, metadatas, embeddings) for t, m, embedding in zip(texts, metadatas, embeddings)
] ]
# insert the documents in DocumentDB # insert the documents in DocumentDB
insert_result = self._collection.insert_many(to_insert) # type: ignore insert_result = self._collection.insert_many(to_insert)
return insert_result.inserted_ids return insert_result.inserted_ids
@classmethod @classmethod

View File

@ -220,7 +220,7 @@ class DuckDB(VectorStore):
except ImportError: except ImportError:
warnings.warn("You may need to `pip install pandas` to use this method.") warnings.warn("You may need to `pip install pandas` to use this method.")
embedding = self._embedding.embed_query(query) # type: ignore embedding = self._embedding.embed_query(query)
list_cosine_similarity = self.duckdb.FunctionExpression( list_cosine_similarity = self.duckdb.FunctionExpression(
"list_cosine_similarity", "list_cosine_similarity",
self.duckdb.ColumnExpression(self._vector_key), self.duckdb.ColumnExpression(self._vector_key),
@ -265,7 +265,7 @@ class DuckDB(VectorStore):
A list of Documents most similar to the query. A list of Documents most similar to the query.
""" """
embedding = self._embedding.embed_query(query) # type: ignore embedding = self._embedding.embed_query(query)
list_cosine_similarity = self.duckdb.FunctionExpression( list_cosine_similarity = self.duckdb.FunctionExpression(
"list_cosine_similarity", "list_cosine_similarity",
self.duckdb.ColumnExpression(self._vector_key), self.duckdb.ColumnExpression(self._vector_key),

View File

@ -254,7 +254,7 @@ repeated float %s = 1;
texts_l = list(texts) texts_l = list(texts)
if last_vector: if last_vector:
texts_l.pop() texts_l.pop()
embeds = self._embedding.embed_documents(texts_l) # type: ignore embeds = self._embedding.embed_documents(texts_l) # type: ignore[union-attr]
if last_vector: if last_vector:
embeds.append(last_vector) embeds.append(last_vector)
if not metadatas: if not metadatas:
@ -288,7 +288,7 @@ repeated float %s = 1;
Returns: Returns:
List[Tuple[Document, float]] List[Tuple[Document, float]]
""" """
embed = self._embedding.embed_query(query) # type: ignore embed = self._embedding.embed_query(query) # type: ignore[union-attr]
documents = self.similarity_search_with_score_by_vector(embedding=embed, k=k) documents = self.similarity_search_with_score_by_vector(embedding=embed, k=k)
return documents return documents

View File

@ -201,7 +201,7 @@ class LanceDB(VectorStore):
""" """
docs = [] docs = []
ids = ids or [str(uuid.uuid4()) for _ in texts] ids = ids or [str(uuid.uuid4()) for _ in texts]
embeddings = self._embedding.embed_documents(list(texts)) # type: ignore embeddings = self._embedding.embed_documents(list(texts)) # type: ignore[union-attr]
for idx, text in enumerate(texts): for idx, text in enumerate(texts):
embedding = embeddings[idx] embedding = embeddings[idx]
metadata = metadatas[idx] if metadatas else {"id": ids[idx]} metadata = metadatas[idx] if metadatas else {"id": ids[idx]}
@ -490,7 +490,7 @@ class LanceDB(VectorStore):
embedding = self._embedding.embed_query(query) embedding = self._embedding.embed_query(query)
_query = (embedding, query) _query = (embedding, query)
else: else:
_query = query # type: ignore _query = query # type: ignore[assignment]
res = self._query(_query, k, filter=filter, name=name, **kwargs) res = self._query(_query, k, filter=filter, name=name, **kwargs)
return self.results_to_docs(res, score=score) return self.results_to_docs(res, score=score)

View File

@ -58,9 +58,9 @@ def get_embedding_store(
embedding_type = None embedding_type = None
if distance_strategy == DistanceStrategy.HAMMING: if distance_strategy == DistanceStrategy.HAMMING:
embedding_type = sqlalchemy.INTEGER # type: ignore embedding_type = sqlalchemy.INTEGER
else: else:
embedding_type = sqlalchemy.REAL # type: ignore embedding_type = sqlalchemy.REAL # type: ignore[assignment]
DynamicBase = declarative_base(class_registry=dict()) # type: Any DynamicBase = declarative_base(class_registry=dict()) # type: Any
@ -74,7 +74,7 @@ def get_embedding_store(
cmetadata = sqlalchemy.Column(JSON, nullable=True) cmetadata = sqlalchemy.Column(JSON, nullable=True)
# custom_id : any user defined id # custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
embedding = sqlalchemy.Column(sqlalchemy.ARRAY(embedding_type)) # type: ignore embedding = sqlalchemy.Column(sqlalchemy.ARRAY(embedding_type)) # type: ignore[arg-type,var-annotated]
return EmbeddingStore return EmbeddingStore

View File

@ -397,7 +397,7 @@ class MomentoVectorIndex(VectorStore):
) )
selected = [response.hits[i].metadata for i in mmr_selected] selected = [response.hits[i].metadata for i in mmr_selected]
return [ return [
Document(page_content=metadata.pop(self.text_field, ""), metadata=metadata) # type: ignore Document(page_content=metadata.pop(self.text_field, ""), metadata=metadata)
for metadata in selected for metadata in selected
] ]
@ -484,6 +484,6 @@ class MomentoVectorIndex(VectorStore):
configuration=VectorIndexConfigurations.Default.latest(), configuration=VectorIndexConfigurations.Default.latest(),
credential_provider=CredentialProvider.from_string(api_key), credential_provider=CredentialProvider.from_string(api_key),
) )
vector_db = cls(embedding=embedding, client=client, **kwargs) # type: ignore vector_db = cls(embedding=embedding, client=client, **kwargs) # type: ignore[call-arg]
vector_db.add_texts(texts=texts, metadatas=metadatas, **kwargs) vector_db.add_texts(texts=texts, metadatas=metadatas, **kwargs)
return vector_db return vector_db

View File

@ -183,7 +183,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
for t, m, embedding in zip(texts, metadatas, embeddings) for t, m, embedding in zip(texts, metadatas, embeddings)
] ]
# insert the documents in MongoDB Atlas # insert the documents in MongoDB Atlas
insert_result = self._collection.insert_many(to_insert) # type: ignore insert_result = self._collection.insert_many(to_insert)
return insert_result.inserted_ids return insert_result.inserted_ids
def _similarity_search_with_score( def _similarity_search_with_score(

View File

@ -857,7 +857,7 @@ class OracleVS(VectorStore):
) )
documents.append((document, distance, current_embedding)) documents.append((document, distance, current_embedding))
return documents # type: ignore return documents
@_handle_exceptions @_handle_exceptions
def max_marginal_relevance_search_with_score_by_vector( def max_marginal_relevance_search_with_score_by_vector(

View File

@ -49,7 +49,7 @@ class CollectionStore(BaseModel):
@classmethod @classmethod
def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]: def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]:
return session.query(cls).filter(cls.name == name).first() # type: ignore return session.query(cls).filter(cls.name == name).first()
@classmethod @classmethod
def get_or_create( def get_or_create(
@ -88,7 +88,7 @@ class EmbeddingStore(BaseModel):
) )
collection = relationship(CollectionStore, back_populates="embeddings") collection = relationship(CollectionStore, back_populates="embeddings")
embedding = sqlalchemy.Column(sqlalchemy.ARRAY(sqlalchemy.REAL)) # type: ignore embedding = sqlalchemy.Column(sqlalchemy.ARRAY(sqlalchemy.REAL)) # type: ignore[var-annotated]
document = sqlalchemy.Column(sqlalchemy.String, nullable=True) document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata = sqlalchemy.Column(JSON, nullable=True) cmetadata = sqlalchemy.Column(JSON, nullable=True)

View File

@ -33,7 +33,7 @@ try:
from sqlalchemy import SQLColumnExpression from sqlalchemy import SQLColumnExpression
except ImportError: except ImportError:
# for sqlalchemy < 2 # for sqlalchemy < 2
SQLColumnExpression = Any # type: ignore SQLColumnExpression = Any # type: ignore[assignment,misc]
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
@ -126,7 +126,7 @@ def _get_embedding_collection_store(
def get_by_name( def get_by_name(
cls, session: Session, name: str cls, session: Session, name: str
) -> Optional["CollectionStore"]: ) -> Optional["CollectionStore"]:
return session.query(cls).filter(cls.name == name).first() # type: ignore return session.query(cls).filter(cls.name == name).first()
@classmethod @classmethod
def get_or_create( def get_or_create(
@ -956,7 +956,7 @@ class PGVector(VectorStore):
results: List[Any] = ( results: List[Any] = (
session.query( session.query(
self.EmbeddingStore, self.EmbeddingStore,
self.distance_strategy(embedding).label("distance"), # type: ignore self.distance_strategy(embedding).label("distance"),
) )
.filter(*filter_by) .filter(*filter_by)
.order_by(sqlalchemy.asc("distance")) .order_by(sqlalchemy.asc("distance"))

View File

@ -419,7 +419,7 @@ class Redis(VectorStore):
# type check for metadata # type check for metadata
if metadatas: if metadatas:
if isinstance(metadatas, list) and len(metadatas) != len(texts): # type: ignore if isinstance(metadatas, list) and len(metadatas) != len(texts):
raise ValueError("Number of metadatas must match number of texts") raise ValueError("Number of metadatas must match number of texts")
if not (isinstance(metadatas, list) and isinstance(metadatas[0], dict)): if not (isinstance(metadatas, list) and isinstance(metadatas[0], dict)):
raise ValueError("Metadatas must be a list of dicts") raise ValueError("Metadatas must be a list of dicts")
@ -427,7 +427,7 @@ class Redis(VectorStore):
generated_schema = _generate_field_schema(metadatas[0]) generated_schema = _generate_field_schema(metadatas[0])
if index_schema: if index_schema:
# read in the schema solely to compare to the generated schema # read in the schema solely to compare to the generated schema
user_schema = read_schema(index_schema) # type: ignore user_schema = read_schema(index_schema)
# the very rare case where a super user decides to pass the index # the very rare case where a super user decides to pass the index
# schema and a document loader is used that has metadata which # schema and a document loader is used that has metadata which
@ -722,7 +722,7 @@ class Redis(VectorStore):
# type check for metadata # type check for metadata
if metadatas: if metadatas:
if isinstance(metadatas, list) and len(metadatas) != len(texts): # type: ignore if isinstance(metadatas, list) and len(metadatas) != len(texts): # type: ignore[arg-type]
raise ValueError("Number of metadatas must match number of texts") raise ValueError("Number of metadatas must match number of texts")
if not (isinstance(metadatas, list) and isinstance(metadatas[0], dict)): if not (isinstance(metadatas, list) and isinstance(metadatas[0], dict)):
raise ValueError("Metadatas must be a list of dicts") raise ValueError("Metadatas must be a list of dicts")
@ -850,7 +850,7 @@ class Redis(VectorStore):
# Perform vector search # Perform vector search
# ignore type because redis-py is wrong about bytes # ignore type because redis-py is wrong about bytes
try: try:
results = self.client.ft(self.index_name).search(redis_query, params_dict) # type: ignore results = self.client.ft(self.index_name).search(redis_query, params_dict)
except redis.exceptions.ResponseError as e: except redis.exceptions.ResponseError as e:
# split error message and see if it starts with "Syntax" # split error message and see if it starts with "Syntax"
if str(e).split(" ")[0] == "Syntax": if str(e).split(" ")[0] == "Syntax":
@ -966,7 +966,7 @@ class Redis(VectorStore):
# Perform vector search # Perform vector search
# ignore type because redis-py is wrong about bytes # ignore type because redis-py is wrong about bytes
try: try:
results = self.client.ft(self.index_name).search(redis_query, params_dict) # type: ignore results = self.client.ft(self.index_name).search(redis_query, params_dict)
except redis.exceptions.ResponseError as e: except redis.exceptions.ResponseError as e:
# split error message and see if it starts with "Syntax" # split error message and see if it starts with "Syntax"
if str(e).split(" ")[0] == "Syntax": if str(e).split(" ")[0] == "Syntax":
@ -1206,7 +1206,7 @@ class Redis(VectorStore):
# read in schema (yaml file or dict) and # read in schema (yaml file or dict) and
# pass to the Pydantic validators # pass to the Pydantic validators
if index_schema: if index_schema:
schema_values = read_schema(index_schema) # type: ignore schema_values = read_schema(index_schema)
schema = RedisModel(**schema_values) schema = RedisModel(**schema_values)
# ensure user did not exclude the content field # ensure user did not exclude the content field
@ -1242,7 +1242,7 @@ class Redis(VectorStore):
def _create_index_if_not_exist(self, dim: int = 1536) -> None: def _create_index_if_not_exist(self, dim: int = 1536) -> None:
try: try:
from redis.commands.search.indexDefinition import ( # type: ignore from redis.commands.search.indexDefinition import (
IndexDefinition, IndexDefinition,
IndexType, IndexType,
) )

View File

@ -140,7 +140,7 @@ class RedisTag(RedisFilterField):
elif isinstance(other, str): elif isinstance(other, str):
other = [other] other = [other]
self._set_value(other, self.SUPPORTED_VAL_TYPES, operator) # type: ignore self._set_value(other, self.SUPPORTED_VAL_TYPES, operator) # type: ignore[arg-type]
@check_operator_misuse @check_operator_misuse
def __eq__( def __eq__(
@ -240,7 +240,7 @@ class RedisNum(RedisFilterField):
>>> from langchain_community.vectorstores.redis import RedisNum >>> from langchain_community.vectorstores.redis import RedisNum
>>> filter = RedisNum("zipcode") == 90210 >>> filter = RedisNum("zipcode") == 90210
""" """
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore[arg-type]
return RedisFilterExpression(str(self)) return RedisFilterExpression(str(self))
@check_operator_misuse @check_operator_misuse
@ -254,7 +254,7 @@ class RedisNum(RedisFilterField):
>>> from langchain_community.vectorstores.redis import RedisNum >>> from langchain_community.vectorstores.redis import RedisNum
>>> filter = RedisNum("zipcode") != 90210 >>> filter = RedisNum("zipcode") != 90210
""" """
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore[arg-type]
return RedisFilterExpression(str(self)) return RedisFilterExpression(str(self))
def __gt__(self, other: Union[int, float]) -> "RedisFilterExpression": def __gt__(self, other: Union[int, float]) -> "RedisFilterExpression":
@ -267,7 +267,7 @@ class RedisNum(RedisFilterField):
>>> from langchain_community.vectorstores.redis import RedisNum >>> from langchain_community.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") > 18 >>> filter = RedisNum("age") > 18
""" """
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GT) # type: ignore self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GT) # type: ignore[arg-type]
return RedisFilterExpression(str(self)) return RedisFilterExpression(str(self))
def __lt__(self, other: Union[int, float]) -> "RedisFilterExpression": def __lt__(self, other: Union[int, float]) -> "RedisFilterExpression":
@ -280,7 +280,7 @@ class RedisNum(RedisFilterField):
>>> from langchain_community.vectorstores.redis import RedisNum >>> from langchain_community.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") < 18 >>> filter = RedisNum("age") < 18
""" """
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LT) # type: ignore self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LT) # type: ignore[arg-type]
return RedisFilterExpression(str(self)) return RedisFilterExpression(str(self))
def __ge__(self, other: Union[int, float]) -> "RedisFilterExpression": def __ge__(self, other: Union[int, float]) -> "RedisFilterExpression":
@ -293,7 +293,7 @@ class RedisNum(RedisFilterField):
>>> from langchain_community.vectorstores.redis import RedisNum >>> from langchain_community.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") >= 18 >>> filter = RedisNum("age") >= 18
""" """
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GE) # type: ignore self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GE) # type: ignore[arg-type]
return RedisFilterExpression(str(self)) return RedisFilterExpression(str(self))
def __le__(self, other: Union[int, float]) -> "RedisFilterExpression": def __le__(self, other: Union[int, float]) -> "RedisFilterExpression":
@ -306,7 +306,7 @@ class RedisNum(RedisFilterField):
>>> from langchain_community.vectorstores.redis import RedisNum >>> from langchain_community.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") <= 18 >>> filter = RedisNum("age") <= 18
""" """
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LE) # type: ignore self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LE) # type: ignore[arg-type]
return RedisFilterExpression(str(self)) return RedisFilterExpression(str(self))
@ -336,7 +336,7 @@ class RedisText(RedisFilterField):
>>> from langchain_community.vectorstores.redis import RedisText >>> from langchain_community.vectorstores.redis import RedisText
>>> filter = RedisText("job") == "engineer" >>> filter = RedisText("job") == "engineer"
""" """
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore[arg-type]
return RedisFilterExpression(str(self)) return RedisFilterExpression(str(self))
@check_operator_misuse @check_operator_misuse
@ -350,7 +350,7 @@ class RedisText(RedisFilterField):
>>> from langchain_community.vectorstores.redis import RedisText >>> from langchain_community.vectorstores.redis import RedisText
>>> filter = RedisText("job") != "engineer" >>> filter = RedisText("job") != "engineer"
""" """
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore[arg-type]
return RedisFilterExpression(str(self)) return RedisFilterExpression(str(self))
def __mod__(self, other: str) -> "RedisFilterExpression": def __mod__(self, other: str) -> "RedisFilterExpression":
@ -366,7 +366,7 @@ class RedisText(RedisFilterField):
>>> filter = RedisText("job") % "engineer|doctor" # contains either term >>> filter = RedisText("job") % "engineer|doctor" # contains either term
>>> filter = RedisText("job") % "engineer doctor" # contains both terms >>> filter = RedisText("job") % "engineer doctor" # contains both terms
""" """
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LIKE) # type: ignore self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LIKE) # type: ignore[arg-type]
return RedisFilterExpression(str(self)) return RedisFilterExpression(str(self))
def __str__(self) -> str: def __str__(self) -> str:

View File

@ -14,7 +14,7 @@ from typing_extensions import TYPE_CHECKING, Literal
from langchain_community.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP from langchain_community.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP
if TYPE_CHECKING: if TYPE_CHECKING:
from redis.commands.search.field import ( # type: ignore from redis.commands.search.field import (
NumericField, NumericField,
TagField, TagField,
TextField, TextField,
@ -47,13 +47,13 @@ class TextFieldSchema(RedisField):
sortable: Optional[bool] = False sortable: Optional[bool] = False
def as_field(self) -> TextField: def as_field(self) -> TextField:
from redis.commands.search.field import TextField # type: ignore from redis.commands.search.field import TextField
return TextField( return TextField(
self.name, self.name,
weight=self.weight, weight=self.weight,
no_stem=self.no_stem, no_stem=self.no_stem,
phonetic_matcher=self.phonetic_matcher, # type: ignore phonetic_matcher=self.phonetic_matcher,
sortable=self.sortable, sortable=self.sortable,
no_index=self.no_index, no_index=self.no_index,
) )
@ -68,7 +68,7 @@ class TagFieldSchema(RedisField):
sortable: Optional[bool] = False sortable: Optional[bool] = False
def as_field(self) -> TagField: def as_field(self) -> TagField:
from redis.commands.search.field import TagField # type: ignore from redis.commands.search.field import TagField
return TagField( return TagField(
self.name, self.name,
@ -86,7 +86,7 @@ class NumericFieldSchema(RedisField):
sortable: Optional[bool] = False sortable: Optional[bool] = False
def as_field(self) -> NumericField: def as_field(self) -> NumericField:
from redis.commands.search.field import NumericField # type: ignore from redis.commands.search.field import NumericField
return NumericField(self.name, sortable=self.sortable, no_index=self.no_index) return NumericField(self.name, sortable=self.sortable, no_index=self.no_index)
@ -131,7 +131,7 @@ class FlatVectorField(RedisVectorField): # type: ignore[override]
block_size: Optional[int] = None block_size: Optional[int] = None
def as_field(self) -> VectorField: def as_field(self) -> VectorField:
from redis.commands.search.field import VectorField # type: ignore from redis.commands.search.field import VectorField
field_data = super()._fields() field_data = super()._fields()
if self.block_size is not None: if self.block_size is not None:
@ -149,7 +149,7 @@ class HNSWVectorField(RedisVectorField): # type: ignore[override]
epsilon: float = Field(default=0.01) epsilon: float = Field(default=0.01)
def as_field(self) -> VectorField: def as_field(self) -> VectorField:
from redis.commands.search.field import VectorField # type: ignore from redis.commands.search.field import VectorField
field_data = super()._fields() field_data = super()._fields()
field_data.update( field_data.update(
@ -193,9 +193,9 @@ class RedisModel(BaseModel):
# ignore types as pydantic is handling type validation and conversion # ignore types as pydantic is handling type validation and conversion
if vector_field["algorithm"] == "FLAT": if vector_field["algorithm"] == "FLAT":
self.vector.append(FlatVectorField(**vector_field)) # type: ignore self.vector.append(FlatVectorField(**vector_field))
elif vector_field["algorithm"] == "HNSW": elif vector_field["algorithm"] == "HNSW":
self.vector.append(HNSWVectorField(**vector_field)) # type: ignore self.vector.append(HNSWVectorField(**vector_field))
else: else:
raise ValueError( raise ValueError(
f"algorithm must be either FLAT or HNSW. Got " f"algorithm must be either FLAT or HNSW. Got "

View File

@ -257,7 +257,7 @@ class SupabaseVectorStore(VectorStore):
match_result = [ match_result = [
( (
Document( Document(
metadata=search.get("metadata", {}), # type: ignore metadata=search.get("metadata", {}),
page_content=search.get("content", ""), page_content=search.get("content", ""),
), ),
search.get("similarity", 0.0), search.get("similarity", 0.0),
@ -302,7 +302,7 @@ class SupabaseVectorStore(VectorStore):
match_result = [ match_result = [
( (
Document( Document(
metadata=search.get("metadata", {}), # type: ignore metadata=search.get("metadata", {}),
page_content=search.get("content", ""), page_content=search.get("content", ""),
), ),
search.get("similarity", 0.0), search.get("similarity", 0.0),
@ -351,7 +351,7 @@ class SupabaseVectorStore(VectorStore):
"id": ids[idx], "id": ids[idx],
"content": documents[idx].page_content, "content": documents[idx].page_content,
"embedding": embedding, "embedding": embedding,
"metadata": documents[idx].metadata, # type: ignore "metadata": documents[idx].metadata,
**kwargs, **kwargs,
} }
for idx, embedding in enumerate(vectors) for idx, embedding in enumerate(vectors)
@ -360,7 +360,7 @@ class SupabaseVectorStore(VectorStore):
for i in range(0, len(rows), chunk_size): for i in range(0, len(rows), chunk_size):
chunk = rows[i : i + chunk_size] chunk = rows[i : i + chunk_size]
result = client.from_(table_name).upsert(chunk).execute() # type: ignore result = client.from_(table_name).upsert(chunk).execute()
if len(result.data) == 0: if len(result.data) == 0:
raise Exception("Error inserting: No rows added") raise Exception("Error inserting: No rows added")

View File

@ -153,7 +153,7 @@ class UpstashVectorStore(VectorStore):
self._namespace = namespace self._namespace = namespace
@property @property
def embeddings(self) -> Optional[Union[Embeddings, bool]]: # type: ignore def embeddings(self) -> Optional[Union[Embeddings, bool]]: # type: ignore[override]
"""Access the query embedding object if available.""" """Access the query embedding object if available."""
return self._embeddings return self._embeddings
@ -730,7 +730,7 @@ class UpstashVectorStore(VectorStore):
) )
selected = [results[i].metadata for i in mmr_selected] selected = [results[i].metadata for i in mmr_selected]
return [ return [
Document(page_content=metadata.pop((self._text_key)), metadata=metadata) # type: ignore Document(page_content=metadata.pop((self._text_key)), metadata=metadata)
for metadata in selected for metadata in selected
] ]
@ -798,7 +798,7 @@ class UpstashVectorStore(VectorStore):
) )
selected = [results[i].metadata for i in mmr_selected] selected = [results[i].metadata for i in mmr_selected]
return [ return [
Document(page_content=metadata.pop((self._text_key)), metadata=metadata) # type: ignore Document(page_content=metadata.pop((self._text_key)), metadata=metadata)
for metadata in selected for metadata in selected
] ]

View File

@ -467,7 +467,7 @@ class Vectara(VectorStore):
} }
if config.lambda_val > 0: if config.lambda_val > 0:
body["query"][0]["corpusKey"][0]["lexicalInterpolationConfig"] = { # type: ignore body["query"][0]["corpusKey"][0]["lexicalInterpolationConfig"] = { # type: ignore[index]
"lambda": config.lambda_val "lambda": config.lambda_val
} }
@ -495,7 +495,7 @@ class Vectara(VectorStore):
} }
] ]
if chat: if chat:
body["query"][0]["summary"][0]["chat"] = { # type: ignore body["query"][0]["summary"][0]["chat"] = { # type: ignore[index]
"store": True, "store": True,
"conversationId": chat_conv_id, "conversationId": chat_conv_id,
} }

View File

@ -65,7 +65,7 @@ lint = [
] ]
dev = ["jupyter<2.0.0,>=1.0.0", "setuptools<68.0.0,>=67.6.1", "langchain-core"] dev = ["jupyter<2.0.0,>=1.0.0", "setuptools<68.0.0,>=67.6.1", "langchain-core"]
typing = [ typing = [
"mypy<2.0,>=1.12", "mypy<2.0,>=1.15",
"types-pyyaml<7.0.0.0,>=6.0.12.2", "types-pyyaml<7.0.0.0,>=6.0.12.2",
"types-requests<3.0.0.0,>=2.28.11.5", "types-requests<3.0.0.0,>=2.28.11.5",
"types-toml<1.0.0.0,>=0.10.8.1", "types-toml<1.0.0.0,>=0.10.8.1",
@ -103,7 +103,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin,cann" ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin,cann"
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I", "T201"] select = ["E", "F", "I", "PGH003", "T201"]
[tool.coverage.run] [tool.coverage.run]
omit = ["tests/*"] omit = ["tests/*"]

View File

@ -27,7 +27,7 @@ def init_gptcache_map(cache_obj: Any) -> None:
pre_embedding_func=get_prompt, pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=cache_path), data_manager=get_data_manager(data_path=cache_path),
) )
init_gptcache_map._i = i + 1 # type: ignore init_gptcache_map._i = i + 1 # type: ignore[attr-defined]
def init_gptcache_map_with_llm(cache_obj: Any, llm: str) -> None: def init_gptcache_map_with_llm(cache_obj: Any, llm: str) -> None:

View File

@ -37,7 +37,7 @@ def test_messages(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> Non
Message(content="message2", role="human", metadata={"key2": "value2"}), Message(content="message2", role="human", metadata={"key2": "value2"}),
], ],
) )
zep_chat.zep_client.memory.get_memory.return_value = mock_memory # type: ignore zep_chat.zep_client.memory.get_memory.return_value = mock_memory
result = zep_chat.messages result = zep_chat.messages
@ -52,25 +52,25 @@ def test_add_user_message(
mocker: MockerFixture, zep_chat: ZepChatMessageHistory mocker: MockerFixture, zep_chat: ZepChatMessageHistory
) -> None: ) -> None:
zep_chat.add_user_message("test message") zep_chat.add_user_message("test message")
zep_chat.zep_client.memory.add_memory.assert_called_once() # type: ignore zep_chat.zep_client.memory.add_memory.assert_called_once()
@pytest.mark.requires("zep_python") @pytest.mark.requires("zep_python")
def test_add_ai_message(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None: def test_add_ai_message(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
zep_chat.add_ai_message("test message") zep_chat.add_ai_message("test message")
zep_chat.zep_client.memory.add_memory.assert_called_once() # type: ignore zep_chat.zep_client.memory.add_memory.assert_called_once()
@pytest.mark.requires("zep_python") @pytest.mark.requires("zep_python")
def test_append(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None: def test_append(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
zep_chat.add_message(AIMessage(content="test message")) zep_chat.add_message(AIMessage(content="test message"))
zep_chat.zep_client.memory.add_memory.assert_called_once() # type: ignore zep_chat.zep_client.memory.add_memory.assert_called_once()
@pytest.mark.requires("zep_python") @pytest.mark.requires("zep_python")
def test_search(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None: def test_search(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
zep_chat.search("test query") zep_chat.search("test query")
zep_chat.zep_client.memory.search_memory.assert_called_once_with( # type: ignore zep_chat.zep_client.memory.search_memory.assert_called_once_with(
"test_session", mocker.ANY, limit=None "test_session", mocker.ANY, limit=None
) )
@ -78,6 +78,4 @@ def test_search(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
@pytest.mark.requires("zep_python") @pytest.mark.requires("zep_python")
def test_clear(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None: def test_clear(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
zep_chat.clear() zep_chat.clear()
zep_chat.zep_client.memory.delete_memory.assert_called_once_with( # type: ignore zep_chat.zep_client.memory.delete_memory.assert_called_once_with("test_session")
"test_session"
)

View File

@ -67,7 +67,7 @@ class AnswerWithJustification(BaseModel):
def test_chat_minimax_with_structured_output() -> None: def test_chat_minimax_with_structured_output() -> None:
"""Test MiniMaxChat with structured output.""" """Test MiniMaxChat with structured output."""
llm = MiniMaxChat() # type: ignore llm = MiniMaxChat() # type: ignore[call-arg]
structured_llm = llm.with_structured_output(AnswerWithJustification) structured_llm = llm.with_structured_output(AnswerWithJustification)
response = structured_llm.invoke( response = structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers" "What weighs more a pound of bricks or a pound of feathers"
@ -77,7 +77,7 @@ def test_chat_minimax_with_structured_output() -> None:
def test_chat_tongyi_with_structured_output_include_raw() -> None: def test_chat_tongyi_with_structured_output_include_raw() -> None:
"""Test MiniMaxChat with structured output.""" """Test MiniMaxChat with structured output."""
llm = MiniMaxChat() # type: ignore llm = MiniMaxChat() # type: ignore[call-arg]
structured_llm = llm.with_structured_output( structured_llm = llm.with_structured_output(
AnswerWithJustification, include_raw=True AnswerWithJustification, include_raw=True
) )

View File

@ -170,7 +170,7 @@ class GenerateUsername(BaseModel):
def test_tool_use() -> None: def test_tool_use() -> None:
llm = ChatTongyi(model="qwen-turbo", temperature=0) # type: ignore llm = ChatTongyi(model="qwen-turbo", temperature=0) # type: ignore[call-arg]
llm_with_tool = llm.bind_tools(tools=[GenerateUsername]) llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
msgs: List = [ msgs: List = [
HumanMessage(content="Sally has green hair, what would her username be?") HumanMessage(content="Sally has green hair, what would her username be?")
@ -187,7 +187,7 @@ def test_tool_use() -> None:
tool_msg = ToolMessage( tool_msg = ToolMessage(
content="sally_green_hair", content="sally_green_hair",
tool_call_id=ai_msg.tool_calls[0]["id"], # type: ignore tool_call_id=ai_msg.tool_calls[0]["id"],
name=ai_msg.tool_calls[0]["name"], name=ai_msg.tool_calls[0]["name"],
) )
msgs.extend([ai_msg, tool_msg]) msgs.extend([ai_msg, tool_msg])
@ -201,7 +201,7 @@ def test_tool_use() -> None:
gathered = message gathered = message
first = False first = False
else: else:
gathered = gathered + message # type: ignore gathered = gathered + message # type: ignore[assignment]
assert isinstance(gathered, AIMessageChunk) assert isinstance(gathered, AIMessageChunk)
streaming_tool_msg = ToolMessage( streaming_tool_msg = ToolMessage(
@ -215,7 +215,7 @@ def test_tool_use() -> None:
def test_manual_tool_call_msg() -> None: def test_manual_tool_call_msg() -> None:
"""Test passing in manually construct tool call message.""" """Test passing in manually construct tool call message."""
llm = ChatTongyi(model="qwen-turbo", temperature=0) # type: ignore llm = ChatTongyi(model="qwen-turbo", temperature=0) # type: ignore[call-arg]
llm_with_tool = llm.bind_tools(tools=[GenerateUsername]) llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
msgs: List = [ msgs: List = [
HumanMessage(content="Sally has green hair, what would her username be?"), HumanMessage(content="Sally has green hair, what would her username be?"),
@ -246,7 +246,7 @@ class AnswerWithJustification(BaseModel):
def test_chat_tongyi_with_structured_output() -> None: def test_chat_tongyi_with_structured_output() -> None:
"""Test ChatTongyi with structured output.""" """Test ChatTongyi with structured output."""
llm = ChatTongyi() # type: ignore llm = ChatTongyi() # type: ignore[call-arg]
structured_llm = llm.with_structured_output(AnswerWithJustification) structured_llm = llm.with_structured_output(AnswerWithJustification)
response = structured_llm.invoke( response = structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers" "What weighs more a pound of bricks or a pound of feathers"
@ -256,7 +256,7 @@ def test_chat_tongyi_with_structured_output() -> None:
def test_chat_tongyi_with_structured_output_include_raw() -> None: def test_chat_tongyi_with_structured_output_include_raw() -> None:
"""Test ChatTongyi with structured output.""" """Test ChatTongyi with structured output."""
llm = ChatTongyi() # type: ignore llm = ChatTongyi() # type: ignore[call-arg]
structured_llm = llm.with_structured_output( structured_llm = llm.with_structured_output(
AnswerWithJustification, include_raw=True AnswerWithJustification, include_raw=True
) )

View File

@ -75,7 +75,7 @@ async def test_vertexai_agenerate(model_name: str) -> None:
message = HumanMessage(content="Hello") message = HumanMessage(content="Hello")
response = await model.agenerate([[message]]) response = await model.agenerate([[message]])
assert isinstance(response, LLMResult) assert isinstance(response, LLMResult)
assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore[union-attr]
sync_response = model.generate([[message]]) sync_response = model.generate([[message]])
assert response.generations[0][0] == sync_response.generations[0][0] assert response.generations[0][0] == sync_response.generations[0][0]

View File

@ -1,4 +1,6 @@
from typing import Dict from __future__ import annotations
from typing import TYPE_CHECKING, Dict
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@ -6,6 +8,9 @@ from langchain_core.documents import Document
from langchain_community.document_loaders.quip import QuipLoader from langchain_community.document_loaders.quip import QuipLoader
if TYPE_CHECKING:
from collections.abc import Iterator
try: try:
from quip_api.quip import QuipClient # noqa: F401 from quip_api.quip import QuipClient # noqa: F401
@ -15,7 +20,7 @@ except ImportError:
@pytest.fixture @pytest.fixture
def mock_quip(): # type: ignore def mock_quip() -> Iterator[MagicMock]:
# mock quip_client # mock quip_client
with patch("quip_api.quip.QuipClient") as mock_quip: with patch("quip_api.quip.QuipClient") as mock_quip:
yield mock_quip yield mock_quip

View File

@ -105,7 +105,7 @@ def test_unstructured_api_file_loader_io_multiple_files() -> None:
files = [stack.enter_context(open(file_path, "rb")) for file_path in file_paths] files = [stack.enter_context(open(file_path, "rb")) for file_path in file_paths]
loader = UnstructuredAPIFileIOLoader( loader = UnstructuredAPIFileIOLoader(
file=files, # type: ignore file=files,
api_key="FAKE_API_KEY", api_key="FAKE_API_KEY",
strategy="fast", strategy="fast",
mode="elements", mode="elements",

Some files were not shown because too many files have changed in this diff Show More