import inspect
import os
from datetime import datetime
from pathlib import Path
from typing import Callable, Generic
import logging
from celery import Celery, shared_task, Task
from celery.result import AsyncResult
from redis import Redis
from dotenv import find_dotenv, load_dotenv
import json
import celery
from pydantic import BaseModel
from sqlalchemy import create_engine
from ..database.database import SessionWrapper
from ..worker.interface import ReturnType
from celery.utils.log import get_task_logger
# We use get_task_logger to ensure we get the Celery-friendly logger
# but you could also just use logging.getLogger(__name__) if you prefer.
logger = get_task_logger(__name__)
logger.setLevel(logging.INFO)
# Send database url in function queue call. Have a dict in module of database url to engines. Look up engine in dict
[docs]
class JobDescription(BaseModel):
id: str
type: str
name: str
status: str
start_time: str | None
end_time: str | None
result: str | None
[docs]
def read_environment_variables():
load_dotenv(find_dotenv())
host = os.getenv("CELERY_BROKER", "redis://localhost:6379")
return host
# Setup celery
url = read_environment_variables()
logger.info(f"Connecting to {url}")
app = Celery("worker", broker=url, backend=url)
app.conf.update(
task_serializer="pickle",
accept_content=["pickle"], # Allow pickle serialization
result_serializer="pickle",
# Enables tracking of job lifecycle
task_track_started=True,
task_send_sent_event=True,
worker_send_task_events=True,
)
# Setup Redis connection (for job metadata)
# TODO: switch to using utils.load_redis()?
redis_url = "redis" if "localhost" not in url else "localhost"
r = Redis(host=redis_url, port=6379, db=2, decode_responses=True) # TODO: how to set this better?
# logger.warning("No database URL set")
# This is hacky, but defaults to using the test database. Should be synched with what is setup in conftest
# engine = create_engine("sqlite:///test.db", connect_args={"check_same_thread": False})
[docs]
class TrackedTask(Task):
def __call__(self, *args, **kwargs):
# Extract the current task id
task_id = self.request.id
# Create a file handler for this task's logs
file_handler = logging.FileHandler(Path("logs") / f"task_{task_id}.txt")
file_formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
file_handler.setFormatter(file_formatter)
# Remember old handlers so we can restore them later
old_handlers = logger.handlers[:]
# also add this handler to the root-logger, so that logging done by other packages is also logged
root_logger = logging.getLogger()
old_root_handlers = root_logger.handlers[:]
root_logger.addHandler(file_handler)
# Replace the logger handlers with our per-task file handler
logger.handlers = [file_handler]
# also add stdout handler
logger.addHandler(logging.StreamHandler())
try:
# Mark as started when the task is actually executing
r.hmset(
f"job_meta:{task_id}",
{
"status": "STARTED",
# "start_time": datetime.now().isoformat(), # update the start time
},
)
# Execute the actual task
return super().__call__(*args, **kwargs)
finally:
# Close the file handler and restore old handlers after the task is done
file_handler.close()
logger.handlers = old_handlers
root_logger.handlers = old_root_handlers
[docs]
def apply_async(self, args=None, kwargs=None, **options):
# print('apply async', args, kwargs, options)
job_name = kwargs.pop(JOB_NAME_KW, None) or "Unnamed"
job_type = kwargs.pop(JOB_TYPE_KW, None) or "Unspecified"
result = super().apply_async(args=args, kwargs=kwargs, **options)
r.hmset(
f"job_meta:{result.id}",
{"job_name": job_name, "job_type": job_type, "status": "PENDING", "start_time": datetime.now().isoformat()},
)
return result
[docs]
def on_success(self, retval, task_id, args, kwargs):
print("success!")
# start = float(r.hget(f"job_meta:{task_id}", "start_time") or time.time())
# duration = time.time() - start
print(retval)
try:
retval = json.dumps(retval)
except TypeError:
logger.error("RETVAL: Could not serialize return value to JSON")
logger.error(str(retval))
r.hmset(
f"job_meta:{task_id}",
{
"status": "FAILURE",
# "duration": duration,
"error": "Could not serialize return value to JSON",
"end_time": datetime.now().isoformat(),
},
)
raise Exception("Could not serialize return value to JSON. Return value is:" + str(retval))
r.hmset(
f"job_meta:{task_id}",
{
"status": "SUCCESS",
# "duration": duration,
"result": retval,
"end_time": datetime.now().isoformat(),
},
)
[docs]
def on_failure(self, exc, task_id, args, kwargs, einfo):
print("failure!")
# start = float(r.hget(f"job_meta:{task_id}", "start_time") or time.time())
# duration = time.time() - start
r.hmset(
f"job_meta:{task_id}",
{
"status": "FAILURE",
# "duration": duration,
"error": str(exc),
"traceback": str(einfo.traceback),
"end_time": datetime.now().isoformat(),
},
)
@shared_task(name="celery.ping")
def ping():
return "pong"
[docs]
def add_numbers(a: int, b: int):
logger.info(f"Adding {a} + {b}")
return a + b
# set base to TrackedTask to enable per-task logging
@app.task(base=TrackedTask)
def celery_run(func, *args, **kwargs):
return func(*args, **kwargs)
ENGINES_CACHE = {}
@app.task(base=TrackedTask)
def celery_run_with_session(func, *args, **kwargs):
database_url = kwargs.pop("database_url")
if database_url not in ENGINES_CACHE:
ENGINES_CACHE[database_url] = create_engine(database_url)
engine = ENGINES_CACHE[database_url]
named_args = inspect.getfullargspec(func).args
logger.info(f"Running {named_args}")
with SessionWrapper(engine) as session:
ret = func(*args, **kwargs | {"session": session})
return ret
JOB_TYPE_KW = "__job_type__"
JOB_NAME_KW = "__job_name__"
[docs]
class CeleryJob(Generic[ReturnType]):
"""Wrapper for a Celery Job"""
def __init__(self, job: celery.Task, app: Celery):
self._job = job
self._app = app
@property
def _result(self) -> AsyncResult:
return AsyncResult(self._job.id, app=self._app)
@property
def status(self) -> str:
return self._result.state
@property
def result(self) -> ReturnType:
return self._result.result
@property
def progress(self) -> float:
return 0
[docs]
def cancel(self):
self._result.revoke(terminate=True)
# Update Redis metadata to reflect the cancellation
r.hmset(
f"job_meta:{self._job.id}",
{
"status": "REVOKED",
"end_time": datetime.now().isoformat(),
},
)
@property
def id(self):
return self._job.id
@property
def is_finished(self) -> bool:
return self._result.state in ("SUCCESS", "FAILURE", "REVOKED")
@property
def exception_info(self) -> str:
return str(self._result.traceback or "")
[docs]
def get_logs(self) -> str:
log_file = Path("app/logs") / f"task_{self._job.id}.txt" # TODO: not sure why have to specify app/logs...
logger.info(f"Looking for log file at {log_file}")
logger.info(f"Job id is: {self._job.id}")
if log_file.exists():
logs = log_file.read_text()
job_meta = get_job_meta(self.id)
if job_meta["status"] == "FAILURE":
logs += "\n" + job_meta["traceback"]
return logs
else:
# fallback to traceback if log file not found
return self.exception_info
[docs]
class CeleryPool(Generic[ReturnType]):
"""Simple abstraction for a Celery Worker"""
def __init__(self, celery: Celery = None):
assert celery is None
self._celery = app
[docs]
def queue(self, func: Callable[..., ReturnType], *args, **kwargs) -> CeleryJob[ReturnType]:
job = celery_run.delay(func, *args, **kwargs)
return CeleryJob(job, app=self._celery)
[docs]
def queue_db(self, func: Callable[..., ReturnType], *args, **kwargs) -> CeleryJob[ReturnType]:
job = celery_run_with_session.delay(func, *args, **kwargs)
return CeleryJob(job, app=self._celery)
[docs]
def get_job(self, task_id: str) -> CeleryJob[ReturnType]:
return CeleryJob(AsyncResult(task_id, app=self._celery), app=self._celery)
# def _describe_job(self, job_info: dict) -> str:
# func, *args = job_info["args"]
# func_name = func.__name__
# return f"{func_name}({', '.join(map(str, args))})"
# def list_jobs(self) -> List[JobDescription]:
# all_jobs = {'active': self._celery.control.inspect().active(),}
# #'scheduled': self._celery.control.inspect().scheduled(),
# #'reserved': self._celery.control.inspect().reserved()}
# print(all_jobs)
# return [JobDescription(id=info['id'],
# description=self._describe_job(info),
# status=status,
# start_time=datetime.fromtimestamp(info["time_start"]),
# hostname=hostname,
# type=self._get_job_type(info))
# for status, host_dict in all_jobs.items() for hostname, jobs in host_dict.items() for info in jobs]
[docs]
def list_jobs(self, status: str = None):
"""List all tracked jobs stored by Redis. Optional filter by status: PENDING, STARTED, SUCCESS, FAILURE, REVOKED"""
keys = r.keys("job_meta:*")
jobs = []
for key in keys:
task_id = key.split(":")[1]
meta = r.hgetall(key)
meta["task_id"] = task_id
if status is None or meta.get("status") == status:
jobs.append(meta)
return [
JobDescription(
id=meta["task_id"],
type=meta.get("job_type", "Unspecified"),
name=meta.get("job_name", "Unnamed"),
status=meta["status"],
start_time=meta.get("start_time", None),
end_time=meta.get("end_time", None),
result=meta.get("result", None),
)
for meta in sorted(jobs, key=lambda x: x.get("start_time", datetime(1900, 1, 1).isoformat()), reverse=True)
]