mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 08:03:39 +00:00
infra: add -p to mkdir in lint steps (#17013)
Previously, if this did not find a mypy cache then it wouldnt run this makes it always run adding mypy ignore comments with existing uncaught issues to unblock other prs --------- Co-authored-by: Erick Friis <erick@langchain.dev> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
db6af21395
commit
4eda647fdd
6
.github/workflows/_lint.yml
vendored
6
.github/workflows/_lint.yml
vendored
@ -86,7 +86,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.WORKDIR }}/.mypy_cache
|
${{ env.WORKDIR }}/.mypy_cache
|
||||||
key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', env.WORKDIR)) }}
|
key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }}
|
||||||
|
|
||||||
|
|
||||||
- name: Analysing the code with our lint
|
- name: Analysing the code with our lint
|
||||||
@ -105,7 +105,7 @@ jobs:
|
|||||||
# It doesn't matter how you change it, any change will cause a cache-bust.
|
# It doesn't matter how you change it, any change will cause a cache-bust.
|
||||||
working-directory: ${{ inputs.working-directory }}
|
working-directory: ${{ inputs.working-directory }}
|
||||||
run: |
|
run: |
|
||||||
poetry install --with test,test_integration
|
poetry install --with test
|
||||||
|
|
||||||
- name: Get .mypy_cache_test to speed up mypy
|
- name: Get .mypy_cache_test to speed up mypy
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v3
|
||||||
@ -114,7 +114,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.WORKDIR }}/.mypy_cache_test
|
${{ env.WORKDIR }}/.mypy_cache_test
|
||||||
key: mypy-test-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', env.WORKDIR)) }}
|
key: mypy-test-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }}
|
||||||
|
|
||||||
- name: Analysing the code with our lint
|
- name: Analysing the code with our lint
|
||||||
working-directory: ${{ inputs.working-directory }}
|
working-directory: ${{ inputs.working-directory }}
|
||||||
|
@ -41,7 +41,7 @@ lint lint_diff lint_package lint_tests:
|
|||||||
poetry run ruff .
|
poetry run ruff .
|
||||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
||||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
|
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
|
||||||
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||||
|
|
||||||
format format_diff:
|
format format_diff:
|
||||||
poetry run ruff format $(PYTHON_FILES)
|
poetry run ruff format $(PYTHON_FILES)
|
||||||
|
@ -84,7 +84,7 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
|
|||||||
raise e
|
raise e
|
||||||
data_params = data.get("params")
|
data_params = data.get("params")
|
||||||
response = self.requests_wrapper.get(data["url"], params=data_params)
|
response = self.requests_wrapper.get(data["url"], params=data_params)
|
||||||
response = response[: self.response_length]
|
response = response[: self.response_length] # type: ignore[index]
|
||||||
return self.llm_chain.predict(
|
return self.llm_chain.predict(
|
||||||
response=response, instructions=data["output_instructions"]
|
response=response, instructions=data["output_instructions"]
|
||||||
).strip()
|
).strip()
|
||||||
@ -115,7 +115,7 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
|
|||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise e
|
raise e
|
||||||
response = self.requests_wrapper.post(data["url"], data["data"])
|
response = self.requests_wrapper.post(data["url"], data["data"])
|
||||||
response = response[: self.response_length]
|
response = response[: self.response_length] # type: ignore[index]
|
||||||
return self.llm_chain.predict(
|
return self.llm_chain.predict(
|
||||||
response=response, instructions=data["output_instructions"]
|
response=response, instructions=data["output_instructions"]
|
||||||
).strip()
|
).strip()
|
||||||
@ -146,7 +146,7 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
|
|||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise e
|
raise e
|
||||||
response = self.requests_wrapper.patch(data["url"], data["data"])
|
response = self.requests_wrapper.patch(data["url"], data["data"])
|
||||||
response = response[: self.response_length]
|
response = response[: self.response_length] # type: ignore[index]
|
||||||
return self.llm_chain.predict(
|
return self.llm_chain.predict(
|
||||||
response=response, instructions=data["output_instructions"]
|
response=response, instructions=data["output_instructions"]
|
||||||
).strip()
|
).strip()
|
||||||
@ -177,7 +177,7 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
|
|||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise e
|
raise e
|
||||||
response = self.requests_wrapper.put(data["url"], data["data"])
|
response = self.requests_wrapper.put(data["url"], data["data"])
|
||||||
response = response[: self.response_length]
|
response = response[: self.response_length] # type: ignore[index]
|
||||||
return self.llm_chain.predict(
|
return self.llm_chain.predict(
|
||||||
response=response, instructions=data["output_instructions"]
|
response=response, instructions=data["output_instructions"]
|
||||||
).strip()
|
).strip()
|
||||||
@ -209,7 +209,7 @@ class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
|
|||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise e
|
raise e
|
||||||
response = self.requests_wrapper.delete(data["url"])
|
response = self.requests_wrapper.delete(data["url"])
|
||||||
response = response[: self.response_length]
|
response = response[: self.response_length] # type: ignore[index]
|
||||||
return self.llm_chain.predict(
|
return self.llm_chain.predict(
|
||||||
response=response, instructions=data["output_instructions"]
|
response=response, instructions=data["output_instructions"]
|
||||||
).strip()
|
).strip()
|
||||||
|
@ -177,12 +177,12 @@ def create_sql_agent(
|
|||||||
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
messages = [
|
messages = [
|
||||||
SystemMessage(content=prefix),
|
SystemMessage(content=prefix), # type: ignore[arg-type]
|
||||||
HumanMessagePromptTemplate.from_template("{input}"),
|
HumanMessagePromptTemplate.from_template("{input}"),
|
||||||
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
||||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||||
]
|
]
|
||||||
prompt = ChatPromptTemplate.from_messages(messages)
|
prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type]
|
||||||
agent = RunnableAgent(
|
agent = RunnableAgent(
|
||||||
runnable=create_openai_functions_agent(llm, tools, prompt),
|
runnable=create_openai_functions_agent(llm, tools, prompt),
|
||||||
input_keys_arg=["input"],
|
input_keys_arg=["input"],
|
||||||
@ -191,12 +191,12 @@ def create_sql_agent(
|
|||||||
elif agent_type == "openai-tools":
|
elif agent_type == "openai-tools":
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
messages = [
|
messages = [
|
||||||
SystemMessage(content=prefix),
|
SystemMessage(content=prefix), # type: ignore[arg-type]
|
||||||
HumanMessagePromptTemplate.from_template("{input}"),
|
HumanMessagePromptTemplate.from_template("{input}"),
|
||||||
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
|
||||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||||
]
|
]
|
||||||
prompt = ChatPromptTemplate.from_messages(messages)
|
prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type]
|
||||||
agent = RunnableMultiActionAgent(
|
agent = RunnableMultiActionAgent(
|
||||||
runnable=create_openai_tools_agent(llm, tools, prompt),
|
runnable=create_openai_tools_agent(llm, tools, prompt),
|
||||||
input_keys_arg=["input"],
|
input_keys_arg=["input"],
|
||||||
|
@ -723,7 +723,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
)
|
)
|
||||||
return session_analysis_df
|
return session_analysis_df
|
||||||
|
|
||||||
def _contain_llm_records(self):
|
def _contain_llm_records(self): # type: ignore[no-untyped-def]
|
||||||
return bool(self.records["on_llm_start_records"])
|
return bool(self.records["on_llm_start_records"])
|
||||||
|
|
||||||
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:
|
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:
|
||||||
|
@ -47,7 +47,7 @@ class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
|
|||||||
):
|
):
|
||||||
self.index: str = index
|
self.index: str = index
|
||||||
self.session_id: str = session_id
|
self.session_id: str = session_id
|
||||||
self.ensure_ascii: bool = esnsure_ascii
|
self.ensure_ascii: bool = esnsure_ascii # type: ignore[assignment]
|
||||||
|
|
||||||
# Initialize Elasticsearch client from passed client arg or connection info
|
# Initialize Elasticsearch client from passed client arg or connection info
|
||||||
if es_connection is not None:
|
if es_connection is not None:
|
||||||
|
@ -40,7 +40,7 @@ class TiDBChatMessageHistory(BaseChatMessageHistory):
|
|||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
self.table_name = table_name
|
self.table_name = table_name
|
||||||
self.earliest_time = earliest_time
|
self.earliest_time = earliest_time
|
||||||
self.cache = []
|
self.cache = [] # type: ignore[var-annotated]
|
||||||
|
|
||||||
# Set up SQLAlchemy engine and session
|
# Set up SQLAlchemy engine and session
|
||||||
self.engine = create_engine(connection_string)
|
self.engine = create_engine(connection_string)
|
||||||
@ -102,7 +102,7 @@ class TiDBChatMessageHistory(BaseChatMessageHistory):
|
|||||||
logger.error(f"Error loading messages to cache: {e}")
|
logger.error(f"Error loading messages to cache: {e}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def messages(self) -> List[BaseMessage]:
|
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||||
"""returns all messages"""
|
"""returns all messages"""
|
||||||
if len(self.cache) == 0:
|
if len(self.cache) == 0:
|
||||||
self.reload_cache()
|
self.reload_cache()
|
||||||
|
@ -149,7 +149,7 @@ class ZepChatMessageHistory(BaseChatMessageHistory):
|
|||||||
return None
|
return None
|
||||||
return zep_memory
|
return zep_memory
|
||||||
|
|
||||||
def add_user_message(
|
def add_user_message( # type: ignore[override]
|
||||||
self, message: str, metadata: Optional[Dict[str, Any]] = None
|
self, message: str, metadata: Optional[Dict[str, Any]] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Convenience method for adding a human message string to the store.
|
"""Convenience method for adding a human message string to the store.
|
||||||
@ -160,7 +160,7 @@ class ZepChatMessageHistory(BaseChatMessageHistory):
|
|||||||
"""
|
"""
|
||||||
self.add_message(HumanMessage(content=message), metadata=metadata)
|
self.add_message(HumanMessage(content=message), metadata=metadata)
|
||||||
|
|
||||||
def add_ai_message(
|
def add_ai_message( # type: ignore[override]
|
||||||
self, message: str, metadata: Optional[Dict[str, Any]] = None
|
self, message: str, metadata: Optional[Dict[str, Any]] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Convenience method for adding an AI message string to the store.
|
"""Convenience method for adding an AI message string to the store.
|
||||||
|
@ -20,7 +20,7 @@ from langchain_community.llms.azureml_endpoint import (
|
|||||||
|
|
||||||
|
|
||||||
class LlamaContentFormatter(ContentFormatterBase):
|
class LlamaContentFormatter(ContentFormatterBase):
|
||||||
def __init__(self):
|
def __init__(self): # type: ignore[no-untyped-def]
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"`LlamaContentFormatter` is deprecated for chat models. Use "
|
"`LlamaContentFormatter` is deprecated for chat models. Use "
|
||||||
"`LlamaChatContentFormatter` instead."
|
"`LlamaChatContentFormatter` instead."
|
||||||
@ -72,7 +72,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
|||||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||||
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
|
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
|
||||||
|
|
||||||
def format_request_payload(
|
def format_request_payload( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
model_kwargs: Dict,
|
model_kwargs: Dict,
|
||||||
@ -98,9 +98,9 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`api_type` {api_type} is not supported by this formatter"
|
f"`api_type` {api_type} is not supported by this formatter"
|
||||||
)
|
)
|
||||||
return str.encode(request_payload)
|
return str.encode(request_payload) # type: ignore[return-value]
|
||||||
|
|
||||||
def format_response_payload(
|
def format_response_payload( # type: ignore[override]
|
||||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||||
) -> ChatGeneration:
|
) -> ChatGeneration:
|
||||||
"""Formats response"""
|
"""Formats response"""
|
||||||
@ -108,7 +108,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
|||||||
try:
|
try:
|
||||||
choice = json.loads(output)["output"]
|
choice = json.loads(output)["output"]
|
||||||
except (KeyError, IndexError, TypeError) as e:
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
||||||
return ChatGeneration(
|
return ChatGeneration(
|
||||||
message=BaseMessage(
|
message=BaseMessage(
|
||||||
content=choice.strip(),
|
content=choice.strip(),
|
||||||
@ -125,7 +125,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
|||||||
"model. Expected `dict` but `{type(choice)}` was received."
|
"model. Expected `dict` but `{type(choice)}` was received."
|
||||||
)
|
)
|
||||||
except (KeyError, IndexError, TypeError) as e:
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
||||||
return ChatGeneration(
|
return ChatGeneration(
|
||||||
message=BaseMessage(
|
message=BaseMessage(
|
||||||
content=choice["message"]["content"].strip(),
|
content=choice["message"]["content"].strip(),
|
||||||
|
@ -175,7 +175,7 @@ class ChatEdenAI(BaseChatModel):
|
|||||||
"""Call out to EdenAI's chat endpoint."""
|
"""Call out to EdenAI's chat endpoint."""
|
||||||
url = f"{self.edenai_api_url}/text/chat/stream"
|
url = f"{self.edenai_api_url}/text/chat/stream"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}",
|
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
|
||||||
"User-Agent": self.get_user_agent(),
|
"User-Agent": self.get_user_agent(),
|
||||||
}
|
}
|
||||||
formatted_data = _format_edenai_messages(messages=messages)
|
formatted_data = _format_edenai_messages(messages=messages)
|
||||||
@ -216,7 +216,7 @@ class ChatEdenAI(BaseChatModel):
|
|||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
url = f"{self.edenai_api_url}/text/chat/stream"
|
url = f"{self.edenai_api_url}/text/chat/stream"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}",
|
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
|
||||||
"User-Agent": self.get_user_agent(),
|
"User-Agent": self.get_user_agent(),
|
||||||
}
|
}
|
||||||
formatted_data = _format_edenai_messages(messages=messages)
|
formatted_data = _format_edenai_messages(messages=messages)
|
||||||
@ -265,7 +265,7 @@ class ChatEdenAI(BaseChatModel):
|
|||||||
|
|
||||||
url = f"{self.edenai_api_url}/text/chat"
|
url = f"{self.edenai_api_url}/text/chat"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}",
|
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
|
||||||
"User-Agent": self.get_user_agent(),
|
"User-Agent": self.get_user_agent(),
|
||||||
}
|
}
|
||||||
formatted_data = _format_edenai_messages(messages=messages)
|
formatted_data = _format_edenai_messages(messages=messages)
|
||||||
@ -323,7 +323,7 @@ class ChatEdenAI(BaseChatModel):
|
|||||||
|
|
||||||
url = f"{self.edenai_api_url}/text/chat"
|
url = f"{self.edenai_api_url}/text/chat"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}",
|
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
|
||||||
"User-Agent": self.get_user_agent(),
|
"User-Agent": self.get_user_agent(),
|
||||||
}
|
}
|
||||||
formatted_data = _format_edenai_messages(messages=messages)
|
formatted_data = _format_edenai_messages(messages=messages)
|
||||||
|
@ -214,7 +214,7 @@ class ErnieBotChat(BaseChatModel):
|
|||||||
generations = [
|
generations = [
|
||||||
ChatGeneration(
|
ChatGeneration(
|
||||||
message=AIMessage(
|
message=AIMessage(
|
||||||
content=response.get("result"),
|
content=response.get("result"), # type: ignore[arg-type]
|
||||||
additional_kwargs={**additional_kwargs},
|
additional_kwargs={**additional_kwargs},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -56,7 +56,7 @@ class GPTRouterModel(BaseModel):
|
|||||||
provider_name: str
|
provider_name: str
|
||||||
|
|
||||||
|
|
||||||
def get_ordered_generation_requests(
|
def get_ordered_generation_requests( # type: ignore[no-untyped-def, no-untyped-def]
|
||||||
models_priority_list: List[GPTRouterModel], **kwargs
|
models_priority_list: List[GPTRouterModel], **kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -100,7 +100,7 @@ def completion_with_retry(
|
|||||||
models_priority_list: List[GPTRouterModel],
|
models_priority_list: List[GPTRouterModel],
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]:
|
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]: # type: ignore[type-arg]
|
||||||
"""Use tenacity to retry the completion call."""
|
"""Use tenacity to retry the completion call."""
|
||||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ async def acompletion_with_retry(
|
|||||||
models_priority_list: List[GPTRouterModel],
|
models_priority_list: List[GPTRouterModel],
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]:
|
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]: # type: ignore[type-arg]
|
||||||
"""Use tenacity to retry the async completion call."""
|
"""Use tenacity to retry the async completion call."""
|
||||||
|
|
||||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
@ -282,7 +282,7 @@ class GPTRouter(BaseChatModel):
|
|||||||
)
|
)
|
||||||
return self._create_chat_result(response)
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
def _create_chat_generation_chunk(
|
def _create_chat_generation_chunk( # type: ignore[no-untyped-def, no-untyped-def]
|
||||||
self, data: Mapping[str, Any], default_chunk_class
|
self, data: Mapping[str, Any], default_chunk_class
|
||||||
):
|
):
|
||||||
chunk = _convert_delta_to_message_chunk(
|
chunk = _convert_delta_to_message_chunk(
|
||||||
@ -293,7 +293,7 @@ class GPTRouter(BaseChatModel):
|
|||||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||||
)
|
)
|
||||||
default_chunk_class = chunk.__class__
|
default_chunk_class = chunk.__class__
|
||||||
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) # type: ignore[assignment]
|
||||||
return chunk, default_chunk_class
|
return chunk, default_chunk_class
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
|
@ -144,7 +144,7 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
|
|
||||||
elif isinstance(self.llm, HuggingFaceHub):
|
elif isinstance(self.llm, HuggingFaceHub):
|
||||||
# no need to look up model_id for HuggingFaceHub LLM
|
# no need to look up model_id for HuggingFaceHub LLM
|
||||||
self.model_id = self.llm.repo_id
|
self.model_id = self.llm.repo_id # type: ignore[assignment]
|
||||||
return
|
return
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -169,7 +169,7 @@ class ChatKonko(ChatOpenAI):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if openai_api_key:
|
if openai_api_key:
|
||||||
headers["X-OpenAI-Api-Key"] = openai_api_key.get_secret_value()
|
headers["X-OpenAI-Api-Key"] = openai_api_key.get_secret_value() # type: ignore[union-attr]
|
||||||
|
|
||||||
models_response = requests.get(models_url, headers=headers)
|
models_response = requests.get(models_url, headers=headers)
|
||||||
|
|
||||||
|
@ -74,10 +74,10 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
|||||||
if isinstance(message, ChatMessage):
|
if isinstance(message, ChatMessage):
|
||||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||||
elif isinstance(message, HumanMessage):
|
elif isinstance(message, HumanMessage):
|
||||||
if message.content[0].get("type") == "text":
|
if message.content[0].get("type") == "text": # type: ignore[union-attr]
|
||||||
message_text = f"[INST] {message.content[0]['text']} [/INST]"
|
message_text = f"[INST] {message.content[0]['text']} [/INST]" # type: ignore[index]
|
||||||
elif message.content[0].get("type") == "image_url":
|
elif message.content[0].get("type") == "image_url": # type: ignore[union-attr]
|
||||||
message_text = message.content[0]["image_url"]["url"]
|
message_text = message.content[0]["image_url"]["url"] # type: ignore[index, index]
|
||||||
elif isinstance(message, AIMessage):
|
elif isinstance(message, AIMessage):
|
||||||
message_text = f"{message.content}"
|
message_text = f"{message.content}"
|
||||||
elif isinstance(message, SystemMessage):
|
elif isinstance(message, SystemMessage):
|
||||||
@ -112,11 +112,11 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
|||||||
content = message.content
|
content = message.content
|
||||||
else:
|
else:
|
||||||
for content_part in message.content:
|
for content_part in message.content:
|
||||||
if content_part.get("type") == "text":
|
if content_part.get("type") == "text": # type: ignore[union-attr]
|
||||||
content += f"\n{content_part['text']}"
|
content += f"\n{content_part['text']}" # type: ignore[index]
|
||||||
elif content_part.get("type") == "image_url":
|
elif content_part.get("type") == "image_url": # type: ignore[union-attr]
|
||||||
if isinstance(content_part.get("image_url"), str):
|
if isinstance(content_part.get("image_url"), str): # type: ignore[union-attr]
|
||||||
image_url_components = content_part["image_url"].split(",")
|
image_url_components = content_part["image_url"].split(",") # type: ignore[index]
|
||||||
# Support data:image/jpeg;base64,<image> format
|
# Support data:image/jpeg;base64,<image> format
|
||||||
# and base64 strings
|
# and base64 strings
|
||||||
if len(image_url_components) > 1:
|
if len(image_url_components) > 1:
|
||||||
@ -142,7 +142,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return ollama_messages
|
return ollama_messages # type: ignore[return-value]
|
||||||
|
|
||||||
def _create_chat_stream(
|
def _create_chat_stream(
|
||||||
self,
|
self,
|
||||||
@ -337,7 +337,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
|||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
)
|
)
|
||||||
except OllamaEndpointNotFoundError:
|
except OllamaEndpointNotFoundError:
|
||||||
async for chunk in self._legacy_astream(messages, stop, **kwargs):
|
async for chunk in self._legacy_astream(messages, stop, **kwargs): # type: ignore[attr-defined]
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
@deprecated("0.0.3", alternative="_stream")
|
@deprecated("0.0.3", alternative="_stream")
|
||||||
|
@ -197,7 +197,7 @@ class ChatTongyi(BaseChatModel):
|
|||||||
return {
|
return {
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
"api_key": self.dashscope_api_key.get_secret_value(),
|
"api_key": self.dashscope_api_key.get_secret_value(), # type: ignore[union-attr]
|
||||||
"result_format": "message",
|
"result_format": "message",
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
|
@ -121,7 +121,7 @@ def _parse_chat_history_gemini(
|
|||||||
elif path.startswith("data:image/"):
|
elif path.startswith("data:image/"):
|
||||||
# extract base64 component from image uri
|
# extract base64 component from image uri
|
||||||
try:
|
try:
|
||||||
encoded = re.search(r"data:image/\w{2,4};base64,(.*)", path).group(
|
encoded = re.search(r"data:image/\w{2,4};base64,(.*)", path).group( # type: ignore[union-attr]
|
||||||
1
|
1
|
||||||
)
|
)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
|
@ -52,7 +52,7 @@ def _parse_chat_history(history: List[BaseMessage]) -> List[Dict[str, str]]:
|
|||||||
return chat_history
|
return chat_history
|
||||||
|
|
||||||
|
|
||||||
class ChatYandexGPT(_BaseYandexGPT, BaseChatModel):
|
class ChatYandexGPT(_BaseYandexGPT, BaseChatModel): # type: ignore[misc]
|
||||||
"""Wrapper around YandexGPT large language models.
|
"""Wrapper around YandexGPT large language models.
|
||||||
|
|
||||||
There are two authentication options for the service account
|
There are two authentication options for the service account
|
||||||
@ -156,7 +156,7 @@ def _make_request(
|
|||||||
messages=[Message(**message) for message in message_history],
|
messages=[Message(**message) for message in message_history],
|
||||||
)
|
)
|
||||||
stub = TextGenerationServiceStub(channel)
|
stub = TextGenerationServiceStub(channel)
|
||||||
res = stub.Completion(request, metadata=self._grpc_metadata)
|
res = stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
|
||||||
return list(res)[0].alternatives[0].message.text
|
return list(res)[0].alternatives[0].message.text
|
||||||
|
|
||||||
|
|
||||||
@ -201,7 +201,7 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
|
|||||||
messages=[Message(**message) for message in message_history],
|
messages=[Message(**message) for message in message_history],
|
||||||
)
|
)
|
||||||
stub = TextGenerationAsyncServiceStub(channel)
|
stub = TextGenerationAsyncServiceStub(channel)
|
||||||
operation = await stub.Completion(request, metadata=self._grpc_metadata)
|
operation = await stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
|
||||||
async with grpc.aio.secure_channel(
|
async with grpc.aio.secure_channel(
|
||||||
operation_api_url, channel_credentials
|
operation_api_url, channel_credentials
|
||||||
) as operation_channel:
|
) as operation_channel:
|
||||||
@ -210,7 +210,8 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
|
|||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
operation_request = GetOperationRequest(operation_id=operation.id)
|
operation_request = GetOperationRequest(operation_id=operation.id)
|
||||||
operation = await operation_stub.Get(
|
operation = await operation_stub.Get(
|
||||||
operation_request, metadata=self._grpc_metadata
|
operation_request,
|
||||||
|
metadata=self._grpc_metadata, # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
|
|
||||||
completion_response = CompletionResponse()
|
completion_response = CompletionResponse()
|
||||||
|
@ -161,7 +161,7 @@ class ChatZhipuAI(BaseChatModel):
|
|||||||
|
|
||||||
return attributes
|
return attributes
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
try:
|
try:
|
||||||
import zhipuai
|
import zhipuai
|
||||||
@ -174,7 +174,7 @@ class ChatZhipuAI(BaseChatModel):
|
|||||||
"Please install it via 'pip install zhipuai'"
|
"Please install it via 'pip install zhipuai'"
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, prompt):
|
def invoke(self, prompt): # type: ignore[no-untyped-def]
|
||||||
if self.model == "chatglm_turbo":
|
if self.model == "chatglm_turbo":
|
||||||
return self.zhipuai.model_api.invoke(
|
return self.zhipuai.model_api.invoke(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
@ -195,7 +195,7 @@ class ChatZhipuAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def sse_invoke(self, prompt):
|
def sse_invoke(self, prompt): # type: ignore[no-untyped-def]
|
||||||
if self.model == "chatglm_turbo":
|
if self.model == "chatglm_turbo":
|
||||||
return self.zhipuai.model_api.sse_invoke(
|
return self.zhipuai.model_api.sse_invoke(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
@ -218,7 +218,7 @@ class ChatZhipuAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def async_invoke(self, prompt):
|
async def async_invoke(self, prompt): # type: ignore[no-untyped-def]
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
partial_func = partial(
|
partial_func = partial(
|
||||||
self.zhipuai.model_api.async_invoke, model=self.model, prompt=prompt
|
self.zhipuai.model_api.async_invoke, model=self.model, prompt=prompt
|
||||||
@ -229,7 +229,7 @@ class ChatZhipuAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def async_invoke_result(self, task_id):
|
async def async_invoke_result(self, task_id): # type: ignore[no-untyped-def]
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
response = await loop.run_in_executor(
|
response = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
@ -270,11 +270,14 @@ class ChatZhipuAI(BaseChatModel):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
stream_iter = self._stream(
|
stream_iter = self._stream(
|
||||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
prompt=prompt, # type: ignore[arg-type]
|
||||||
|
stop=stop,
|
||||||
|
run_manager=run_manager,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return generate_from_stream(stream_iter)
|
return generate_from_stream(stream_iter)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
@ -307,7 +310,7 @@ class ChatZhipuAI(BaseChatModel):
|
|||||||
generations=[ChatGeneration(message=AIMessage(content=content))]
|
generations=[ChatGeneration(message=AIMessage(content=content))]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _stream(
|
def _stream( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
prompt: List[Dict[str, str]],
|
prompt: List[Dict[str, str]],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
|
@ -123,7 +123,7 @@ class AssemblyAIAudioLoaderById(BaseLoader):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, transcript_id, api_key, transcript_format):
|
def __init__(self, transcript_id, api_key, transcript_format): # type: ignore[no-untyped-def]
|
||||||
"""
|
"""
|
||||||
Initializes the AssemblyAI AssemblyAIAudioLoaderById.
|
Initializes the AssemblyAI AssemblyAIAudioLoaderById.
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ class AstraDBLoader(BaseLoader):
|
|||||||
return list(self.lazy_load())
|
return list(self.lazy_load())
|
||||||
|
|
||||||
def lazy_load(self) -> Iterator[Document]:
|
def lazy_load(self) -> Iterator[Document]:
|
||||||
queue = Queue(self.nb_prefetched)
|
queue = Queue(self.nb_prefetched) # type: ignore[var-annotated]
|
||||||
t = threading.Thread(target=self.fetch_results, args=(queue,))
|
t = threading.Thread(target=self.fetch_results, args=(queue,))
|
||||||
t.start()
|
t.start()
|
||||||
while True:
|
while True:
|
||||||
@ -95,7 +95,7 @@ class AstraDBLoader(BaseLoader):
|
|||||||
item = await run_in_executor(None, lambda it: next(it, done), iterator)
|
item = await run_in_executor(None, lambda it: next(it, done), iterator)
|
||||||
if item is done:
|
if item is done:
|
||||||
break
|
break
|
||||||
yield item
|
yield item # type: ignore[misc]
|
||||||
return
|
return
|
||||||
async_collection = await self.astra_env.async_astra_db.collection(
|
async_collection = await self.astra_env.async_astra_db.collection(
|
||||||
self.collection_name
|
self.collection_name
|
||||||
@ -116,13 +116,13 @@ class AstraDBLoader(BaseLoader):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def fetch_results(self, queue: Queue):
|
def fetch_results(self, queue: Queue): # type: ignore[no-untyped-def]
|
||||||
self.fetch_page_result(queue)
|
self.fetch_page_result(queue)
|
||||||
while self.find_options.get("pageState"):
|
while self.find_options.get("pageState"):
|
||||||
self.fetch_page_result(queue)
|
self.fetch_page_result(queue)
|
||||||
queue.put(None)
|
queue.put(None)
|
||||||
|
|
||||||
def fetch_page_result(self, queue: Queue):
|
def fetch_page_result(self, queue: Queue): # type: ignore[no-untyped-def]
|
||||||
res = self.collection.find(
|
res = self.collection.find(
|
||||||
filter=self.filter,
|
filter=self.filter,
|
||||||
options=self.find_options,
|
options=self.find_options,
|
||||||
|
@ -64,10 +64,10 @@ class BaseLoader(ABC):
|
|||||||
iterator = await run_in_executor(None, self.lazy_load)
|
iterator = await run_in_executor(None, self.lazy_load)
|
||||||
done = object()
|
done = object()
|
||||||
while True:
|
while True:
|
||||||
doc = await run_in_executor(None, next, iterator, done)
|
doc = await run_in_executor(None, next, iterator, done) # type: ignore[call-arg, arg-type]
|
||||||
if doc is done:
|
if doc is done:
|
||||||
break
|
break
|
||||||
yield doc
|
yield doc # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
class BaseBlobParser(ABC):
|
class BaseBlobParser(ABC):
|
||||||
|
@ -33,14 +33,14 @@ class CassandraLoader(BaseLoader):
|
|||||||
page_content_mapper: Callable[[Any], str] = str,
|
page_content_mapper: Callable[[Any], str] = str,
|
||||||
metadata_mapper: Callable[[Any], dict] = lambda _: {},
|
metadata_mapper: Callable[[Any], dict] = lambda _: {},
|
||||||
*,
|
*,
|
||||||
query_parameters: Union[dict, Sequence] = None,
|
query_parameters: Union[dict, Sequence] = None, # type: ignore[assignment]
|
||||||
query_timeout: Optional[float] = _NOT_SET,
|
query_timeout: Optional[float] = _NOT_SET, # type: ignore[assignment]
|
||||||
query_trace: bool = False,
|
query_trace: bool = False,
|
||||||
query_custom_payload: dict = None,
|
query_custom_payload: dict = None, # type: ignore[assignment]
|
||||||
query_execution_profile: Any = _NOT_SET,
|
query_execution_profile: Any = _NOT_SET,
|
||||||
query_paging_state: Any = None,
|
query_paging_state: Any = None,
|
||||||
query_host: Host = None,
|
query_host: Host = None,
|
||||||
query_execute_as: str = None,
|
query_execute_as: str = None, # type: ignore[assignment]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Document Loader for Apache Cassandra.
|
Document Loader for Apache Cassandra.
|
||||||
@ -85,7 +85,7 @@ class CassandraLoader(BaseLoader):
|
|||||||
self.query = f"SELECT * FROM {_keyspace}.{table};"
|
self.query = f"SELECT * FROM {_keyspace}.{table};"
|
||||||
self.metadata = {"table": table, "keyspace": _keyspace}
|
self.metadata = {"table": table, "keyspace": _keyspace}
|
||||||
else:
|
else:
|
||||||
self.query = query
|
self.query = query # type: ignore[assignment]
|
||||||
self.metadata = {}
|
self.metadata = {}
|
||||||
|
|
||||||
self.session = session or check_resolve_session(session)
|
self.session = session or check_resolve_session(session)
|
||||||
|
@ -27,7 +27,7 @@ class UnstructuredCHMLoader(UnstructuredFileLoader):
|
|||||||
def _get_elements(self) -> List:
|
def _get_elements(self) -> List:
|
||||||
from unstructured.partition.html import partition_html
|
from unstructured.partition.html import partition_html
|
||||||
|
|
||||||
with CHMParser(self.file_path) as f:
|
with CHMParser(self.file_path) as f: # type: ignore[arg-type]
|
||||||
return [
|
return [
|
||||||
partition_html(text=item["content"], **self.unstructured_kwargs)
|
partition_html(text=item["content"], **self.unstructured_kwargs)
|
||||||
for item in f.load_all()
|
for item in f.load_all()
|
||||||
@ -45,10 +45,10 @@ class CHMParser(object):
|
|||||||
self.file = chm.CHMFile()
|
self.file = chm.CHMFile()
|
||||||
self.file.LoadCHM(path)
|
self.file.LoadCHM(path)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self): # type: ignore[no-untyped-def]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback): # type: ignore[no-untyped-def]
|
||||||
if self.file:
|
if self.file:
|
||||||
self.file.CloseCHM()
|
self.file.CloseCHM()
|
||||||
|
|
||||||
|
@ -89,4 +89,4 @@ class AzureAIDocumentIntelligenceLoader(BaseLoader):
|
|||||||
blob = Blob.from_path(self.file_path)
|
blob = Blob.from_path(self.file_path)
|
||||||
yield from self.parser.parse(blob)
|
yield from self.parser.parse(blob)
|
||||||
else:
|
else:
|
||||||
yield from self.parser.parse_url(self.url_path)
|
yield from self.parser.parse_url(self.url_path) # type: ignore[arg-type]
|
||||||
|
@ -60,7 +60,7 @@ class MWDumpLoader(BaseLoader):
|
|||||||
self.skip_redirects = skip_redirects
|
self.skip_redirects = skip_redirects
|
||||||
self.stop_on_error = stop_on_error
|
self.stop_on_error = stop_on_error
|
||||||
|
|
||||||
def _load_dump_file(self):
|
def _load_dump_file(self): # type: ignore[no-untyped-def]
|
||||||
try:
|
try:
|
||||||
import mwxml
|
import mwxml
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@ -70,7 +70,7 @@ class MWDumpLoader(BaseLoader):
|
|||||||
|
|
||||||
return mwxml.Dump.from_file(open(self.file_path, encoding=self.encoding))
|
return mwxml.Dump.from_file(open(self.file_path, encoding=self.encoding))
|
||||||
|
|
||||||
def _load_single_page_from_dump(self, page) -> Document:
|
def _load_single_page_from_dump(self, page) -> Document: # type: ignore[no-untyped-def, return]
|
||||||
"""Parse a single page."""
|
"""Parse a single page."""
|
||||||
try:
|
try:
|
||||||
import mwparserfromhell
|
import mwparserfromhell
|
||||||
|
@ -11,7 +11,7 @@ from langchain_community.document_loaders.blob_loaders import Blob
|
|||||||
|
|
||||||
|
|
||||||
class VsdxParser(BaseBlobParser, ABC):
|
class VsdxParser(BaseBlobParser, ABC):
|
||||||
def parse(self, blob: Blob) -> Iterator[Document]:
|
def parse(self, blob: Blob) -> Iterator[Document]: # type: ignore[override]
|
||||||
"""Parse a vsdx file."""
|
"""Parse a vsdx file."""
|
||||||
return self.lazy_parse(blob)
|
return self.lazy_parse(blob)
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ class VsdxParser(BaseBlobParser, ABC):
|
|||||||
|
|
||||||
with blob.as_bytes_io() as pdf_file_obj:
|
with blob.as_bytes_io() as pdf_file_obj:
|
||||||
with zipfile.ZipFile(pdf_file_obj, "r") as zfile:
|
with zipfile.ZipFile(pdf_file_obj, "r") as zfile:
|
||||||
pages = self.get_pages_content(zfile, blob.source)
|
pages = self.get_pages_content(zfile, blob.source) # type: ignore[arg-type]
|
||||||
|
|
||||||
yield from [
|
yield from [
|
||||||
Document(
|
Document(
|
||||||
@ -60,13 +60,13 @@ class VsdxParser(BaseBlobParser, ABC):
|
|||||||
|
|
||||||
if "visio/pages/pages.xml" not in zfile.namelist():
|
if "visio/pages/pages.xml" not in zfile.namelist():
|
||||||
print("WARNING - No pages.xml file found in {}".format(source))
|
print("WARNING - No pages.xml file found in {}".format(source))
|
||||||
return
|
return # type: ignore[return-value]
|
||||||
if "visio/pages/_rels/pages.xml.rels" not in zfile.namelist():
|
if "visio/pages/_rels/pages.xml.rels" not in zfile.namelist():
|
||||||
print("WARNING - No pages.xml.rels file found in {}".format(source))
|
print("WARNING - No pages.xml.rels file found in {}".format(source))
|
||||||
return
|
return # type: ignore[return-value]
|
||||||
if "docProps/app.xml" not in zfile.namelist():
|
if "docProps/app.xml" not in zfile.namelist():
|
||||||
print("WARNING - No app.xml file found in {}".format(source))
|
print("WARNING - No app.xml file found in {}".format(source))
|
||||||
return
|
return # type: ignore[return-value]
|
||||||
|
|
||||||
pagesxml_content: dict = xmltodict.parse(zfile.read("visio/pages/pages.xml"))
|
pagesxml_content: dict = xmltodict.parse(zfile.read("visio/pages/pages.xml"))
|
||||||
appxml_content: dict = xmltodict.parse(zfile.read("docProps/app.xml"))
|
appxml_content: dict = xmltodict.parse(zfile.read("docProps/app.xml"))
|
||||||
@ -79,7 +79,7 @@ class VsdxParser(BaseBlobParser, ABC):
|
|||||||
rel["@Name"].strip() for rel in pagesxml_content["Pages"]["Page"]
|
rel["@Name"].strip() for rel in pagesxml_content["Pages"]["Page"]
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
disordered_names: List[str] = [
|
disordered_names: List[str] = [ # type: ignore[no-redef]
|
||||||
pagesxml_content["Pages"]["Page"]["@Name"].strip()
|
pagesxml_content["Pages"]["Page"]["@Name"].strip()
|
||||||
]
|
]
|
||||||
if isinstance(pagesxmlrels_content["Relationships"]["Relationship"], list):
|
if isinstance(pagesxmlrels_content["Relationships"]["Relationship"], list):
|
||||||
@ -88,7 +88,7 @@ class VsdxParser(BaseBlobParser, ABC):
|
|||||||
for rel in pagesxmlrels_content["Relationships"]["Relationship"]
|
for rel in pagesxmlrels_content["Relationships"]["Relationship"]
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
disordered_paths: List[str] = [
|
disordered_paths: List[str] = [ # type: ignore[no-redef]
|
||||||
"visio/pages/"
|
"visio/pages/"
|
||||||
+ pagesxmlrels_content["Relationships"]["Relationship"]["@Target"]
|
+ pagesxmlrels_content["Relationships"]["Relationship"]["@Target"]
|
||||||
]
|
]
|
||||||
|
@ -89,7 +89,7 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
|
|||||||
print(f"Exception occurred while trying to get embeddings: {str(e)}")
|
print(f"Exception occurred while trying to get embeddings: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]:
|
def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: # type: ignore[override]
|
||||||
"""Public method to get embeddings for a list of documents.
|
"""Public method to get embeddings for a list of documents.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -100,7 +100,7 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
|
|||||||
"""
|
"""
|
||||||
return self._embed(texts)
|
return self._embed(texts)
|
||||||
|
|
||||||
def embed_query(self, text: str) -> Optional[List[float]]:
|
def embed_query(self, text: str) -> Optional[List[float]]: # type: ignore[override]
|
||||||
"""Public method to get embedding for a single query text.
|
"""Public method to get embedding for a single query text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -56,7 +56,7 @@ class EdenAiEmbeddings(BaseModel, Embeddings):
|
|||||||
headers = {
|
headers = {
|
||||||
"accept": "application/json",
|
"accept": "application/json",
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
"authorization": f"Bearer {self.edenai_api_key.get_secret_value()}",
|
"authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
|
||||||
"User-Agent": self.get_user_agent(),
|
"User-Agent": self.get_user_agent(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ class EmbaasEmbeddings(BaseModel, Embeddings):
|
|||||||
def _handle_request(self, payload: EmbaasEmbeddingsPayload) -> List[List[float]]:
|
def _handle_request(self, payload: EmbaasEmbeddingsPayload) -> List[List[float]]:
|
||||||
"""Sends a request to the Embaas API and handles the response."""
|
"""Sends a request to the Embaas API and handles the response."""
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.embaas_api_key.get_secret_value()}",
|
"Authorization": f"Bearer {self.embaas_api_key.get_secret_value()}", # type: ignore[union-attr]
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -162,5 +162,5 @@ class TinyAsyncGradientEmbeddingClient: #: :meta private:
|
|||||||
It might be entirely removed in the future.
|
It might be entirely removed in the future.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
|
||||||
raise ValueError("Deprecated,TinyAsyncGradientEmbeddingClient was removed.")
|
raise ValueError("Deprecated,TinyAsyncGradientEmbeddingClient was removed.")
|
||||||
|
@ -56,7 +56,7 @@ class LLMRailsEmbeddings(BaseModel, Embeddings):
|
|||||||
"""
|
"""
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"https://api.llmrails.com/v1/embeddings",
|
"https://api.llmrails.com/v1/embeddings",
|
||||||
headers={"X-API-KEY": self.api_key.get_secret_value()},
|
headers={"X-API-KEY": self.api_key.get_secret_value()}, # type: ignore[union-attr]
|
||||||
json={"input": texts, "model": self.model},
|
json={"input": texts, "model": self.model},
|
||||||
timeout=60,
|
timeout=60,
|
||||||
)
|
)
|
||||||
|
@ -110,7 +110,7 @@ class MiniMaxEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
# HTTP headers for authorization
|
# HTTP headers for authorization
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.minimax_api_key.get_secret_value()}",
|
"Authorization": f"Bearer {self.minimax_api_key.get_secret_value()}", # type: ignore[union-attr]
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,7 +71,8 @@ class MlflowEmbeddings(Embeddings, BaseModel):
|
|||||||
embeddings: List[List[float]] = []
|
embeddings: List[List[float]] = []
|
||||||
for txt in _chunk(texts, 20):
|
for txt in _chunk(texts, 20):
|
||||||
resp = self._client.predict(
|
resp = self._client.predict(
|
||||||
endpoint=self.endpoint, inputs={"input": txt, **params}
|
endpoint=self.endpoint,
|
||||||
|
inputs={"input": txt, **params}, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
embeddings.extend(r["embedding"] for r in resp["data"])
|
embeddings.extend(r["embedding"] for r in resp["data"])
|
||||||
return embeddings
|
return embeddings
|
||||||
|
@ -63,16 +63,16 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
If not specified , DEFAULT will be used
|
If not specified , DEFAULT will be used
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_id: str = None
|
model_id: str = None # type: ignore[assignment]
|
||||||
"""Id of the model to call, e.g., cohere.embed-english-light-v2.0"""
|
"""Id of the model to call, e.g., cohere.embed-english-light-v2.0"""
|
||||||
|
|
||||||
model_kwargs: Optional[Dict] = None
|
model_kwargs: Optional[Dict] = None
|
||||||
"""Keyword arguments to pass to the model"""
|
"""Keyword arguments to pass to the model"""
|
||||||
|
|
||||||
service_endpoint: str = None
|
service_endpoint: str = None # type: ignore[assignment]
|
||||||
"""service endpoint url"""
|
"""service endpoint url"""
|
||||||
|
|
||||||
compartment_id: str = None
|
compartment_id: str = None # type: ignore[assignment]
|
||||||
"""OCID of compartment"""
|
"""OCID of compartment"""
|
||||||
|
|
||||||
truncate: Optional[str] = "END"
|
truncate: Optional[str] = "END"
|
||||||
@ -109,7 +109,7 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
client_kwargs.pop("signer", None)
|
client_kwargs.pop("signer", None)
|
||||||
elif values["auth_type"] == OCIAuthType(2).name:
|
elif values["auth_type"] == OCIAuthType(2).name:
|
||||||
|
|
||||||
def make_security_token_signer(oci_config):
|
def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
|
||||||
pk = oci.signer.load_private_key_from_file(
|
pk = oci.signer.load_private_key_from_file(
|
||||||
oci_config.get("key_file"), None
|
oci_config.get("key_file"), None
|
||||||
)
|
)
|
||||||
|
@ -78,7 +78,7 @@ class SpacyEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
A list of embeddings, one for each document.
|
A list of embeddings, one for each document.
|
||||||
"""
|
"""
|
||||||
return [self.nlp(text).vector.tolist() for text in texts]
|
return [self.nlp(text).vector.tolist() for text in texts] # type: ignore[misc]
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""
|
"""
|
||||||
@ -90,7 +90,7 @@ class SpacyEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
The embedding for the text.
|
The embedding for the text.
|
||||||
"""
|
"""
|
||||||
return self.nlp(text).vector.tolist()
|
return self.nlp(text).vector.tolist() # type: ignore[misc]
|
||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""
|
"""
|
||||||
|
@ -42,10 +42,10 @@ class YandexGPTEmbeddings(BaseModel, Embeddings):
|
|||||||
embeddings = YandexGPTEmbeddings(iam_token="t1.9eu...", model_uri="emb://<folder-id>/text-search-query/latest")
|
embeddings = YandexGPTEmbeddings(iam_token="t1.9eu...", model_uri="emb://<folder-id>/text-search-query/latest")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
iam_token: SecretStr = ""
|
iam_token: SecretStr = "" # type: ignore[assignment]
|
||||||
"""Yandex Cloud IAM token for service account
|
"""Yandex Cloud IAM token for service account
|
||||||
with the `ai.languageModels.user` role"""
|
with the `ai.languageModels.user` role"""
|
||||||
api_key: SecretStr = ""
|
api_key: SecretStr = "" # type: ignore[assignment]
|
||||||
"""Yandex Cloud Api Key for service account
|
"""Yandex Cloud Api Key for service account
|
||||||
with the `ai.languageModels.user` role"""
|
with the `ai.languageModels.user` role"""
|
||||||
model_uri: str = ""
|
model_uri: str = ""
|
||||||
@ -146,7 +146,7 @@ def _embed_with_retry(llm: YandexGPTEmbeddings, **kwargs: Any) -> Any:
|
|||||||
return _completion_with_retry(**kwargs)
|
return _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _make_request(self: YandexGPTEmbeddings, texts: List[str]):
|
def _make_request(self: YandexGPTEmbeddings, texts: List[str]): # type: ignore[no-untyped-def]
|
||||||
try:
|
try:
|
||||||
import grpc
|
import grpc
|
||||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
|
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
|
||||||
@ -167,7 +167,7 @@ def _make_request(self: YandexGPTEmbeddings, texts: List[str]):
|
|||||||
for text in texts:
|
for text in texts:
|
||||||
request = TextEmbeddingRequest(model_uri=self.model_uri, text=text)
|
request = TextEmbeddingRequest(model_uri=self.model_uri, text=text)
|
||||||
stub = EmbeddingsServiceStub(channel)
|
stub = EmbeddingsServiceStub(channel)
|
||||||
res = stub.TextEmbedding(request, metadata=self._grpc_metadata)
|
res = stub.TextEmbedding(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
|
||||||
result.append(list(res.embedding))
|
result.append(list(res.embedding))
|
||||||
time.sleep(self.sleep_interval)
|
time.sleep(self.sleep_interval)
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
cleaned_list.append(value_sanitize(item))
|
cleaned_list.append(value_sanitize(item))
|
||||||
else:
|
else:
|
||||||
cleaned_list.append(item)
|
cleaned_list.append(item)
|
||||||
new_dict[key] = cleaned_list
|
new_dict[key] = cleaned_list # type: ignore[assignment]
|
||||||
else:
|
else:
|
||||||
new_dict[key] = value
|
new_dict[key] = value
|
||||||
return new_dict
|
return new_dict
|
||||||
|
@ -95,12 +95,13 @@ class OntotextGraphDBGraph:
|
|||||||
|
|
||||||
if local_file:
|
if local_file:
|
||||||
ontology_schema_graph = self._load_ontology_schema_from_file(
|
ontology_schema_graph = self._load_ontology_schema_from_file(
|
||||||
local_file, local_file_format
|
local_file,
|
||||||
|
local_file_format, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._validate_user_query(query_ontology)
|
self._validate_user_query(query_ontology) # type: ignore[arg-type]
|
||||||
ontology_schema_graph = self._load_ontology_schema_with_query(
|
ontology_schema_graph = self._load_ontology_schema_with_query(
|
||||||
query_ontology
|
query_ontology # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
self.schema = ontology_schema_graph.serialize(format="turtle")
|
self.schema = ontology_schema_graph.serialize(format="turtle")
|
||||||
|
|
||||||
@ -139,7 +140,7 @@ class OntotextGraphDBGraph:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_ontology_schema_from_file(local_file: str, local_file_format: str = None):
|
def _load_ontology_schema_from_file(local_file: str, local_file_format: str = None): # type: ignore[no-untyped-def, assignment]
|
||||||
"""
|
"""
|
||||||
Parse the ontology schema statements from the provided file
|
Parse the ontology schema statements from the provided file
|
||||||
"""
|
"""
|
||||||
@ -176,7 +177,7 @@ class OntotextGraphDBGraph:
|
|||||||
"Invalid query type. Only CONSTRUCT queries are supported."
|
"Invalid query type. Only CONSTRUCT queries are supported."
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_ontology_schema_with_query(self, query: str):
|
def _load_ontology_schema_with_query(self, query: str): # type: ignore[no-untyped-def]
|
||||||
"""
|
"""
|
||||||
Execute the query for collecting the ontology schema statements
|
Execute the query for collecting the ontology schema statements
|
||||||
"""
|
"""
|
||||||
|
@ -31,7 +31,7 @@ class TigerGraph(GraphStore):
|
|||||||
def schema(self) -> Dict[str, Any]:
|
def schema(self) -> Dict[str, Any]:
|
||||||
return self._schema
|
return self._schema
|
||||||
|
|
||||||
def get_schema(self) -> str:
|
def get_schema(self) -> str: # type: ignore[override]
|
||||||
if self._schema:
|
if self._schema:
|
||||||
return str(self._schema)
|
return str(self._schema)
|
||||||
else:
|
else:
|
||||||
@ -71,10 +71,10 @@ class TigerGraph(GraphStore):
|
|||||||
"""
|
"""
|
||||||
return self._conn.getSchema(force=True)
|
return self._conn.getSchema(force=True)
|
||||||
|
|
||||||
def refresh_schema(self):
|
def refresh_schema(self): # type: ignore[no-untyped-def]
|
||||||
self.generate_schema()
|
self.generate_schema()
|
||||||
|
|
||||||
def query(self, query: str) -> Dict[str, Any]:
|
def query(self, query: str) -> Dict[str, Any]: # type: ignore[override]
|
||||||
"""Query the TigerGraph database."""
|
"""Query the TigerGraph database."""
|
||||||
answer = self._conn.ai.query(query)
|
answer = self._conn.ai.query(query)
|
||||||
return answer
|
return answer
|
||||||
|
@ -165,7 +165,7 @@ class GPT2ContentFormatter(ContentFormatterBase):
|
|||||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||||
return [AzureMLEndpointApiType.realtime]
|
return [AzureMLEndpointApiType.realtime]
|
||||||
|
|
||||||
def format_request_payload(
|
def format_request_payload( # type: ignore[override]
|
||||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||||
@ -174,13 +174,13 @@ class GPT2ContentFormatter(ContentFormatterBase):
|
|||||||
)
|
)
|
||||||
return str.encode(request_payload)
|
return str.encode(request_payload)
|
||||||
|
|
||||||
def format_response_payload(
|
def format_response_payload( # type: ignore[override]
|
||||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||||
) -> Generation:
|
) -> Generation:
|
||||||
try:
|
try:
|
||||||
choice = json.loads(output)[0]["0"]
|
choice = json.loads(output)[0]["0"]
|
||||||
except (KeyError, IndexError, TypeError) as e:
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
||||||
return Generation(text=choice)
|
return Generation(text=choice)
|
||||||
|
|
||||||
|
|
||||||
@ -207,7 +207,7 @@ class HFContentFormatter(ContentFormatterBase):
|
|||||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||||
return [AzureMLEndpointApiType.realtime]
|
return [AzureMLEndpointApiType.realtime]
|
||||||
|
|
||||||
def format_request_payload(
|
def format_request_payload( # type: ignore[override]
|
||||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
ContentFormatterBase.escape_special_characters(prompt)
|
ContentFormatterBase.escape_special_characters(prompt)
|
||||||
@ -216,13 +216,13 @@ class HFContentFormatter(ContentFormatterBase):
|
|||||||
)
|
)
|
||||||
return str.encode(request_payload)
|
return str.encode(request_payload)
|
||||||
|
|
||||||
def format_response_payload(
|
def format_response_payload( # type: ignore[override]
|
||||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||||
) -> Generation:
|
) -> Generation:
|
||||||
try:
|
try:
|
||||||
choice = json.loads(output)[0]["0"]["generated_text"]
|
choice = json.loads(output)[0]["0"]["generated_text"]
|
||||||
except (KeyError, IndexError, TypeError) as e:
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
||||||
return Generation(text=choice)
|
return Generation(text=choice)
|
||||||
|
|
||||||
|
|
||||||
@ -233,7 +233,7 @@ class DollyContentFormatter(ContentFormatterBase):
|
|||||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||||
return [AzureMLEndpointApiType.realtime]
|
return [AzureMLEndpointApiType.realtime]
|
||||||
|
|
||||||
def format_request_payload(
|
def format_request_payload( # type: ignore[override]
|
||||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||||
@ -245,13 +245,13 @@ class DollyContentFormatter(ContentFormatterBase):
|
|||||||
)
|
)
|
||||||
return str.encode(request_payload)
|
return str.encode(request_payload)
|
||||||
|
|
||||||
def format_response_payload(
|
def format_response_payload( # type: ignore[override]
|
||||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||||
) -> Generation:
|
) -> Generation:
|
||||||
try:
|
try:
|
||||||
choice = json.loads(output)[0]
|
choice = json.loads(output)[0]
|
||||||
except (KeyError, IndexError, TypeError) as e:
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
||||||
return Generation(text=choice)
|
return Generation(text=choice)
|
||||||
|
|
||||||
|
|
||||||
@ -262,7 +262,7 @@ class LlamaContentFormatter(ContentFormatterBase):
|
|||||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||||
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
|
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
|
||||||
|
|
||||||
def format_request_payload(
|
def format_request_payload( # type: ignore[override]
|
||||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Formats the request according to the chosen api"""
|
"""Formats the request according to the chosen api"""
|
||||||
@ -284,7 +284,7 @@ class LlamaContentFormatter(ContentFormatterBase):
|
|||||||
)
|
)
|
||||||
return str.encode(request_payload)
|
return str.encode(request_payload)
|
||||||
|
|
||||||
def format_response_payload(
|
def format_response_payload( # type: ignore[override]
|
||||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||||
) -> Generation:
|
) -> Generation:
|
||||||
"""Formats response"""
|
"""Formats response"""
|
||||||
@ -292,7 +292,7 @@ class LlamaContentFormatter(ContentFormatterBase):
|
|||||||
try:
|
try:
|
||||||
choice = json.loads(output)[0]["0"]
|
choice = json.loads(output)[0]["0"]
|
||||||
except (KeyError, IndexError, TypeError) as e:
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
||||||
return Generation(text=choice)
|
return Generation(text=choice)
|
||||||
if api_type == AzureMLEndpointApiType.serverless:
|
if api_type == AzureMLEndpointApiType.serverless:
|
||||||
try:
|
try:
|
||||||
@ -304,7 +304,7 @@ class LlamaContentFormatter(ContentFormatterBase):
|
|||||||
"received."
|
"received."
|
||||||
)
|
)
|
||||||
except (KeyError, IndexError, TypeError) as e:
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
||||||
return Generation(
|
return Generation(
|
||||||
text=choice["text"].strip(),
|
text=choice["text"].strip(),
|
||||||
generation_info=dict(
|
generation_info=dict(
|
||||||
@ -397,7 +397,7 @@ class AzureMLBaseEndpoint(BaseModel):
|
|||||||
) -> AzureMLEndpointApiType:
|
) -> AzureMLEndpointApiType:
|
||||||
"""Validate that endpoint api type is compatible with the URL format."""
|
"""Validate that endpoint api type is compatible with the URL format."""
|
||||||
endpoint_url = values.get("endpoint_url")
|
endpoint_url = values.get("endpoint_url")
|
||||||
if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith(
|
if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith( # type: ignore[union-attr]
|
||||||
"/score"
|
"/score"
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -407,8 +407,8 @@ class AzureMLBaseEndpoint(BaseModel):
|
|||||||
"`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead."
|
"`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead."
|
||||||
)
|
)
|
||||||
if field_value == AzureMLEndpointApiType.serverless and not (
|
if field_value == AzureMLEndpointApiType.serverless and not (
|
||||||
endpoint_url.endswith("/v1/completions")
|
endpoint_url.endswith("/v1/completions") # type: ignore[union-attr]
|
||||||
or endpoint_url.endswith("/v1/chat/completions")
|
or endpoint_url.endswith("/v1/chat/completions") # type: ignore[union-attr]
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Endpoints of type `serverless` should follow the format "
|
"Endpoints of type `serverless` should follow the format "
|
||||||
@ -426,7 +426,9 @@ class AzureMLBaseEndpoint(BaseModel):
|
|||||||
deployment_name = values.get("deployment_name")
|
deployment_name = values.get("deployment_name")
|
||||||
|
|
||||||
http_client = AzureMLEndpointClient(
|
http_client = AzureMLEndpointClient(
|
||||||
endpoint_url, endpoint_key.get_secret_value(), deployment_name
|
endpoint_url, # type: ignore
|
||||||
|
endpoint_key.get_secret_value(), # type: ignore
|
||||||
|
deployment_name, # type: ignore
|
||||||
)
|
)
|
||||||
return http_client
|
return http_client
|
||||||
|
|
||||||
|
@ -56,11 +56,11 @@ class BaichuanLLM(LLM):
|
|||||||
def _post(self, request: Any) -> Any:
|
def _post(self, request: Any) -> Any:
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {self.baichuan_api_key.get_secret_value()}",
|
"Authorization": f"Bearer {self.baichuan_api_key.get_secret_value()}", # type: ignore[union-attr]
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.baichuan_api_host,
|
self.baichuan_api_host, # type: ignore[arg-type]
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=request,
|
json=request,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
|
@ -395,8 +395,8 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"amazon-bedrock-guardrailDetails": {
|
"amazon-bedrock-guardrailDetails": {
|
||||||
"guardrailId": self.guardrails.get("id"),
|
"guardrailId": self.guardrails.get("id"), # type: ignore[union-attr]
|
||||||
"guardrailVersion": self.guardrails.get("version"),
|
"guardrailVersion": self.guardrails.get("version"), # type: ignore[union-attr]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -427,7 +427,7 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
|
|
||||||
if self._guardrails_enabled:
|
if self._guardrails_enabled:
|
||||||
request_options["guardrail"] = "ENABLED"
|
request_options["guardrail"] = "ENABLED"
|
||||||
if self.guardrails.get("trace"):
|
if self.guardrails.get("trace"): # type: ignore[union-attr]
|
||||||
request_options["trace"] = "ENABLED"
|
request_options["trace"] = "ENABLED"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -446,7 +446,7 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
# Verify and raise a callback error if any intervention occurs or a signal is
|
# Verify and raise a callback error if any intervention occurs or a signal is
|
||||||
# sent from a Bedrock service,
|
# sent from a Bedrock service,
|
||||||
# such as when guardrails are triggered.
|
# such as when guardrails are triggered.
|
||||||
services_trace = self._get_bedrock_services_signal(body)
|
services_trace = self._get_bedrock_services_signal(body) # type: ignore[arg-type]
|
||||||
|
|
||||||
if services_trace.get("signal") and run_manager is not None:
|
if services_trace.get("signal") and run_manager is not None:
|
||||||
run_manager.on_llm_error(
|
run_manager.on_llm_error(
|
||||||
@ -468,7 +468,7 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
self._guardrails_enabled
|
self._guardrails_enabled
|
||||||
and self.guardrails.get("trace")
|
and self.guardrails.get("trace") # type: ignore[union-attr]
|
||||||
and self._is_guardrails_intervention(body)
|
and self._is_guardrails_intervention(body)
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
@ -526,7 +526,7 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
|
|
||||||
if self._guardrails_enabled:
|
if self._guardrails_enabled:
|
||||||
request_options["guardrail"] = "ENABLED"
|
request_options["guardrail"] = "ENABLED"
|
||||||
if self.guardrails.get("trace"):
|
if self.guardrails.get("trace"): # type: ignore[union-attr]
|
||||||
request_options["trace"] = "ENABLED"
|
request_options["trace"] = "ENABLED"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -540,7 +540,7 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
# verify and raise callback error if any middleware intervened
|
# verify and raise callback error if any middleware intervened
|
||||||
self._get_bedrock_services_signal(chunk.generation_info)
|
self._get_bedrock_services_signal(chunk.generation_info) # type: ignore[arg-type]
|
||||||
|
|
||||||
if run_manager is not None:
|
if run_manager is not None:
|
||||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
@ -588,7 +588,7 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
):
|
):
|
||||||
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
elif run_manager is not None:
|
elif run_manager is not None:
|
||||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk) # type: ignore[unused-coroutine]
|
||||||
|
|
||||||
|
|
||||||
class Bedrock(LLM, BedrockBase):
|
class Bedrock(LLM, BedrockBase):
|
||||||
|
@ -42,10 +42,10 @@ class OCIGenAIBase(BaseModel, ABC):
|
|||||||
If not specified , DEFAULT will be used
|
If not specified , DEFAULT will be used
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_id: str = None
|
model_id: str = None # type: ignore[assignment]
|
||||||
"""Id of the model to call, e.g., cohere.command"""
|
"""Id of the model to call, e.g., cohere.command"""
|
||||||
|
|
||||||
provider: str = None
|
provider: str = None # type: ignore[assignment]
|
||||||
"""Provider name of the model. Default to None,
|
"""Provider name of the model. Default to None,
|
||||||
will try to be derived from the model_id
|
will try to be derived from the model_id
|
||||||
otherwise, requires user input
|
otherwise, requires user input
|
||||||
@ -54,10 +54,10 @@ class OCIGenAIBase(BaseModel, ABC):
|
|||||||
model_kwargs: Optional[Dict] = None
|
model_kwargs: Optional[Dict] = None
|
||||||
"""Keyword arguments to pass to the model"""
|
"""Keyword arguments to pass to the model"""
|
||||||
|
|
||||||
service_endpoint: str = None
|
service_endpoint: str = None # type: ignore[assignment]
|
||||||
"""service endpoint url"""
|
"""service endpoint url"""
|
||||||
|
|
||||||
compartment_id: str = None
|
compartment_id: str = None # type: ignore[assignment]
|
||||||
"""OCID of compartment"""
|
"""OCID of compartment"""
|
||||||
|
|
||||||
is_stream: bool = False
|
is_stream: bool = False
|
||||||
@ -94,7 +94,7 @@ class OCIGenAIBase(BaseModel, ABC):
|
|||||||
client_kwargs.pop("signer", None)
|
client_kwargs.pop("signer", None)
|
||||||
elif values["auth_type"] == OCIAuthType(2).name:
|
elif values["auth_type"] == OCIAuthType(2).name:
|
||||||
|
|
||||||
def make_security_token_signer(oci_config):
|
def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
|
||||||
pk = oci.signer.load_private_key_from_file(
|
pk = oci.signer.load_private_key_from_file(
|
||||||
oci_config.get("key_file"), None
|
oci_config.get("key_file"), None
|
||||||
)
|
)
|
||||||
|
@ -297,7 +297,7 @@ class _OllamaCommon(BaseLanguageModel):
|
|||||||
"Ollama call failed with status code 404."
|
"Ollama call failed with status code 404."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
optional_detail = await response.json().get("error")
|
optional_detail = await response.json().get("error") # type: ignore[attr-defined]
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Ollama call failed with status code {response.status}."
|
f"Ollama call failed with status code {response.status}."
|
||||||
f" Details: {optional_detail}"
|
f" Details: {optional_detail}"
|
||||||
@ -380,7 +380,7 @@ class Ollama(BaseLLM, _OllamaCommon):
|
|||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "ollama-llm"
|
return "ollama-llm"
|
||||||
|
|
||||||
def _generate(
|
def _generate( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
@ -416,7 +416,7 @@ class Ollama(BaseLLM, _OllamaCommon):
|
|||||||
generations.append([final_chunk])
|
generations.append([final_chunk])
|
||||||
return LLMResult(generations=generations)
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
@ -445,7 +445,7 @@ class Ollama(BaseLLM, _OllamaCommon):
|
|||||||
prompt,
|
prompt,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
images=images,
|
images=images,
|
||||||
run_manager=run_manager,
|
run_manager=run_manager, # type: ignore[arg-type]
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -102,7 +102,7 @@ class PipelineAI(LLM, BaseModel):
|
|||||||
"Could not import pipeline-ai python package. "
|
"Could not import pipeline-ai python package. "
|
||||||
"Please install it with `pip install pipeline-ai`."
|
"Please install it with `pip install pipeline-ai`."
|
||||||
)
|
)
|
||||||
client = PipelineCloud(token=self.pipeline_api_key.get_secret_value())
|
client = PipelineCloud(token=self.pipeline_api_key.get_secret_value()) # type: ignore[union-attr]
|
||||||
params = self.pipeline_kwargs or {}
|
params = self.pipeline_kwargs or {}
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ class StochasticAI(LLM):
|
|||||||
url=self.api_url,
|
url=self.api_url,
|
||||||
json={"prompt": prompt, "params": params},
|
json={"prompt": prompt, "params": params},
|
||||||
headers={
|
headers={
|
||||||
"apiKey": f"{self.stochasticai_api_key.get_secret_value()}",
|
"apiKey": f"{self.stochasticai_api_key.get_secret_value()}", # type: ignore[union-attr]
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
@ -119,7 +119,7 @@ class StochasticAI(LLM):
|
|||||||
response_get = requests.get(
|
response_get = requests.get(
|
||||||
url=response_post_json["data"]["responseUrl"],
|
url=response_post_json["data"]["responseUrl"],
|
||||||
headers={
|
headers={
|
||||||
"apiKey": f"{self.stochasticai_api_key.get_secret_value()}",
|
"apiKey": f"{self.stochasticai_api_key.get_secret_value()}", # type: ignore[union-attr]
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
|
@ -49,7 +49,7 @@ def is_gemini_model(model_name: str) -> bool:
|
|||||||
return model_name is not None and "gemini" in model_name
|
return model_name is not None and "gemini" in model_name
|
||||||
|
|
||||||
|
|
||||||
def completion_with_retry(
|
def completion_with_retry( # type: ignore[no-redef]
|
||||||
llm: VertexAI,
|
llm: VertexAI,
|
||||||
prompt: List[Union[str, "Image"]],
|
prompt: List[Union[str, "Image"]],
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
@ -330,7 +330,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
generation += chunk
|
generation += chunk
|
||||||
generations.append([generation])
|
generations.append([generation])
|
||||||
else:
|
else:
|
||||||
res = completion_with_retry(
|
res = completion_with_retry( # type: ignore[misc]
|
||||||
self,
|
self,
|
||||||
[prompt],
|
[prompt],
|
||||||
stream=should_stream,
|
stream=should_stream,
|
||||||
@ -373,7 +373,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[GenerationChunk]:
|
) -> Iterator[GenerationChunk]:
|
||||||
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
||||||
for stream_resp in completion_with_retry(
|
for stream_resp in completion_with_retry( # type: ignore[misc]
|
||||||
self,
|
self,
|
||||||
[prompt],
|
[prompt],
|
||||||
stream=True,
|
stream=True,
|
||||||
|
@ -250,9 +250,9 @@ class WatsonxLLM(BaseLLM):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _get_chat_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
|
def _get_chat_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||||
params: Dict[str, Any] = {**self.params} if self.params else None
|
params: Dict[str, Any] = {**self.params} if self.params else {}
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
params = (params or {}) | {"stop_sequences": stop}
|
params["stop_sequences"] = stop
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def _create_llm_result(self, response: List[dict]) -> LLMResult:
|
def _create_llm_result(self, response: List[dict]) -> LLMResult:
|
||||||
|
@ -25,10 +25,10 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class _BaseYandexGPT(Serializable):
|
class _BaseYandexGPT(Serializable):
|
||||||
iam_token: SecretStr = ""
|
iam_token: SecretStr = "" # type: ignore[assignment]
|
||||||
"""Yandex Cloud IAM token for service or user account
|
"""Yandex Cloud IAM token for service or user account
|
||||||
with the `ai.languageModels.user` role"""
|
with the `ai.languageModels.user` role"""
|
||||||
api_key: SecretStr = ""
|
api_key: SecretStr = "" # type: ignore[assignment]
|
||||||
"""Yandex Cloud Api Key for service account
|
"""Yandex Cloud Api Key for service account
|
||||||
with the `ai.languageModels.user` role"""
|
with the `ai.languageModels.user` role"""
|
||||||
folder_id: str = ""
|
folder_id: str = ""
|
||||||
@ -211,7 +211,7 @@ def _make_request(
|
|||||||
messages=[Message(role="user", text=prompt)],
|
messages=[Message(role="user", text=prompt)],
|
||||||
)
|
)
|
||||||
stub = TextGenerationServiceStub(channel)
|
stub = TextGenerationServiceStub(channel)
|
||||||
res = stub.Completion(request, metadata=self._grpc_metadata)
|
res = stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
|
||||||
return list(res)[0].alternatives[0].message.text
|
return list(res)[0].alternatives[0].message.text
|
||||||
|
|
||||||
|
|
||||||
@ -253,7 +253,7 @@ async def _amake_request(self: YandexGPT, prompt: str) -> str:
|
|||||||
messages=[Message(role="user", text=prompt)],
|
messages=[Message(role="user", text=prompt)],
|
||||||
)
|
)
|
||||||
stub = TextGenerationAsyncServiceStub(channel)
|
stub = TextGenerationAsyncServiceStub(channel)
|
||||||
operation = await stub.Completion(request, metadata=self._grpc_metadata)
|
operation = await stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
|
||||||
async with grpc.aio.secure_channel(
|
async with grpc.aio.secure_channel(
|
||||||
operation_api_url, channel_credentials
|
operation_api_url, channel_credentials
|
||||||
) as operation_channel:
|
) as operation_channel:
|
||||||
@ -262,7 +262,8 @@ async def _amake_request(self: YandexGPT, prompt: str) -> str:
|
|||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
operation_request = GetOperationRequest(operation_id=operation.id)
|
operation_request = GetOperationRequest(operation_id=operation.id)
|
||||||
operation = await operation_stub.Get(
|
operation = await operation_stub.Get(
|
||||||
operation_request, metadata=self._grpc_metadata
|
operation_request,
|
||||||
|
metadata=self._grpc_metadata, # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
|
|
||||||
completion_response = CompletionResponse()
|
completion_response = CompletionResponse()
|
||||||
|
@ -58,4 +58,4 @@ class AmadeusClosestAirport(AmadeusBaseTool):
|
|||||||
' Location Identifier" '
|
' Location Identifier" '
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.llm.invoke(content)
|
return self.llm.invoke(content) # type: ignore[union-attr]
|
||||||
|
@ -93,10 +93,10 @@ class ShellTool(BaseTool):
|
|||||||
return self.process.run(commands)
|
return self.process.run(commands)
|
||||||
else:
|
else:
|
||||||
logger.info("Invalid input. User aborted command execution.")
|
logger.info("Invalid input. User aborted command execution.")
|
||||||
return None
|
return None # type: ignore[return-value]
|
||||||
else:
|
else:
|
||||||
return self.process.run(commands)
|
return self.process.run(commands)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during command execution: {e}")
|
logger.error(f"Error during command execution: {e}")
|
||||||
return None
|
return None # type: ignore[return-value]
|
||||||
|
@ -48,7 +48,7 @@ class BraveSearchWrapper(BaseModel):
|
|||||||
results = self._search_request(query)
|
results = self._search_request(query)
|
||||||
return [
|
return [
|
||||||
Document(
|
Document(
|
||||||
page_content=item.get("description"),
|
page_content=item.get("description"), # type: ignore[arg-type]
|
||||||
metadata={"title": item.get("title"), "link": item.get("url")},
|
metadata={"title": item.get("title"), "link": item.get("url")},
|
||||||
)
|
)
|
||||||
for item in results
|
for item in results
|
||||||
|
@ -141,9 +141,9 @@ class GenericRequestsWrapper(BaseModel):
|
|||||||
self, response: aiohttp.ClientResponse
|
self, response: aiohttp.ClientResponse
|
||||||
) -> Union[str, Dict[str, Any]]:
|
) -> Union[str, Dict[str, Any]]:
|
||||||
if self.response_content_type == "text":
|
if self.response_content_type == "text":
|
||||||
return response.text()
|
return response.text() # type: ignore[return-value]
|
||||||
elif self.response_content_type == "json":
|
elif self.response_content_type == "json":
|
||||||
return response.json()
|
return response.json() # type: ignore[return-value]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid return type: {self.response_content_type}")
|
raise ValueError(f"Invalid return type: {self.response_content_type}")
|
||||||
|
|
||||||
@ -176,33 +176,33 @@ class GenericRequestsWrapper(BaseModel):
|
|||||||
async def aget(self, url: str, **kwargs: Any) -> Union[str, Dict[str, Any]]:
|
async def aget(self, url: str, **kwargs: Any) -> Union[str, Dict[str, Any]]:
|
||||||
"""GET the URL and return the text asynchronously."""
|
"""GET the URL and return the text asynchronously."""
|
||||||
async with self.requests.aget(url, **kwargs) as response:
|
async with self.requests.aget(url, **kwargs) as response:
|
||||||
return await self._aget_resp_content(response)
|
return await self._aget_resp_content(response) # type: ignore[misc]
|
||||||
|
|
||||||
async def apost(
|
async def apost(
|
||||||
self, url: str, data: Dict[str, Any], **kwargs: Any
|
self, url: str, data: Dict[str, Any], **kwargs: Any
|
||||||
) -> Union[str, Dict[str, Any]]:
|
) -> Union[str, Dict[str, Any]]:
|
||||||
"""POST to the URL and return the text asynchronously."""
|
"""POST to the URL and return the text asynchronously."""
|
||||||
async with self.requests.apost(url, data, **kwargs) as response:
|
async with self.requests.apost(url, data, **kwargs) as response:
|
||||||
return await self._aget_resp_content(response)
|
return await self._aget_resp_content(response) # type: ignore[misc]
|
||||||
|
|
||||||
async def apatch(
|
async def apatch(
|
||||||
self, url: str, data: Dict[str, Any], **kwargs: Any
|
self, url: str, data: Dict[str, Any], **kwargs: Any
|
||||||
) -> Union[str, Dict[str, Any]]:
|
) -> Union[str, Dict[str, Any]]:
|
||||||
"""PATCH the URL and return the text asynchronously."""
|
"""PATCH the URL and return the text asynchronously."""
|
||||||
async with self.requests.apatch(url, data, **kwargs) as response:
|
async with self.requests.apatch(url, data, **kwargs) as response:
|
||||||
return await self._aget_resp_content(response)
|
return await self._aget_resp_content(response) # type: ignore[misc]
|
||||||
|
|
||||||
async def aput(
|
async def aput(
|
||||||
self, url: str, data: Dict[str, Any], **kwargs: Any
|
self, url: str, data: Dict[str, Any], **kwargs: Any
|
||||||
) -> Union[str, Dict[str, Any]]:
|
) -> Union[str, Dict[str, Any]]:
|
||||||
"""PUT the URL and return the text asynchronously."""
|
"""PUT the URL and return the text asynchronously."""
|
||||||
async with self.requests.aput(url, data, **kwargs) as response:
|
async with self.requests.aput(url, data, **kwargs) as response:
|
||||||
return await self._aget_resp_content(response)
|
return await self._aget_resp_content(response) # type: ignore[misc]
|
||||||
|
|
||||||
async def adelete(self, url: str, **kwargs: Any) -> Union[str, Dict[str, Any]]:
|
async def adelete(self, url: str, **kwargs: Any) -> Union[str, Dict[str, Any]]:
|
||||||
"""DELETE the URL and return the text asynchronously."""
|
"""DELETE the URL and return the text asynchronously."""
|
||||||
async with self.requests.adelete(url, **kwargs) as response:
|
async with self.requests.adelete(url, **kwargs) as response:
|
||||||
return await self._aget_resp_content(response)
|
return await self._aget_resp_content(response) # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
class JsonRequestsWrapper(GenericRequestsWrapper):
|
class JsonRequestsWrapper(GenericRequestsWrapper):
|
||||||
|
@ -381,7 +381,7 @@ class SQLDatabase:
|
|||||||
|
|
||||||
If the statement returns no rows, an empty list is returned.
|
If the statement returns no rows, an empty list is returned.
|
||||||
"""
|
"""
|
||||||
with self._engine.begin() as connection: # type: Connection
|
with self._engine.begin() as connection: # type: Connection # type: ignore[name-defined]
|
||||||
if self._schema is not None:
|
if self._schema is not None:
|
||||||
if self.dialect == "snowflake":
|
if self.dialect == "snowflake":
|
||||||
connection.exec_driver_sql(
|
connection.exec_driver_sql(
|
||||||
@ -444,7 +444,7 @@ class SQLDatabase:
|
|||||||
]
|
]
|
||||||
|
|
||||||
if not include_columns:
|
if not include_columns:
|
||||||
res = [tuple(row.values()) for row in res]
|
res = [tuple(row.values()) for row in res] # type: ignore[misc]
|
||||||
|
|
||||||
if not res:
|
if not res:
|
||||||
return ""
|
return ""
|
||||||
|
@ -356,7 +356,7 @@ class AlibabaCloudOpenSearch(VectorStore):
|
|||||||
"fields" not in item
|
"fields" not in item
|
||||||
or self.config.field_name_mapping["document"] not in item["fields"]
|
or self.config.field_name_mapping["document"] not in item["fields"]
|
||||||
):
|
):
|
||||||
query_result_list.append(Document())
|
query_result_list.append(Document()) # type: ignore[call-arg]
|
||||||
else:
|
else:
|
||||||
fields = item["fields"]
|
fields = item["fields"]
|
||||||
query_result_list.append(
|
query_result_list.append(
|
||||||
|
@ -140,7 +140,7 @@ class AstraDB(VectorStore):
|
|||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
metadata_filter[k] = [AstraDB._filter_to_metadata(f) for f in v]
|
metadata_filter[k] = [AstraDB._filter_to_metadata(f) for f in v]
|
||||||
else:
|
else:
|
||||||
metadata_filter[k] = AstraDB._filter_to_metadata(v)
|
metadata_filter[k] = AstraDB._filter_to_metadata(v) # type: ignore[assignment]
|
||||||
else:
|
else:
|
||||||
metadata_filter[f"metadata.{k}"] = v
|
metadata_filter[f"metadata.{k}"] = v
|
||||||
|
|
||||||
@ -253,13 +253,13 @@ class AstraDB(VectorStore):
|
|||||||
else:
|
else:
|
||||||
self.clear()
|
self.clear()
|
||||||
|
|
||||||
def _ensure_astra_db_client(self):
|
def _ensure_astra_db_client(self): # type: ignore[no-untyped-def]
|
||||||
if not self.astra_db:
|
if not self.astra_db:
|
||||||
raise ValueError("Missing AstraDB client")
|
raise ValueError("Missing AstraDB client")
|
||||||
|
|
||||||
async def _setup_db(self, pre_delete_collection: bool) -> None:
|
async def _setup_db(self, pre_delete_collection: bool) -> None:
|
||||||
if pre_delete_collection:
|
if pre_delete_collection:
|
||||||
await self.async_astra_db.delete_collection(
|
await self.async_astra_db.delete_collection( # type: ignore[union-attr]
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
)
|
)
|
||||||
await self._aprovision_collection()
|
await self._aprovision_collection()
|
||||||
@ -282,7 +282,7 @@ class AstraDB(VectorStore):
|
|||||||
Internal-usage method, no object members are set,
|
Internal-usage method, no object members are set,
|
||||||
other than working on the underlying actual storage.
|
other than working on the underlying actual storage.
|
||||||
"""
|
"""
|
||||||
self.astra_db.create_collection(
|
self.astra_db.create_collection( # type: ignore[union-attr]
|
||||||
dimension=self._get_embedding_dimension(),
|
dimension=self._get_embedding_dimension(),
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
metric=self.metric,
|
metric=self.metric,
|
||||||
@ -295,7 +295,7 @@ class AstraDB(VectorStore):
|
|||||||
Internal-usage method, no object members are set,
|
Internal-usage method, no object members are set,
|
||||||
other than working on the underlying actual storage.
|
other than working on the underlying actual storage.
|
||||||
"""
|
"""
|
||||||
await self.async_astra_db.create_collection(
|
await self.async_astra_db.create_collection( # type: ignore[union-attr]
|
||||||
dimension=self._get_embedding_dimension(),
|
dimension=self._get_embedding_dimension(),
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
metric=self.metric,
|
metric=self.metric,
|
||||||
@ -328,7 +328,7 @@ class AstraDB(VectorStore):
|
|||||||
await self._ensure_db_setup()
|
await self._ensure_db_setup()
|
||||||
if not self.async_astra_db:
|
if not self.async_astra_db:
|
||||||
await run_in_executor(None, self.clear)
|
await run_in_executor(None, self.clear)
|
||||||
await self.async_collection.delete_many({})
|
await self.async_collection.delete_many({}) # type: ignore[union-attr]
|
||||||
|
|
||||||
def delete_by_document_id(self, document_id: str) -> bool:
|
def delete_by_document_id(self, document_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -336,7 +336,7 @@ class AstraDB(VectorStore):
|
|||||||
Return True if a document has indeed been deleted, False if ID not found.
|
Return True if a document has indeed been deleted, False if ID not found.
|
||||||
"""
|
"""
|
||||||
self._ensure_astra_db_client()
|
self._ensure_astra_db_client()
|
||||||
deletion_response = self.collection.delete_one(document_id)
|
deletion_response = self.collection.delete_one(document_id) # type: ignore[union-attr]
|
||||||
return ((deletion_response or {}).get("status") or {}).get(
|
return ((deletion_response or {}).get("status") or {}).get(
|
||||||
"deletedCount", 0
|
"deletedCount", 0
|
||||||
) == 1
|
) == 1
|
||||||
@ -434,7 +434,7 @@ class AstraDB(VectorStore):
|
|||||||
Use with caution.
|
Use with caution.
|
||||||
"""
|
"""
|
||||||
self._ensure_astra_db_client()
|
self._ensure_astra_db_client()
|
||||||
self.astra_db.delete_collection(
|
self.astra_db.delete_collection( # type: ignore[union-attr]
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -448,7 +448,7 @@ class AstraDB(VectorStore):
|
|||||||
await self._ensure_db_setup()
|
await self._ensure_db_setup()
|
||||||
if not self.async_astra_db:
|
if not self.async_astra_db:
|
||||||
await run_in_executor(None, self.delete_collection)
|
await run_in_executor(None, self.delete_collection)
|
||||||
await self.async_astra_db.delete_collection(
|
await self.async_astra_db.delete_collection( # type: ignore[union-attr]
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -571,7 +571,7 @@ class AstraDB(VectorStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _handle_batch(document_batch: List[DocDict]) -> List[str]:
|
def _handle_batch(document_batch: List[DocDict]) -> List[str]:
|
||||||
im_result = self.collection.insert_many(
|
im_result = self.collection.insert_many( # type: ignore[union-attr]
|
||||||
documents=document_batch,
|
documents=document_batch,
|
||||||
options={"ordered": False},
|
options={"ordered": False},
|
||||||
partial_failures_allowed=True,
|
partial_failures_allowed=True,
|
||||||
@ -581,7 +581,7 @@ class AstraDB(VectorStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _handle_missing_document(missing_document: DocDict) -> str:
|
def _handle_missing_document(missing_document: DocDict) -> str:
|
||||||
replacement_result = self.collection.find_one_and_replace(
|
replacement_result = self.collection.find_one_and_replace( # type: ignore[union-attr]
|
||||||
filter={"_id": missing_document["_id"]},
|
filter={"_id": missing_document["_id"]},
|
||||||
replacement=missing_document,
|
replacement=missing_document,
|
||||||
)
|
)
|
||||||
@ -672,7 +672,7 @@ class AstraDB(VectorStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_batch(document_batch: List[DocDict]) -> List[str]:
|
async def _handle_batch(document_batch: List[DocDict]) -> List[str]:
|
||||||
im_result = await self.async_collection.insert_many(
|
im_result = await self.async_collection.insert_many( # type: ignore[union-attr]
|
||||||
documents=document_batch,
|
documents=document_batch,
|
||||||
options={"ordered": False},
|
options={"ordered": False},
|
||||||
partial_failures_allowed=True,
|
partial_failures_allowed=True,
|
||||||
@ -682,7 +682,7 @@ class AstraDB(VectorStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_missing_document(missing_document: DocDict) -> str:
|
async def _handle_missing_document(missing_document: DocDict) -> str:
|
||||||
replacement_result = await self.async_collection.find_one_and_replace(
|
replacement_result = await self.async_collection.find_one_and_replace( # type: ignore[union-attr]
|
||||||
filter={"_id": missing_document["_id"]},
|
filter={"_id": missing_document["_id"]},
|
||||||
replacement=missing_document,
|
replacement=missing_document,
|
||||||
)
|
)
|
||||||
@ -729,7 +729,7 @@ class AstraDB(VectorStore):
|
|||||||
metadata_parameter = self._filter_to_metadata(filter)
|
metadata_parameter = self._filter_to_metadata(filter)
|
||||||
#
|
#
|
||||||
hits = list(
|
hits = list(
|
||||||
self.collection.paginated_find(
|
self.collection.paginated_find( # type: ignore[union-attr]
|
||||||
filter=metadata_parameter,
|
filter=metadata_parameter,
|
||||||
sort={"$vector": embedding},
|
sort={"$vector": embedding},
|
||||||
options={"limit": k, "includeSimilarity": True},
|
options={"limit": k, "includeSimilarity": True},
|
||||||
@ -771,7 +771,7 @@ class AstraDB(VectorStore):
|
|||||||
if not self.async_collection:
|
if not self.async_collection:
|
||||||
return await run_in_executor(
|
return await run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.asimilarity_search_with_score_id_by_vector,
|
self.asimilarity_search_with_score_id_by_vector, # type: ignore[arg-type]
|
||||||
embedding,
|
embedding,
|
||||||
k,
|
k,
|
||||||
filter,
|
filter,
|
||||||
@ -962,7 +962,7 @@ class AstraDB(VectorStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_mmr_hits(embedding, k, lambda_mult, prefetch_hits):
|
def _get_mmr_hits(embedding, k, lambda_mult, prefetch_hits): # type: ignore[no-untyped-def]
|
||||||
mmr_chosen_indices = maximal_marginal_relevance(
|
mmr_chosen_indices = maximal_marginal_relevance(
|
||||||
np.array(embedding, dtype=np.float32),
|
np.array(embedding, dtype=np.float32),
|
||||||
[prefetch_hit["$vector"] for prefetch_hit in prefetch_hits],
|
[prefetch_hit["$vector"] for prefetch_hit in prefetch_hits],
|
||||||
@ -1008,7 +1008,7 @@ class AstraDB(VectorStore):
|
|||||||
metadata_parameter = self._filter_to_metadata(filter)
|
metadata_parameter = self._filter_to_metadata(filter)
|
||||||
|
|
||||||
prefetch_hits = list(
|
prefetch_hits = list(
|
||||||
self.collection.paginated_find(
|
self.collection.paginated_find( # type: ignore[union-attr]
|
||||||
filter=metadata_parameter,
|
filter=metadata_parameter,
|
||||||
sort={"$vector": embedding},
|
sort={"$vector": embedding},
|
||||||
options={"limit": fetch_k, "includeSimilarity": True},
|
options={"limit": fetch_k, "includeSimilarity": True},
|
||||||
@ -1228,7 +1228,7 @@ class AstraDB(VectorStore):
|
|||||||
batch_concurrency=kwargs.get("batch_concurrency"),
|
batch_concurrency=kwargs.get("batch_concurrency"),
|
||||||
overwrite_concurrency=kwargs.get("overwrite_concurrency"),
|
overwrite_concurrency=kwargs.get("overwrite_concurrency"),
|
||||||
)
|
)
|
||||||
return astra_db_store
|
return astra_db_store # type: ignore[return-value]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def afrom_texts(
|
async def afrom_texts(
|
||||||
@ -1263,7 +1263,7 @@ class AstraDB(VectorStore):
|
|||||||
batch_concurrency=kwargs.get("batch_concurrency"),
|
batch_concurrency=kwargs.get("batch_concurrency"),
|
||||||
overwrite_concurrency=kwargs.get("overwrite_concurrency"),
|
overwrite_concurrency=kwargs.get("overwrite_concurrency"),
|
||||||
)
|
)
|
||||||
return astra_db_store
|
return astra_db_store # type: ignore[return-value]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_documents(
|
def from_documents(
|
||||||
|
@ -339,7 +339,7 @@ class AzureSearch(VectorStore):
|
|||||||
# batching support if embedding function is an Embeddings object
|
# batching support if embedding function is an Embeddings object
|
||||||
if isinstance(self.embedding_function, Embeddings):
|
if isinstance(self.embedding_function, Embeddings):
|
||||||
try:
|
try:
|
||||||
embeddings = self.embedding_function.embed_documents(texts)
|
embeddings = self.embedding_function.embed_documents(texts) # type: ignore[arg-type]
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
embeddings = [self.embedding_function.embed_query(x) for x in texts]
|
embeddings = [self.embedding_function.embed_query(x) for x in texts]
|
||||||
else:
|
else:
|
||||||
|
@ -222,7 +222,7 @@ class BigQueryVectorSearch(VectorStore):
|
|||||||
self._logger.debug("Vector index already exists.")
|
self._logger.debug("Vector index already exists.")
|
||||||
self._have_index = True
|
self._have_index = True
|
||||||
|
|
||||||
def _create_index_in_background(self):
|
def _create_index_in_background(self): # type: ignore[no-untyped-def]
|
||||||
if self._have_index or self._creating_index:
|
if self._have_index or self._creating_index:
|
||||||
# Already have an index or in the process of creating one.
|
# Already have an index or in the process of creating one.
|
||||||
return
|
return
|
||||||
@ -231,7 +231,7 @@ class BigQueryVectorSearch(VectorStore):
|
|||||||
thread = Thread(target=self._create_index, daemon=True)
|
thread = Thread(target=self._create_index, daemon=True)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
def _create_index(self):
|
def _create_index(self): # type: ignore[no-untyped-def]
|
||||||
from google.api_core.exceptions import ClientError
|
from google.api_core.exceptions import ClientError
|
||||||
|
|
||||||
table = self.bq_client.get_table(self.vectors_table)
|
table = self.bq_client.get_table(self.vectors_table)
|
||||||
@ -289,7 +289,7 @@ class BigQueryVectorSearch(VectorStore):
|
|||||||
def full_table_id(self) -> str:
|
def full_table_id(self) -> str:
|
||||||
return self._full_table_id
|
return self._full_table_id
|
||||||
|
|
||||||
def add_texts(
|
def add_texts( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
@ -905,7 +905,7 @@ class DeepLake(VectorStore):
|
|||||||
return self.vectorstore.dataset
|
return self.vectorstore.dataset
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_kwargs(cls, kwargs, method_name):
|
def _validate_kwargs(cls, kwargs, method_name): # type: ignore[no-untyped-def]
|
||||||
if kwargs:
|
if kwargs:
|
||||||
valid_items = cls._get_valid_args(method_name)
|
valid_items = cls._get_valid_args(method_name)
|
||||||
unsupported_items = cls._get_unsupported_items(kwargs, valid_items)
|
unsupported_items = cls._get_unsupported_items(kwargs, valid_items)
|
||||||
@ -917,14 +917,14 @@ class DeepLake(VectorStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_valid_args(cls, method_name):
|
def _get_valid_args(cls, method_name): # type: ignore[no-untyped-def]
|
||||||
if method_name == "search":
|
if method_name == "search":
|
||||||
return cls._valid_search_kwargs
|
return cls._valid_search_kwargs
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_unsupported_items(kwargs, valid_items):
|
def _get_unsupported_items(kwargs, valid_items): # type: ignore[no-untyped-def]
|
||||||
kwargs = {k: v for k, v in kwargs.items() if k not in valid_items}
|
kwargs = {k: v for k, v in kwargs.items() if k not in valid_items}
|
||||||
unsupported_items = None
|
unsupported_items = None
|
||||||
if kwargs:
|
if kwargs:
|
||||||
|
@ -305,7 +305,7 @@ class FAISS(VectorStore):
|
|||||||
if filter is not None:
|
if filter is not None:
|
||||||
if isinstance(filter, dict):
|
if isinstance(filter, dict):
|
||||||
|
|
||||||
def filter_func(metadata):
|
def filter_func(metadata): # type: ignore[no-untyped-def]
|
||||||
if all(
|
if all(
|
||||||
metadata.get(key) in value
|
metadata.get(key) in value
|
||||||
if isinstance(value, list)
|
if isinstance(value, list)
|
||||||
@ -607,7 +607,7 @@ class FAISS(VectorStore):
|
|||||||
filtered_indices = []
|
filtered_indices = []
|
||||||
if isinstance(filter, dict):
|
if isinstance(filter, dict):
|
||||||
|
|
||||||
def filter_func(metadata):
|
def filter_func(metadata): # type: ignore[no-untyped-def]
|
||||||
if all(
|
if all(
|
||||||
metadata.get(key) in value
|
metadata.get(key) in value
|
||||||
if isinstance(value, list)
|
if isinstance(value, list)
|
||||||
|
@ -117,7 +117,7 @@ class HanaDB(VectorStore):
|
|||||||
self.vector_column_length,
|
self.vector_column_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _table_exists(self, table_name) -> bool:
|
def _table_exists(self, table_name) -> bool: # type: ignore[no-untyped-def]
|
||||||
sql_str = (
|
sql_str = (
|
||||||
"SELECT COUNT(*) FROM SYS.TABLES WHERE SCHEMA_NAME = CURRENT_SCHEMA"
|
"SELECT COUNT(*) FROM SYS.TABLES WHERE SCHEMA_NAME = CURRENT_SCHEMA"
|
||||||
" AND TABLE_NAME = ?"
|
" AND TABLE_NAME = ?"
|
||||||
@ -133,7 +133,7 @@ class HanaDB(VectorStore):
|
|||||||
cur.close()
|
cur.close()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _check_column(self, table_name, column_name, column_type, column_length=None):
|
def _check_column(self, table_name, column_name, column_type, column_length=None): # type: ignore[no-untyped-def]
|
||||||
sql_str = (
|
sql_str = (
|
||||||
"SELECT DATA_TYPE_NAME, LENGTH FROM SYS.TABLE_COLUMNS WHERE "
|
"SELECT DATA_TYPE_NAME, LENGTH FROM SYS.TABLE_COLUMNS WHERE "
|
||||||
"SCHEMA_NAME = CURRENT_SCHEMA "
|
"SCHEMA_NAME = CURRENT_SCHEMA "
|
||||||
@ -166,17 +166,17 @@ class HanaDB(VectorStore):
|
|||||||
def embeddings(self) -> Embeddings:
|
def embeddings(self) -> Embeddings:
|
||||||
return self.embedding
|
return self.embedding
|
||||||
|
|
||||||
def _sanitize_name(input_str: str) -> str:
|
def _sanitize_name(input_str: str) -> str: # type: ignore[misc]
|
||||||
# Remove characters that are not alphanumeric or underscores
|
# Remove characters that are not alphanumeric or underscores
|
||||||
return re.sub(r"[^a-zA-Z0-9_]", "", input_str)
|
return re.sub(r"[^a-zA-Z0-9_]", "", input_str)
|
||||||
|
|
||||||
def _sanitize_int(input_int: any) -> int:
|
def _sanitize_int(input_int: any) -> int: # type: ignore[valid-type]
|
||||||
value = int(str(input_int))
|
value = int(str(input_int))
|
||||||
if value < -1:
|
if value < -1:
|
||||||
raise ValueError(f"Value ({value}) must not be smaller than -1")
|
raise ValueError(f"Value ({value}) must not be smaller than -1")
|
||||||
return int(str(input_int))
|
return int(str(input_int))
|
||||||
|
|
||||||
def _sanitize_list_float(embedding: List[float]) -> List[float]:
|
def _sanitize_list_float(embedding: List[float]) -> List[float]: # type: ignore[misc]
|
||||||
for value in embedding:
|
for value in embedding:
|
||||||
if not isinstance(value, float):
|
if not isinstance(value, float):
|
||||||
raise ValueError(f"Value ({value}) does not have type float")
|
raise ValueError(f"Value ({value}) does not have type float")
|
||||||
@ -185,14 +185,14 @@ class HanaDB(VectorStore):
|
|||||||
# Compile pattern only once, for better performance
|
# Compile pattern only once, for better performance
|
||||||
_compiled_pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$")
|
_compiled_pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$")
|
||||||
|
|
||||||
def _sanitize_metadata_keys(metadata: dict) -> dict:
|
def _sanitize_metadata_keys(metadata: dict) -> dict: # type: ignore[misc]
|
||||||
for key in metadata.keys():
|
for key in metadata.keys():
|
||||||
if not HanaDB._compiled_pattern.match(key):
|
if not HanaDB._compiled_pattern.match(key):
|
||||||
raise ValueError(f"Invalid metadata key {key}")
|
raise ValueError(f"Invalid metadata key {key}")
|
||||||
|
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def add_texts(
|
def add_texts( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
@ -243,7 +243,7 @@ class HanaDB(VectorStore):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts( # type: ignore[no-untyped-def, override]
|
||||||
cls: Type[HanaDB],
|
cls: Type[HanaDB],
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
@ -277,7 +277,7 @@ class HanaDB(VectorStore):
|
|||||||
instance.add_texts(texts, metadatas)
|
instance.add_texts(texts, metadatas)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
def similarity_search(
|
def similarity_search( # type: ignore[override]
|
||||||
self, query: str, k: int = 4, filter: Optional[dict] = None
|
self, query: str, k: int = 4, filter: Optional[dict] = None
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Return docs most similar to query.
|
"""Return docs most similar to query.
|
||||||
@ -382,7 +382,7 @@ class HanaDB(VectorStore):
|
|||||||
)
|
)
|
||||||
return [(result_item[0], result_item[1]) for result_item in whole_result]
|
return [(result_item[0], result_item[1]) for result_item in whole_result]
|
||||||
|
|
||||||
def similarity_search_by_vector(
|
def similarity_search_by_vector( # type: ignore[override]
|
||||||
self, embedding: List[float], k: int = 4, filter: Optional[dict] = None
|
self, embedding: List[float], k: int = 4, filter: Optional[dict] = None
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Return docs most similar to embedding vector.
|
"""Return docs most similar to embedding vector.
|
||||||
@ -401,7 +401,7 @@ class HanaDB(VectorStore):
|
|||||||
)
|
)
|
||||||
return [doc for doc, _ in docs_and_scores]
|
return [doc for doc, _ in docs_and_scores]
|
||||||
|
|
||||||
def _create_where_by_filter(self, filter):
|
def _create_where_by_filter(self, filter): # type: ignore[no-untyped-def]
|
||||||
query_tuple = []
|
query_tuple = []
|
||||||
where_str = ""
|
where_str = ""
|
||||||
if filter:
|
if filter:
|
||||||
@ -427,7 +427,7 @@ class HanaDB(VectorStore):
|
|||||||
|
|
||||||
return where_str, query_tuple
|
return where_str, query_tuple
|
||||||
|
|
||||||
def delete(
|
def delete( # type: ignore[override]
|
||||||
self, ids: Optional[List[str]] = None, filter: Optional[dict] = None
|
self, ids: Optional[List[str]] = None, filter: Optional[dict] = None
|
||||||
) -> Optional[bool]:
|
) -> Optional[bool]:
|
||||||
"""Delete entries by filter with metadata values
|
"""Delete entries by filter with metadata values
|
||||||
@ -459,7 +459,7 @@ class HanaDB(VectorStore):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def adelete(
|
async def adelete( # type: ignore[override]
|
||||||
self, ids: Optional[List[str]] = None, filter: Optional[dict] = None
|
self, ids: Optional[List[str]] = None, filter: Optional[dict] = None
|
||||||
) -> Optional[bool]:
|
) -> Optional[bool]:
|
||||||
"""Delete by vector ID or other criteria.
|
"""Delete by vector ID or other criteria.
|
||||||
@ -473,7 +473,7 @@ class HanaDB(VectorStore):
|
|||||||
"""
|
"""
|
||||||
return await run_in_executor(None, self.delete, ids=ids, filter=filter)
|
return await run_in_executor(None, self.delete, ids=ids, filter=filter)
|
||||||
|
|
||||||
def max_marginal_relevance_search(
|
def max_marginal_relevance_search( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
@ -511,11 +511,11 @@ class HanaDB(VectorStore):
|
|||||||
filter=filter,
|
filter=filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_float_array_from_string(array_as_string: str) -> List[float]:
|
def _parse_float_array_from_string(array_as_string: str) -> List[float]: # type: ignore[misc]
|
||||||
array_wo_brackets = array_as_string[1:-1]
|
array_wo_brackets = array_as_string[1:-1]
|
||||||
return [float(x) for x in array_wo_brackets.split(",")]
|
return [float(x) for x in array_wo_brackets.split(",")]
|
||||||
|
|
||||||
def max_marginal_relevance_search_by_vector(
|
def max_marginal_relevance_search_by_vector( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
@ -533,7 +533,7 @@ class HanaDB(VectorStore):
|
|||||||
|
|
||||||
return [whole_result[i][0] for i in mmr_doc_indexes]
|
return [whole_result[i][0] for i in mmr_doc_indexes]
|
||||||
|
|
||||||
async def amax_marginal_relevance_search_by_vector(
|
async def amax_marginal_relevance_search_by_vector( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
|
@ -135,7 +135,7 @@ class Jaguar(VectorStore):
|
|||||||
def embeddings(self) -> Optional[Embeddings]:
|
def embeddings(self) -> Optional[Embeddings]:
|
||||||
return self._embedding
|
return self._embedding
|
||||||
|
|
||||||
def add_texts(
|
def add_texts( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
@ -351,7 +351,7 @@ class Jaguar(VectorStore):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts( # type: ignore[override]
|
||||||
cls,
|
cls,
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
@ -383,7 +383,7 @@ class Jaguar(VectorStore):
|
|||||||
q = "truncate store " + podstore
|
q = "truncate store " + podstore
|
||||||
self.run(q)
|
self.run(q)
|
||||||
|
|
||||||
def delete(self, zids: List[str], **kwargs: Any) -> None:
|
def delete(self, zids: List[str], **kwargs: Any) -> None: # type: ignore[override]
|
||||||
"""
|
"""
|
||||||
Delete records in jaguardb by a list of zero-ids
|
Delete records in jaguardb by a list of zero-ids
|
||||||
Args:
|
Args:
|
||||||
|
@ -554,10 +554,10 @@ class Milvus(VectorStore):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if not self.auto_id:
|
if not self.auto_id:
|
||||||
insert_dict[self._primary_field] = ids
|
insert_dict[self._primary_field] = ids # type: ignore[assignment]
|
||||||
|
|
||||||
if self._metadata_field is not None:
|
if self._metadata_field is not None:
|
||||||
for d in metadatas:
|
for d in metadatas: # type: ignore[union-attr]
|
||||||
insert_dict.setdefault(self._metadata_field, []).append(d)
|
insert_dict.setdefault(self._metadata_field, []).append(d)
|
||||||
else:
|
else:
|
||||||
# Collect the metadata into the insert dict.
|
# Collect the metadata into the insert dict.
|
||||||
@ -901,7 +901,7 @@ class Milvus(VectorStore):
|
|||||||
ret.append(documents[x])
|
ret.append(documents[x])
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def delete(
|
def delete( # type: ignore[no-untyped-def]
|
||||||
self, ids: Optional[List[str]] = None, expr: Optional[str] = None, **kwargs: str
|
self, ids: Optional[List[str]] = None, expr: Optional[str] = None, **kwargs: str
|
||||||
):
|
):
|
||||||
"""Delete by vector ID or boolean expression.
|
"""Delete by vector ID or boolean expression.
|
||||||
@ -923,7 +923,7 @@ class Milvus(VectorStore):
|
|||||||
assert isinstance(
|
assert isinstance(
|
||||||
expr, str
|
expr, str
|
||||||
), "Either ids list or expr string must be provided."
|
), "Either ids list or expr string must be provided."
|
||||||
return self.col.delete(expr=expr, **kwargs)
|
return self.col.delete(expr=expr, **kwargs) # type: ignore[union-attr]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
|
@ -398,7 +398,7 @@ class PGEmbedding(VectorStore):
|
|||||||
docs = [
|
docs = [
|
||||||
(
|
(
|
||||||
Document(
|
Document(
|
||||||
page_content=result.EmbeddingStore.document,
|
page_content=result.EmbeddingStore.document, # type: ignore[arg-type]
|
||||||
metadata=result.EmbeddingStore.cmetadata,
|
metadata=result.EmbeddingStore.cmetadata,
|
||||||
),
|
),
|
||||||
result.distance if self.embedding_function is not None else 0.0,
|
result.distance if self.embedding_function is not None else 0.0,
|
||||||
|
@ -133,7 +133,7 @@ class PGVecto_rs(VectorStore):
|
|||||||
Record.from_text(text, embedding, meta)
|
Record.from_text(text, embedding, meta)
|
||||||
for text, embedding, meta in zip(texts, embeddings, metadatas or [])
|
for text, embedding, meta in zip(texts, embeddings, metadatas or [])
|
||||||
]
|
]
|
||||||
self._store.insert(records)
|
self._store.insert(records) # type: ignore[union-attr]
|
||||||
return [str(record.id) for record in records]
|
return [str(record.id) for record in records]
|
||||||
|
|
||||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||||
@ -177,7 +177,7 @@ class PGVecto_rs(VectorStore):
|
|||||||
real_filter = meta_contains(filter)
|
real_filter = meta_contains(filter)
|
||||||
else:
|
else:
|
||||||
real_filter = filter
|
real_filter = filter
|
||||||
results = self._store.search(
|
results = self._store.search( # type: ignore[union-attr]
|
||||||
query_vector,
|
query_vector,
|
||||||
distance_func_map[distance_func],
|
distance_func_map[distance_func],
|
||||||
k,
|
k,
|
||||||
|
@ -238,7 +238,7 @@ class PGVector(VectorStore):
|
|||||||
|
|
||||||
def create_vector_extension(self) -> None:
|
def create_vector_extension(self) -> None:
|
||||||
try:
|
try:
|
||||||
with Session(self._bind) as session:
|
with Session(self._bind) as session: # type: ignore[arg-type]
|
||||||
# The advisor lock fixes issue arising from concurrent
|
# The advisor lock fixes issue arising from concurrent
|
||||||
# creation of the vector extension.
|
# creation of the vector extension.
|
||||||
# https://github.com/langchain-ai/langchain/issues/12933
|
# https://github.com/langchain-ai/langchain/issues/12933
|
||||||
@ -256,24 +256,24 @@ class PGVector(VectorStore):
|
|||||||
raise Exception(f"Failed to create vector extension: {e}") from e
|
raise Exception(f"Failed to create vector extension: {e}") from e
|
||||||
|
|
||||||
def create_tables_if_not_exists(self) -> None:
|
def create_tables_if_not_exists(self) -> None:
|
||||||
with Session(self._bind) as session, session.begin():
|
with Session(self._bind) as session, session.begin(): # type: ignore[arg-type]
|
||||||
Base.metadata.create_all(session.get_bind())
|
Base.metadata.create_all(session.get_bind())
|
||||||
|
|
||||||
def drop_tables(self) -> None:
|
def drop_tables(self) -> None:
|
||||||
with Session(self._bind) as session, session.begin():
|
with Session(self._bind) as session, session.begin(): # type: ignore[arg-type]
|
||||||
Base.metadata.drop_all(session.get_bind())
|
Base.metadata.drop_all(session.get_bind())
|
||||||
|
|
||||||
def create_collection(self) -> None:
|
def create_collection(self) -> None:
|
||||||
if self.pre_delete_collection:
|
if self.pre_delete_collection:
|
||||||
self.delete_collection()
|
self.delete_collection()
|
||||||
with Session(self._bind) as session:
|
with Session(self._bind) as session: # type: ignore[arg-type]
|
||||||
self.CollectionStore.get_or_create(
|
self.CollectionStore.get_or_create(
|
||||||
session, self.collection_name, cmetadata=self.collection_metadata
|
session, self.collection_name, cmetadata=self.collection_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_collection(self) -> None:
|
def delete_collection(self) -> None:
|
||||||
self.logger.debug("Trying to delete collection")
|
self.logger.debug("Trying to delete collection")
|
||||||
with Session(self._bind) as session:
|
with Session(self._bind) as session: # type: ignore[arg-type]
|
||||||
collection = self.get_collection(session)
|
collection = self.get_collection(session)
|
||||||
if not collection:
|
if not collection:
|
||||||
self.logger.warning("Collection not found")
|
self.logger.warning("Collection not found")
|
||||||
@ -284,7 +284,7 @@ class PGVector(VectorStore):
|
|||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _make_session(self) -> Generator[Session, None, None]:
|
def _make_session(self) -> Generator[Session, None, None]:
|
||||||
"""Create a context manager for the session, bind to _conn string."""
|
"""Create a context manager for the session, bind to _conn string."""
|
||||||
yield Session(self._bind)
|
yield Session(self._bind) # type: ignore[arg-type]
|
||||||
|
|
||||||
def delete(
|
def delete(
|
||||||
self,
|
self,
|
||||||
@ -298,7 +298,7 @@ class PGVector(VectorStore):
|
|||||||
ids: List of ids to delete.
|
ids: List of ids to delete.
|
||||||
collection_only: Only delete ids in the collection.
|
collection_only: Only delete ids in the collection.
|
||||||
"""
|
"""
|
||||||
with Session(self._bind) as session:
|
with Session(self._bind) as session: # type: ignore[arg-type]
|
||||||
if ids is not None:
|
if ids is not None:
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
"Trying to delete vectors by ids (represented by the model "
|
"Trying to delete vectors by ids (represented by the model "
|
||||||
@ -383,7 +383,7 @@ class PGVector(VectorStore):
|
|||||||
if not metadatas:
|
if not metadatas:
|
||||||
metadatas = [{} for _ in texts]
|
metadatas = [{} for _ in texts]
|
||||||
|
|
||||||
with Session(self._bind) as session:
|
with Session(self._bind) as session: # type: ignore[arg-type]
|
||||||
collection = self.get_collection(session)
|
collection = self.get_collection(session)
|
||||||
if not collection:
|
if not collection:
|
||||||
raise ValueError("Collection not found")
|
raise ValueError("Collection not found")
|
||||||
@ -508,7 +508,7 @@ class PGVector(VectorStore):
|
|||||||
]
|
]
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def _create_filter_clause(self, key, value):
|
def _create_filter_clause(self, key, value): # type: ignore[no-untyped-def]
|
||||||
IN, NIN, BETWEEN, GT, LT, NE = "in", "nin", "between", "gt", "lt", "ne"
|
IN, NIN, BETWEEN, GT, LT, NE = "in", "nin", "between", "gt", "lt", "ne"
|
||||||
EQ, LIKE, CONTAINS, OR, AND = "eq", "like", "contains", "or", "and"
|
EQ, LIKE, CONTAINS, OR, AND = "eq", "like", "contains", "or", "and"
|
||||||
|
|
||||||
@ -575,7 +575,7 @@ class PGVector(VectorStore):
|
|||||||
filter: Optional[Dict[str, str]] = None,
|
filter: Optional[Dict[str, str]] = None,
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
"""Query the collection."""
|
"""Query the collection."""
|
||||||
with Session(self._bind) as session:
|
with Session(self._bind) as session: # type: ignore[arg-type]
|
||||||
collection = self.get_collection(session)
|
collection = self.get_collection(session)
|
||||||
if not collection:
|
if not collection:
|
||||||
raise ValueError("Collection not found")
|
raise ValueError("Collection not found")
|
||||||
|
@ -115,7 +115,7 @@ class SurrealDBStore(VectorStore):
|
|||||||
for idx, text in enumerate(texts):
|
for idx, text in enumerate(texts):
|
||||||
data = {"text": text, "embedding": embeddings[idx]}
|
data = {"text": text, "embedding": embeddings[idx]}
|
||||||
if metadatas is not None and idx < len(metadatas):
|
if metadatas is not None and idx < len(metadatas):
|
||||||
data["metadata"] = metadatas[idx]
|
data["metadata"] = metadatas[idx] # type: ignore[assignment]
|
||||||
record = await self.sdb.create(
|
record = await self.sdb.create(
|
||||||
self.collection,
|
self.collection,
|
||||||
data,
|
data,
|
||||||
|
@ -316,7 +316,7 @@ class TencentVectorDB(VectorStore):
|
|||||||
meta = result.get(self.field_metadata)
|
meta = result.get(self.field_metadata)
|
||||||
if meta is not None:
|
if meta is not None:
|
||||||
meta = json.loads(meta)
|
meta = json.loads(meta)
|
||||||
doc = Document(page_content=result.get(self.field_text), metadata=meta)
|
doc = Document(page_content=result.get(self.field_text), metadata=meta) # type: ignore[arg-type]
|
||||||
pair = (doc, result.get("score", 0.0))
|
pair = (doc, result.get("score", 0.0))
|
||||||
ret.append(pair)
|
ret.append(pair)
|
||||||
return ret
|
return ret
|
||||||
@ -374,7 +374,7 @@ class TencentVectorDB(VectorStore):
|
|||||||
meta = result.get(self.field_metadata)
|
meta = result.get(self.field_metadata)
|
||||||
if meta is not None:
|
if meta is not None:
|
||||||
meta = json.loads(meta)
|
meta = json.loads(meta)
|
||||||
doc = Document(page_content=result.get(self.field_text), metadata=meta)
|
doc = Document(page_content=result.get(self.field_text), metadata=meta) # type: ignore[arg-type]
|
||||||
documents.append(doc)
|
documents.append(doc)
|
||||||
ordered_result_embeddings.append(result.get(self.field_vector))
|
ordered_result_embeddings.append(result.get(self.field_vector))
|
||||||
# Get the new order of results.
|
# Get the new order of results.
|
||||||
|
@ -24,7 +24,7 @@ class NeuralDBVectorStore(VectorStore):
|
|||||||
underscore_attrs_are_private = True
|
underscore_attrs_are_private = True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _verify_thirdai_library(thirdai_key: Optional[str] = None):
|
def _verify_thirdai_library(thirdai_key: Optional[str] = None): # type: ignore[no-untyped-def]
|
||||||
try:
|
try:
|
||||||
from thirdai import licensing
|
from thirdai import licensing
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ class NeuralDBVectorStore(VectorStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_scratch(
|
def from_scratch( # type: ignore[no-untyped-def, no-untyped-def]
|
||||||
cls,
|
cls,
|
||||||
thirdai_key: Optional[str] = None,
|
thirdai_key: Optional[str] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@ -69,10 +69,10 @@ class NeuralDBVectorStore(VectorStore):
|
|||||||
NeuralDBVectorStore._verify_thirdai_library(thirdai_key)
|
NeuralDBVectorStore._verify_thirdai_library(thirdai_key)
|
||||||
from thirdai import neural_db as ndb
|
from thirdai import neural_db as ndb
|
||||||
|
|
||||||
return cls(db=ndb.NeuralDB(**model_kwargs))
|
return cls(db=ndb.NeuralDB(**model_kwargs)) # type: ignore[call-arg]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_bazaar(
|
def from_bazaar( # type: ignore[no-untyped-def]
|
||||||
cls,
|
cls,
|
||||||
base: str,
|
base: str,
|
||||||
bazaar_cache: Optional[str] = None,
|
bazaar_cache: Optional[str] = None,
|
||||||
@ -111,10 +111,10 @@ class NeuralDBVectorStore(VectorStore):
|
|||||||
os.mkdir(cache)
|
os.mkdir(cache)
|
||||||
model_bazaar = ndb.Bazaar(cache)
|
model_bazaar = ndb.Bazaar(cache)
|
||||||
model_bazaar.fetch()
|
model_bazaar.fetch()
|
||||||
return cls(db=model_bazaar.get_model(base))
|
return cls(db=model_bazaar.get_model(base)) # type: ignore[call-arg]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_checkpoint(
|
def from_checkpoint( # type: ignore[no-untyped-def]
|
||||||
cls,
|
cls,
|
||||||
checkpoint: Union[str, Path],
|
checkpoint: Union[str, Path],
|
||||||
thirdai_key: Optional[str] = None,
|
thirdai_key: Optional[str] = None,
|
||||||
@ -146,7 +146,7 @@ class NeuralDBVectorStore(VectorStore):
|
|||||||
NeuralDBVectorStore._verify_thirdai_library(thirdai_key)
|
NeuralDBVectorStore._verify_thirdai_library(thirdai_key)
|
||||||
from thirdai import neural_db as ndb
|
from thirdai import neural_db as ndb
|
||||||
|
|
||||||
return cls(db=ndb.NeuralDB.from_checkpoint(checkpoint))
|
return cls(db=ndb.NeuralDB.from_checkpoint(checkpoint)) # type: ignore[call-arg]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
@ -187,11 +187,11 @@ class NeuralDBVectorStore(VectorStore):
|
|||||||
df = pd.DataFrame({"texts": texts})
|
df = pd.DataFrame({"texts": texts})
|
||||||
if metadatas:
|
if metadatas:
|
||||||
df = pd.concat([df, pd.DataFrame.from_records(metadatas)], axis=1)
|
df = pd.concat([df, pd.DataFrame.from_records(metadatas)], axis=1)
|
||||||
temp = tempfile.NamedTemporaryFile("w", delete=False, delete_on_close=False)
|
temp = tempfile.NamedTemporaryFile("w", delete=False, delete_on_close=False) # type: ignore[call-overload]
|
||||||
df.to_csv(temp)
|
df.to_csv(temp)
|
||||||
source_id = self.insert([ndb.CSV(temp.name)], **kwargs)[0]
|
source_id = self.insert([ndb.CSV(temp.name)], **kwargs)[0]
|
||||||
offset = self.db._savable_state.documents.get_source_by_id(source_id)[1]
|
offset = self.db._savable_state.documents.get_source_by_id(source_id)[1]
|
||||||
return [str(offset + i) for i in range(len(texts))]
|
return [str(offset + i) for i in range(len(texts))] # type: ignore[arg-type]
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environments(cls, values: Dict) -> Dict:
|
def validate_environments(cls, values: Dict) -> Dict:
|
||||||
@ -205,7 +205,7 @@ class NeuralDBVectorStore(VectorStore):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def insert(
|
def insert( # type: ignore[no-untyped-def, no-untyped-def]
|
||||||
self,
|
self,
|
||||||
sources: List[Any],
|
sources: List[Any],
|
||||||
train: bool = True,
|
train: bool = True,
|
||||||
@ -229,7 +229,7 @@ class NeuralDBVectorStore(VectorStore):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _preprocess_sources(self, sources):
|
def _preprocess_sources(self, sources): # type: ignore[no-untyped-def]
|
||||||
"""Checks if the provided sources are string paths. If they are, convert
|
"""Checks if the provided sources are string paths. If they are, convert
|
||||||
to NeuralDB document objects.
|
to NeuralDB document objects.
|
||||||
|
|
||||||
@ -261,7 +261,7 @@ class NeuralDBVectorStore(VectorStore):
|
|||||||
)
|
)
|
||||||
return preprocessed_sources
|
return preprocessed_sources
|
||||||
|
|
||||||
def upvote(self, query: str, document_id: Union[int, str]):
|
def upvote(self, query: str, document_id: Union[int, str]): # type: ignore[no-untyped-def]
|
||||||
"""The vectorstore upweights the score of a document for a specific query.
|
"""The vectorstore upweights the score of a document for a specific query.
|
||||||
This is useful for fine-tuning the vectorstore to user behavior.
|
This is useful for fine-tuning the vectorstore to user behavior.
|
||||||
|
|
||||||
@ -271,7 +271,7 @@ class NeuralDBVectorStore(VectorStore):
|
|||||||
"""
|
"""
|
||||||
self.db.text_to_result(query, int(document_id))
|
self.db.text_to_result(query, int(document_id))
|
||||||
|
|
||||||
def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]):
|
def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]): # type: ignore[no-untyped-def]
|
||||||
"""Given a batch of (query, document id) pairs, the vectorstore upweights
|
"""Given a batch of (query, document id) pairs, the vectorstore upweights
|
||||||
the scores of the document for the corresponding queries.
|
the scores of the document for the corresponding queries.
|
||||||
This is useful for fine-tuning the vectorstore to user behavior.
|
This is useful for fine-tuning the vectorstore to user behavior.
|
||||||
@ -284,7 +284,7 @@ class NeuralDBVectorStore(VectorStore):
|
|||||||
[(query, int(doc_id)) for query, doc_id in query_id_pairs]
|
[(query, int(doc_id)) for query, doc_id in query_id_pairs]
|
||||||
)
|
)
|
||||||
|
|
||||||
def associate(self, source: str, target: str):
|
def associate(self, source: str, target: str): # type: ignore[no-untyped-def]
|
||||||
"""The vectorstore associates a source phrase with a target phrase.
|
"""The vectorstore associates a source phrase with a target phrase.
|
||||||
When the vectorstore sees the source phrase, it will also consider results
|
When the vectorstore sees the source phrase, it will also consider results
|
||||||
that are relevant to the target phrase.
|
that are relevant to the target phrase.
|
||||||
@ -295,7 +295,7 @@ class NeuralDBVectorStore(VectorStore):
|
|||||||
"""
|
"""
|
||||||
self.db.associate(source, target)
|
self.db.associate(source, target)
|
||||||
|
|
||||||
def associate_batch(self, text_pairs: List[Tuple[str, str]]):
|
def associate_batch(self, text_pairs: List[Tuple[str, str]]): # type: ignore[no-untyped-def]
|
||||||
"""Given a batch of (source, target) pairs, the vectorstore associates
|
"""Given a batch of (source, target) pairs, the vectorstore associates
|
||||||
each source phrase with the corresponding target phrase.
|
each source phrase with the corresponding target phrase.
|
||||||
|
|
||||||
@ -334,7 +334,7 @@ class NeuralDBVectorStore(VectorStore):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error while retrieving documents: {e}") from e
|
raise ValueError(f"Error while retrieving documents: {e}") from e
|
||||||
|
|
||||||
def save(self, path: str):
|
def save(self, path: str): # type: ignore[no-untyped-def]
|
||||||
"""Saves a NeuralDB instance to disk. Can be loaded into memory by
|
"""Saves a NeuralDB instance to disk. Can be loaded into memory by
|
||||||
calling NeuralDB.from_checkpoint(path)
|
calling NeuralDB.from_checkpoint(path)
|
||||||
|
|
||||||
|
@ -384,7 +384,7 @@ class Vectara(VectorStore):
|
|||||||
f"(code {response.status_code}, reason {response.reason}, details "
|
f"(code {response.status_code}, reason {response.reason}, details "
|
||||||
f"{response.text})",
|
f"{response.text})",
|
||||||
)
|
)
|
||||||
return [], ""
|
return [], "" # type: ignore[return-value]
|
||||||
|
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
|
||||||
@ -454,7 +454,7 @@ class Vectara(VectorStore):
|
|||||||
docs = self.vectara_query(query, config)
|
docs = self.vectara_query(query, config)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def similarity_search(
|
def similarity_search( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -474,7 +474,7 @@ class Vectara(VectorStore):
|
|||||||
)
|
)
|
||||||
return [doc for doc, _ in docs_and_scores]
|
return [doc for doc, _ in docs_and_scores]
|
||||||
|
|
||||||
def max_marginal_relevance_search(
|
def max_marginal_relevance_search( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
fetch_k: int = 50,
|
fetch_k: int = 50,
|
||||||
|
@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class VikingDBConfig(object):
|
class VikingDBConfig(object):
|
||||||
def __init__(self, host="host", region="region", ak="ak", sk="sk", scheme="http"):
|
def __init__(self, host="host", region="region", ak="ak", sk="sk", scheme="http"): # type: ignore[no-untyped-def]
|
||||||
self.host = host
|
self.host = host
|
||||||
self.region = region
|
self.region = region
|
||||||
self.ak = ak
|
self.ak = ak
|
||||||
@ -47,11 +47,11 @@ class VikingDB(VectorStore):
|
|||||||
self.index_params = index_params
|
self.index_params = index_params
|
||||||
self.drop_old = drop_old
|
self.drop_old = drop_old
|
||||||
self.service = VikingDBService(
|
self.service = VikingDBService(
|
||||||
connection_args.host,
|
connection_args.host, # type: ignore[union-attr]
|
||||||
connection_args.region,
|
connection_args.region, # type: ignore[union-attr]
|
||||||
connection_args.ak,
|
connection_args.ak, # type: ignore[union-attr]
|
||||||
connection_args.sk,
|
connection_args.sk, # type: ignore[union-attr]
|
||||||
connection_args.scheme,
|
connection_args.scheme, # type: ignore[union-attr]
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -143,7 +143,7 @@ class VikingDB(VectorStore):
|
|||||||
scalar_index=scalar_index,
|
scalar_index=scalar_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_texts(
|
def add_texts( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
@ -183,7 +183,7 @@ class VikingDB(VectorStore):
|
|||||||
if metadatas is not None and index < len(metadatas):
|
if metadatas is not None and index < len(metadatas):
|
||||||
names = list(metadatas[index].keys())
|
names = list(metadatas[index].keys())
|
||||||
for name in names:
|
for name in names:
|
||||||
field[name] = metadatas[index].get(name)
|
field[name] = metadatas[index].get(name) # type: ignore[assignment]
|
||||||
data.append(Data(field))
|
data.append(Data(field))
|
||||||
|
|
||||||
total_count = len(data)
|
total_count = len(data)
|
||||||
@ -191,10 +191,10 @@ class VikingDB(VectorStore):
|
|||||||
end = min(i + batch_size, total_count)
|
end = min(i + batch_size, total_count)
|
||||||
insert_data = data[i:end]
|
insert_data = data[i:end]
|
||||||
# print(insert_data)
|
# print(insert_data)
|
||||||
self.collection.upsert_data(insert_data)
|
self.collection.upsert_data(insert_data) # type: ignore[union-attr]
|
||||||
return pks
|
return pks
|
||||||
|
|
||||||
def similarity_search(
|
def similarity_search( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
@ -216,7 +216,7 @@ class VikingDB(VectorStore):
|
|||||||
)
|
)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def similarity_search_by_vector(
|
def similarity_search_by_vector( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
@ -251,7 +251,7 @@ class VikingDB(VectorStore):
|
|||||||
if params.get("partition") is not None:
|
if params.get("partition") is not None:
|
||||||
partition = params["partition"]
|
partition = params["partition"]
|
||||||
|
|
||||||
res = self.index.search_by_vector(
|
res = self.index.search_by_vector( # type: ignore[union-attr]
|
||||||
embedding,
|
embedding,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@ -269,7 +269,7 @@ class VikingDB(VectorStore):
|
|||||||
ret.append(pair)
|
ret.append(pair)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def max_marginal_relevance_search(
|
def max_marginal_relevance_search( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
@ -286,7 +286,7 @@ class VikingDB(VectorStore):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def max_marginal_relevance_search_by_vector(
|
def max_marginal_relevance_search_by_vector( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
@ -311,7 +311,7 @@ class VikingDB(VectorStore):
|
|||||||
if params.get("partition") is not None:
|
if params.get("partition") is not None:
|
||||||
partition = params["partition"]
|
partition = params["partition"]
|
||||||
|
|
||||||
res = self.index.search_by_vector(
|
res = self.index.search_by_vector( # type: ignore[union-attr]
|
||||||
embedding,
|
embedding,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@ -347,10 +347,10 @@ class VikingDB(VectorStore):
|
|||||||
) -> None:
|
) -> None:
|
||||||
if self.collection is None:
|
if self.collection is None:
|
||||||
logger.debug("No existing collection to search.")
|
logger.debug("No existing collection to search.")
|
||||||
self.collection.delete_data(ids)
|
self.collection.delete_data(ids) # type: ignore[union-attr]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts( # type: ignore[no-untyped-def, override]
|
||||||
cls,
|
cls,
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
|
6
libs/community/poetry.lock
generated
6
libs/community/poetry.lock
generated
@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aenum"
|
name = "aenum"
|
||||||
@ -3944,7 +3944,7 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.1.17"
|
version = "0.1.18"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
@ -9252,4 +9252,4 @@ extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "as
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "6e1aabbf689bf7294ffc3f9215559157b95868275421d776862ddb1499969c79"
|
content-hash = "1ab63edcddcef2deb01e6fff5c376f7b0773435bb9d5b55bc1d50d19a8f1dee2"
|
||||||
|
@ -101,7 +101,7 @@ optional = true
|
|||||||
# dependencies used for running tests (e.g., pytest, freezegun, response).
|
# dependencies used for running tests (e.g., pytest, freezegun, response).
|
||||||
# Any dependencies that do not meet that criteria will be removed.
|
# Any dependencies that do not meet that criteria will be removed.
|
||||||
pytest = "^7.3.0"
|
pytest = "^7.3.0"
|
||||||
pytest-cov = "^4.0.0"
|
pytest-cov = "^4.1.0"
|
||||||
pytest-dotenv = "^0.5.2"
|
pytest-dotenv = "^0.5.2"
|
||||||
duckdb-engine = "^0.9.2"
|
duckdb-engine = "^0.9.2"
|
||||||
pytest-watcher = "^0.2.6"
|
pytest-watcher = "^0.2.6"
|
||||||
|
@ -57,7 +57,7 @@ def test_add_messages() -> None:
|
|||||||
assert len(message_store_another.messages) == 0
|
assert len(message_store_another.messages) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_tidb_recent_chat_message():
|
def test_tidb_recent_chat_message(): # type: ignore[no-untyped-def]
|
||||||
"""Test the TiDBChatMessageHistory with earliest_time parameter."""
|
"""Test the TiDBChatMessageHistory with earliest_time parameter."""
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -40,7 +40,7 @@ def test_konko_key_masked_when_passed_via_constructor(
|
|||||||
captured = capsys.readouterr()
|
captured = capsys.readouterr()
|
||||||
assert captured.out == "**********"
|
assert captured.out == "**********"
|
||||||
|
|
||||||
print(chat.konko_secret_key, end="")
|
print(chat.konko_secret_key, end="") # type: ignore[attr-defined]
|
||||||
captured = capsys.readouterr()
|
captured = capsys.readouterr()
|
||||||
assert captured.out == "**********"
|
assert captured.out == "**********"
|
||||||
|
|
||||||
@ -49,7 +49,7 @@ def test_uses_actual_secret_value_from_secret_str() -> None:
|
|||||||
"""Test that actual secret is retrieved using `.get_secret_value()`."""
|
"""Test that actual secret is retrieved using `.get_secret_value()`."""
|
||||||
chat = ChatKonko(openai_api_key="test-openai-key", konko_api_key="test-konko-key")
|
chat = ChatKonko(openai_api_key="test-openai-key", konko_api_key="test-konko-key")
|
||||||
assert cast(SecretStr, chat.konko_api_key).get_secret_value() == "test-openai-key"
|
assert cast(SecretStr, chat.konko_api_key).get_secret_value() == "test-openai-key"
|
||||||
assert cast(SecretStr, chat.konko_secret_key).get_secret_value() == "test-konko-key"
|
assert cast(SecretStr, chat.konko_secret_key).get_secret_value() == "test-konko-key" # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
def test_konko_chat_test() -> None:
|
def test_konko_chat_test() -> None:
|
||||||
|
@ -47,6 +47,6 @@ def test_chat_wasm_service_streaming() -> None:
|
|||||||
output = ""
|
output = ""
|
||||||
for chunk in chat.stream(messages):
|
for chunk in chat.stream(messages):
|
||||||
print(chunk.content, end="", flush=True)
|
print(chunk.content, end="", flush=True)
|
||||||
output += chunk.content
|
output += chunk.content # type: ignore[operator]
|
||||||
|
|
||||||
assert "Paris" in output
|
assert "Paris" in output
|
||||||
|
@ -167,5 +167,5 @@ class TestAstraDB:
|
|||||||
find_options={"limit": 30},
|
find_options={"limit": 30},
|
||||||
extraction_function=lambda x: x["foo"],
|
extraction_function=lambda x: x["foo"],
|
||||||
)
|
)
|
||||||
doc = await anext(loader.alazy_load())
|
doc = await anext(loader.alazy_load()) # type: ignore[name-defined]
|
||||||
assert doc.page_content == "bar"
|
assert doc.page_content == "bar"
|
||||||
|
@ -14,7 +14,7 @@ CASSANDRA_TABLE = "docloader_test_table"
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True, scope="session")
|
@pytest.fixture(autouse=True, scope="session")
|
||||||
def keyspace() -> str:
|
def keyspace() -> str: # type: ignore[misc]
|
||||||
import cassio
|
import cassio
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from cassio.config import check_resolve_session, resolve_keyspace
|
from cassio.config import check_resolve_session, resolve_keyspace
|
||||||
|
@ -7,8 +7,8 @@ def test_baichuan_embedding_documents() -> None:
|
|||||||
documents = ["今天天气不错", "今天阳光灿烂"]
|
documents = ["今天天气不错", "今天阳光灿烂"]
|
||||||
embedding = BaichuanTextEmbeddings()
|
embedding = BaichuanTextEmbeddings()
|
||||||
output = embedding.embed_documents(documents)
|
output = embedding.embed_documents(documents)
|
||||||
assert len(output) == 2
|
assert len(output) == 2 # type: ignore[arg-type]
|
||||||
assert len(output[0]) == 1024
|
assert len(output[0]) == 1024 # type: ignore[index]
|
||||||
|
|
||||||
|
|
||||||
def test_baichuan_embedding_query() -> None:
|
def test_baichuan_embedding_query() -> None:
|
||||||
@ -16,4 +16,4 @@ def test_baichuan_embedding_query() -> None:
|
|||||||
document = "所有的小学生都会学过只因兔同笼问题。"
|
document = "所有的小学生都会学过只因兔同笼问题。"
|
||||||
embedding = BaichuanTextEmbeddings()
|
embedding = BaichuanTextEmbeddings()
|
||||||
output = embedding.embed_query(document)
|
output = embedding.embed_query(document)
|
||||||
assert len(output) == 1024
|
assert len(output) == 1024 # type: ignore[arg-type]
|
||||||
|
@ -85,7 +85,7 @@ def test_neo4j_timeout() -> None:
|
|||||||
graph.query("UNWIND range(0,100000,1) AS i MERGE (:Foo {id:i})")
|
graph.query("UNWIND range(0,100000,1) AS i MERGE (:Foo {id:i})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
assert (
|
assert (
|
||||||
e.code
|
e.code # type: ignore[attr-defined]
|
||||||
== "Neo.ClientError.Transaction.TransactionTimedOutClientConfiguration"
|
== "Neo.ClientError.Transaction.TransactionTimedOutClientConfiguration"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ def test_custom_formatter() -> None:
|
|||||||
content_type = "application/json"
|
content_type = "application/json"
|
||||||
accepts = "application/json"
|
accepts = "application/json"
|
||||||
|
|
||||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
|
||||||
input_str = json.dumps(
|
input_str = json.dumps(
|
||||||
{
|
{
|
||||||
"inputs": [prompt],
|
"inputs": [prompt],
|
||||||
@ -72,7 +72,7 @@ def test_custom_formatter() -> None:
|
|||||||
)
|
)
|
||||||
return input_str.encode("utf-8")
|
return input_str.encode("utf-8")
|
||||||
|
|
||||||
def format_response_payload(self, output: bytes) -> str:
|
def format_response_payload(self, output: bytes) -> str: # type: ignore[override]
|
||||||
response_json = json.loads(output)
|
response_json = json.loads(output)
|
||||||
return response_json[0]["summary_text"]
|
return response_json[0]["summary_text"]
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ def test_invalid_request_format() -> None:
|
|||||||
content_type = "application/json"
|
content_type = "application/json"
|
||||||
accepts = "application/json"
|
accepts = "application/json"
|
||||||
|
|
||||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
|
||||||
input_str = json.dumps(
|
input_str = json.dumps(
|
||||||
{
|
{
|
||||||
"incorrect_input": {"input_string": [prompt]},
|
"incorrect_input": {"input_string": [prompt]},
|
||||||
@ -113,7 +113,7 @@ def test_invalid_request_format() -> None:
|
|||||||
)
|
)
|
||||||
return str.encode(input_str)
|
return str.encode(input_str)
|
||||||
|
|
||||||
def format_response_payload(self, output: bytes) -> str:
|
def format_response_payload(self, output: bytes) -> str: # type: ignore[override]
|
||||||
response_json = json.loads(output)
|
response_json = json.loads(output)
|
||||||
return response_json[0]["0"]
|
return response_json[0]["0"]
|
||||||
|
|
||||||
|
@ -37,12 +37,12 @@ class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
|
|||||||
if reason == "GUARDRAIL_INTERVENED":
|
if reason == "GUARDRAIL_INTERVENED":
|
||||||
self.guardrails_intervened = True
|
self.guardrails_intervened = True
|
||||||
|
|
||||||
def get_response(self):
|
def get_response(self): # type: ignore[no-untyped-def]
|
||||||
return self.guardrails_intervened
|
return self.guardrails_intervened
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def bedrock_runtime_client():
|
def bedrock_runtime_client(): # type: ignore[no-untyped-def]
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -56,7 +56,7 @@ def bedrock_runtime_client():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def bedrock_client():
|
def bedrock_client(): # type: ignore[no-untyped-def]
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -70,7 +70,7 @@ def bedrock_client():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def bedrock_models(bedrock_client):
|
def bedrock_models(bedrock_client): # type: ignore[no-untyped-def]
|
||||||
"""List bedrock models."""
|
"""List bedrock models."""
|
||||||
response = bedrock_client.list_foundation_models().get("modelSummaries")
|
response = bedrock_client.list_foundation_models().get("modelSummaries")
|
||||||
models = {}
|
models = {}
|
||||||
@ -79,7 +79,7 @@ def bedrock_models(bedrock_client):
|
|||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
def test_claude_instant_v1(bedrock_runtime_client, bedrock_models):
|
def test_claude_instant_v1(bedrock_runtime_client, bedrock_models): # type: ignore[no-untyped-def]
|
||||||
try:
|
try:
|
||||||
llm = Bedrock(
|
llm = Bedrock(
|
||||||
model_id="anthropic.claude-instant-v1",
|
model_id="anthropic.claude-instant-v1",
|
||||||
@ -92,7 +92,7 @@ def test_claude_instant_v1(bedrock_runtime_client, bedrock_models):
|
|||||||
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
|
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
|
||||||
|
|
||||||
|
|
||||||
def test_amazon_bedrock_guardrails_no_intervention_for_valid_query(
|
def test_amazon_bedrock_guardrails_no_intervention_for_valid_query( # type: ignore[no-untyped-def]
|
||||||
bedrock_runtime_client, bedrock_models
|
bedrock_runtime_client, bedrock_models
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
@ -112,7 +112,7 @@ def test_amazon_bedrock_guardrails_no_intervention_for_valid_query(
|
|||||||
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
|
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
|
||||||
|
|
||||||
|
|
||||||
def test_amazon_bedrock_guardrails_intervention_for_invalid_query(
|
def test_amazon_bedrock_guardrails_intervention_for_invalid_query( # type: ignore[no-untyped-def]
|
||||||
bedrock_runtime_client, bedrock_models
|
bedrock_runtime_client, bedrock_models
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
@ -16,7 +16,7 @@ def _has_env_vars() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def astra_db():
|
def astra_db(): # type: ignore[no-untyped-def]
|
||||||
from astrapy.db import AstraDB
|
from astrapy.db import AstraDB
|
||||||
|
|
||||||
return AstraDB(
|
return AstraDB(
|
||||||
@ -26,14 +26,14 @@ def astra_db():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def init_store(astra_db, collection_name: str):
|
def init_store(astra_db, collection_name: str): # type: ignore[no-untyped-def, no-untyped-def]
|
||||||
astra_db.create_collection(collection_name)
|
astra_db.create_collection(collection_name)
|
||||||
store = AstraDBStore(collection_name=collection_name, astra_db_client=astra_db)
|
store = AstraDBStore(collection_name=collection_name, astra_db_client=astra_db)
|
||||||
store.mset([("key1", [0.1, 0.2]), ("key2", "value2")])
|
store.mset([("key1", [0.1, 0.2]), ("key2", "value2")])
|
||||||
return store
|
return store
|
||||||
|
|
||||||
|
|
||||||
def init_bytestore(astra_db, collection_name: str):
|
def init_bytestore(astra_db, collection_name: str): # type: ignore[no-untyped-def, no-untyped-def]
|
||||||
astra_db.create_collection(collection_name)
|
astra_db.create_collection(collection_name)
|
||||||
store = AstraDBByteStore(collection_name=collection_name, astra_db_client=astra_db)
|
store = AstraDBByteStore(collection_name=collection_name, astra_db_client=astra_db)
|
||||||
store.mset([("key1", b"value1"), ("key2", b"value2")])
|
store.mset([("key1", b"value1"), ("key2", b"value2")])
|
||||||
@ -43,7 +43,7 @@ def init_bytestore(astra_db, collection_name: str):
|
|||||||
@pytest.mark.requires("astrapy")
|
@pytest.mark.requires("astrapy")
|
||||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||||
class TestAstraDBStore:
|
class TestAstraDBStore:
|
||||||
def test_mget(self, astra_db) -> None:
|
def test_mget(self, astra_db) -> None: # type: ignore[no-untyped-def]
|
||||||
"""Test AstraDBStore mget method."""
|
"""Test AstraDBStore mget method."""
|
||||||
collection_name = "lc_test_store_mget"
|
collection_name = "lc_test_store_mget"
|
||||||
try:
|
try:
|
||||||
@ -52,7 +52,7 @@ class TestAstraDBStore:
|
|||||||
finally:
|
finally:
|
||||||
astra_db.delete_collection(collection_name)
|
astra_db.delete_collection(collection_name)
|
||||||
|
|
||||||
def test_mset(self, astra_db) -> None:
|
def test_mset(self, astra_db) -> None: # type: ignore[no-untyped-def]
|
||||||
"""Test that multiple keys can be set with AstraDBStore."""
|
"""Test that multiple keys can be set with AstraDBStore."""
|
||||||
collection_name = "lc_test_store_mset"
|
collection_name = "lc_test_store_mset"
|
||||||
try:
|
try:
|
||||||
@ -64,7 +64,7 @@ class TestAstraDBStore:
|
|||||||
finally:
|
finally:
|
||||||
astra_db.delete_collection(collection_name)
|
astra_db.delete_collection(collection_name)
|
||||||
|
|
||||||
def test_mdelete(self, astra_db) -> None:
|
def test_mdelete(self, astra_db) -> None: # type: ignore[no-untyped-def]
|
||||||
"""Test that deletion works as expected."""
|
"""Test that deletion works as expected."""
|
||||||
collection_name = "lc_test_store_mdelete"
|
collection_name = "lc_test_store_mdelete"
|
||||||
try:
|
try:
|
||||||
@ -75,7 +75,7 @@ class TestAstraDBStore:
|
|||||||
finally:
|
finally:
|
||||||
astra_db.delete_collection(collection_name)
|
astra_db.delete_collection(collection_name)
|
||||||
|
|
||||||
def test_yield_keys(self, astra_db) -> None:
|
def test_yield_keys(self, astra_db) -> None: # type: ignore[no-untyped-def]
|
||||||
collection_name = "lc_test_store_yield_keys"
|
collection_name = "lc_test_store_yield_keys"
|
||||||
try:
|
try:
|
||||||
store = init_store(astra_db, collection_name)
|
store = init_store(astra_db, collection_name)
|
||||||
@ -85,7 +85,7 @@ class TestAstraDBStore:
|
|||||||
finally:
|
finally:
|
||||||
astra_db.delete_collection(collection_name)
|
astra_db.delete_collection(collection_name)
|
||||||
|
|
||||||
def test_bytestore_mget(self, astra_db) -> None:
|
def test_bytestore_mget(self, astra_db) -> None: # type: ignore[no-untyped-def]
|
||||||
"""Test AstraDBByteStore mget method."""
|
"""Test AstraDBByteStore mget method."""
|
||||||
collection_name = "lc_test_bytestore_mget"
|
collection_name = "lc_test_bytestore_mget"
|
||||||
try:
|
try:
|
||||||
@ -94,7 +94,7 @@ class TestAstraDBStore:
|
|||||||
finally:
|
finally:
|
||||||
astra_db.delete_collection(collection_name)
|
astra_db.delete_collection(collection_name)
|
||||||
|
|
||||||
def test_bytestore_mset(self, astra_db) -> None:
|
def test_bytestore_mset(self, astra_db) -> None: # type: ignore[no-untyped-def]
|
||||||
"""Test that multiple keys can be set with AstraDBByteStore."""
|
"""Test that multiple keys can be set with AstraDBByteStore."""
|
||||||
collection_name = "lc_test_bytestore_mset"
|
collection_name = "lc_test_bytestore_mset"
|
||||||
try:
|
try:
|
||||||
|
@ -6,7 +6,7 @@ from langchain_community.utilities.google_trends import GoogleTrendsAPIWrapper
|
|||||||
|
|
||||||
|
|
||||||
@patch("serpapi.SerpApiClient.get_json")
|
@patch("serpapi.SerpApiClient.get_json")
|
||||||
def test_unexpected_response(mocked_serpapiclient):
|
def test_unexpected_response(mocked_serpapiclient): # type: ignore[no-untyped-def]
|
||||||
os.environ["SERPAPI_API_KEY"] = "123abcd"
|
os.environ["SERPAPI_API_KEY"] = "123abcd"
|
||||||
resp = {
|
resp = {
|
||||||
"search_metadata": {
|
"search_metadata": {
|
||||||
|
@ -15,7 +15,7 @@ def qdrant_is_not_running() -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def assert_documents_equals(actual: List[Document], expected: List[Document]):
|
def assert_documents_equals(actual: List[Document], expected: List[Document]): # type: ignore[no-untyped-def]
|
||||||
assert len(actual) == len(expected)
|
assert len(actual) == len(expected)
|
||||||
|
|
||||||
for actual_doc, expected_doc in zip(actual, expected):
|
for actual_doc, expected_doc in zip(actual, expected):
|
||||||
|
@ -32,7 +32,7 @@ def store(request: pytest.FixtureRequest) -> BigQueryVectorSearch:
|
|||||||
TestBigQueryVectorStore.dataset_name, exists_ok=True
|
TestBigQueryVectorStore.dataset_name, exists_ok=True
|
||||||
)
|
)
|
||||||
TestBigQueryVectorStore.store = BigQueryVectorSearch(
|
TestBigQueryVectorStore.store = BigQueryVectorSearch(
|
||||||
project_id=os.environ.get("PROJECT", None),
|
project_id=os.environ.get("PROJECT", None), # type: ignore[arg-type]
|
||||||
embedding=FakeEmbeddings(),
|
embedding=FakeEmbeddings(),
|
||||||
dataset_name=TestBigQueryVectorStore.dataset_name,
|
dataset_name=TestBigQueryVectorStore.dataset_name,
|
||||||
table_name=TEST_TABLE_NAME,
|
table_name=TEST_TABLE_NAME,
|
||||||
|
@ -52,7 +52,7 @@ def test_deeplake_with_metadatas() -> None:
|
|||||||
assert output == [Document(page_content="foo", metadata={"page": "0"})]
|
assert output == [Document(page_content="foo", metadata={"page": "0"})]
|
||||||
|
|
||||||
|
|
||||||
def test_deeplake_with_persistence(deeplake_datastore) -> None:
|
def test_deeplake_with_persistence(deeplake_datastore) -> None: # type: ignore[no-untyped-def]
|
||||||
"""Test end to end construction and search, with persistence."""
|
"""Test end to end construction and search, with persistence."""
|
||||||
output = deeplake_datastore.similarity_search("foo", k=1)
|
output = deeplake_datastore.similarity_search("foo", k=1)
|
||||||
assert output == [Document(page_content="foo", metadata={"page": "0"})]
|
assert output == [Document(page_content="foo", metadata={"page": "0"})]
|
||||||
@ -72,7 +72,7 @@ def test_deeplake_with_persistence(deeplake_datastore) -> None:
|
|||||||
# Or on program exit
|
# Or on program exit
|
||||||
|
|
||||||
|
|
||||||
def test_deeplake_overwrite_flag(deeplake_datastore) -> None:
|
def test_deeplake_overwrite_flag(deeplake_datastore) -> None: # type: ignore[no-untyped-def]
|
||||||
"""Test overwrite behavior"""
|
"""Test overwrite behavior"""
|
||||||
dataset_path = deeplake_datastore.vectorstore.dataset_handler.path
|
dataset_path = deeplake_datastore.vectorstore.dataset_handler.path
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ def test_deeplake_overwrite_flag(deeplake_datastore) -> None:
|
|||||||
output = docsearch.similarity_search("foo", k=1)
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
|
||||||
|
|
||||||
def test_similarity_search(deeplake_datastore) -> None:
|
def test_similarity_search(deeplake_datastore) -> None: # type: ignore[no-untyped-def]
|
||||||
"""Test similarity search."""
|
"""Test similarity search."""
|
||||||
distance_metric = "cos"
|
distance_metric = "cos"
|
||||||
output = deeplake_datastore.similarity_search(
|
output = deeplake_datastore.similarity_search(
|
||||||
|
@ -38,7 +38,7 @@ embedding = NormalizedFakeEmbeddings()
|
|||||||
|
|
||||||
|
|
||||||
class ConfigData:
|
class ConfigData:
|
||||||
def __init__(self):
|
def __init__(self): # type: ignore[no-untyped-def]
|
||||||
self.conn = None
|
self.conn = None
|
||||||
self.schema_name = ""
|
self.schema_name = ""
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ class ConfigData:
|
|||||||
test_setup = ConfigData()
|
test_setup = ConfigData()
|
||||||
|
|
||||||
|
|
||||||
def generateSchemaName(cursor):
|
def generateSchemaName(cursor): # type: ignore[no-untyped-def]
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT REPLACE(CURRENT_UTCDATE, '-', '') || '_' || BINTOHEX(SYSUUID) FROM "
|
"SELECT REPLACE(CURRENT_UTCDATE, '-', '') || '_' || BINTOHEX(SYSUUID) FROM "
|
||||||
"DUMMY;"
|
"DUMMY;"
|
||||||
@ -59,7 +59,7 @@ def generateSchemaName(cursor):
|
|||||||
return f"VEC_{uid}"
|
return f"VEC_{uid}"
|
||||||
|
|
||||||
|
|
||||||
def setup_module(module):
|
def setup_module(module): # type: ignore[no-untyped-def]
|
||||||
test_setup.conn = dbapi.connect(
|
test_setup.conn = dbapi.connect(
|
||||||
address=os.environ.get("HANA_DB_ADDRESS"),
|
address=os.environ.get("HANA_DB_ADDRESS"),
|
||||||
port=os.environ.get("HANA_DB_PORT"),
|
port=os.environ.get("HANA_DB_PORT"),
|
||||||
@ -81,7 +81,7 @@ def setup_module(module):
|
|||||||
cur.close()
|
cur.close()
|
||||||
|
|
||||||
|
|
||||||
def teardown_module(module):
|
def teardown_module(module): # type: ignore[no-untyped-def]
|
||||||
try:
|
try:
|
||||||
cur = test_setup.conn.cursor()
|
cur = test_setup.conn.cursor()
|
||||||
sql_str = f"DROP SCHEMA {test_setup.schema_name} CASCADE"
|
sql_str = f"DROP SCHEMA {test_setup.schema_name} CASCADE"
|
||||||
@ -100,13 +100,13 @@ def texts() -> List[str]:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def metadatas() -> List[str]:
|
def metadatas() -> List[str]:
|
||||||
return [
|
return [
|
||||||
{"start": 0, "end": 100, "quality": "good", "ready": True},
|
{"start": 0, "end": 100, "quality": "good", "ready": True}, # type: ignore[list-item]
|
||||||
{"start": 100, "end": 200, "quality": "bad", "ready": False},
|
{"start": 100, "end": 200, "quality": "bad", "ready": False}, # type: ignore[list-item]
|
||||||
{"start": 200, "end": 300, "quality": "ugly", "ready": True},
|
{"start": 200, "end": 300, "quality": "ugly", "ready": True}, # type: ignore[list-item]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def drop_table(connection, table_name):
|
def drop_table(connection, table_name): # type: ignore[no-untyped-def]
|
||||||
try:
|
try:
|
||||||
cur = connection.cursor()
|
cur = connection.cursor()
|
||||||
sql_str = f"DROP TABLE {table_name}"
|
sql_str = f"DROP TABLE {table_name}"
|
||||||
@ -825,7 +825,7 @@ def test_hanavector_filter_prepared_statement_params(
|
|||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
assert len(rows) == 1
|
assert len(rows) == 1
|
||||||
|
|
||||||
query_value = "good"
|
query_value = "good" # type: ignore[assignment]
|
||||||
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.quality') = ?"
|
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.quality') = ?"
|
||||||
cur.execute(sql_str, (query_value))
|
cur.execute(sql_str, (query_value))
|
||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
@ -839,14 +839,14 @@ def test_hanavector_filter_prepared_statement_params(
|
|||||||
assert len(rows) == 1
|
assert len(rows) == 1
|
||||||
|
|
||||||
# query_value = True
|
# query_value = True
|
||||||
query_value = "true"
|
query_value = "true" # type: ignore[assignment]
|
||||||
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?"
|
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?"
|
||||||
cur.execute(sql_str, (query_value))
|
cur.execute(sql_str, (query_value))
|
||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
assert len(rows) == 2
|
assert len(rows) == 2
|
||||||
|
|
||||||
# query_value = False
|
# query_value = False
|
||||||
query_value = "false"
|
query_value = "false" # type: ignore[assignment]
|
||||||
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?"
|
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?"
|
||||||
cur.execute(sql_str, (query_value))
|
cur.execute(sql_str, (query_value))
|
||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
|
@ -31,7 +31,7 @@ def fix_distance_precision(
|
|||||||
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
|
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
|
||||||
"""Fake embeddings functionality for testing."""
|
"""Fake embeddings functionality for testing."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self): # type: ignore[no-untyped-def]
|
||||||
super(FakeEmbeddingsWithAdaDimension, self).__init__(size=ADA_TOKEN_COUNT)
|
super(FakeEmbeddingsWithAdaDimension, self).__init__(size=ADA_TOKEN_COUNT)
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
@ -7,7 +7,7 @@ from langchain_community.vectorstores import NeuralDBVectorStore
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def test_csv():
|
def test_csv(): # type: ignore[no-untyped-def]
|
||||||
csv = "thirdai-test.csv"
|
csv = "thirdai-test.csv"
|
||||||
with open(csv, "w") as o:
|
with open(csv, "w") as o:
|
||||||
o.write("column_1,column_2\n")
|
o.write("column_1,column_2\n")
|
||||||
@ -16,13 +16,13 @@ def test_csv():
|
|||||||
os.remove(csv)
|
os.remove(csv)
|
||||||
|
|
||||||
|
|
||||||
def assert_result_correctness(documents):
|
def assert_result_correctness(documents): # type: ignore[no-untyped-def]
|
||||||
assert len(documents) == 1
|
assert len(documents) == 1
|
||||||
assert documents[0].page_content == "column_1: column one\n\ncolumn_2: column two"
|
assert documents[0].page_content == "column_1: column one\n\ncolumn_2: column two"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("thirdai[neural_db]")
|
@pytest.mark.requires("thirdai[neural_db]")
|
||||||
def test_neuraldb_retriever_from_scratch(test_csv):
|
def test_neuraldb_retriever_from_scratch(test_csv): # type: ignore[no-untyped-def]
|
||||||
retriever = NeuralDBVectorStore.from_scratch()
|
retriever = NeuralDBVectorStore.from_scratch()
|
||||||
retriever.insert([test_csv])
|
retriever.insert([test_csv])
|
||||||
documents = retriever.similarity_search("column")
|
documents = retriever.similarity_search("column")
|
||||||
@ -30,7 +30,7 @@ def test_neuraldb_retriever_from_scratch(test_csv):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("thirdai[neural_db]")
|
@pytest.mark.requires("thirdai[neural_db]")
|
||||||
def test_neuraldb_retriever_from_checkpoint(test_csv):
|
def test_neuraldb_retriever_from_checkpoint(test_csv): # type: ignore[no-untyped-def]
|
||||||
checkpoint = "thirdai-test-save.ndb"
|
checkpoint = "thirdai-test-save.ndb"
|
||||||
if os.path.exists(checkpoint):
|
if os.path.exists(checkpoint):
|
||||||
shutil.rmtree(checkpoint)
|
shutil.rmtree(checkpoint)
|
||||||
@ -47,7 +47,7 @@ def test_neuraldb_retriever_from_checkpoint(test_csv):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("thirdai[neural_db]")
|
@pytest.mark.requires("thirdai[neural_db]")
|
||||||
def test_neuraldb_retriever_from_bazaar(test_csv):
|
def test_neuraldb_retriever_from_bazaar(test_csv): # type: ignore[no-untyped-def]
|
||||||
retriever = NeuralDBVectorStore.from_bazaar("General QnA")
|
retriever = NeuralDBVectorStore.from_bazaar("General QnA")
|
||||||
retriever.insert([test_csv])
|
retriever.insert([test_csv])
|
||||||
documents = retriever.similarity_search("column")
|
documents = retriever.similarity_search("column")
|
||||||
@ -55,7 +55,7 @@ def test_neuraldb_retriever_from_bazaar(test_csv):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("thirdai[neural_db]")
|
@pytest.mark.requires("thirdai[neural_db]")
|
||||||
def test_neuraldb_retriever_other_methods(test_csv):
|
def test_neuraldb_retriever_other_methods(test_csv): # type: ignore[no-untyped-def]
|
||||||
retriever = NeuralDBVectorStore.from_scratch()
|
retriever = NeuralDBVectorStore.from_scratch()
|
||||||
retriever.insert([test_csv])
|
retriever.insert([test_csv])
|
||||||
# Make sure they don't throw an error.
|
# Make sure they don't throw an error.
|
||||||
|
@ -25,7 +25,7 @@ def get_abbr(s: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def vectara1():
|
def vectara1(): # type: ignore[no-untyped-def]
|
||||||
# Set up code
|
# Set up code
|
||||||
# create a new Vectara instance
|
# create a new Vectara instance
|
||||||
vectara1: Vectara = Vectara()
|
vectara1: Vectara = Vectara()
|
||||||
@ -54,7 +54,7 @@ def vectara1():
|
|||||||
vectara1._delete_doc(doc_id)
|
vectara1._delete_doc(doc_id)
|
||||||
|
|
||||||
|
|
||||||
def test_vectara_add_documents(vectara1) -> None:
|
def test_vectara_add_documents(vectara1) -> None: # type: ignore[no-untyped-def]
|
||||||
"""Test add_documents."""
|
"""Test add_documents."""
|
||||||
|
|
||||||
# test without filter
|
# test without filter
|
||||||
@ -164,7 +164,7 @@ models can greatly improve the training of DNNs and other deep discriminative mo
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def vectara3():
|
def vectara3(): # type: ignore[no-untyped-def]
|
||||||
# Set up code
|
# Set up code
|
||||||
vectara3: Vectara = Vectara()
|
vectara3: Vectara = Vectara()
|
||||||
|
|
||||||
@ -210,7 +210,7 @@ def vectara3():
|
|||||||
vectara3._delete_doc(doc_id)
|
vectara3._delete_doc(doc_id)
|
||||||
|
|
||||||
|
|
||||||
def test_vectara_mmr(vectara3) -> None:
|
def test_vectara_mmr(vectara3) -> None: # type: ignore[no-untyped-def]
|
||||||
# test max marginal relevance
|
# test max marginal relevance
|
||||||
output1 = vectara3.max_marginal_relevance_search(
|
output1 = vectara3.max_marginal_relevance_search(
|
||||||
"generative AI",
|
"generative AI",
|
||||||
@ -241,7 +241,7 @@ def test_vectara_mmr(vectara3) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_vectara_with_summary(vectara3) -> None:
|
def test_vectara_with_summary(vectara3) -> None: # type: ignore[no-untyped-def]
|
||||||
"""Test vectara summary."""
|
"""Test vectara summary."""
|
||||||
# test summarization
|
# test summarization
|
||||||
num_results = 10
|
num_results = 10
|
||||||
|
@ -35,6 +35,6 @@ def test_edenai_messages_formatting(messages: List[BaseMessage], expected: str)
|
|||||||
("role", "role_response"),
|
("role", "role_response"),
|
||||||
[("ai", "assistant"), ("human", "user"), ("chat", "user")],
|
[("ai", "assistant"), ("human", "user"), ("chat", "user")],
|
||||||
)
|
)
|
||||||
def test_edenai_message_role(role: str, role_response) -> None:
|
def test_edenai_message_role(role: str, role_response) -> None: # type: ignore[no-untyped-def]
|
||||||
role = _message_role(role)
|
role = _message_role(role)
|
||||||
assert role == role_response
|
assert role == role_response
|
||||||
|
@ -29,7 +29,7 @@ class GradientEmbeddingsModel(MagicMock):
|
|||||||
embeddings = []
|
embeddings = []
|
||||||
for i, inp in enumerate(inputs):
|
for i, inp in enumerate(inputs):
|
||||||
# verify correct ordering
|
# verify correct ordering
|
||||||
inp = inp["input"]
|
inp = inp["input"] # type: ignore[assignment]
|
||||||
if "pizza" in inp:
|
if "pizza" in inp:
|
||||||
v = [1.0, 0.0, 0.0]
|
v = [1.0, 0.0, 0.0]
|
||||||
elif "document" in inp:
|
elif "document" in inp:
|
||||||
@ -45,14 +45,14 @@ class GradientEmbeddingsModel(MagicMock):
|
|||||||
output.embeddings = embeddings
|
output.embeddings = embeddings
|
||||||
return output
|
return output
|
||||||
|
|
||||||
async def aembed(self, *args) -> Any:
|
async def aembed(self, *args) -> Any: # type: ignore[no-untyped-def]
|
||||||
return self.embed(*args)
|
return self.embed(*args)
|
||||||
|
|
||||||
|
|
||||||
class MockGradient(MagicMock):
|
class MockGradient(MagicMock):
|
||||||
"""Mock Gradient package."""
|
"""Mock Gradient package."""
|
||||||
|
|
||||||
def __init__(self, access_token: str, workspace_id, host):
|
def __init__(self, access_token: str, workspace_id, host): # type: ignore[no-untyped-def]
|
||||||
assert access_token == _GRADIENT_SECRET
|
assert access_token == _GRADIENT_SECRET
|
||||||
assert workspace_id == _GRADIENT_WORKSPACE_ID
|
assert workspace_id == _GRADIENT_WORKSPACE_ID
|
||||||
assert host == _GRADIENT_BASE_URL
|
assert host == _GRADIENT_BASE_URL
|
||||||
|
@ -8,7 +8,7 @@ from langchain_community.embeddings import OCIGenAIEmbeddings
|
|||||||
|
|
||||||
|
|
||||||
class MockResponseDict(dict):
|
class MockResponseDict(dict):
|
||||||
def __getattr__(self, val):
|
def __getattr__(self, val): # type: ignore[no-untyped-def]
|
||||||
return self[val]
|
return self[val]
|
||||||
|
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ def test_embedding_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
|
|||||||
client=oci_gen_ai_client,
|
client=oci_gen_ai_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
def mocked_response(invocation_obj):
|
def mocked_response(invocation_obj): # type: ignore[no-untyped-def]
|
||||||
docs = invocation_obj.inputs
|
docs = invocation_obj.inputs
|
||||||
|
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
from langchain_community.graphs.neo4j_graph import value_sanitize
|
from langchain_community.graphs.neo4j_graph import value_sanitize
|
||||||
|
|
||||||
|
|
||||||
def test_value_sanitize_with_small_list():
|
def test_value_sanitize_with_small_list(): # type: ignore[no-untyped-def]
|
||||||
small_list = list(range(15)) # list size > LIST_LIMIT
|
small_list = list(range(15)) # list size > LIST_LIMIT
|
||||||
input_dict = {"key1": "value1", "small_list": small_list}
|
input_dict = {"key1": "value1", "small_list": small_list}
|
||||||
expected_output = {"key1": "value1", "small_list": small_list}
|
expected_output = {"key1": "value1", "small_list": small_list}
|
||||||
assert value_sanitize(input_dict) == expected_output
|
assert value_sanitize(input_dict) == expected_output
|
||||||
|
|
||||||
|
|
||||||
def test_value_sanitize_with_oversized_list():
|
def test_value_sanitize_with_oversized_list(): # type: ignore[no-untyped-def]
|
||||||
oversized_list = list(range(150)) # list size > LIST_LIMIT
|
oversized_list = list(range(150)) # list size > LIST_LIMIT
|
||||||
input_dict = {"key1": "value1", "oversized_list": oversized_list}
|
input_dict = {"key1": "value1", "oversized_list": oversized_list}
|
||||||
expected_output = {
|
expected_output = {
|
||||||
@ -18,14 +18,14 @@ def test_value_sanitize_with_oversized_list():
|
|||||||
assert value_sanitize(input_dict) == expected_output
|
assert value_sanitize(input_dict) == expected_output
|
||||||
|
|
||||||
|
|
||||||
def test_value_sanitize_with_nested_oversized_list():
|
def test_value_sanitize_with_nested_oversized_list(): # type: ignore[no-untyped-def]
|
||||||
oversized_list = list(range(150)) # list size > LIST_LIMIT
|
oversized_list = list(range(150)) # list size > LIST_LIMIT
|
||||||
input_dict = {"key1": "value1", "oversized_list": {"key": oversized_list}}
|
input_dict = {"key1": "value1", "oversized_list": {"key": oversized_list}}
|
||||||
expected_output = {"key1": "value1", "oversized_list": {}}
|
expected_output = {"key1": "value1", "oversized_list": {}}
|
||||||
assert value_sanitize(input_dict) == expected_output
|
assert value_sanitize(input_dict) == expected_output
|
||||||
|
|
||||||
|
|
||||||
def test_value_sanitize_with_dict_in_list():
|
def test_value_sanitize_with_dict_in_list(): # type: ignore[no-untyped-def]
|
||||||
oversized_list = list(range(150)) # list size > LIST_LIMIT
|
oversized_list = list(range(150)) # list size > LIST_LIMIT
|
||||||
input_dict = {"key1": "value1", "oversized_list": [1, 2, {"key": oversized_list}]}
|
input_dict = {"key1": "value1", "oversized_list": [1, 2, {"key": oversized_list}]}
|
||||||
expected_output = {"key1": "value1", "oversized_list": [1, 2, {}]}
|
expected_output = {"key1": "value1", "oversized_list": [1, 2, {}]}
|
||||||
|
@ -15,7 +15,7 @@ class TestOntotextGraphDBGraph(unittest.TestCase):
|
|||||||
|
|
||||||
with self.assertRaises(TypeError) as e:
|
with self.assertRaises(TypeError) as e:
|
||||||
OntotextGraphDBGraph._validate_user_query(
|
OntotextGraphDBGraph._validate_user_query(
|
||||||
[
|
[ # type: ignore[arg-type]
|
||||||
"PREFIX starwars: <https://swapi.co/ontology/> "
|
"PREFIX starwars: <https://swapi.co/ontology/> "
|
||||||
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> "
|
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> "
|
||||||
"DESCRIBE starwars: ?term "
|
"DESCRIBE starwars: ?term "
|
||||||
|
@ -8,7 +8,7 @@ from langchain_community.llms import OCIGenAI
|
|||||||
|
|
||||||
|
|
||||||
class MockResponseDict(dict):
|
class MockResponseDict(dict):
|
||||||
def __getattr__(self, val):
|
def __getattr__(self, val): # type: ignore[no-untyped-def]
|
||||||
return self[val]
|
return self[val]
|
||||||
|
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
|
|||||||
|
|
||||||
provider = llm._get_provider()
|
provider = llm._get_provider()
|
||||||
|
|
||||||
def mocked_response(*args):
|
def mocked_response(*args): # type: ignore[no-untyped-def]
|
||||||
response_text = "This is the completion."
|
response_text = "This is the completion."
|
||||||
|
|
||||||
if provider == "cohere":
|
if provider == "cohere":
|
||||||
|
@ -4,11 +4,11 @@ from pytest import MonkeyPatch
|
|||||||
from langchain_community.llms.ollama import Ollama
|
from langchain_community.llms.ollama import Ollama
|
||||||
|
|
||||||
|
|
||||||
def mock_response_stream():
|
def mock_response_stream(): # type: ignore[no-untyped-def]
|
||||||
mock_response = [b'{ "response": "Response chunk 1" }']
|
mock_response = [b'{ "response": "Response chunk 1" }']
|
||||||
|
|
||||||
class MockRaw:
|
class MockRaw:
|
||||||
def read(self, chunk_size):
|
def read(self, chunk_size): # type: ignore[no-untyped-def]
|
||||||
try:
|
try:
|
||||||
return mock_response.pop()
|
return mock_response.pop()
|
||||||
except IndexError:
|
except IndexError:
|
||||||
@ -31,7 +31,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
|
|||||||
timeout=300,
|
timeout=300,
|
||||||
)
|
)
|
||||||
|
|
||||||
def mock_post(url, headers, json, stream, timeout):
|
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
|
||||||
assert url == "https://ollama-hostname:8000/api/generate/"
|
assert url == "https://ollama-hostname:8000/api/generate/"
|
||||||
assert headers == {
|
assert headers == {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
@ -52,7 +52,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
|
|||||||
def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
|
def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
|
||||||
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
||||||
|
|
||||||
def mock_post(url, headers, json, stream, timeout):
|
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
|
||||||
assert url == "https://ollama-hostname:8000/api/generate/"
|
assert url == "https://ollama-hostname:8000/api/generate/"
|
||||||
assert headers == {
|
assert headers == {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
@ -72,7 +72,7 @@ def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None:
|
|||||||
"""Test that top level params are sent to the endpoint as top level params"""
|
"""Test that top level params are sent to the endpoint as top level params"""
|
||||||
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
||||||
|
|
||||||
def mock_post(url, headers, json, stream, timeout):
|
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
|
||||||
assert url == "https://ollama-hostname:8000/api/generate/"
|
assert url == "https://ollama-hostname:8000/api/generate/"
|
||||||
assert headers == {
|
assert headers == {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
@ -118,7 +118,7 @@ def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None:
|
|||||||
"""
|
"""
|
||||||
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
||||||
|
|
||||||
def mock_post(url, headers, json, stream, timeout):
|
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
|
||||||
assert url == "https://ollama-hostname:8000/api/generate/"
|
assert url == "https://ollama-hostname:8000/api/generate/"
|
||||||
assert headers == {
|
assert headers == {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
@ -165,7 +165,7 @@ def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None:
|
|||||||
"""
|
"""
|
||||||
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
||||||
|
|
||||||
def mock_post(url, headers, json, stream, timeout):
|
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
|
||||||
assert url == "https://ollama-hostname:8000/api/generate/"
|
assert url == "https://ollama-hostname:8000/api/generate/"
|
||||||
assert headers == {
|
assert headers == {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
@ -42,7 +42,7 @@ lint lint_diff lint_package lint_tests:
|
|||||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
||||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
|
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
|
||||||
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
|
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
|
||||||
[ "$(PYTHON_FILES)" = "" ] || mkdir $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||||
|
|
||||||
format format_diff:
|
format format_diff:
|
||||||
poetry run ruff format $(PYTHON_FILES)
|
poetry run ruff format $(PYTHON_FILES)
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user