mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-28 11:55:21 +00:00
community: Add ruff rule PGH003 (#30812)
See https://docs.astral.sh/ruff/rules/blanket-type-ignore/ --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
f005988e31
commit
ada740b5b9
@ -196,7 +196,7 @@ def create_sql_agent(
|
||||
]
|
||||
prompt = ChatPromptTemplate.from_messages(messages)
|
||||
agent = RunnableAgent(
|
||||
runnable=create_openai_functions_agent(llm, tools, prompt), # type: ignore
|
||||
runnable=create_openai_functions_agent(llm, tools, prompt), # type: ignore[arg-type]
|
||||
input_keys_arg=["input"],
|
||||
return_keys_arg=["output"],
|
||||
**kwargs,
|
||||
@ -211,9 +211,9 @@ def create_sql_agent(
|
||||
]
|
||||
prompt = ChatPromptTemplate.from_messages(messages)
|
||||
if agent_type == "openai-tools":
|
||||
runnable = create_openai_tools_agent(llm, tools, prompt) # type: ignore
|
||||
runnable = create_openai_tools_agent(llm, tools, prompt) # type: ignore[arg-type]
|
||||
else:
|
||||
runnable = create_tool_calling_agent(llm, tools, prompt) # type: ignore
|
||||
runnable = create_tool_calling_agent(llm, tools, prompt) # type: ignore[arg-type]
|
||||
agent = RunnableMultiActionAgent( # type: ignore[assignment]
|
||||
runnable=runnable,
|
||||
input_keys_arg=["input"],
|
||||
|
@ -135,7 +135,7 @@ def _get_assistants_tool(
|
||||
Dict[str, Any]: A dictionary of tools that are converted into OpenAI tools.
|
||||
"""
|
||||
if _is_assistants_builtin_tool(tool):
|
||||
return tool # type: ignore
|
||||
return tool # type: ignore[return-value]
|
||||
else:
|
||||
return convert_to_openai_tool(tool)
|
||||
|
||||
@ -288,7 +288,7 @@ class OpenAIAssistantV2Runnable(OpenAIAssistantRunnable):
|
||||
assistant = client.beta.assistants.create(
|
||||
name=name,
|
||||
instructions=instructions,
|
||||
tools=[_get_assistants_tool(tool) for tool in tools], # type: ignore
|
||||
tools=[_get_assistants_tool(tool) for tool in tools],
|
||||
tool_resources=tool_resources, # type: ignore[arg-type]
|
||||
model=model,
|
||||
extra_body=extra_body,
|
||||
@ -430,7 +430,7 @@ class OpenAIAssistantV2Runnable(OpenAIAssistantRunnable):
|
||||
assistant = await async_client.beta.assistants.create(
|
||||
name=name,
|
||||
instructions=instructions,
|
||||
tools=openai_tools, # type: ignore
|
||||
tools=openai_tools,
|
||||
tool_resources=tool_resources, # type: ignore[arg-type]
|
||||
model=model,
|
||||
)
|
||||
|
@ -238,7 +238,7 @@ class InMemoryCache(BaseCache):
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class FullLLMCache(Base): # type: ignore
|
||||
class FullLLMCache(Base): # type: ignore[misc,valid-type]
|
||||
"""SQLite table for full LLM Cache (all generations)."""
|
||||
|
||||
__tablename__ = "full_llm_cache"
|
||||
@ -261,7 +261,7 @@ class SQLAlchemyCache(BaseCache):
|
||||
"""Look up based on prompt and llm_string."""
|
||||
stmt = (
|
||||
select(self.cache_schema.response)
|
||||
.where(self.cache_schema.prompt == prompt) # type: ignore
|
||||
.where(self.cache_schema.prompt == prompt)
|
||||
.where(self.cache_schema.llm == llm_string)
|
||||
.order_by(self.cache_schema.idx)
|
||||
)
|
||||
@ -1531,7 +1531,7 @@ class CassandraSemanticCache(BaseCache):
|
||||
await self.table.aclear()
|
||||
|
||||
|
||||
class FullMd5LLMCache(Base): # type: ignore
|
||||
class FullMd5LLMCache(Base): # type: ignore[misc,valid-type]
|
||||
"""SQLite table for full LLM Cache (all generations)."""
|
||||
|
||||
__tablename__ = "full_md5_llm_cache"
|
||||
@ -1583,7 +1583,7 @@ class SQLAlchemyMd5Cache(BaseCache):
|
||||
def _delete_previous(self, session: Session, prompt: str, llm_string: str) -> None:
|
||||
stmt = (
|
||||
delete(self.cache_schema)
|
||||
.where(self.cache_schema.prompt_md5 == self.get_md5(prompt)) # type: ignore
|
||||
.where(self.cache_schema.prompt_md5 == self.get_md5(prompt))
|
||||
.where(self.cache_schema.llm == llm_string)
|
||||
.where(self.cache_schema.prompt == prompt)
|
||||
)
|
||||
@ -1593,7 +1593,7 @@ class SQLAlchemyMd5Cache(BaseCache):
|
||||
prompt_pd5 = self.get_md5(prompt)
|
||||
stmt = (
|
||||
select(self.cache_schema.response)
|
||||
.where(self.cache_schema.prompt_md5 == prompt_pd5) # type: ignore
|
||||
.where(self.cache_schema.prompt_md5 == prompt_pd5)
|
||||
.where(self.cache_schema.llm == llm_string)
|
||||
.where(self.cache_schema.prompt == prompt)
|
||||
.order_by(self.cache_schema.idx)
|
||||
@ -1796,7 +1796,7 @@ class _CachedAwaitable:
|
||||
def __await__(self) -> Generator:
|
||||
if self.result is _unset:
|
||||
self.result = yield from self.awaitable.__await__()
|
||||
return self.result # type: ignore
|
||||
return self.result # type: ignore[return-value]
|
||||
|
||||
|
||||
def _reawaitable(func: Callable) -> Callable:
|
||||
|
@ -584,7 +584,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
)
|
||||
_custom_metrics = custom_metrics if custom_metrics else self.custom_metrics
|
||||
|
||||
self.__init__( # type: ignore
|
||||
self.__init__( # type: ignore[misc]
|
||||
task_type=_task_type,
|
||||
workspace=_workspace,
|
||||
project_name=_project_name,
|
||||
|
@ -580,7 +580,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.temp_dir.cleanup()
|
||||
self.reset_callback_meta()
|
||||
if reset:
|
||||
self.__init__( # type: ignore
|
||||
self.__init__( # type: ignore[misc]
|
||||
job_type=job_type if job_type else self.job_type,
|
||||
project=project if project else self.project,
|
||||
entity=entity if entity else self.entity,
|
||||
|
@ -352,7 +352,7 @@ def create_structured_output_runnable(
|
||||
class _OutputFormatter(BaseModel):
|
||||
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501
|
||||
|
||||
output: output_schema # type: ignore
|
||||
output: output_schema # type: ignore[valid-type]
|
||||
|
||||
function = _OutputFormatter
|
||||
output_parser = output_parser or PydanticAttrOutputFunctionsParser(
|
||||
@ -537,7 +537,7 @@ def create_structured_output_chain(
|
||||
class _OutputFormatter(BaseModel):
|
||||
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501
|
||||
|
||||
output: output_schema # type: ignore
|
||||
output: output_schema # type: ignore[valid-type]
|
||||
|
||||
function = _OutputFormatter
|
||||
output_parser = output_parser or PydanticAttrOutputFunctionsParser(
|
||||
|
@ -316,7 +316,7 @@ class GraphCypherQAChain(Chain):
|
||||
MessagesPlaceholder(variable_name="function_response"),
|
||||
]
|
||||
)
|
||||
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore
|
||||
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore[operator]
|
||||
except (NotImplementedError, AttributeError):
|
||||
raise ValueError("Provided LLM does not support native tools/functions")
|
||||
else:
|
||||
@ -404,15 +404,15 @@ class GraphCypherQAChain(Chain):
|
||||
intermediate_steps.append({"context": context})
|
||||
if self.use_function_response:
|
||||
function_response = get_function_response(question, context)
|
||||
final_result = self.qa_chain.invoke( # type: ignore
|
||||
final_result = self.qa_chain.invoke( # type: ignore[assignment]
|
||||
{"question": question, "function_response": function_response},
|
||||
)
|
||||
else:
|
||||
result = self.qa_chain.invoke( # type: ignore
|
||||
result = self.qa_chain.invoke(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
final_result = result[self.qa_chain.output_key] # type: ignore
|
||||
final_result = result[self.qa_chain.output_key] # type: ignore[union-attr]
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
|
@ -225,11 +225,11 @@ class MemgraphQAChain(Chain):
|
||||
MessagesPlaceholder(variable_name="function_response"),
|
||||
]
|
||||
)
|
||||
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore
|
||||
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore[operator]
|
||||
except (NotImplementedError, AttributeError):
|
||||
raise ValueError("Provided LLM does not support native tools/functions")
|
||||
else:
|
||||
qa_chain = use_qa_llm_kwargs["prompt"] | qa_llm | StrOutputParser() # type: ignore
|
||||
qa_chain = use_qa_llm_kwargs["prompt"] | qa_llm | StrOutputParser()
|
||||
|
||||
prompt = use_cypher_llm_kwargs["prompt"]
|
||||
llm_to_use = cypher_llm if cypher_llm is not None else llm
|
||||
@ -300,11 +300,11 @@ class MemgraphQAChain(Chain):
|
||||
intermediate_steps.append({"context": context})
|
||||
if self.use_function_response:
|
||||
function_response = get_function_response(question, context)
|
||||
result = self.qa_chain.invoke( # type: ignore
|
||||
result = self.qa_chain.invoke(
|
||||
{"question": question, "function_response": function_response},
|
||||
)
|
||||
else:
|
||||
result = self.qa_chain.invoke( # type: ignore
|
||||
result = self.qa_chain.invoke(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
@ -67,11 +67,11 @@ def extract_cypher(text: str) -> str:
|
||||
|
||||
def use_simple_prompt(llm: BaseLanguageModel) -> bool:
|
||||
"""Decides whether to use the simple prompt"""
|
||||
if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore
|
||||
if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore[attr-defined]
|
||||
return True
|
||||
|
||||
# Bedrock anthropic
|
||||
if hasattr(llm, "model_id") and "anthropic" in llm.model_id: # type: ignore
|
||||
if hasattr(llm, "model_id") and "anthropic" in llm.model_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -313,8 +313,12 @@ class PebbloRetrievalQA(Chain):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_app_details( # type: ignore
|
||||
app_name: str, owner: str, description: str, llm: BaseLanguageModel, **kwargs
|
||||
def _get_app_details(
|
||||
app_name: str,
|
||||
owner: str,
|
||||
description: str,
|
||||
llm: BaseLanguageModel,
|
||||
**kwargs: Any,
|
||||
) -> App:
|
||||
"""Fetch app details. Internal method.
|
||||
Returns:
|
||||
|
@ -81,7 +81,7 @@ class CassandraChatMessageHistory(BaseChatMessageHistory):
|
||||
)
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||
"""Retrieve all session messages from DB"""
|
||||
# The latest are returned, in chronological order
|
||||
rows = self.table.get_partition(
|
||||
|
@ -35,7 +35,7 @@ class FileChatMessageHistory(BaseChatMessageHistory):
|
||||
)
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||
"""Retrieve the messages from the local file"""
|
||||
items = json.loads(self.file_path.read_text(encoding=self.encoding))
|
||||
messages = messages_from_dict(items)
|
||||
|
@ -334,7 +334,7 @@ class KafkaChatMessageHistory(BaseChatMessageHistory):
|
||||
)
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||
"""
|
||||
Retrieve the messages for the session, from Kafka topic continuously
|
||||
from last consumed message. This method is stateful and maintains
|
||||
|
@ -60,7 +60,7 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
self.collection.create_index("SessionId")
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||
"""Retrieve the messages from MongoDB"""
|
||||
from pymongo import errors
|
||||
|
||||
|
@ -65,7 +65,7 @@ class PostgresChatMessageHistory(BaseChatMessageHistory):
|
||||
self.connection.commit()
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||
"""Retrieve the messages from PostgreSQL"""
|
||||
query = (
|
||||
f"SELECT message FROM {self.table_name} WHERE session_id = %s ORDER BY id;"
|
||||
|
@ -215,7 +215,7 @@ class RocksetChatMessageHistory(BaseChatMessageHistory):
|
||||
self._create_empty_doc()
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||
"""Messages in this chat history."""
|
||||
return messages_from_dict(
|
||||
self._query(
|
||||
|
@ -212,7 +212,7 @@ class SingleStoreDBChatMessageHistory(BaseChatMessageHistory):
|
||||
conn.close()
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||
"""Retrieve the messages from SingleStoreDB"""
|
||||
self._create_table_if_not_exists()
|
||||
conn = self.connection_pool.connect()
|
||||
|
@ -47,7 +47,7 @@ try:
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
except ImportError:
|
||||
# dummy for sqlalchemy < 2
|
||||
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore
|
||||
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore[assignment,misc]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -242,7 +242,7 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
|
||||
self._table_created = True
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||
"""Retrieve all messages from db"""
|
||||
with self._make_sync_session() as session:
|
||||
result = (
|
||||
|
@ -51,7 +51,7 @@ class UpstashRedisChatMessageHistory(BaseChatMessageHistory):
|
||||
return self.key_prefix + self.session_id
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||
"""Retrieve the messages from Upstash Redis"""
|
||||
_items = self.redis_client.lrange(self.key, 0, -1)
|
||||
items = [json.loads(m) for m in _items[::-1]]
|
||||
|
@ -83,7 +83,7 @@ class XataChatMessageHistory(BaseChatMessageHistory):
|
||||
raise Exception(f"Error adding message to Xata: {r.status_code} {r}")
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||
r = self._client.data().query(
|
||||
self._table_name,
|
||||
payload={
|
||||
|
@ -87,7 +87,7 @@ class ZepChatMessageHistory(BaseChatMessageHistory):
|
||||
self.session_id = session_id
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||
"""Retrieve messages from Zep memory"""
|
||||
zep_memory: Optional[Memory] = self._get_memory()
|
||||
if not zep_memory:
|
||||
|
@ -134,7 +134,7 @@ class ZepCloudChatMessageHistory(BaseChatMessageHistory):
|
||||
self.summary_instruction = summary_instruction
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||
"""Retrieve messages from Zep memory"""
|
||||
zep_memory: Optional[Memory] = self._get_memory()
|
||||
if not zep_memory:
|
||||
|
@ -42,7 +42,7 @@ class ChatLiteLLMRouter(ChatLiteLLM):
|
||||
|
||||
def __init__(self, *, router: Any, **kwargs: Any) -> None:
|
||||
"""Construct Chat LiteLLM Router."""
|
||||
super().__init__(router=router, **kwargs) # type: ignore
|
||||
super().__init__(router=router, **kwargs) # type: ignore[call-arg]
|
||||
self.router = router
|
||||
|
||||
@property
|
||||
|
@ -815,4 +815,4 @@ def _convert_delta_to_message_chunk(
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role, id=id_)
|
||||
else:
|
||||
return default_class(content=content, id=id_) # type: ignore
|
||||
return default_class(content=content, id=id_) # type: ignore[call-arg]
|
||||
|
@ -716,7 +716,7 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
|
||||
if is_pydantic_schema:
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[schema], # type: ignore[list-item]
|
||||
first_tool_only=True, # type: ignore[list-item]
|
||||
first_tool_only=True,
|
||||
)
|
||||
else:
|
||||
output_parser = JsonOutputKeyToolsParser(
|
||||
|
@ -158,9 +158,9 @@ def _convert_delta_response_to_message_chunk(
|
||||
Optional[str],
|
||||
]:
|
||||
"""Converts delta response to message chunk"""
|
||||
_delta = response.choices[0].delta # type: ignore
|
||||
role = _delta.get("role", "") # type: ignore
|
||||
content = _delta.get("content", "") # type: ignore
|
||||
_delta = response.choices[0].delta
|
||||
role = _delta.get("role", "")
|
||||
content = _delta.get("content", "")
|
||||
additional_kwargs: Dict = {}
|
||||
finish_reasons: Optional[str] = response.choices[0].finish_reason
|
||||
|
||||
@ -398,7 +398,7 @@ class ChatPremAI(BaseChatModel, BaseModel):
|
||||
messages, template_id=kwargs["template_id"]
|
||||
)
|
||||
else:
|
||||
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore
|
||||
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages)
|
||||
|
||||
if system_prompt is not None and system_prompt != "":
|
||||
kwargs["system_prompt"] = system_prompt
|
||||
@ -425,9 +425,9 @@ class ChatPremAI(BaseChatModel, BaseModel):
|
||||
if "template_id" in kwargs:
|
||||
system_prompt, messages_to_pass = _messages_to_prompt_dict(
|
||||
messages, template_id=kwargs["template_id"]
|
||||
) # type: ignore
|
||||
)
|
||||
else:
|
||||
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore
|
||||
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages)
|
||||
|
||||
if stop is not None:
|
||||
logger.warning("stop is not supported in langchain streaming")
|
||||
|
@ -218,7 +218,7 @@ class BlackboardLoader(WebBaseLoader):
|
||||
loader = DirectoryLoader(
|
||||
path=self.folder_path,
|
||||
glob="*.pdf",
|
||||
loader_cls=PyPDFLoader, # type: ignore
|
||||
loader_cls=PyPDFLoader, # type: ignore[arg-type]
|
||||
)
|
||||
# Load the documents
|
||||
documents = loader.load()
|
||||
|
@ -35,7 +35,7 @@ class _CloudBlob(Blob):
|
||||
from cloudpathlib import AnyPath
|
||||
|
||||
if self.data is None and self.path:
|
||||
return AnyPath(self.path).read_text(encoding=self.encoding) # type: ignore
|
||||
return AnyPath(self.path).read_text(encoding=self.encoding)
|
||||
elif isinstance(self.data, bytes):
|
||||
return self.data.decode(self.encoding)
|
||||
elif isinstance(self.data, str):
|
||||
@ -52,7 +52,7 @@ class _CloudBlob(Blob):
|
||||
elif isinstance(self.data, str):
|
||||
return self.data.encode(self.encoding)
|
||||
elif self.data is None and self.path:
|
||||
return AnyPath(self.path).read_bytes() # type: ignore
|
||||
return AnyPath(self.path).read_bytes()
|
||||
else:
|
||||
raise ValueError(f"Unable to get bytes for blob {self}")
|
||||
|
||||
@ -64,7 +64,7 @@ class _CloudBlob(Blob):
|
||||
if isinstance(self.data, bytes):
|
||||
yield BytesIO(self.data)
|
||||
elif self.data is None and self.path:
|
||||
return AnyPath(self.path).read_bytes() # type: ignore
|
||||
return AnyPath(self.path).read_bytes()
|
||||
else:
|
||||
raise NotImplementedError(f"Unable to convert blob {self}")
|
||||
|
||||
@ -79,7 +79,7 @@ def _url_to_filename(url: str) -> str:
|
||||
url_parsed = urlparse(url)
|
||||
suffix = Path(url_parsed.path).suffix
|
||||
if url_parsed.scheme in ["s3", "az", "gs"]:
|
||||
with AnyPath(url).open("rb") as f: # type: ignore
|
||||
with AnyPath(url).open("rb") as f:
|
||||
temp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||
while True:
|
||||
buf = f.read()
|
||||
@ -116,7 +116,7 @@ def _make_iterator(
|
||||
|
||||
iterator = _with_tqdm
|
||||
else:
|
||||
iterator = iter # type: ignore
|
||||
iterator = iter # type: ignore[assignment]
|
||||
|
||||
return iterator
|
||||
|
||||
@ -220,7 +220,7 @@ class CloudBlobLoader(BlobLoader):
|
||||
|
||||
def _yield_paths(self) -> Iterable["AnyPath"]:
|
||||
"""Yield paths that match the requested pattern."""
|
||||
if self.path.is_file(): # type: ignore
|
||||
if self.path.is_file():
|
||||
yield self.path
|
||||
return
|
||||
|
||||
@ -269,7 +269,7 @@ class CloudBlobLoader(BlobLoader):
|
||||
Blob instance
|
||||
"""
|
||||
if mime_type is None and guess_type:
|
||||
_mimetype = mimetypes.guess_type(path)[0] if guess_type else None # type: ignore
|
||||
_mimetype = mimetypes.guess_type(path)[0] if guess_type else None
|
||||
else:
|
||||
_mimetype = mime_type
|
||||
|
||||
|
@ -252,7 +252,7 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
|
||||
files = self._fetch_files_recursive(service, folder_id)
|
||||
# If file types filter is provided, we'll filter by the file type.
|
||||
if file_types:
|
||||
_files = [f for f in files if f["mimeType"] in file_types] # type: ignore
|
||||
_files = [f for f in files if f["mimeType"] in file_types]
|
||||
else:
|
||||
_files = files
|
||||
|
||||
@ -261,14 +261,14 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
|
||||
if file["trashed"] and not self.load_trashed_files:
|
||||
continue
|
||||
elif file["mimeType"] == "application/vnd.google-apps.document":
|
||||
returns.append(self._load_document_from_id(file["id"])) # type: ignore
|
||||
returns.append(self._load_document_from_id(file["id"])) # type: ignore[arg-type]
|
||||
elif file["mimeType"] == "application/vnd.google-apps.spreadsheet":
|
||||
returns.extend(self._load_sheet_from_id(file["id"])) # type: ignore
|
||||
returns.extend(self._load_sheet_from_id(file["id"])) # type: ignore[arg-type]
|
||||
elif (
|
||||
file["mimeType"] == "application/pdf"
|
||||
or self.file_loader_cls is not None
|
||||
):
|
||||
returns.extend(self._load_file_from_id(file["id"])) # type: ignore
|
||||
returns.extend(self._load_file_from_id(file["id"])) # type: ignore[arg-type]
|
||||
else:
|
||||
pass
|
||||
return returns
|
||||
|
@ -267,7 +267,7 @@ class DocAIParser(BaseBlobParser):
|
||||
"""Initializes Long-Running Operations from their names."""
|
||||
try:
|
||||
from google.longrunning.operations_pb2 import (
|
||||
GetOperationRequest, # type: ignore
|
||||
GetOperationRequest,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
|
@ -59,9 +59,9 @@ class DocumentLoaderAsParser(BaseBlobParser):
|
||||
"""
|
||||
Use underlying DocumentLoader to lazily parse the blob.
|
||||
"""
|
||||
doc_loader = self.DocumentLoaderClass(
|
||||
doc_loader = self.DocumentLoaderClass( # type: ignore[call-arg]
|
||||
file_path=blob.path, **self.document_loader_kwargs
|
||||
) # type: ignore
|
||||
)
|
||||
for document in doc_loader.lazy_load():
|
||||
document.metadata.update(blob.metadata)
|
||||
yield document
|
||||
|
@ -107,7 +107,7 @@ class RapidOCRBlobParser(BaseImageBlobParser):
|
||||
"`rapidocr-onnxruntime` package not found, please install it with "
|
||||
"`pip install rapidocr-onnxruntime`"
|
||||
)
|
||||
ocr_result, _ = self.ocr(np.array(img)) # type: ignore
|
||||
ocr_result, _ = self.ocr(np.array(img)) # type: ignore[misc]
|
||||
content = ""
|
||||
if ocr_result:
|
||||
content = ("\n".join([text[1] for text in ocr_result])).strip()
|
||||
|
@ -82,7 +82,7 @@ class TreeSitterSegmenter(CodeSegmenter):
|
||||
)
|
||||
|
||||
for line_num in range(start_line + 1, end_line + 1):
|
||||
simplified_lines[line_num] = None # type: ignore
|
||||
simplified_lines[line_num] = None # type: ignore[call-overload]
|
||||
|
||||
processed_lines.update(lines)
|
||||
|
||||
|
@ -6,7 +6,7 @@ import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
|
||||
import requests # type: ignore
|
||||
import requests
|
||||
from langchain_core.document_loaders import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from pydantic import Field
|
||||
|
@ -78,7 +78,7 @@ class TrelloLoader(BaseLoader):
|
||||
"""
|
||||
|
||||
try:
|
||||
from trello import TrelloClient # type: ignore
|
||||
from trello import TrelloClient
|
||||
except ImportError as ex:
|
||||
raise ImportError(
|
||||
"Could not import trello python package. "
|
||||
@ -124,7 +124,7 @@ class TrelloLoader(BaseLoader):
|
||||
return board
|
||||
|
||||
def _card_to_doc(self, card: Card, list_dict: dict) -> Document:
|
||||
from bs4 import BeautifulSoup # type: ignore
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
text_content = ""
|
||||
if self.include_card_name:
|
||||
|
@ -245,8 +245,8 @@ def get_elements_from_api(
|
||||
from unstructured.partition.api import partition_multiple_via_api
|
||||
|
||||
_doc_elements = partition_multiple_via_api(
|
||||
filenames=file_path, # type: ignore
|
||||
files=file, # type: ignore
|
||||
filenames=file_path,
|
||||
files=file,
|
||||
api_key=api_key,
|
||||
api_url=api_url,
|
||||
**unstructured_kwargs,
|
||||
|
@ -393,7 +393,7 @@ class WebBaseLoader(BaseLoader):
|
||||
"https://python.langchain.com/api_reference/community/document_loaders/langchain_community.document_loaders.web_base.WebBaseLoader.html" # noqa: E501
|
||||
),
|
||||
)
|
||||
def aload(self) -> List[Document]: # type: ignore
|
||||
def aload(self) -> List[Document]: # type: ignore[override]
|
||||
"""Load text from the urls in web_path async into Documents."""
|
||||
|
||||
results = self.scrape_all(self.web_paths)
|
||||
|
@ -439,7 +439,7 @@ class HypotheticalDocumentEmbedder:
|
||||
)
|
||||
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H
|
||||
|
||||
return H(*args, **kwargs) # type: ignore
|
||||
return H(*args, **kwargs) # type: ignore[return-value]
|
||||
|
||||
@classmethod
|
||||
def from_llm(cls, *args: Any, **kwargs: Any) -> Any:
|
||||
|
@ -68,7 +68,7 @@ class InfinityEmbeddingsLocal(BaseModel, Embeddings):
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
||||
try:
|
||||
from infinity_emb import AsyncEmbeddingEngine # type: ignore
|
||||
from infinity_emb import AsyncEmbeddingEngine
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install the "
|
||||
|
@ -76,7 +76,7 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
def _embed(self, input: Any) -> List[List[float]]:
|
||||
# Call Jina AI Embedding API
|
||||
resp = self.session.post( # type: ignore
|
||||
resp = self.session.post(
|
||||
JINA_API_URL, json={"input": input, "model": self.model_name}
|
||||
).json()
|
||||
if "data" not in resp:
|
||||
@ -85,7 +85,7 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = resp["data"]
|
||||
|
||||
# Sort resulting embeddings by index
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])
|
||||
|
||||
# Return just the embeddings
|
||||
return [result["embedding"] for result in sorted_embeddings]
|
||||
|
@ -309,7 +309,7 @@ class AsyncOpenAITextEmbedEmbeddingClient:
|
||||
Raises:
|
||||
Exception: If the response status is not 200.
|
||||
"""
|
||||
async with session.post(**kwargs) as response: # type: ignore
|
||||
async with session.post(**kwargs) as response: # type: ignore[arg-type]
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"TextEmbed responded with an unexpected status message "
|
||||
|
@ -22,7 +22,7 @@ def ngram_overlap_score(source: List[str], example: List[str]) -> float:
|
||||
https://aclanthology.org/P02-1040.pdf
|
||||
"""
|
||||
from nltk.translate.bleu_score import (
|
||||
SmoothingFunction, # type: ignore
|
||||
SmoothingFunction,
|
||||
sentence_bleu,
|
||||
)
|
||||
|
||||
|
@ -54,7 +54,7 @@ try:
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
except ImportError:
|
||||
# dummy for sqlalchemy < 2
|
||||
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore
|
||||
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore[assignment,misc]
|
||||
|
||||
from langchain_community.indexes.base import RecordManager
|
||||
|
||||
@ -308,8 +308,8 @@ class SQLRecordManager(RecordManager):
|
||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||
set_=dict(
|
||||
# attr-defined type ignore
|
||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
||||
updated_at=insert_stmt.excluded.updated_at,
|
||||
group_id=insert_stmt.excluded.group_id,
|
||||
),
|
||||
)
|
||||
elif self.dialect == "postgresql":
|
||||
@ -322,8 +322,8 @@ class SQLRecordManager(RecordManager):
|
||||
"uix_key_namespace", # Name of constraint
|
||||
set_=dict(
|
||||
# attr-defined type ignore
|
||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
||||
updated_at=insert_stmt.excluded.updated_at,
|
||||
group_id=insert_stmt.excluded.group_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@ -383,8 +383,8 @@ class SQLRecordManager(RecordManager):
|
||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||
set_=dict(
|
||||
# attr-defined type ignore
|
||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
||||
updated_at=insert_stmt.excluded.updated_at,
|
||||
group_id=insert_stmt.excluded.group_id,
|
||||
),
|
||||
)
|
||||
elif self.dialect == "postgresql":
|
||||
@ -397,8 +397,8 @@ class SQLRecordManager(RecordManager):
|
||||
"uix_key_namespace", # Name of constraint
|
||||
set_=dict(
|
||||
# attr-defined type ignore
|
||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
||||
updated_at=insert_stmt.excluded.updated_at,
|
||||
group_id=insert_stmt.excluded.group_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
@ -486,10 +486,10 @@ class AzureMLBaseEndpoint(BaseModel):
|
||||
timeout = values.get("timeout", DEFAULT_TIMEOUT)
|
||||
|
||||
http_client = AzureMLEndpointClient(
|
||||
endpoint_url, # type: ignore
|
||||
endpoint_key.get_secret_value(), # type: ignore
|
||||
deployment_name, # type: ignore
|
||||
timeout, # type: ignore
|
||||
endpoint_url, # type: ignore[arg-type]
|
||||
endpoint_key.get_secret_value(), # type: ignore[union-attr]
|
||||
deployment_name, # type: ignore[arg-type]
|
||||
timeout,
|
||||
)
|
||||
|
||||
return http_client
|
||||
|
@ -201,7 +201,7 @@ class Beam(LLM): # type: ignore[override, override, override, override]
|
||||
def _deploy(self) -> str:
|
||||
"""Call to Beam."""
|
||||
try:
|
||||
import beam # type: ignore
|
||||
import beam
|
||||
|
||||
if beam.__path__ == "":
|
||||
raise ImportError
|
||||
|
@ -181,7 +181,7 @@ class _BaseGigaChat(Serializable):
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Count approximate number of tokens"""
|
||||
if self.use_api_for_tokens:
|
||||
return self.tokens_count([text])[0].tokens # type: ignore
|
||||
return self.tokens_count([text])[0].tokens
|
||||
else:
|
||||
return round(len(text) / 4.6)
|
||||
|
||||
|
@ -142,7 +142,7 @@ class HuggingFaceHub(LLM):
|
||||
if "error" in response:
|
||||
raise ValueError(f"Error raised by inference API: {response['error']}")
|
||||
|
||||
response_key = VALID_TASKS_DICT[self.task] # type: ignore
|
||||
response_key = VALID_TASKS_DICT[self.task] # type: ignore[index]
|
||||
if isinstance(response, list):
|
||||
text = response[0][response_key]
|
||||
else:
|
||||
|
@ -172,7 +172,7 @@ class IpexLLM(LLM):
|
||||
if not low_bit_model:
|
||||
if load_in_low_bit is not None:
|
||||
load_function_name = "from_pretrained"
|
||||
load_kwargs["load_in_low_bit"] = load_in_low_bit # type: ignore
|
||||
load_kwargs["load_in_low_bit"] = load_in_low_bit # type: ignore[assignment]
|
||||
else:
|
||||
load_function_name = "from_pretrained"
|
||||
load_kwargs["load_in_4bit"] = load_in_4bit
|
||||
|
@ -246,7 +246,7 @@ class BaseOpenAI(BaseLLM):
|
||||
http_client: Union[Any, None] = None
|
||||
"""Optional httpx.Client."""
|
||||
|
||||
def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
|
||||
def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore[misc]
|
||||
"""Initialize the OpenAI object."""
|
||||
model_name = data.get("model_name", "")
|
||||
if (
|
||||
|
@ -47,7 +47,7 @@ class YiLLM(LLM):
|
||||
def _post(self, request: Any) -> Any:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.yi_api_key.get_secret_value()}", # type: ignore
|
||||
"Authorization": f"Bearer {self.yi_api_key.get_secret_value()}", # type: ignore[union-attr]
|
||||
}
|
||||
|
||||
urls = []
|
||||
|
@ -161,11 +161,11 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
_result = super().parse_result(result)
|
||||
if self.args_only:
|
||||
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
|
||||
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore[union-attr]
|
||||
else:
|
||||
fn_name = _result["name"]
|
||||
_args = _result["arguments"]
|
||||
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore
|
||||
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore[index]
|
||||
return pydantic_args
|
||||
|
||||
|
||||
|
@ -53,9 +53,9 @@ class DeepLakeTranslator(Visitor):
|
||||
def _format_func(self, func: Union[Operator, Comparator]) -> str:
|
||||
self._validate_func(func)
|
||||
if isinstance(func, Operator):
|
||||
value = OPERATOR_TO_TQL[func.value] # type: ignore
|
||||
value = OPERATOR_TO_TQL[func.value] # type: ignore[index]
|
||||
elif isinstance(func, Comparator):
|
||||
value = COMPARATOR_TO_TQL[func.value] # type: ignore
|
||||
value = COMPARATOR_TO_TQL[func.value] # type: ignore[index]
|
||||
return f"{value}"
|
||||
|
||||
def visit_operation(self, operation: Operation) -> str:
|
||||
|
@ -42,9 +42,9 @@ class TimescaleVectorTranslator(Visitor):
|
||||
def _format_func(self, func: Union[Operator, Comparator]) -> str:
|
||||
self._validate_func(func)
|
||||
if isinstance(func, Operator):
|
||||
value = self.OPERATOR_MAP[func.value] # type: ignore
|
||||
value = self.OPERATOR_MAP[func.value] # type: ignore[index]
|
||||
elif isinstance(func, Comparator):
|
||||
value = self.COMPARATOR_MAP[func.value] # type: ignore
|
||||
value = self.COMPARATOR_MAP[func.value] # type: ignore[index]
|
||||
return f"{value}"
|
||||
|
||||
def visit_operation(self, operation: Operation) -> client.Predicates:
|
||||
|
@ -41,7 +41,7 @@ try:
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
except ImportError:
|
||||
# dummy for sqlalchemy < 2
|
||||
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore
|
||||
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore[assignment,misc]
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
@ -255,7 +255,7 @@ class SQLStore(BaseStore[str, bytes]):
|
||||
|
||||
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
||||
with self._make_sync_session() as session:
|
||||
for v in session.query(LangchainKeyValueStores).filter( # type: ignore
|
||||
for v in session.query(LangchainKeyValueStores).filter(
|
||||
LangchainKeyValueStores.namespace == self.namespace
|
||||
):
|
||||
if str(v.key).startswith(prefix or ""):
|
||||
|
@ -178,7 +178,7 @@ def execute_function(
|
||||
statement=parametrized_statement.statement,
|
||||
warehouse_id=warehouse_id,
|
||||
parameters=parametrized_statement.parameters,
|
||||
**execute_statement_args, # type: ignore
|
||||
**execute_statement_args,
|
||||
)
|
||||
if response.status and job_pending(response.status.state) and response.statement_id:
|
||||
statement_id = response.statement_id
|
||||
@ -197,7 +197,7 @@ def execute_function(
|
||||
f"status after {wait} seconds."
|
||||
)
|
||||
time.sleep(wait)
|
||||
response = ws.statement_execution.get_statement(statement_id) # type: ignore
|
||||
response = ws.statement_execution.get_statement(statement_id)
|
||||
if response.status is None or not job_pending(response.status.state):
|
||||
break
|
||||
wait_time += wait
|
||||
@ -228,7 +228,7 @@ def execute_function(
|
||||
if is_scalar(function):
|
||||
value = None
|
||||
if data_array and len(data_array) > 0 and len(data_array[0]) > 0:
|
||||
value = str(data_array[0][0]) # type: ignore
|
||||
value = str(data_array[0][0])
|
||||
return FunctionExecutionResult(
|
||||
format="SCALAR", value=value, truncated=truncated
|
||||
)
|
||||
|
@ -51,8 +51,8 @@ def _uc_type_to_pydantic_type(uc_type_json: Union[str, Dict[str, Any]]) -> Type:
|
||||
if tpe == "array":
|
||||
element_type = _uc_type_to_pydantic_type(uc_type_json["elementType"])
|
||||
if uc_type_json["containsNull"]:
|
||||
element_type = Optional[element_type] # type: ignore
|
||||
return List[element_type] # type: ignore
|
||||
element_type = Optional[element_type] # type: ignore[assignment]
|
||||
return List[element_type] # type: ignore[valid-type]
|
||||
elif tpe == "map":
|
||||
key_type = uc_type_json["keyType"]
|
||||
assert key_type == "string", TypeError(
|
||||
@ -60,14 +60,14 @@ def _uc_type_to_pydantic_type(uc_type_json: Union[str, Dict[str, Any]]) -> Type:
|
||||
)
|
||||
value_type = _uc_type_to_pydantic_type(uc_type_json["valueType"])
|
||||
if uc_type_json["valueContainsNull"]:
|
||||
value_type: Type = Optional[value_type] # type: ignore
|
||||
return Dict[str, value_type] # type: ignore
|
||||
value_type: Type = Optional[value_type] # type: ignore[no-redef]
|
||||
return Dict[str, value_type] # type: ignore[valid-type]
|
||||
elif tpe == "struct":
|
||||
fields = {}
|
||||
for field in uc_type_json["fields"]:
|
||||
field_type = _uc_type_to_pydantic_type(field["type"])
|
||||
if field.get("nullable"):
|
||||
field_type = Optional[field_type] # type: ignore
|
||||
field_type = Optional[field_type] # type: ignore[assignment]
|
||||
comment = (
|
||||
uc_type_json["metadata"].get("comment")
|
||||
if "metadata" in uc_type_json
|
||||
@ -76,7 +76,7 @@ def _uc_type_to_pydantic_type(uc_type_json: Union[str, Dict[str, Any]]) -> Type:
|
||||
fields[field["name"]] = (field_type, Field(..., description=comment))
|
||||
uc_type_json_str = json.dumps(uc_type_json, sort_keys=True)
|
||||
type_hash = md5(uc_type_json_str.encode()).hexdigest()[:8]
|
||||
return create_model(f"Struct_{type_hash}", **fields) # type: ignore
|
||||
return create_model(f"Struct_{type_hash}", **fields) # type: ignore[call-overload]
|
||||
else:
|
||||
raise TypeError(f"Unknown type {uc_type_json}. Try upgrading this package.")
|
||||
|
||||
@ -94,7 +94,7 @@ def _generate_args_schema(function: "FunctionInfo") -> Type[BaseModel]:
|
||||
description = p.comment
|
||||
default: Any = ...
|
||||
if p.parameter_default:
|
||||
pydantic_type = Optional[pydantic_type] # type: ignore
|
||||
pydantic_type = Optional[pydantic_type] # type: ignore[assignment]
|
||||
default = None
|
||||
# TODO: Convert default value string to the correct type.
|
||||
# We might need to use statement execution API
|
||||
@ -108,9 +108,9 @@ def _generate_args_schema(function: "FunctionInfo") -> Type[BaseModel]:
|
||||
pydantic_type,
|
||||
Field(default=default, description=description),
|
||||
)
|
||||
return create_model(
|
||||
return create_model( # type: ignore[call-overload]
|
||||
f"{function.catalog_name}__{function.schema_name}__{function.name}__params",
|
||||
**fields, # type: ignore
|
||||
**fields,
|
||||
)
|
||||
|
||||
|
||||
|
@ -217,10 +217,10 @@ class NucliaUnderstandingAPI(BaseTool): # type: ignore[override, override]
|
||||
logger.info(f"No matching id for {uuid}")
|
||||
else:
|
||||
self._results[matching_id]["status"] = "done"
|
||||
data = MessageToJson(
|
||||
data = MessageToJson( # type: ignore[call-arg]
|
||||
pb,
|
||||
preserving_proto_field_name=True,
|
||||
including_default_value_fields=True, # type: ignore
|
||||
including_default_value_fields=True,
|
||||
)
|
||||
self._results[matching_id]["data"] = data
|
||||
|
||||
|
@ -161,9 +161,7 @@ class ZapierNLARunAction(BaseTool): # type: ignore[override]
|
||||
)
|
||||
|
||||
|
||||
ZapierNLARunAction.__doc__ = (
|
||||
ZapierNLAWrapper.run.__doc__ + ZapierNLARunAction.__doc__ # type: ignore
|
||||
)
|
||||
ZapierNLARunAction.__doc__ = ZapierNLAWrapper.run.__doc__ + ZapierNLARunAction.__doc__ # type: ignore[operator]
|
||||
|
||||
|
||||
# other useful actions
|
||||
@ -210,5 +208,5 @@ class ZapierNLAListActions(BaseTool): # type: ignore[override]
|
||||
|
||||
|
||||
ZapierNLAListActions.__doc__ = (
|
||||
ZapierNLAWrapper.list.__doc__ + ZapierNLAListActions.__doc__ # type: ignore
|
||||
ZapierNLAWrapper.list.__doc__ + ZapierNLAListActions.__doc__ # type: ignore[operator]
|
||||
)
|
||||
|
@ -50,7 +50,7 @@ class BingSearchAPIWrapper(BaseModel):
|
||||
response = requests.get(
|
||||
self.bing_search_url,
|
||||
headers=headers,
|
||||
params=params, # type: ignore
|
||||
params=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
search_results = response.json()
|
||||
|
@ -82,7 +82,7 @@ class GoogleScholarAPIWrapper(BaseModel):
|
||||
# 0 is the first page of results, 20 is the 2nd page of results,
|
||||
# 40 is the 3rd page of results, etc.
|
||||
results = (
|
||||
self.google_scholar_engine( # type: ignore
|
||||
self.google_scholar_engine(
|
||||
{
|
||||
"q": query,
|
||||
"start": page,
|
||||
@ -106,7 +106,7 @@ class GoogleScholarAPIWrapper(BaseModel):
|
||||
): # From the last page we would only need top_k_results%20 results
|
||||
# if k is not divisible by 20.
|
||||
results = (
|
||||
self.google_scholar_engine( # type: ignore
|
||||
self.google_scholar_engine(
|
||||
{
|
||||
"q": query,
|
||||
"start": page,
|
||||
|
@ -49,7 +49,6 @@ class MetaphorSearchAPIWrapper(BaseModel):
|
||||
"useAutoprompt": use_autoprompt,
|
||||
}
|
||||
response = requests.post(
|
||||
# type: ignore
|
||||
f"{METAPHOR_API_URL}/search",
|
||||
headers=headers,
|
||||
json=params,
|
||||
|
@ -53,7 +53,7 @@ if TYPE_CHECKING:
|
||||
try:
|
||||
from openapi_pydantic import OpenAPI
|
||||
except ImportError:
|
||||
OpenAPI = object # type: ignore
|
||||
OpenAPI = object
|
||||
|
||||
|
||||
class OpenAPISpec(OpenAPI):
|
||||
|
@ -134,7 +134,7 @@ class NutritionAIAPI(BaseModel):
|
||||
return requests.get(
|
||||
self.nutritionai_api_url,
|
||||
headers=self.auth_.headers,
|
||||
params=params, # type: ignore
|
||||
params=params,
|
||||
)
|
||||
|
||||
def _api_call_results(self, search_term: str) -> dict:
|
||||
|
@ -395,7 +395,7 @@ class SQLDatabase:
|
||||
try:
|
||||
# get the sample rows
|
||||
with self._engine.connect() as connection:
|
||||
sample_rows_result = connection.execute(command) # type: ignore
|
||||
sample_rows_result = connection.execute(command)
|
||||
# shorten values in the sample rows
|
||||
sample_rows = list(
|
||||
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
|
||||
|
@ -123,10 +123,10 @@ class SteamWebAPIWrapper(BaseModel):
|
||||
except ImportError:
|
||||
raise ImportError("steamspypi library is not installed.")
|
||||
users_games = self.get_users_games(steam_id)
|
||||
result = {} # type: ignore
|
||||
result: dict[str, int] = {}
|
||||
most_popular_genre = ""
|
||||
most_popular_genre_count = 0
|
||||
for game in users_games["games"]: # type: ignore
|
||||
for game in users_games["games"]: # type: ignore[call-overload]
|
||||
appid = game["appid"]
|
||||
data_request = {"request": "appdetails", "appid": appid}
|
||||
genreStore = steamspypi.download(data_request)
|
||||
@ -148,7 +148,7 @@ class SteamWebAPIWrapper(BaseModel):
|
||||
sorted_data = sorted(
|
||||
data.values(), key=lambda x: x.get("average_forever", 0), reverse=True
|
||||
)
|
||||
owned_games = [game["appid"] for game in users_games["games"]] # type: ignore
|
||||
owned_games = [game["appid"] for game in users_games["games"]] # type: ignore[call-overload]
|
||||
remaining_games = [
|
||||
game for game in sorted_data if game["appid"] not in owned_games
|
||||
]
|
||||
|
@ -58,7 +58,6 @@ class TavilySearchAPIWrapper(BaseModel):
|
||||
"include_images": include_images,
|
||||
}
|
||||
response = requests.post(
|
||||
# type: ignore
|
||||
f"{TAVILY_API_URL}/search",
|
||||
json=params,
|
||||
)
|
||||
|
@ -240,7 +240,6 @@ class YouSearchAPIWrapper(BaseModel):
|
||||
if self.endpoint_type == "snippet":
|
||||
self.endpoint_type = "search"
|
||||
response = requests.get(
|
||||
# type: ignore
|
||||
f"{YOU_API_URL}/{self.endpoint_type}",
|
||||
params=params,
|
||||
headers=headers,
|
||||
|
@ -71,4 +71,4 @@ def cosine_similarity_top_k(
|
||||
top_k_idxs = top_k_idxs[np.argsort(score_array.ravel()[top_k_idxs])][::-1]
|
||||
ret_idxs = np.unravel_index(top_k_idxs, score_array.shape)
|
||||
scores = score_array.ravel()[top_k_idxs].tolist()
|
||||
return list(zip(*ret_idxs)), scores # type: ignore
|
||||
return list(zip(*ret_idxs)), scores # type: ignore[return-value]
|
||||
|
@ -407,7 +407,7 @@ class AzureCosmosDBVectorSearch(VectorStore):
|
||||
for t, m, embedding in zip(texts, metadatas, embeddings)
|
||||
]
|
||||
# insert the documents in Cosmos DB
|
||||
insert_result = self._collection.insert_many(to_insert) # type: ignore
|
||||
insert_result = self._collection.insert_many(to_insert)
|
||||
return insert_result.inserted_ids
|
||||
|
||||
@classmethod
|
||||
|
@ -1571,7 +1571,7 @@ class AzureSearch(VectorStore):
|
||||
azure_search.add_embeddings(text_embeddings, metadatas, **kwargs)
|
||||
return azure_search
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> AzureSearchVectorStoreRetriever: # type: ignore
|
||||
def as_retriever(self, **kwargs: Any) -> AzureSearchVectorStoreRetriever: # type: ignore[override]
|
||||
"""Return AzureSearchVectorStoreRetriever initialized from this VectorStore.
|
||||
|
||||
Args:
|
||||
@ -1781,7 +1781,7 @@ async def _areorder_results_with_maximal_marginal_relevance(
|
||||
# Function can return -1 index
|
||||
if x == -1:
|
||||
break
|
||||
ret.append((documents[x], scores[x])) # type: ignore
|
||||
ret.append((documents[x], scores[x]))
|
||||
|
||||
return ret
|
||||
|
||||
@ -1816,7 +1816,7 @@ def _reorder_results_with_maximal_marginal_relevance(
|
||||
# Function can return -1 index
|
||||
if x == -1:
|
||||
break
|
||||
ret.append((documents[x], scores[x])) # type: ignore
|
||||
ret.append((documents[x], scores[x]))
|
||||
|
||||
return ret
|
||||
|
||||
|
@ -656,7 +656,7 @@ class BigQueryVectorSearch(VectorStore):
|
||||
Returns:
|
||||
List of Documents most similar to the query vector, with similarity scores.
|
||||
"""
|
||||
emb = self.embedding_model.embed_query(query) # type: ignore
|
||||
emb = self.embedding_model.embed_query(query)
|
||||
return self.similarity_search_with_score_by_vector(
|
||||
emb, k, filter, brute_force, fraction_lists_to_search, **kwargs
|
||||
)
|
||||
@ -738,9 +738,7 @@ class BigQueryVectorSearch(VectorStore):
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
query_embedding = self.embedding_model.embed_query( # type: ignore
|
||||
query
|
||||
)
|
||||
query_embedding = self.embedding_model.embed_query(query)
|
||||
doc_tuples = self._search_with_score_and_embeddings_by_vector(
|
||||
query_embedding, fetch_k, filter, brute_force, fraction_lists_to_search
|
||||
)
|
||||
|
@ -183,7 +183,7 @@ class Clarifai(VectorStore):
|
||||
try:
|
||||
from clarifai.client.search import Search
|
||||
from clarifai_grpc.grpc.api import resources_pb2
|
||||
from google.protobuf import json_format # type: ignore
|
||||
from google.protobuf import json_format
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import clarifai python package. "
|
||||
|
@ -275,7 +275,7 @@ class DeepLake(VectorStore):
|
||||
metadata=metadatas,
|
||||
embedding_data=texts,
|
||||
embedding_tensor="embedding",
|
||||
embedding_function=self._embedding_function.embed_documents, # type: ignore
|
||||
embedding_function=self._embedding_function.embed_documents, # type: ignore[union-attr]
|
||||
return_ids=True,
|
||||
**kwargs,
|
||||
)
|
||||
@ -464,8 +464,8 @@ class DeepLake(VectorStore):
|
||||
|
||||
if use_maximal_marginal_relevance:
|
||||
lambda_mult = kwargs.get("lambda_mult", 0.5)
|
||||
indices = maximal_marginal_relevance( # type: ignore
|
||||
embedding, # type: ignore
|
||||
indices = maximal_marginal_relevance(
|
||||
embedding, # type: ignore[arg-type]
|
||||
embeddings,
|
||||
k=min(k, len(texts)),
|
||||
lambda_mult=lambda_mult,
|
||||
@ -829,7 +829,7 @@ class DeepLake(VectorStore):
|
||||
use_maximal_marginal_relevance=True,
|
||||
lambda_mult=lambda_mult,
|
||||
exec_option=exec_option,
|
||||
embedding_function=embedding_function, # type: ignore
|
||||
embedding_function=embedding_function, # type: ignore[arg-type]
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -103,7 +103,7 @@ class DocArrayIndex(VectorStore, ABC):
|
||||
Lower score represents more similarity.
|
||||
"""
|
||||
query_embedding = self.embedding.embed_query(query)
|
||||
query_doc = self.doc_cls(embedding=query_embedding) # type: ignore
|
||||
query_doc = self.doc_cls(embedding=query_embedding)
|
||||
docs, scores = self.doc_index.find(query_doc, search_field="embedding", limit=k)
|
||||
|
||||
result = [
|
||||
@ -152,7 +152,7 @@ class DocArrayIndex(VectorStore, ABC):
|
||||
List of Documents most similar to the query vector.
|
||||
"""
|
||||
|
||||
query_doc = self.doc_cls(embedding=embedding) # type: ignore
|
||||
query_doc = self.doc_cls(embedding=embedding)
|
||||
docs = self.doc_index.find(
|
||||
query_doc, search_field="embedding", limit=k
|
||||
).documents
|
||||
@ -187,7 +187,7 @@ class DocArrayIndex(VectorStore, ABC):
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
query_embedding = self.embedding.embed_query(query)
|
||||
query_doc = self.doc_cls(embedding=query_embedding) # type: ignore
|
||||
query_doc = self.doc_cls(embedding=query_embedding)
|
||||
|
||||
docs = self.doc_index.find(
|
||||
query_doc, search_field="embedding", limit=fetch_k
|
||||
|
@ -71,7 +71,7 @@ class DocArrayHnswSearch(DocArrayIndex):
|
||||
num_threads=num_threads,
|
||||
**kwargs,
|
||||
)
|
||||
doc_index = HnswDocumentIndex[doc_cls](work_dir=work_dir) # type: ignore
|
||||
doc_index = HnswDocumentIndex[doc_cls](work_dir=work_dir)
|
||||
return cls(doc_index, embedding)
|
||||
|
||||
@classmethod
|
||||
|
@ -41,7 +41,7 @@ class DocArrayInMemorySearch(DocArrayIndex):
|
||||
from docarray.index import InMemoryExactNNIndex
|
||||
|
||||
doc_cls = cls._get_doc_cls(space=metric, **kwargs)
|
||||
doc_index = InMemoryExactNNIndex[doc_cls]() # type: ignore
|
||||
doc_index = InMemoryExactNNIndex[doc_cls]()
|
||||
return cls(doc_index, embedding)
|
||||
|
||||
@classmethod
|
||||
|
@ -265,7 +265,7 @@ class DocumentDBVectorSearch(VectorStore):
|
||||
for t, m, embedding in zip(texts, metadatas, embeddings)
|
||||
]
|
||||
# insert the documents in DocumentDB
|
||||
insert_result = self._collection.insert_many(to_insert) # type: ignore
|
||||
insert_result = self._collection.insert_many(to_insert)
|
||||
return insert_result.inserted_ids
|
||||
|
||||
@classmethod
|
||||
|
@ -220,7 +220,7 @@ class DuckDB(VectorStore):
|
||||
except ImportError:
|
||||
warnings.warn("You may need to `pip install pandas` to use this method.")
|
||||
|
||||
embedding = self._embedding.embed_query(query) # type: ignore
|
||||
embedding = self._embedding.embed_query(query)
|
||||
list_cosine_similarity = self.duckdb.FunctionExpression(
|
||||
"list_cosine_similarity",
|
||||
self.duckdb.ColumnExpression(self._vector_key),
|
||||
@ -265,7 +265,7 @@ class DuckDB(VectorStore):
|
||||
A list of Documents most similar to the query.
|
||||
"""
|
||||
|
||||
embedding = self._embedding.embed_query(query) # type: ignore
|
||||
embedding = self._embedding.embed_query(query)
|
||||
list_cosine_similarity = self.duckdb.FunctionExpression(
|
||||
"list_cosine_similarity",
|
||||
self.duckdb.ColumnExpression(self._vector_key),
|
||||
|
@ -254,7 +254,7 @@ repeated float %s = 1;
|
||||
texts_l = list(texts)
|
||||
if last_vector:
|
||||
texts_l.pop()
|
||||
embeds = self._embedding.embed_documents(texts_l) # type: ignore
|
||||
embeds = self._embedding.embed_documents(texts_l) # type: ignore[union-attr]
|
||||
if last_vector:
|
||||
embeds.append(last_vector)
|
||||
if not metadatas:
|
||||
@ -288,7 +288,7 @@ repeated float %s = 1;
|
||||
Returns:
|
||||
List[Tuple[Document, float]]
|
||||
"""
|
||||
embed = self._embedding.embed_query(query) # type: ignore
|
||||
embed = self._embedding.embed_query(query) # type: ignore[union-attr]
|
||||
documents = self.similarity_search_with_score_by_vector(embedding=embed, k=k)
|
||||
return documents
|
||||
|
||||
|
@ -201,7 +201,7 @@ class LanceDB(VectorStore):
|
||||
"""
|
||||
docs = []
|
||||
ids = ids or [str(uuid.uuid4()) for _ in texts]
|
||||
embeddings = self._embedding.embed_documents(list(texts)) # type: ignore
|
||||
embeddings = self._embedding.embed_documents(list(texts)) # type: ignore[union-attr]
|
||||
for idx, text in enumerate(texts):
|
||||
embedding = embeddings[idx]
|
||||
metadata = metadatas[idx] if metadatas else {"id": ids[idx]}
|
||||
@ -490,7 +490,7 @@ class LanceDB(VectorStore):
|
||||
embedding = self._embedding.embed_query(query)
|
||||
_query = (embedding, query)
|
||||
else:
|
||||
_query = query # type: ignore
|
||||
_query = query # type: ignore[assignment]
|
||||
|
||||
res = self._query(_query, k, filter=filter, name=name, **kwargs)
|
||||
return self.results_to_docs(res, score=score)
|
||||
|
@ -58,9 +58,9 @@ def get_embedding_store(
|
||||
embedding_type = None
|
||||
|
||||
if distance_strategy == DistanceStrategy.HAMMING:
|
||||
embedding_type = sqlalchemy.INTEGER # type: ignore
|
||||
embedding_type = sqlalchemy.INTEGER
|
||||
else:
|
||||
embedding_type = sqlalchemy.REAL # type: ignore
|
||||
embedding_type = sqlalchemy.REAL # type: ignore[assignment]
|
||||
|
||||
DynamicBase = declarative_base(class_registry=dict()) # type: Any
|
||||
|
||||
@ -74,7 +74,7 @@ def get_embedding_store(
|
||||
cmetadata = sqlalchemy.Column(JSON, nullable=True)
|
||||
# custom_id : any user defined id
|
||||
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||
embedding = sqlalchemy.Column(sqlalchemy.ARRAY(embedding_type)) # type: ignore
|
||||
embedding = sqlalchemy.Column(sqlalchemy.ARRAY(embedding_type)) # type: ignore[arg-type,var-annotated]
|
||||
|
||||
return EmbeddingStore
|
||||
|
||||
|
@ -397,7 +397,7 @@ class MomentoVectorIndex(VectorStore):
|
||||
)
|
||||
selected = [response.hits[i].metadata for i in mmr_selected]
|
||||
return [
|
||||
Document(page_content=metadata.pop(self.text_field, ""), metadata=metadata) # type: ignore
|
||||
Document(page_content=metadata.pop(self.text_field, ""), metadata=metadata)
|
||||
for metadata in selected
|
||||
]
|
||||
|
||||
@ -484,6 +484,6 @@ class MomentoVectorIndex(VectorStore):
|
||||
configuration=VectorIndexConfigurations.Default.latest(),
|
||||
credential_provider=CredentialProvider.from_string(api_key),
|
||||
)
|
||||
vector_db = cls(embedding=embedding, client=client, **kwargs) # type: ignore
|
||||
vector_db = cls(embedding=embedding, client=client, **kwargs) # type: ignore[call-arg]
|
||||
vector_db.add_texts(texts=texts, metadatas=metadatas, **kwargs)
|
||||
return vector_db
|
||||
|
@ -183,7 +183,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
for t, m, embedding in zip(texts, metadatas, embeddings)
|
||||
]
|
||||
# insert the documents in MongoDB Atlas
|
||||
insert_result = self._collection.insert_many(to_insert) # type: ignore
|
||||
insert_result = self._collection.insert_many(to_insert)
|
||||
return insert_result.inserted_ids
|
||||
|
||||
def _similarity_search_with_score(
|
||||
|
@ -857,7 +857,7 @@ class OracleVS(VectorStore):
|
||||
)
|
||||
|
||||
documents.append((document, distance, current_embedding))
|
||||
return documents # type: ignore
|
||||
return documents
|
||||
|
||||
@_handle_exceptions
|
||||
def max_marginal_relevance_search_with_score_by_vector(
|
||||
|
@ -49,7 +49,7 @@ class CollectionStore(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]:
|
||||
return session.query(cls).filter(cls.name == name).first() # type: ignore
|
||||
return session.query(cls).filter(cls.name == name).first()
|
||||
|
||||
@classmethod
|
||||
def get_or_create(
|
||||
@ -88,7 +88,7 @@ class EmbeddingStore(BaseModel):
|
||||
)
|
||||
collection = relationship(CollectionStore, back_populates="embeddings")
|
||||
|
||||
embedding = sqlalchemy.Column(sqlalchemy.ARRAY(sqlalchemy.REAL)) # type: ignore
|
||||
embedding = sqlalchemy.Column(sqlalchemy.ARRAY(sqlalchemy.REAL)) # type: ignore[var-annotated]
|
||||
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||
cmetadata = sqlalchemy.Column(JSON, nullable=True)
|
||||
|
||||
|
@ -33,7 +33,7 @@ try:
|
||||
from sqlalchemy import SQLColumnExpression
|
||||
except ImportError:
|
||||
# for sqlalchemy < 2
|
||||
SQLColumnExpression = Any # type: ignore
|
||||
SQLColumnExpression = Any # type: ignore[assignment,misc]
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
@ -126,7 +126,7 @@ def _get_embedding_collection_store(
|
||||
def get_by_name(
|
||||
cls, session: Session, name: str
|
||||
) -> Optional["CollectionStore"]:
|
||||
return session.query(cls).filter(cls.name == name).first() # type: ignore
|
||||
return session.query(cls).filter(cls.name == name).first()
|
||||
|
||||
@classmethod
|
||||
def get_or_create(
|
||||
@ -956,7 +956,7 @@ class PGVector(VectorStore):
|
||||
results: List[Any] = (
|
||||
session.query(
|
||||
self.EmbeddingStore,
|
||||
self.distance_strategy(embedding).label("distance"), # type: ignore
|
||||
self.distance_strategy(embedding).label("distance"),
|
||||
)
|
||||
.filter(*filter_by)
|
||||
.order_by(sqlalchemy.asc("distance"))
|
||||
|
@ -419,7 +419,7 @@ class Redis(VectorStore):
|
||||
|
||||
# type check for metadata
|
||||
if metadatas:
|
||||
if isinstance(metadatas, list) and len(metadatas) != len(texts): # type: ignore
|
||||
if isinstance(metadatas, list) and len(metadatas) != len(texts):
|
||||
raise ValueError("Number of metadatas must match number of texts")
|
||||
if not (isinstance(metadatas, list) and isinstance(metadatas[0], dict)):
|
||||
raise ValueError("Metadatas must be a list of dicts")
|
||||
@ -427,7 +427,7 @@ class Redis(VectorStore):
|
||||
generated_schema = _generate_field_schema(metadatas[0])
|
||||
if index_schema:
|
||||
# read in the schema solely to compare to the generated schema
|
||||
user_schema = read_schema(index_schema) # type: ignore
|
||||
user_schema = read_schema(index_schema)
|
||||
|
||||
# the very rare case where a super user decides to pass the index
|
||||
# schema and a document loader is used that has metadata which
|
||||
@ -722,7 +722,7 @@ class Redis(VectorStore):
|
||||
|
||||
# type check for metadata
|
||||
if metadatas:
|
||||
if isinstance(metadatas, list) and len(metadatas) != len(texts): # type: ignore
|
||||
if isinstance(metadatas, list) and len(metadatas) != len(texts): # type: ignore[arg-type]
|
||||
raise ValueError("Number of metadatas must match number of texts")
|
||||
if not (isinstance(metadatas, list) and isinstance(metadatas[0], dict)):
|
||||
raise ValueError("Metadatas must be a list of dicts")
|
||||
@ -850,7 +850,7 @@ class Redis(VectorStore):
|
||||
# Perform vector search
|
||||
# ignore type because redis-py is wrong about bytes
|
||||
try:
|
||||
results = self.client.ft(self.index_name).search(redis_query, params_dict) # type: ignore
|
||||
results = self.client.ft(self.index_name).search(redis_query, params_dict)
|
||||
except redis.exceptions.ResponseError as e:
|
||||
# split error message and see if it starts with "Syntax"
|
||||
if str(e).split(" ")[0] == "Syntax":
|
||||
@ -966,7 +966,7 @@ class Redis(VectorStore):
|
||||
# Perform vector search
|
||||
# ignore type because redis-py is wrong about bytes
|
||||
try:
|
||||
results = self.client.ft(self.index_name).search(redis_query, params_dict) # type: ignore
|
||||
results = self.client.ft(self.index_name).search(redis_query, params_dict)
|
||||
except redis.exceptions.ResponseError as e:
|
||||
# split error message and see if it starts with "Syntax"
|
||||
if str(e).split(" ")[0] == "Syntax":
|
||||
@ -1206,7 +1206,7 @@ class Redis(VectorStore):
|
||||
# read in schema (yaml file or dict) and
|
||||
# pass to the Pydantic validators
|
||||
if index_schema:
|
||||
schema_values = read_schema(index_schema) # type: ignore
|
||||
schema_values = read_schema(index_schema)
|
||||
schema = RedisModel(**schema_values)
|
||||
|
||||
# ensure user did not exclude the content field
|
||||
@ -1242,7 +1242,7 @@ class Redis(VectorStore):
|
||||
|
||||
def _create_index_if_not_exist(self, dim: int = 1536) -> None:
|
||||
try:
|
||||
from redis.commands.search.indexDefinition import ( # type: ignore
|
||||
from redis.commands.search.indexDefinition import (
|
||||
IndexDefinition,
|
||||
IndexType,
|
||||
)
|
||||
|
@ -140,7 +140,7 @@ class RedisTag(RedisFilterField):
|
||||
elif isinstance(other, str):
|
||||
other = [other]
|
||||
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, operator) # type: ignore
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, operator) # type: ignore[arg-type]
|
||||
|
||||
@check_operator_misuse
|
||||
def __eq__(
|
||||
@ -240,7 +240,7 @@ class RedisNum(RedisFilterField):
|
||||
>>> from langchain_community.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("zipcode") == 90210
|
||||
"""
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore[arg-type]
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
@check_operator_misuse
|
||||
@ -254,7 +254,7 @@ class RedisNum(RedisFilterField):
|
||||
>>> from langchain_community.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("zipcode") != 90210
|
||||
"""
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore[arg-type]
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __gt__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||
@ -267,7 +267,7 @@ class RedisNum(RedisFilterField):
|
||||
>>> from langchain_community.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") > 18
|
||||
"""
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GT) # type: ignore
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GT) # type: ignore[arg-type]
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __lt__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||
@ -280,7 +280,7 @@ class RedisNum(RedisFilterField):
|
||||
>>> from langchain_community.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") < 18
|
||||
"""
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LT) # type: ignore
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LT) # type: ignore[arg-type]
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __ge__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||
@ -293,7 +293,7 @@ class RedisNum(RedisFilterField):
|
||||
>>> from langchain_community.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") >= 18
|
||||
"""
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GE) # type: ignore
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GE) # type: ignore[arg-type]
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __le__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||
@ -306,7 +306,7 @@ class RedisNum(RedisFilterField):
|
||||
>>> from langchain_community.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") <= 18
|
||||
"""
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LE) # type: ignore
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LE) # type: ignore[arg-type]
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
|
||||
@ -336,7 +336,7 @@ class RedisText(RedisFilterField):
|
||||
>>> from langchain_community.vectorstores.redis import RedisText
|
||||
>>> filter = RedisText("job") == "engineer"
|
||||
"""
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore[arg-type]
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
@check_operator_misuse
|
||||
@ -350,7 +350,7 @@ class RedisText(RedisFilterField):
|
||||
>>> from langchain_community.vectorstores.redis import RedisText
|
||||
>>> filter = RedisText("job") != "engineer"
|
||||
"""
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore[arg-type]
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __mod__(self, other: str) -> "RedisFilterExpression":
|
||||
@ -366,7 +366,7 @@ class RedisText(RedisFilterField):
|
||||
>>> filter = RedisText("job") % "engineer|doctor" # contains either term
|
||||
>>> filter = RedisText("job") % "engineer doctor" # contains both terms
|
||||
"""
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LIKE) # type: ignore
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LIKE) # type: ignore[arg-type]
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
@ -14,7 +14,7 @@ from typing_extensions import TYPE_CHECKING, Literal
|
||||
from langchain_community.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.commands.search.field import ( # type: ignore
|
||||
from redis.commands.search.field import (
|
||||
NumericField,
|
||||
TagField,
|
||||
TextField,
|
||||
@ -47,13 +47,13 @@ class TextFieldSchema(RedisField):
|
||||
sortable: Optional[bool] = False
|
||||
|
||||
def as_field(self) -> TextField:
|
||||
from redis.commands.search.field import TextField # type: ignore
|
||||
from redis.commands.search.field import TextField
|
||||
|
||||
return TextField(
|
||||
self.name,
|
||||
weight=self.weight,
|
||||
no_stem=self.no_stem,
|
||||
phonetic_matcher=self.phonetic_matcher, # type: ignore
|
||||
phonetic_matcher=self.phonetic_matcher,
|
||||
sortable=self.sortable,
|
||||
no_index=self.no_index,
|
||||
)
|
||||
@ -68,7 +68,7 @@ class TagFieldSchema(RedisField):
|
||||
sortable: Optional[bool] = False
|
||||
|
||||
def as_field(self) -> TagField:
|
||||
from redis.commands.search.field import TagField # type: ignore
|
||||
from redis.commands.search.field import TagField
|
||||
|
||||
return TagField(
|
||||
self.name,
|
||||
@ -86,7 +86,7 @@ class NumericFieldSchema(RedisField):
|
||||
sortable: Optional[bool] = False
|
||||
|
||||
def as_field(self) -> NumericField:
|
||||
from redis.commands.search.field import NumericField # type: ignore
|
||||
from redis.commands.search.field import NumericField
|
||||
|
||||
return NumericField(self.name, sortable=self.sortable, no_index=self.no_index)
|
||||
|
||||
@ -131,7 +131,7 @@ class FlatVectorField(RedisVectorField): # type: ignore[override]
|
||||
block_size: Optional[int] = None
|
||||
|
||||
def as_field(self) -> VectorField:
|
||||
from redis.commands.search.field import VectorField # type: ignore
|
||||
from redis.commands.search.field import VectorField
|
||||
|
||||
field_data = super()._fields()
|
||||
if self.block_size is not None:
|
||||
@ -149,7 +149,7 @@ class HNSWVectorField(RedisVectorField): # type: ignore[override]
|
||||
epsilon: float = Field(default=0.01)
|
||||
|
||||
def as_field(self) -> VectorField:
|
||||
from redis.commands.search.field import VectorField # type: ignore
|
||||
from redis.commands.search.field import VectorField
|
||||
|
||||
field_data = super()._fields()
|
||||
field_data.update(
|
||||
@ -193,9 +193,9 @@ class RedisModel(BaseModel):
|
||||
|
||||
# ignore types as pydantic is handling type validation and conversion
|
||||
if vector_field["algorithm"] == "FLAT":
|
||||
self.vector.append(FlatVectorField(**vector_field)) # type: ignore
|
||||
self.vector.append(FlatVectorField(**vector_field))
|
||||
elif vector_field["algorithm"] == "HNSW":
|
||||
self.vector.append(HNSWVectorField(**vector_field)) # type: ignore
|
||||
self.vector.append(HNSWVectorField(**vector_field))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"algorithm must be either FLAT or HNSW. Got "
|
||||
|
@ -257,7 +257,7 @@ class SupabaseVectorStore(VectorStore):
|
||||
match_result = [
|
||||
(
|
||||
Document(
|
||||
metadata=search.get("metadata", {}), # type: ignore
|
||||
metadata=search.get("metadata", {}),
|
||||
page_content=search.get("content", ""),
|
||||
),
|
||||
search.get("similarity", 0.0),
|
||||
@ -302,7 +302,7 @@ class SupabaseVectorStore(VectorStore):
|
||||
match_result = [
|
||||
(
|
||||
Document(
|
||||
metadata=search.get("metadata", {}), # type: ignore
|
||||
metadata=search.get("metadata", {}),
|
||||
page_content=search.get("content", ""),
|
||||
),
|
||||
search.get("similarity", 0.0),
|
||||
@ -351,7 +351,7 @@ class SupabaseVectorStore(VectorStore):
|
||||
"id": ids[idx],
|
||||
"content": documents[idx].page_content,
|
||||
"embedding": embedding,
|
||||
"metadata": documents[idx].metadata, # type: ignore
|
||||
"metadata": documents[idx].metadata,
|
||||
**kwargs,
|
||||
}
|
||||
for idx, embedding in enumerate(vectors)
|
||||
@ -360,7 +360,7 @@ class SupabaseVectorStore(VectorStore):
|
||||
for i in range(0, len(rows), chunk_size):
|
||||
chunk = rows[i : i + chunk_size]
|
||||
|
||||
result = client.from_(table_name).upsert(chunk).execute() # type: ignore
|
||||
result = client.from_(table_name).upsert(chunk).execute()
|
||||
|
||||
if len(result.data) == 0:
|
||||
raise Exception("Error inserting: No rows added")
|
||||
|
@ -153,7 +153,7 @@ class UpstashVectorStore(VectorStore):
|
||||
self._namespace = namespace
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Union[Embeddings, bool]]: # type: ignore
|
||||
def embeddings(self) -> Optional[Union[Embeddings, bool]]: # type: ignore[override]
|
||||
"""Access the query embedding object if available."""
|
||||
return self._embeddings
|
||||
|
||||
@ -730,7 +730,7 @@ class UpstashVectorStore(VectorStore):
|
||||
)
|
||||
selected = [results[i].metadata for i in mmr_selected]
|
||||
return [
|
||||
Document(page_content=metadata.pop((self._text_key)), metadata=metadata) # type: ignore
|
||||
Document(page_content=metadata.pop((self._text_key)), metadata=metadata)
|
||||
for metadata in selected
|
||||
]
|
||||
|
||||
@ -798,7 +798,7 @@ class UpstashVectorStore(VectorStore):
|
||||
)
|
||||
selected = [results[i].metadata for i in mmr_selected]
|
||||
return [
|
||||
Document(page_content=metadata.pop((self._text_key)), metadata=metadata) # type: ignore
|
||||
Document(page_content=metadata.pop((self._text_key)), metadata=metadata)
|
||||
for metadata in selected
|
||||
]
|
||||
|
||||
|
@ -467,7 +467,7 @@ class Vectara(VectorStore):
|
||||
}
|
||||
|
||||
if config.lambda_val > 0:
|
||||
body["query"][0]["corpusKey"][0]["lexicalInterpolationConfig"] = { # type: ignore
|
||||
body["query"][0]["corpusKey"][0]["lexicalInterpolationConfig"] = { # type: ignore[index]
|
||||
"lambda": config.lambda_val
|
||||
}
|
||||
|
||||
@ -495,7 +495,7 @@ class Vectara(VectorStore):
|
||||
}
|
||||
]
|
||||
if chat:
|
||||
body["query"][0]["summary"][0]["chat"] = { # type: ignore
|
||||
body["query"][0]["summary"][0]["chat"] = { # type: ignore[index]
|
||||
"store": True,
|
||||
"conversationId": chat_conv_id,
|
||||
}
|
||||
|
@ -65,7 +65,7 @@ lint = [
|
||||
]
|
||||
dev = ["jupyter<2.0.0,>=1.0.0", "setuptools<68.0.0,>=67.6.1", "langchain-core"]
|
||||
typing = [
|
||||
"mypy<2.0,>=1.12",
|
||||
"mypy<2.0,>=1.15",
|
||||
"types-pyyaml<7.0.0.0,>=6.0.12.2",
|
||||
"types-requests<3.0.0.0,>=2.28.11.5",
|
||||
"types-toml<1.0.0.0,>=0.10.8.1",
|
||||
@ -103,7 +103,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
|
||||
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin,cann"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "T201"]
|
||||
select = ["E", "F", "I", "PGH003", "T201"]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
@ -27,7 +27,7 @@ def init_gptcache_map(cache_obj: Any) -> None:
|
||||
pre_embedding_func=get_prompt,
|
||||
data_manager=get_data_manager(data_path=cache_path),
|
||||
)
|
||||
init_gptcache_map._i = i + 1 # type: ignore
|
||||
init_gptcache_map._i = i + 1 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def init_gptcache_map_with_llm(cache_obj: Any, llm: str) -> None:
|
||||
|
@ -37,7 +37,7 @@ def test_messages(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> Non
|
||||
Message(content="message2", role="human", metadata={"key2": "value2"}),
|
||||
],
|
||||
)
|
||||
zep_chat.zep_client.memory.get_memory.return_value = mock_memory # type: ignore
|
||||
zep_chat.zep_client.memory.get_memory.return_value = mock_memory
|
||||
|
||||
result = zep_chat.messages
|
||||
|
||||
@ -52,25 +52,25 @@ def test_add_user_message(
|
||||
mocker: MockerFixture, zep_chat: ZepChatMessageHistory
|
||||
) -> None:
|
||||
zep_chat.add_user_message("test message")
|
||||
zep_chat.zep_client.memory.add_memory.assert_called_once() # type: ignore
|
||||
zep_chat.zep_client.memory.add_memory.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.requires("zep_python")
|
||||
def test_add_ai_message(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
|
||||
zep_chat.add_ai_message("test message")
|
||||
zep_chat.zep_client.memory.add_memory.assert_called_once() # type: ignore
|
||||
zep_chat.zep_client.memory.add_memory.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.requires("zep_python")
|
||||
def test_append(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
|
||||
zep_chat.add_message(AIMessage(content="test message"))
|
||||
zep_chat.zep_client.memory.add_memory.assert_called_once() # type: ignore
|
||||
zep_chat.zep_client.memory.add_memory.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.requires("zep_python")
|
||||
def test_search(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
|
||||
zep_chat.search("test query")
|
||||
zep_chat.zep_client.memory.search_memory.assert_called_once_with( # type: ignore
|
||||
zep_chat.zep_client.memory.search_memory.assert_called_once_with(
|
||||
"test_session", mocker.ANY, limit=None
|
||||
)
|
||||
|
||||
@ -78,6 +78,4 @@ def test_search(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
|
||||
@pytest.mark.requires("zep_python")
|
||||
def test_clear(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
|
||||
zep_chat.clear()
|
||||
zep_chat.zep_client.memory.delete_memory.assert_called_once_with( # type: ignore
|
||||
"test_session"
|
||||
)
|
||||
zep_chat.zep_client.memory.delete_memory.assert_called_once_with("test_session")
|
||||
|
@ -67,7 +67,7 @@ class AnswerWithJustification(BaseModel):
|
||||
|
||||
def test_chat_minimax_with_structured_output() -> None:
|
||||
"""Test MiniMaxChat with structured output."""
|
||||
llm = MiniMaxChat() # type: ignore
|
||||
llm = MiniMaxChat() # type: ignore[call-arg]
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification)
|
||||
response = structured_llm.invoke(
|
||||
"What weighs more a pound of bricks or a pound of feathers"
|
||||
@ -77,7 +77,7 @@ def test_chat_minimax_with_structured_output() -> None:
|
||||
|
||||
def test_chat_tongyi_with_structured_output_include_raw() -> None:
|
||||
"""Test MiniMaxChat with structured output."""
|
||||
llm = MiniMaxChat() # type: ignore
|
||||
llm = MiniMaxChat() # type: ignore[call-arg]
|
||||
structured_llm = llm.with_structured_output(
|
||||
AnswerWithJustification, include_raw=True
|
||||
)
|
||||
|
@ -170,7 +170,7 @@ class GenerateUsername(BaseModel):
|
||||
|
||||
|
||||
def test_tool_use() -> None:
|
||||
llm = ChatTongyi(model="qwen-turbo", temperature=0) # type: ignore
|
||||
llm = ChatTongyi(model="qwen-turbo", temperature=0) # type: ignore[call-arg]
|
||||
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
|
||||
msgs: List = [
|
||||
HumanMessage(content="Sally has green hair, what would her username be?")
|
||||
@ -187,7 +187,7 @@ def test_tool_use() -> None:
|
||||
|
||||
tool_msg = ToolMessage(
|
||||
content="sally_green_hair",
|
||||
tool_call_id=ai_msg.tool_calls[0]["id"], # type: ignore
|
||||
tool_call_id=ai_msg.tool_calls[0]["id"],
|
||||
name=ai_msg.tool_calls[0]["name"],
|
||||
)
|
||||
msgs.extend([ai_msg, tool_msg])
|
||||
@ -201,7 +201,7 @@ def test_tool_use() -> None:
|
||||
gathered = message
|
||||
first = False
|
||||
else:
|
||||
gathered = gathered + message # type: ignore
|
||||
gathered = gathered + message # type: ignore[assignment]
|
||||
assert isinstance(gathered, AIMessageChunk)
|
||||
|
||||
streaming_tool_msg = ToolMessage(
|
||||
@ -215,7 +215,7 @@ def test_tool_use() -> None:
|
||||
|
||||
def test_manual_tool_call_msg() -> None:
|
||||
"""Test passing in manually construct tool call message."""
|
||||
llm = ChatTongyi(model="qwen-turbo", temperature=0) # type: ignore
|
||||
llm = ChatTongyi(model="qwen-turbo", temperature=0) # type: ignore[call-arg]
|
||||
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
|
||||
msgs: List = [
|
||||
HumanMessage(content="Sally has green hair, what would her username be?"),
|
||||
@ -246,7 +246,7 @@ class AnswerWithJustification(BaseModel):
|
||||
|
||||
def test_chat_tongyi_with_structured_output() -> None:
|
||||
"""Test ChatTongyi with structured output."""
|
||||
llm = ChatTongyi() # type: ignore
|
||||
llm = ChatTongyi() # type: ignore[call-arg]
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification)
|
||||
response = structured_llm.invoke(
|
||||
"What weighs more a pound of bricks or a pound of feathers"
|
||||
@ -256,7 +256,7 @@ def test_chat_tongyi_with_structured_output() -> None:
|
||||
|
||||
def test_chat_tongyi_with_structured_output_include_raw() -> None:
|
||||
"""Test ChatTongyi with structured output."""
|
||||
llm = ChatTongyi() # type: ignore
|
||||
llm = ChatTongyi() # type: ignore[call-arg]
|
||||
structured_llm = llm.with_structured_output(
|
||||
AnswerWithJustification, include_raw=True
|
||||
)
|
||||
|
@ -75,7 +75,7 @@ async def test_vertexai_agenerate(model_name: str) -> None:
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await model.agenerate([[message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore
|
||||
assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore[union-attr]
|
||||
|
||||
sync_response = model.generate([[message]])
|
||||
assert response.generations[0][0] == sync_response.generations[0][0]
|
||||
|
@ -1,4 +1,6 @@
|
||||
from typing import Dict
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -6,6 +8,9 @@ from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.document_loaders.quip import QuipLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
try:
|
||||
from quip_api.quip import QuipClient # noqa: F401
|
||||
|
||||
@ -15,7 +20,7 @@ except ImportError:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_quip(): # type: ignore
|
||||
def mock_quip() -> Iterator[MagicMock]:
|
||||
# mock quip_client
|
||||
with patch("quip_api.quip.QuipClient") as mock_quip:
|
||||
yield mock_quip
|
||||
|
@ -105,7 +105,7 @@ def test_unstructured_api_file_loader_io_multiple_files() -> None:
|
||||
files = [stack.enter_context(open(file_path, "rb")) for file_path in file_paths]
|
||||
|
||||
loader = UnstructuredAPIFileIOLoader(
|
||||
file=files, # type: ignore
|
||||
file=files,
|
||||
api_key="FAKE_API_KEY",
|
||||
strategy="fast",
|
||||
mode="elements",
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user