infra: add -p to mkdir in lint steps (#17013)

Previously, if this did not find a mypy cache then it wouldnt run

this makes it always run

adding mypy ignore comments with existing uncaught issues to unblock other prs

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2024-02-05 11:22:06 -08:00 committed by GitHub
parent db6af21395
commit 4eda647fdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
103 changed files with 378 additions and 369 deletions

View File

@ -86,7 +86,7 @@ jobs:
with: with:
path: | path: |
${{ env.WORKDIR }}/.mypy_cache ${{ env.WORKDIR }}/.mypy_cache
key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', env.WORKDIR)) }} key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }}
- name: Analysing the code with our lint - name: Analysing the code with our lint
@ -105,7 +105,7 @@ jobs:
# It doesn't matter how you change it, any change will cause a cache-bust. # It doesn't matter how you change it, any change will cause a cache-bust.
working-directory: ${{ inputs.working-directory }} working-directory: ${{ inputs.working-directory }}
run: | run: |
poetry install --with test,test_integration poetry install --with test
- name: Get .mypy_cache_test to speed up mypy - name: Get .mypy_cache_test to speed up mypy
uses: actions/cache@v3 uses: actions/cache@v3
@ -114,7 +114,7 @@ jobs:
with: with:
path: | path: |
${{ env.WORKDIR }}/.mypy_cache_test ${{ env.WORKDIR }}/.mypy_cache_test
key: mypy-test-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', env.WORKDIR)) }} key: mypy-test-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }}
- name: Analysing the code with our lint - name: Analysing the code with our lint
working-directory: ${{ inputs.working-directory }} working-directory: ${{ inputs.working-directory }}

View File

@ -41,7 +41,7 @@ lint lint_diff lint_package lint_tests:
poetry run ruff . poetry run ruff .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES) [ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff: format format_diff:
poetry run ruff format $(PYTHON_FILES) poetry run ruff format $(PYTHON_FILES)

View File

@ -84,7 +84,7 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
raise e raise e
data_params = data.get("params") data_params = data.get("params")
response = self.requests_wrapper.get(data["url"], params=data_params) response = self.requests_wrapper.get(data["url"], params=data_params)
response = response[: self.response_length] response = response[: self.response_length] # type: ignore[index]
return self.llm_chain.predict( return self.llm_chain.predict(
response=response, instructions=data["output_instructions"] response=response, instructions=data["output_instructions"]
).strip() ).strip()
@ -115,7 +115,7 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise e raise e
response = self.requests_wrapper.post(data["url"], data["data"]) response = self.requests_wrapper.post(data["url"], data["data"])
response = response[: self.response_length] response = response[: self.response_length] # type: ignore[index]
return self.llm_chain.predict( return self.llm_chain.predict(
response=response, instructions=data["output_instructions"] response=response, instructions=data["output_instructions"]
).strip() ).strip()
@ -146,7 +146,7 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise e raise e
response = self.requests_wrapper.patch(data["url"], data["data"]) response = self.requests_wrapper.patch(data["url"], data["data"])
response = response[: self.response_length] response = response[: self.response_length] # type: ignore[index]
return self.llm_chain.predict( return self.llm_chain.predict(
response=response, instructions=data["output_instructions"] response=response, instructions=data["output_instructions"]
).strip() ).strip()
@ -177,7 +177,7 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise e raise e
response = self.requests_wrapper.put(data["url"], data["data"]) response = self.requests_wrapper.put(data["url"], data["data"])
response = response[: self.response_length] response = response[: self.response_length] # type: ignore[index]
return self.llm_chain.predict( return self.llm_chain.predict(
response=response, instructions=data["output_instructions"] response=response, instructions=data["output_instructions"]
).strip() ).strip()
@ -209,7 +209,7 @@ class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise e raise e
response = self.requests_wrapper.delete(data["url"]) response = self.requests_wrapper.delete(data["url"])
response = response[: self.response_length] response = response[: self.response_length] # type: ignore[index]
return self.llm_chain.predict( return self.llm_chain.predict(
response=response, instructions=data["output_instructions"] response=response, instructions=data["output_instructions"]
).strip() ).strip()

View File

@ -177,12 +177,12 @@ def create_sql_agent(
elif agent_type == AgentType.OPENAI_FUNCTIONS: elif agent_type == AgentType.OPENAI_FUNCTIONS:
if prompt is None: if prompt is None:
messages = [ messages = [
SystemMessage(content=prefix), SystemMessage(content=prefix), # type: ignore[arg-type]
HumanMessagePromptTemplate.from_template("{input}"), HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX), AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"), MessagesPlaceholder(variable_name="agent_scratchpad"),
] ]
prompt = ChatPromptTemplate.from_messages(messages) prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type]
agent = RunnableAgent( agent = RunnableAgent(
runnable=create_openai_functions_agent(llm, tools, prompt), runnable=create_openai_functions_agent(llm, tools, prompt),
input_keys_arg=["input"], input_keys_arg=["input"],
@ -191,12 +191,12 @@ def create_sql_agent(
elif agent_type == "openai-tools": elif agent_type == "openai-tools":
if prompt is None: if prompt is None:
messages = [ messages = [
SystemMessage(content=prefix), SystemMessage(content=prefix), # type: ignore[arg-type]
HumanMessagePromptTemplate.from_template("{input}"), HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX), AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"), MessagesPlaceholder(variable_name="agent_scratchpad"),
] ]
prompt = ChatPromptTemplate.from_messages(messages) prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type]
agent = RunnableMultiActionAgent( agent = RunnableMultiActionAgent(
runnable=create_openai_tools_agent(llm, tools, prompt), runnable=create_openai_tools_agent(llm, tools, prompt),
input_keys_arg=["input"], input_keys_arg=["input"],

View File

@ -723,7 +723,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) )
return session_analysis_df return session_analysis_df
def _contain_llm_records(self): def _contain_llm_records(self): # type: ignore[no-untyped-def]
return bool(self.records["on_llm_start_records"]) return bool(self.records["on_llm_start_records"])
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None: def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:

View File

@ -47,7 +47,7 @@ class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
): ):
self.index: str = index self.index: str = index
self.session_id: str = session_id self.session_id: str = session_id
self.ensure_ascii: bool = esnsure_ascii self.ensure_ascii: bool = esnsure_ascii # type: ignore[assignment]
# Initialize Elasticsearch client from passed client arg or connection info # Initialize Elasticsearch client from passed client arg or connection info
if es_connection is not None: if es_connection is not None:

View File

