diff --git a/.github/workflows/_lint.yml b/.github/workflows/_lint.yml index 1bb605d22fb..e2e3877c554 100644 --- a/.github/workflows/_lint.yml +++ b/.github/workflows/_lint.yml @@ -86,7 +86,7 @@ jobs: with: path: | ${{ env.WORKDIR }}/.mypy_cache - key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', env.WORKDIR)) }} + key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }} - name: Analysing the code with our lint @@ -105,7 +105,7 @@ jobs: # It doesn't matter how you change it, any change will cause a cache-bust. working-directory: ${{ inputs.working-directory }} run: | - poetry install --with test,test_integration + poetry install --with test - name: Get .mypy_cache_test to speed up mypy uses: actions/cache@v3 @@ -114,7 +114,7 @@ jobs: with: path: | ${{ env.WORKDIR }}/.mypy_cache_test - key: mypy-test-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', env.WORKDIR)) }} + key: mypy-test-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }} - name: Analysing the code with our lint working-directory: ${{ inputs.working-directory }} diff --git a/libs/community/Makefile b/libs/community/Makefile index 900deed91ab..d7bda9945aa 100644 --- a/libs/community/Makefile +++ b/libs/community/Makefile @@ -41,7 +41,7 @@ lint lint_diff lint_package lint_tests: poetry run ruff . [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff [ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES) - [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) + [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) format format_diff: poetry run ruff format $(PYTHON_FILES) diff --git a/libs/community/langchain_community/agent_toolkits/openapi/planner.py b/libs/community/langchain_community/agent_toolkits/openapi/planner.py index e95e011e395..7b561a93039 100644 --- a/libs/community/langchain_community/agent_toolkits/openapi/planner.py +++ b/libs/community/langchain_community/agent_toolkits/openapi/planner.py @@ -84,7 +84,7 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool): raise e data_params = data.get("params") response = self.requests_wrapper.get(data["url"], params=data_params) - response = response[: self.response_length] + response = response[: self.response_length] # type: ignore[index] return self.llm_chain.predict( response=response, instructions=data["output_instructions"] ).strip() @@ -115,7 +115,7 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool): except json.JSONDecodeError as e: raise e response = self.requests_wrapper.post(data["url"], data["data"]) - response = response[: self.response_length] + response = response[: self.response_length] # type: ignore[index] return self.llm_chain.predict( response=response, instructions=data["output_instructions"] ).strip() @@ -146,7 +146,7 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool): except json.JSONDecodeError as e: raise e response = self.requests_wrapper.patch(data["url"], data["data"]) - response = response[: self.response_length] + response = response[: self.response_length] # type: ignore[index] return self.llm_chain.predict( response=response, instructions=data["output_instructions"] ).strip() @@ -177,7 +177,7 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool): except json.JSONDecodeError as e: raise e response = self.requests_wrapper.put(data["url"], data["data"]) - response = response[: self.response_length] + response = response[: self.response_length] # type: ignore[index] return self.llm_chain.predict( response=response, instructions=data["output_instructions"] ).strip() @@ -209,7 +209,7 @@ class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool): except json.JSONDecodeError as e: raise e response = self.requests_wrapper.delete(data["url"]) - response = response[: self.response_length] + response = response[: self.response_length] # type: ignore[index] return self.llm_chain.predict( response=response, instructions=data["output_instructions"] ).strip() diff --git a/libs/community/langchain_community/agent_toolkits/sql/base.py b/libs/community/langchain_community/agent_toolkits/sql/base.py index 73915732cf1..65294cf094b 100644 --- a/libs/community/langchain_community/agent_toolkits/sql/base.py +++ b/libs/community/langchain_community/agent_toolkits/sql/base.py @@ -177,12 +177,12 @@ def create_sql_agent( elif agent_type == AgentType.OPENAI_FUNCTIONS: if prompt is None: messages = [ - SystemMessage(content=prefix), + SystemMessage(content=prefix), # type: ignore[arg-type] HumanMessagePromptTemplate.from_template("{input}"), AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX), MessagesPlaceholder(variable_name="agent_scratchpad"), ] - prompt = ChatPromptTemplate.from_messages(messages) + prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type] agent = RunnableAgent( runnable=create_openai_functions_agent(llm, tools, prompt), input_keys_arg=["input"], @@ -191,12 +191,12 @@ def create_sql_agent( elif agent_type == "openai-tools": if prompt is None: messages = [ - SystemMessage(content=prefix), + SystemMessage(content=prefix), # type: ignore[arg-type] HumanMessagePromptTemplate.from_template("{input}"), AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX), MessagesPlaceholder(variable_name="agent_scratchpad"), ] - prompt = ChatPromptTemplate.from_messages(messages) + prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type] agent = RunnableMultiActionAgent( runnable=create_openai_tools_agent(llm, tools, prompt), input_keys_arg=["input"], diff --git a/libs/community/langchain_community/callbacks/mlflow_callback.py b/libs/community/langchain_community/callbacks/mlflow_callback.py index bda5aa4054d..ea862c23fd2 100644 --- a/libs/community/langchain_community/callbacks/mlflow_callback.py +++ b/libs/community/langchain_community/callbacks/mlflow_callback.py @@ -723,7 +723,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): ) return session_analysis_df - def _contain_llm_records(self): + def _contain_llm_records(self): # type: ignore[no-untyped-def] return bool(self.records["on_llm_start_records"]) def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None: diff --git a/libs/community/langchain_community/chat_message_histories/elasticsearch.py b/libs/community/langchain_community/chat_message_histories/elasticsearch.py index f98515e99dc..a9e076563e9 100644 --- a/libs/community/langchain_community/chat_message_histories/elasticsearch.py +++ b/libs/community/langchain_community/chat_message_histories/elasticsearch.py @@ -47,7 +47,7 @@ class ElasticsearchChatMessageHistory(BaseChatMessageHistory): ): self.index: str = index self.session_id: str = session_id - self.ensure_ascii: bool = esnsure_ascii + self.ensure_ascii: bool = esnsure_ascii # type: ignore[assignment] # Initialize Elasticsearch client from passed client arg or connection info if es_connection is not None: diff --git a/libs/community/langchain_community/chat_message_histories/tidb.py b/libs/community/langchain_community/chat_message_histories/tidb.py index bfa36ad06ff..973e97820a2 100644 --- a/libs/community/langchain_community/chat_message_histories/tidb.py +++ b/libs/community/langchain_community/chat_message_histories/tidb.py @@ -40,7 +40,7 @@ class TiDBChatMessageHistory(BaseChatMessageHistory): self.session_id = session_id self.table_name = table_name self.earliest_time = earliest_time - self.cache = [] + self.cache = [] # type: ignore[var-annotated] # Set up SQLAlchemy engine and session self.engine = create_engine(connection_string) @@ -102,7 +102,7 @@ class TiDBChatMessageHistory(BaseChatMessageHistory): logger.error(f"Error loading messages to cache: {e}") @property - def messages(self) -> List[BaseMessage]: + def messages(self) -> List[BaseMessage]: # type: ignore[override] """returns all messages""" if len(self.cache) == 0: self.reload_cache() diff --git a/libs/community/langchain_community/chat_message_histories/zep.py b/libs/community/langchain_community/chat_message_histories/zep.py index 45de39f150e..b38b1e9b49d 100644 --- a/libs/community/langchain_community/chat_message_histories/zep.py +++ b/libs/community/langchain_community/chat_message_histories/zep.py @@ -149,7 +149,7 @@ class ZepChatMessageHistory(BaseChatMessageHistory): return None return zep_memory - def add_user_message( + def add_user_message( # type: ignore[override] self, message: str, metadata: Optional[Dict[str, Any]] = None ) -> None: """Convenience method for adding a human message string to the store. @@ -160,7 +160,7 @@ class ZepChatMessageHistory(BaseChatMessageHistory): """ self.add_message(HumanMessage(content=message), metadata=metadata) - def add_ai_message( + def add_ai_message( # type: ignore[override] self, message: str, metadata: Optional[Dict[str, Any]] = None ) -> None: """Convenience method for adding an AI message string to the store. diff --git a/libs/community/langchain_community/chat_models/azureml_endpoint.py b/libs/community/langchain_community/chat_models/azureml_endpoint.py index 58192d6cdce..36229582723 100644 --- a/libs/community/langchain_community/chat_models/azureml_endpoint.py +++ b/libs/community/langchain_community/chat_models/azureml_endpoint.py @@ -20,7 +20,7 @@ from langchain_community.llms.azureml_endpoint import ( class LlamaContentFormatter(ContentFormatterBase): - def __init__(self): + def __init__(self): # type: ignore[no-untyped-def] raise TypeError( "`LlamaContentFormatter` is deprecated for chat models. Use " "`LlamaChatContentFormatter` instead." @@ -72,7 +72,7 @@ class LlamaChatContentFormatter(ContentFormatterBase): def supported_api_types(self) -> List[AzureMLEndpointApiType]: return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless] - def format_request_payload( + def format_request_payload( # type: ignore[override] self, messages: List[BaseMessage], model_kwargs: Dict, @@ -98,9 +98,9 @@ class LlamaChatContentFormatter(ContentFormatterBase): raise ValueError( f"`api_type` {api_type} is not supported by this formatter" ) - return str.encode(request_payload) + return str.encode(request_payload) # type: ignore[return-value] - def format_response_payload( + def format_response_payload( # type: ignore[override] self, output: bytes, api_type: AzureMLEndpointApiType ) -> ChatGeneration: """Formats response""" @@ -108,7 +108,7 @@ class LlamaChatContentFormatter(ContentFormatterBase): try: choice = json.loads(output)["output"] except (KeyError, IndexError, TypeError) as e: - raise ValueError(self.format_error_msg.format(api_type=api_type)) from e + raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr] return ChatGeneration( message=BaseMessage( content=choice.strip(), @@ -125,7 +125,7 @@ class LlamaChatContentFormatter(ContentFormatterBase): "model. Expected `dict` but `{type(choice)}` was received." ) except (KeyError, IndexError, TypeError) as e: - raise ValueError(self.format_error_msg.format(api_type=api_type)) from e + raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr] return ChatGeneration( message=BaseMessage( content=choice["message"]["content"].strip(), diff --git a/libs/community/langchain_community/chat_models/edenai.py b/libs/community/langchain_community/chat_models/edenai.py index 8c72f1d791d..a97dbf28b9c 100644 --- a/libs/community/langchain_community/chat_models/edenai.py +++ b/libs/community/langchain_community/chat_models/edenai.py @@ -175,7 +175,7 @@ class ChatEdenAI(BaseChatModel): """Call out to EdenAI's chat endpoint.""" url = f"{self.edenai_api_url}/text/chat/stream" headers = { - "Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", + "Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr] "User-Agent": self.get_user_agent(), } formatted_data = _format_edenai_messages(messages=messages) @@ -216,7 +216,7 @@ class ChatEdenAI(BaseChatModel): ) -> AsyncIterator[ChatGenerationChunk]: url = f"{self.edenai_api_url}/text/chat/stream" headers = { - "Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", + "Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr] "User-Agent": self.get_user_agent(), } formatted_data = _format_edenai_messages(messages=messages) @@ -265,7 +265,7 @@ class ChatEdenAI(BaseChatModel): url = f"{self.edenai_api_url}/text/chat" headers = { - "Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", + "Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr] "User-Agent": self.get_user_agent(), } formatted_data = _format_edenai_messages(messages=messages) @@ -323,7 +323,7 @@ class ChatEdenAI(BaseChatModel): url = f"{self.edenai_api_url}/text/chat" headers = { - "Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", + "Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr] "User-Agent": self.get_user_agent(), } formatted_data = _format_edenai_messages(messages=messages) diff --git a/libs/community/langchain_community/chat_models/ernie.py b/libs/community/langchain_community/chat_models/ernie.py index 9954ef4272f..6f2a4d649bf 100644 --- a/libs/community/langchain_community/chat_models/ernie.py +++ b/libs/community/langchain_community/chat_models/ernie.py @@ -214,7 +214,7 @@ class ErnieBotChat(BaseChatModel): generations = [ ChatGeneration( message=AIMessage( - content=response.get("result"), + content=response.get("result"), # type: ignore[arg-type] additional_kwargs={**additional_kwargs}, ) ) diff --git a/libs/community/langchain_community/chat_models/gpt_router.py b/libs/community/langchain_community/chat_models/gpt_router.py index 498d8542c8d..571d458e4b3 100644 --- a/libs/community/langchain_community/chat_models/gpt_router.py +++ b/libs/community/langchain_community/chat_models/gpt_router.py @@ -56,7 +56,7 @@ class GPTRouterModel(BaseModel): provider_name: str -def get_ordered_generation_requests( +def get_ordered_generation_requests( # type: ignore[no-untyped-def, no-untyped-def] models_priority_list: List[GPTRouterModel], **kwargs ): """ @@ -100,7 +100,7 @@ def completion_with_retry( models_priority_list: List[GPTRouterModel], run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, -) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]: +) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]: # type: ignore[type-arg] """Use tenacity to retry the completion call.""" retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) @@ -122,7 +122,7 @@ async def acompletion_with_retry( models_priority_list: List[GPTRouterModel], run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, -) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]: +) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]: # type: ignore[type-arg] """Use tenacity to retry the async completion call.""" retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) @@ -282,7 +282,7 @@ class GPTRouter(BaseChatModel): ) return self._create_chat_result(response) - def _create_chat_generation_chunk( + def _create_chat_generation_chunk( # type: ignore[no-untyped-def, no-untyped-def] self, data: Mapping[str, Any], default_chunk_class ): chunk = _convert_delta_to_message_chunk( @@ -293,7 +293,7 @@ class GPTRouter(BaseChatModel): dict(finish_reason=finish_reason) if finish_reason is not None else None ) default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) + chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) # type: ignore[assignment] return chunk, default_chunk_class def _stream( diff --git a/libs/community/langchain_community/chat_models/huggingface.py b/libs/community/langchain_community/chat_models/huggingface.py index b83184ac6e5..bb52c7795d3 100644 --- a/libs/community/langchain_community/chat_models/huggingface.py +++ b/libs/community/langchain_community/chat_models/huggingface.py @@ -144,7 +144,7 @@ class ChatHuggingFace(BaseChatModel): elif isinstance(self.llm, HuggingFaceHub): # no need to look up model_id for HuggingFaceHub LLM - self.model_id = self.llm.repo_id + self.model_id = self.llm.repo_id # type: ignore[assignment] return else: diff --git a/libs/community/langchain_community/chat_models/konko.py b/libs/community/langchain_community/chat_models/konko.py index 8492a5f8c56..9a977a1728a 100644 --- a/libs/community/langchain_community/chat_models/konko.py +++ b/libs/community/langchain_community/chat_models/konko.py @@ -169,7 +169,7 @@ class ChatKonko(ChatOpenAI): } if openai_api_key: - headers["X-OpenAI-Api-Key"] = openai_api_key.get_secret_value() + headers["X-OpenAI-Api-Key"] = openai_api_key.get_secret_value() # type: ignore[union-attr] models_response = requests.get(models_url, headers=headers) diff --git a/libs/community/langchain_community/chat_models/ollama.py b/libs/community/langchain_community/chat_models/ollama.py index e0327e7e50c..a6bb1bb6988 100644 --- a/libs/community/langchain_community/chat_models/ollama.py +++ b/libs/community/langchain_community/chat_models/ollama.py @@ -74,10 +74,10 @@ class ChatOllama(BaseChatModel, _OllamaCommon): if isinstance(message, ChatMessage): message_text = f"\n\n{message.role.capitalize()}: {message.content}" elif isinstance(message, HumanMessage): - if message.content[0].get("type") == "text": - message_text = f"[INST] {message.content[0]['text']} [/INST]" - elif message.content[0].get("type") == "image_url": - message_text = message.content[0]["image_url"]["url"] + if message.content[0].get("type") == "text": # type: ignore[union-attr] + message_text = f"[INST] {message.content[0]['text']} [/INST]" # type: ignore[index] + elif message.content[0].get("type") == "image_url": # type: ignore[union-attr] + message_text = message.content[0]["image_url"]["url"] # type: ignore[index, index] elif isinstance(message, AIMessage): message_text = f"{message.content}" elif isinstance(message, SystemMessage): @@ -112,11 +112,11 @@ class ChatOllama(BaseChatModel, _OllamaCommon): content = message.content else: for content_part in message.content: - if content_part.get("type") == "text": - content += f"\n{content_part['text']}" - elif content_part.get("type") == "image_url": - if isinstance(content_part.get("image_url"), str): - image_url_components = content_part["image_url"].split(",") + if content_part.get("type") == "text": # type: ignore[union-attr] + content += f"\n{content_part['text']}" # type: ignore[index] + elif content_part.get("type") == "image_url": # type: ignore[union-attr] + if isinstance(content_part.get("image_url"), str): # type: ignore[union-attr] + image_url_components = content_part["image_url"].split(",") # type: ignore[index] # Support data:image/jpeg;base64, format # and base64 strings if len(image_url_components) > 1: @@ -142,7 +142,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon): } ) - return ollama_messages + return ollama_messages # type: ignore[return-value] def _create_chat_stream( self, @@ -337,7 +337,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon): verbose=self.verbose, ) except OllamaEndpointNotFoundError: - async for chunk in self._legacy_astream(messages, stop, **kwargs): + async for chunk in self._legacy_astream(messages, stop, **kwargs): # type: ignore[attr-defined] yield chunk @deprecated("0.0.3", alternative="_stream") diff --git a/libs/community/langchain_community/chat_models/tongyi.py b/libs/community/langchain_community/chat_models/tongyi.py index 004ffdd3b12..4623864c851 100644 --- a/libs/community/langchain_community/chat_models/tongyi.py +++ b/libs/community/langchain_community/chat_models/tongyi.py @@ -197,7 +197,7 @@ class ChatTongyi(BaseChatModel): return { "model": self.model_name, "top_p": self.top_p, - "api_key": self.dashscope_api_key.get_secret_value(), + "api_key": self.dashscope_api_key.get_secret_value(), # type: ignore[union-attr] "result_format": "message", **self.model_kwargs, } diff --git a/libs/community/langchain_community/chat_models/vertexai.py b/libs/community/langchain_community/chat_models/vertexai.py index cecd63ac41f..c5a81ea32a9 100644 --- a/libs/community/langchain_community/chat_models/vertexai.py +++ b/libs/community/langchain_community/chat_models/vertexai.py @@ -121,7 +121,7 @@ def _parse_chat_history_gemini( elif path.startswith("data:image/"): # extract base64 component from image uri try: - encoded = re.search(r"data:image/\w{2,4};base64,(.*)", path).group( + encoded = re.search(r"data:image/\w{2,4};base64,(.*)", path).group( # type: ignore[union-attr] 1 ) except AttributeError: diff --git a/libs/community/langchain_community/chat_models/yandex.py b/libs/community/langchain_community/chat_models/yandex.py index 61de024089e..94e05803405 100644 --- a/libs/community/langchain_community/chat_models/yandex.py +++ b/libs/community/langchain_community/chat_models/yandex.py @@ -52,7 +52,7 @@ def _parse_chat_history(history: List[BaseMessage]) -> List[Dict[str, str]]: return chat_history -class ChatYandexGPT(_BaseYandexGPT, BaseChatModel): +class ChatYandexGPT(_BaseYandexGPT, BaseChatModel): # type: ignore[misc] """Wrapper around YandexGPT large language models. There are two authentication options for the service account @@ -156,7 +156,7 @@ def _make_request( messages=[Message(**message) for message in message_history], ) stub = TextGenerationServiceStub(channel) - res = stub.Completion(request, metadata=self._grpc_metadata) + res = stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined] return list(res)[0].alternatives[0].message.text @@ -201,7 +201,7 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st messages=[Message(**message) for message in message_history], ) stub = TextGenerationAsyncServiceStub(channel) - operation = await stub.Completion(request, metadata=self._grpc_metadata) + operation = await stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined] async with grpc.aio.secure_channel( operation_api_url, channel_credentials ) as operation_channel: @@ -210,7 +210,8 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st await asyncio.sleep(1) operation_request = GetOperationRequest(operation_id=operation.id) operation = await operation_stub.Get( - operation_request, metadata=self._grpc_metadata + operation_request, + metadata=self._grpc_metadata, # type: ignore[attr-defined] ) completion_response = CompletionResponse() diff --git a/libs/community/langchain_community/chat_models/zhipuai.py b/libs/community/langchain_community/chat_models/zhipuai.py index e8e56f0b95c..113c2be4a1e 100644 --- a/libs/community/langchain_community/chat_models/zhipuai.py +++ b/libs/community/langchain_community/chat_models/zhipuai.py @@ -161,7 +161,7 @@ class ChatZhipuAI(BaseChatModel): return attributes - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] super().__init__(*args, **kwargs) try: import zhipuai @@ -174,7 +174,7 @@ class ChatZhipuAI(BaseChatModel): "Please install it via 'pip install zhipuai'" ) - def invoke(self, prompt): + def invoke(self, prompt): # type: ignore[no-untyped-def] if self.model == "chatglm_turbo": return self.zhipuai.model_api.invoke( model=self.model, @@ -195,7 +195,7 @@ class ChatZhipuAI(BaseChatModel): ) return None - def sse_invoke(self, prompt): + def sse_invoke(self, prompt): # type: ignore[no-untyped-def] if self.model == "chatglm_turbo": return self.zhipuai.model_api.sse_invoke( model=self.model, @@ -218,7 +218,7 @@ class ChatZhipuAI(BaseChatModel): ) return None - async def async_invoke(self, prompt): + async def async_invoke(self, prompt): # type: ignore[no-untyped-def] loop = asyncio.get_running_loop() partial_func = partial( self.zhipuai.model_api.async_invoke, model=self.model, prompt=prompt @@ -229,7 +229,7 @@ class ChatZhipuAI(BaseChatModel): ) return response - async def async_invoke_result(self, task_id): + async def async_invoke_result(self, task_id): # type: ignore[no-untyped-def] loop = asyncio.get_running_loop() response = await loop.run_in_executor( None, @@ -270,11 +270,14 @@ class ChatZhipuAI(BaseChatModel): else: stream_iter = self._stream( - prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + prompt=prompt, # type: ignore[arg-type] + stop=stop, + run_manager=run_manager, + **kwargs, ) return generate_from_stream(stream_iter) - async def _agenerate( + async def _agenerate( # type: ignore[override] self, messages: List[BaseMessage], stop: Optional[List[str]] = None, @@ -307,7 +310,7 @@ class ChatZhipuAI(BaseChatModel): generations=[ChatGeneration(message=AIMessage(content=content))] ) - def _stream( + def _stream( # type: ignore[override] self, prompt: List[Dict[str, str]], stop: Optional[List[str]] = None, diff --git a/libs/community/langchain_community/document_loaders/assemblyai.py b/libs/community/langchain_community/document_loaders/assemblyai.py index d3947d9f71b..eb4671c0d23 100644 --- a/libs/community/langchain_community/document_loaders/assemblyai.py +++ b/libs/community/langchain_community/document_loaders/assemblyai.py @@ -123,7 +123,7 @@ class AssemblyAIAudioLoaderById(BaseLoader): """ - def __init__(self, transcript_id, api_key, transcript_format): + def __init__(self, transcript_id, api_key, transcript_format): # type: ignore[no-untyped-def] """ Initializes the AssemblyAI AssemblyAIAudioLoaderById. diff --git a/libs/community/langchain_community/document_loaders/astradb.py b/libs/community/langchain_community/document_loaders/astradb.py index 2af1e77fde1..4cc1621a39d 100644 --- a/libs/community/langchain_community/document_loaders/astradb.py +++ b/libs/community/langchain_community/document_loaders/astradb.py @@ -65,7 +65,7 @@ class AstraDBLoader(BaseLoader): return list(self.lazy_load()) def lazy_load(self) -> Iterator[Document]: - queue = Queue(self.nb_prefetched) + queue = Queue(self.nb_prefetched) # type: ignore[var-annotated] t = threading.Thread(target=self.fetch_results, args=(queue,)) t.start() while True: @@ -95,7 +95,7 @@ class AstraDBLoader(BaseLoader): item = await run_in_executor(None, lambda it: next(it, done), iterator) if item is done: break - yield item + yield item # type: ignore[misc] return async_collection = await self.astra_env.async_astra_db.collection( self.collection_name @@ -116,13 +116,13 @@ class AstraDBLoader(BaseLoader): }, ) - def fetch_results(self, queue: Queue): + def fetch_results(self, queue: Queue): # type: ignore[no-untyped-def] self.fetch_page_result(queue) while self.find_options.get("pageState"): self.fetch_page_result(queue) queue.put(None) - def fetch_page_result(self, queue: Queue): + def fetch_page_result(self, queue: Queue): # type: ignore[no-untyped-def] res = self.collection.find( filter=self.filter, options=self.find_options, diff --git a/libs/community/langchain_community/document_loaders/base.py b/libs/community/langchain_community/document_loaders/base.py index 7a3e5a2706c..fc266ff51d6 100644 --- a/libs/community/langchain_community/document_loaders/base.py +++ b/libs/community/langchain_community/document_loaders/base.py @@ -64,10 +64,10 @@ class BaseLoader(ABC): iterator = await run_in_executor(None, self.lazy_load) done = object() while True: - doc = await run_in_executor(None, next, iterator, done) + doc = await run_in_executor(None, next, iterator, done) # type: ignore[call-arg, arg-type] if doc is done: break - yield doc + yield doc # type: ignore[misc] class BaseBlobParser(ABC): diff --git a/libs/community/langchain_community/document_loaders/cassandra.py b/libs/community/langchain_community/document_loaders/cassandra.py index 3167711228a..a3b7732c131 100644 --- a/libs/community/langchain_community/document_loaders/cassandra.py +++ b/libs/community/langchain_community/document_loaders/cassandra.py @@ -33,14 +33,14 @@ class CassandraLoader(BaseLoader): page_content_mapper: Callable[[Any], str] = str, metadata_mapper: Callable[[Any], dict] = lambda _: {}, *, - query_parameters: Union[dict, Sequence] = None, - query_timeout: Optional[float] = _NOT_SET, + query_parameters: Union[dict, Sequence] = None, # type: ignore[assignment] + query_timeout: Optional[float] = _NOT_SET, # type: ignore[assignment] query_trace: bool = False, - query_custom_payload: dict = None, + query_custom_payload: dict = None, # type: ignore[assignment] query_execution_profile: Any = _NOT_SET, query_paging_state: Any = None, query_host: Host = None, - query_execute_as: str = None, + query_execute_as: str = None, # type: ignore[assignment] ) -> None: """ Document Loader for Apache Cassandra. @@ -85,7 +85,7 @@ class CassandraLoader(BaseLoader): self.query = f"SELECT * FROM {_keyspace}.{table};" self.metadata = {"table": table, "keyspace": _keyspace} else: - self.query = query + self.query = query # type: ignore[assignment] self.metadata = {} self.session = session or check_resolve_session(session) diff --git a/libs/community/langchain_community/document_loaders/chm.py b/libs/community/langchain_community/document_loaders/chm.py index c036ee8b8f3..381087d9917 100644 --- a/libs/community/langchain_community/document_loaders/chm.py +++ b/libs/community/langchain_community/document_loaders/chm.py @@ -27,7 +27,7 @@ class UnstructuredCHMLoader(UnstructuredFileLoader): def _get_elements(self) -> List: from unstructured.partition.html import partition_html - with CHMParser(self.file_path) as f: + with CHMParser(self.file_path) as f: # type: ignore[arg-type] return [ partition_html(text=item["content"], **self.unstructured_kwargs) for item in f.load_all() @@ -45,10 +45,10 @@ class CHMParser(object): self.file = chm.CHMFile() self.file.LoadCHM(path) - def __enter__(self): + def __enter__(self): # type: ignore[no-untyped-def] return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback): # type: ignore[no-untyped-def] if self.file: self.file.CloseCHM() diff --git a/libs/community/langchain_community/document_loaders/doc_intelligence.py b/libs/community/langchain_community/document_loaders/doc_intelligence.py index d1326afdbe3..2aae5212a59 100644 --- a/libs/community/langchain_community/document_loaders/doc_intelligence.py +++ b/libs/community/langchain_community/document_loaders/doc_intelligence.py @@ -89,4 +89,4 @@ class AzureAIDocumentIntelligenceLoader(BaseLoader): blob = Blob.from_path(self.file_path) yield from self.parser.parse(blob) else: - yield from self.parser.parse_url(self.url_path) + yield from self.parser.parse_url(self.url_path) # type: ignore[arg-type] diff --git a/libs/community/langchain_community/document_loaders/mediawikidump.py b/libs/community/langchain_community/document_loaders/mediawikidump.py index 360420c536d..4868e5f89cd 100644 --- a/libs/community/langchain_community/document_loaders/mediawikidump.py +++ b/libs/community/langchain_community/document_loaders/mediawikidump.py @@ -60,7 +60,7 @@ class MWDumpLoader(BaseLoader): self.skip_redirects = skip_redirects self.stop_on_error = stop_on_error - def _load_dump_file(self): + def _load_dump_file(self): # type: ignore[no-untyped-def] try: import mwxml except ImportError as e: @@ -70,7 +70,7 @@ class MWDumpLoader(BaseLoader): return mwxml.Dump.from_file(open(self.file_path, encoding=self.encoding)) - def _load_single_page_from_dump(self, page) -> Document: + def _load_single_page_from_dump(self, page) -> Document: # type: ignore[no-untyped-def, return] """Parse a single page.""" try: import mwparserfromhell diff --git a/libs/community/langchain_community/document_loaders/parsers/vsdx.py b/libs/community/langchain_community/document_loaders/parsers/vsdx.py index 109521e48cc..d4dde56de0c 100644 --- a/libs/community/langchain_community/document_loaders/parsers/vsdx.py +++ b/libs/community/langchain_community/document_loaders/parsers/vsdx.py @@ -11,7 +11,7 @@ from langchain_community.document_loaders.blob_loaders import Blob class VsdxParser(BaseBlobParser, ABC): - def parse(self, blob: Blob) -> Iterator[Document]: + def parse(self, blob: Blob) -> Iterator[Document]: # type: ignore[override] """Parse a vsdx file.""" return self.lazy_parse(blob) @@ -21,7 +21,7 @@ class VsdxParser(BaseBlobParser, ABC): with blob.as_bytes_io() as pdf_file_obj: with zipfile.ZipFile(pdf_file_obj, "r") as zfile: - pages = self.get_pages_content(zfile, blob.source) + pages = self.get_pages_content(zfile, blob.source) # type: ignore[arg-type] yield from [ Document( @@ -60,13 +60,13 @@ class VsdxParser(BaseBlobParser, ABC): if "visio/pages/pages.xml" not in zfile.namelist(): print("WARNING - No pages.xml file found in {}".format(source)) - return + return # type: ignore[return-value] if "visio/pages/_rels/pages.xml.rels" not in zfile.namelist(): print("WARNING - No pages.xml.rels file found in {}".format(source)) - return + return # type: ignore[return-value] if "docProps/app.xml" not in zfile.namelist(): print("WARNING - No app.xml file found in {}".format(source)) - return + return # type: ignore[return-value] pagesxml_content: dict = xmltodict.parse(zfile.read("visio/pages/pages.xml")) appxml_content: dict = xmltodict.parse(zfile.read("docProps/app.xml")) @@ -79,7 +79,7 @@ class VsdxParser(BaseBlobParser, ABC): rel["@Name"].strip() for rel in pagesxml_content["Pages"]["Page"] ] else: - disordered_names: List[str] = [ + disordered_names: List[str] = [ # type: ignore[no-redef] pagesxml_content["Pages"]["Page"]["@Name"].strip() ] if isinstance(pagesxmlrels_content["Relationships"]["Relationship"], list): @@ -88,7 +88,7 @@ class VsdxParser(BaseBlobParser, ABC): for rel in pagesxmlrels_content["Relationships"]["Relationship"] ] else: - disordered_paths: List[str] = [ + disordered_paths: List[str] = [ # type: ignore[no-redef] "visio/pages/" + pagesxmlrels_content["Relationships"]["Relationship"]["@Target"] ] diff --git a/libs/community/langchain_community/embeddings/baichuan.py b/libs/community/langchain_community/embeddings/baichuan.py index b31e6092d59..9f0bf92ea8a 100644 --- a/libs/community/langchain_community/embeddings/baichuan.py +++ b/libs/community/langchain_community/embeddings/baichuan.py @@ -89,7 +89,7 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings): print(f"Exception occurred while trying to get embeddings: {str(e)}") return None - def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: + def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: # type: ignore[override] """Public method to get embeddings for a list of documents. Args: @@ -100,7 +100,7 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings): """ return self._embed(texts) - def embed_query(self, text: str) -> Optional[List[float]]: + def embed_query(self, text: str) -> Optional[List[float]]: # type: ignore[override] """Public method to get embedding for a single query text. Args: diff --git a/libs/community/langchain_community/embeddings/edenai.py b/libs/community/langchain_community/embeddings/edenai.py index 3d6b1ec16d0..9446969d332 100644 --- a/libs/community/langchain_community/embeddings/edenai.py +++ b/libs/community/langchain_community/embeddings/edenai.py @@ -56,7 +56,7 @@ class EdenAiEmbeddings(BaseModel, Embeddings): headers = { "accept": "application/json", "content-type": "application/json", - "authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", + "authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr] "User-Agent": self.get_user_agent(), } diff --git a/libs/community/langchain_community/embeddings/embaas.py b/libs/community/langchain_community/embeddings/embaas.py index 2f0b31f4439..00f23116bf4 100644 --- a/libs/community/langchain_community/embeddings/embaas.py +++ b/libs/community/langchain_community/embeddings/embaas.py @@ -85,7 +85,7 @@ class EmbaasEmbeddings(BaseModel, Embeddings): def _handle_request(self, payload: EmbaasEmbeddingsPayload) -> List[List[float]]: """Sends a request to the Embaas API and handles the response.""" headers = { - "Authorization": f"Bearer {self.embaas_api_key.get_secret_value()}", + "Authorization": f"Bearer {self.embaas_api_key.get_secret_value()}", # type: ignore[union-attr] "Content-Type": "application/json", } diff --git a/libs/community/langchain_community/embeddings/gradient_ai.py b/libs/community/langchain_community/embeddings/gradient_ai.py index 6fd05d3c699..c0133d02386 100644 --- a/libs/community/langchain_community/embeddings/gradient_ai.py +++ b/libs/community/langchain_community/embeddings/gradient_ai.py @@ -162,5 +162,5 @@ class TinyAsyncGradientEmbeddingClient: #: :meta private: It might be entirely removed in the future. """ - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] raise ValueError("Deprecated,TinyAsyncGradientEmbeddingClient was removed.") diff --git a/libs/community/langchain_community/embeddings/llm_rails.py b/libs/community/langchain_community/embeddings/llm_rails.py index 44bc0171803..c8afa11b491 100644 --- a/libs/community/langchain_community/embeddings/llm_rails.py +++ b/libs/community/langchain_community/embeddings/llm_rails.py @@ -56,7 +56,7 @@ class LLMRailsEmbeddings(BaseModel, Embeddings): """ response = requests.post( "https://api.llmrails.com/v1/embeddings", - headers={"X-API-KEY": self.api_key.get_secret_value()}, + headers={"X-API-KEY": self.api_key.get_secret_value()}, # type: ignore[union-attr] json={"input": texts, "model": self.model}, timeout=60, ) diff --git a/libs/community/langchain_community/embeddings/minimax.py b/libs/community/langchain_community/embeddings/minimax.py index a13eb98367d..38e56c553b8 100644 --- a/libs/community/langchain_community/embeddings/minimax.py +++ b/libs/community/langchain_community/embeddings/minimax.py @@ -110,7 +110,7 @@ class MiniMaxEmbeddings(BaseModel, Embeddings): # HTTP headers for authorization headers = { - "Authorization": f"Bearer {self.minimax_api_key.get_secret_value()}", + "Authorization": f"Bearer {self.minimax_api_key.get_secret_value()}", # type: ignore[union-attr] "Content-Type": "application/json", } diff --git a/libs/community/langchain_community/embeddings/mlflow.py b/libs/community/langchain_community/embeddings/mlflow.py index 6b24dacb025..1b1abb41032 100644 --- a/libs/community/langchain_community/embeddings/mlflow.py +++ b/libs/community/langchain_community/embeddings/mlflow.py @@ -71,7 +71,8 @@ class MlflowEmbeddings(Embeddings, BaseModel): embeddings: List[List[float]] = [] for txt in _chunk(texts, 20): resp = self._client.predict( - endpoint=self.endpoint, inputs={"input": txt, **params} + endpoint=self.endpoint, + inputs={"input": txt, **params}, # type: ignore[arg-type] ) embeddings.extend(r["embedding"] for r in resp["data"]) return embeddings diff --git a/libs/community/langchain_community/embeddings/oci_generative_ai.py b/libs/community/langchain_community/embeddings/oci_generative_ai.py index 6d47fec6f32..afcbb62024a 100644 --- a/libs/community/langchain_community/embeddings/oci_generative_ai.py +++ b/libs/community/langchain_community/embeddings/oci_generative_ai.py @@ -63,16 +63,16 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings): If not specified , DEFAULT will be used """ - model_id: str = None + model_id: str = None # type: ignore[assignment] """Id of the model to call, e.g., cohere.embed-english-light-v2.0""" model_kwargs: Optional[Dict] = None """Keyword arguments to pass to the model""" - service_endpoint: str = None + service_endpoint: str = None # type: ignore[assignment] """service endpoint url""" - compartment_id: str = None + compartment_id: str = None # type: ignore[assignment] """OCID of compartment""" truncate: Optional[str] = "END" @@ -109,7 +109,7 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings): client_kwargs.pop("signer", None) elif values["auth_type"] == OCIAuthType(2).name: - def make_security_token_signer(oci_config): + def make_security_token_signer(oci_config): # type: ignore[no-untyped-def] pk = oci.signer.load_private_key_from_file( oci_config.get("key_file"), None ) diff --git a/libs/community/langchain_community/embeddings/spacy_embeddings.py b/libs/community/langchain_community/embeddings/spacy_embeddings.py index 645d5afc963..954b2550fb4 100644 --- a/libs/community/langchain_community/embeddings/spacy_embeddings.py +++ b/libs/community/langchain_community/embeddings/spacy_embeddings.py @@ -78,7 +78,7 @@ class SpacyEmbeddings(BaseModel, Embeddings): Returns: A list of embeddings, one for each document. """ - return [self.nlp(text).vector.tolist() for text in texts] + return [self.nlp(text).vector.tolist() for text in texts] # type: ignore[misc] def embed_query(self, text: str) -> List[float]: """ @@ -90,7 +90,7 @@ class SpacyEmbeddings(BaseModel, Embeddings): Returns: The embedding for the text. """ - return self.nlp(text).vector.tolist() + return self.nlp(text).vector.tolist() # type: ignore[misc] async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """ diff --git a/libs/community/langchain_community/embeddings/yandex.py b/libs/community/langchain_community/embeddings/yandex.py index d3e6ed69e45..34f71feafae 100644 --- a/libs/community/langchain_community/embeddings/yandex.py +++ b/libs/community/langchain_community/embeddings/yandex.py @@ -42,10 +42,10 @@ class YandexGPTEmbeddings(BaseModel, Embeddings): embeddings = YandexGPTEmbeddings(iam_token="t1.9eu...", model_uri="emb:///text-search-query/latest") """ - iam_token: SecretStr = "" + iam_token: SecretStr = "" # type: ignore[assignment] """Yandex Cloud IAM token for service account with the `ai.languageModels.user` role""" - api_key: SecretStr = "" + api_key: SecretStr = "" # type: ignore[assignment] """Yandex Cloud Api Key for service account with the `ai.languageModels.user` role""" model_uri: str = "" @@ -146,7 +146,7 @@ def _embed_with_retry(llm: YandexGPTEmbeddings, **kwargs: Any) -> Any: return _completion_with_retry(**kwargs) -def _make_request(self: YandexGPTEmbeddings, texts: List[str]): +def _make_request(self: YandexGPTEmbeddings, texts: List[str]): # type: ignore[no-untyped-def] try: import grpc from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501 @@ -167,7 +167,7 @@ def _make_request(self: YandexGPTEmbeddings, texts: List[str]): for text in texts: request = TextEmbeddingRequest(model_uri=self.model_uri, text=text) stub = EmbeddingsServiceStub(channel) - res = stub.TextEmbedding(request, metadata=self._grpc_metadata) + res = stub.TextEmbedding(request, metadata=self._grpc_metadata) # type: ignore[attr-defined] result.append(list(res.embedding)) time.sleep(self.sleep_interval) diff --git a/libs/community/langchain_community/graphs/neo4j_graph.py b/libs/community/langchain_community/graphs/neo4j_graph.py index b8970c06ecc..125fd96b151 100644 --- a/libs/community/langchain_community/graphs/neo4j_graph.py +++ b/libs/community/langchain_community/graphs/neo4j_graph.py @@ -56,7 +56,7 @@ def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]: cleaned_list.append(value_sanitize(item)) else: cleaned_list.append(item) - new_dict[key] = cleaned_list + new_dict[key] = cleaned_list # type: ignore[assignment] else: new_dict[key] = value return new_dict diff --git a/libs/community/langchain_community/graphs/ontotext_graphdb_graph.py b/libs/community/langchain_community/graphs/ontotext_graphdb_graph.py index c1072a97f81..bf5d1a71777 100644 --- a/libs/community/langchain_community/graphs/ontotext_graphdb_graph.py +++ b/libs/community/langchain_community/graphs/ontotext_graphdb_graph.py @@ -95,12 +95,13 @@ class OntotextGraphDBGraph: if local_file: ontology_schema_graph = self._load_ontology_schema_from_file( - local_file, local_file_format + local_file, + local_file_format, # type: ignore[arg-type] ) else: - self._validate_user_query(query_ontology) + self._validate_user_query(query_ontology) # type: ignore[arg-type] ontology_schema_graph = self._load_ontology_schema_with_query( - query_ontology + query_ontology # type: ignore[arg-type] ) self.schema = ontology_schema_graph.serialize(format="turtle") @@ -139,7 +140,7 @@ class OntotextGraphDBGraph: ) @staticmethod - def _load_ontology_schema_from_file(local_file: str, local_file_format: str = None): + def _load_ontology_schema_from_file(local_file: str, local_file_format: str = None): # type: ignore[no-untyped-def, assignment] """ Parse the ontology schema statements from the provided file """ @@ -176,7 +177,7 @@ class OntotextGraphDBGraph: "Invalid query type. Only CONSTRUCT queries are supported." ) - def _load_ontology_schema_with_query(self, query: str): + def _load_ontology_schema_with_query(self, query: str): # type: ignore[no-untyped-def] """ Execute the query for collecting the ontology schema statements """ diff --git a/libs/community/langchain_community/graphs/tigergraph_graph.py b/libs/community/langchain_community/graphs/tigergraph_graph.py index cff2f4e2ce7..f32d43e6e5e 100644 --- a/libs/community/langchain_community/graphs/tigergraph_graph.py +++ b/libs/community/langchain_community/graphs/tigergraph_graph.py @@ -31,7 +31,7 @@ class TigerGraph(GraphStore): def schema(self) -> Dict[str, Any]: return self._schema - def get_schema(self) -> str: + def get_schema(self) -> str: # type: ignore[override] if self._schema: return str(self._schema) else: @@ -71,10 +71,10 @@ class TigerGraph(GraphStore): """ return self._conn.getSchema(force=True) - def refresh_schema(self): + def refresh_schema(self): # type: ignore[no-untyped-def] self.generate_schema() - def query(self, query: str) -> Dict[str, Any]: + def query(self, query: str) -> Dict[str, Any]: # type: ignore[override] """Query the TigerGraph database.""" answer = self._conn.ai.query(query) return answer diff --git a/libs/community/langchain_community/llms/azureml_endpoint.py b/libs/community/langchain_community/llms/azureml_endpoint.py index 3480a801f96..3601a7629a2 100644 --- a/libs/community/langchain_community/llms/azureml_endpoint.py +++ b/libs/community/langchain_community/llms/azureml_endpoint.py @@ -165,7 +165,7 @@ class GPT2ContentFormatter(ContentFormatterBase): def supported_api_types(self) -> List[AzureMLEndpointApiType]: return [AzureMLEndpointApiType.realtime] - def format_request_payload( + def format_request_payload( # type: ignore[override] self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType ) -> bytes: prompt = ContentFormatterBase.escape_special_characters(prompt) @@ -174,13 +174,13 @@ class GPT2ContentFormatter(ContentFormatterBase): ) return str.encode(request_payload) - def format_response_payload( + def format_response_payload( # type: ignore[override] self, output: bytes, api_type: AzureMLEndpointApiType ) -> Generation: try: choice = json.loads(output)[0]["0"] except (KeyError, IndexError, TypeError) as e: - raise ValueError(self.format_error_msg.format(api_type=api_type)) from e + raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr] return Generation(text=choice) @@ -207,7 +207,7 @@ class HFContentFormatter(ContentFormatterBase): def supported_api_types(self) -> List[AzureMLEndpointApiType]: return [AzureMLEndpointApiType.realtime] - def format_request_payload( + def format_request_payload( # type: ignore[override] self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType ) -> bytes: ContentFormatterBase.escape_special_characters(prompt) @@ -216,13 +216,13 @@ class HFContentFormatter(ContentFormatterBase): ) return str.encode(request_payload) - def format_response_payload( + def format_response_payload( # type: ignore[override] self, output: bytes, api_type: AzureMLEndpointApiType ) -> Generation: try: choice = json.loads(output)[0]["0"]["generated_text"] except (KeyError, IndexError, TypeError) as e: - raise ValueError(self.format_error_msg.format(api_type=api_type)) from e + raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr] return Generation(text=choice) @@ -233,7 +233,7 @@ class DollyContentFormatter(ContentFormatterBase): def supported_api_types(self) -> List[AzureMLEndpointApiType]: return [AzureMLEndpointApiType.realtime] - def format_request_payload( + def format_request_payload( # type: ignore[override] self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType ) -> bytes: prompt = ContentFormatterBase.escape_special_characters(prompt) @@ -245,13 +245,13 @@ class DollyContentFormatter(ContentFormatterBase): ) return str.encode(request_payload) - def format_response_payload( + def format_response_payload( # type: ignore[override] self, output: bytes, api_type: AzureMLEndpointApiType ) -> Generation: try: choice = json.loads(output)[0] except (KeyError, IndexError, TypeError) as e: - raise ValueError(self.format_error_msg.format(api_type=api_type)) from e + raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr] return Generation(text=choice) @@ -262,7 +262,7 @@ class LlamaContentFormatter(ContentFormatterBase): def supported_api_types(self) -> List[AzureMLEndpointApiType]: return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless] - def format_request_payload( + def format_request_payload( # type: ignore[override] self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType ) -> bytes: """Formats the request according to the chosen api""" @@ -284,7 +284,7 @@ class LlamaContentFormatter(ContentFormatterBase): ) return str.encode(request_payload) - def format_response_payload( + def format_response_payload( # type: ignore[override] self, output: bytes, api_type: AzureMLEndpointApiType ) -> Generation: """Formats response""" @@ -292,7 +292,7 @@ class LlamaContentFormatter(ContentFormatterBase): try: choice = json.loads(output)[0]["0"] except (KeyError, IndexError, TypeError) as e: - raise ValueError(self.format_error_msg.format(api_type=api_type)) from e + raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr] return Generation(text=choice) if api_type == AzureMLEndpointApiType.serverless: try: @@ -304,7 +304,7 @@ class LlamaContentFormatter(ContentFormatterBase): "received." ) except (KeyError, IndexError, TypeError) as e: - raise ValueError(self.format_error_msg.format(api_type=api_type)) from e + raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr] return Generation( text=choice["text"].strip(), generation_info=dict( @@ -397,7 +397,7 @@ class AzureMLBaseEndpoint(BaseModel): ) -> AzureMLEndpointApiType: """Validate that endpoint api type is compatible with the URL format.""" endpoint_url = values.get("endpoint_url") - if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith( + if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith( # type: ignore[union-attr] "/score" ): raise ValueError( @@ -407,8 +407,8 @@ class AzureMLBaseEndpoint(BaseModel): "`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead." ) if field_value == AzureMLEndpointApiType.serverless and not ( - endpoint_url.endswith("/v1/completions") - or endpoint_url.endswith("/v1/chat/completions") + endpoint_url.endswith("/v1/completions") # type: ignore[union-attr] + or endpoint_url.endswith("/v1/chat/completions") # type: ignore[union-attr] ): raise ValueError( "Endpoints of type `serverless` should follow the format " @@ -426,7 +426,9 @@ class AzureMLBaseEndpoint(BaseModel): deployment_name = values.get("deployment_name") http_client = AzureMLEndpointClient( - endpoint_url, endpoint_key.get_secret_value(), deployment_name + endpoint_url, # type: ignore + endpoint_key.get_secret_value(), # type: ignore + deployment_name, # type: ignore ) return http_client diff --git a/libs/community/langchain_community/llms/baichuan.py b/libs/community/langchain_community/llms/baichuan.py index 2627b81bd24..293c95d2c84 100644 --- a/libs/community/langchain_community/llms/baichuan.py +++ b/libs/community/langchain_community/llms/baichuan.py @@ -56,11 +56,11 @@ class BaichuanLLM(LLM): def _post(self, request: Any) -> Any: headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self.baichuan_api_key.get_secret_value()}", + "Authorization": f"Bearer {self.baichuan_api_key.get_secret_value()}", # type: ignore[union-attr] } try: response = requests.post( - self.baichuan_api_host, + self.baichuan_api_host, # type: ignore[arg-type] headers=headers, json=request, timeout=self.timeout, diff --git a/libs/community/langchain_community/llms/bedrock.py b/libs/community/langchain_community/llms/bedrock.py index c66d0d071a8..002eb90fc0b 100644 --- a/libs/community/langchain_community/llms/bedrock.py +++ b/libs/community/langchain_community/llms/bedrock.py @@ -395,8 +395,8 @@ class BedrockBase(BaseModel, ABC): """ return { "amazon-bedrock-guardrailDetails": { - "guardrailId": self.guardrails.get("id"), - "guardrailVersion": self.guardrails.get("version"), + "guardrailId": self.guardrails.get("id"), # type: ignore[union-attr] + "guardrailVersion": self.guardrails.get("version"), # type: ignore[union-attr] } } @@ -427,7 +427,7 @@ class BedrockBase(BaseModel, ABC): if self._guardrails_enabled: request_options["guardrail"] = "ENABLED" - if self.guardrails.get("trace"): + if self.guardrails.get("trace"): # type: ignore[union-attr] request_options["trace"] = "ENABLED" try: @@ -446,7 +446,7 @@ class BedrockBase(BaseModel, ABC): # Verify and raise a callback error if any intervention occurs or a signal is # sent from a Bedrock service, # such as when guardrails are triggered. - services_trace = self._get_bedrock_services_signal(body) + services_trace = self._get_bedrock_services_signal(body) # type: ignore[arg-type] if services_trace.get("signal") and run_manager is not None: run_manager.on_llm_error( @@ -468,7 +468,7 @@ class BedrockBase(BaseModel, ABC): if ( self._guardrails_enabled - and self.guardrails.get("trace") + and self.guardrails.get("trace") # type: ignore[union-attr] and self._is_guardrails_intervention(body) ): return { @@ -526,7 +526,7 @@ class BedrockBase(BaseModel, ABC): if self._guardrails_enabled: request_options["guardrail"] = "ENABLED" - if self.guardrails.get("trace"): + if self.guardrails.get("trace"): # type: ignore[union-attr] request_options["trace"] = "ENABLED" try: @@ -540,7 +540,7 @@ class BedrockBase(BaseModel, ABC): ): yield chunk # verify and raise callback error if any middleware intervened - self._get_bedrock_services_signal(chunk.generation_info) + self._get_bedrock_services_signal(chunk.generation_info) # type: ignore[arg-type] if run_manager is not None: run_manager.on_llm_new_token(chunk.text, chunk=chunk) @@ -588,7 +588,7 @@ class BedrockBase(BaseModel, ABC): ): await run_manager.on_llm_new_token(chunk.text, chunk=chunk) elif run_manager is not None: - run_manager.on_llm_new_token(chunk.text, chunk=chunk) + run_manager.on_llm_new_token(chunk.text, chunk=chunk) # type: ignore[unused-coroutine] class Bedrock(LLM, BedrockBase): diff --git a/libs/community/langchain_community/llms/oci_generative_ai.py b/libs/community/langchain_community/llms/oci_generative_ai.py index 092cbda5548..0ed977e24b0 100644 --- a/libs/community/langchain_community/llms/oci_generative_ai.py +++ b/libs/community/langchain_community/llms/oci_generative_ai.py @@ -42,10 +42,10 @@ class OCIGenAIBase(BaseModel, ABC): If not specified , DEFAULT will be used """ - model_id: str = None + model_id: str = None # type: ignore[assignment] """Id of the model to call, e.g., cohere.command""" - provider: str = None + provider: str = None # type: ignore[assignment] """Provider name of the model. Default to None, will try to be derived from the model_id otherwise, requires user input @@ -54,10 +54,10 @@ class OCIGenAIBase(BaseModel, ABC): model_kwargs: Optional[Dict] = None """Keyword arguments to pass to the model""" - service_endpoint: str = None + service_endpoint: str = None # type: ignore[assignment] """service endpoint url""" - compartment_id: str = None + compartment_id: str = None # type: ignore[assignment] """OCID of compartment""" is_stream: bool = False @@ -94,7 +94,7 @@ class OCIGenAIBase(BaseModel, ABC): client_kwargs.pop("signer", None) elif values["auth_type"] == OCIAuthType(2).name: - def make_security_token_signer(oci_config): + def make_security_token_signer(oci_config): # type: ignore[no-untyped-def] pk = oci.signer.load_private_key_from_file( oci_config.get("key_file"), None ) diff --git a/libs/community/langchain_community/llms/ollama.py b/libs/community/langchain_community/llms/ollama.py index a06ab72641b..c6aba99accf 100644 --- a/libs/community/langchain_community/llms/ollama.py +++ b/libs/community/langchain_community/llms/ollama.py @@ -297,7 +297,7 @@ class _OllamaCommon(BaseLanguageModel): "Ollama call failed with status code 404." ) else: - optional_detail = await response.json().get("error") + optional_detail = await response.json().get("error") # type: ignore[attr-defined] raise ValueError( f"Ollama call failed with status code {response.status}." f" Details: {optional_detail}" @@ -380,7 +380,7 @@ class Ollama(BaseLLM, _OllamaCommon): """Return type of llm.""" return "ollama-llm" - def _generate( + def _generate( # type: ignore[override] self, prompts: List[str], stop: Optional[List[str]] = None, @@ -416,7 +416,7 @@ class Ollama(BaseLLM, _OllamaCommon): generations.append([final_chunk]) return LLMResult(generations=generations) - async def _agenerate( + async def _agenerate( # type: ignore[override] self, prompts: List[str], stop: Optional[List[str]] = None, @@ -445,7 +445,7 @@ class Ollama(BaseLLM, _OllamaCommon): prompt, stop=stop, images=images, - run_manager=run_manager, + run_manager=run_manager, # type: ignore[arg-type] verbose=self.verbose, **kwargs, ) diff --git a/libs/community/langchain_community/llms/pipelineai.py b/libs/community/langchain_community/llms/pipelineai.py index d309c9cab55..170722d920a 100644 --- a/libs/community/langchain_community/llms/pipelineai.py +++ b/libs/community/langchain_community/llms/pipelineai.py @@ -102,7 +102,7 @@ class PipelineAI(LLM, BaseModel): "Could not import pipeline-ai python package. " "Please install it with `pip install pipeline-ai`." ) - client = PipelineCloud(token=self.pipeline_api_key.get_secret_value()) + client = PipelineCloud(token=self.pipeline_api_key.get_secret_value()) # type: ignore[union-attr] params = self.pipeline_kwargs or {} params = {**params, **kwargs} diff --git a/libs/community/langchain_community/llms/stochasticai.py b/libs/community/langchain_community/llms/stochasticai.py index d645a019da7..4cfc6d5ee94 100644 --- a/libs/community/langchain_community/llms/stochasticai.py +++ b/libs/community/langchain_community/llms/stochasticai.py @@ -107,7 +107,7 @@ class StochasticAI(LLM): url=self.api_url, json={"prompt": prompt, "params": params}, headers={ - "apiKey": f"{self.stochasticai_api_key.get_secret_value()}", + "apiKey": f"{self.stochasticai_api_key.get_secret_value()}", # type: ignore[union-attr] "Accept": "application/json", "Content-Type": "application/json", }, @@ -119,7 +119,7 @@ class StochasticAI(LLM): response_get = requests.get( url=response_post_json["data"]["responseUrl"], headers={ - "apiKey": f"{self.stochasticai_api_key.get_secret_value()}", + "apiKey": f"{self.stochasticai_api_key.get_secret_value()}", # type: ignore[union-attr] "Accept": "application/json", "Content-Type": "application/json", }, diff --git a/libs/community/langchain_community/llms/vertexai.py b/libs/community/langchain_community/llms/vertexai.py index 3833bbd785c..58e209f7665 100644 --- a/libs/community/langchain_community/llms/vertexai.py +++ b/libs/community/langchain_community/llms/vertexai.py @@ -49,7 +49,7 @@ def is_gemini_model(model_name: str) -> bool: return model_name is not None and "gemini" in model_name -def completion_with_retry( +def completion_with_retry( # type: ignore[no-redef] llm: VertexAI, prompt: List[Union[str, "Image"]], stream: bool = False, @@ -330,7 +330,7 @@ class VertexAI(_VertexAICommon, BaseLLM): generation += chunk generations.append([generation]) else: - res = completion_with_retry( + res = completion_with_retry( # type: ignore[misc] self, [prompt], stream=should_stream, @@ -373,7 +373,7 @@ class VertexAI(_VertexAICommon, BaseLLM): **kwargs: Any, ) -> Iterator[GenerationChunk]: params = self._prepare_params(stop=stop, stream=True, **kwargs) - for stream_resp in completion_with_retry( + for stream_resp in completion_with_retry( # type: ignore[misc] self, [prompt], stream=True, diff --git a/libs/community/langchain_community/llms/watsonxllm.py b/libs/community/langchain_community/llms/watsonxllm.py index d60c7284600..d3225730c39 100644 --- a/libs/community/langchain_community/llms/watsonxllm.py +++ b/libs/community/langchain_community/llms/watsonxllm.py @@ -250,9 +250,9 @@ class WatsonxLLM(BaseLLM): } def _get_chat_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]: - params: Dict[str, Any] = {**self.params} if self.params else None + params: Dict[str, Any] = {**self.params} if self.params else {} if stop is not None: - params = (params or {}) | {"stop_sequences": stop} + params["stop_sequences"] = stop return params def _create_llm_result(self, response: List[dict]) -> LLMResult: diff --git a/libs/community/langchain_community/llms/yandex.py b/libs/community/langchain_community/llms/yandex.py index f8e05373318..d4cf5840ca1 100644 --- a/libs/community/langchain_community/llms/yandex.py +++ b/libs/community/langchain_community/llms/yandex.py @@ -25,10 +25,10 @@ logger = logging.getLogger(__name__) class _BaseYandexGPT(Serializable): - iam_token: SecretStr = "" + iam_token: SecretStr = "" # type: ignore[assignment] """Yandex Cloud IAM token for service or user account with the `ai.languageModels.user` role""" - api_key: SecretStr = "" + api_key: SecretStr = "" # type: ignore[assignment] """Yandex Cloud Api Key for service account with the `ai.languageModels.user` role""" folder_id: str = "" @@ -211,7 +211,7 @@ def _make_request( messages=[Message(role="user", text=prompt)], ) stub = TextGenerationServiceStub(channel) - res = stub.Completion(request, metadata=self._grpc_metadata) + res = stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined] return list(res)[0].alternatives[0].message.text @@ -253,7 +253,7 @@ async def _amake_request(self: YandexGPT, prompt: str) -> str: messages=[Message(role="user", text=prompt)], ) stub = TextGenerationAsyncServiceStub(channel) - operation = await stub.Completion(request, metadata=self._grpc_metadata) + operation = await stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined] async with grpc.aio.secure_channel( operation_api_url, channel_credentials ) as operation_channel: @@ -262,7 +262,8 @@ async def _amake_request(self: YandexGPT, prompt: str) -> str: await asyncio.sleep(1) operation_request = GetOperationRequest(operation_id=operation.id) operation = await operation_stub.Get( - operation_request, metadata=self._grpc_metadata + operation_request, + metadata=self._grpc_metadata, # type: ignore[attr-defined] ) completion_response = CompletionResponse() diff --git a/libs/community/langchain_community/tools/amadeus/closest_airport.py b/libs/community/langchain_community/tools/amadeus/closest_airport.py index cf108f3b11d..1a3f4c4da2c 100644 --- a/libs/community/langchain_community/tools/amadeus/closest_airport.py +++ b/libs/community/langchain_community/tools/amadeus/closest_airport.py @@ -58,4 +58,4 @@ class AmadeusClosestAirport(AmadeusBaseTool): ' Location Identifier" ' ) - return self.llm.invoke(content) + return self.llm.invoke(content) # type: ignore[union-attr] diff --git a/libs/community/langchain_community/tools/shell/tool.py b/libs/community/langchain_community/tools/shell/tool.py index e26deb365dc..15445441a30 100644 --- a/libs/community/langchain_community/tools/shell/tool.py +++ b/libs/community/langchain_community/tools/shell/tool.py @@ -93,10 +93,10 @@ class ShellTool(BaseTool): return self.process.run(commands) else: logger.info("Invalid input. User aborted command execution.") - return None + return None # type: ignore[return-value] else: return self.process.run(commands) except Exception as e: logger.error(f"Error during command execution: {e}") - return None + return None # type: ignore[return-value] diff --git a/libs/community/langchain_community/utilities/brave_search.py b/libs/community/langchain_community/utilities/brave_search.py index 8f3df0666c6..fd282fc3465 100644 --- a/libs/community/langchain_community/utilities/brave_search.py +++ b/libs/community/langchain_community/utilities/brave_search.py @@ -48,7 +48,7 @@ class BraveSearchWrapper(BaseModel): results = self._search_request(query) return [ Document( - page_content=item.get("description"), + page_content=item.get("description"), # type: ignore[arg-type] metadata={"title": item.get("title"), "link": item.get("url")}, ) for item in results diff --git a/libs/community/langchain_community/utilities/requests.py b/libs/community/langchain_community/utilities/requests.py index 673df3d5e57..fd183b4cfad 100644 --- a/libs/community/langchain_community/utilities/requests.py +++ b/libs/community/langchain_community/utilities/requests.py @@ -141,9 +141,9 @@ class GenericRequestsWrapper(BaseModel): self, response: aiohttp.ClientResponse ) -> Union[str, Dict[str, Any]]: if self.response_content_type == "text": - return response.text() + return response.text() # type: ignore[return-value] elif self.response_content_type == "json": - return response.json() + return response.json() # type: ignore[return-value] else: raise ValueError(f"Invalid return type: {self.response_content_type}") @@ -176,33 +176,33 @@ class GenericRequestsWrapper(BaseModel): async def aget(self, url: str, **kwargs: Any) -> Union[str, Dict[str, Any]]: """GET the URL and return the text asynchronously.""" async with self.requests.aget(url, **kwargs) as response: - return await self._aget_resp_content(response) + return await self._aget_resp_content(response) # type: ignore[misc] async def apost( self, url: str, data: Dict[str, Any], **kwargs: Any ) -> Union[str, Dict[str, Any]]: """POST to the URL and return the text asynchronously.""" async with self.requests.apost(url, data, **kwargs) as response: - return await self._aget_resp_content(response) + return await self._aget_resp_content(response) # type: ignore[misc] async def apatch( self, url: str, data: Dict[str, Any], **kwargs: Any ) -> Union[str, Dict[str, Any]]: """PATCH the URL and return the text asynchronously.""" async with self.requests.apatch(url, data, **kwargs) as response: - return await self._aget_resp_content(response) + return await self._aget_resp_content(response) # type: ignore[misc] async def aput( self, url: str, data: Dict[str, Any], **kwargs: Any ) -> Union[str, Dict[str, Any]]: """PUT the URL and return the text asynchronously.""" async with self.requests.aput(url, data, **kwargs) as response: - return await self._aget_resp_content(response) + return await self._aget_resp_content(response) # type: ignore[misc] async def adelete(self, url: str, **kwargs: Any) -> Union[str, Dict[str, Any]]: """DELETE the URL and return the text asynchronously.""" async with self.requests.adelete(url, **kwargs) as response: - return await self._aget_resp_content(response) + return await self._aget_resp_content(response) # type: ignore[misc] class JsonRequestsWrapper(GenericRequestsWrapper): diff --git a/libs/community/langchain_community/utilities/sql_database.py b/libs/community/langchain_community/utilities/sql_database.py index e56423733cb..298b8ca7e22 100644 --- a/libs/community/langchain_community/utilities/sql_database.py +++ b/libs/community/langchain_community/utilities/sql_database.py @@ -381,7 +381,7 @@ class SQLDatabase: If the statement returns no rows, an empty list is returned. """ - with self._engine.begin() as connection: # type: Connection + with self._engine.begin() as connection: # type: Connection # type: ignore[name-defined] if self._schema is not None: if self.dialect == "snowflake": connection.exec_driver_sql( @@ -444,7 +444,7 @@ class SQLDatabase: ] if not include_columns: - res = [tuple(row.values()) for row in res] + res = [tuple(row.values()) for row in res] # type: ignore[misc] if not res: return "" diff --git a/libs/community/langchain_community/vectorstores/alibabacloud_opensearch.py b/libs/community/langchain_community/vectorstores/alibabacloud_opensearch.py index c9d5b85a6bd..12d02ae19d3 100644 --- a/libs/community/langchain_community/vectorstores/alibabacloud_opensearch.py +++ b/libs/community/langchain_community/vectorstores/alibabacloud_opensearch.py @@ -356,7 +356,7 @@ class AlibabaCloudOpenSearch(VectorStore): "fields" not in item or self.config.field_name_mapping["document"] not in item["fields"] ): - query_result_list.append(Document()) + query_result_list.append(Document()) # type: ignore[call-arg] else: fields = item["fields"] query_result_list.append( diff --git a/libs/community/langchain_community/vectorstores/astradb.py b/libs/community/langchain_community/vectorstores/astradb.py index 7d59bc91ebd..e6d1a5e0100 100644 --- a/libs/community/langchain_community/vectorstores/astradb.py +++ b/libs/community/langchain_community/vectorstores/astradb.py @@ -140,7 +140,7 @@ class AstraDB(VectorStore): if isinstance(v, list): metadata_filter[k] = [AstraDB._filter_to_metadata(f) for f in v] else: - metadata_filter[k] = AstraDB._filter_to_metadata(v) + metadata_filter[k] = AstraDB._filter_to_metadata(v) # type: ignore[assignment] else: metadata_filter[f"metadata.{k}"] = v @@ -253,13 +253,13 @@ class AstraDB(VectorStore): else: self.clear() - def _ensure_astra_db_client(self): + def _ensure_astra_db_client(self): # type: ignore[no-untyped-def] if not self.astra_db: raise ValueError("Missing AstraDB client") async def _setup_db(self, pre_delete_collection: bool) -> None: if pre_delete_collection: - await self.async_astra_db.delete_collection( + await self.async_astra_db.delete_collection( # type: ignore[union-attr] collection_name=self.collection_name, ) await self._aprovision_collection() @@ -282,7 +282,7 @@ class AstraDB(VectorStore): Internal-usage method, no object members are set, other than working on the underlying actual storage. """ - self.astra_db.create_collection( + self.astra_db.create_collection( # type: ignore[union-attr] dimension=self._get_embedding_dimension(), collection_name=self.collection_name, metric=self.metric, @@ -295,7 +295,7 @@ class AstraDB(VectorStore): Internal-usage method, no object members are set, other than working on the underlying actual storage. """ - await self.async_astra_db.create_collection( + await self.async_astra_db.create_collection( # type: ignore[union-attr] dimension=self._get_embedding_dimension(), collection_name=self.collection_name, metric=self.metric, @@ -328,7 +328,7 @@ class AstraDB(VectorStore): await self._ensure_db_setup() if not self.async_astra_db: await run_in_executor(None, self.clear) - await self.async_collection.delete_many({}) + await self.async_collection.delete_many({}) # type: ignore[union-attr] def delete_by_document_id(self, document_id: str) -> bool: """ @@ -336,7 +336,7 @@ class AstraDB(VectorStore): Return True if a document has indeed been deleted, False if ID not found. """ self._ensure_astra_db_client() - deletion_response = self.collection.delete_one(document_id) + deletion_response = self.collection.delete_one(document_id) # type: ignore[union-attr] return ((deletion_response or {}).get("status") or {}).get( "deletedCount", 0 ) == 1 @@ -434,7 +434,7 @@ class AstraDB(VectorStore): Use with caution. """ self._ensure_astra_db_client() - self.astra_db.delete_collection( + self.astra_db.delete_collection( # type: ignore[union-attr] collection_name=self.collection_name, ) @@ -448,7 +448,7 @@ class AstraDB(VectorStore): await self._ensure_db_setup() if not self.async_astra_db: await run_in_executor(None, self.delete_collection) - await self.async_astra_db.delete_collection( + await self.async_astra_db.delete_collection( # type: ignore[union-attr] collection_name=self.collection_name, ) @@ -571,7 +571,7 @@ class AstraDB(VectorStore): ) def _handle_batch(document_batch: List[DocDict]) -> List[str]: - im_result = self.collection.insert_many( + im_result = self.collection.insert_many( # type: ignore[union-attr] documents=document_batch, options={"ordered": False}, partial_failures_allowed=True, @@ -581,7 +581,7 @@ class AstraDB(VectorStore): ) def _handle_missing_document(missing_document: DocDict) -> str: - replacement_result = self.collection.find_one_and_replace( + replacement_result = self.collection.find_one_and_replace( # type: ignore[union-attr] filter={"_id": missing_document["_id"]}, replacement=missing_document, ) @@ -672,7 +672,7 @@ class AstraDB(VectorStore): ) async def _handle_batch(document_batch: List[DocDict]) -> List[str]: - im_result = await self.async_collection.insert_many( + im_result = await self.async_collection.insert_many( # type: ignore[union-attr] documents=document_batch, options={"ordered": False}, partial_failures_allowed=True, @@ -682,7 +682,7 @@ class AstraDB(VectorStore): ) async def _handle_missing_document(missing_document: DocDict) -> str: - replacement_result = await self.async_collection.find_one_and_replace( + replacement_result = await self.async_collection.find_one_and_replace( # type: ignore[union-attr] filter={"_id": missing_document["_id"]}, replacement=missing_document, ) @@ -729,7 +729,7 @@ class AstraDB(VectorStore): metadata_parameter = self._filter_to_metadata(filter) # hits = list( - self.collection.paginated_find( + self.collection.paginated_find( # type: ignore[union-attr] filter=metadata_parameter, sort={"$vector": embedding}, options={"limit": k, "includeSimilarity": True}, @@ -771,7 +771,7 @@ class AstraDB(VectorStore): if not self.async_collection: return await run_in_executor( None, - self.asimilarity_search_with_score_id_by_vector, + self.asimilarity_search_with_score_id_by_vector, # type: ignore[arg-type] embedding, k, filter, @@ -962,7 +962,7 @@ class AstraDB(VectorStore): ) @staticmethod - def _get_mmr_hits(embedding, k, lambda_mult, prefetch_hits): + def _get_mmr_hits(embedding, k, lambda_mult, prefetch_hits): # type: ignore[no-untyped-def] mmr_chosen_indices = maximal_marginal_relevance( np.array(embedding, dtype=np.float32), [prefetch_hit["$vector"] for prefetch_hit in prefetch_hits], @@ -1008,7 +1008,7 @@ class AstraDB(VectorStore): metadata_parameter = self._filter_to_metadata(filter) prefetch_hits = list( - self.collection.paginated_find( + self.collection.paginated_find( # type: ignore[union-attr] filter=metadata_parameter, sort={"$vector": embedding}, options={"limit": fetch_k, "includeSimilarity": True}, @@ -1228,7 +1228,7 @@ class AstraDB(VectorStore): batch_concurrency=kwargs.get("batch_concurrency"), overwrite_concurrency=kwargs.get("overwrite_concurrency"), ) - return astra_db_store + return astra_db_store # type: ignore[return-value] @classmethod async def afrom_texts( @@ -1263,7 +1263,7 @@ class AstraDB(VectorStore): batch_concurrency=kwargs.get("batch_concurrency"), overwrite_concurrency=kwargs.get("overwrite_concurrency"), ) - return astra_db_store + return astra_db_store # type: ignore[return-value] @classmethod def from_documents( diff --git a/libs/community/langchain_community/vectorstores/azuresearch.py b/libs/community/langchain_community/vectorstores/azuresearch.py index 992db81f227..9006985fa2d 100644 --- a/libs/community/langchain_community/vectorstores/azuresearch.py +++ b/libs/community/langchain_community/vectorstores/azuresearch.py @@ -339,7 +339,7 @@ class AzureSearch(VectorStore): # batching support if embedding function is an Embeddings object if isinstance(self.embedding_function, Embeddings): try: - embeddings = self.embedding_function.embed_documents(texts) + embeddings = self.embedding_function.embed_documents(texts) # type: ignore[arg-type] except NotImplementedError: embeddings = [self.embedding_function.embed_query(x) for x in texts] else: diff --git a/libs/community/langchain_community/vectorstores/bigquery_vector_search.py b/libs/community/langchain_community/vectorstores/bigquery_vector_search.py index 64a1f4b7655..28edc6536ce 100644 --- a/libs/community/langchain_community/vectorstores/bigquery_vector_search.py +++ b/libs/community/langchain_community/vectorstores/bigquery_vector_search.py @@ -222,7 +222,7 @@ class BigQueryVectorSearch(VectorStore): self._logger.debug("Vector index already exists.") self._have_index = True - def _create_index_in_background(self): + def _create_index_in_background(self): # type: ignore[no-untyped-def] if self._have_index or self._creating_index: # Already have an index or in the process of creating one. return @@ -231,7 +231,7 @@ class BigQueryVectorSearch(VectorStore): thread = Thread(target=self._create_index, daemon=True) thread.start() - def _create_index(self): + def _create_index(self): # type: ignore[no-untyped-def] from google.api_core.exceptions import ClientError table = self.bq_client.get_table(self.vectors_table) @@ -289,7 +289,7 @@ class BigQueryVectorSearch(VectorStore): def full_table_id(self) -> str: return self._full_table_id - def add_texts( + def add_texts( # type: ignore[override] self, texts: List[str], metadatas: Optional[List[dict]] = None, diff --git a/libs/community/langchain_community/vectorstores/deeplake.py b/libs/community/langchain_community/vectorstores/deeplake.py index 52fe55fa7ce..659c24c6ca8 100644 --- a/libs/community/langchain_community/vectorstores/deeplake.py +++ b/libs/community/langchain_community/vectorstores/deeplake.py @@ -905,7 +905,7 @@ class DeepLake(VectorStore): return self.vectorstore.dataset @classmethod - def _validate_kwargs(cls, kwargs, method_name): + def _validate_kwargs(cls, kwargs, method_name): # type: ignore[no-untyped-def] if kwargs: valid_items = cls._get_valid_args(method_name) unsupported_items = cls._get_unsupported_items(kwargs, valid_items) @@ -917,14 +917,14 @@ class DeepLake(VectorStore): ) @classmethod - def _get_valid_args(cls, method_name): + def _get_valid_args(cls, method_name): # type: ignore[no-untyped-def] if method_name == "search": return cls._valid_search_kwargs else: return [] @staticmethod - def _get_unsupported_items(kwargs, valid_items): + def _get_unsupported_items(kwargs, valid_items): # type: ignore[no-untyped-def] kwargs = {k: v for k, v in kwargs.items() if k not in valid_items} unsupported_items = None if kwargs: diff --git a/libs/community/langchain_community/vectorstores/faiss.py b/libs/community/langchain_community/vectorstores/faiss.py index 044209add23..0341756fc5f 100644 --- a/libs/community/langchain_community/vectorstores/faiss.py +++ b/libs/community/langchain_community/vectorstores/faiss.py @@ -305,7 +305,7 @@ class FAISS(VectorStore): if filter is not None: if isinstance(filter, dict): - def filter_func(metadata): + def filter_func(metadata): # type: ignore[no-untyped-def] if all( metadata.get(key) in value if isinstance(value, list) @@ -607,7 +607,7 @@ class FAISS(VectorStore): filtered_indices = [] if isinstance(filter, dict): - def filter_func(metadata): + def filter_func(metadata): # type: ignore[no-untyped-def] if all( metadata.get(key) in value if isinstance(value, list) diff --git a/libs/community/langchain_community/vectorstores/hanavector.py b/libs/community/langchain_community/vectorstores/hanavector.py index 04eec65a448..5aba9334828 100644 --- a/libs/community/langchain_community/vectorstores/hanavector.py +++ b/libs/community/langchain_community/vectorstores/hanavector.py @@ -117,7 +117,7 @@ class HanaDB(VectorStore): self.vector_column_length, ) - def _table_exists(self, table_name) -> bool: + def _table_exists(self, table_name) -> bool: # type: ignore[no-untyped-def] sql_str = ( "SELECT COUNT(*) FROM SYS.TABLES WHERE SCHEMA_NAME = CURRENT_SCHEMA" " AND TABLE_NAME = ?" @@ -133,7 +133,7 @@ class HanaDB(VectorStore): cur.close() return False - def _check_column(self, table_name, column_name, column_type, column_length=None): + def _check_column(self, table_name, column_name, column_type, column_length=None): # type: ignore[no-untyped-def] sql_str = ( "SELECT DATA_TYPE_NAME, LENGTH FROM SYS.TABLE_COLUMNS WHERE " "SCHEMA_NAME = CURRENT_SCHEMA " @@ -166,17 +166,17 @@ class HanaDB(VectorStore): def embeddings(self) -> Embeddings: return self.embedding - def _sanitize_name(input_str: str) -> str: + def _sanitize_name(input_str: str) -> str: # type: ignore[misc] # Remove characters that are not alphanumeric or underscores return re.sub(r"[^a-zA-Z0-9_]", "", input_str) - def _sanitize_int(input_int: any) -> int: + def _sanitize_int(input_int: any) -> int: # type: ignore[valid-type] value = int(str(input_int)) if value < -1: raise ValueError(f"Value ({value}) must not be smaller than -1") return int(str(input_int)) - def _sanitize_list_float(embedding: List[float]) -> List[float]: + def _sanitize_list_float(embedding: List[float]) -> List[float]: # type: ignore[misc] for value in embedding: if not isinstance(value, float): raise ValueError(f"Value ({value}) does not have type float") @@ -185,14 +185,14 @@ class HanaDB(VectorStore): # Compile pattern only once, for better performance _compiled_pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$") - def _sanitize_metadata_keys(metadata: dict) -> dict: + def _sanitize_metadata_keys(metadata: dict) -> dict: # type: ignore[misc] for key in metadata.keys(): if not HanaDB._compiled_pattern.match(key): raise ValueError(f"Invalid metadata key {key}") return metadata - def add_texts( + def add_texts( # type: ignore[override] self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, @@ -243,7 +243,7 @@ class HanaDB(VectorStore): return [] @classmethod - def from_texts( + def from_texts( # type: ignore[no-untyped-def, override] cls: Type[HanaDB], texts: List[str], embedding: Embeddings, @@ -277,7 +277,7 @@ class HanaDB(VectorStore): instance.add_texts(texts, metadatas) return instance - def similarity_search( + def similarity_search( # type: ignore[override] self, query: str, k: int = 4, filter: Optional[dict] = None ) -> List[Document]: """Return docs most similar to query. @@ -382,7 +382,7 @@ class HanaDB(VectorStore): ) return [(result_item[0], result_item[1]) for result_item in whole_result] - def similarity_search_by_vector( + def similarity_search_by_vector( # type: ignore[override] self, embedding: List[float], k: int = 4, filter: Optional[dict] = None ) -> List[Document]: """Return docs most similar to embedding vector. @@ -401,7 +401,7 @@ class HanaDB(VectorStore): ) return [doc for doc, _ in docs_and_scores] - def _create_where_by_filter(self, filter): + def _create_where_by_filter(self, filter): # type: ignore[no-untyped-def] query_tuple = [] where_str = "" if filter: @@ -427,7 +427,7 @@ class HanaDB(VectorStore): return where_str, query_tuple - def delete( + def delete( # type: ignore[override] self, ids: Optional[List[str]] = None, filter: Optional[dict] = None ) -> Optional[bool]: """Delete entries by filter with metadata values @@ -459,7 +459,7 @@ class HanaDB(VectorStore): return True - async def adelete( + async def adelete( # type: ignore[override] self, ids: Optional[List[str]] = None, filter: Optional[dict] = None ) -> Optional[bool]: """Delete by vector ID or other criteria. @@ -473,7 +473,7 @@ class HanaDB(VectorStore): """ return await run_in_executor(None, self.delete, ids=ids, filter=filter) - def max_marginal_relevance_search( + def max_marginal_relevance_search( # type: ignore[override] self, query: str, k: int = 4, @@ -511,11 +511,11 @@ class HanaDB(VectorStore): filter=filter, ) - def _parse_float_array_from_string(array_as_string: str) -> List[float]: + def _parse_float_array_from_string(array_as_string: str) -> List[float]: # type: ignore[misc] array_wo_brackets = array_as_string[1:-1] return [float(x) for x in array_wo_brackets.split(",")] - def max_marginal_relevance_search_by_vector( + def max_marginal_relevance_search_by_vector( # type: ignore[override] self, embedding: List[float], k: int = 4, @@ -533,7 +533,7 @@ class HanaDB(VectorStore): return [whole_result[i][0] for i in mmr_doc_indexes] - async def amax_marginal_relevance_search_by_vector( + async def amax_marginal_relevance_search_by_vector( # type: ignore[override] self, embedding: List[float], k: int = 4, diff --git a/libs/community/langchain_community/vectorstores/jaguar.py b/libs/community/langchain_community/vectorstores/jaguar.py index 2771530d66d..a42cdf3641e 100644 --- a/libs/community/langchain_community/vectorstores/jaguar.py +++ b/libs/community/langchain_community/vectorstores/jaguar.py @@ -135,7 +135,7 @@ class Jaguar(VectorStore): def embeddings(self) -> Optional[Embeddings]: return self._embedding - def add_texts( + def add_texts( # type: ignore[override] self, texts: List[str], metadatas: Optional[List[dict]] = None, @@ -351,7 +351,7 @@ class Jaguar(VectorStore): return False @classmethod - def from_texts( + def from_texts( # type: ignore[override] cls, texts: List[str], embedding: Embeddings, @@ -383,7 +383,7 @@ class Jaguar(VectorStore): q = "truncate store " + podstore self.run(q) - def delete(self, zids: List[str], **kwargs: Any) -> None: + def delete(self, zids: List[str], **kwargs: Any) -> None: # type: ignore[override] """ Delete records in jaguardb by a list of zero-ids Args: diff --git a/libs/community/langchain_community/vectorstores/milvus.py b/libs/community/langchain_community/vectorstores/milvus.py index e72a82f2bd2..28244621ca7 100644 --- a/libs/community/langchain_community/vectorstores/milvus.py +++ b/libs/community/langchain_community/vectorstores/milvus.py @@ -554,10 +554,10 @@ class Milvus(VectorStore): } if not self.auto_id: - insert_dict[self._primary_field] = ids + insert_dict[self._primary_field] = ids # type: ignore[assignment] if self._metadata_field is not None: - for d in metadatas: + for d in metadatas: # type: ignore[union-attr] insert_dict.setdefault(self._metadata_field, []).append(d) else: # Collect the metadata into the insert dict. @@ -901,7 +901,7 @@ class Milvus(VectorStore): ret.append(documents[x]) return ret - def delete( + def delete( # type: ignore[no-untyped-def] self, ids: Optional[List[str]] = None, expr: Optional[str] = None, **kwargs: str ): """Delete by vector ID or boolean expression. @@ -923,7 +923,7 @@ class Milvus(VectorStore): assert isinstance( expr, str ), "Either ids list or expr string must be provided." - return self.col.delete(expr=expr, **kwargs) + return self.col.delete(expr=expr, **kwargs) # type: ignore[union-attr] @classmethod def from_texts( diff --git a/libs/community/langchain_community/vectorstores/pgembedding.py b/libs/community/langchain_community/vectorstores/pgembedding.py index d5c37d5942e..cbdf65c6c4f 100644 --- a/libs/community/langchain_community/vectorstores/pgembedding.py +++ b/libs/community/langchain_community/vectorstores/pgembedding.py @@ -398,7 +398,7 @@ class PGEmbedding(VectorStore): docs = [ ( Document( - page_content=result.EmbeddingStore.document, + page_content=result.EmbeddingStore.document, # type: ignore[arg-type] metadata=result.EmbeddingStore.cmetadata, ), result.distance if self.embedding_function is not None else 0.0, diff --git a/libs/community/langchain_community/vectorstores/pgvecto_rs.py b/libs/community/langchain_community/vectorstores/pgvecto_rs.py index 2b14cd8036e..2f1dbf42720 100644 --- a/libs/community/langchain_community/vectorstores/pgvecto_rs.py +++ b/libs/community/langchain_community/vectorstores/pgvecto_rs.py @@ -133,7 +133,7 @@ class PGVecto_rs(VectorStore): Record.from_text(text, embedding, meta) for text, embedding, meta in zip(texts, embeddings, metadatas or []) ] - self._store.insert(records) + self._store.insert(records) # type: ignore[union-attr] return [str(record.id) for record in records] def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: @@ -177,7 +177,7 @@ class PGVecto_rs(VectorStore): real_filter = meta_contains(filter) else: real_filter = filter - results = self._store.search( + results = self._store.search( # type: ignore[union-attr] query_vector, distance_func_map[distance_func], k, diff --git a/libs/community/langchain_community/vectorstores/pgvector.py b/libs/community/langchain_community/vectorstores/pgvector.py index fe439d86c5a..755c72bb4dc 100644 --- a/libs/community/langchain_community/vectorstores/pgvector.py +++ b/libs/community/langchain_community/vectorstores/pgvector.py @@ -238,7 +238,7 @@ class PGVector(VectorStore): def create_vector_extension(self) -> None: try: - with Session(self._bind) as session: + with Session(self._bind) as session: # type: ignore[arg-type] # The advisor lock fixes issue arising from concurrent # creation of the vector extension. # https://github.com/langchain-ai/langchain/issues/12933 @@ -256,24 +256,24 @@ class PGVector(VectorStore): raise Exception(f"Failed to create vector extension: {e}") from e def create_tables_if_not_exists(self) -> None: - with Session(self._bind) as session, session.begin(): + with Session(self._bind) as session, session.begin(): # type: ignore[arg-type] Base.metadata.create_all(session.get_bind()) def drop_tables(self) -> None: - with Session(self._bind) as session, session.begin(): + with Session(self._bind) as session, session.begin(): # type: ignore[arg-type] Base.metadata.drop_all(session.get_bind()) def create_collection(self) -> None: if self.pre_delete_collection: self.delete_collection() - with Session(self._bind) as session: + with Session(self._bind) as session: # type: ignore[arg-type] self.CollectionStore.get_or_create( session, self.collection_name, cmetadata=self.collection_metadata ) def delete_collection(self) -> None: self.logger.debug("Trying to delete collection") - with Session(self._bind) as session: + with Session(self._bind) as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: self.logger.warning("Collection not found") @@ -284,7 +284,7 @@ class PGVector(VectorStore): @contextlib.contextmanager def _make_session(self) -> Generator[Session, None, None]: """Create a context manager for the session, bind to _conn string.""" - yield Session(self._bind) + yield Session(self._bind) # type: ignore[arg-type] def delete( self, @@ -298,7 +298,7 @@ class PGVector(VectorStore): ids: List of ids to delete. collection_only: Only delete ids in the collection. """ - with Session(self._bind) as session: + with Session(self._bind) as session: # type: ignore[arg-type] if ids is not None: self.logger.debug( "Trying to delete vectors by ids (represented by the model " @@ -383,7 +383,7 @@ class PGVector(VectorStore): if not metadatas: metadatas = [{} for _ in texts] - with Session(self._bind) as session: + with Session(self._bind) as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") @@ -508,7 +508,7 @@ class PGVector(VectorStore): ] return docs - def _create_filter_clause(self, key, value): + def _create_filter_clause(self, key, value): # type: ignore[no-untyped-def] IN, NIN, BETWEEN, GT, LT, NE = "in", "nin", "between", "gt", "lt", "ne" EQ, LIKE, CONTAINS, OR, AND = "eq", "like", "contains", "or", "and" @@ -575,7 +575,7 @@ class PGVector(VectorStore): filter: Optional[Dict[str, str]] = None, ) -> List[Any]: """Query the collection.""" - with Session(self._bind) as session: + with Session(self._bind) as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") diff --git a/libs/community/langchain_community/vectorstores/surrealdb.py b/libs/community/langchain_community/vectorstores/surrealdb.py index d21f5bf0e02..ef65c5ec6b0 100644 --- a/libs/community/langchain_community/vectorstores/surrealdb.py +++ b/libs/community/langchain_community/vectorstores/surrealdb.py @@ -115,7 +115,7 @@ class SurrealDBStore(VectorStore): for idx, text in enumerate(texts): data = {"text": text, "embedding": embeddings[idx]} if metadatas is not None and idx < len(metadatas): - data["metadata"] = metadatas[idx] + data["metadata"] = metadatas[idx] # type: ignore[assignment] record = await self.sdb.create( self.collection, data, diff --git a/libs/community/langchain_community/vectorstores/tencentvectordb.py b/libs/community/langchain_community/vectorstores/tencentvectordb.py index d1c05b647d4..e185a3ab123 100644 --- a/libs/community/langchain_community/vectorstores/tencentvectordb.py +++ b/libs/community/langchain_community/vectorstores/tencentvectordb.py @@ -316,7 +316,7 @@ class TencentVectorDB(VectorStore): meta = result.get(self.field_metadata) if meta is not None: meta = json.loads(meta) - doc = Document(page_content=result.get(self.field_text), metadata=meta) + doc = Document(page_content=result.get(self.field_text), metadata=meta) # type: ignore[arg-type] pair = (doc, result.get("score", 0.0)) ret.append(pair) return ret @@ -374,7 +374,7 @@ class TencentVectorDB(VectorStore): meta = result.get(self.field_metadata) if meta is not None: meta = json.loads(meta) - doc = Document(page_content=result.get(self.field_text), metadata=meta) + doc = Document(page_content=result.get(self.field_text), metadata=meta) # type: ignore[arg-type] documents.append(doc) ordered_result_embeddings.append(result.get(self.field_vector)) # Get the new order of results. diff --git a/libs/community/langchain_community/vectorstores/thirdai_neuraldb.py b/libs/community/langchain_community/vectorstores/thirdai_neuraldb.py index f447d73745c..444f7d14f8d 100644 --- a/libs/community/langchain_community/vectorstores/thirdai_neuraldb.py +++ b/libs/community/langchain_community/vectorstores/thirdai_neuraldb.py @@ -24,7 +24,7 @@ class NeuralDBVectorStore(VectorStore): underscore_attrs_are_private = True @staticmethod - def _verify_thirdai_library(thirdai_key: Optional[str] = None): + def _verify_thirdai_library(thirdai_key: Optional[str] = None): # type: ignore[no-untyped-def] try: from thirdai import licensing @@ -38,7 +38,7 @@ class NeuralDBVectorStore(VectorStore): ) @classmethod - def from_scratch( + def from_scratch( # type: ignore[no-untyped-def, no-untyped-def] cls, thirdai_key: Optional[str] = None, **model_kwargs, @@ -69,10 +69,10 @@ class NeuralDBVectorStore(VectorStore): NeuralDBVectorStore._verify_thirdai_library(thirdai_key) from thirdai import neural_db as ndb - return cls(db=ndb.NeuralDB(**model_kwargs)) + return cls(db=ndb.NeuralDB(**model_kwargs)) # type: ignore[call-arg] @classmethod - def from_bazaar( + def from_bazaar( # type: ignore[no-untyped-def] cls, base: str, bazaar_cache: Optional[str] = None, @@ -111,10 +111,10 @@ class NeuralDBVectorStore(VectorStore): os.mkdir(cache) model_bazaar = ndb.Bazaar(cache) model_bazaar.fetch() - return cls(db=model_bazaar.get_model(base)) + return cls(db=model_bazaar.get_model(base)) # type: ignore[call-arg] @classmethod - def from_checkpoint( + def from_checkpoint( # type: ignore[no-untyped-def] cls, checkpoint: Union[str, Path], thirdai_key: Optional[str] = None, @@ -146,7 +146,7 @@ class NeuralDBVectorStore(VectorStore): NeuralDBVectorStore._verify_thirdai_library(thirdai_key) from thirdai import neural_db as ndb - return cls(db=ndb.NeuralDB.from_checkpoint(checkpoint)) + return cls(db=ndb.NeuralDB.from_checkpoint(checkpoint)) # type: ignore[call-arg] @classmethod def from_texts( @@ -187,11 +187,11 @@ class NeuralDBVectorStore(VectorStore): df = pd.DataFrame({"texts": texts}) if metadatas: df = pd.concat([df, pd.DataFrame.from_records(metadatas)], axis=1) - temp = tempfile.NamedTemporaryFile("w", delete=False, delete_on_close=False) + temp = tempfile.NamedTemporaryFile("w", delete=False, delete_on_close=False) # type: ignore[call-overload] df.to_csv(temp) source_id = self.insert([ndb.CSV(temp.name)], **kwargs)[0] offset = self.db._savable_state.documents.get_source_by_id(source_id)[1] - return [str(offset + i) for i in range(len(texts))] + return [str(offset + i) for i in range(len(texts))] # type: ignore[arg-type] @root_validator() def validate_environments(cls, values: Dict) -> Dict: @@ -205,7 +205,7 @@ class NeuralDBVectorStore(VectorStore): ) return values - def insert( + def insert( # type: ignore[no-untyped-def, no-untyped-def] self, sources: List[Any], train: bool = True, @@ -229,7 +229,7 @@ class NeuralDBVectorStore(VectorStore): **kwargs, ) - def _preprocess_sources(self, sources): + def _preprocess_sources(self, sources): # type: ignore[no-untyped-def] """Checks if the provided sources are string paths. If they are, convert to NeuralDB document objects. @@ -261,7 +261,7 @@ class NeuralDBVectorStore(VectorStore): ) return preprocessed_sources - def upvote(self, query: str, document_id: Union[int, str]): + def upvote(self, query: str, document_id: Union[int, str]): # type: ignore[no-untyped-def] """The vectorstore upweights the score of a document for a specific query. This is useful for fine-tuning the vectorstore to user behavior. @@ -271,7 +271,7 @@ class NeuralDBVectorStore(VectorStore): """ self.db.text_to_result(query, int(document_id)) - def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]): + def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]): # type: ignore[no-untyped-def] """Given a batch of (query, document id) pairs, the vectorstore upweights the scores of the document for the corresponding queries. This is useful for fine-tuning the vectorstore to user behavior. @@ -284,7 +284,7 @@ class NeuralDBVectorStore(VectorStore): [(query, int(doc_id)) for query, doc_id in query_id_pairs] ) - def associate(self, source: str, target: str): + def associate(self, source: str, target: str): # type: ignore[no-untyped-def] """The vectorstore associates a source phrase with a target phrase. When the vectorstore sees the source phrase, it will also consider results that are relevant to the target phrase. @@ -295,7 +295,7 @@ class NeuralDBVectorStore(VectorStore): """ self.db.associate(source, target) - def associate_batch(self, text_pairs: List[Tuple[str, str]]): + def associate_batch(self, text_pairs: List[Tuple[str, str]]): # type: ignore[no-untyped-def] """Given a batch of (source, target) pairs, the vectorstore associates each source phrase with the corresponding target phrase. @@ -334,7 +334,7 @@ class NeuralDBVectorStore(VectorStore): except Exception as e: raise ValueError(f"Error while retrieving documents: {e}") from e - def save(self, path: str): + def save(self, path: str): # type: ignore[no-untyped-def] """Saves a NeuralDB instance to disk. Can be loaded into memory by calling NeuralDB.from_checkpoint(path) diff --git a/libs/community/langchain_community/vectorstores/vectara.py b/libs/community/langchain_community/vectorstores/vectara.py index 0fc91004216..3e0ecd3bfe7 100644 --- a/libs/community/langchain_community/vectorstores/vectara.py +++ b/libs/community/langchain_community/vectorstores/vectara.py @@ -384,7 +384,7 @@ class Vectara(VectorStore): f"(code {response.status_code}, reason {response.reason}, details " f"{response.text})", ) - return [], "" + return [], "" # type: ignore[return-value] result = response.json() @@ -454,7 +454,7 @@ class Vectara(VectorStore): docs = self.vectara_query(query, config) return docs - def similarity_search( + def similarity_search( # type: ignore[override] self, query: str, **kwargs: Any, @@ -474,7 +474,7 @@ class Vectara(VectorStore): ) return [doc for doc, _ in docs_and_scores] - def max_marginal_relevance_search( + def max_marginal_relevance_search( # type: ignore[override] self, query: str, fetch_k: int = 50, diff --git a/libs/community/langchain_community/vectorstores/vikngdb.py b/libs/community/langchain_community/vectorstores/vikngdb.py index 2f235f0bf4c..942f4b9591a 100644 --- a/libs/community/langchain_community/vectorstores/vikngdb.py +++ b/libs/community/langchain_community/vectorstores/vikngdb.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) class VikingDBConfig(object): - def __init__(self, host="host", region="region", ak="ak", sk="sk", scheme="http"): + def __init__(self, host="host", region="region", ak="ak", sk="sk", scheme="http"): # type: ignore[no-untyped-def] self.host = host self.region = region self.ak = ak @@ -47,11 +47,11 @@ class VikingDB(VectorStore): self.index_params = index_params self.drop_old = drop_old self.service = VikingDBService( - connection_args.host, - connection_args.region, - connection_args.ak, - connection_args.sk, - connection_args.scheme, + connection_args.host, # type: ignore[union-attr] + connection_args.region, # type: ignore[union-attr] + connection_args.ak, # type: ignore[union-attr] + connection_args.sk, # type: ignore[union-attr] + connection_args.scheme, # type: ignore[union-attr] ) try: @@ -143,7 +143,7 @@ class VikingDB(VectorStore): scalar_index=scalar_index, ) - def add_texts( + def add_texts( # type: ignore[override] self, texts: List[str], metadatas: Optional[List[dict]] = None, @@ -183,7 +183,7 @@ class VikingDB(VectorStore): if metadatas is not None and index < len(metadatas): names = list(metadatas[index].keys()) for name in names: - field[name] = metadatas[index].get(name) + field[name] = metadatas[index].get(name) # type: ignore[assignment] data.append(Data(field)) total_count = len(data) @@ -191,10 +191,10 @@ class VikingDB(VectorStore): end = min(i + batch_size, total_count) insert_data = data[i:end] # print(insert_data) - self.collection.upsert_data(insert_data) + self.collection.upsert_data(insert_data) # type: ignore[union-attr] return pks - def similarity_search( + def similarity_search( # type: ignore[override] self, query: str, params: Optional[dict] = None, @@ -216,7 +216,7 @@ class VikingDB(VectorStore): ) return res - def similarity_search_by_vector( + def similarity_search_by_vector( # type: ignore[override] self, embedding: List[float], params: Optional[dict] = None, @@ -251,7 +251,7 @@ class VikingDB(VectorStore): if params.get("partition") is not None: partition = params["partition"] - res = self.index.search_by_vector( + res = self.index.search_by_vector( # type: ignore[union-attr] embedding, filter=filter, limit=limit, @@ -269,7 +269,7 @@ class VikingDB(VectorStore): ret.append(pair) return ret - def max_marginal_relevance_search( + def max_marginal_relevance_search( # type: ignore[override] self, query: str, k: int = 4, @@ -286,7 +286,7 @@ class VikingDB(VectorStore): **kwargs, ) - def max_marginal_relevance_search_by_vector( + def max_marginal_relevance_search_by_vector( # type: ignore[override] self, embedding: List[float], k: int = 4, @@ -311,7 +311,7 @@ class VikingDB(VectorStore): if params.get("partition") is not None: partition = params["partition"] - res = self.index.search_by_vector( + res = self.index.search_by_vector( # type: ignore[union-attr] embedding, filter=filter, limit=limit, @@ -347,10 +347,10 @@ class VikingDB(VectorStore): ) -> None: if self.collection is None: logger.debug("No existing collection to search.") - self.collection.delete_data(ids) + self.collection.delete_data(ids) # type: ignore[union-attr] @classmethod - def from_texts( + def from_texts( # type: ignore[no-untyped-def, override] cls, texts: List[str], embedding: Embeddings, diff --git a/libs/community/poetry.lock b/libs/community/poetry.lock index f674b1d20a6..70d943ea094 100644 --- a/libs/community/poetry.lock +++ b/libs/community/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aenum" @@ -3944,7 +3944,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.1.17" +version = "0.1.18" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -9252,4 +9252,4 @@ extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "as [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "6e1aabbf689bf7294ffc3f9215559157b95868275421d776862ddb1499969c79" +content-hash = "1ab63edcddcef2deb01e6fff5c376f7b0773435bb9d5b55bc1d50d19a8f1dee2" diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index 4080d8c7ee7..14a0117adb7 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -101,7 +101,7 @@ optional = true # dependencies used for running tests (e.g., pytest, freezegun, response). # Any dependencies that do not meet that criteria will be removed. pytest = "^7.3.0" -pytest-cov = "^4.0.0" +pytest-cov = "^4.1.0" pytest-dotenv = "^0.5.2" duckdb-engine = "^0.9.2" pytest-watcher = "^0.2.6" diff --git a/libs/community/tests/integration_tests/chat_message_histories/test_tidb.py b/libs/community/tests/integration_tests/chat_message_histories/test_tidb.py index 17601af48b5..e6a3d2303dc 100644 --- a/libs/community/tests/integration_tests/chat_message_histories/test_tidb.py +++ b/libs/community/tests/integration_tests/chat_message_histories/test_tidb.py @@ -57,7 +57,7 @@ def test_add_messages() -> None: assert len(message_store_another.messages) == 0 -def test_tidb_recent_chat_message(): +def test_tidb_recent_chat_message(): # type: ignore[no-untyped-def] """Test the TiDBChatMessageHistory with earliest_time parameter.""" import time from datetime import datetime diff --git a/libs/community/tests/integration_tests/chat_models/test_konko.py b/libs/community/tests/integration_tests/chat_models/test_konko.py index 9f38f740eb4..7287d5b11dd 100644 --- a/libs/community/tests/integration_tests/chat_models/test_konko.py +++ b/libs/community/tests/integration_tests/chat_models/test_konko.py @@ -40,7 +40,7 @@ def test_konko_key_masked_when_passed_via_constructor( captured = capsys.readouterr() assert captured.out == "**********" - print(chat.konko_secret_key, end="") + print(chat.konko_secret_key, end="") # type: ignore[attr-defined] captured = capsys.readouterr() assert captured.out == "**********" @@ -49,7 +49,7 @@ def test_uses_actual_secret_value_from_secret_str() -> None: """Test that actual secret is retrieved using `.get_secret_value()`.""" chat = ChatKonko(openai_api_key="test-openai-key", konko_api_key="test-konko-key") assert cast(SecretStr, chat.konko_api_key).get_secret_value() == "test-openai-key" - assert cast(SecretStr, chat.konko_secret_key).get_secret_value() == "test-konko-key" + assert cast(SecretStr, chat.konko_secret_key).get_secret_value() == "test-konko-key" # type: ignore[attr-defined] def test_konko_chat_test() -> None: diff --git a/libs/community/tests/integration_tests/chat_models/test_llama_edge.py b/libs/community/tests/integration_tests/chat_models/test_llama_edge.py index 50919902dd4..92e674c0fbf 100644 --- a/libs/community/tests/integration_tests/chat_models/test_llama_edge.py +++ b/libs/community/tests/integration_tests/chat_models/test_llama_edge.py @@ -47,6 +47,6 @@ def test_chat_wasm_service_streaming() -> None: output = "" for chunk in chat.stream(messages): print(chunk.content, end="", flush=True) - output += chunk.content + output += chunk.content # type: ignore[operator] assert "Paris" in output diff --git a/libs/community/tests/integration_tests/document_loaders/test_astradb.py b/libs/community/tests/integration_tests/document_loaders/test_astradb.py index 0a8518885ae..8f9146aacb5 100644 --- a/libs/community/tests/integration_tests/document_loaders/test_astradb.py +++ b/libs/community/tests/integration_tests/document_loaders/test_astradb.py @@ -167,5 +167,5 @@ class TestAstraDB: find_options={"limit": 30}, extraction_function=lambda x: x["foo"], ) - doc = await anext(loader.alazy_load()) + doc = await anext(loader.alazy_load()) # type: ignore[name-defined] assert doc.page_content == "bar" diff --git a/libs/community/tests/integration_tests/document_loaders/test_cassandra.py b/libs/community/tests/integration_tests/document_loaders/test_cassandra.py index 037ee0a3470..5562188eced 100644 --- a/libs/community/tests/integration_tests/document_loaders/test_cassandra.py +++ b/libs/community/tests/integration_tests/document_loaders/test_cassandra.py @@ -14,7 +14,7 @@ CASSANDRA_TABLE = "docloader_test_table" @pytest.fixture(autouse=True, scope="session") -def keyspace() -> str: +def keyspace() -> str: # type: ignore[misc] import cassio from cassandra.cluster import Cluster from cassio.config import check_resolve_session, resolve_keyspace diff --git a/libs/community/tests/integration_tests/embeddings/test_baichuan.py b/libs/community/tests/integration_tests/embeddings/test_baichuan.py index 008dc0f97df..7512b060580 100644 --- a/libs/community/tests/integration_tests/embeddings/test_baichuan.py +++ b/libs/community/tests/integration_tests/embeddings/test_baichuan.py @@ -7,8 +7,8 @@ def test_baichuan_embedding_documents() -> None: documents = ["今天天气不错", "今天阳光灿烂"] embedding = BaichuanTextEmbeddings() output = embedding.embed_documents(documents) - assert len(output) == 2 - assert len(output[0]) == 1024 + assert len(output) == 2 # type: ignore[arg-type] + assert len(output[0]) == 1024 # type: ignore[index] def test_baichuan_embedding_query() -> None: @@ -16,4 +16,4 @@ def test_baichuan_embedding_query() -> None: document = "所有的小学生都会学过只因兔同笼问题。" embedding = BaichuanTextEmbeddings() output = embedding.embed_query(document) - assert len(output) == 1024 + assert len(output) == 1024 # type: ignore[arg-type] diff --git a/libs/community/tests/integration_tests/graphs/test_neo4j.py b/libs/community/tests/integration_tests/graphs/test_neo4j.py index e4b47f907ea..a209c56d099 100644 --- a/libs/community/tests/integration_tests/graphs/test_neo4j.py +++ b/libs/community/tests/integration_tests/graphs/test_neo4j.py @@ -85,7 +85,7 @@ def test_neo4j_timeout() -> None: graph.query("UNWIND range(0,100000,1) AS i MERGE (:Foo {id:i})") except Exception as e: assert ( - e.code + e.code # type: ignore[attr-defined] == "Neo.ClientError.Transaction.TransactionTimedOutClientConfiguration" ) diff --git a/libs/community/tests/integration_tests/llms/test_azureml_endpoint.py b/libs/community/tests/integration_tests/llms/test_azureml_endpoint.py index 4d0e86b5102..3c4e61eb444 100644 --- a/libs/community/tests/integration_tests/llms/test_azureml_endpoint.py +++ b/libs/community/tests/integration_tests/llms/test_azureml_endpoint.py @@ -62,7 +62,7 @@ def test_custom_formatter() -> None: content_type = "application/json" accepts = "application/json" - def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: + def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override] input_str = json.dumps( { "inputs": [prompt], @@ -72,7 +72,7 @@ def test_custom_formatter() -> None: ) return input_str.encode("utf-8") - def format_response_payload(self, output: bytes) -> str: + def format_response_payload(self, output: bytes) -> str: # type: ignore[override] response_json = json.loads(output) return response_json[0]["summary_text"] @@ -104,7 +104,7 @@ def test_invalid_request_format() -> None: content_type = "application/json" accepts = "application/json" - def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: + def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override] input_str = json.dumps( { "incorrect_input": {"input_string": [prompt]}, @@ -113,7 +113,7 @@ def test_invalid_request_format() -> None: ) return str.encode(input_str) - def format_response_payload(self, output: bytes) -> str: + def format_response_payload(self, output: bytes) -> str: # type: ignore[override] response_json = json.loads(output) return response_json[0]["0"] diff --git a/libs/community/tests/integration_tests/llms/test_bedrock.py b/libs/community/tests/integration_tests/llms/test_bedrock.py index 45a7bcd0bfa..97dbf4f4e8a 100644 --- a/libs/community/tests/integration_tests/llms/test_bedrock.py +++ b/libs/community/tests/integration_tests/llms/test_bedrock.py @@ -37,12 +37,12 @@ class BedrockAsyncCallbackHandler(AsyncCallbackHandler): if reason == "GUARDRAIL_INTERVENED": self.guardrails_intervened = True - def get_response(self): + def get_response(self): # type: ignore[no-untyped-def] return self.guardrails_intervened @pytest.fixture(autouse=True) -def bedrock_runtime_client(): +def bedrock_runtime_client(): # type: ignore[no-untyped-def] import boto3 try: @@ -56,7 +56,7 @@ def bedrock_runtime_client(): @pytest.fixture(autouse=True) -def bedrock_client(): +def bedrock_client(): # type: ignore[no-untyped-def] import boto3 try: @@ -70,7 +70,7 @@ def bedrock_client(): @pytest.fixture -def bedrock_models(bedrock_client): +def bedrock_models(bedrock_client): # type: ignore[no-untyped-def] """List bedrock models.""" response = bedrock_client.list_foundation_models().get("modelSummaries") models = {} @@ -79,7 +79,7 @@ def bedrock_models(bedrock_client): return models -def test_claude_instant_v1(bedrock_runtime_client, bedrock_models): +def test_claude_instant_v1(bedrock_runtime_client, bedrock_models): # type: ignore[no-untyped-def] try: llm = Bedrock( model_id="anthropic.claude-instant-v1", @@ -92,7 +92,7 @@ def test_claude_instant_v1(bedrock_runtime_client, bedrock_models): pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False) -def test_amazon_bedrock_guardrails_no_intervention_for_valid_query( +def test_amazon_bedrock_guardrails_no_intervention_for_valid_query( # type: ignore[no-untyped-def] bedrock_runtime_client, bedrock_models ): try: @@ -112,7 +112,7 @@ def test_amazon_bedrock_guardrails_no_intervention_for_valid_query( pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False) -def test_amazon_bedrock_guardrails_intervention_for_invalid_query( +def test_amazon_bedrock_guardrails_intervention_for_invalid_query( # type: ignore[no-untyped-def] bedrock_runtime_client, bedrock_models ): try: diff --git a/libs/community/tests/integration_tests/storage/test_astradb.py b/libs/community/tests/integration_tests/storage/test_astradb.py index 37a955ad31e..643b4e93a31 100644 --- a/libs/community/tests/integration_tests/storage/test_astradb.py +++ b/libs/community/tests/integration_tests/storage/test_astradb.py @@ -16,7 +16,7 @@ def _has_env_vars() -> bool: @pytest.fixture -def astra_db(): +def astra_db(): # type: ignore[no-untyped-def] from astrapy.db import AstraDB return AstraDB( @@ -26,14 +26,14 @@ def astra_db(): ) -def init_store(astra_db, collection_name: str): +def init_store(astra_db, collection_name: str): # type: ignore[no-untyped-def, no-untyped-def] astra_db.create_collection(collection_name) store = AstraDBStore(collection_name=collection_name, astra_db_client=astra_db) store.mset([("key1", [0.1, 0.2]), ("key2", "value2")]) return store -def init_bytestore(astra_db, collection_name: str): +def init_bytestore(astra_db, collection_name: str): # type: ignore[no-untyped-def, no-untyped-def] astra_db.create_collection(collection_name) store = AstraDBByteStore(collection_name=collection_name, astra_db_client=astra_db) store.mset([("key1", b"value1"), ("key2", b"value2")]) @@ -43,7 +43,7 @@ def init_bytestore(astra_db, collection_name: str): @pytest.mark.requires("astrapy") @pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") class TestAstraDBStore: - def test_mget(self, astra_db) -> None: + def test_mget(self, astra_db) -> None: # type: ignore[no-untyped-def] """Test AstraDBStore mget method.""" collection_name = "lc_test_store_mget" try: @@ -52,7 +52,7 @@ class TestAstraDBStore: finally: astra_db.delete_collection(collection_name) - def test_mset(self, astra_db) -> None: + def test_mset(self, astra_db) -> None: # type: ignore[no-untyped-def] """Test that multiple keys can be set with AstraDBStore.""" collection_name = "lc_test_store_mset" try: @@ -64,7 +64,7 @@ class TestAstraDBStore: finally: astra_db.delete_collection(collection_name) - def test_mdelete(self, astra_db) -> None: + def test_mdelete(self, astra_db) -> None: # type: ignore[no-untyped-def] """Test that deletion works as expected.""" collection_name = "lc_test_store_mdelete" try: @@ -75,7 +75,7 @@ class TestAstraDBStore: finally: astra_db.delete_collection(collection_name) - def test_yield_keys(self, astra_db) -> None: + def test_yield_keys(self, astra_db) -> None: # type: ignore[no-untyped-def] collection_name = "lc_test_store_yield_keys" try: store = init_store(astra_db, collection_name) @@ -85,7 +85,7 @@ class TestAstraDBStore: finally: astra_db.delete_collection(collection_name) - def test_bytestore_mget(self, astra_db) -> None: + def test_bytestore_mget(self, astra_db) -> None: # type: ignore[no-untyped-def] """Test AstraDBByteStore mget method.""" collection_name = "lc_test_bytestore_mget" try: @@ -94,7 +94,7 @@ class TestAstraDBStore: finally: astra_db.delete_collection(collection_name) - def test_bytestore_mset(self, astra_db) -> None: + def test_bytestore_mset(self, astra_db) -> None: # type: ignore[no-untyped-def] """Test that multiple keys can be set with AstraDBByteStore.""" collection_name = "lc_test_bytestore_mset" try: diff --git a/libs/community/tests/integration_tests/utilities/test_google_trends.py b/libs/community/tests/integration_tests/utilities/test_google_trends.py index 0455f16a580..c2583d26339 100644 --- a/libs/community/tests/integration_tests/utilities/test_google_trends.py +++ b/libs/community/tests/integration_tests/utilities/test_google_trends.py @@ -6,7 +6,7 @@ from langchain_community.utilities.google_trends import GoogleTrendsAPIWrapper @patch("serpapi.SerpApiClient.get_json") -def test_unexpected_response(mocked_serpapiclient): +def test_unexpected_response(mocked_serpapiclient): # type: ignore[no-untyped-def] os.environ["SERPAPI_API_KEY"] = "123abcd" resp = { "search_metadata": { diff --git a/libs/community/tests/integration_tests/vectorstores/qdrant/common.py b/libs/community/tests/integration_tests/vectorstores/qdrant/common.py index ec5d14ef318..c09fcc6a7a3 100644 --- a/libs/community/tests/integration_tests/vectorstores/qdrant/common.py +++ b/libs/community/tests/integration_tests/vectorstores/qdrant/common.py @@ -15,7 +15,7 @@ def qdrant_is_not_running() -> bool: return True -def assert_documents_equals(actual: List[Document], expected: List[Document]): +def assert_documents_equals(actual: List[Document], expected: List[Document]): # type: ignore[no-untyped-def] assert len(actual) == len(expected) for actual_doc, expected_doc in zip(actual, expected): diff --git a/libs/community/tests/integration_tests/vectorstores/test_bigquery_vector_search.py b/libs/community/tests/integration_tests/vectorstores/test_bigquery_vector_search.py index 57da87d6696..d5146258c5e 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_bigquery_vector_search.py +++ b/libs/community/tests/integration_tests/vectorstores/test_bigquery_vector_search.py @@ -32,7 +32,7 @@ def store(request: pytest.FixtureRequest) -> BigQueryVectorSearch: TestBigQueryVectorStore.dataset_name, exists_ok=True ) TestBigQueryVectorStore.store = BigQueryVectorSearch( - project_id=os.environ.get("PROJECT", None), + project_id=os.environ.get("PROJECT", None), # type: ignore[arg-type] embedding=FakeEmbeddings(), dataset_name=TestBigQueryVectorStore.dataset_name, table_name=TEST_TABLE_NAME, diff --git a/libs/community/tests/integration_tests/vectorstores/test_deeplake.py b/libs/community/tests/integration_tests/vectorstores/test_deeplake.py index 99d4d6def2e..7f86795d972 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_deeplake.py +++ b/libs/community/tests/integration_tests/vectorstores/test_deeplake.py @@ -52,7 +52,7 @@ def test_deeplake_with_metadatas() -> None: assert output == [Document(page_content="foo", metadata={"page": "0"})] -def test_deeplake_with_persistence(deeplake_datastore) -> None: +def test_deeplake_with_persistence(deeplake_datastore) -> None: # type: ignore[no-untyped-def] """Test end to end construction and search, with persistence.""" output = deeplake_datastore.similarity_search("foo", k=1) assert output == [Document(page_content="foo", metadata={"page": "0"})] @@ -72,7 +72,7 @@ def test_deeplake_with_persistence(deeplake_datastore) -> None: # Or on program exit -def test_deeplake_overwrite_flag(deeplake_datastore) -> None: +def test_deeplake_overwrite_flag(deeplake_datastore) -> None: # type: ignore[no-untyped-def] """Test overwrite behavior""" dataset_path = deeplake_datastore.vectorstore.dataset_handler.path @@ -108,7 +108,7 @@ def test_deeplake_overwrite_flag(deeplake_datastore) -> None: output = docsearch.similarity_search("foo", k=1) -def test_similarity_search(deeplake_datastore) -> None: +def test_similarity_search(deeplake_datastore) -> None: # type: ignore[no-untyped-def] """Test similarity search.""" distance_metric = "cos" output = deeplake_datastore.similarity_search( diff --git a/libs/community/tests/integration_tests/vectorstores/test_hanavector.py b/libs/community/tests/integration_tests/vectorstores/test_hanavector.py index dfcdb8c7040..6e6f5b44a6e 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_hanavector.py +++ b/libs/community/tests/integration_tests/vectorstores/test_hanavector.py @@ -38,7 +38,7 @@ embedding = NormalizedFakeEmbeddings() class ConfigData: - def __init__(self): + def __init__(self): # type: ignore[no-untyped-def] self.conn = None self.schema_name = "" @@ -46,7 +46,7 @@ class ConfigData: test_setup = ConfigData() -def generateSchemaName(cursor): +def generateSchemaName(cursor): # type: ignore[no-untyped-def] cursor.execute( "SELECT REPLACE(CURRENT_UTCDATE, '-', '') || '_' || BINTOHEX(SYSUUID) FROM " "DUMMY;" @@ -59,7 +59,7 @@ def generateSchemaName(cursor): return f"VEC_{uid}" -def setup_module(module): +def setup_module(module): # type: ignore[no-untyped-def] test_setup.conn = dbapi.connect( address=os.environ.get("HANA_DB_ADDRESS"), port=os.environ.get("HANA_DB_PORT"), @@ -81,7 +81,7 @@ def setup_module(module): cur.close() -def teardown_module(module): +def teardown_module(module): # type: ignore[no-untyped-def] try: cur = test_setup.conn.cursor() sql_str = f"DROP SCHEMA {test_setup.schema_name} CASCADE" @@ -100,13 +100,13 @@ def texts() -> List[str]: @pytest.fixture def metadatas() -> List[str]: return [ - {"start": 0, "end": 100, "quality": "good", "ready": True}, - {"start": 100, "end": 200, "quality": "bad", "ready": False}, - {"start": 200, "end": 300, "quality": "ugly", "ready": True}, + {"start": 0, "end": 100, "quality": "good", "ready": True}, # type: ignore[list-item] + {"start": 100, "end": 200, "quality": "bad", "ready": False}, # type: ignore[list-item] + {"start": 200, "end": 300, "quality": "ugly", "ready": True}, # type: ignore[list-item] ] -def drop_table(connection, table_name): +def drop_table(connection, table_name): # type: ignore[no-untyped-def] try: cur = connection.cursor() sql_str = f"DROP TABLE {table_name}" @@ -825,7 +825,7 @@ def test_hanavector_filter_prepared_statement_params( rows = cur.fetchall() assert len(rows) == 1 - query_value = "good" + query_value = "good" # type: ignore[assignment] sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.quality') = ?" cur.execute(sql_str, (query_value)) rows = cur.fetchall() @@ -839,14 +839,14 @@ def test_hanavector_filter_prepared_statement_params( assert len(rows) == 1 # query_value = True - query_value = "true" + query_value = "true" # type: ignore[assignment] sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?" cur.execute(sql_str, (query_value)) rows = cur.fetchall() assert len(rows) == 2 # query_value = False - query_value = "false" + query_value = "false" # type: ignore[assignment] sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?" cur.execute(sql_str, (query_value)) rows = cur.fetchall() diff --git a/libs/community/tests/integration_tests/vectorstores/test_lantern.py b/libs/community/tests/integration_tests/vectorstores/test_lantern.py index 8de7d803644..bde3c5b6965 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_lantern.py +++ b/libs/community/tests/integration_tests/vectorstores/test_lantern.py @@ -31,7 +31,7 @@ def fix_distance_precision( class FakeEmbeddingsWithAdaDimension(FakeEmbeddings): """Fake embeddings functionality for testing.""" - def __init__(self): + def __init__(self): # type: ignore[no-untyped-def] super(FakeEmbeddingsWithAdaDimension, self).__init__(size=ADA_TOKEN_COUNT) def embed_documents(self, texts: List[str]) -> List[List[float]]: diff --git a/libs/community/tests/integration_tests/vectorstores/test_thirdai_neuraldb.py b/libs/community/tests/integration_tests/vectorstores/test_thirdai_neuraldb.py index bd5eafca00d..370e8ff54fa 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_thirdai_neuraldb.py +++ b/libs/community/tests/integration_tests/vectorstores/test_thirdai_neuraldb.py @@ -7,7 +7,7 @@ from langchain_community.vectorstores import NeuralDBVectorStore @pytest.fixture(scope="session") -def test_csv(): +def test_csv(): # type: ignore[no-untyped-def] csv = "thirdai-test.csv" with open(csv, "w") as o: o.write("column_1,column_2\n") @@ -16,13 +16,13 @@ def test_csv(): os.remove(csv) -def assert_result_correctness(documents): +def assert_result_correctness(documents): # type: ignore[no-untyped-def] assert len(documents) == 1 assert documents[0].page_content == "column_1: column one\n\ncolumn_2: column two" @pytest.mark.requires("thirdai[neural_db]") -def test_neuraldb_retriever_from_scratch(test_csv): +def test_neuraldb_retriever_from_scratch(test_csv): # type: ignore[no-untyped-def] retriever = NeuralDBVectorStore.from_scratch() retriever.insert([test_csv]) documents = retriever.similarity_search("column") @@ -30,7 +30,7 @@ def test_neuraldb_retriever_from_scratch(test_csv): @pytest.mark.requires("thirdai[neural_db]") -def test_neuraldb_retriever_from_checkpoint(test_csv): +def test_neuraldb_retriever_from_checkpoint(test_csv): # type: ignore[no-untyped-def] checkpoint = "thirdai-test-save.ndb" if os.path.exists(checkpoint): shutil.rmtree(checkpoint) @@ -47,7 +47,7 @@ def test_neuraldb_retriever_from_checkpoint(test_csv): @pytest.mark.requires("thirdai[neural_db]") -def test_neuraldb_retriever_from_bazaar(test_csv): +def test_neuraldb_retriever_from_bazaar(test_csv): # type: ignore[no-untyped-def] retriever = NeuralDBVectorStore.from_bazaar("General QnA") retriever.insert([test_csv]) documents = retriever.similarity_search("column") @@ -55,7 +55,7 @@ def test_neuraldb_retriever_from_bazaar(test_csv): @pytest.mark.requires("thirdai[neural_db]") -def test_neuraldb_retriever_other_methods(test_csv): +def test_neuraldb_retriever_other_methods(test_csv): # type: ignore[no-untyped-def] retriever = NeuralDBVectorStore.from_scratch() retriever.insert([test_csv]) # Make sure they don't throw an error. diff --git a/libs/community/tests/integration_tests/vectorstores/test_vectara.py b/libs/community/tests/integration_tests/vectorstores/test_vectara.py index 1aaf4856571..6e2f4f2900a 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_vectara.py +++ b/libs/community/tests/integration_tests/vectorstores/test_vectara.py @@ -25,7 +25,7 @@ def get_abbr(s: str) -> str: @pytest.fixture(scope="function") -def vectara1(): +def vectara1(): # type: ignore[no-untyped-def] # Set up code # create a new Vectara instance vectara1: Vectara = Vectara() @@ -54,7 +54,7 @@ def vectara1(): vectara1._delete_doc(doc_id) -def test_vectara_add_documents(vectara1) -> None: +def test_vectara_add_documents(vectara1) -> None: # type: ignore[no-untyped-def] """Test add_documents.""" # test without filter @@ -164,7 +164,7 @@ models can greatly improve the training of DNNs and other deep discriminative mo @pytest.fixture(scope="function") -def vectara3(): +def vectara3(): # type: ignore[no-untyped-def] # Set up code vectara3: Vectara = Vectara() @@ -210,7 +210,7 @@ def vectara3(): vectara3._delete_doc(doc_id) -def test_vectara_mmr(vectara3) -> None: +def test_vectara_mmr(vectara3) -> None: # type: ignore[no-untyped-def] # test max marginal relevance output1 = vectara3.max_marginal_relevance_search( "generative AI", @@ -241,7 +241,7 @@ def test_vectara_mmr(vectara3) -> None: ) -def test_vectara_with_summary(vectara3) -> None: +def test_vectara_with_summary(vectara3) -> None: # type: ignore[no-untyped-def] """Test vectara summary.""" # test summarization num_results = 10 diff --git a/libs/community/tests/unit_tests/chat_models/test_edenai.py b/libs/community/tests/unit_tests/chat_models/test_edenai.py index 2fa85512c01..dfafc5af988 100644 --- a/libs/community/tests/unit_tests/chat_models/test_edenai.py +++ b/libs/community/tests/unit_tests/chat_models/test_edenai.py @@ -35,6 +35,6 @@ def test_edenai_messages_formatting(messages: List[BaseMessage], expected: str) ("role", "role_response"), [("ai", "assistant"), ("human", "user"), ("chat", "user")], ) -def test_edenai_message_role(role: str, role_response) -> None: +def test_edenai_message_role(role: str, role_response) -> None: # type: ignore[no-untyped-def] role = _message_role(role) assert role == role_response diff --git a/libs/community/tests/unit_tests/embeddings/test_gradient_ai.py b/libs/community/tests/unit_tests/embeddings/test_gradient_ai.py index fa57109ef02..1bec4a7bf63 100644 --- a/libs/community/tests/unit_tests/embeddings/test_gradient_ai.py +++ b/libs/community/tests/unit_tests/embeddings/test_gradient_ai.py @@ -29,7 +29,7 @@ class GradientEmbeddingsModel(MagicMock): embeddings = [] for i, inp in enumerate(inputs): # verify correct ordering - inp = inp["input"] + inp = inp["input"] # type: ignore[assignment] if "pizza" in inp: v = [1.0, 0.0, 0.0] elif "document" in inp: @@ -45,14 +45,14 @@ class GradientEmbeddingsModel(MagicMock): output.embeddings = embeddings return output - async def aembed(self, *args) -> Any: + async def aembed(self, *args) -> Any: # type: ignore[no-untyped-def] return self.embed(*args) class MockGradient(MagicMock): """Mock Gradient package.""" - def __init__(self, access_token: str, workspace_id, host): + def __init__(self, access_token: str, workspace_id, host): # type: ignore[no-untyped-def] assert access_token == _GRADIENT_SECRET assert workspace_id == _GRADIENT_WORKSPACE_ID assert host == _GRADIENT_BASE_URL diff --git a/libs/community/tests/unit_tests/embeddings/test_oci_gen_ai_embedding.py b/libs/community/tests/unit_tests/embeddings/test_oci_gen_ai_embedding.py index 12d9c447d5d..d43f3108b9c 100644 --- a/libs/community/tests/unit_tests/embeddings/test_oci_gen_ai_embedding.py +++ b/libs/community/tests/unit_tests/embeddings/test_oci_gen_ai_embedding.py @@ -8,7 +8,7 @@ from langchain_community.embeddings import OCIGenAIEmbeddings class MockResponseDict(dict): - def __getattr__(self, val): + def __getattr__(self, val): # type: ignore[no-untyped-def] return self[val] @@ -25,7 +25,7 @@ def test_embedding_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None: client=oci_gen_ai_client, ) - def mocked_response(invocation_obj): + def mocked_response(invocation_obj): # type: ignore[no-untyped-def] docs = invocation_obj.inputs embeddings = [] diff --git a/libs/community/tests/unit_tests/graphs/test_neo4j_graph.py b/libs/community/tests/unit_tests/graphs/test_neo4j_graph.py index b352529ba67..eebef74ef16 100644 --- a/libs/community/tests/unit_tests/graphs/test_neo4j_graph.py +++ b/libs/community/tests/unit_tests/graphs/test_neo4j_graph.py @@ -1,14 +1,14 @@ from langchain_community.graphs.neo4j_graph import value_sanitize -def test_value_sanitize_with_small_list(): +def test_value_sanitize_with_small_list(): # type: ignore[no-untyped-def] small_list = list(range(15)) # list size > LIST_LIMIT input_dict = {"key1": "value1", "small_list": small_list} expected_output = {"key1": "value1", "small_list": small_list} assert value_sanitize(input_dict) == expected_output -def test_value_sanitize_with_oversized_list(): +def test_value_sanitize_with_oversized_list(): # type: ignore[no-untyped-def] oversized_list = list(range(150)) # list size > LIST_LIMIT input_dict = {"key1": "value1", "oversized_list": oversized_list} expected_output = { @@ -18,14 +18,14 @@ def test_value_sanitize_with_oversized_list(): assert value_sanitize(input_dict) == expected_output -def test_value_sanitize_with_nested_oversized_list(): +def test_value_sanitize_with_nested_oversized_list(): # type: ignore[no-untyped-def] oversized_list = list(range(150)) # list size > LIST_LIMIT input_dict = {"key1": "value1", "oversized_list": {"key": oversized_list}} expected_output = {"key1": "value1", "oversized_list": {}} assert value_sanitize(input_dict) == expected_output -def test_value_sanitize_with_dict_in_list(): +def test_value_sanitize_with_dict_in_list(): # type: ignore[no-untyped-def] oversized_list = list(range(150)) # list size > LIST_LIMIT input_dict = {"key1": "value1", "oversized_list": [1, 2, {"key": oversized_list}]} expected_output = {"key1": "value1", "oversized_list": [1, 2, {}]} diff --git a/libs/community/tests/unit_tests/graphs/test_ontotext_graphdb_graph.py b/libs/community/tests/unit_tests/graphs/test_ontotext_graphdb_graph.py index 8b025fed36f..9beb2da271d 100644 --- a/libs/community/tests/unit_tests/graphs/test_ontotext_graphdb_graph.py +++ b/libs/community/tests/unit_tests/graphs/test_ontotext_graphdb_graph.py @@ -15,7 +15,7 @@ class TestOntotextGraphDBGraph(unittest.TestCase): with self.assertRaises(TypeError) as e: OntotextGraphDBGraph._validate_user_query( - [ + [ # type: ignore[arg-type] "PREFIX starwars: " "PREFIX rdfs: " "DESCRIBE starwars: ?term " diff --git a/libs/community/tests/unit_tests/llms/test_oci_generative_ai.py b/libs/community/tests/unit_tests/llms/test_oci_generative_ai.py index 694d88f0c24..b1c36ec7a5f 100644 --- a/libs/community/tests/unit_tests/llms/test_oci_generative_ai.py +++ b/libs/community/tests/unit_tests/llms/test_oci_generative_ai.py @@ -8,7 +8,7 @@ from langchain_community.llms import OCIGenAI class MockResponseDict(dict): - def __getattr__(self, val): + def __getattr__(self, val): # type: ignore[no-untyped-def] return self[val] @@ -23,7 +23,7 @@ def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None: provider = llm._get_provider() - def mocked_response(*args): + def mocked_response(*args): # type: ignore[no-untyped-def] response_text = "This is the completion." if provider == "cohere": diff --git a/libs/community/tests/unit_tests/llms/test_ollama.py b/libs/community/tests/unit_tests/llms/test_ollama.py index 63a323eb35e..6dcb0bd38e3 100644 --- a/libs/community/tests/unit_tests/llms/test_ollama.py +++ b/libs/community/tests/unit_tests/llms/test_ollama.py @@ -4,11 +4,11 @@ from pytest import MonkeyPatch from langchain_community.llms.ollama import Ollama -def mock_response_stream(): +def mock_response_stream(): # type: ignore[no-untyped-def] mock_response = [b'{ "response": "Response chunk 1" }'] class MockRaw: - def read(self, chunk_size): + def read(self, chunk_size): # type: ignore[no-untyped-def] try: return mock_response.pop() except IndexError: @@ -31,7 +31,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None: timeout=300, ) - def mock_post(url, headers, json, stream, timeout): + def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] assert url == "https://ollama-hostname:8000/api/generate/" assert headers == { "Content-Type": "application/json", @@ -52,7 +52,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None: def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None: llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) - def mock_post(url, headers, json, stream, timeout): + def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] assert url == "https://ollama-hostname:8000/api/generate/" assert headers == { "Content-Type": "application/json", @@ -72,7 +72,7 @@ def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None: """Test that top level params are sent to the endpoint as top level params""" llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) - def mock_post(url, headers, json, stream, timeout): + def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] assert url == "https://ollama-hostname:8000/api/generate/" assert headers == { "Content-Type": "application/json", @@ -118,7 +118,7 @@ def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None: """ llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) - def mock_post(url, headers, json, stream, timeout): + def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] assert url == "https://ollama-hostname:8000/api/generate/" assert headers == { "Content-Type": "application/json", @@ -165,7 +165,7 @@ def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None: """ llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) - def mock_post(url, headers, json, stream, timeout): + def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] assert url == "https://ollama-hostname:8000/api/generate/" assert headers == { "Content-Type": "application/json", diff --git a/libs/core/Makefile b/libs/core/Makefile index 2c685dfd449..ab8e9cadf03 100644 --- a/libs/core/Makefile +++ b/libs/core/Makefile @@ -42,7 +42,7 @@ lint lint_diff lint_package lint_tests: [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff [ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES) [ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES) - [ "$(PYTHON_FILES)" = "" ] || mkdir $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) + [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) format format_diff: poetry run ruff format $(PYTHON_FILES) diff --git a/libs/langchain/Makefile b/libs/langchain/Makefile index 50f1e423c0f..cff2f348546 100644 --- a/libs/langchain/Makefile +++ b/libs/langchain/Makefile @@ -62,7 +62,7 @@ lint lint_diff lint_package lint_tests: poetry run ruff . [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff [ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES) - [ "$(PYTHON_FILES)" = "" ] || mkdir $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) + [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) format format_diff: [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) diff --git a/libs/langchain/langchain/agents/openai_assistant/base.py b/libs/langchain/langchain/agents/openai_assistant/base.py index cd4e1442918..99620322d8d 100644 --- a/libs/langchain/langchain/agents/openai_assistant/base.py +++ b/libs/langchain/langchain/agents/openai_assistant/base.py @@ -210,7 +210,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): assistant = client.beta.assistants.create( name=name, instructions=instructions, - tools=[convert_to_openai_tool(tool) for tool in tools], + tools=[convert_to_openai_tool(tool) for tool in tools], # type: ignore model=model, ) return cls(assistant_id=assistant.id, client=client, **kwargs) @@ -331,7 +331,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): assistant = await async_client.beta.assistants.create( name=name, instructions=instructions, - tools=openai_tools, + tools=openai_tools, # type: ignore model=model, ) return cls(assistant_id=assistant.id, async_client=async_client, **kwargs) diff --git a/libs/langchain/langchain/chains/moderation.py b/libs/langchain/langchain/chains/moderation.py index 00d6cbd3f35..6deb63ed062 100644 --- a/libs/langchain/langchain/chains/moderation.py +++ b/libs/langchain/langchain/chains/moderation.py @@ -52,7 +52,7 @@ class OpenAIModerationChain(Chain): openai.api_key = openai_api_key if openai_organization: openai.organization = openai_organization - values["client"] = openai.Moderation + values["client"] = openai.Moderation # type: ignore except ImportError: raise ImportError( "Could not import openai python package. "