fix serious bug with Actor.get_next updating all rows instead of only top row

This commit is contained in:
Nick Sweeting 2024-11-17 22:56:03 -08:00
parent 18403b72f0
commit 148ea907bd
No known key found for this signature in database
3 changed files with 26 additions and 16 deletions

View file

@ -84,6 +84,7 @@ ARCHIVEBOX_BUILTIN_PLUGINS = {
'config': PACKAGE_DIR / 'config', 'config': PACKAGE_DIR / 'config',
'core': PACKAGE_DIR / 'core', 'core': PACKAGE_DIR / 'core',
'crawls': PACKAGE_DIR / 'crawls', 'crawls': PACKAGE_DIR / 'crawls',
'queues': PACKAGE_DIR / 'queues',
'seeds': PACKAGE_DIR / 'seeds', 'seeds': PACKAGE_DIR / 'seeds',
'actors': PACKAGE_DIR / 'actors', 'actors': PACKAGE_DIR / 'actors',
# 'search': PACKAGE_DIR / 'search', # 'search': PACKAGE_DIR / 'search',

View file

@ -75,7 +75,7 @@ class ActorType(Generic[ModelType]):
_SPAWNED_ACTOR_PIDS: ClassVar[list[psutil.Process]] = [] # used to record all the pids of Actors spawned on the class _SPAWNED_ACTOR_PIDS: ClassVar[list[psutil.Process]] = [] # used to record all the pids of Actors spawned on the class
### Instance attributes (only used within an actor instance inside a spawned actor thread/process) ### Instance attributes (only used within an actor instance inside a spawned actor thread/process)
pid: int pid: int = os.getpid()
idle_count: int = 0 idle_count: int = 0
launch_kwargs: LaunchKwargs = {} launch_kwargs: LaunchKwargs = {}
mode: Literal['thread', 'process'] = 'process' mode: Literal['thread', 'process'] = 'process'
@ -290,7 +290,7 @@ class ActorType(Generic[ModelType]):
Override this in the subclass to define the QuerySet of objects that the Actor is going to poll for new work. Override this in the subclass to define the QuerySet of objects that the Actor is going to poll for new work.
(don't limit, order, or filter this by retry_at or status yet, Actor.get_queue() handles that part) (don't limit, order, or filter this by retry_at or status yet, Actor.get_queue() handles that part)
""" """
return cls.Model.objects.all() return cls.Model.objects.filter()
@classproperty @classproperty
def final_q(cls) -> Q: def final_q(cls) -> Q:
@ -438,25 +438,30 @@ class ActorType(Generic[ModelType]):
assert select_top_canidates_sql.startswith('SELECT ') assert select_top_canidates_sql.startswith('SELECT ')
# e.g. UPDATE core_archiveresult SET status='%s', retry_at='%s' WHERE status NOT IN (...) AND retry_at <= '...' # e.g. UPDATE core_archiveresult SET status='%s', retry_at='%s' WHERE status NOT IN (...) AND retry_at <= '...'
update_claimed_obj_sql, update_params = self._sql_for_update_claimed_obj(qs=qs, update_kwargs=self.get_update_kwargs_to_claim_obj()) update_claimed_obj_sql, update_params = self._sql_for_update_claimed_obj(qs=self.qs.all(), update_kwargs=self.get_update_kwargs_to_claim_obj())
assert update_claimed_obj_sql.startswith('UPDATE ') assert update_claimed_obj_sql.startswith('UPDATE ') and 'WHERE' not in update_claimed_obj_sql
db_table = self.Model._meta.db_table # e.g. core_archiveresult db_table = self.Model._meta.db_table # e.g. core_archiveresult
# subquery gets the pool of the top candidates e.g. self.get_queue().only('id')[:CLAIM_FROM_TOP_N] # subquery gets the pool of the top candidates e.g. self.get_queue().only('id')[:CLAIM_FROM_TOP_N]
# main query selects a random one from that pool, and claims it using .update(status=ACTIVE_STATE, retry_at=<now + MAX_TICK_TIME>) # main query selects a random one from that pool, and claims it using .update(status=ACTIVE_STATE, retry_at=<now + MAX_TICK_TIME>)
# this is all done in one atomic SQL query to avoid TOCTTOU race conditions (as much as possible) # this is all done in one atomic SQL query to avoid TOCTTOU race conditions (as much as possible)
atomic_select_and_update_sql = f""" atomic_select_and_update_sql = f"""
{update_claimed_obj_sql} AND "{db_table}"."id" = ( with top_candidates AS ({select_top_canidates_sql})
SELECT "{db_table}"."id" FROM ( {update_claimed_obj_sql}
{select_top_canidates_sql} WHERE "{db_table}"."id" IN (
) candidates SELECT id FROM top_candidates
ORDER BY RANDOM() ORDER BY RANDOM()
LIMIT 1 LIMIT 1
) )
RETURNING *; RETURNING *;
""" """
# import ipdb; ipdb.set_trace()
try: try:
return self.Model.objects.raw(atomic_select_and_update_sql, (*update_params, *select_params))[0] updated = qs.raw(atomic_select_and_update_sql, (*select_params, *update_params))
assert len(updated) <= 1, f'Expected to claim at most 1 object, but Django modified {len(updated)} objects!'
return updated[0]
except IndexError: except IndexError:
if self.get_queue().exists(): if self.get_queue().exists():
raise ActorObjectAlreadyClaimed(f'Unable to lock the next {self.Model.__name__} object from {self}.get_queue().first()') raise ActorObjectAlreadyClaimed(f'Unable to lock the next {self.Model.__name__} object from {self}.get_queue().first()')
@ -548,7 +553,7 @@ def compile_sql_select(queryset: QuerySet, filter_kwargs: dict[str, Any] | None=
return select_sql, select_params return select_sql, select_params
def compile_sql_update(queryset: QuerySet, update_kwargs: dict[str, Any], filter_kwargs: dict[str, Any] | None=None) -> tuple[str, tuple[Any, ...]]: def compile_sql_update(queryset: QuerySet, update_kwargs: dict[str, Any]) -> tuple[str, tuple[Any, ...]]:
""" """
Compute the UPDATE query SQL for a queryset.filter(**filter_kwargs).update(**update_kwargs) call Compute the UPDATE query SQL for a queryset.filter(**filter_kwargs).update(**update_kwargs) call
Returns a tuple of (sql, params) where sql is a template string containing %s (unquoted) placeholders for the params Returns a tuple of (sql, params) where sql is a template string containing %s (unquoted) placeholders for the params
@ -562,11 +567,8 @@ def compile_sql_update(queryset: QuerySet, update_kwargs: dict[str, Any], filter
""" """
assert isinstance(queryset, QuerySet), f'compile_sql_update(...) first argument must be a QuerySet, got: {type(queryset).__name__} instead' assert isinstance(queryset, QuerySet), f'compile_sql_update(...) first argument must be a QuerySet, got: {type(queryset).__name__} instead'
assert isinstance(update_kwargs, dict), f'compile_sql_update(...) update_kwargs argument must be a dict[str, Any], got: {type(update_kwargs).__name__} instead' assert isinstance(update_kwargs, dict), f'compile_sql_update(...) update_kwargs argument must be a dict[str, Any], got: {type(update_kwargs).__name__} instead'
assert filter_kwargs is None or isinstance(filter_kwargs, dict), f'compile_sql_update(...) filter_kwargs argument must be a dict[str, Any], got: {type(filter_kwargs).__name__} instead'
queryset = queryset._chain() # type: ignore # copy queryset to avoid modifying the original queryset = queryset._chain().all() # type: ignore # copy queryset to avoid modifying the original and clear any filters
if filter_kwargs:
queryset = queryset.filter(**filter_kwargs)
queryset.query.clear_ordering(force=True) # clear any ORDER BY clauses queryset.query.clear_ordering(force=True) # clear any ORDER BY clauses
queryset.query.clear_limits() # clear any LIMIT clauses aka slices[:n] queryset.query.clear_limits() # clear any LIMIT clauses aka slices[:n]
queryset._for_write = True # type: ignore queryset._for_write = True # type: ignore
@ -576,5 +578,12 @@ def compile_sql_update(queryset: QuerySet, update_kwargs: dict[str, Any], filter
# e.g. UPDATE core_archiveresult SET status='%s', retry_at='%s' WHERE status NOT IN (%s, %s, %s) AND retry_at <= %s # e.g. UPDATE core_archiveresult SET status='%s', retry_at='%s' WHERE status NOT IN (%s, %s, %s) AND retry_at <= %s
update_sql, update_params = query.get_compiler(queryset.db).as_sql() update_sql, update_params = query.get_compiler(queryset.db).as_sql()
# make sure you only pass a raw queryset with no .filter(...) clauses applied to it, the return value is designed to used
# in a manually assembled SQL query with its own WHERE clause later on
assert 'WHERE' not in update_sql, f'compile_sql_update(...) should only contain a SET statement but it tried to return a query with a WHERE clause: {update_sql}'
# print(update_sql, update_params)
return update_sql, update_params return update_sql, update_params

View file

@ -102,7 +102,7 @@ class Orchestrator:
# returns a list of objects that are in the queues of all actor types but not in the queues of any other actor types # returns a list of objects that are in the queues of all actor types but not in the queues of any other actor types
return any( return any(
queue.filter(retry_at__gt=timezone.now()).exists() queue.filter(retry_at__gte=timezone.now()).exists()
for queue in all_queues.values() for queue in all_queues.values()
) )
@ -163,7 +163,7 @@ class Orchestrator:
for actor_type, queue in all_queues.items(): for actor_type, queue in all_queues.items():
next_obj = queue.first() next_obj = queue.first()
print(f'🏃‍♂️ {self}.runloop() {actor_type.__name__.ljust(20)} queue={str(queue.count()).ljust(3)} next={next_obj.abid if next_obj else "None"} {next_obj.status if next_obj else "None"} {(timezone.now() - next_obj.retry_at).total_seconds() if next_obj else "None"}') print(f'🏃‍♂️ {self}.runloop() {actor_type.__name__.ljust(20)} queue={str(queue.count()).ljust(3)} next={next_obj.abid if next_obj else "None"} {next_obj.status if next_obj else "None"} {(timezone.now() - next_obj.retry_at).total_seconds() if next_obj and next_obj.retry_at else "None"}')
try: try:
existing_actors = actor_type.get_running_actors() existing_actors = actor_type.get_running_actors()
all_existing_actors.extend(existing_actors) all_existing_actors.extend(existing_actors)