@ -40,7 +40,7 @@ class TiDBChatMessageHistory(BaseChatMessageHistory):
self.session_id = session_id self.session_id = session_id
self.table_name = table_name self.table_name = table_name
self.earliest_time = earliest_time self.earliest_time = earliest_time
self.cache = [] self.cache = [] # type: ignore[var-annotated]
# Set up SQLAlchemy engine and session # Set up SQLAlchemy engine and session
self.engine = create_engine(connection_string) self.engine = create_engine(connection_string)
@ -102,7 +102,7 @@ class TiDBChatMessageHistory(BaseChatMessageHistory):
logger.error(f"Error loading messages to cache: {e}") logger.error(f"Error loading messages to cache: {e}")
@property @property
def messages(self) -> List[BaseMessage]: def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""returns all messages""" """returns all messages"""
if len(self.cache) == 0: if len(self.cache) == 0:
self.reload_cache() self.reload_cache()

View File

@ -149,7 +149,7 @@ class ZepChatMessageHistory(BaseChatMessageHistory):
return None return None
return zep_memory return zep_memory
def add_user_message( def add_user_message( # type: ignore[override]
self, message: str, metadata: Optional[Dict[str, Any]] = None self, message: str, metadata: Optional[Dict[str, Any]] = None
) -> None: ) -> None:
"""Convenience method for adding a human message string to the store. """Convenience method for adding a human message string to the store.
@ -160,7 +160,7 @@ class ZepChatMessageHistory(BaseChatMessageHistory):
""" """
self.add_message(HumanMessage(content=message), metadata=metadata) self.add_message(HumanMessage(content=message), metadata=metadata)
def add_ai_message( def add_ai_message( # type: ignore[override]
self, message: str, metadata: Optional[Dict[str, Any]] = None self, message: str, metadata: Optional[Dict[str, Any]] = None
) -> None: ) -> None:
"""Convenience method for adding an AI message string to the store. """Convenience method for adding an AI message string to the store.

View File

@ -20,7 +20,7 @@ from langchain_community.llms.azureml_endpoint import (
class LlamaContentFormatter(ContentFormatterBase): class LlamaContentFormatter(ContentFormatterBase):
def __init__(self): def __init__(self): # type: ignore[no-untyped-def]
raise TypeError( raise TypeError(
"`LlamaContentFormatter` is deprecated for chat models. Use " "`LlamaContentFormatter` is deprecated for chat models. Use "
"`LlamaChatContentFormatter` instead." "`LlamaChatContentFormatter` instead."
@ -72,7 +72,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
def supported_api_types(self) -> List[AzureMLEndpointApiType]: def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless] return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
def format_request_payload( def format_request_payload( # type: ignore[override]
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
model_kwargs: Dict, model_kwargs: Dict,
@ -98,9 +98,9 @@ class LlamaChatContentFormatter(ContentFormatterBase):
raise ValueError( raise ValueError(
f"`api_type` {api_type} is not supported by this formatter" f"`api_type` {api_type} is not supported by this formatter"
) )
return str.encode(request_payload) return str.encode(request_payload) # type: ignore[return-value]
def format_response_payload( def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType self, output: bytes, api_type: AzureMLEndpointApiType
) -> ChatGeneration: ) -> ChatGeneration:
"""Formats response""" """Formats response"""
@ -108,7 +108,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
try: try:
choice = json.loads(output)["output"] choice = json.loads(output)["output"]
except (KeyError, IndexError, TypeError) as e: except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return ChatGeneration( return ChatGeneration(
message=BaseMessage( message=BaseMessage(
content=choice.strip(), content=choice.strip(),
@ -125,7 +125,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
"model. Expected `dict` but `{type(choice)}` was received." "model. Expected `dict` but `{type(choice)}` was received."
) )
except (KeyError, IndexError, TypeError) as e: except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return ChatGeneration( return ChatGeneration(
message=BaseMessage( message=BaseMessage(
content=choice["message"]["content"].strip(), content=choice["message"]["content"].strip(),

View File

@ -175,7 +175,7 @@ class ChatEdenAI(BaseChatModel):
"""Call out to EdenAI's chat endpoint.""" """Call out to EdenAI's chat endpoint."""
url = f"{self.edenai_api_url}/text/chat/stream" url = f"{self.edenai_api_url}/text/chat/stream"
headers = { headers = {
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", "Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
"User-Agent": self.get_user_agent(), "User-Agent": self.get_user_agent(),
} }
formatted_data = _format_edenai_messages(messages=messages) formatted_data = _format_edenai_messages(messages=messages)
@ -216,7 +216,7 @@ class ChatEdenAI(BaseChatModel):
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
url = f"{self.edenai_api_url}/text/chat/stream" url = f"{self.edenai_api_url}/text/chat/stream"
headers = { headers = {
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", "Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
"User-Agent": self.get_user_agent(), "User-Agent": self.get_user_agent(),
} }
formatted_data = _format_edenai_messages(messages=messages) formatted_data = _format_edenai_messages(messages=messages)
@ -265,7 +265,7 @@ class ChatEdenAI(BaseChatModel):
url = f"{self.edenai_api_url}/text/chat" url = f"{self.edenai_api_url}/text/chat"
headers = { headers = {
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", "Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
"User-Agent": self.get_user_agent(), "User-Agent": self.get_user_agent(),
} }
formatted_data = _format_edenai_messages(messages=messages) formatted_data = _format_edenai_messages(messages=messages)
@ -323,7 +323,7 @@ class ChatEdenAI(BaseChatModel):
url = f"{self.edenai_api_url}/text/chat" url = f"{self.edenai_api_url}/text/chat"
headers = { headers = {
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", "Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
"User-Agent": self.get_user_agent(), "User-Agent": self.get_user_agent(),
} }
formatted_data = _format_edenai_messages(messages=messages) formatted_data = _format_edenai_messages(messages=messages)

View File

@ -214,7 +214,7 @@ class ErnieBotChat(BaseChatModel):
generations = [ generations = [
ChatGeneration( ChatGeneration(
message=AIMessage( message=AIMessage(
content=response.get("result"), content=response.get("result"), # type: ignore[arg-type]
additional_kwargs={**additional_kwargs}, additional_kwargs={**additional_kwargs},
) )
) )

View File

@ -56,7 +56,7 @@ class GPTRouterModel(BaseModel):
provider_name: str provider_name: str
def get_ordered_generation_requests( def get_ordered_generation_requests( # type: ignore[no-untyped-def, no-untyped-def]
models_priority_list: List[GPTRouterModel], **kwargs models_priority_list: List[GPTRouterModel], **kwargs
): ):
""" """
@ -100,7 +100,7 @@ def completion_with_retry(
models_priority_list: List[GPTRouterModel], models_priority_list: List[GPTRouterModel],
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]: ) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]: # type: ignore[type-arg]
"""Use tenacity to retry the completion call.""" """Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@ -122,7 +122,7 @@ async def acompletion_with_retry(
models_priority_list: List[GPTRouterModel], models_priority_list: List[GPTRouterModel],
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]: ) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]: # type: ignore[type-arg]
"""Use tenacity to retry the async completion call.""" """Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@ -282,7 +282,7 @@ class GPTRouter(BaseChatModel):
) )
return self._create_chat_result(response) return self._create_chat_result(response)
def _create_chat_generation_chunk( def _create_chat_generation_chunk( # type: ignore[no-untyped-def, no-untyped-def]
self, data: Mapping[str, Any], default_chunk_class self, data: Mapping[str, Any], default_chunk_class
): ):
chunk = _convert_delta_to_message_chunk( chunk = _convert_delta_to_message_chunk(
@ -293,7 +293,7 @@ class GPTRouter(BaseChatModel):
dict(finish_reason=finish_reason) if finish_reason is not None else None dict(finish_reason=finish_reason) if finish_reason is not None else None
) )
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) # type: ignore[assignment]
return chunk, default_chunk_class return chunk, default_chunk_class
def _stream( def _stream(

View File

@ -144,7 +144,7 @@ class ChatHuggingFace(BaseChatModel):
elif isinstance(self.llm, HuggingFaceHub): elif isinstance(self.llm, HuggingFaceHub):
# no need to look up model_id for HuggingFaceHub LLM # no need to look up model_id for HuggingFaceHub LLM
self.model_id = self.llm.repo_id self.model_id = self.llm.repo_id # type: ignore[assignment]
return return
else: else:

View File

@ -169,7 +169,7 @@ class ChatKonko(ChatOpenAI):
} }
if openai_api_key: if openai_api_key:
headers["X-OpenAI-Api-Key"] = openai_api_key.get_secret_value() headers["X-OpenAI-Api-Key"] = openai_api_key.get_secret_value() # type: ignore[union-attr]
models_response = requests.get(models_url, headers=headers) models_response = requests.get(models_url, headers=headers)

View File

@ -74,10 +74,10 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
if isinstance(message, ChatMessage): if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}" message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage): elif isinstance(message, HumanMessage):
if message.content[0].get("type") == "text": if message.content[0].get("type") == "text": # type: ignore[union-attr]
message_text = f"[INST] {message.content[0]['text']} [/INST]" message_text = f"[INST] {message.content[0]['text']} [/INST]" # type: ignore[index]
elif message.content[0].get("type") == "image_url": elif message.content[0].get("type") == "image_url": # type: ignore[union-attr]
message_text = message.content[0]["image_url"]["url"] message_text = message.content[0]["image_url"]["url"] # type: ignore[index, index]
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
message_text = f"{message.content}" message_text = f"{message.content}"
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
@ -112,11 +112,11 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
content = message.content content = message.content
else: else:
for content_part in message.content: for content_part in message.content:
if content_part.get("type") == "text": if content_part.get("type") == "text": # type: ignore[union-attr]
content += f"\n{content_part['text']}" content += f"\n{content_part['text']}" # type: ignore[index]
elif content_part.get("type") == "image_url": elif content_part.get("type") == "image_url": # type: ignore[union-attr]
if isinstance(content_part.get("image_url"), str): if isinstance(content_part.get("image_url"), str): # type: ignore[union-attr]
image_url_components = content_part["image_url"].split(",") image_url_components = content_part["image_url"].split(",") # type: ignore[index]
# Support data:image/jpeg;base64,<image> format # Support data:image/jpeg;base64,<image> format
# and base64 strings # and base64 strings
if len(image_url_components) > 1: if len(image_url_components) > 1:
@ -142,7 +142,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
} }
) )
return ollama_messages return ollama_messages # type: ignore[return-value]
def _create_chat_stream( def _create_chat_stream(
self, self,
@ -337,7 +337,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
verbose=self.verbose, verbose=self.verbose,
) )
except OllamaEndpointNotFoundError: except OllamaEndpointNotFoundError:
async for chunk in self._legacy_astream(messages, stop, **kwargs): async for chunk in self._legacy_astream(messages, stop, **kwargs): # type: ignore[attr-defined]
yield chunk yield chunk
@deprecated("0.0.3", alternative="_stream") @deprecated("0.0.3", alternative="_stream")

View File

@ -197,7 +197,7 @@ class ChatTongyi(BaseChatModel):
return { return {
"model": self.model_name, "model": self.model_name,
"top_p": self.top_p, "top_p": self.top_p,
"api_key": self.dashscope_api_key.get_secret_value(), "api_key": self.dashscope_api_key.get_secret_value(), # type: ignore[union-attr]
"result_format": "message", "result_format": "message",
**self.model_kwargs, **self.model_kwargs,
} }

View File

@ -121,7 +121,7 @@ def _parse_chat_history_gemini(
elif path.startswith("data:image/"): elif path.startswith("data:image/"):
# extract base64 component from image uri # extract base64 component from image uri
try: try:
encoded = re.search(r"data:image/\w{2,4};base64,(.*)", path).group( encoded = re.search(r"data:image/\w{2,4};base64,(.*)", path).group( # type: ignore[union-attr]
1 1
) )
except AttributeError: except AttributeError:

View File

@ -52,7 +52,7 @@ def _parse_chat_history(history: List[BaseMessage]) -> List[Dict[str, str]]:
return chat_history return chat_history
class ChatYandexGPT(_BaseYandexGPT, BaseChatModel): class ChatYandexGPT(_BaseYandexGPT, BaseChatModel): # type: ignore[misc]
"""Wrapper around YandexGPT large language models. """Wrapper around YandexGPT large language models.
There are two authentication options for the service account There are two authentication options for the service account
@ -156,7 +156,7 @@ def _make_request(
messages=[Message(**message) for message in message_history], messages=[Message(**message) for message in message_history],
) )
stub = TextGenerationServiceStub(channel) stub = TextGenerationServiceStub(channel)
res = stub.Completion(request, metadata=self._grpc_metadata) res = stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
return list(res)[0].alternatives[0].message.text return list(res)[0].alternatives[0].message.text
@ -201,7 +201,7 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
messages=[Message(**message) for message in message_history], messages=[Message(**message) for message in message_history],
) )
stub = TextGenerationAsyncServiceStub(channel) stub = TextGenerationAsyncServiceStub(channel)
operation = await stub.Completion(request, metadata=self._grpc_metadata) operation = await stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
async with grpc.aio.secure_channel( async with grpc.aio.secure_channel(
operation_api_url, channel_credentials operation_api_url, channel_credentials
) as operation_channel: ) as operation_channel:
@ -210,7 +210,8 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
await asyncio.sleep(1) await asyncio.sleep(1)
operation_request = GetOperationRequest(operation_id=operation.id) operation_request = GetOperationRequest(operation_id=operation.id)
operation = await operation_stub.Get( operation = await operation_stub.Get(
operation_request, metadata=self._grpc_metadata operation_request,
metadata=self._grpc_metadata, # type: ignore[attr-defined]
) )
completion_response = CompletionResponse() completion_response = CompletionResponse()

View File

@ -161,7 +161,7 @@ class ChatZhipuAI(BaseChatModel):
return attributes return attributes
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def]
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
try: try:
import zhipuai import zhipuai
@ -174,7 +174,7 @@ class ChatZhipuAI(BaseChatModel):
"Please install it via 'pip install zhipuai'" "Please install it via 'pip install zhipuai'"
) )
def invoke(self, prompt): def invoke(self, prompt): # type: ignore[no-untyped-def]
if self.model == "chatglm_turbo": if self.model == "chatglm_turbo":
return self.zhipuai.model_api.invoke( return self.zhipuai.model_api.invoke(
model=self.model, model=self.model,
@ -195,7 +195,7 @@ class ChatZhipuAI(BaseChatModel):
) )
return None return None
def sse_invoke(self, prompt): def sse_invoke(self, prompt): # type: ignore[no-untyped-def]
if self.model == "chatglm_turbo": if self.model == "chatglm_turbo":
return self.zhipuai.model_api.sse_invoke( return self.zhipuai.model_api.sse_invoke(
model=self.model, model=self.model,
@ -218,7 +218,7 @@ class ChatZhipuAI(BaseChatModel):
) )
return None return None
async def async_invoke(self, prompt): async def async_invoke(self, prompt): # type: ignore[no-untyped-def]
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
partial_func = partial( partial_func = partial(
self.zhipuai.model_api.async_invoke, model=self.model, prompt=prompt self.zhipuai.model_api.async_invoke, model=self.model, prompt=prompt
@ -229,7 +229,7 @@ class ChatZhipuAI(BaseChatModel):
) )
return response return response
async def async_invoke_result(self, task_id): async def async_invoke_result(self, task_id): # type: ignore[no-untyped-def]
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
response = await loop.run_in_executor( response = await loop.run_in_executor(
None, None,
@ -270,11 +270,14 @@ class ChatZhipuAI(BaseChatModel):
else: else:
stream_iter = self._stream( stream_iter = self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs prompt=prompt, # type: ignore[arg-type]
stop=stop,
run_manager=run_manager,
**kwargs,
) )
return generate_from_stream(stream_iter) return generate_from_stream(stream_iter)
async def _agenerate( async def _agenerate( # type: ignore[override]
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
@ -307,7 +310,7 @@ class ChatZhipuAI(BaseChatModel):
generations=[ChatGeneration(message=AIMessage(content=content))] generations=[ChatGeneration(message=AIMessage(content=content))]
) )
def _stream( def _stream( # type: ignore[override]
self, self,
prompt: List[Dict[str, str]], prompt: List[Dict[str, str]],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,

View File

@ -123,7 +123,7 @@ class AssemblyAIAudioLoaderById(BaseLoader):
""" """
def __init__(self, transcript_id, api_key, transcript_format): def __init__(self, transcript_id, api_key, transcript_format): # type: ignore[no-untyped-def]
""" """
Initializes the AssemblyAI AssemblyAIAudioLoaderById. Initializes the AssemblyAI AssemblyAIAudioLoaderById.

View File

@ -65,7 +65,7 @@ class AstraDBLoader(BaseLoader):
return list(self.lazy_load()) return list(self.lazy_load())
def lazy_load(self) -> Iterator[Document]: def lazy_load(self) -> Iterator[Document]:
queue = Queue(self.nb_prefetched) queue = Queue(self.nb_prefetched) # type: ignore[var-annotated]
t = threading.Thread(target=self.fetch_results, args=(queue,)) t = threading.Thread(target=self.fetch_results, args=(queue,))
t.start() t.start()
while True: while True:
@ -95,7 +95,7 @@ class AstraDBLoader(BaseLoader):
item = await run_in_executor(None, lambda it: next(it, done), iterator) item = await run_in_executor(None, lambda it: next(it, done), iterator)
if item is done: if item is done:
break break
yield item yield item # type: ignore[misc]
return return
async_collection = await self.astra_env.async_astra_db.collection( async_collection = await self.astra_env.async_astra_db.collection(
self.collection_name self.collection_name
@ -116,13 +116,13 @@ class AstraDBLoader(BaseLoader):
}, },
) )
def fetch_results(self, queue: Queue): def fetch_results(self, queue: Queue): # type: ignore[no-untyped-def]
self.fetch_page_result(queue) self.fetch_page_result(queue)
while self.find_options.get("pageState"): while self.find_options.get("pageState"):
self.fetch_page_result(queue) self.fetch_page_result(queue)
queue.put(None) queue.put(None)
def fetch_page_result(self, queue: Queue): def fetch_page_result(self, queue: Queue): # type: ignore[no-untyped-def]
res = self.collection.find( res = self.collection.find(
filter=self.filter, filter=self.filter,
options=self.find_options, options=self.find_options,

View File

@ -64,10 +64,10 @@ class BaseLoader(ABC):
iterator = await run_in_executor(None, self.lazy_load) iterator = await run_in_executor(None, self.lazy_load)
done = object() done = object()
while True: while True:
doc = await run_in_executor(None, next, iterator, done) doc = await run_in_executor(None, next, iterator, done) # type: ignore[call-arg, arg-type]
if doc is done: if doc is done:
break break
yield doc yield doc # type: ignore[misc]
class BaseBlobParser(ABC): class BaseBlobParser(ABC):

View File

@ -33,14 +33,14 @@ class CassandraLoader(BaseLoader):
page_content_mapper: Callable[[Any], str] = str, page_content_mapper: Callable[[Any], str] = str,
metadata_mapper: Callable[[Any], dict] = lambda _: {}, metadata_mapper: Callable[[Any], dict] = lambda _: {},
*, *,
query_parameters: Union[dict, Sequence] = None, query_parameters: Union[dict, Sequence] = None, # type: ignore[assignment]
query_timeout: Optional[float] = _NOT_SET, query_timeout: Optional[float] = _NOT_SET, # type: ignore[assignment]
query_trace: bool = False, query_trace: bool = False,
query_custom_payload: dict = None, query_custom_payload: dict = None, # type: ignore[assignment]
query_execution_profile: Any = _NOT_SET, query_execution_profile: Any = _NOT_SET,
query_paging_state: Any = None, query_paging_state: Any = None,
query_host: Host = None, query_host: Host = None,
query_execute_as: str = None, query_execute_as: str = None, # type: ignore[assignment]
) -> None: ) -> None:
""" """
Document Loader for Apache Cassandra. Document Loader for Apache Cassandra.
@ -85,7 +85,7 @@ class CassandraLoader(BaseLoader):
self.query = f"SELECT * FROM {_keyspace}.{table};" self.query = f"SELECT * FROM {_keyspace}.{table};"
self.metadata = {"table": table, "keyspace": _keyspace} self.metadata = {"table": table, "keyspace": _keyspace}
else: else:
self.query = query self.query = query # type: ignore[assignment]
self.metadata = {} self.metadata = {}
self.session = session or check_resolve_session(session) self.session = session or check_resolve_session(session)

View File

@ -27,7 +27,7 @@ class UnstructuredCHMLoader(UnstructuredFileLoader):
def _get_elements(self) -> List: def _get_elements(self) -> List:
from unstructured.partition.html import partition_html from unstructured.partition.html import partition_html
with CHMParser(self.file_path) as f: with CHMParser(self.file_path) as f: # type: ignore[arg-type]
return [ return [
partition_html(text=item["content"], **self.unstructured_kwargs) partition_html(text=item["content"], **self.unstructured_kwargs)
for item in f.load_all() for item in f.load_all()
@ -45,10 +45,10 @@ class CHMParser(object):
self.file = chm.CHMFile() self.file = chm.CHMFile()
self.file.LoadCHM(path) self.file.LoadCHM(path)
def __enter__(self): def __enter__(self): # type: ignore[no-untyped-def]
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback): # type: ignore[no-untyped-def]
if self.file: if self.file:
self.file.CloseCHM() self.file.CloseCHM()

View File

@ -89,4 +89,4 @@ class AzureAIDocumentIntelligenceLoader(BaseLoader):
blob = Blob.from_path(self.file_path) blob = Blob.from_path(self.file_path)
yield from self.parser.parse(blob) yield from self.parser.parse(blob)
else: else:
yield from self.parser.parse_url(self.url_path) yield from self.parser.parse_url(self.url_path) # type: ignore[arg-type]

View File

@ -60,7 +60,7 @@ class MWDumpLoader(BaseLoader):
self.skip_redirects = skip_redirects self.skip_redirects = skip_redirects
self.stop_on_error = stop_on_error self.stop_on_error = stop_on_error
def _load_dump_file(self): def _load_dump_file(self): # type: ignore[no-untyped-def]
try: try:
import mwxml import mwxml
except ImportError as e: except ImportError as e:
@ -70,7 +70,7 @@ class MWDumpLoader(BaseLoader):
return mwxml.Dump.from_file(open(self.file_path, encoding=self.encoding)) return mwxml.Dump.from_file(open(self.file_path, encoding=self.encoding))
def _load_single_page_from_dump(self, page) -> Document: def _load_single_page_from_dump(self, page) -> Document: # type: ignore[no-untyped-def, return]
"""Parse a single page.""" """Parse a single page."""
try: try:
import mwparserfromhell import mwparserfromhell

View File

@ -11,7 +11,7 @@ from langchain_community.document_loaders.blob_loaders import Blob
class VsdxParser(BaseBlobParser, ABC): class VsdxParser(BaseBlobParser, ABC):
def parse(self, blob: Blob) -> Iterator[Document]: def parse(self, blob: Blob) -> Iterator[Document]: # type: ignore[override]
"""Parse a vsdx file.""" """Parse a vsdx file."""
return self.lazy_parse(blob) return self.lazy_parse(blob)
@ -21,7 +21,7 @@ class VsdxParser(BaseBlobParser, ABC):
with blob.as_bytes_io() as pdf_file_obj: with blob.as_bytes_io() as pdf_file_obj:
with zipfile.ZipFile(pdf_file_obj, "r") as zfile: with zipfile.ZipFile(pdf_file_obj, "r") as zfile:
pages = self.get_pages_content(zfile, blob.source) pages = self.get_pages_content(zfile, blob.source) # type: ignore[arg-type]
yield from [ yield from [
Document( Document(
@ -60,13 +60,13 @@ class VsdxParser(BaseBlobParser, ABC):
if "visio/pages/pages.xml" not in zfile.namelist(): if "visio/pages/pages.xml" not in zfile.namelist():
print("WARNING - No pages.xml file found in {}".format(source)) print("WARNING - No pages.xml file found in {}".format(source))
return return # type: ignore[return-value]
if "visio/pages/_rels/pages.xml.rels" not in zfile.namelist(): if "visio/pages/_rels/pages.xml.rels" not in zfile.namelist():
print("WARNING - No pages.xml.rels file found in {}".format(source)) print("WARNING - No pages.xml.rels file found in {}".format(source))
return return # type: ignore[return-value]
if "docProps/app.xml" not in zfile.namelist(): if "docProps/app.xml" not in zfile.namelist():
print("WARNING - No app.xml file found in {}".format(source)) print("WARNING - No app.xml file found in {}".format(source))
return return # type: ignore[return-value]
pagesxml_content: dict = xmltodict.parse(zfile.read("visio/pages/pages.xml")) pagesxml_content: dict = xmltodict.parse(zfile.read("visio/pages/pages.xml"))
appxml_content: dict = xmltodict.parse(zfile.read("docProps/app.xml")) appxml_content: dict = xmltodict.parse(zfile.read("docProps/app.xml"))
@ -79,7 +79,7 @@ class VsdxParser(BaseBlobParser, ABC):
rel["@Name"].strip() for rel in pagesxml_content["Pages"]["Page"] rel["@Name"].strip() for rel in pagesxml_content["Pages"]["Page"]
] ]
else: else:
disordered_names: List[str] = [ disordered_names: List[str] = [ # type: ignore[no-redef]
pagesxml_content["Pages"]["Page"]["@Name"].strip() pagesxml_content["Pages"]["Page"]["@Name"].strip()
] ]
if isinstance(pagesxmlrels_content["Relationships"]["Relationship"], list): if isinstance(pagesxmlrels_content["Relationships"]["Relationship"], list):
@ -88,7 +88,7 @@ class VsdxParser(BaseBlobParser, ABC):
for rel in pagesxmlrels_content["Relationships"]["Relationship"] for rel in pagesxmlrels_content["Relationships"]["Relationship"]
] ]
else: else:
disordered_paths: List[str] = [ disordered_paths: List[str] = [ # type: ignore[no-redef]
"visio/pages/" "visio/pages/"
+ pagesxmlrels_content["Relationships"]["Relationship"]["@Target"] + pagesxmlrels_content["Relationships"]["Relationship"]["@Target"]
] ]

View File

@ -89,7 +89,7 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
print(f"Exception occurred while trying to get embeddings: {str(e)}") print(f"Exception occurred while trying to get embeddings: {str(e)}")
return None return None
def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: # type: ignore[override]
"""Public method to get embeddings for a list of documents. """Public method to get embeddings for a list of documents.
Args: Args:
@ -100,7 +100,7 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
""" """
return self._embed(texts) return self._embed(texts)
def embed_query(self, text: str) -> Optional[List[float]]: def embed_query(self, text: str) -> Optional[List[float]]: # type: ignore[override]
"""Public method to get embedding for a single query text. """Public method to get embedding for a single query text.
Args: Args:

View File

@ -56,7 +56,7 @@ class EdenAiEmbeddings(BaseModel, Embeddings):
headers = { headers = {
"accept": "application/json", "accept": "application/json",
"content-type": "application/json", "content-type": "application/json",
"authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", "authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
"User-Agent": self.get_user_agent(), "User-Agent": self.get_user_agent(),
} }

View File

@ -85,7 +85,7 @@ class EmbaasEmbeddings(BaseModel, Embeddings):
def _handle_request(self, payload: EmbaasEmbeddingsPayload) -> List[List[float]]: def _handle_request(self, payload: EmbaasEmbeddingsPayload) -> List[List[float]]:
"""Sends a request to the Embaas API and handles the response.""" """Sends a request to the Embaas API and handles the response."""
headers = { headers = {
"Authorization": f"Bearer {self.embaas_api_key.get_secret_value()}", "Authorization": f"Bearer {self.embaas_api_key.get_secret_value()}", # type: ignore[union-attr]
"Content-Type": "application/json", "Content-Type": "application/json",
} }

View File

@ -162,5 +162,5 @@ class TinyAsyncGradientEmbeddingClient: #: :meta private:
It might be entirely removed in the future. It might be entirely removed in the future.
""" """
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
raise ValueError("Deprecated,TinyAsyncGradientEmbeddingClient was removed.") raise ValueError("Deprecated,TinyAsyncGradientEmbeddingClient was removed.")

View File

@ -56,7 +56,7 @@ class LLMRailsEmbeddings(BaseModel, Embeddings):
""" """
response = requests.post( response = requests.post(
"https://api.llmrails.com/v1/embeddings", "https://api.llmrails.com/v1/embeddings",
headers={"X-API-KEY": self.api_key.get_secret_value()}, headers={"X-API-KEY": self.api_key.get_secret_value()}, # type: ignore[union-attr]
json={"input": texts, "model": self.model}, json={"input": texts, "model": self.model},
timeout=60, timeout=60,
) )

View File

@ -110,7 +110,7 @@ class MiniMaxEmbeddings(BaseModel, Embeddings):
# HTTP headers for authorization # HTTP headers for authorization
headers = { headers = {
"Authorization": f"Bearer {self.minimax_api_key.get_secret_value()}", "Authorization": f"Bearer {self.minimax_api_key.get_secret_value()}", # type: ignore[union-attr]
"Content-Type": "application/json", "Content-Type": "application/json",
} }

View File

@ -71,7 +71,8 @@ class MlflowEmbeddings(Embeddings, BaseModel):
embeddings: List[List[float]] = [] embeddings: List[List[float]] = []
for txt in _chunk(texts, 20): for txt in _chunk(texts, 20):
resp = self._client.predict( resp = self._client.predict(
endpoint=self.endpoint, inputs={"input": txt, **params} endpoint=self.endpoint,
inputs={"input": txt, **params}, # type: ignore[arg-type]
) )
embeddings.extend(r["embedding"] for r in resp["data"]) embeddings.extend(r["embedding"] for r in resp["data"])
return embeddings return embeddings

View File

@ -63,16 +63,16 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings):
If not specified , DEFAULT will be used If not specified , DEFAULT will be used
""" """
model_id: str = None model_id: str = None # type: ignore[assignment]
"""Id of the model to call, e.g., cohere.embed-english-light-v2.0""" """Id of the model to call, e.g., cohere.embed-english-light-v2.0"""
model_kwargs: Optional[Dict] = None model_kwargs: Optional[Dict] = None
"""Keyword arguments to pass to the model""" """Keyword arguments to pass to the model"""
service_endpoint: str = None service_endpoint: str = None # type: ignore[assignment]
"""service endpoint url""" """service endpoint url"""
compartment_id: str = None compartment_id: str = None # type: ignore[assignment]
"""OCID of compartment""" """OCID of compartment"""
truncate: Optional[str] = "END" truncate: Optional[str] = "END"
@ -109,7 +109,7 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings):
client_kwargs.pop("signer", None) client_kwargs.pop("signer", None)
elif values["auth_type"] == OCIAuthType(2).name: elif values["auth_type"] == OCIAuthType(2).name:
def make_security_token_signer(oci_config): def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
pk = oci.signer.load_private_key_from_file( pk = oci.signer.load_private_key_from_file(
oci_config.get("key_file"), None oci_config.get("key_file"), None
) )

View File

@ -78,7 +78,7 @@ class SpacyEmbeddings(BaseModel, Embeddings):
Returns: Returns:
A list of embeddings, one for each document. A list of embeddings, one for each document.
""" """
return [self.nlp(text).vector.tolist() for text in texts] return [self.nlp(text).vector.tolist() for text in texts] # type: ignore[misc]
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
""" """
@ -90,7 +90,7 @@ class SpacyEmbeddings(BaseModel, Embeddings):
Returns: Returns:
The embedding for the text. The embedding for the text.
""" """
return self.nlp(text).vector.tolist() return self.nlp(text).vector.tolist() # type: ignore[misc]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
""" """

View File

@ -42,10 +42,10 @@ class YandexGPTEmbeddings(BaseModel, Embeddings):
embeddings = YandexGPTEmbeddings(iam_token="t1.9eu...", model_uri="emb://<folder-id>/text-search-query/latest") embeddings = YandexGPTEmbeddings(iam_token="t1.9eu...", model_uri="emb://<folder-id>/text-search-query/latest")
""" """
iam_token: SecretStr = "" iam_token: SecretStr = "" # type: ignore[assignment]
"""Yandex Cloud IAM token for service account """Yandex Cloud IAM token for service account
with the `ai.languageModels.user` role""" with the `ai.languageModels.user` role"""
api_key: SecretStr = "" api_key: SecretStr = "" # type: ignore[assignment]
"""Yandex Cloud Api Key for service account """Yandex Cloud Api Key for service account
with the `ai.languageModels.user` role""" with the `ai.languageModels.user` role"""
model_uri: str = "" model_uri: str = ""
@ -146,7 +146,7 @@ def _embed_with_retry(llm: YandexGPTEmbeddings, **kwargs: Any) -> Any:
return _completion_with_retry(**kwargs) return _completion_with_retry(**kwargs)
def _make_request(self: YandexGPTEmbeddings, texts: List[str]): def _make_request(self: YandexGPTEmbeddings, texts: List[str]): # type: ignore[no-untyped-def]
try: try:
import grpc import grpc
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501 from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
@ -167,7 +167,7 @@ def _make_request(self: YandexGPTEmbeddings, texts: List[str]):
for text in texts: for text in texts:
request = TextEmbeddingRequest(model_uri=self.model_uri, text=text) request = TextEmbeddingRequest(model_uri=self.model_uri, text=text)
stub = EmbeddingsServiceStub(channel) stub = EmbeddingsServiceStub(channel)
res = stub.TextEmbedding(request, metadata=self._grpc_metadata) res = stub.TextEmbedding(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
result.append(list(res.embedding)) result.append(list(res.embedding))
time.sleep(self.sleep_interval) time.sleep(self.sleep_interval)

View File

@ -56,7 +56,7 @@ def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]:
cleaned_list.append(value_sanitize(item)) cleaned_list.append(value_sanitize(item))
else: else:
cleaned_list.append(item) cleaned_list.append(item)
new_dict[key] = cleaned_list new_dict[key] = cleaned_list # type: ignore[assignment]
else: else:
new_dict[key] = value new_dict[key] = value
return new_dict return new_dict

View File

@ -95,12 +95,13 @@ class OntotextGraphDBGraph:
if local_file: if local_file:
ontology_schema_graph = self._load_ontology_schema_from_file( ontology_schema_graph = self._load_ontology_schema_from_file(
local_file, local_file_format local_file,
local_file_format, # type: ignore[arg-type]
) )
else: else:
self._validate_user_query(query_ontology) self._validate_user_query(query_ontology) # type: ignore[arg-type]
ontology_schema_graph = self._load_ontology_schema_with_query( ontology_schema_graph = self._load_ontology_schema_with_query(
query_ontology query_ontology # type: ignore[arg-type]
) )
self.schema = ontology_schema_graph.serialize(format="turtle") self.schema = ontology_schema_graph.serialize(format="turtle")
@ -139,7 +140,7 @@ class OntotextGraphDBGraph:
) )
@staticmethod @staticmethod
def _load_ontology_schema_from_file(local_file: str, local_file_format: str = None): def _load_ontology_schema_from_file(local_file: str, local_file_format: str = None): # type: ignore[no-untyped-def, assignment]
""" """
Parse the ontology schema statements from the provided file Parse the ontology schema statements from the provided file
""" """
@ -176,7 +177,7 @@ class OntotextGraphDBGraph:
"Invalid query type. Only CONSTRUCT queries are supported." "Invalid query type. Only CONSTRUCT queries are supported."
) )
def _load_ontology_schema_with_query(self, query: str): def _load_ontology_schema_with_query(self, query: str): # type: ignore[no-untyped-def]
""" """
Execute the query for collecting the ontology schema statements Execute the query for collecting the ontology schema statements
""" """

View File

@ -31,7 +31,7 @@ class TigerGraph(GraphStore):
def schema(self) -> Dict[str, Any]: def schema(self) -> Dict[str, Any]:
return self._schema return self._schema
def get_schema(self) -> str: def get_schema(self) -> str: # type: ignore[override]
if self._schema: if self._schema:
return str(self._schema) return str(self._schema)
else: else:
@ -71,10 +71,10 @@ class TigerGraph(GraphStore):
""" """
return self._conn.getSchema(force=True) return self._conn.getSchema(force=True)
def refresh_schema(self): def refresh_schema(self): # type: ignore[no-untyped-def]
self.generate_schema() self.generate_schema()
def query(self, query: str) -> Dict[str, Any]: def query(self, query: str) -> Dict[str, Any]: # type: ignore[override]
"""Query the TigerGraph database.""" """Query the TigerGraph database."""
answer = self._conn.ai.query(query) answer = self._conn.ai.query(query)
return answer return answer

View File

@ -165,7 +165,7 @@ class GPT2ContentFormatter(ContentFormatterBase):
def supported_api_types(self) -> List[AzureMLEndpointApiType]: def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime] return [AzureMLEndpointApiType.realtime]
def format_request_payload( def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes: ) -> bytes:
prompt = ContentFormatterBase.escape_special_characters(prompt) prompt = ContentFormatterBase.escape_special_characters(prompt)
@ -174,13 +174,13 @@ class GPT2ContentFormatter(ContentFormatterBase):
) )
return str.encode(request_payload) return str.encode(request_payload)
def format_response_payload( def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation: ) -> Generation:
try: try:
choice = json.loads(output)[0]["0"] choice = json.loads(output)[0]["0"]
except (KeyError, IndexError, TypeError) as e: except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(text=choice) return Generation(text=choice)
@ -207,7 +207,7 @@ class HFContentFormatter(ContentFormatterBase):
def supported_api_types(self) -> List[AzureMLEndpointApiType]: def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime] return [AzureMLEndpointApiType.realtime]
def format_request_payload( def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes: ) -> bytes:
ContentFormatterBase.escape_special_characters(prompt) ContentFormatterBase.escape_special_characters(prompt)
@ -216,13 +216,13 @@ class HFContentFormatter(ContentFormatterBase):
) )
return str.encode(request_payload) return str.encode(request_payload)
def format_response_payload( def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation: ) -> Generation:
try: try:
choice = json.loads(output)[0]["0"]["generated_text"] choice = json.loads(output)[0]["0"]["generated_text"]
except (KeyError, IndexError, TypeError) as e: except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(text=choice) return Generation(text=choice)
@ -233,7 +233,7 @@ class DollyContentFormatter(ContentFormatterBase):
def supported_api_types(self) -> List[AzureMLEndpointApiType]: def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime] return [AzureMLEndpointApiType.realtime]
def format_request_payload( def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes: ) -> bytes:
prompt = ContentFormatterBase.escape_special_characters(prompt) prompt = ContentFormatterBase.escape_special_characters(prompt)
@ -245,13 +245,13 @@ class DollyContentFormatter(ContentFormatterBase):
) )
return str.encode(request_payload) return str.encode(request_payload)
def format_response_payload( def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation: ) -> Generation:
try: try:
choice = json.loads(output)[0] choice = json.loads(output)[0]
except (KeyError, IndexError, TypeError) as e: except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(text=choice) return Generation(text=choice)
@ -262,7 +262,7 @@ class LlamaContentFormatter(ContentFormatterBase):
def supported_api_types(self) -> List[AzureMLEndpointApiType]: def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless] return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
def format_request_payload( def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes: ) -> bytes:
"""Formats the request according to the chosen api""" """Formats the request according to the chosen api"""
@ -284,7 +284,7 @@ class LlamaContentFormatter(ContentFormatterBase):
) )
return str.encode(request_payload) return str.encode(request_payload)
def format_response_payload( def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation: ) -> Generation:
"""Formats response""" """Formats response"""
@ -292,7 +292,7 @@ class LlamaContentFormatter(ContentFormatterBase):
try: try:
choice = json.loads(output)[0]["0"] choice = json.loads(output)[0]["0"]
except (KeyError, IndexError, TypeError) as e: except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(text=choice) return Generation(text=choice)
if api_type == AzureMLEndpointApiType.serverless: if api_type == AzureMLEndpointApiType.serverless:
try: try:
@ -304,7 +304,7 @@ class LlamaContentFormatter(ContentFormatterBase):
"received." "received."
) )
except (KeyError, IndexError, TypeError) as e: except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation( return Generation(
text=choice["text"].strip(), text=choice["text"].strip(),
generation_info=dict( generation_info=dict(
@ -397,7 +397,7 @@ class AzureMLBaseEndpoint(BaseModel):
) -> AzureMLEndpointApiType: ) -> AzureMLEndpointApiType:
"""Validate that endpoint api type is compatible with the URL format.""" """Validate that endpoint api type is compatible with the URL format."""
endpoint_url = values.get("endpoint_url") endpoint_url = values.get("endpoint_url")
if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith( if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith( # type: ignore[union-attr]
"/score" "/score"
): ):
raise ValueError( raise ValueError(
@ -407,8 +407,8 @@ class AzureMLBaseEndpoint(BaseModel):
"`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead." "`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead."
) )
if field_value == AzureMLEndpointApiType.serverless and not ( if field_value == AzureMLEndpointApiType.serverless and not (
endpoint_url.endswith("/v1/completions") endpoint_url.endswith("/v1/completions") # type: ignore[union-attr]
or endpoint_url.endswith("/v1/chat/completions") or endpoint_url.endswith("/v1/chat/completions") # type: ignore[union-attr]
): ):
raise ValueError( raise ValueError(
"Endpoints of type `serverless` should follow the format " "Endpoints of type `serverless` should follow the format "
@ -426,7 +426,9 @@ class AzureMLBaseEndpoint(BaseModel):
deployment_name = values.get("deployment_name") deployment_name = values.get("deployment_name")
http_client = AzureMLEndpointClient( http_client = AzureMLEndpointClient(
endpoint_url, endpoint_key.get_secret_value(), deployment_name endpoint_url, # type: ignore
endpoint_key.get_secret_value(), # type: ignore
deployment_name, # type: ignore
) )
return http_client return http_client

View File

@ -56,11 +56,11 @@ class BaichuanLLM(LLM):
def _post(self, request: Any) -> Any: def _post(self, request: Any) -> Any:
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {self.baichuan_api_key.get_secret_value()}", "Authorization": f"Bearer {self.baichuan_api_key.get_secret_value()}", # type: ignore[union-attr]
} }
try: try:
response = requests.post( response = requests.post(
self.baichuan_api_host, self.baichuan_api_host, # type: ignore[arg-type]
headers=headers, headers=headers,
json=request, json=request,
timeout=self.timeout, timeout=self.timeout,

View File

@ -395,8 +395,8 @@ class BedrockBase(BaseModel, ABC):
""" """
return { return {
"amazon-bedrock-guardrailDetails": { "amazon-bedrock-guardrailDetails": {
"guardrailId": self.guardrails.get("id"), "guardrailId": self.guardrails.get("id"), # type: ignore[union-attr]
"guardrailVersion": self.guardrails.get("version"), "guardrailVersion": self.guardrails.get("version"), # type: ignore[union-attr]
} }
} }
@ -427,7 +427,7 @@ class BedrockBase(BaseModel, ABC):
if self._guardrails_enabled: if self._guardrails_enabled:
request_options["guardrail"] = "ENABLED" request_options["guardrail"] = "ENABLED"
if self.guardrails.get("trace"): if self.guardrails.get("trace"): # type: ignore[union-attr]
request_options["trace"] = "ENABLED" request_options["trace"] = "ENABLED"
try: try:
@ -446,7 +446,7 @@ class BedrockBase(BaseModel, ABC):
# Verify and raise a callback error if any intervention occurs or a signal is # Verify and raise a callback error if any intervention occurs or a signal is
# sent from a Bedrock service, # sent from a Bedrock service,
# such as when guardrails are triggered. # such as when guardrails are triggered.
services_trace = self._get_bedrock_services_signal(body) services_trace = self._get_bedrock_services_signal(body) # type: ignore[arg-type]
if services_trace.get("signal") and run_manager is not None: if services_trace.get("signal") and run_manager is not None:
run_manager.on_llm_error( run_manager.on_llm_error(
@ -468,7 +468,7 @@ class BedrockBase(BaseModel, ABC):
if ( if (
self._guardrails_enabled self._guardrails_enabled
and self.guardrails.get("trace") and self.guardrails.get("trace") # type: ignore[union-attr]
and self._is_guardrails_intervention(body) and self._is_guardrails_intervention(body)
): ):
return { return {
@ -526,7 +526,7 @@ class BedrockBase(BaseModel, ABC):
if self._guardrails_enabled: if self._guardrails_enabled:
request_options["guardrail"] = "ENABLED" request_options["guardrail"] = "ENABLED"
if self.guardrails.get("trace"): if self.guardrails.get("trace"): # type: ignore[union-attr]
request_options["trace"] = "ENABLED" request_options["trace"] = "ENABLED"
try: try:
@ -540,7 +540,7 @@ class BedrockBase(BaseModel, ABC):
): ):
yield chunk yield chunk
# verify and raise callback error if any middleware intervened # verify and raise callback error if any middleware intervened
self._get_bedrock_services_signal(chunk.generation_info) self._get_bedrock_services_signal(chunk.generation_info) # type: ignore[arg-type]
if run_manager is not None: if run_manager is not None:
run_manager.on_llm_new_token(chunk.text, chunk=chunk) run_manager.on_llm_new_token(chunk.text, chunk=chunk)
@ -588,7 +588,7 @@ class BedrockBase(BaseModel, ABC):
): ):
await run_manager.on_llm_new_token(chunk.text, chunk=chunk) await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
elif run_manager is not None: elif run_manager is not None:
run_manager.on_llm_new_token(chunk.text, chunk=chunk) run_manager.on_llm_new_token(chunk.text, chunk=chunk) # type: ignore[unused-coroutine]
class Bedrock(LLM, BedrockBase): class Bedrock(LLM, BedrockBase):

View File

@ -42,10 +42,10 @@ class OCIGenAIBase(BaseModel, ABC):
If not specified , DEFAULT will be used If not specified , DEFAULT will be used
""" """
model_id: str = None model_id: str = None # type: ignore[assignment]
"""Id of the model to call, e.g., cohere.command""" """Id of the model to call, e.g., cohere.command"""
provider: str = None provider: str = None # type: ignore[assignment]
"""Provider name of the model. Default to None, """Provider name of the model. Default to None,
will try to be derived from the model_id will try to be derived from the model_id
otherwise, requires user input otherwise, requires user input
@ -54,10 +54,10 @@ class OCIGenAIBase(BaseModel, ABC):
model_kwargs: Optional[Dict] = None model_kwargs: Optional[Dict] = None
"""Keyword arguments to pass to the model""" """Keyword arguments to pass to the model"""
service_endpoint: str = None service_endpoint: str = None # type: ignore[assignment]
"""service endpoint url""" """service endpoint url"""
compartment_id: str = None compartment_id: str = None # type: ignore[assignment]
"""OCID of compartment""" """OCID of compartment"""
is_stream: bool = False is_stream: bool = False
@ -94,7 +94,7 @@ class OCIGenAIBase(BaseModel, ABC):
client_kwargs.pop("signer", None) client_kwargs.pop("signer", None)
elif values["auth_type"] == OCIAuthType(2).name: elif values["auth_type"] == OCIAuthType(2).name:
def make_security_token_signer(oci_config): def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
pk = oci.signer.load_private_key_from_file( pk = oci.signer.load_private_key_from_file(
oci_config.get("key_file"), None oci_config.get("key_file"), None
) )

View File

@ -297,7 +297,7 @@ class _OllamaCommon(BaseLanguageModel):
"Ollama call failed with status code 404." "Ollama call failed with status code 404."
) )
else: else:
optional_detail = await response.json().get("error") optional_detail = await response.json().get("error") # type: ignore[attr-defined]
raise ValueError( raise ValueError(
f"Ollama call failed with status code {response.status}." f"Ollama call failed with status code {response.status}."
f" Details: {optional_detail}" f" Details: {optional_detail}"
@ -380,7 +380,7 @@ class Ollama(BaseLLM, _OllamaCommon):
"""Return type of llm.""" """Return type of llm."""
return "ollama-llm" return "ollama-llm"
def _generate( def _generate( # type: ignore[override]
self, self,
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
@ -416,7 +416,7 @@ class Ollama(BaseLLM, _OllamaCommon):
generations.append([final_chunk]) generations.append([final_chunk])
return LLMResult(generations=generations) return LLMResult(generations=generations)
async def _agenerate( async def _agenerate( # type: ignore[override]
self, self,
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
@ -445,7 +445,7 @@ class Ollama(BaseLLM, _OllamaCommon):
prompt, prompt,
stop=stop, stop=stop,
images=images, images=images,
run_manager=run_manager, run_manager=run_manager, # type: ignore[arg-type]
verbose=self.verbose, verbose=self.verbose,
**kwargs, **kwargs,
) )

View File

@ -102,7 +102,7 @@ class PipelineAI(LLM, BaseModel):
"Could not import pipeline-ai python package. " "Could not import pipeline-ai python package. "
"Please install it with `pip install pipeline-ai`." "Please install it with `pip install pipeline-ai`."
) )
client = PipelineCloud(token=self.pipeline_api_key.get_secret_value()) client = PipelineCloud(token=self.pipeline_api_key.get_secret_value()) # type: ignore[union-attr]
params = self.pipeline_kwargs or {} params = self.pipeline_kwargs or {}
params = {**params, **kwargs} params = {**params, **kwargs}

View File

@ -107,7 +107,7 @@ class StochasticAI(LLM):
url=self.api_url, url=self.api_url,
json={"prompt": prompt, "params": params}, json={"prompt": prompt, "params": params},
headers={ headers={
"apiKey": f"{self.stochasticai_api_key.get_secret_value()}", "apiKey": f"{self.stochasticai_api_key.get_secret_value()}", # type: ignore[union-attr]
"Accept": "application/json", "Accept": "application/json",
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
@ -119,7 +119,7 @@ class StochasticAI(LLM):
response_get = requests.get( response_get = requests.get(
url=response_post_json["data"]["responseUrl"], url=response_post_json["data"]["responseUrl"],
headers={ headers={
"apiKey": f"{self.stochasticai_api_key.get_secret_value()}", "apiKey": f"{self.stochasticai_api_key.get_secret_value()}", # type: ignore[union-attr]
"Accept": "application/json", "Accept": "application/json",
"Content-Type": "application/json", "Content-Type": "application/json",
}, },

View File

@ -49,7 +49,7 @@ def is_gemini_model(model_name: str) -> bool:
return model_name is not None and "gemini" in model_name return model_name is not None and "gemini" in model_name
def completion_with_retry( def completion_with_retry( # type: ignore[no-redef]
llm: VertexAI, llm: VertexAI,
prompt: List[Union[str, "Image"]], prompt: List[Union[str, "Image"]],
stream: bool = False, stream: bool = False,
@ -330,7 +330,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
generation += chunk generation += chunk
generations.append([generation]) generations.append([generation])
else: else:
res = completion_with_retry( res = completion_with_retry( # type: ignore[misc]
self, self,
[prompt], [prompt],
stream=should_stream, stream=should_stream,
@ -373,7 +373,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
**kwargs: Any, **kwargs: Any,
) -> Iterator[GenerationChunk]: ) -> Iterator[GenerationChunk]:
params = self._prepare_params(stop=stop, stream=True, **kwargs) params = self._prepare_params(stop=stop, stream=True, **kwargs)
for stream_resp in completion_with_retry( for stream_resp in completion_with_retry( # type: ignore[misc]
self, self,
[prompt], [prompt],
stream=True, stream=True,

View File

@ -250,9 +250,9 @@ class WatsonxLLM(BaseLLM):
} }
def _get_chat_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]: def _get_chat_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
params: Dict[str, Any] = {**self.params} if self.params else None params: Dict[str, Any] = {**self.params} if self.params else {}
if stop is not None: if stop is not None:
params = (params or {}) | {"stop_sequences": stop} params["stop_sequences"] = stop
return params return params
def _create_llm_result(self, response: List[dict]) -> LLMResult: def _create_llm_result(self, response: List[dict]) -> LLMResult:

View File

@ -25,10 +25,10 @@ logger = logging.getLogger(__name__)
class _BaseYandexGPT(Serializable): class _BaseYandexGPT(Serializable):
iam_token: SecretStr = "" iam_token: SecretStr = "" # type: ignore[assignment]
"""Yandex Cloud IAM token for service or user account """Yandex Cloud IAM token for service or user account
with the `ai.languageModels.user` role""" with the `ai.languageModels.user` role"""
api_key: SecretStr = "" api_key: SecretStr = "" # type: ignore[assignment]
"""Yandex Cloud Api Key for service account """Yandex Cloud Api Key for service account
with the `ai.languageModels.user` role""" with the `ai.languageModels.user` role"""
folder_id: str = "" folder_id: str = ""
@ -211,7 +211,7 @@ def _make_request(
messages=[Message(role="user", text=prompt)], messages=[Message(role="user", text=prompt)],
) )
stub = TextGenerationServiceStub(channel) stub = TextGenerationServiceStub(channel)
res = stub.Completion(request, metadata=self._grpc_metadata) res = stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
return list(res)[0].alternatives[0].message.text return list(res)[0].alternatives[0].message.text
@ -253,7 +253,7 @@ async def _amake_request(self: YandexGPT, prompt: str) -> str:
messages=[Message(role="user", text=prompt)], messages=[Message(role="user", text=prompt)],
) )
stub = TextGenerationAsyncServiceStub(channel) stub = TextGenerationAsyncServiceStub(channel)
operation = await stub.Completion(request, metadata=self._grpc_metadata) operation = await stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
async with grpc.aio.secure_channel( async with grpc.aio.secure_channel(
operation_api_url, channel_credentials operation_api_url, channel_credentials
) as operation_channel: ) as operation_channel:
@ -262,7 +262,8 @@ async def _amake_request(self: YandexGPT, prompt: str) -> str:
await asyncio.sleep(1) await asyncio.sleep(1)
operation_request = GetOperationRequest(operation_id=operation.id) operation_request = GetOperationRequest(operation_id=operation.id)
operation = await operation_stub.Get( operation = await operation_stub.Get(
operation_request, metadata=self._grpc_metadata operation_request,
metadata=self._grpc_metadata, # type: ignore[attr-defined]
) )
completion_response = CompletionResponse() completion_response = CompletionResponse()

View File

@ -58,4 +58,4 @@ class AmadeusClosestAirport(AmadeusBaseTool):
' Location Identifier" ' ' Location Identifier" '
) )
return self.llm.invoke(content) return self.llm.invoke(content) # type: ignore[union-attr]

View File

@ -93,10 +93,10 @@ class ShellTool(BaseTool):
return self.process.run(commands) return self.process.run(commands)
else: else:
logger.info("Invalid input. User aborted command execution.") logger.info("Invalid input. User aborted command execution.")
return None return None # type: ignore[return-value]
else: else:
return self.process.run(commands) return self.process.run(commands)
except Exception as e: except Exception as e:
logger.error(f"Error during command execution: {e}") logger.error(f"Error during command execution: {e}")
return None return None # type: ignore[return-value]

View File

@ -48,7 +48,7 @@ class BraveSearchWrapper(BaseModel):
results = self._search_request(query) results = self._search_request(query)
return [ return [
Document( Document(
page_content=item.get("description"), page_content=item.get("description"), # type: ignore[arg-type]
metadata={"title": item.get("title"), "link": item.get("url")}, metadata={"title": item.get("title"), "link": item.get("url")},
) )
for item in results for item in results

View File

@ -141,9 +141,9 @@ class GenericRequestsWrapper(BaseModel):
self, response: aiohttp.ClientResponse self, response: aiohttp.ClientResponse
) -> Union[str, Dict[str, Any]]: ) -> Union[str, Dict[str, Any]]:
if self.response_content_type == "text": if self.response_content_type == "text":
return response.text() return response.text() # type: ignore[return-value]
elif self.response_content_type == "json": elif self.response_content_type == "json":
return response.json() return response.json() # type: ignore[return-value]
else: else:
raise ValueError(f"Invalid return type: {self.response_content_type}") raise ValueError(f"Invalid return type: {self.response_content_type}")
@ -176,33 +176,33 @@ class GenericRequestsWrapper(BaseModel):
async def aget(self, url: str, **kwargs: Any) -> Union[str, Dict[str, Any]]: async def aget(self, url: str, **kwargs: Any) -> Union[str, Dict[str, Any]]:
"""GET the URL and return the text asynchronously.""" """GET the URL and return the text asynchronously."""
async with self.requests.aget(url, **kwargs) as response: async with self.requests.aget(url, **kwargs) as response:
return await self._aget_resp_content(response) return await self._aget_resp_content(response) # type: ignore[misc]
async def apost( async def apost(
self, url: str, data: Dict[str, Any], **kwargs: Any self, url: str, data: Dict[str, Any], **kwargs: Any
) -> Union[str, Dict[str, Any]]: ) -> Union[str, Dict[str, Any]]:
"""POST to the URL and return the text asynchronously.""" """POST to the URL and return the text asynchronously."""
async with self.requests.apost(url, data, **kwargs) as response: async with self.requests.apost(url, data, **kwargs) as response:
return await self._aget_resp_content(response) return await self._aget_resp_content(response) # type: ignore[misc]
async def apatch( async def apatch(
self, url: str, data: Dict[str, Any], **kwargs: Any self, url: str, data: Dict[str, Any], **kwargs: Any
) -> Union[str, Dict[str, Any]]: ) -> Union[str, Dict[str, Any]]:
"""PATCH the URL and return the text asynchronously.""" """PATCH the URL and return the text asynchronously."""
async with self.requests.apatch(url, data, **kwargs) as response: async with self.requests.apatch(url, data, **kwargs) as response:
return await self._aget_resp_content(response) return await self._aget_resp_content(response) # type: ignore[misc]
async def aput( async def aput(
self, url: str, data: Dict[str, Any], **kwargs: Any self, url: str, data: Dict[str, Any], **kwargs: Any
) -> Union[str, Dict[str, Any]]: ) -> Union[str, Dict[str, Any]]:
"""PUT the URL and return the text asynchronously.""" """PUT the URL and return the text asynchronously."""
async with self.requests.aput(url, data, **kwargs) as response: async with self.requests.aput(url, data, **kwargs) as response:
return await self._aget_resp_content(response) return await self._aget_resp_content(response) # type: ignore[misc]
async def adelete(self, url: str, **kwargs: Any) -> Union[str, Dict[str, Any]]: async def adelete(self, url: str, **kwargs: Any) -> Union[str, Dict[str, Any]]:
"""DELETE the URL and return the text asynchronously.""" """DELETE the URL and return the text asynchronously."""
async with self.requests.adelete(url, **kwargs) as response: async with self.requests.adelete(url, **kwargs) as response:
return await self._aget_resp_content(response) return await self._aget_resp_content(response) # type: ignore[misc]
class JsonRequestsWrapper(GenericRequestsWrapper): class JsonRequestsWrapper(GenericRequestsWrapper):

View File

@ -381,7 +381,7 @@ class SQLDatabase:
If the statement returns no rows, an empty list is returned. If the statement returns no rows, an empty list is returned.
""" """
with self._engine.begin() as connection: # type: Connection with self._engine.begin() as connection: # type: Connection # type: ignore[name-defined]
if self._schema is not None: if self._schema is not None:
if self.dialect == "snowflake": if self.dialect == "snowflake":
connection.exec_driver_sql( connection.exec_driver_sql(
@ -444,7 +444,7 @@ class SQLDatabase:
] ]
if not include_columns: if not include_columns:
res = [tuple(row.values()) for row in res] res = [tuple(row.values()) for row in res] # type: ignore[misc]
if not res: if not res:
return "" return ""

View File

@ -356,7 +356,7 @@ class AlibabaCloudOpenSearch(VectorStore):
"fields" not in item "fields" not in item
or self.config.field_name_mapping["document"] not in item["fields"] or self.config.field_name_mapping["document"] not in item["fields"]
): ):
query_result_list.append(Document()) query_result_list.append(Document()) # type: ignore[call-arg]
else: else:
fields = item["fields"] fields = item["fields"]
query_result_list.append( query_result_list.append(

View File

@ -140,7 +140,7 @@ class AstraDB(VectorStore):
if isinstance(v, list): if isinstance(v, list):
metadata_filter[k] = [AstraDB._filter_to_metadata(f) for f in v] metadata_filter[k] = [AstraDB._filter_to_metadata(f) for f in v]
else: else:
metadata_filter[k] = AstraDB._filter_to_metadata(v) metadata_filter[k] = AstraDB._filter_to_metadata(v) # type: ignore[assignment]
else: else:
metadata_filter[f"metadata.{k}"] = v metadata_filter[f"metadata.{k}"] = v
@ -253,13 +253,13 @@ class AstraDB(VectorStore):
else: else:
self.clear() self.clear()
def _ensure_astra_db_client(self): def _ensure_astra_db_client(self): # type: ignore[no-untyped-def]
if not self.astra_db: if not self.astra_db:
raise ValueError("Missing AstraDB client") raise ValueError("Missing AstraDB client")
async def _setup_db(self, pre_delete_collection: bool) -> None: async def _setup_db(self, pre_delete_collection: bool) -> None:
if pre_delete_collection: if pre_delete_collection:
await self.async_astra_db.delete_collection( await self.async_astra_db.delete_collection( # type: ignore[union-attr]
collection_name=self.collection_name, collection_name=self.collection_name,
) )
await self._aprovision_collection() await self._aprovision_collection()
@ -282,7 +282,7 @@ class AstraDB(VectorStore):
Internal-usage method, no object members are set, Internal-usage method, no object members are set,
other than working on the underlying actual storage. other than working on the underlying actual storage.
""" """
self.astra_db.create_collection( self.astra_db.create_collection( # type: ignore[union-attr]
dimension=self._get_embedding_dimension(), dimension=self._get_embedding_dimension(),
collection_name=self.collection_name, collection_name=self.collection_name,
metric=self.metric, metric=self.metric,
@ -295,7 +295,7 @@ class AstraDB(VectorStore):
Internal-usage method, no object members are set, Internal-usage method, no object members are set,
other than working on the underlying actual storage. other than working on the underlying actual storage.
""" """
await self.async_astra_db.create_collection( await self.async_astra_db.create_collection( # type: ignore[union-attr]
dimension=self._get_embedding_dimension(), dimension=self._get_embedding_dimension(),
collection_name=self.collection_name, collection_name=self.collection_name,
metric=self.metric, metric=self.metric,
@ -328,7 +328,7 @@ class AstraDB(VectorStore):
await self._ensure_db_setup() await self._ensure_db_setup()
if not self.async_astra_db: if not self.async_astra_db:
await run_in_executor(None, self.clear) await run_in_executor(None, self.clear)
await self.async_collection.delete_many({}) await self.async_collection.delete_many({}) # type: ignore[union-attr]
def delete_by_document_id(self, document_id: str) -> bool: def delete_by_document_id(self, document_id: str) -> bool:
""" """
@ -336,7 +336,7 @@ class AstraDB(VectorStore):
Return True if a document has indeed been deleted, False if ID not found. Return True if a document has indeed been deleted, False if ID not found.
""" """
self._ensure_astra_db_client() self._ensure_astra_db_client()
deletion_response = self.collection.delete_one(document_id) deletion_response = self.collection.delete_one(document_id) # type: ignore[union-attr]
return ((deletion_response or {}).get("status") or {}).get( return ((deletion_response or {}).get("status") or {}).get(
"deletedCount", 0 "deletedCount", 0
) == 1 ) == 1
@ -434,7 +434,7 @@ class AstraDB(VectorStore):
Use with caution. Use with caution.
""" """
self._ensure_astra_db_client() self._ensure_astra_db_client()
self.astra_db.delete_collection( self.astra_db.delete_collection( # type: ignore[union-attr]
collection_name=self.collection_name, collection_name=self.collection_name,
) )
@ -448,7 +448,7 @@ class AstraDB(VectorStore):
await self._ensure_db_setup() await self._ensure_db_setup()
if not self.async_astra_db: if not self.async_astra_db:
await run_in_executor(None, self.delete_collection) await run_in_executor(None, self.delete_collection)
await self.async_astra_db.delete_collection( await self.async_astra_db.delete_collection( # type: ignore[union-attr]
collection_name=self.collection_name, collection_name=self.collection_name,
) )
@ -571,7 +571,7 @@ class AstraDB(VectorStore):
) )
def _handle_batch(document_batch: List[DocDict]) -> List[str]: def _handle_batch(document_batch: List[DocDict]) -> List[str]:
im_result = self.collection.insert_many( im_result = self.collection.insert_many( # type: ignore[union-attr]
documents=document_batch, documents=document_batch,
options={"ordered": False}, options={"ordered": False},
partial_failures_allowed=True, partial_failures_allowed=True,
@ -581,7 +581,7 @@ class AstraDB(VectorStore):
) )
def _handle_missing_document(missing_document: DocDict) -> str: def _handle_missing_document(missing_document: DocDict) -> str:
replacement_result = self.collection.find_one_and_replace( replacement_result = self.collection.find_one_and_replace( # type: ignore[union-attr]
filter={"_id": missing_document["_id"]}, filter={"_id": missing_document["_id"]},
replacement=missing_document, replacement=missing_document,
) )
@ -672,7 +672,7 @@ class AstraDB(VectorStore):
) )
async def _handle_batch(document_batch: List[DocDict]) -> List[str]: async def _handle_batch(document_batch: List[DocDict]) -> List[str]:
im_result = await self.async_collection.insert_many( im_result = await self.async_collection.insert_many( # type: ignore[union-attr]
documents=document_batch, documents=document_batch,
options={"ordered": False}, options={"ordered": False},
partial_failures_allowed=True, partial_failures_allowed=True,
@ -682,7 +682,7 @@ class AstraDB(VectorStore):
) )
async def _handle_missing_document(missing_document: DocDict) -> str: async def _handle_missing_document(missing_document: DocDict) -> str:
replacement_result = await self.async_collection.find_one_and_replace( replacement_result = await self.async_collection.find_one_and_replace( # type: ignore[union-attr]
filter={"_id": missing_document["_id"]}, filter={"_id": missing_document["_id"]},
replacement=missing_document, replacement=missing_document,
) )
@ -729,7 +729,7 @@ class AstraDB(VectorStore):
metadata_parameter = self._filter_to_metadata(filter) metadata_parameter = self._filter_to_metadata(filter)
# #
hits = list( hits = list(
self.collection.paginated_find( self.collection.paginated_find( # type: ignore[union-attr]
filter=metadata_parameter, filter=metadata_parameter,
sort={"$vector": embedding}, sort={"$vector": embedding},
options={"limit": k, "includeSimilarity": True}, options={"limit": k, "includeSimilarity": True},
@ -771,7 +771,7 @@ class AstraDB(VectorStore):
if not self.async_collection: if not self.async_collection:
return await run_in_executor( return await run_in_executor(
None, None,
self.asimilarity_search_with_score_id_by_vector, self.asimilarity_search_with_score_id_by_vector, # type: ignore[arg-type]
embedding, embedding,
k, k,
filter, filter,
@ -962,7 +962,7 @@ class AstraDB(VectorStore):
) )
@staticmethod @staticmethod
def _get_mmr_hits(embedding, k, lambda_mult, prefetch_hits): def _get_mmr_hits(embedding, k, lambda_mult, prefetch_hits): # type: ignore[no-untyped-def]
mmr_chosen_indices = maximal_marginal_relevance( mmr_chosen_indices = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32), np.array(embedding, dtype=np.float32),
[prefetch_hit["$vector"] for prefetch_hit in prefetch_hits], [prefetch_hit["$vector"] for prefetch_hit in prefetch_hits],
@ -1008,7 +1008,7 @@ class AstraDB(VectorStore):
metadata_parameter = self._filter_to_metadata(filter) metadata_parameter = self._filter_to_metadata(filter)
prefetch_hits = list( prefetch_hits = list(
self.collection.paginated_find( self.collection.paginated_find( # type: ignore[union-attr]
filter=metadata_parameter, filter=metadata_parameter,
sort={"$vector": embedding}, sort={"$vector": embedding},
options={"limit": fetch_k, "includeSimilarity": True}, options={"limit": fetch_k, "includeSimilarity": True},
@ -1228,7 +1228,7 @@ class AstraDB(VectorStore):
batch_concurrency=kwargs.get("batch_concurrency"), batch_concurrency=kwargs.get("batch_concurrency"),
overwrite_concurrency=kwargs.get("overwrite_concurrency"), overwrite_concurrency=kwargs.get("overwrite_concurrency"),
) )
return astra_db_store return astra_db_store # type: ignore[return-value]
@classmethod @classmethod
async def afrom_texts( async def afrom_texts(
@ -1263,7 +1263,7 @@ class AstraDB(VectorStore):
batch_concurrency=kwargs.get("batch_concurrency"), batch_concurrency=kwargs.get("batch_concurrency"),
overwrite_concurrency=kwargs.get("overwrite_concurrency"), overwrite_concurrency=kwargs.get("overwrite_concurrency"),
) )
return astra_db_store return astra_db_store # type: ignore[return-value]
@classmethod @classmethod
def from_documents( def from_documents(

View File

@ -339,7 +339,7 @@ class AzureSearch(VectorStore):
# batching support if embedding function is an Embeddings object # batching support if embedding function is an Embeddings object
if isinstance(self.embedding_function, Embeddings): if isinstance(self.embedding_function, Embeddings):
try: try:
embeddings = self.embedding_function.embed_documents(texts) embeddings = self.embedding_function.embed_documents(texts) # type: ignore[arg-type]
except NotImplementedError: except NotImplementedError:
embeddings = [self.embedding_function.embed_query(x) for x in texts] embeddings = [self.embedding_function.embed_query(x) for x in texts]
else: else:

View File

@ -222,7 +222,7 @@ class BigQueryVectorSearch(VectorStore):
self._logger.debug("Vector index already exists.") self._logger.debug("Vector index already exists.")
self._have_index = True self._have_index = True
def _create_index_in_background(self): def _create_index_in_background(self): # type: ignore[no-untyped-def]
if self._have_index or self._creating_index: if self._have_index or self._creating_index:
# Already have an index or in the process of creating one. # Already have an index or in the process of creating one.
return return
@ -231,7 +231,7 @@ class BigQueryVectorSearch(VectorStore):
thread = Thread(target=self._create_index, daemon=True) thread = Thread(target=self._create_index, daemon=True)
thread.start() thread.start()
def _create_index(self): def _create_index(self): # type: ignore[no-untyped-def]
from google.api_core.exceptions import ClientError from google.api_core.exceptions import ClientError
table = self.bq_client.get_table(self.vectors_table) table = self.bq_client.get_table(self.vectors_table)
@ -289,7 +289,7 @@ class BigQueryVectorSearch(VectorStore):
def full_table_id(self) -> str: def full_table_id(self) -> str:
return self._full_table_id return self._full_table_id
def add_texts( def add_texts( # type: ignore[override]
self, self,
texts: List[str], texts: List[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,

View File

@ -905,7 +905,7 @@ class DeepLake(VectorStore):
return self.vectorstore.dataset return self.vectorstore.dataset
@classmethod @classmethod
def _validate_kwargs(cls, kwargs, method_name): def _validate_kwargs(cls, kwargs, method_name): # type: ignore[no-untyped-def]
if kwargs: if kwargs:
valid_items = cls._get_valid_args(method_name) valid_items = cls._get_valid_args(method_name)
unsupported_items = cls._get_unsupported_items(kwargs, valid_items) unsupported_items = cls._get_unsupported_items(kwargs, valid_items)
@ -917,14 +917,14 @@ class DeepLake(VectorStore):
) )
@classmethod @classmethod
def _get_valid_args(cls, method_name): def _get_valid_args(cls, method_name): # type: ignore[no-untyped-def]
if method_name == "search": if method_name == "search":
return cls._valid_search_kwargs return cls._valid_search_kwargs
else: else:
return [] return []
@staticmethod @staticmethod
def _get_unsupported_items(kwargs, valid_items): def _get_unsupported_items(kwargs, valid_items): # type: ignore[no-untyped-def]
kwargs = {k: v for k, v in kwargs.items() if k not in valid_items} kwargs = {k: v for k, v in kwargs.items() if k not in valid_items}
unsupported_items = None unsupported_items = None
if kwargs: if kwargs:

View File

@ -305,7 +305,7 @@ class FAISS(VectorStore):
if filter is not None: if filter is not None:
if isinstance(filter, dict): if isinstance(filter, dict):
def filter_func(metadata): def filter_func(metadata): # type: ignore[no-untyped-def]
if all( if all(
metadata.get(key) in value metadata.get(key) in value
if isinstance(value, list) if isinstance(value, list)
@ -607,7 +607,7 @@ class FAISS(VectorStore):
filtered_indices = [] filtered_indices = []
if isinstance(filter, dict): if isinstance(filter, dict):
def filter_func(metadata): def filter_func(metadata): # type: ignore[no-untyped-def]
if all( if all(
metadata.get(key) in value metadata.get(key) in value
if isinstance(value, list) if isinstance(value, list)

View File

@ -117,7 +117,7 @@ class HanaDB(VectorStore):
self.vector_column_length, self.vector_column_length,
) )
def _table_exists(self, table_name) -> bool: def _table_exists(self, table_name) -> bool: # type: ignore[no-untyped-def]
sql_str = ( sql_str = (
"SELECT COUNT(*) FROM SYS.TABLES WHERE SCHEMA_NAME = CURRENT_SCHEMA" "SELECT COUNT(*) FROM SYS.TABLES WHERE SCHEMA_NAME = CURRENT_SCHEMA"
" AND TABLE_NAME = ?" " AND TABLE_NAME = ?"
@ -133,7 +133,7 @@ class HanaDB(VectorStore):
cur.close() cur.close()
return False return False
def _check_column(self, table_name, column_name, column_type, column_length=None): def _check_column(self, table_name, column_name, column_type, column_length=None): # type: ignore[no-untyped-def]
sql_str = ( sql_str = (
"SELECT DATA_TYPE_NAME, LENGTH FROM SYS.TABLE_COLUMNS WHERE " "SELECT DATA_TYPE_NAME, LENGTH FROM SYS.TABLE_COLUMNS WHERE "
"SCHEMA_NAME = CURRENT_SCHEMA " "SCHEMA_NAME = CURRENT_SCHEMA "
@ -166,17 +166,17 @@ class HanaDB(VectorStore):
def embeddings(self) -> Embeddings: def embeddings(self) -> Embeddings:
return self.embedding return self.embedding
def _sanitize_name(input_str: str) -> str: def _sanitize_name(input_str: str) -> str: # type: ignore[misc]
# Remove characters that are not alphanumeric or underscores # Remove characters that are not alphanumeric or underscores
return re.sub(r"[^a-zA-Z0-9_]", "", input_str) return re.sub(r"[^a-zA-Z0-9_]", "", input_str)
def _sanitize_int(input_int: any) -> int: def _sanitize_int(input_int: any) -> int: # type: ignore[valid-type]
value = int(str(input_int)) value = int(str(input_int))
if value < -1: if value < -1:
raise ValueError(f"Value ({value}) must not be smaller than -1") raise ValueError(f"Value ({value}) must not be smaller than -1")
return int(str(input_int)) return int(str(input_int))
def _sanitize_list_float(embedding: List[float]) -> List[float]: def _sanitize_list_float(embedding: List[float]) -> List[float]: # type: ignore[misc]
for value in embedding: for value in embedding:
if not isinstance(value, float): if not isinstance(value, float):
raise ValueError(f"Value ({value}) does not have type float") raise ValueError(f"Value ({value}) does not have type float")
@ -185,14 +185,14 @@ class HanaDB(VectorStore):
# Compile pattern only once, for better performance # Compile pattern only once, for better performance
_compiled_pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$") _compiled_pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$")
def _sanitize_metadata_keys(metadata: dict) -> dict: def _sanitize_metadata_keys(metadata: dict) -> dict: # type: ignore[misc]
for key in metadata.keys(): for key in metadata.keys():
if not HanaDB._compiled_pattern.match(key): if not HanaDB._compiled_pattern.match(key):
raise ValueError(f"Invalid metadata key {key}") raise ValueError(f"Invalid metadata key {key}")
return metadata return metadata
def add_texts( def add_texts( # type: ignore[override]
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
@ -243,7 +243,7 @@ class HanaDB(VectorStore):
return [] return []
@classmethod @classmethod
def from_texts( def from_texts( # type: ignore[no-untyped-def, override]
cls: Type[HanaDB], cls: Type[HanaDB],
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
@ -277,7 +277,7 @@ class HanaDB(VectorStore):
instance.add_texts(texts, metadatas) instance.add_texts(texts, metadatas)
return instance return instance
def similarity_search( def similarity_search( # type: ignore[override]
self, query: str, k: int = 4, filter: Optional[dict] = None self, query: str, k: int = 4, filter: Optional[dict] = None
) -> List[Document]: ) -> List[Document]:
"""Return docs most similar to query. """Return docs most similar to query.
@ -382,7 +382,7 @@ class HanaDB(VectorStore):
) )
return [(result_item[0], result_item[1]) for result_item in whole_result] return [(result_item[0], result_item[1]) for result_item in whole_result]
def similarity_search_by_vector( def similarity_search_by_vector( # type: ignore[override]
self, embedding: List[float], k: int = 4, filter: Optional[dict] = None self, embedding: List[float], k: int = 4, filter: Optional[dict] = None
) -> List[Document]: ) -> List[Document]:
"""Return docs most similar to embedding vector. """Return docs most similar to embedding vector.
@ -401,7 +401,7 @@ class HanaDB(VectorStore):
) )
return [doc for doc, _ in docs_and_scores] return [doc for doc, _ in docs_and_scores]
def _create_where_by_filter(self, filter): def _create_where_by_filter(self, filter): # type: ignore[no-untyped-def]
query_tuple = [] query_tuple = []
where_str = "" where_str = ""
if filter: if filter:
@ -427,7 +427,7 @@ class HanaDB(VectorStore):
return where_str, query_tuple return where_str, query_tuple
def delete( def delete( # type: ignore[override]
self, ids: Optional[List[str]] = None, filter: Optional[dict] = None self, ids: Optional[List[str]] = None, filter: Optional[dict] = None
) -> Optional[bool]: ) -> Optional[bool]:
"""Delete entries by filter with metadata values """Delete entries by filter with metadata values
@ -459,7 +459,7 @@ class HanaDB(VectorStore):
return True return True
async def adelete( async def adelete( # type: ignore[override]
self, ids: Optional[List[str]] = None, filter: Optional[dict] = None self, ids: Optional[List[str]] = None, filter: Optional[dict] = None
) -> Optional[bool]: ) -> Optional[bool]:
"""Delete by vector ID or other criteria. """Delete by vector ID or other criteria.
@ -473,7 +473,7 @@ class HanaDB(VectorStore):
""" """
return await run_in_executor(None, self.delete, ids=ids, filter=filter) return await run_in_executor(None, self.delete, ids=ids, filter=filter)
def max_marginal_relevance_search( def max_marginal_relevance_search( # type: ignore[override]
self, self,
query: str, query: str,
k: int = 4, k: int = 4,
@ -511,11 +511,11 @@ class HanaDB(VectorStore):
filter=filter, filter=filter,
) )
def _parse_float_array_from_string(array_as_string: str) -> List[float]: def _parse_float_array_from_string(array_as_string: str) -> List[float]: # type: ignore[misc]
array_wo_brackets = array_as_string[1:-1] array_wo_brackets = array_as_string[1:-1]
return [float(x) for x in array_wo_brackets.split(",")] return [float(x) for x in array_wo_brackets.split(",")]
def max_marginal_relevance_search_by_vector( def max_marginal_relevance_search_by_vector( # type: ignore[override]
self, self,
embedding: List[float], embedding: List[float],
k: int = 4, k: int = 4,
@ -533,7 +533,7 @@ class HanaDB(VectorStore):
return [whole_result[i][0] for i in mmr_doc_indexes] return [whole_result[i][0] for i in mmr_doc_indexes]
async def amax_marginal_relevance_search_by_vector( async def amax_marginal_relevance_search_by_vector( # type: ignore[override]
self, self,
embedding: List[float], embedding: List[float],
k: int = 4, k: int = 4,

View File

@ -135,7 +135,7 @@ class Jaguar(VectorStore):
def embeddings(self) -> Optional[Embeddings]: def embeddings(self) -> Optional[Embeddings]:
return self._embedding return self._embedding
def add_texts( def add_texts( # type: ignore[override]
self, self,
texts: List[str], texts: List[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
@ -351,7 +351,7 @@ class Jaguar(VectorStore):
return False return False
@classmethod @classmethod
def from_texts( def from_texts( # type: ignore[override]
cls, cls,
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
@ -383,7 +383,7 @@ class Jaguar(VectorStore):
q = "truncate store " + podstore q = "truncate store " + podstore
self.run(q) self.run(q)
def delete(self, zids: List[str], **kwargs: Any) -> None: def delete(self, zids: List[str], **kwargs: Any) -> None: # type: ignore[override]
""" """
Delete records in jaguardb by a list of zero-ids Delete records in jaguardb by a list of zero-ids
Args: Args:

View File

@ -554,10 +554,10 @@ class Milvus(VectorStore):
} }
if not self.auto_id: if not self.auto_id:
insert_dict[self._primary_field] = ids insert_dict[self._primary_field] = ids # type: ignore[assignment]
if self._metadata_field is not None: if self._metadata_field is not None:
for d in metadatas: for d in metadatas: # type: ignore[union-attr]
insert_dict.setdefault(self._metadata_field, []).append(d) insert_dict.setdefault(self._metadata_field, []).append(d)
else: else:
# Collect the metadata into the insert dict. # Collect the metadata into the insert dict.
@ -901,7 +901,7 @@ class Milvus(VectorStore):
ret.append(documents[x]) ret.append(documents[x])
return ret return ret
def delete( def delete( # type: ignore[no-untyped-def]
self, ids: Optional[List[str]] = None, expr: Optional[str] = None, **kwargs: str self, ids: Optional[List[str]] = None, expr: Optional[str] = None, **kwargs: str
): ):
"""Delete by vector ID or boolean expression. """Delete by vector ID or boolean expression.
@ -923,7 +923,7 @@ class Milvus(VectorStore):
assert isinstance( assert isinstance(
expr, str expr, str
), "Either ids list or expr string must be provided." ), "Either ids list or expr string must be provided."
return self.col.delete(expr=expr, **kwargs) return self.col.delete(expr=expr, **kwargs) # type: ignore[union-attr]
@classmethod @classmethod
def from_texts( def from_texts(

View File

@ -398,7 +398,7 @@ class PGEmbedding(VectorStore):
docs = [ docs = [
( (
Document( Document(
page_content=result.EmbeddingStore.document, page_content=result.EmbeddingStore.document, # type: ignore[arg-type]
metadata=result.EmbeddingStore.cmetadata, metadata=result.EmbeddingStore.cmetadata,
), ),
result.distance if self.embedding_function is not None else 0.0, result.distance if self.embedding_function is not None else 0.0,

View File

@ -133,7 +133,7 @@ class PGVecto_rs(VectorStore):
Record.from_text(text, embedding, meta) Record.from_text(text, embedding, meta)
for text, embedding, meta in zip(texts, embeddings, metadatas or []) for text, embedding, meta in zip(texts, embeddings, metadatas or [])
] ]
self._store.insert(records) self._store.insert(records) # type: ignore[union-attr]
return [str(record.id) for record in records] return [str(record.id) for record in records]
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
@ -177,7 +177,7 @@ class PGVecto_rs(VectorStore):
real_filter = meta_contains(filter) real_filter = meta_contains(filter)
else: else:
real_filter = filter real_filter = filter
results = self._store.search( results = self._store.search( # type: ignore[union-attr]
query_vector, query_vector,
distance_func_map[distance_func], distance_func_map[distance_func],
k, k,

View File

@ -238,7 +238,7 @@ class PGVector(VectorStore):
def create_vector_extension(self) -> None: def create_vector_extension(self) -> None:
try: try:
with Session(self._bind) as session: with Session(self._bind) as session: # type: ignore[arg-type]
# The advisor lock fixes issue arising from concurrent # The advisor lock fixes issue arising from concurrent
# creation of the vector extension. # creation of the vector extension.
# https://github.com/langchain-ai/langchain/issues/12933 # https://github.com/langchain-ai/langchain/issues/12933
@ -256,24 +256,24 @@ class PGVector(VectorStore):
raise Exception(f"Failed to create vector extension: {e}") from e raise Exception(f"Failed to create vector extension: {e}") from e
def create_tables_if_not_exists(self) -> None: def create_tables_if_not_exists(self) -> None:
with Session(self._bind) as session, session.begin(): with Session(self._bind) as session, session.begin(): # type: ignore[arg-type]
Base.metadata.create_all(session.get_bind()) Base.metadata.create_all(session.get_bind())
def drop_tables(self) -> None: def drop_tables(self) -> None:
with Session(self._bind) as session, session.begin(): with Session(self._bind) as session, session.begin(): # type: ignore[arg-type]
Base.metadata.drop_all(session.get_bind()) Base.metadata.drop_all(session.get_bind())
def create_collection(self) -> None: def create_collection(self) -> None:
if self.pre_delete_collection: if self.pre_delete_collection:
self.delete_collection() self.delete_collection()
with Session(self._bind) as session: with Session(self._bind) as session: # type: ignore[arg-type]
self.CollectionStore.get_or_create( self.CollectionStore.get_or_create(
session, self.collection_name, cmetadata=self.collection_metadata session, self.collection_name, cmetadata=self.collection_metadata
) )
def delete_collection(self) -> None: def delete_collection(self) -> None:
self.logger.debug("Trying to delete collection") self.logger.debug("Trying to delete collection")
with Session(self._bind) as session: with Session(self._bind) as session: # type: ignore[arg-type]
collection = self.get_collection(session) collection = self.get_collection(session)
if not collection: if not collection:
self.logger.warning("Collection not found") self.logger.warning("Collection not found")
@ -284,7 +284,7 @@ class PGVector(VectorStore):
@contextlib.contextmanager @contextlib.contextmanager
def _make_session(self) -> Generator[Session, None, None]: def _make_session(self) -> Generator[Session, None, None]:
"""Create a context manager for the session, bind to _conn string.""" """Create a context manager for the session, bind to _conn string."""
yield Session(self._bind) yield Session(self._bind) # type: ignore[arg-type]
def delete( def delete(
self, self,
@ -298,7 +298,7 @@ class PGVector(VectorStore):
ids: List of ids to delete. ids: List of ids to delete.
collection_only: Only delete ids in the collection. collection_only: Only delete ids in the collection.
""" """
with Session(self._bind) as session: with Session(self._bind) as session: # type: ignore[arg-type]
if ids is not None: if ids is not None:
self.logger.debug( self.logger.debug(
"Trying to delete vectors by ids (represented by the model " "Trying to delete vectors by ids (represented by the model "
@ -383,7 +383,7 @@ class PGVector(VectorStore):
if not metadatas: if not metadatas:
metadatas = [{} for _ in texts] metadatas = [{} for _ in texts]
with Session(self._bind) as session: with Session(self._bind) as session: # type: ignore[arg-type]
collection = self.get_collection(session) collection = self.get_collection(session)
if not collection: if not collection:
raise ValueError("Collection not found") raise ValueError("Collection not found")
@ -508,7 +508,7 @@ class PGVector(VectorStore):
] ]
return docs return docs
def _create_filter_clause(self, key, value): def _create_filter_clause(self, key, value): # type: ignore[no-untyped-def]
IN, NIN, BETWEEN, GT, LT, NE = "in", "nin", "between", "gt", "lt", "ne" IN, NIN, BETWEEN, GT, LT, NE = "in", "nin", "between", "gt", "lt", "ne"
EQ, LIKE, CONTAINS, OR, AND = "eq", "like", "contains", "or", "and" EQ, LIKE, CONTAINS, OR, AND = "eq", "like", "contains", "or", "and"
@ -575,7 +575,7 @@ class PGVector(VectorStore):
filter: Optional[Dict[str, str]] = None, filter: Optional[Dict[str, str]] = None,
) -> List[Any]: ) -> List[Any]:
"""Query the collection.""" """Query the collection."""
with Session(self._bind) as session: with Session(self._bind) as session: # type: ignore[arg-type]
collection = self.get_collection(session) collection = self.get_collection(session)
if not collection: if not collection:
raise ValueError("Collection not found") raise ValueError("Collection not found")

View File

@ -115,7 +115,7 @@ class SurrealDBStore(VectorStore):
for idx, text in enumerate(texts): for idx, text in enumerate(texts):
data = {"text": text, "embedding": embeddings[idx]} data = {"text": text, "embedding": embeddings[idx]}
if metadatas is not None and idx < len(metadatas): if metadatas is not None and idx < len(metadatas):
data["metadata"] = metadatas[idx] data["metadata"] = metadatas[idx] # type: ignore[assignment]
record = await self.sdb.create( record = await self.sdb.create(
self.collection, self.collection,
data, data,

View File

@ -316,7 +316,7 @@ class TencentVectorDB(VectorStore):
meta = result.get(self.field_metadata) meta = result.get(self.field_metadata)
if meta is not None: if meta is not None:
meta = json.loads(meta) meta = json.loads(meta)
doc = Document(page_content=result.get(self.field_text), metadata=meta) doc = Document(page_content=result.get(self.field_text), metadata=meta) # type: ignore[arg-type]
pair = (doc, result.get("score", 0.0)) pair = (doc, result.get("score", 0.0))
ret.append(pair) ret.append(pair)
return ret return ret
@ -374,7 +374,7 @@ class TencentVectorDB(VectorStore):
meta = result.get(self.field_metadata) meta = result.get(self.field_metadata)
if meta is not None: if meta is not None:
meta = json.loads(meta) meta = json.loads(meta)
doc = Document(page_content=result.get(self.field_text), metadata=meta) doc = Document(page_content=result.get(self.field_text), metadata=meta) # type: ignore[arg-type]
documents.append(doc) documents.append(doc)
ordered_result_embeddings.append(result.get(self.field_vector)) ordered_result_embeddings.append(result.get(self.field_vector))
# Get the new order of results. # Get the new order of results.

View File

@ -24,7 +24,7 @@ class NeuralDBVectorStore(VectorStore):
underscore_attrs_are_private = True underscore_attrs_are_private = True
@staticmethod @staticmethod
def _verify_thirdai_library(thirdai_key: Optional[str] = None): def _verify_thirdai_library(thirdai_key: Optional[str] = None): # type: ignore[no-untyped-def]
try: try:
from thirdai import licensing from thirdai import licensing
@ -38,7 +38,7 @@ class NeuralDBVectorStore(VectorStore):
) )
@classmethod @classmethod
def from_scratch( def from_scratch( # type: ignore[no-untyped-def, no-untyped-def]
cls, cls,
thirdai_key: Optional[str] = None, thirdai_key: Optional[str] = None,
**model_kwargs, **model_kwargs,
@ -69,10 +69,10 @@ class NeuralDBVectorStore(VectorStore):
NeuralDBVectorStore._verify_thirdai_library(thirdai_key) NeuralDBVectorStore._verify_thirdai_library(thirdai_key)
from thirdai import neural_db as ndb from thirdai import neural_db as ndb
return cls(db=ndb.NeuralDB(**model_kwargs)) return cls(db=ndb.NeuralDB(**model_kwargs)) # type: ignore[call-arg]
@classmethod @classmethod
def from_bazaar( def from_bazaar( # type: ignore[no-untyped-def]
cls, cls,
base: str, base: str,
bazaar_cache: Optional[str] = None, bazaar_cache: Optional[str] = None,
@ -111,10 +111,10 @@ class NeuralDBVectorStore(VectorStore):
os.mkdir(cache) os.mkdir(cache)
model_bazaar = ndb.Bazaar(cache) model_bazaar = ndb.Bazaar(cache)
model_bazaar.fetch() model_bazaar.fetch()
return cls(db=model_bazaar.get_model(base)) return cls(db=model_bazaar.get_model(base)) # type: ignore[call-arg]
@classmethod @classmethod
def from_checkpoint( def from_checkpoint( # type: ignore[no-untyped-def]
cls, cls,
checkpoint: Union[str, Path], checkpoint: Union[str, Path],
thirdai_key: Optional[str] = None, thirdai_key: Optional[str] = None,
@ -146,7 +146,7 @@ class NeuralDBVectorStore(VectorStore):
NeuralDBVectorStore._verify_thirdai_library(thirdai_key) NeuralDBVectorStore._verify_thirdai_library(thirdai_key)
from thirdai import neural_db as ndb from thirdai import neural_db as ndb
return cls(db=ndb.NeuralDB.from_checkpoint(checkpoint)) return cls(db=ndb.NeuralDB.from_checkpoint(checkpoint)) # type: ignore[call-arg]
@classmethod @classmethod
def from_texts( def from_texts(
@ -187,11 +187,11 @@ class NeuralDBVectorStore(VectorStore):
df = pd.DataFrame({"texts": texts}) df = pd.DataFrame({"texts": texts})
if metadatas: if metadatas:
df = pd.concat([df, pd.DataFrame.from_records(metadatas)], axis=1) df = pd.concat([df, pd.DataFrame.from_records(metadatas)], axis=1)
temp = tempfile.NamedTemporaryFile("w", delete=False, delete_on_close=False) temp = tempfile.NamedTemporaryFile("w", delete=False, delete_on_close=False) # type: ignore[call-overload]
df.to_csv(temp) df.to_csv(temp)
source_id = self.insert([ndb.CSV(temp.name)], **kwargs)[0] source_id = self.insert([ndb.CSV(temp.name)], **kwargs)[0]
offset = self.db._savable_state.documents.get_source_by_id(source_id)[1] offset = self.db._savable_state.documents.get_source_by_id(source_id)[1]
return [str(offset + i) for i in range(len(texts))] return [str(offset + i) for i in range(len(texts))] # type: ignore[arg-type]
@root_validator() @root_validator()
def validate_environments(cls, values: Dict) -> Dict: def validate_environments(cls, values: Dict) -> Dict:
@ -205,7 +205,7 @@ class NeuralDBVectorStore(VectorStore):
) )
return values return values
def insert( def insert( # type: ignore[no-untyped-def, no-untyped-def]
self, self,
sources: List[Any], sources: List[Any],
train: bool = True, train: bool = True,
@ -229,7 +229,7 @@ class NeuralDBVectorStore(VectorStore):
**kwargs, **kwargs,
) )
def _preprocess_sources(self, sources): def _preprocess_sources(self, sources): # type: ignore[no-untyped-def]
"""Checks if the provided sources are string paths. If they are, convert """Checks if the provided sources are string paths. If they are, convert
to NeuralDB document objects. to NeuralDB document objects.
@ -261,7 +261,7 @@ class NeuralDBVectorStore(VectorStore):
) )
return preprocessed_sources return preprocessed_sources
def upvote(self, query: str, document_id: Union[int, str]): def upvote(self, query: str, document_id: Union[int, str]): # type: ignore[no-untyped-def]
"""The vectorstore upweights the score of a document for a specific query. """The vectorstore upweights the score of a document for a specific query.
This is useful for fine-tuning the vectorstore to user behavior. This is useful for fine-tuning the vectorstore to user behavior.
@ -271,7 +271,7 @@ class NeuralDBVectorStore(VectorStore):
""" """
self.db.text_to_result(query, int(document_id)) self.db.text_to_result(query, int(document_id))
def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]): def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]): # type: ignore[no-untyped-def]
"""Given a batch of (query, document id) pairs, the vectorstore upweights """Given a batch of (query, document id) pairs, the vectorstore upweights
the scores of the document for the corresponding queries. the scores of the document for the corresponding queries.
This is useful for fine-tuning the vectorstore to user behavior. This is useful for fine-tuning the vectorstore to user behavior.
@ -284,7 +284,7 @@ class NeuralDBVectorStore(VectorStore):
[(query, int(doc_id)) for query, doc_id in query_id_pairs] [(query, int(doc_id)) for query, doc_id in query_id_pairs]
) )
def associate(self, source: str, target: str): def associate(self, source: str, target: str): # type: ignore[no-untyped-def]
"""The vectorstore associates a source phrase with a target phrase. """The vectorstore associates a source phrase with a target phrase.
When the vectorstore sees the source phrase, it will also consider results When the vectorstore sees the source phrase, it will also consider results
that are relevant to the target phrase. that are relevant to the target phrase.
@ -295,7 +295,7 @@ class NeuralDBVectorStore(VectorStore):
""" """
self.db.associate(source, target) self.db.associate(source, target)
def associate_batch(self, text_pairs: List[Tuple[str, str]]): def associate_batch(self, text_pairs: List[Tuple[str, str]]): # type: ignore[no-untyped-def]
"""Given a batch of (source, target) pairs, the vectorstore associates """Given a batch of (source, target) pairs, the vectorstore associates
each source phrase with the corresponding target phrase. each source phrase with the corresponding target phrase.
@ -334,7 +334,7 @@ class NeuralDBVectorStore(VectorStore):
except Exception as e: except Exception as e:
raise ValueError(f"Error while retrieving documents: {e}") from e raise ValueError(f"Error while retrieving documents: {e}") from e
def save(self, path: str): def save(self, path: str): # type: ignore[no-untyped-def]
"""Saves a NeuralDB instance to disk. Can be loaded into memory by """Saves a NeuralDB instance to disk. Can be loaded into memory by
calling NeuralDB.from_checkpoint(path) calling NeuralDB.from_checkpoint(path)

View File

@ -384,7 +384,7 @@ class Vectara(VectorStore):
f"(code {response.status_code}, reason {response.reason}, details " f"(code {response.status_code}, reason {response.reason}, details "
f"{response.text})", f"{response.text})",
) )
return [], "" return [], "" # type: ignore[return-value]
result = response.json() result = response.json()
@ -454,7 +454,7 @@ class Vectara(VectorStore):
docs = self.vectara_query(query, config) docs = self.vectara_query(query, config)
return docs return docs
def similarity_search( def similarity_search( # type: ignore[override]
self, self,
query: str, query: str,
**kwargs: Any, **kwargs: Any,
@ -474,7 +474,7 @@ class Vectara(VectorStore):
) )
return [doc for doc, _ in docs_and_scores] return [doc for doc, _ in docs_and_scores]
def max_marginal_relevance_search( def max_marginal_relevance_search( # type: ignore[override]
self, self,
query: str, query: str,
fetch_k: int = 50, fetch_k: int = 50,

View File

@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
class VikingDBConfig(object): class VikingDBConfig(object):
def __init__(self, host="host", region="region", ak="ak", sk="sk", scheme="http"): def __init__(self, host="host", region="region", ak="ak", sk="sk", scheme="http"): # type: ignore[no-untyped-def]
self.host = host self.host = host
self.region = region self.region = region
self.ak = ak self.ak = ak
@ -47,11 +47,11 @@ class VikingDB(VectorStore):
self.index_params = index_params self.index_params = index_params
self.drop_old = drop_old self.drop_old = drop_old
self.service = VikingDBService( self.service = VikingDBService(
connection_args.host, connection_args.host, # type: ignore[union-attr]
connection_args.region, connection_args.region, # type: ignore[union-attr]
connection_args.ak, connection_args.ak, # type: ignore[union-attr]
connection_args.sk, connection_args.sk, # type: ignore[union-attr]
connection_args.scheme, connection_args.scheme, # type: ignore[union-attr]
) )
try: try:
@ -143,7 +143,7 @@ class VikingDB(VectorStore):
scalar_index=scalar_index, scalar_index=scalar_index,
) )
def add_texts( def add_texts( # type: ignore[override]
self, self,
texts: List[str], texts: List[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
@ -183,7 +183,7 @@ class VikingDB(VectorStore):
if metadatas is not None and index < len(metadatas): if metadatas is not None and index < len(metadatas):
names = list(metadatas[index].keys()) names = list(metadatas[index].keys())
for name in names: for name in names:
field[name] = metadatas[index].get(name) field[name] = metadatas[index].get(name) # type: ignore[assignment]
data.append(Data(field)) data.append(Data(field))
total_count = len(data) total_count = len(data)
@ -191,10 +191,10 @@ class VikingDB(VectorStore):
end = min(i + batch_size, total_count) end = min(i + batch_size, total_count)
insert_data = data[i:end] insert_data = data[i:end]
# print(insert_data) # print(insert_data)
self.collection.upsert_data(insert_data) self.collection.upsert_data(insert_data) # type: ignore[union-attr]
return pks return pks
def similarity_search( def similarity_search( # type: ignore[override]
self, self,
query: str, query: str,
params: Optional[dict] = None, params: Optional[dict] = None,
@ -216,7 +216,7 @@ class VikingDB(VectorStore):
) )
return res return res
def similarity_search_by_vector( def similarity_search_by_vector( # type: ignore[override]
self, self,
embedding: List[float], embedding: List[float],
params: Optional[dict] = None, params: Optional[dict] = None,
@ -251,7 +251,7 @@ class VikingDB(VectorStore):
if params.get("partition") is not None: if params.get("partition") is not None:
partition = params["partition"] partition = params["partition"]
res = self.index.search_by_vector( res = self.index.search_by_vector( # type: ignore[union-attr]
embedding, embedding,
filter=filter, filter=filter,
limit=limit, limit=limit,
@ -269,7 +269,7 @@ class VikingDB(VectorStore):
ret.append(pair) ret.append(pair)
return ret return ret
def max_marginal_relevance_search( def max_marginal_relevance_search( # type: ignore[override]
self, self,
query: str, query: str,
k: int = 4, k: int = 4,
@ -286,7 +286,7 @@ class VikingDB(VectorStore):
**kwargs, **kwargs,
) )
def max_marginal_relevance_search_by_vector( def max_marginal_relevance_search_by_vector( # type: ignore[override]
self, self,
embedding: List[float], embedding: List[float],
k: int = 4, k: int = 4,
@ -311,7 +311,7 @@ class VikingDB(VectorStore):
if params.get("partition") is not None: if params.get("partition") is not None:
partition = params["partition"] partition = params["partition"]
res = self.index.search_by_vector( res = self.index.search_by_vector( # type: ignore[union-attr]
embedding, embedding,
filter=filter, filter=filter,
limit=limit, limit=limit,
@ -347,10 +347,10 @@ class VikingDB(VectorStore):
) -> None: ) -> None:
if self.collection is None: if self.collection is None:
logger.debug("No existing collection to search.") logger.debug("No existing collection to search.")
self.collection.delete_data(ids) self.collection.delete_data(ids) # type: ignore[union-attr]
@classmethod @classmethod
def from_texts( def from_texts( # type: ignore[no-untyped-def, override]
cls, cls,
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. # This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
[[package]] [[package]]
name = "aenum" name = "aenum"
@ -3944,7 +3944,7 @@ files = [
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "0.1.17" version = "0.1.18"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -9252,4 +9252,4 @@ extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "as
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "6e1aabbf689bf7294ffc3f9215559157b95868275421d776862ddb1499969c79" content-hash = "1ab63edcddcef2deb01e6fff5c376f7b0773435bb9d5b55bc1d50d19a8f1dee2"

View File

@ -101,7 +101,7 @@ optional = true
# dependencies used for running tests (e.g., pytest, freezegun, response). # dependencies used for running tests (e.g., pytest, freezegun, response).
# Any dependencies that do not meet that criteria will be removed. # Any dependencies that do not meet that criteria will be removed.
pytest = "^7.3.0" pytest = "^7.3.0"
pytest-cov = "^4.0.0" pytest-cov = "^4.1.0"
pytest-dotenv = "^0.5.2" pytest-dotenv = "^0.5.2"
duckdb-engine = "^0.9.2" duckdb-engine = "^0.9.2"
pytest-watcher = "^0.2.6" pytest-watcher = "^0.2.6"

View File

@ -57,7 +57,7 @@ def test_add_messages() -> None:
assert len(message_store_another.messages) == 0 assert len(message_store_another.messages) == 0
def test_tidb_recent_chat_message(): def test_tidb_recent_chat_message(): # type: ignore[no-untyped-def]
"""Test the TiDBChatMessageHistory with earliest_time parameter.""" """Test the TiDBChatMessageHistory with earliest_time parameter."""
import time import time
from datetime import datetime from datetime import datetime

View File

@ -40,7 +40,7 @@ def test_konko_key_masked_when_passed_via_constructor(
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == "**********" assert captured.out == "**********"
print(chat.konko_secret_key, end="") print(chat.konko_secret_key, end="") # type: ignore[attr-defined]
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == "**********" assert captured.out == "**********"
@ -49,7 +49,7 @@ def test_uses_actual_secret_value_from_secret_str() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`.""" """Test that actual secret is retrieved using `.get_secret_value()`."""
chat = ChatKonko(openai_api_key="test-openai-key", konko_api_key="test-konko-key") chat = ChatKonko(openai_api_key="test-openai-key", konko_api_key="test-konko-key")
assert cast(SecretStr, chat.konko_api_key).get_secret_value() == "test-openai-key" assert cast(SecretStr, chat.konko_api_key).get_secret_value() == "test-openai-key"
assert cast(SecretStr, chat.konko_secret_key).get_secret_value() == "test-konko-key" assert cast(SecretStr, chat.konko_secret_key).get_secret_value() == "test-konko-key" # type: ignore[attr-defined]
def test_konko_chat_test() -> None: def test_konko_chat_test() -> None:

