rm mutable defaults (#9974)

This commit is contained in:
Bagatur 2023-08-29 20:36:27 -07:00 committed by GitHub
parent 6a51672164
commit d762a6b51f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 26 additions and 21 deletions

View File

@ -242,7 +242,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self, self,
name: Optional[str] = "langchainrun-%", name: Optional[str] = "langchainrun-%",
experiment: Optional[str] = "langchain", experiment: Optional[str] = "langchain",
tags: Optional[Dict] = {}, tags: Optional[Dict] = None,
tracking_uri: Optional[str] = None, tracking_uri: Optional[str] = None,
) -> None: ) -> None:
"""Initialize callback handler.""" """Initialize callback handler."""
@ -254,7 +254,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.name = name self.name = name
self.experiment = experiment self.experiment = experiment
self.tags = tags self.tags = tags or {}
self.tracking_uri = tracking_uri self.tracking_uri = tracking_uri
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()

View File

@ -40,12 +40,12 @@ class PromptLayerCallbackHandler(BaseCallbackHandler):
def __init__( def __init__(
self, self,
pl_id_callback: Optional[Callable[..., Any]] = None, pl_id_callback: Optional[Callable[..., Any]] = None,
pl_tags: Optional[List[str]] = [], pl_tags: Optional[List[str]] = None,
) -> None: ) -> None:
"""Initialize the PromptLayerCallbackHandler.""" """Initialize the PromptLayerCallbackHandler."""
_lazy_import_promptlayer() _lazy_import_promptlayer()
self.pl_id_callback = pl_id_callback self.pl_id_callback = pl_id_callback
self.pl_tags = pl_tags self.pl_tags = pl_tags or []
self.runs: Dict[UUID, Dict[str, Any]] = {} self.runs: Dict[UUID, Dict[str, Any]] = {}
def on_chat_model_start( def on_chat_model_start(

View File

@ -33,7 +33,7 @@ class AsyncHtmlLoader(BaseLoader):
verify_ssl: Optional[bool] = True, verify_ssl: Optional[bool] = True,
proxies: Optional[dict] = None, proxies: Optional[dict] = None,
requests_per_second: int = 2, requests_per_second: int = 2,
requests_kwargs: Dict[str, Any] = {}, requests_kwargs: Optional[Dict[str, Any]] = None,
raise_for_status: bool = False, raise_for_status: bool = False,
): ):
"""Initialize with a webpage path.""" """Initialize with a webpage path."""
@ -67,7 +67,7 @@ class AsyncHtmlLoader(BaseLoader):
self.session.proxies.update(proxies) self.session.proxies.update(proxies)
self.requests_per_second = requests_per_second self.requests_per_second = requests_per_second
self.requests_kwargs = requests_kwargs self.requests_kwargs = requests_kwargs or {}
self.raise_for_status = raise_for_status self.raise_for_status = raise_for_status
async def _fetch( async def _fetch(

View File

@ -1,6 +1,6 @@
import logging import logging
from string import Template from string import Template
from typing import Any, Dict from typing import Any, Dict, Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -106,11 +106,12 @@ class NebulaGraph:
"""Returns the schema of the NebulaGraph database""" """Returns the schema of the NebulaGraph database"""
return self.schema return self.schema
def execute(self, query: str, params: dict = {}, retry: int = 0) -> Any: def execute(self, query: str, params: Optional[dict] = None, retry: int = 0) -> Any:
"""Query NebulaGraph database.""" """Query NebulaGraph database."""
from nebula3.Exception import IOErrorException, NoValidSessionException from nebula3.Exception import IOErrorException, NoValidSessionException
from nebula3.fbthrift.transport.TTransport import TTransportException from nebula3.fbthrift.transport.TTransport import TTransportException
params = params or {}
try: try:
result = self.session_pool.execute_parameter(query, params) result = self.session_pool.execute_parameter(query, params)
if not result.is_succeeded(): if not result.is_succeeded():

View File

@ -183,9 +183,10 @@ def make_request(
instruction: str, instruction: str,
conversation: str, conversation: str,
url: str = f"{DEFAULT_NEBULA_SERVICE_URL}{DEFAULT_NEBULA_SERVICE_PATH}", url: str = f"{DEFAULT_NEBULA_SERVICE_URL}{DEFAULT_NEBULA_SERVICE_PATH}",
params: Dict = {}, params: Optional[Dict] = None,
) -> Any: ) -> Any:
"""Generate text from the model.""" """Generate text from the model."""
params = params or {}
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"ApiKey": f"{self.nebula_api_key}", "ApiKey": f"{self.nebula_api_key}",

View File

@ -372,10 +372,10 @@ class Marqo(VectorStore):
index_name: str = "", index_name: str = "",
url: str = "http://localhost:8882", url: str = "http://localhost:8882",
api_key: str = "", api_key: str = "",
add_documents_settings: Optional[Dict[str, Any]] = {}, add_documents_settings: Optional[Dict[str, Any]] = None,
searchable_attributes: Optional[List[str]] = None, searchable_attributes: Optional[List[str]] = None,
page_content_builder: Optional[Callable[[Dict[str, str]], str]] = None, page_content_builder: Optional[Callable[[Dict[str, str]], str]] = None,
index_settings: Optional[Dict[str, Any]] = {}, index_settings: Optional[Dict[str, Any]] = None,
verbose: bool = True, verbose: bool = True,
**kwargs: Any, **kwargs: Any,
) -> Marqo: ) -> Marqo:
@ -435,7 +435,7 @@ class Marqo(VectorStore):
client = marqo.Client(url=url, api_key=api_key) client = marqo.Client(url=url, api_key=api_key)
try: try:
client.create_index(index_name, settings_dict=index_settings) client.create_index(index_name, settings_dict=index_settings or {})
if verbose: if verbose:
print(f"Created {index_name} successfully.") print(f"Created {index_name} successfully.")
except Exception: except Exception:
@ -446,7 +446,7 @@ class Marqo(VectorStore):
client, client,
index_name, index_name,
searchable_attributes=searchable_attributes, searchable_attributes=searchable_attributes,
add_documents_settings=add_documents_settings, add_documents_settings=add_documents_settings or {},
page_content_builder=page_content_builder, page_content_builder=page_content_builder,
) )
instance.add_texts(texts, metadatas) instance.add_texts(texts, metadatas)

