diff --git a/composer/workflows/airflow_db_cleanup.py b/composer/workflows/airflow_db_cleanup.py index 65cc48c4688..b78fec91b56 100644 --- a/composer/workflows/airflow_db_cleanup.py +++ b/composer/workflows/airflow_db_cleanup.py @@ -53,7 +53,6 @@ import airflow from airflow import settings -from airflow.jobs.base_job import BaseJob from airflow.models import ( DAG, DagModel, @@ -69,7 +68,7 @@ from airflow.version import version as airflow_version import dateutil.parser -from sqlalchemy import and_, func +from sqlalchemy import and_, func, text from sqlalchemy.exc import ProgrammingError from sqlalchemy.orm import load_only @@ -101,13 +100,6 @@ # List of all the objects that will be deleted. Comment out the DB objects you # want to skip. DATABASE_OBJECTS = [ - { - "airflow_db_model": BaseJob, - "age_check_column": BaseJob.latest_heartbeat, - "keep_last": False, - "keep_last_filters": None, - "keep_last_group_by": None, - }, { "airflow_db_model": DagRun, "age_check_column": DagRun.execution_date, @@ -228,6 +220,35 @@ except Exception as e: logging.error(e) +if AIRFLOW_VERSION < ["2", "6", "0"]: + try: + from airflow.jobs.base_job import BaseJob + DATABASE_OBJECTS.append( + { + "airflow_db_model": BaseJob, + "age_check_column": BaseJob.latest_heartbeat, + "keep_last": False, + "keep_last_filters": None, + "keep_last_group_by": None, + } + ) + except Exception as e: + logging.error(e) +else: + try: + from airflow.jobs.job import Job + DATABASE_OBJECTS.append( + { + "airflow_db_model": Job, + "age_check_column": Job.latest_heartbeat, + "keep_last": False, + "keep_last_filters": None, + "keep_last_group_by": None, + } + ) + except Exception as e: + logging.error(e) + default_args = { "owner": DAG_OWNER_NAME, "depends_on_past": False, @@ -440,16 +461,42 @@ def cleanup_function(**context): session.close() +def cleanup_sessions(): + session = settings.Session() + + try: + logging.info("Deleting sessions...") + before = len(session.execute(text("SELECT * FROM session WHERE expiry > now()::timestamp(0);")).mappings().all()) + session.execute(text("DELETE FROM session WHERE expiry > now()::timestamp(0);")) + after = len(session.execute(text("SELECT * FROM session WHERE expiry > now()::timestamp(0);")).mappings().all()) + logging.info("Deleted {} expired sessions.".format(before-after)) + except Exception as e: + logging.error(e) + + session.commit() + session.close() + + def analyze_db(): session = settings.Session() session.execute("ANALYZE") session.commit() + session.close() analyze_op = PythonOperator( task_id="analyze_query", python_callable=analyze_db, provide_context=True, dag=dag ) +cleanup_session_op = PythonOperator( + task_id="cleanup_sessions", + python_callable=cleanup_sessions, + provide_context=True, + dag=dag +) + +cleanup_session_op.set_downstream(analyze_op) + for db_object in DATABASE_OBJECTS: cleanup_op = PythonOperator( task_id="cleanup_" + str(db_object["airflow_db_model"].__name__),