View File

@ -47,6 +47,6 @@ def test_chat_wasm_service_streaming() -> None:
output = "" output = ""
for chunk in chat.stream(messages): for chunk in chat.stream(messages):
print(chunk.content, end="", flush=True) print(chunk.content, end="", flush=True)
output += chunk.content output += chunk.content # type: ignore[operator]
assert "Paris" in output assert "Paris" in output

View File

@ -167,5 +167,5 @@ class TestAstraDB:
find_options={"limit": 30}, find_options={"limit": 30},
extraction_function=lambda x: x["foo"], extraction_function=lambda x: x["foo"],
) )
doc = await anext(loader.alazy_load()) doc = await anext(loader.alazy_load()) # type: ignore[name-defined]
assert doc.page_content == "bar" assert doc.page_content == "bar"

View File

@ -14,7 +14,7 @@ CASSANDRA_TABLE = "docloader_test_table"
@pytest.fixture(autouse=True, scope="session") @pytest.fixture(autouse=True, scope="session")
def keyspace() -> str: def keyspace() -> str: # type: ignore[misc]
import cassio import cassio
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
from cassio.config import check_resolve_session, resolve_keyspace from cassio.config import check_resolve_session, resolve_keyspace

View File

@ -7,8 +7,8 @@ def test_baichuan_embedding_documents() -> None:
documents = ["今天天气不错", "今天阳光灿烂"] documents = ["今天天气不错", "今天阳光灿烂"]
embedding = BaichuanTextEmbeddings() embedding = BaichuanTextEmbeddings()
output = embedding.embed_documents(documents) output = embedding.embed_documents(documents)
assert len(output) == 2 assert len(output) == 2 # type: ignore[arg-type]
assert len(output[0]) == 1024 assert len(output[0]) == 1024 # type: ignore[index]
def test_baichuan_embedding_query() -> None: def test_baichuan_embedding_query() -> None:
@ -16,4 +16,4 @@ def test_baichuan_embedding_query() -> None:
document = "所有的小学生都会学过只因兔同笼问题。" document = "所有的小学生都会学过只因兔同笼问题。"
embedding = BaichuanTextEmbeddings() embedding = BaichuanTextEmbeddings()
output = embedding.embed_query(document) output = embedding.embed_query(document)
assert len(output) == 1024 assert len(output) == 1024 # type: ignore[arg-type]

