This commit is contained in:
Bagatur 2023-08-31 00:43:34 -07:00
parent 8c4e29240c
commit b1644bc9ad
3 changed files with 5 additions and 13 deletions

View File

@ -89,7 +89,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"vector_db = TencentVectorDB(embedding_function=embeddings, connection_params=conn_params)\n", "vector_db = TencentVectorDB(embeddings, conn_params)\n",
"\n", "\n",
"vector_db.add_texts([\"Ankush went to Princeton\"])\n", "vector_db.add_texts([\"Ankush went to Princeton\"])\n",
"query = \"Where did Ankush go to college?\"\n", "query = \"Where did Ankush go to college?\"\n",
@ -114,7 +114,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.17" "version": "3.9.1"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -78,7 +78,7 @@ class TencentVectorDB(VectorStore):
def __init__( def __init__(
self, self,
embedding_function: Embeddings, embedding: Embeddings,
connection_params: ConnectionParams, connection_params: ConnectionParams,
index_params: IndexParams = IndexParams(128), index_params: IndexParams = IndexParams(128),
database_name: str = "LangChainDatabase", database_name: str = "LangChainDatabase",
@ -87,7 +87,7 @@ class TencentVectorDB(VectorStore):
): ):
self.document = guard_import("tcvectordb.model.document") self.document = guard_import("tcvectordb.model.document")
tcvectordb = guard_import("tcvectordb") tcvectordb = guard_import("tcvectordb")
self.embedding_func = embedding_function self.embedding_func = embedding
self.index_params = index_params self.index_params = index_params
self.vdb_client = tcvectordb.VectorDBClient( self.vdb_client = tcvectordb.VectorDBClient(
url=connection_params.url, url=connection_params.url,
@ -193,7 +193,7 @@ class TencentVectorDB(VectorStore):
else: else:
index_params.dimension = dimension index_params.dimension = dimension
vector_db = cls( vector_db = cls(
embedding_function=embedding, embedding=embedding,
connection_params=connection_params, connection_params=connection_params,
index_params=index_params, index_params=index_params,
database_name=database_name, database_name=database_name,

View File

@ -83,11 +83,3 @@ def test_tencent_vector_db_no_drop() -> None:
time.sleep(3) time.sleep(3)
output = docsearch.similarity_search("foo", k=10) output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6 assert len(output) == 6
# if __name__ == "__main__":
# test_tencent_vector_db()
# test_tencent_vector_db_with_score()
# test_tencent_vector_db_max_marginal_relevance_search()
# test_tencent_vector_db_add_extra()
# test_tencent_vector_db_no_drop()