Fixed bug in AnalyticDB Vector Store caused by upgrade SQLAlchemy version (#6736)

This commit is contained in:
Richy Wang 2023-06-26 20:35:25 +08:00 committed by GitHub
parent d84a3bcf7a
commit ec8247ec59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -80,34 +80,34 @@ class AnalyticDB(VectorStore):
extend_existing=True, extend_existing=True,
) )
with self.engine.connect() as conn: with self.engine.connect() as conn:
# Create the table with conn.begin():
Base.metadata.create_all(conn) # Create the table
Base.metadata.create_all(conn)
# Check if the index exists # Check if the index exists
index_name = f"{self.collection_name}_embedding_idx" index_name = f"{self.collection_name}_embedding_idx"
index_query = text( index_query = text(
f"""
SELECT 1
FROM pg_indexes
WHERE indexname = '{index_name}';
"""
)
result = conn.execute(index_query).scalar()
# Create the index if it doesn't exist
if not result:
index_statement = text(
f""" f"""
CREATE INDEX {index_name} SELECT 1
ON {self.collection_name} USING ann(embedding) FROM pg_indexes
WITH ( WHERE indexname = '{index_name}';
"dim" = {self.embedding_dimension},
"hnsw_m" = 100
);
""" """
) )
conn.execute(index_statement) result = conn.execute(index_query).scalar()
conn.commit()
# Create the index if it doesn't exist
if not result:
index_statement = text(
f"""
CREATE INDEX {index_name}
ON {self.collection_name} USING ann(embedding)
WITH (
"dim" = {self.embedding_dimension},
"hnsw_m" = 100
);
"""
)
conn.execute(index_statement)
def create_collection(self) -> None: def create_collection(self) -> None:
if self.pre_delete_collection: if self.pre_delete_collection:
@ -118,8 +118,8 @@ class AnalyticDB(VectorStore):
self.logger.debug("Trying to delete collection") self.logger.debug("Trying to delete collection")
drop_statement = text(f"DROP TABLE IF EXISTS {self.collection_name};") drop_statement = text(f"DROP TABLE IF EXISTS {self.collection_name};")
with self.engine.connect() as conn: with self.engine.connect() as conn:
conn.execute(drop_statement) with conn.begin():
conn.commit() conn.execute(drop_statement)
def add_texts( def add_texts(
self, self,
@ -160,30 +160,28 @@ class AnalyticDB(VectorStore):
chunks_table_data = [] chunks_table_data = []
with self.engine.connect() as conn: with self.engine.connect() as conn:
for document, metadata, chunk_id, embedding in zip( with conn.begin():
texts, metadatas, ids, embeddings for document, metadata, chunk_id, embedding in zip(
): texts, metadatas, ids, embeddings
chunks_table_data.append( ):
{ chunks_table_data.append(
"id": chunk_id, {
"embedding": embedding, "id": chunk_id,
"document": document, "embedding": embedding,
"metadata": metadata, "document": document,
} "metadata": metadata,
) }
)
# Execute the batch insert when the batch size is reached # Execute the batch insert when the batch size is reached
if len(chunks_table_data) == batch_size: if len(chunks_table_data) == batch_size:
conn.execute(insert(chunks_table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
conn.execute(insert(chunks_table).values(chunks_table_data)) conn.execute(insert(chunks_table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
conn.execute(insert(chunks_table).values(chunks_table_data))
# Commit the transaction only once after all records have been inserted
conn.commit()
return ids return ids
@ -333,9 +331,9 @@ class AnalyticDB(VectorStore):
) -> AnalyticDB: ) -> AnalyticDB:
""" """
Return VectorStore initialized from texts and embeddings. Return VectorStore initialized from texts and embeddings.
Postgres connection string is required Postgres Connection string is required
Either pass it as a parameter Either pass it as a parameter
or set the PGVECTOR_CONNECTION_STRING environment variable. or set the PG_CONNECTION_STRING environment variable.
""" """
connection_string = cls.get_connection_string(kwargs) connection_string = cls.get_connection_string(kwargs)
@ -363,7 +361,7 @@ class AnalyticDB(VectorStore):
raise ValueError( raise ValueError(
"Postgres connection string is required" "Postgres connection string is required"
"Either pass it as a parameter" "Either pass it as a parameter"
"or set the PGVECTOR_CONNECTION_STRING environment variable." "or set the PG_CONNECTION_STRING environment variable."
) )
return connection_string return connection_string
@ -381,9 +379,9 @@ class AnalyticDB(VectorStore):
) -> AnalyticDB: ) -> AnalyticDB:
""" """
Return VectorStore initialized from documents and embeddings. Return VectorStore initialized from documents and embeddings.
Postgres connection string is required Postgres Connection string is required
Either pass it as a parameter Either pass it as a parameter
or set the PGVECTOR_CONNECTION_STRING environment variable. or set the PG_CONNECTION_STRING environment variable.
""" """
texts = [d.page_content for d in documents] texts = [d.page_content for d in documents]