View File

@ -85,7 +85,7 @@ def test_neo4j_timeout() -> None:
graph.query("UNWIND range(0,100000,1) AS i MERGE (:Foo {id:i})") graph.query("UNWIND range(0,100000,1) AS i MERGE (:Foo {id:i})")
except Exception as e: except Exception as e:
assert ( assert (
e.code e.code # type: ignore[attr-defined]
== "Neo.ClientError.Transaction.TransactionTimedOutClientConfiguration" == "Neo.ClientError.Transaction.TransactionTimedOutClientConfiguration"
) )

View File

@ -62,7 +62,7 @@ def test_custom_formatter() -> None:
content_type = "application/json" content_type = "application/json"
accepts = "application/json" accepts = "application/json"
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
input_str = json.dumps( input_str = json.dumps(
{ {
"inputs": [prompt], "inputs": [prompt],
@ -72,7 +72,7 @@ def test_custom_formatter() -> None:
) )
return input_str.encode("utf-8") return input_str.encode("utf-8")
def format_response_payload(self, output: bytes) -> str: def format_response_payload(self, output: bytes) -> str: # type: ignore[override]
response_json = json.loads(output) response_json = json.loads(output)
return response_json[0]["summary_text"] return response_json[0]["summary_text"]
@ -104,7 +104,7 @@ def test_invalid_request_format() -> None:
content_type = "application/json" content_type = "application/json"
accepts = "application/json" accepts = "application/json"
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
input_str = json.dumps( input_str = json.dumps(
{ {
"incorrect_input": {"input_string": [prompt]}, "incorrect_input": {"input_string": [prompt]},
@ -113,7 +113,7 @@ def test_invalid_request_format() -> None:
) )
return str.encode(input_str) return str.encode(input_str)
def format_response_payload(self, output: bytes) -> str: def format_response_payload(self, output: bytes) -> str: # type: ignore[override]
response_json = json.loads(output) response_json = json.loads(output)
return response_json[0]["0"] return response_json[0]["0"]

View File

@ -37,12 +37,12 @@ class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
if reason == "GUARDRAIL_INTERVENED": if reason == "GUARDRAIL_INTERVENED":
self.guardrails_intervened = True self.guardrails_intervened = True
def get_response(self): def get_response(self): # type: ignore[no-untyped-def]
return self.guardrails_intervened return self.guardrails_intervened
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def bedrock_runtime_client(): def bedrock_runtime_client(): # type: ignore[no-untyped-def]
import boto3 import boto3
try: try:
@ -56,7 +56,7 @@ def bedrock_runtime_client():
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def bedrock_client(): def bedrock_client(): # type: ignore[no-untyped-def]
import boto3 import boto3
try: try:
@ -70,7 +70,7 @@ def bedrock_client():
@pytest.fixture @pytest.fixture
def bedrock_models(bedrock_client): def bedrock_models(bedrock_client): # type: ignore[no-untyped-def]
"""List bedrock models.""" """List bedrock models."""
response = bedrock_client.list_foundation_models().get("modelSummaries") response = bedrock_client.list_foundation_models().get("modelSummaries")
models = {} models = {}
@ -79,7 +79,7 @@ def bedrock_models(bedrock_client):
return models return models
def test_claude_instant_v1(bedrock_runtime_client, bedrock_models): def test_claude_instant_v1(bedrock_runtime_client, bedrock_models): # type: ignore[no-untyped-def]
try: try:
llm = Bedrock( llm = Bedrock(
model_id="anthropic.claude-instant-v1", model_id="anthropic.claude-instant-v1",
@ -92,7 +92,7 @@ def test_claude_instant_v1(bedrock_runtime_client, bedrock_models):
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False) pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
def test_amazon_bedrock_guardrails_no_intervention_for_valid_query( def test_amazon_bedrock_guardrails_no_intervention_for_valid_query( # type: ignore[no-untyped-def]
bedrock_runtime_client, bedrock_models bedrock_runtime_client, bedrock_models
): ):
try: try:
@ -112,7 +112,7 @@ def test_amazon_bedrock_guardrails_no_intervention_for_valid_query(
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False) pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
def test_amazon_bedrock_guardrails_intervention_for_invalid_query( def test_amazon_bedrock_guardrails_intervention_for_invalid_query( # type: ignore[no-untyped-def]
bedrock_runtime_client, bedrock_models bedrock_runtime_client, bedrock_models
): ):
try: try:

View File

@ -16,7 +16,7 @@ def _has_env_vars() -> bool:
@pytest.fixture @pytest.fixture
def astra_db(): def astra_db(): # type: ignore[no-untyped-def]
from astrapy.db import AstraDB from astrapy.db import AstraDB
return AstraDB( return AstraDB(
@ -26,14 +26,14 @@ def astra_db():
) )
def init_store(astra_db, collection_name: str): def init_store(astra_db, collection_name: str): # type: ignore[no-untyped-def, no-untyped-def]
astra_db.create_collection(collection_name) astra_db.create_collection(collection_name)
store = AstraDBStore(collection_name=collection_name, astra_db_client=astra_db) store = AstraDBStore(collection_name=collection_name, astra_db_client=astra_db)
store.mset([("key1", [0.1, 0.2]), ("key2", "value2")]) store.mset([("key1", [0.1, 0.2]), ("key2", "value2")])
return store return store
def init_bytestore(astra_db, collection_name: str): def init_bytestore(astra_db, collection_name: str): # type: ignore[no-untyped-def, no-untyped-def]
astra_db.create_collection(collection_name) astra_db.create_collection(collection_name)
store = AstraDBByteStore(collection_name=collection_name, astra_db_client=astra_db) store = AstraDBByteStore(collection_name=collection_name, astra_db_client=astra_db)
store.mset([("key1", b"value1"), ("key2", b"value2")]) store.mset([("key1", b"value1"), ("key2", b"value2")])
@ -43,7 +43,7 @@ def init_bytestore(astra_db, collection_name: str):
@pytest.mark.requires("astrapy") @pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") @pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
class TestAstraDBStore: class TestAstraDBStore:
def test_mget(self, astra_db) -> None: def test_mget(self, astra_db) -> None: # type: ignore[no-untyped-def]
"""Test AstraDBStore mget method.""" """Test AstraDBStore mget method."""
collection_name = "lc_test_store_mget" collection_name = "lc_test_store_mget"
try: try:
@ -52,7 +52,7 @@ class TestAstraDBStore:
finally: finally:
astra_db.delete_collection(collection_name) astra_db.delete_collection(collection_name)
def test_mset(self, astra_db) -> None: def test_mset(self, astra_db) -> None: # type: ignore[no-untyped-def]
"""Test that multiple keys can be set with AstraDBStore.""" """Test that multiple keys can be set with AstraDBStore."""
collection_name = "lc_test_store_mset" collection_name = "lc_test_store_mset"
try: try:
@ -64,7 +64,7 @@ class TestAstraDBStore:
finally: finally:
astra_db.delete_collection(collection_name) astra_db.delete_collection(collection_name)
def test_mdelete(self, astra_db) -> None: def test_mdelete(self, astra_db) -> None: # type: ignore[no-untyped-def]
"""Test that deletion works as expected.""" """Test that deletion works as expected."""
collection_name = "lc_test_store_mdelete" collection_name = "lc_test_store_mdelete"
try: try:
@ -75,7 +75,7 @@ class TestAstraDBStore:
finally: finally:
astra_db.delete_collection(collection_name) astra_db.delete_collection(collection_name)
def test_yield_keys(self, astra_db) -> None: def test_yield_keys(self, astra_db) -> None: # type: ignore[no-untyped-def]
collection_name = "lc_test_store_yield_keys" collection_name = "lc_test_store_yield_keys"
try: try:
store = init_store(astra_db, collection_name) store = init_store(astra_db, collection_name)
@ -85,7 +85,7 @@ class TestAstraDBStore:
finally: finally:
astra_db.delete_collection(collection_name) astra_db.delete_collection(collection_name)
def test_bytestore_mget(self, astra_db) -> None: def test_bytestore_mget(self, astra_db) -> None: # type: ignore[no-untyped-def]
"""Test AstraDBByteStore mget method.""" """Test AstraDBByteStore mget method."""
collection_name = "lc_test_bytestore_mget" collection_name = "lc_test_bytestore_mget"
try: try:
@ -94,7 +94,7 @@ class TestAstraDBStore:
finally: finally:
astra_db.delete_collection(collection_name) astra_db.delete_collection(collection_name)
def test_bytestore_mset(self, astra_db) -> None: def test_bytestore_mset(self, astra_db) -> None: # type: ignore[no-untyped-def]
"""Test that multiple keys can be set with AstraDBByteStore.""" """Test that multiple keys can be set with AstraDBByteStore."""
collection_name = "lc_test_bytestore_mset" collection_name = "lc_test_bytestore_mset"
try: try:

View File

@ -6,7 +6,7 @@ from langchain_community.utilities.google_trends import GoogleTrendsAPIWrapper
@patch("serpapi.SerpApiClient.get_json") @patch("serpapi.SerpApiClient.get_json")
def test_unexpected_response(mocked_serpapiclient): def test_unexpected_response(mocked_serpapiclient): # type: ignore[no-untyped-def]
os.environ["SERPAPI_API_KEY"] = "123abcd" os.environ["SERPAPI_API_KEY"] = "123abcd"
resp = { resp = {
"search_metadata": { "search_metadata": {

View File

@ -15,7 +15,7 @@ def qdrant_is_not_running() -> bool:
return True return True
def assert_documents_equals(actual: List[Document], expected: List[Document]): def assert_documents_equals(actual: List[Document], expected: List[Document]): # type: ignore[no-untyped-def]
assert len(actual) == len(expected) assert len(actual) == len(expected)
for actual_doc, expected_doc in zip(actual, expected): for actual_doc, expected_doc in zip(actual, expected):

View File

@ -32,7 +32,7 @@ def store(request: pytest.FixtureRequest) -> BigQueryVectorSearch:
TestBigQueryVectorStore.dataset_name, exists_ok=True TestBigQueryVectorStore.dataset_name, exists_ok=True
) )
TestBigQueryVectorStore.store = BigQueryVectorSearch( TestBigQueryVectorStore.store = BigQueryVectorSearch(
project_id=os.environ.get("PROJECT", None), project_id=os.environ.get("PROJECT", None), # type: ignore[arg-type]
embedding=FakeEmbeddings(), embedding=FakeEmbeddings(),
dataset_name=TestBigQueryVectorStore.dataset_name, dataset_name=TestBigQueryVectorStore.dataset_name,
table_name=TEST_TABLE_NAME, table_name=TEST_TABLE_NAME,

View File

@ -52,7 +52,7 @@ def test_deeplake_with_metadatas() -> None:
assert output == [Document(page_content="foo", metadata={"page": "0"})] assert output == [Document(page_content="foo", metadata={"page": "0"})]
def test_deeplake_with_persistence(deeplake_datastore) -> None: def test_deeplake_with_persistence(deeplake_datastore) -> None: # type: ignore[no-untyped-def]
"""Test end to end construction and search, with persistence.""" """Test end to end construction and search, with persistence."""
output = deeplake_datastore.similarity_search("foo", k=1) output = deeplake_datastore.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": "0"})] assert output == [Document(page_content="foo", metadata={"page": "0"})]
@ -72,7 +72,7 @@ def test_deeplake_with_persistence(deeplake_datastore) -> None:
# Or on program exit # Or on program exit
def test_deeplake_overwrite_flag(deeplake_datastore) -> None: def test_deeplake_overwrite_flag(deeplake_datastore) -> None: # type: ignore[no-untyped-def]
"""Test overwrite behavior""" """Test overwrite behavior"""
dataset_path = deeplake_datastore.vectorstore.dataset_handler.path dataset_path = deeplake_datastore.vectorstore.dataset_handler.path
@ -108,7 +108,7 @@ def test_deeplake_overwrite_flag(deeplake_datastore) -> None:
output = docsearch.similarity_search("foo", k=1) output = docsearch.similarity_search("foo", k=1)
def test_similarity_search(deeplake_datastore) -> None: def test_similarity_search(deeplake_datastore) -> None: # type: ignore[no-untyped-def]
"""Test similarity search.""" """Test similarity search."""
distance_metric = "cos" distance_metric = "cos"
output = deeplake_datastore.similarity_search( output = deeplake_datastore.similarity_search(

View File

@ -38,7 +38,7 @@ embedding = NormalizedFakeEmbeddings()
class ConfigData: class ConfigData:
def __init__(self): def __init__(self): # type: ignore[no-untyped-def]
self.conn = None self.conn = None
self.schema_name = "" self.schema_name = ""
@ -46,7 +46,7 @@ class ConfigData:
test_setup = ConfigData() test_setup = ConfigData()
def generateSchemaName(cursor): def generateSchemaName(cursor): # type: ignore[no-untyped-def]
cursor.execute( cursor.execute(
"SELECT REPLACE(CURRENT_UTCDATE, '-', '') || '_' || BINTOHEX(SYSUUID) FROM " "SELECT REPLACE(CURRENT_UTCDATE, '-', '') || '_' || BINTOHEX(SYSUUID) FROM "
"DUMMY;" "DUMMY;"
@ -59,7 +59,7 @@ def generateSchemaName(cursor):
return f"VEC_{uid}" return f"VEC_{uid}"
def setup_module(module): def setup_module(module): # type: ignore[no-untyped-def]
test_setup.conn = dbapi.connect( test_setup.conn = dbapi.connect(
address=os.environ.get("HANA_DB_ADDRESS"), address=os.environ.get("HANA_DB_ADDRESS"),
port=os.environ.get("HANA_DB_PORT"), port=os.environ.get("HANA_DB_PORT"),
@ -81,7 +81,7 @@ def setup_module(module):
cur.close() cur.close()
def teardown_module(module): def teardown_module(module): # type: ignore[no-untyped-def]
try: try:
cur = test_setup.conn.cursor() cur = test_setup.conn.cursor()
sql_str = f"DROP SCHEMA {test_setup.schema_name} CASCADE" sql_str = f"DROP SCHEMA {test_setup.schema_name} CASCADE"
@ -100,13 +100,13 @@ def texts() -> List[str]:
@pytest.fixture @pytest.fixture
def metadatas() -> List[str]: def metadatas() -> List[str]:
return [ return [
{"start": 0, "end": 100, "quality": "good", "ready": True}, {"start": 0, "end": 100, "quality": "good", "ready": True}, # type: ignore[list-item]
{"start": 100, "end": 200, "quality": "bad", "ready": False}, {"start": 100, "end": 200, "quality": "bad", "ready": False}, # type: ignore[list-item]
{"start": 200, "end": 300, "quality": "ugly", "ready": True}, {"start": 200, "end": 300, "quality": "ugly", "ready": True}, # type: ignore[list-item]
] ]
def drop_table(connection, table_name): def drop_table(connection, table_name): # type: ignore[no-untyped-def]
try: try:
cur = connection.cursor() cur = connection.cursor()
sql_str = f"DROP TABLE {table_name}" sql_str = f"DROP TABLE {table_name}"
@ -825,7 +825,7 @@ def test_hanavector_filter_prepared_statement_params(
rows = cur.fetchall() rows = cur.fetchall()
assert len(rows) == 1 assert len(rows) == 1
query_value = "good" query_value = "good" # type: ignore[assignment]
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.quality') = ?" sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.quality') = ?"
cur.execute(sql_str, (query_value)) cur.execute(sql_str, (query_value))
rows = cur.fetchall() rows = cur.fetchall()
@ -839,14 +839,14 @@ def test_hanavector_filter_prepared_statement_params(
assert len(rows) == 1 assert len(rows) == 1
# query_value = True # query_value = True
query_value = "true" query_value = "true" # type: ignore[assignment]
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?" sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?"
cur.execute(sql_str, (query_value)) cur.execute(sql_str, (query_value))
rows = cur.fetchall() rows = cur.fetchall()
assert len(rows) == 2 assert len(rows) == 2
# query_value = False # query_value = False
query_value = "false" query_value = "false" # type: ignore[assignment]
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?" sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?"
cur.execute(sql_str, (query_value)) cur.execute(sql_str, (query_value))
rows = cur.fetchall() rows = cur.fetchall()

View File

@ -31,7 +31,7 @@ def fix_distance_precision(
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings): class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
"""Fake embeddings functionality for testing.""" """Fake embeddings functionality for testing."""
def __init__(self): def __init__(self): # type: ignore[no-untyped-def]
super(FakeEmbeddingsWithAdaDimension, self).__init__(size=ADA_TOKEN_COUNT) super(FakeEmbeddingsWithAdaDimension, self).__init__(size=ADA_TOKEN_COUNT)
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:

View File

@ -7,7 +7,7 @@ from langchain_community.vectorstores import NeuralDBVectorStore
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def test_csv(): def test_csv(): # type: ignore[no-untyped-def]
csv = "thirdai-test.csv" csv = "thirdai-test.csv"
with open(csv, "w") as o: with open(csv, "w") as o:
o.write("column_1,column_2\n") o.write("column_1,column_2\n")
@ -16,13 +16,13 @@ def test_csv():
os.remove(csv) os.remove(csv)
def assert_result_correctness(documents): def assert_result_correctness(documents): # type: ignore[no-untyped-def]
assert len(documents) == 1 assert len(documents) == 1
assert documents[0].page_content == "column_1: column one\n\ncolumn_2: column two" assert documents[0].page_content == "column_1: column one\n\ncolumn_2: column two"
@pytest.mark.requires("thirdai[neural_db]") @pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_scratch(test_csv): def test_neuraldb_retriever_from_scratch(test_csv): # type: ignore[no-untyped-def]
retriever = NeuralDBVectorStore.from_scratch() retriever = NeuralDBVectorStore.from_scratch()
retriever.insert([test_csv]) retriever.insert([test_csv])
documents = retriever.similarity_search("column") documents = retriever.similarity_search("column")
@ -30,7 +30,7 @@ def test_neuraldb_retriever_from_scratch(test_csv):
@pytest.mark.requires("thirdai[neural_db]") @pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_checkpoint(test_csv): def test_neuraldb_retriever_from_checkpoint(test_csv): # type: ignore[no-untyped-def]
checkpoint = "thirdai-test-save.ndb" checkpoint = "thirdai-test-save.ndb"
if os.path.exists(checkpoint): if os.path.exists(checkpoint):
shutil.rmtree(checkpoint) shutil.rmtree(checkpoint)
@ -47,7 +47,7 @@ def test_neuraldb_retriever_from_checkpoint(test_csv):
@pytest.mark.requires("thirdai[neural_db]") @pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_bazaar(test_csv): def test_neuraldb_retriever_from_bazaar(test_csv): # type: ignore[no-untyped-def]
retriever = NeuralDBVectorStore.from_bazaar("General QnA") retriever = NeuralDBVectorStore.from_bazaar("General QnA")
retriever.insert([test_csv]) retriever.insert([test_csv])
documents = retriever.similarity_search("column") documents = retriever.similarity_search("column")
@ -55,7 +55,7 @@ def test_neuraldb_retriever_from_bazaar(test_csv):
@pytest.mark.requires("thirdai[neural_db]") @pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_other_methods(test_csv): def test_neuraldb_retriever_other_methods(test_csv): # type: ignore[no-untyped-def]
retriever = NeuralDBVectorStore.from_scratch() retriever = NeuralDBVectorStore.from_scratch()
retriever.insert([test_csv]) retriever.insert([test_csv])
# Make sure they don't throw an error. # Make sure they don't throw an error.

View File

@ -25,7 +25,7 @@ def get_abbr(s: str) -> str:
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def vectara1(): def vectara1(): # type: ignore[no-untyped-def]
# Set up code # Set up code
# create a new Vectara instance # create a new Vectara instance
vectara1: Vectara = Vectara() vectara1: Vectara = Vectara()
@ -54,7 +54,7 @@ def vectara1():
vectara1._delete_doc(doc_id) vectara1._delete_doc(doc_id)
def test_vectara_add_documents(vectara1) -> None: def test_vectara_add_documents(vectara1) -> None: # type: ignore[no-untyped-def]
"""Test add_documents.""" """Test add_documents."""
# test without filter # test without filter
@ -164,7 +164,7 @@ models can greatly improve the training of DNNs and other deep discriminative mo
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def vectara3(): def vectara3(): # type: ignore[no-untyped-def]
# Set up code # Set up code
vectara3: Vectara = Vectara() vectara3: Vectara = Vectara()
@ -210,7 +210,7 @@ def vectara3():
vectara3._delete_doc(doc_id) vectara3._delete_doc(doc_id)
def test_vectara_mmr(vectara3) -> None: def test_vectara_mmr(vectara3) -> None: # type: ignore[no-untyped-def]
# test max marginal relevance # test max marginal relevance
output1 = vectara3.max_marginal_relevance_search( output1 = vectara3.max_marginal_relevance_search(
"generative AI", "generative AI",
@ -241,7 +241,7 @@ def test_vectara_mmr(vectara3) -> None:
) )
def test_vectara_with_summary(vectara3) -> None: def test_vectara_with_summary(vectara3) -> None: # type: ignore[no-untyped-def]
"""Test vectara summary.""" """Test vectara summary."""
# test summarization # test summarization
num_results = 10 num_results = 10

View File

@ -35,6 +35,6 @@ def test_edenai_messages_formatting(messages: List[BaseMessage], expected: str)
("role", "role_response"), ("role", "role_response"),
[("ai", "assistant"), ("human", "user"), ("chat", "user")], [("ai", "assistant"), ("human", "user"), ("chat", "user")],
) )
def test_edenai_message_role(role: str, role_response) -> None: def test_edenai_message_role(role: str, role_response) -> None: # type: ignore[no-untyped-def]
role = _message_role(role) role = _message_role(role)
assert role == role_response assert role == role_response

View File

@ -29,7 +29,7 @@ class GradientEmbeddingsModel(MagicMock):
embeddings = [] embeddings = []
for i, inp in enumerate(inputs): for i, inp in enumerate(inputs):
# verify correct ordering # verify correct ordering
inp = inp["input"] inp = inp["input"] # type: ignore[assignment]
if "pizza" in inp: if "pizza" in inp:
v = [1.0, 0.0, 0.0] v = [1.0, 0.0, 0.0]
elif "document" in inp: elif "document" in inp:
@ -45,14 +45,14 @@ class GradientEmbeddingsModel(MagicMock):
output.embeddings = embeddings output.embeddings = embeddings
return output return output
async def aembed(self, *args) -> Any: async def aembed(self, *args) -> Any: # type: ignore[no-untyped-def]
return self.embed(*args) return self.embed(*args)
class MockGradient(MagicMock): class MockGradient(MagicMock):
"""Mock Gradient package.""" """Mock Gradient package."""
def __init__(self, access_token: str, workspace_id, host): def __init__(self, access_token: str, workspace_id, host): # type: ignore[no-untyped-def]
assert access_token == _GRADIENT_SECRET assert access_token == _GRADIENT_SECRET
assert workspace_id == _GRADIENT_WORKSPACE_ID assert workspace_id == _GRADIENT_WORKSPACE_ID
assert host == _GRADIENT_BASE_URL assert host == _GRADIENT_BASE_URL

View File

@ -8,7 +8,7 @@ from langchain_community.embeddings import OCIGenAIEmbeddings
class MockResponseDict(dict): class MockResponseDict(dict):
def __getattr__(self, val): def __getattr__(self, val): # type: ignore[no-untyped-def]
return self[val] return self[val]
@ -25,7 +25,7 @@ def test_embedding_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
client=oci_gen_ai_client, client=oci_gen_ai_client,
) )
def mocked_response(invocation_obj): def mocked_response(invocation_obj): # type: ignore[no-untyped-def]
docs = invocation_obj.inputs docs = invocation_obj.inputs
embeddings = [] embeddings = []

View File

@ -1,14 +1,14 @@
from langchain_community.graphs.neo4j_graph import value_sanitize from langchain_community.graphs.neo4j_graph import value_sanitize
def test_value_sanitize_with_small_list(): def test_value_sanitize_with_small_list(): # type: ignore[no-untyped-def]
small_list = list(range(15)) # list size > LIST_LIMIT small_list = list(range(15)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "small_list": small_list} input_dict = {"key1": "value1", "small_list": small_list}
expected_output = {"key1": "value1", "small_list": small_list} expected_output = {"key1": "value1", "small_list": small_list}
assert value_sanitize(input_dict) == expected_output assert value_sanitize(input_dict) == expected_output
def test_value_sanitize_with_oversized_list(): def test_value_sanitize_with_oversized_list(): # type: ignore[no-untyped-def]
oversized_list = list(range(150)) # list size > LIST_LIMIT oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": oversized_list} input_dict = {"key1": "value1", "oversized_list": oversized_list}
expected_output = { expected_output = {
@ -18,14 +18,14 @@ def test_value_sanitize_with_oversized_list():
assert value_sanitize(input_dict) == expected_output assert value_sanitize(input_dict) == expected_output
def test_value_sanitize_with_nested_oversized_list(): def test_value_sanitize_with_nested_oversized_list(): # type: ignore[no-untyped-def]
oversized_list = list(range(150)) # list size > LIST_LIMIT oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": {"key": oversized_list}} input_dict = {"key1": "value1", "oversized_list": {"key": oversized_list}}
expected_output = {"key1": "value1", "oversized_list": {}} expected_output = {"key1": "value1", "oversized_list": {}}
assert value_sanitize(input_dict) == expected_output assert value_sanitize(input_dict) == expected_output
def test_value_sanitize_with_dict_in_list(): def test_value_sanitize_with_dict_in_list(): # type: ignore[no-untyped-def]
oversized_list = list(range(150)) # list size > LIST_LIMIT oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": [1, 2, {"key": oversized_list}]} input_dict = {"key1": "value1", "oversized_list": [1, 2, {"key": oversized_list}]}
expected_output = {"key1": "value1", "oversized_list": [1, 2, {}]} expected_output = {"key1": "value1", "oversized_list": [1, 2, {}]}

View File

@ -15,7 +15,7 @@ class TestOntotextGraphDBGraph(unittest.TestCase):
with self.assertRaises(TypeError) as e: with self.assertRaises(TypeError) as e:
OntotextGraphDBGraph._validate_user_query( OntotextGraphDBGraph._validate_user_query(
[ [ # type: ignore[arg-type]
"PREFIX starwars: <https://swapi.co/ontology/> " "PREFIX starwars: <https://swapi.co/ontology/> "
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> " "PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> "
"DESCRIBE starwars: ?term " "DESCRIBE starwars: ?term "

View File

@ -8,7 +8,7 @@ from langchain_community.llms import OCIGenAI
class MockResponseDict(dict): class MockResponseDict(dict):
def __getattr__(self, val): def __getattr__(self, val): # type: ignore[no-untyped-def]
return self[val] return self[val]
@ -23,7 +23,7 @@ def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
provider = llm._get_provider() provider = llm._get_provider()
def mocked_response(*args): def mocked_response(*args): # type: ignore[no-untyped-def]
response_text = "This is the completion." response_text = "This is the completion."
if provider == "cohere": if provider == "cohere":

View File

@ -4,11 +4,11 @@ from pytest import MonkeyPatch
from langchain_community.llms.ollama import Ollama from langchain_community.llms.ollama import Ollama
def mock_response_stream(): def mock_response_stream(): # type: ignore[no-untyped-def]
mock_response = [b'{ "response": "Response chunk 1" }'] mock_response = [b'{ "response": "Response chunk 1" }']
class MockRaw: class MockRaw:
def read(self, chunk_size): def read(self, chunk_size): # type: ignore[no-untyped-def]
try: try:
return mock_response.pop() return mock_response.pop()
except IndexError: except IndexError:
@ -31,7 +31,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
timeout=300, timeout=300,
) )
def mock_post(url, headers, json, stream, timeout): def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate/" assert url == "https://ollama-hostname:8000/api/generate/"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -52,7 +52,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None: def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout): def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate/" assert url == "https://ollama-hostname:8000/api/generate/"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -72,7 +72,7 @@ def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None:
"""Test that top level params are sent to the endpoint as top level params""" """Test that top level params are sent to the endpoint as top level params"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout): def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate/" assert url == "https://ollama-hostname:8000/api/generate/"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -118,7 +118,7 @@ def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None:
""" """
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout): def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate/" assert url == "https://ollama-hostname:8000/api/generate/"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -165,7 +165,7 @@ def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None:
""" """
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout): def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate/" assert url == "https://ollama-hostname:8000/api/generate/"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",

View File

@ -42,7 +42,7 @@ lint lint_diff lint_package lint_tests:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES) [ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES) [ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff: format format_diff:
poetry run ruff format $(PYTHON_FILES) poetry run ruff format $(PYTHON_FILES)

Some files were not shown because too many files have changed in this diff Show More