From d07db457fca73cf41eb7f1b3f9d504161c2ab56a Mon Sep 17 00:00:00 2001 From: Philippe PRADOS Date: Wed, 14 Feb 2024 20:45:28 +0100 Subject: [PATCH] community[patch]: Fix SQLAlchemyMd5Cache race condition (#16279) If the SQLAlchemyMd5Cache is shared among multiple processes, it is possible to encounter a race condition during the cache update. Co-authored-by: Eugene Yurtsev --- libs/community/langchain_community/cache.py | 41 ++++++++++----------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/libs/community/langchain_community/cache.py b/libs/community/langchain_community/cache.py index 4719d661ba1..c8741c8f67a 100644 --- a/libs/community/langchain_community/cache.py +++ b/libs/community/langchain_community/cache.py @@ -37,13 +37,14 @@ from typing import ( Dict, List, Optional, + Sequence, Tuple, Type, Union, cast, ) -from sqlalchemy import Column, Integer, String, create_engine, select +from sqlalchemy import Column, Integer, String, create_engine, delete, select from sqlalchemy.engine import Row from sqlalchemy.engine.base import Engine from sqlalchemy.orm import Session @@ -1308,37 +1309,33 @@ class SQLAlchemyMd5Cache(BaseCache): def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """Update based on prompt and llm_string.""" - self._delete_previous(prompt, llm_string) - prompt_md5 = self.get_md5(prompt) - items = [ - self.cache_schema( - id=str(uuid.uuid1()), - prompt=prompt, - prompt_md5=prompt_md5, - llm=llm_string, - response=dumps(gen), - idx=i, - ) - for i, gen in enumerate(return_val) - ] with Session(self.engine) as session, session.begin(): + self._delete_previous(session, prompt, llm_string) + prompt_md5 = self.get_md5(prompt) + items = [ + self.cache_schema( + id=str(uuid.uuid1()), + prompt=prompt, + prompt_md5=prompt_md5, + llm=llm_string, + response=dumps(gen), + idx=i, + ) + for i, gen in enumerate(return_val) + ] for item in items: session.merge(item) - def _delete_previous(self, prompt: str, llm_string: str) -> None: + def _delete_previous(self, session: Session, prompt: str, llm_string: str) -> None: stmt = ( - select(self.cache_schema.response) + delete(self.cache_schema) .where(self.cache_schema.prompt_md5 == self.get_md5(prompt)) # type: ignore .where(self.cache_schema.llm == llm_string) .where(self.cache_schema.prompt == prompt) - .order_by(self.cache_schema.idx) ) - with Session(self.engine) as session, session.begin(): - rows = session.execute(stmt).fetchall() - for item in rows: - session.delete(item) + session.execute(stmt) - def _search_rows(self, prompt: str, llm_string: str) -> List[Row]: + def _search_rows(self, prompt: str, llm_string: str) -> Sequence[Row]: prompt_pd5 = self.get_md5(prompt) stmt = ( select(self.cache_schema.response)