View File

@ -991,7 +991,7 @@ class Redis(VectorStore):
self, self,
k: int, k: int,
filter: Optional[RedisFilterExpression] = None, filter: Optional[RedisFilterExpression] = None,
return_fields: List[str] = [], return_fields: Optional[List[str]] = None,
) -> "Query": ) -> "Query":
try: try:
from redis.commands.search.query import Query from redis.commands.search.query import Query
@ -1000,6 +1000,7 @@ class Redis(VectorStore):
"Could not import redis python package. " "Could not import redis python package. "
"Please install it with `pip install redis`." "Please install it with `pip install redis`."
) from e ) from e
return_fields = return_fields or []
vector_key = self._schema.content_vector_key vector_key = self._schema.content_vector_key
base_query = f"@{vector_key}:[VECTOR_RANGE $distance_threshold $vector]" base_query = f"@{vector_key}:[VECTOR_RANGE $distance_threshold $vector]"
@ -1020,7 +1021,7 @@ class Redis(VectorStore):
self, self,
k: int, k: int,
filter: Optional[RedisFilterExpression] = None, filter: Optional[RedisFilterExpression] = None,
return_fields: List[str] = [], return_fields: Optional[List[str]] = None,
) -> "Query": ) -> "Query":
"""Prepare query for vector search. """Prepare query for vector search.
@ -1038,6 +1039,7 @@ class Redis(VectorStore):
"Could not import redis python package. " "Could not import redis python package. "
"Please install it with `pip install redis`." "Please install it with `pip install redis`."
) from e ) from e
return_fields = return_fields or []
query_prefix = "*" query_prefix = "*"
if filter: if filter:
query_prefix = f"{str(filter)}" query_prefix = f"{str(filter)}"

View File

@ -345,8 +345,9 @@ class SingleStoreDB(VectorStore):
def build_where_clause( def build_where_clause(
where_clause_values: List[Any], where_clause_values: List[Any],
sub_filter: dict, sub_filter: dict,
prefix_args: List[str] = [], prefix_args: Optional[List[str]] = None,
) -> None: ) -> None:
prefix_args = prefix_args or []
for key in sub_filter.keys(): for key in sub_filter.keys():
if isinstance(sub_filter[key], dict): if isinstance(sub_filter[key], dict):
build_where_clause( build_where_clause(

View File

@ -463,7 +463,7 @@ class VectaraRetriever(VectorStoreRetriever):
self, self,
texts: List[str], texts: List[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
doc_metadata: Optional[dict] = {}, doc_metadata: Optional[dict] = None,
) -> None: ) -> None:
"""Add text to the Vectara vectorstore. """Add text to the Vectara vectorstore.
@ -471,4 +471,4 @@ class VectaraRetriever(VectorStoreRetriever):
texts (List[str]): The text texts (List[str]): The text
metadatas (List[dict]): Metadata dicts, must line up with existing store metadatas (List[dict]): Metadata dicts, must line up with existing store
""" """
self.vectorstore.add_texts(texts, metadatas, doc_metadata) self.vectorstore.add_texts(texts, metadatas, doc_metadata or {})

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any, List, Optional from typing import Any, Dict, List, Optional
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.vectorstores.milvus import Milvus from langchain.vectorstores.milvus import Milvus
@ -140,7 +140,7 @@ class Zilliz(Milvus):
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
collection_name: str = "LangChainCollection", collection_name: str = "LangChainCollection",
connection_args: dict[str, Any] = {}, connection_args: Optional[Dict[str, Any]] = None,
consistency_level: str = "Session", consistency_level: str = "Session",
index_params: Optional[dict] = None, index_params: Optional[dict] = None,
search_params: Optional[dict] = None, search_params: Optional[dict] = None,
@ -173,7 +173,7 @@ class Zilliz(Milvus):
vector_db = cls( vector_db = cls(
embedding_function=embedding, embedding_function=embedding,
collection_name=collection_name, collection_name=collection_name,
connection_args=connection_args, connection_args=connection_args or {},
consistency_level=consistency_level, consistency_level=consistency_level,
index_params=index_params, index_params=index_params,
search_params=search_params, search_params=search_params,