mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 01:19:31 +00:00
community[patch]: Fix SQLAlchemyMd5Cache race condition (#16279)
If the SQLAlchemyMd5Cache is shared among multiple processes, it is possible to encounter a race condition during the cache update. Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
70c296ae96
commit
d07db457fc
@ -37,13 +37,14 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, String, create_engine, select
|
from sqlalchemy import Column, Integer, String, create_engine, delete, select
|
||||||
from sqlalchemy.engine import Row
|
from sqlalchemy.engine import Row
|
||||||
from sqlalchemy.engine.base import Engine
|
from sqlalchemy.engine.base import Engine
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -1308,37 +1309,33 @@ class SQLAlchemyMd5Cache(BaseCache):
|
|||||||
|
|
||||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
"""Update based on prompt and llm_string."""
|
"""Update based on prompt and llm_string."""
|
||||||
self._delete_previous(prompt, llm_string)
|
|
||||||
prompt_md5 = self.get_md5(prompt)
|
|
||||||
items = [
|
|
||||||
self.cache_schema(
|
|
||||||
id=str(uuid.uuid1()),
|
|
||||||
prompt=prompt,
|
|
||||||
prompt_md5=prompt_md5,
|
|
||||||
llm=llm_string,
|
|
||||||
response=dumps(gen),
|
|
||||||
idx=i,
|
|
||||||
)
|
|
||||||
for i, gen in enumerate(return_val)
|
|
||||||
]
|
|
||||||
with Session(self.engine) as session, session.begin():
|
with Session(self.engine) as session, session.begin():
|
||||||
|
self._delete_previous(session, prompt, llm_string)
|
||||||
|
prompt_md5 = self.get_md5(prompt)
|
||||||
|
items = [
|
||||||
|
self.cache_schema(
|
||||||
|
id=str(uuid.uuid1()),
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_md5=prompt_md5,
|
||||||
|
llm=llm_string,
|
||||||
|
response=dumps(gen),
|
||||||
|
idx=i,
|
||||||
|
)
|
||||||
|
for i, gen in enumerate(return_val)
|
||||||
|
]
|
||||||
for item in items:
|
for item in items:
|
||||||
session.merge(item)
|
session.merge(item)
|
||||||
|
|
||||||
def _delete_previous(self, prompt: str, llm_string: str) -> None:
|
def _delete_previous(self, session: Session, prompt: str, llm_string: str) -> None:
|
||||||
stmt = (
|
stmt = (
|
||||||
select(self.cache_schema.response)
|
delete(self.cache_schema)
|
||||||
.where(self.cache_schema.prompt_md5 == self.get_md5(prompt)) # type: ignore
|
.where(self.cache_schema.prompt_md5 == self.get_md5(prompt)) # type: ignore
|
||||||
.where(self.cache_schema.llm == llm_string)
|
.where(self.cache_schema.llm == llm_string)
|
||||||
.where(self.cache_schema.prompt == prompt)
|
.where(self.cache_schema.prompt == prompt)
|
||||||
.order_by(self.cache_schema.idx)
|
|
||||||
)
|
)
|
||||||
with Session(self.engine) as session, session.begin():
|
session.execute(stmt)
|
||||||
rows = session.execute(stmt).fetchall()
|
|
||||||
for item in rows:
|
|
||||||
session.delete(item)
|
|
||||||
|
|
||||||
def _search_rows(self, prompt: str, llm_string: str) -> List[Row]:
|
def _search_rows(self, prompt: str, llm_string: str) -> Sequence[Row]:
|
||||||
prompt_pd5 = self.get_md5(prompt)
|
prompt_pd5 = self.get_md5(prompt)
|
||||||
stmt = (
|
stmt = (
|
||||||
select(self.cache_schema.response)
|
select(self.cache_schema.response)
|
||||||
|
Loading…
Reference in New Issue
Block a user