mirror of
https://github.com/fastapi-practices/fastapi_best_architecture.git
synced 2025-08-16 12:52:38 +08:00
498 lines
18 KiB
Python
498 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
import asyncio
|
||
import json
|
||
import math
|
||
|
||
from datetime import datetime, timedelta
|
||
from multiprocessing.util import Finalize
|
||
|
||
from celery import current_app, schedules
|
||
from celery.beat import ScheduleEntry, Scheduler
|
||
from celery.signals import beat_init
|
||
from celery.utils.log import get_logger
|
||
from redis.asyncio.lock import Lock
|
||
from sqlalchemy import select
|
||
from sqlalchemy.exc import DatabaseError, InterfaceError
|
||
|
||
from backend.app.task.enums import PeriodType, TaskSchedulerType
|
||
from backend.app.task.model.scheduler import TaskScheduler
|
||
from backend.app.task.schema.scheduler import CreateTaskSchedulerParam
|
||
from backend.app.task.utils.tzcrontab import TzAwareCrontab, crontab_verify
|
||
from backend.common.exception import errors
|
||
from backend.core.conf import settings
|
||
from backend.database.db import async_db_session
|
||
from backend.database.redis import redis_client
|
||
from backend.utils._await import run_await
|
||
from backend.utils.serializers import select_as_dict
|
||
from backend.utils.timezone import timezone
|
||
|
||
# 此计划程序必须比常规的 5 分钟更频繁地唤醒,因为它需要考虑对计划的外部更改
|
||
DEFAULT_MAX_INTERVAL = 5 # seconds
|
||
|
||
# 计划锁时长,避免重复创建
|
||
DEFAULT_MAX_LOCK_TIMEOUT = 300 # seconds
|
||
|
||
# 锁检测周期,应小于计划锁时长
|
||
DEFAULT_LOCK_INTERVAL = 60 # seconds
|
||
|
||
# Copied from:
|
||
# https://github.com/andymccurdy/redis-py/blob/master/redis/lock.py#L33
|
||
# Changes:
|
||
# The second line from the bottom: The original Lua script intends
|
||
# to extend time to (lock remaining time + additional time); while
|
||
# the script here extend time to an expected expiration time.
|
||
# KEYS[1] - lock name
|
||
# ARGS[1] - token
|
||
# ARGS[2] - additional milliseconds
|
||
# return 1 if the locks time was extended, otherwise 0
|
||
LUA_EXTEND_TO_SCRIPT = """
|
||
local token = redis.call('get', KEYS[1])
|
||
if not token or token ~= ARGV[1] then
|
||
return 0
|
||
end
|
||
local expiration = redis.call('pttl', KEYS[1])
|
||
if not expiration then
|
||
expiration = 0
|
||
end
|
||
if expiration < 0 then
|
||
return 0
|
||
end
|
||
redis.call('pexpire', KEYS[1], ARGV[2])
|
||
return 1
|
||
"""
|
||
|
||
logger = get_logger('fba.schedulers')
|
||
|
||
|
||
class ModelEntry(ScheduleEntry):
|
||
"""任务调度实体"""
|
||
|
||
def __init__(self, model: TaskScheduler, app=None):
|
||
super().__init__(
|
||
app=app or current_app._get_current_object(),
|
||
name=model.name,
|
||
task=model.task,
|
||
)
|
||
try:
|
||
if (
|
||
model.type == TaskSchedulerType.INTERVAL
|
||
and model.interval_every is not None
|
||
and model.interval_period is not None
|
||
):
|
||
self.schedule = schedules.schedule(timedelta(**{model.interval_period: model.interval_every}))
|
||
elif model.type == TaskSchedulerType.CRONTAB and model.crontab is not None:
|
||
crontab_split = model.crontab.split(' ')
|
||
self.schedule = TzAwareCrontab(
|
||
minute=crontab_split[0],
|
||
hour=crontab_split[1],
|
||
day_of_week=crontab_split[2],
|
||
day_of_month=crontab_split[3],
|
||
month_of_year=crontab_split[4],
|
||
)
|
||
else:
|
||
raise errors.NotFoundError(msg=f'{self.name} 计划为空!')
|
||
# logger.debug('Schedule: {}'.format(self.schedule))
|
||
except Exception as e:
|
||
logger.error(f'禁用计划为空的任务 {self.name},详情:{e}')
|
||
asyncio.create_task(self._disable(model))
|
||
|
||
try:
|
||
self.args = json.loads(model.args) if model.args else None
|
||
self.kwargs = json.loads(model.kwargs) if model.kwargs else None
|
||
except ValueError as exc:
|
||
logger.error(f'禁用参数错误的任务:{self.name};error: {str(exc)}')
|
||
asyncio.create_task(self._disable(model))
|
||
|
||
self.options = {}
|
||
for option in ['queue', 'exchange', 'routing_key']:
|
||
value = getattr(model, option)
|
||
if value is None:
|
||
continue
|
||
self.options[option] = value
|
||
|
||
expires = getattr(model, 'expires_', None)
|
||
if expires:
|
||
if isinstance(expires, int):
|
||
self.options['expires'] = expires
|
||
elif isinstance(expires, datetime):
|
||
self.options['expires'] = timezone.from_datetime(expires)
|
||
|
||
if not model.last_run_time:
|
||
model.last_run_time = timezone.now()
|
||
if model.start_time:
|
||
model.last_run_time = timezone.from_datetime(model.start_time) - timedelta(days=365)
|
||
|
||
self.last_run_at = timezone.from_datetime(model.last_run_time)
|
||
self.options['periodic_task_name'] = model.name
|
||
self.model = model
|
||
|
||
async def _disable(self, model: TaskScheduler) -> None:
|
||
"""禁用任务"""
|
||
model.no_changes = True
|
||
self.model.enabled = self.enabled = model.enabled = False
|
||
async with async_db_session.begin():
|
||
setattr(model, 'enabled', False)
|
||
|
||
def is_due(self) -> tuple[bool, int | float]:
|
||
"""任务到期状态"""
|
||
if not self.model.enabled:
|
||
# 重新启用时延迟 5 秒
|
||
return schedules.schedstate(is_due=False, next=5)
|
||
|
||
# 仅在 'start_time' 之后运行
|
||
if self.model.start_time is not None:
|
||
now = timezone.now()
|
||
start_time = timezone.from_datetime(self.model.start_time)
|
||
if now < start_time:
|
||
delay = math.ceil((start_time - now).total_seconds())
|
||
return schedules.schedstate(is_due=False, next=delay)
|
||
|
||
# 一次性任务
|
||
if self.model.one_off and self.model.enabled and self.model.total_run_count > 0:
|
||
self.model.enabled = False
|
||
self.model.total_run_count = 0
|
||
self.model.no_changes = False
|
||
save_fields = ('enabled',)
|
||
run_await(self.save)(save_fields)
|
||
return schedules.schedstate(is_due=False, next=1000000000) # 高延迟,避免重新检查
|
||
|
||
return self.schedule.is_due(self.last_run_at)
|
||
|
||
def __next__(self):
|
||
self.model.last_run_time = timezone.now()
|
||
self.model.total_run_count += 1
|
||
self.model.no_changes = True
|
||
return self.__class__(self.model)
|
||
|
||
next = __next__
|
||
|
||
async def save(self, fields: tuple = ()):
|
||
"""
|
||
保存任务状态字段
|
||
|
||
:param fields: 要保存的其他字段
|
||
:return:
|
||
"""
|
||
async with async_db_session.begin() as db:
|
||
stmt = select(TaskScheduler).where(TaskScheduler.id == self.model.id).with_for_update()
|
||
query = await db.execute(stmt)
|
||
task = query.scalars().first()
|
||
if task:
|
||
for field in ['last_run_time', 'total_run_count', 'no_changes']:
|
||
setattr(task, field, getattr(self.model, field))
|
||
for field in fields:
|
||
setattr(task, field, getattr(self.model, field))
|
||
else:
|
||
logger.warning(f'任务 {self.model.name} 不存在,跳过更新')
|
||
|
||
@classmethod
|
||
async def from_entry(cls, name, app=None, **entry):
|
||
"""保存或更新本地任务调度"""
|
||
async with async_db_session.begin() as db:
|
||
stmt = select(TaskScheduler).where(TaskScheduler.name == name)
|
||
query = await db.execute(stmt)
|
||
task = query.scalars().first()
|
||
temp = await cls._unpack_fields(name, **entry)
|
||
if not task:
|
||
task = TaskScheduler(**temp)
|
||
db.add(task)
|
||
else:
|
||
for key, value in temp.items():
|
||
setattr(task, key, value)
|
||
res = cls(task, app=app)
|
||
return res
|
||
|
||
@staticmethod
|
||
async def to_model_schedule(name: str, task: str, schedule: schedules.schedule | TzAwareCrontab):
|
||
schedule = schedules.maybe_schedule(schedule)
|
||
|
||
async with async_db_session() as db:
|
||
if isinstance(schedule, schedules.schedule):
|
||
every = max(schedule.run_every.total_seconds(), 0)
|
||
spec = {
|
||
'name': name,
|
||
'type': TaskSchedulerType.INTERVAL.value,
|
||
'interval_every': every,
|
||
'interval_period': PeriodType.SECONDS.value,
|
||
}
|
||
stmt = select(TaskScheduler).filter_by(**spec)
|
||
query = await db.execute(stmt)
|
||
obj = query.scalars().first()
|
||
if not obj:
|
||
obj = TaskScheduler(**CreateTaskSchedulerParam(task=task, **spec).model_dump())
|
||
elif isinstance(schedule, schedules.crontab):
|
||
crontab = f'{schedule._orig_minute} {schedule._orig_hour} {schedule._orig_day_of_week} {schedule._orig_day_of_month} {schedule._orig_month_of_year}' # noqa: E501
|
||
crontab_verify(crontab)
|
||
spec = {
|
||
'name': name,
|
||
'type': TaskSchedulerType.CRONTAB.value,
|
||
'crontab': crontab,
|
||
}
|
||
stmt = select(TaskScheduler).filter_by(**spec)
|
||
query = await db.execute(stmt)
|
||
obj = query.scalars().first()
|
||
if not obj:
|
||
obj = TaskScheduler(**CreateTaskSchedulerParam(task=task, **spec).model_dump())
|
||
else:
|
||
raise errors.NotFoundError(msg=f'暂不支持的计划类型:{schedule}')
|
||
|
||
return obj
|
||
|
||
@classmethod
|
||
async def _unpack_fields(
|
||
cls,
|
||
name: str,
|
||
task: str,
|
||
schedule: schedules.schedule | TzAwareCrontab,
|
||
args: tuple | None = None,
|
||
kwargs: dict | None = None,
|
||
options: dict = None,
|
||
**entry,
|
||
) -> dict:
|
||
model_schedule = await cls.to_model_schedule(name, task, schedule)
|
||
model_dict = select_as_dict(model_schedule)
|
||
for k in ['id', 'created_time', 'updated_time']:
|
||
try:
|
||
del model_dict[k]
|
||
except KeyError:
|
||
continue
|
||
model_dict.update(
|
||
args=json.dumps(args, ensure_ascii=False) if args else None,
|
||
kwargs=json.dumps(kwargs, ensure_ascii=False) if kwargs else None,
|
||
**cls._unpack_options(**options or {}),
|
||
**entry,
|
||
)
|
||
return model_dict
|
||
|
||
@classmethod
|
||
def _unpack_options(
|
||
cls,
|
||
queue: str = None,
|
||
exchange: str = None,
|
||
routing_key: str = None,
|
||
start_time: datetime = None,
|
||
expires: datetime = None,
|
||
expire_seconds: int = None,
|
||
one_off: bool = False,
|
||
) -> dict:
|
||
data = {
|
||
'queue': queue,
|
||
'exchange': exchange,
|
||
'routing_key': routing_key,
|
||
'start_time': start_time,
|
||
'expire_time': expires,
|
||
'expire_seconds': expire_seconds,
|
||
'one_off': one_off,
|
||
}
|
||
if expires:
|
||
if isinstance(expires, int):
|
||
data['expire_seconds'] = expires
|
||
elif isinstance(expires, timedelta):
|
||
data['expire_time'] = timezone.now() + expires
|
||
return data
|
||
|
||
|
||
class DatabaseScheduler(Scheduler):
|
||
"""数据库调度程序"""
|
||
|
||
Entry = ModelEntry
|
||
|
||
_schedule = None
|
||
_last_update = None
|
||
_initial_read = True
|
||
_heap_invalidated = False
|
||
|
||
lock: Lock | None = None
|
||
lock_key = f'{settings.CELERY_REDIS_PREFIX}:beat_lock'
|
||
|
||
def __init__(self, *args, **kwargs):
|
||
self.app = kwargs['app']
|
||
self._dirty = set()
|
||
super().__init__(*args, **kwargs)
|
||
self._finalize = Finalize(self, self.sync, exitpriority=5)
|
||
self.max_interval = kwargs.get('max_interval') or self.app.conf.beat_max_loop_interval or DEFAULT_MAX_INTERVAL
|
||
|
||
def setup_schedule(self):
|
||
"""重写父函数"""
|
||
logger.info('setup_schedule')
|
||
tasks = self.schedule
|
||
self.install_default_entries(tasks)
|
||
self.update_from_dict(self.app.conf.beat_schedule)
|
||
|
||
async def get_all_task_schedulers(self):
|
||
"""获取所有任务调度"""
|
||
async with async_db_session() as db:
|
||
logger.debug('DatabaseScheduler: Fetching database schedule')
|
||
stmt = select(TaskScheduler).where(TaskScheduler.enabled == 1)
|
||
query = await db.execute(stmt)
|
||
tasks = query.scalars().all()
|
||
s = {}
|
||
for task in tasks:
|
||
s[task.name] = self.Entry(task, app=self.app)
|
||
return s
|
||
|
||
def schedule_changed(self) -> bool:
|
||
"""任务调度变更状态"""
|
||
now = timezone.now()
|
||
last_update = run_await(redis_client.get)(f'{settings.CELERY_REDIS_PREFIX}:last_update')
|
||
if not last_update:
|
||
run_await(redis_client.set)(f'{settings.CELERY_REDIS_PREFIX}:last_update', timezone.to_str(now))
|
||
return False
|
||
|
||
last, ts = self._last_update, timezone.from_str(last_update)
|
||
try:
|
||
if ts and ts > (last if last else ts):
|
||
return True
|
||
finally:
|
||
self._last_update = now
|
||
|
||
def reserve(self, entry):
|
||
"""重写父函数"""
|
||
new_entry = next(entry)
|
||
# 需要按名称存储条目,因为条目可能会发生变化
|
||
self._dirty.add(new_entry.name)
|
||
return new_entry
|
||
|
||
def close(self):
|
||
"""重写父函数"""
|
||
if self.lock:
|
||
logger.info('beat: Releasing lock')
|
||
if run_await(self.lock.owned)():
|
||
run_await(self.lock.release)()
|
||
self.lock = None
|
||
|
||
super().close()
|
||
|
||
def sync(self):
|
||
"""重写父函数"""
|
||
_tried = set()
|
||
_failed = set()
|
||
try:
|
||
while self._dirty:
|
||
name = self._dirty.pop()
|
||
try:
|
||
tasks = self.schedule
|
||
run_await(tasks[name].save)()
|
||
logger.debug(f'保存任务 {name} 最新状态到数据库')
|
||
_tried.add(name)
|
||
except KeyError as e:
|
||
logger.error(f'保存任务 {name} 最新状态失败:{e} ')
|
||
_failed.add(name)
|
||
except DatabaseError as e:
|
||
logger.exception('同步时出现数据库错误: %r', e)
|
||
except InterfaceError as e:
|
||
logger.warning(f'DatabaseScheduler InterfaceError:{str(e)},等待下次调用时重试...')
|
||
finally:
|
||
# 请稍后重试(仅针对失败的)
|
||
self._dirty |= _failed
|
||
|
||
def update_from_dict(self, beat_dict: dict):
|
||
"""重写父函数"""
|
||
s = {}
|
||
for name, entry_fields in beat_dict.items():
|
||
try:
|
||
entry = run_await(self.Entry.from_entry)(name, app=self.app, **entry_fields)
|
||
if entry.model.enabled:
|
||
s[name] = entry
|
||
except Exception as e:
|
||
logger.error(f'添加任务 {name} 到数据库失败')
|
||
raise e
|
||
|
||
tasks = self.schedule
|
||
tasks.update(s)
|
||
|
||
def install_default_entries(self, data):
|
||
"""重写父函数"""
|
||
entries = {}
|
||
if self.app.conf.result_expires:
|
||
entries.setdefault(
|
||
'celery.backend_cleanup',
|
||
{
|
||
'task': 'celery.backend_cleanup',
|
||
'schedule': schedules.crontab('0', '4', '*'),
|
||
'options': {'expire_seconds': 12 * 3600},
|
||
},
|
||
)
|
||
self.update_from_dict(entries)
|
||
|
||
def schedules_equal(self, *args, **kwargs):
|
||
"""重写父函数"""
|
||
if self._heap_invalidated:
|
||
self._heap_invalidated = False
|
||
return False
|
||
return super().schedules_equal(*args, **kwargs)
|
||
|
||
@property
|
||
def schedule(self) -> dict[str, ModelEntry]:
|
||
"""获取任务调度"""
|
||
initial = update = False
|
||
if self._initial_read:
|
||
logger.debug('DatabaseScheduler: initial read')
|
||
initial = update = True
|
||
self._initial_read = False
|
||
elif self.schedule_changed():
|
||
logger.info('DatabaseScheduler: Schedule changed.')
|
||
update = True
|
||
|
||
if update:
|
||
logger.debug('beat: Synchronizing schedule...')
|
||
self.sync()
|
||
self._schedule = run_await(self.get_all_task_schedulers)()
|
||
# 计划已更改,使 Scheduler.tick 中的堆无效
|
||
if not initial:
|
||
self._heap = []
|
||
self._heap_invalidated = True
|
||
logger.debug(
|
||
'Current schedule:\n%s',
|
||
'\n'.join(repr(entry) for entry in self._schedule.values()),
|
||
)
|
||
|
||
# logger.debug(self._schedule)
|
||
return self._schedule
|
||
|
||
|
||
async def extend_scheduler_lock(lock):
|
||
"""
|
||
延长调度程序锁
|
||
|
||
:param lock: 计划程序锁
|
||
:return:
|
||
"""
|
||
while True:
|
||
await asyncio.sleep(DEFAULT_LOCK_INTERVAL)
|
||
if lock:
|
||
try:
|
||
await lock.extend(DEFAULT_MAX_LOCK_TIMEOUT)
|
||
except Exception as e:
|
||
logger.error(f'Failed to extend lock: {e}')
|
||
|
||
|
||
@beat_init.connect
|
||
def acquire_distributed_beat_lock(sender=None, *args, **kwargs):
|
||
"""
|
||
尝试在启动时获取锁
|
||
|
||
:param sender: 接收方应响应的发送方
|
||
:return:
|
||
"""
|
||
scheduler = sender.scheduler
|
||
if not scheduler.lock_key:
|
||
return
|
||
|
||
logger.debug('beat: Acquiring lock...')
|
||
lock = redis_client.lock(
|
||
scheduler.lock_key,
|
||
timeout=DEFAULT_MAX_LOCK_TIMEOUT,
|
||
sleep=scheduler.max_interval,
|
||
)
|
||
# overwrite redis-py's extend script
|
||
# which will add additional timeout instead of extend to a new timeout
|
||
lock.lua_extend = redis_client.register_script(LUA_EXTEND_TO_SCRIPT)
|
||
run_await(lock.acquire)()
|
||
logger.info('beat: Acquired lock')
|
||
scheduler.lock = lock
|
||
|
||
loop = asyncio.get_event_loop()
|
||
loop.create_task(extend_scheduler_lock(scheduler.lock))
|