Source code for opdi.pipeline.flights

"""
Flight list generation module.

Creates flight-level summaries from processed tracks by:
1. Detecting departures and arrivals using H3 airport zones
2. Classifying flights as take-off, landing, or overflight
3. Enriching with aircraft metadata from the OSN aircraft database
4. Producing the OPDI flight list table

Ported from: OPDI-live/python/v2.0.0/03_opdi_flight_list_v2.py
"""

import os
from datetime import date, datetime
from typing import List

import pandas as pd

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.sql.functions import (
    broadcast,
    col,
    collect_list,
    concat_ws,
    expr,
    lit,
    max as f_max,
    min as f_min,
    radians,
    row_number,
    sin,
    cos,
    sqrt,
    atan2,
    to_date,
    to_timestamp,
    unix_timestamp,
    when,
)
from pyspark.sql.window import Window

from opdi.config import OPDIConfig
from opdi.utils.datetime_helpers import (
    generate_months,
    get_start_end_of_month,
)


[docs] class FlightListProcessor: """ Generates the OPDI flight list from processed track data. The flight list is produced in two phases: * **DAI (Departures/Arrivals/Internal)** -- Identifies flights with known departure and/or arrival airports by matching track points to H3 airport detection zones within 30 NM and below FL40. * **Overflights** -- Captures remaining tracks (no airport match) that have ADS-B signals lasting at least 5 minutes. Each flight is enriched with aircraft metadata from the OSN aircraft database (registration, model, typecode, operator). Args: spark: Active SparkSession. config: OPDI configuration object. log_dir: Directory for processing progress logs. Example: >>> processor = FlightListProcessor(spark, config) >>> processor.process_date_range(date(2024, 1, 1), date(2024, 3, 1)) """ MAX_FL = 40 # Maximum flight level for airport zone matching
[docs] def __init__( self, spark: SparkSession, config: OPDIConfig, log_dir: str = "OPDI_live/logs", ): self.spark = spark self.config = config self.project = config.project.project_name self.resolution = config.h3.airport_detection_resolution self.log_dir = log_dir self._dai_log = os.path.join(log_dir, "03_osn-flight_table-etl-log-v2.parquet") self._overflight_log = os.path.join( log_dir, "03_osn-flight_table-overflights-etl-log-v2.parquet" ) os.makedirs(log_dir, exist_ok=True)
# ------------------------------------------------------------------ # Progress tracking # ------------------------------------------------------------------ def _load_processed_months(self, log_path: str) -> List[date]: """Load list of already processed months from log file.""" if os.path.isfile(log_path): return pd.read_parquet(log_path).months.to_list() return [] def _mark_month_processed(self, month: date, log_path: str) -> None: """Mark a month as processed.""" processed = self._load_processed_months(log_path) if month not in processed: processed.append(month) pd.DataFrame({"months": processed}).to_parquet(log_path) # ------------------------------------------------------------------ # Data retrieval # ------------------------------------------------------------------ def _get_data_within_timeframe( self, table_name: str, month: date, time_col: str = "event_time" ) -> DataFrame: """Retrieve records from a table within a monthly timeframe.""" start_ts, end_ts = get_start_end_of_month(month) start_lit = to_timestamp(lit(start_ts)) end_lit = to_timestamp(lit(end_ts)) df = self.spark.table(table_name) return df.filter((col(time_col) >= start_lit) & (col(time_col) < end_lit)) def _load_airports_hex(self, airports_hex_path: str) -> DataFrame: """ Load the airport hex detection zone reference data. The data should be pre-generated by AirportDetectionZoneGenerator and saved as a parquet file. It contains H3 hex IDs within 30 NM of each airport with distance-from-center measurements. Args: airports_hex_path: Path to the preprocessed airport hex parquet. Returns: Spark DataFrame with airport hex zones. """ df_apt = pd.read_parquet(airports_hex_path) return self.spark.createDataFrame(df_apt.to_dict(orient="records")) # ------------------------------------------------------------------ # DAI processing (Departures / Arrivals / Internal) # ------------------------------------------------------------------ def _fetch_and_label_sv( self, month: date, sdf_apt: DataFrame ) -> DataFrame: """ Fetch track data and label state vectors near airports. Steps: 1. Fetch tracks for the month 2. Add flight level from baro_altitude 3. Add first_seen, last_seen, DOF per track 4. Filter to below FL40 5. Join with airport hex zones within 30 NM Args: month: Month to process. sdf_apt: Airport hex zone reference DataFrame. Returns: DataFrame of state vectors near airports. """ sv = self._get_data_within_timeframe( f"{self.project}.osn_tracks", month ) sv_f = sv.dropna(subset=["lat", "lon", "baro_altitude", "track_id"]) sv_f = sv_f.withColumnRenamed("callsign", "flight_id") sv_f = sv_f.fillna({"flight_id": ""}) sv_f = sv_f.withColumn("event_time", F.to_timestamp(col("event_time"))) sv_f = sv_f.withColumn( "flight_level", (col("baro_altitude") * 3.28084 / 100).cast("int") ) columns_of_interest = [ "track_id", "icao24", "flight_id", "event_time", "lat", "lon", "flight_level", "baro_altitude", "heading", "vert_rate", "on_ground", "h3_res_7", ] sv_f = sv_f.select(columns_of_interest) # Per-track first/last seen and DOF window_track = Window.partitionBy("track_id") sv_f = sv_f.withColumn("first_seen", f_min("event_time").over(window_track)) sv_f = sv_f.withColumn("last_seen", f_max("event_time").over(window_track)) sv_f = sv_f.withColumn("DOF", to_date("first_seen")) # Filter to low altitude (below FL40) and join with airport zones sv_low_alt = sv_f.filter(col("flight_level") <= self.MAX_FL) sv_nearby_apt = sv_low_alt.join( sdf_apt, sv_low_alt.h3_res_7 == sdf_apt.apt_hex_id, "left" ) return sv_nearby_apt @staticmethod def _categorize_landing_take_off(df: DataFrame) -> DataFrame: """ Classify each track-airport pair as take-off, landing, or ambiguous. Uses a smoothed altitude change analysis: if the altitude is mostly increasing near the airport, it's a take-off; if decreasing, it's a landing. A margin of +4 state vectors prevents noise from flipping the classification. Args: df: DataFrame from _fetch_and_label_sv. Returns: DataFrame with 'status' column (take-off / landing / ambiguous). """ window_spec = Window.partitionBy( ["icao24", "flight_id", "track_id", "apt_ident"] ).orderBy("event_time") # Smoothed altitude window_avg = Window.partitionBy( ["icao24", "flight_id", "track_id", "apt_ident"] ).orderBy("event_time").rowsBetween(-2, 2) df_m = df.withColumn("smoothed_altitude", F.avg("baro_altitude").over(window_avg)) df_m = df_m.withColumn( "altitude_change", col("smoothed_altitude") - F.lag("smoothed_altitude").over(window_spec), ) df_m = df_m.withColumn( "trajectory_type", when(col("altitude_change") > 0, "take-off") .when(col("altitude_change") < 0, "landing") .otherwise("constant altitude"), ) # Aggregate per track-airport flight_type_df = df_m.groupBy( ["icao24", "flight_id", "track_id", "apt_ident"] ).agg( F.sum(when(col("trajectory_type") == "take-off", 1).otherwise(0)).alias("take_off_count"), F.sum(when(col("trajectory_type") == "landing", 1).otherwise(0)).alias("landing_count"), ) # Classify with +4 margin (at least 20s in one state) flight_type_df = flight_type_df.withColumn( "status", when(col("take_off_count") > (col("landing_count") + 4), "take-off") .when(col("landing_count") > (col("take_off_count") + 4), "landing") .otherwise("ambiguous"), ) return df.join( flight_type_df, on=["icao24", "flight_id", "track_id", "apt_ident"], how="left", ) @staticmethod def _compute_flight_table(df: DataFrame) -> DataFrame: """ Create the flight table from classified tracks. For each track with a take-off or landing classification: 1. Find the state vector closest to the airport center 2. Use Haversine distance to resolve multi-airport ambiguity 3. Merge departures and arrivals into a single flight record Args: df: DataFrame from _categorize_landing_take_off. Returns: DataFrame with ADEP, ADES, and flight metadata. """ df = df.filter(df.status != "ambiguous") # Find the point closest to airport center for each track-status pair window_spec = Window.partitionBy(["icao24", "flight_id", "track_id", "status"]) df = df.withColumn("min_distance", f_min("distance_from_center").over(window_spec)) df_min = df.filter(df.distance_from_center == df.min_distance) df_min = df_min.select( "icao24", "flight_id", "track_id", "apt_ident", "apt_longitude_deg", "apt_latitude_deg", "DOF", "first_seen", "last_seen", "status", "event_time", "lat", "lon", "min_distance", "take_off_count", "landing_count", ) # Get time range per track-airport-status window_spec2 = Window.partitionBy( ["icao24", "flight_id", "track_id", "apt_ident", "status"] ) df_min = df_min.withColumn("min_time", f_min("event_time").over(window_spec2)) df_min = df_min.withColumn("max_time", f_max("event_time").over(window_spec2)) df_take_off = df_min.filter( (col("status") == "take-off") & (col("event_time") == col("min_time")) ) df_landing = df_min.filter( (col("status") == "landing") & (col("event_time") == col("max_time")) ) flight_table = df_take_off.union(df_landing) # Haversine distance to airport for multi-airport disambiguation R = 6371.0 flight_table = ( flight_table .withColumn("lat1", radians(col("lat"))) .withColumn("lon1", radians(col("lon"))) .withColumn("lat2", radians(col("apt_latitude_deg"))) .withColumn("lon2", radians(col("apt_longitude_deg"))) .withColumn("dlat", col("lat2") - col("lat1")) .withColumn("dlon", col("lon2") - col("lon1")) .withColumn( "a", sin(col("dlat") / 2) ** 2 + cos(col("lat1")) * cos(col("lat2")) * sin(col("dlon") / 2) ** 2, ) .withColumn("c", 2 * atan2(sqrt(col("a")), sqrt(1 - col("a")))) .withColumn("distance_km", R * col("c")) ) # Select closest airport per flight key_columns = ["icao24", "flight_id", "track_id", "status", "first_seen", "last_seen"] window_closest = Window.partitionBy(key_columns).orderBy(col("distance_km")) df_numbered = flight_table.withColumn("row_number", row_number().over(window_closest)) df_numbered = df_numbered.withColumn("is_most_likely", col("row_number") == 1) result_df = df_numbered.groupBy(key_columns).agg( expr("first(apt_ident) as most_likely_airport"), collect_list( expr("case when not is_most_likely then apt_ident end") ).alias("potential_airports"), ) result_df = result_df.select( *key_columns, col("most_likely_airport"), col("potential_airports") ) # Split into departures and arrivals, then merge take_offs = ( result_df.filter(col("status") == "take-off") .withColumnRenamed("most_likely_airport", "ADEP") .withColumnRenamed("potential_airports", "ADEP_P") ) landings = ( result_df.filter(col("status") == "landing") .withColumnRenamed("most_likely_airport", "ADES") .withColumnRenamed("potential_airports", "ADES_P") ) key_cols = ["icao24", "flight_id", "track_id", "first_seen", "last_seen"] flight_table = take_offs.drop("status").join( landings.drop("status"), on=key_cols, how="outer" ) flight_table = flight_table.withColumn("DOF", to_date(col("first_seen"))) flight_table = ( flight_table .withColumnRenamed("track_id", "id") .withColumnRenamed("icao24", "ICAO24") .withColumnRenamed("flight_id", "FLT_ID") ) flight_table = flight_table.withColumn("version", lit("v2.0.0")) flight_table = flight_table.withColumn("ADEP_P", concat_ws(", ", col("ADEP_P"))) flight_table = flight_table.withColumn("ADES_P", concat_ws(", ", col("ADES_P"))) return flight_table.select( "id", "ADEP", "ADES", "ADEP_P", "ADES_P", "ICAO24", "FLT_ID", "first_seen", "last_seen", "DOF", "version", ) def _add_osn_aircraft_db_data(self, flight_table: DataFrame) -> DataFrame: """ Enrich the flight table with aircraft metadata from the OSN database. Args: flight_table: Flight table DataFrame. Returns: Enriched DataFrame with registration, model, typecode, etc. """ osn_aircraft_db = self.spark.table(f"{self.project}.osn_aircraft_db") merged = flight_table.alias("ft").join( osn_aircraft_db.alias("adb"), col("ft.ICAO24") == col("adb.icao24"), "left", ) merged_upper = merged.select( *[col(f"ft.{c}").alias(c.upper()) for c in flight_table.columns], *[ col(f"adb.{c}").alias(c.upper()) for c in osn_aircraft_db.columns if c != "icao24" ], ) return merged_upper.select( "ID", "ICAO24", "FLT_ID", "DOF", "ADEP", "ADES", "ADEP_P", "ADES_P", "REGISTRATION", "MODEL", "TYPECODE", "ICAO_AIRCRAFT_CLASS", "ICAO_OPERATOR", "FIRST_SEEN", "LAST_SEEN", "VERSION", ) # ------------------------------------------------------------------ # Overflight processing # ------------------------------------------------------------------ def _fetch_overflights(self, month: date) -> DataFrame: """ Identify overflights: tracks not in the DAI flight list. Overflights are tracks with ADS-B signals lasting >= 5 minutes that don't already appear in the flight list. Args: month: Month to process. Returns: DataFrame of overflight records. """ sv = self._get_data_within_timeframe(f"{self.project}.osn_tracks", month) fl = self._get_data_within_timeframe( f"{self.project}.opdi_flight_list_v2", month, time_col="first_seen", ).select("id") window_track = Window.partitionBy("track_id") sv = sv.withColumn("event_time", F.to_timestamp(col("event_time"))) sv = sv.withColumn("first_seen", f_min("event_time").over(window_track)) sv = sv.withColumn("last_seen", f_max("event_time").over(window_track)) # Keep only the first row per track sv_f = sv.filter(col("first_seen") == col("event_time")) sv_f = sv_f.withColumn("event_date", to_date("event_time")) sv_f = sv_f.withColumn("DOF", f_min("event_date").over(window_track)) sv_f = sv_f.withColumnRenamed("track_id", "id") sv_f = sv_f.withColumnRenamed("icao24", "ICAO24") sv_f = sv_f.withColumnRenamed("callsign", "FLT_ID") for col_name in ["ADEP", "ADES", "ADEP_P", "ADES_P"]: sv_f = sv_f.withColumn(col_name, lit(None).cast("string")) sv_f = sv_f.withColumn("version", lit("v2.0.0")) sv_f = sv_f.select( "id", "ADEP", "ADES", "ADEP_P", "ADES_P", "ICAO24", "FLT_ID", "first_seen", "last_seen", "DOF", "version", ) # Anti-join to exclude flights already in the flight list fl_broadcast = broadcast(fl) sv_f = sv_f.join(fl_broadcast, sv_f.id == fl.id, "left_anti") # Filter out short ADS-B signals (< 5 min) sv_f = sv_f.filter( (unix_timestamp("last_seen") - unix_timestamp("first_seen")) >= 300 ) return sv_f # ------------------------------------------------------------------ # Main processing entry points # ------------------------------------------------------------------
[docs] def process_dai( self, month: date, airports_hex_path: str, skip_if_processed: bool = True, ) -> None: """ Process Departures/Arrivals/Internal flights for a month. Args: month: Month to process. airports_hex_path: Path to preprocessed airport hex zones parquet. skip_if_processed: Skip if month already processed. """ if skip_if_processed and month in self._load_processed_months(self._dai_log): print(f"Month DAI {month} already processed. Skipping.") return print(f"Processing DAI for {month}...") sdf_apt = self._load_airports_hex(airports_hex_path) sv_nearby = self._fetch_and_label_sv(month, sdf_apt) sv_classified = self._categorize_landing_take_off(sv_nearby) flight_table = self._compute_flight_table(sv_classified) flight_table = self._add_osn_aircraft_db_data(flight_table) # Prepare and write flight_table = flight_table.withColumn("DOF_day", to_date(col("DOF"))) flight_table = flight_table.repartition("DOF_day").orderBy("DOF_day") flight_table = flight_table.drop("DOF_day") flight_table.writeTo(f"`{self.project}`.`opdi_flight_list_v2`").append() self._mark_month_processed(month, self._dai_log) print(f"DAI processing complete for {month}.")
[docs] def process_overflights( self, month: date, skip_if_processed: bool = True, ) -> None: """ Process overflight records for a month. Args: month: Month to process. skip_if_processed: Skip if month already processed. """ if skip_if_processed and month in self._load_processed_months(self._overflight_log): print(f"Month overflights {month} already processed. Skipping.") return print(f"Processing overflights for {month}...") flight_table = self._fetch_overflights(month) flight_table = self._add_osn_aircraft_db_data(flight_table) flight_table = flight_table.withColumn("DOF_day", to_date(col("DOF"))) flight_table = flight_table.repartition("DOF_day").orderBy("DOF_day") flight_table = flight_table.drop("DOF_day") flight_table.writeTo(f"`{self.project}`.`opdi_flight_list_v2`").append() self._mark_month_processed(month, self._overflight_log) print(f"Overflight processing complete for {month}.")
[docs] def process_date_range( self, start_month: date, end_month: date, airports_hex_path: str, skip_if_processed: bool = True, ) -> None: """ Process the complete flight list for a range of months. Runs DAI processing first, then overflight processing for each month. Args: start_month: First month to process. end_month: Last month to process. airports_hex_path: Path to preprocessed airport hex zones parquet. skip_if_processed: Skip already processed months. Example: >>> processor = FlightListProcessor(spark, config) >>> processor.process_date_range( ... date(2024, 1, 1), ... date(2024, 6, 1), ... "data/airport_hex/zones_res7_processed.parquet" ... ) """ months = generate_months(start_month, end_month) print(f"Processing flight list for {len(months)} months...") for month in months: self.process_dai(month, airports_hex_path, skip_if_processed) for month in months: self.process_overflights(month, skip_if_processed) print(f"Flight list processing complete for {start_month} to {end_month}.")
[docs] def create_table_if_not_exists(self) -> None: """Create the opdi_flight_list_v2 Iceberg table if it doesn't exist.""" today = datetime.today().strftime("%d %B %Y") create_sql = f""" CREATE TABLE IF NOT EXISTS `{self.project}`.`opdi_flight_list_v2` ( ID STRING COMMENT 'Unique flight identifier (track_id)', ICAO24 STRING COMMENT '24-bit ICAO transponder address', FLT_ID STRING COMMENT 'Flight callsign', DOF DATE COMMENT 'Date of flight', ADEP STRING COMMENT 'Aerodrome of departure (ICAO code)', ADES STRING COMMENT 'Aerodrome of destination (ICAO code)', ADEP_P STRING COMMENT 'Alternative departure airports', ADES_P STRING COMMENT 'Alternative destination airports', REGISTRATION STRING COMMENT 'Aircraft registration', MODEL STRING COMMENT 'Aircraft model', TYPECODE STRING COMMENT 'ICAO type designator', ICAO_AIRCRAFT_CLASS STRING COMMENT 'ICAO aircraft class', ICAO_OPERATOR STRING COMMENT 'ICAO operator code', FIRST_SEEN TIMESTAMP COMMENT 'First ADS-B reception time', LAST_SEEN TIMESTAMP COMMENT 'Last ADS-B reception time', VERSION STRING COMMENT 'Processing version' ) USING iceberg PARTITIONED BY (days(FIRST_SEEN)) COMMENT 'OPDI flight list v2. Last updated: {today}.' """ self.spark.sql(create_sql) print(f"Table {self.project}.opdi_flight_list_v2 created/verified.")