more orchestrator and actor improvements

This commit is contained in:
Nick Sweeting 2024-11-02 17:25:51 -07:00
parent 721427a484
commit dbe5c0bc07
No known key found for this signature in database
2 changed files with 137 additions and 70 deletions

View file

@ -4,9 +4,12 @@ import os
import time
import itertools
import uuid
from typing import Dict, Type
from typing import Dict, Type, Literal
from django.utils.functional import classproperty
from multiprocessing import Process, cpu_count
from threading import Thread, get_native_id
from rich import print
@ -19,21 +22,41 @@ class Orchestrator:
pid: int
idle_count: int = 0
actor_types: Dict[str, Type[ActorType]]
mode: Literal['thread', 'process'] = 'process'
def __init__(self, actor_types: Dict[str, Type[ActorType]] | None = None):
def __init__(self, actor_types: Dict[str, Type[ActorType]] | None = None, mode: Literal['thread', 'process'] | None=None):
self.actor_types = actor_types or self.actor_types or self.autodiscover_actor_types()
self.mode = mode or self.mode
def __repr__(self) -> str:
return f'[underline]{self.__class__.__name__}[/underline]\\[pid={self.pid}]'
label = 'tid' if self.mode == 'thread' else 'pid'
return f'[underline]{self.name}[/underline]\\[{label}={self.pid}]'
def __str__(self) -> str:
return self.__repr__()
@classproperty
def name(cls) -> str:
return cls.__name__ # type: ignore
def fork_as_thread(self):
self.thread = Thread(target=self.runloop)
self.thread.start()
assert self.thread.native_id is not None
return self.thread.native_id
def fork_as_process(self):
self.process = Process(target=self.runloop)
self.process.start()
assert self.process.pid is not None
return self.process.pid
def start(self) -> int:
orchestrator_bg_proc = Process(target=self.runloop)
orchestrator_bg_proc.start()
assert orchestrator_bg_proc.pid is not None
return orchestrator_bg_proc.pid
if self.mode == 'thread':
return self.fork_as_thread()
elif self.mode == 'process':
return self.fork_as_process()
raise ValueError(f'Invalid orchestrator mode: {self.mode}')
@classmethod
def autodiscover_actor_types(cls) -> Dict[str, Type[ActorType]]:
@ -42,7 +65,8 @@ class Orchestrator:
# return {'Snapshot': SnapshotActorType, 'ArchiveResult_chrome': ChromeActorType, ...}
return {
# look through all models and find all classes that inherit from ActorType
# ...
# actor_type.__name__: actor_type
# for actor_type in abx.pm.hook.get_all_ACTORS_TYPES().values()
}
@classmethod
@ -56,8 +80,12 @@ class Orchestrator:
return orphaned_objects
def on_startup(self):
self.pid = os.getpid()
print(f'[green]👨‍✈️ {self}.on_startup() STARTUP (PROCESS)[/green]')
if self.mode == 'thread':
self.pid = get_native_id()
print(f'[green]👨‍✈️ {self}.on_startup() STARTUP (THREAD)[/green]')
elif self.mode == 'process':
self.pid = os.getpid()
print(f'[green]👨‍✈️ {self}.on_startup() STARTUP (PROCESS)[/green]')
# abx.pm.hook.on_orchestrator_startup(self)
def on_shutdown(self, err: BaseException | None = None):
@ -109,8 +137,10 @@ class Orchestrator:
for launch_kwargs in actors_to_spawn:
new_actor_pid = actor_type.start(mode='process', **launch_kwargs)
all_spawned_actors.append(new_actor_pid)
except BaseException as err:
except Exception as err:
print(f'🏃‍♂️ ERROR: {self} Failed to get {actor_type} queue & running actors', err)
except BaseException:
raise
if not any(queue.exists() for queue in all_queues.values()):
self.on_idle(all_queues)
@ -152,30 +182,36 @@ class FaviconActor(ActorType[ArchiveResult]):
@classmethod
def get_next(cls) -> ArchiveResult | None:
return cls.get_next_atomic(
# return cls.get_next_atomic(
# model=ArchiveResult,
# where='status = "failed"',
# set='status = "started"',
# order_by='created_at DESC',
# choose_from_top=cpu_count() * 10,
# )
return cls.get_random(
model=ArchiveResult,
filter=('status', 'failed'),
update=('status', 'started'),
sort='created_at',
order='DESC',
choose_from_top=cpu_count() * 10
where='status = "failed"',
set='status = "queued"',
choose_from_top=cls.get_queue().count(),
)
def tick(self, obj: ArchiveResult):
print(f'[grey53]{self}.tick({obj.id}) remaining:[/grey53]', self.get_queue().count())
print(f'[grey53]{self}.tick({obj.abid or obj.id}) remaining:[/grey53]', self.get_queue().count())
updated = ArchiveResult.objects.filter(id=obj.id, status='started').update(status='success') == 1
if not updated:
raise Exception(f'Failed to update {obj.abid}, interrupted by another actor writing to the same object')
raise Exception(f'Failed to update {obj.abid or obj.id}, interrupted by another actor writing to the same object')
def lock(self, obj: ArchiveResult) -> bool:
"""As an alternative to self.get_next_atomic(), we can use select_for_update() or manually update a semaphore field here"""
# locked = ArchiveResult.objects.select_for_update(skip_locked=True).filter(id=obj.id, status='pending').update(status='started') == 1
# if locked:
# print(f'FaviconActor[{self.pid}] lock({obj.id}) 🔒')
# else:
# print(f'FaviconActor[{self.pid}] lock({obj.id}) X')
return True
locked = ArchiveResult.objects.filter(id=obj.id, status='queued').update(status='started') == 1
if locked:
# print(f'FaviconActor[{self.pid}] lock({obj.id}) 🔒')
pass
else:
print(f'FaviconActor[{self.pid}] lock({obj.id}) X')
return locked
class ExtractorsOrchestrator(Orchestrator):
@ -192,32 +228,32 @@ if __name__ == '__main__':
assert snap is not None
created = 0
while True:
time.sleep(0.005)
try:
ArchiveResult.objects.bulk_create([
ArchiveResult(
id=uuid.uuid4(),
snapshot=snap,
status='failed',
extractor='favicon',
cmd=['echo', '"hello"'],
cmd_version='1.0',
pwd='.',
start_ts=timezone.now(),
end_ts=timezone.now(),
created_at=timezone.now(),
modified_at=timezone.now(),
created_by_id=1,
)
for _ in range(100)
])
created += 100
if created % 1000 == 0:
print(f'[blue]Created {created} ArchiveResults...[/blue]')
time.sleep(25)
except Exception as err:
print(err)
db.connections.close_all()
except BaseException as err:
print(err)
break
time.sleep(0.05)
# try:
# ArchiveResult.objects.bulk_create([
# ArchiveResult(
# id=uuid.uuid4(),
# snapshot=snap,
# status='failed',
# extractor='favicon',
# cmd=['echo', '"hello"'],
# cmd_version='1.0',
# pwd='.',
# start_ts=timezone.now(),
# end_ts=timezone.now(),
# created_at=timezone.now(),
# modified_at=timezone.now(),
# created_by_id=1,
# )
# for _ in range(100)
# ])
# created += 100
# if created % 1000 == 0:
# print(f'[blue]Created {created} ArchiveResults...[/blue]')
# time.sleep(25)
# except Exception as err:
# print(err)
# db.connections.close_all()
# except BaseException as err:
# print(err)
# break