#!/usr/bin/env python3 """ Ingest EIA energy infrastructure data via EIA API into PostGIS with GEOID linking. Usage: ingest_eia_energy_layers.py [--category electric|power|gas|all] [--max-records N] [--skip-ingest] [--list-only] Features: - Queries EIA API for energy infrastructure datasets - Filters by category (electric grid, power plants, gas infrastructure) - Imports GeoJSON features to standardized PostGIS table names - Creates source catalog, tract-level GEOID linkage, and summary tables - Skips unavailable or invalid datasets gracefully Environment: - EIA_API_KEY: Required. Your EIA API key from https://www.eia.gov/opendata/ - PGWEB_HOST, PGWEB_PORT, PGWEB_USER, PGWEB_PASSWORD: PostGIS connection """ import argparse import hashlib import json import os import re import sys import time from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Dict, Any import psycopg2 import requests from psycopg2 import sql from psycopg2.extras import execute_values DB_NAME = "data_centers" TRACT_TABLE = "public.data_center_census_tracts_2024" LINK_TABLE = "public.energy_atlas_tract_link" SUMMARY_TABLE = "public.energy_atlas_tract_summary" # EIA API configuration EIA_API_BASE = "https://api.eia.gov/v2" EIA_API_KEY = os.environ.get("EIA_API_KEY") if not EIA_API_KEY: print("ERROR: EIA_API_KEY environment variable not set") sys.exit(1) SERVICE_URL_RE = re.compile(r"https?://[^\s\"']+", re.IGNORECASE) SAFE_RE = re.compile(r"[^a-z0-9_]+") # EIA dataset categories mapped to infrastructure types EIA_DATASETS = { "power": { "category": "power_plants", "endpoints": [ # Plant-level generation assets with coordinates: what + where. "electricity/operating-generator-capacity", # Per-plant monthly net + gross generation (Form EIA-923): how much. "electricity/facility-fuel", ], }, "state_energy": { "category": "state_energy", "endpoints": [ # State Energy Data System (SEDS): annual state-level production, # consumption, price, and expenditure across all energy sources. "seds", ], }, } # Extra data fields (the EIA `data[N]=` query params) each endpoint needs. # operating-generator-capacity returns only id columns by default; latitude/longitude # must be requested explicitly. facility-fuel returns only id columns; generation # values must be requested explicitly. seds returns only id columns; the numeric # value column must be requested explicitly. EIA_DATASET_DATA_FIELDS = { "electricity/operating-generator-capacity": [ "latitude", "longitude", "nameplate-capacity-mw", "net-summer-capacity-mw", "net-winter-capacity-mw", ], "electricity/facility-fuel": ["generation", "gross-generation"], "seds": ["value"], } # Frequency for each endpoint. Drives how period range is discovered and how # pagination iterates. Endpoints not listed default to "monthly". EIA_DATASET_FREQUENCY = { "electricity/operating-generator-capacity": "monthly", "electricity/facility-fuel": "monthly", "seds": "annual", } def endpoint_frequency(endpoint: str) -> str: return EIA_DATASET_FREQUENCY.get(endpoint, "monthly") # Endpoints that do not reliably support retry with ad-hoc data[] field requests. EIA_NO_RETRY_EXTRA_FIELDS = { } # US state abbreviation to FIPS code mapping for state-level GEOID linking STATE_FIPS = { "AL": "01", "AK": "02", "AZ": "04", "AR": "05", "CA": "06", "CO": "08", "CT": "09", "DE": "10", "DC": "11", "FL": "12", "GA": "13", "HI": "15", "ID": "16", "IL": "17", "IN": "18", "IA": "19", "KS": "20", "KY": "21", "LA": "22", "ME": "23", "MD": "24", "MA": "25", "MI": "26", "MN": "27", "MS": "28", "MO": "29", "MT": "30", "NE": "31", "NV": "32", "NH": "33", "NJ": "34", "NM": "35", "NY": "36", "NC": "37", "ND": "38", "OH": "39", "OK": "40", "OR": "41", "PA": "42", "RI": "44", "SC": "45", "SD": "46", "TN": "47", "TX": "48", "UT": "49", "VT": "50", "VA": "51", "WA": "53", "WV": "54", "WI": "55", "WY": "56", "PR": "72", } @dataclass class EIADataset: dataset_id: str name: str category: str api_endpoint: str description: str = "" source_url: str = "" def slugify(value: str) -> str: """Convert string to SQL-safe identifier.""" cleaned = SAFE_RE.sub("_", value.lower()).strip("_") cleaned = re.sub(r"_+", "_", cleaned) return cleaned or "layer" def standardize_table_name(dataset_id: str) -> str: """Generate standardized table name from dataset ID.""" base = f"energy_eia_{slugify(dataset_id)}" if len(base) <= 55: return base digest = hashlib.md5(base.encode("utf-8")).hexdigest()[:8] return f"{base[:46]}_{digest}" class EIAClientError(Exception): """Non-retryable EIA API error (e.g. 400 for unsupported fields).""" def iter_months(start: str, end: str): """Yield 'YYYY-MM' strings from start to end inclusive.""" sy, sm = (int(x) for x in start.split("-")) ey, em = (int(x) for x in end.split("-")) y, m = sy, sm while (y, m) <= (ey, em): yield f"{y:04d}-{m:02d}" m += 1 if m > 12: m = 1 y += 1 def iter_years(start: str, end: str): """Yield 'YYYY' strings from start to end inclusive.""" sy = int(start[:4]) ey = int(end[:4]) for y in range(sy, ey + 1): yield f"{y:04d}" def iter_periods(frequency: str, start: str, end: str): if frequency == "annual": yield from iter_years(start, end) else: yield from iter_months(start, end) def discover_period_range(endpoint: str, frequency: str = "monthly") -> tuple: """Return (earliest, latest) period strings for an endpoint. Forces an explicit frequency so endpoints that publish multiple frequencies (e.g. facility-fuel monthly+annual) return periods in the expected format. Monthly endpoints get 'YYYY-MM'; annual endpoints (e.g. SEDS) get 'YYYY'. Routes through query_eia_api for retry/backoff coverage. """ def _one(direction: str) -> str: data = query_eia_api( endpoint, params={ "length": 1, "sort[0][column]": "period", "sort[0][direction]": direction, }, query_params={"frequency": frequency}, ) rows = (data or {}).get("response", {}).get("data", []) if not rows: raise RuntimeError(f"no rows returned discovering period range for {endpoint}") return rows[0]["period"] return _one("asc"), _one("desc") def query_eia_api( endpoint: str, params: Optional[Dict[str, Any]] = None, extra_data_fields: Optional[List[str]] = None, query_params: Optional[Dict[str, Any]] = None, ) -> Optional[Dict]: """Query EIA API endpoint with retry/backoff on transient errors. Returns parsed JSON on success. Raises EIAClientError on 4xx (caller decides whether to retry without extra fields). Raises requests.RequestException after exhausting retries on transient errors. """ if not endpoint.endswith("/data"): endpoint = f"{endpoint}/data" url = f"{EIA_API_BASE}/{endpoint}/" req_params: Dict[str, Any] = {"api_key": EIA_API_KEY, "length": 5000} if query_params: req_params.update(query_params) if params: req_params.update(params) if extra_data_fields: for i, field in enumerate(extra_data_fields): req_params[f"data[{i}]"] = field max_attempts = 10 base_backoff = 5.0 max_backoff = 120.0 last_exc: Optional[Exception] = None for attempt in range(1, max_attempts + 1): try: resp = requests.get(url, params=req_params, timeout=(10, 120)) if 400 <= resp.status_code < 500 and resp.status_code != 429: raise EIAClientError(f"HTTP {resp.status_code} on {endpoint}: {resp.text[:200]}") resp.raise_for_status() return resp.json() except EIAClientError: raise except (requests.Timeout, requests.ConnectionError, requests.HTTPError, ValueError) as e: last_exc = e if attempt == max_attempts: break sleep_s = min(base_backoff * (2 ** (attempt - 1)), max_backoff) print(f" api error on {endpoint} (attempt {attempt}/{max_attempts}): {e}; retrying in {sleep_s:.0f}s") time.sleep(sleep_s) raise last_exc # type: ignore[misc] def fetch_eia_pages( endpoint: str, max_records: int = 0, extra_data_fields: Optional[List[str]] = None, query_params: Optional[Dict[str, Any]] = None, ) -> Any: """Yield paged EIA records; retry without extra fields on unsupported endpoints.""" page_size = 5000 offset = 0 yielded = 0 used_extra_fields = extra_data_fields previous_first_row: Optional[str] = None while True: params = {"offset": offset, "length": page_size} try: data = query_eia_api( endpoint, params=params, extra_data_fields=used_extra_fields, query_params=query_params, ) except EIAClientError as e: if used_extra_fields and endpoint not in EIA_NO_RETRY_EXTRA_FIELDS: print(f" retrying {endpoint} without extra data fields ({e})") used_extra_fields = None data = query_eia_api( endpoint, params=params, extra_data_fields=None, query_params=query_params, ) else: raise if not data: return response = data.get("response", {}) page_records = response.get("data", []) if not page_records: return # Some EIA endpoints ignore offset and repeat page 1 forever. # Detect repeated first row signature and stop pagination. first_row_sig = json.dumps(page_records[0], sort_keys=True, default=str) if previous_first_row is not None and first_row_sig == previous_first_row: return previous_first_row = first_row_sig total = response.get("total") try: total_int = int(total) if total is not None else None except (TypeError, ValueError): total_int = None if max_records > 0: remaining = max_records - yielded if remaining <= 0: return if len(page_records) > remaining: page_records = page_records[:remaining] yield page_records, used_extra_fields yielded += len(page_records) if max_records > 0 and yielded >= max_records: return if total_int is not None and yielded >= total_int: return if len(page_records) < page_size: return offset += len(page_records) def fetch_eia_pages_by_period( endpoint: str, frequency: str, earliest: str, latest: str, max_records: int = 0, extra_data_fields: Optional[List[str]] = None, ) -> Any: """Yield pages across periods, querying one period (month or year) at a time. EIA's bulk endpoints serve large offsets slowly and return frequent 503s under sustained load. Filtering by &frequency=F&start=X&end=X keeps each query small (~17k–28k rows per month for operating-generator-capacity, ~40k rows per year for SEDS) and dramatically reduces failure rate and wall time. """ yielded = 0 for period in iter_periods(frequency, earliest, latest): if max_records > 0 and yielded >= max_records: return remaining = max_records - yielded if max_records > 0 else 0 period_params = {"frequency": frequency, "start": period, "end": period} for page_records, used_extra_fields in fetch_eia_pages( endpoint, max_records=remaining, extra_data_fields=extra_data_fields, query_params=period_params, ): yield page_records, used_extra_fields, period yielded += len(page_records) if max_records > 0 and yielded >= max_records: return def get_eia_datasets(category: str = "all") -> List[EIADataset]: """Discover EIA datasets by category.""" datasets = [] cats = [category] if category != "all" else list(EIA_DATASETS.keys()) for cat in cats: if cat not in EIA_DATASETS: continue cat_info = EIA_DATASETS[cat] for endpoint in cat_info.get("endpoints", []): dataset_id = endpoint.replace("/", "_") datasets.append( EIADataset( dataset_id=dataset_id, name=endpoint.split("/")[-1], category=cat_info["category"], api_endpoint=endpoint, source_url=f"{EIA_API_BASE}/{endpoint}", ) ) return datasets def parse_seed_line(line: str) -> Optional[EIADataset]: """Parse stub - not used with EIA API.""" return None def read_seed_file(seed_path: Path) -> List[EIADataset]: """Stub - datasets discovered via EIA API catalog.""" return [] def classify_layer(text: str) -> Optional[str]: """Stub - classification handled by get_eia_datasets().""" return None def import_layer_to_postgis(dataset: EIADataset, table_name: str, max_records: int = 0) -> bool: """Import EIA dataset to PostGIS table.""" conn = connect_db() try: extra_fields = EIA_DATASET_DATA_FIELDS.get(dataset.api_endpoint) frequency = endpoint_frequency(dataset.api_endpoint) earliest, latest = discover_period_range(dataset.api_endpoint, frequency) print(f" period range ({frequency}): {earliest} -> {latest}") count = 0 geo_count = 0 initialized = False current_period: Optional[str] = None for page_records, used_extra_fields, period in fetch_eia_pages_by_period( dataset.api_endpoint, frequency=frequency, earliest=earliest, latest=latest, max_records=max_records, extra_data_fields=extra_fields, ): if period != current_period: if current_period is not None: print(f" progress: {count} rows ingested through {current_period}") current_period = period if not initialized: with conn: with conn.cursor() as cur: cur.execute( f""" CREATE TABLE IF NOT EXISTS public.{table_name} ( gid SERIAL PRIMARY KEY, geom GEOMETRY(GEOMETRY, 4326), properties JSONB ) """ ) cur.execute(f"TRUNCATE TABLE public.{table_name}") initialized = True geom_rows = [] prop_rows = [] for record in page_records: props_json = json.dumps(record) lat = record.get("latitude") or record.get("lat") lon = record.get("longitude") or record.get("lon") try: lat = float(lat) if lat is not None else None lon = float(lon) if lon is not None else None except (TypeError, ValueError): lat = lon = None if lat is not None and lon is not None and -90 <= lat <= 90 and -180 <= lon <= 180: geom_rows.append((lon, lat, props_json)) geo_count += 1 else: prop_rows.append((props_json,)) count += 1 with conn: with conn.cursor() as cur: if geom_rows: execute_values( cur, f"INSERT INTO public.{table_name} (geom, properties) VALUES %s", geom_rows, template="(ST_SetSRID(ST_MakePoint(%s, %s), 4326), %s)", page_size=1000, ) if prop_rows: execute_values( cur, f"INSERT INTO public.{table_name} (properties) VALUES %s", prop_rows, template="(%s)", page_size=1000, ) # Track if API ended up running without extra fields after retry. if used_extra_fields is None: extra_fields = None if not initialized: print(" no data returned") return False geo_msg = f", {geo_count} with geometry" if extra_fields else "" print(f" inserted {count} features into {table_name}{geo_msg}") return count > 0 finally: try: conn.close() except: pass def connect_db(): """Connect to PostGIS database.""" return psycopg2.connect( host=os.environ["PGWEB_HOST"], port=os.environ["PGWEB_PORT"], user=os.environ["PGWEB_USER"], password=os.environ["PGWEB_PASSWORD"], dbname=DB_NAME, ) def ensure_source_catalog_table(conn): """Create source catalog table if it doesn't exist.""" with conn: with conn.cursor() as cur: cur.execute( """ create table if not exists public.energy_atlas_layers_catalog ( table_name text primary key, source_item_id text, source_type text, source_title text, source_owner text, source_url text, category text, imported_at timestamptz default now() ) """ ) def upsert_layer_catalog(conn, table_name: str, dataset: EIADataset, category: str): """Upsert layer metadata into source catalog.""" with conn: with conn.cursor() as cur: cur.execute( """ insert into public.energy_atlas_layers_catalog ( table_name, source_item_id, source_type, source_title, source_owner, source_url, category, imported_at ) values (%s,%s,%s,%s,%s,%s,%s,now()) on conflict (table_name) do update set source_item_id = excluded.source_item_id, source_type = excluded.source_type, source_title = excluded.source_title, source_owner = excluded.source_owner, source_url = excluded.source_url, category = excluded.category, imported_at = now() """, ( table_name, dataset.dataset_id, "EIA API", dataset.name, "U.S. Energy Information Administration", dataset.source_url, category, ), ) def add_geom_index_and_analyze(conn, table_name: str): """Create spatial index and analyze table.""" with conn: with conn.cursor() as cur: cur.execute( f"create index if not exists {table_name}_geom_gix on public.{table_name} using gist (geom)" ) cur.execute(f"analyze public.{table_name}") def table_geom_class(conn, table_name: str) -> Optional[str]: """Get geometry type of table.""" with conn.cursor() as cur: cur.execute( f""" select st_geometrytype(geom) from public.{table_name} where geom is not null limit 1 """ ) row = cur.fetchone() if not row: return None return (row[0] or "").lower() def reset_link_tables(conn): """Recreate GEOID linkage tables.""" with conn: with conn.cursor() as cur: cur.execute(f"drop table if exists {LINK_TABLE}") cur.execute( f""" create table {LINK_TABLE} ( geoid text not null, source_table text not null, category text not null, feature_count integer, intersect_length_m numeric, intersect_area_sqm numeric, unique (geoid, source_table) ) """ ) def link_one_table(conn, table_name: str, category: str): """Link infrastructure features to census tracts via spatial join or state FIPS fallback.""" gclass = table_geom_class(conn, table_name) # Check if table has a properties column (EIA API tables do; legacy tables may not) with conn.cursor() as cur: cur.execute( "SELECT count(*) FROM information_schema.columns " "WHERE table_schema='public' AND table_name=%s AND column_name='properties'", (table_name,), ) has_properties = cur.fetchone()[0] > 0 # Check if table has any rows with stateid for state-level fallback stateid_count = 0 if has_properties: with conn.cursor() as cur: cur.execute( f"SELECT count(*) FROM public.{table_name} WHERE geom IS NULL AND properties->>'stateid' IS NOT NULL" ) stateid_count = cur.fetchone()[0] if gclass: # Spatial join for rows that have geometry with conn: with conn.cursor() as cur: if "point" in gclass: cur.execute( f""" INSERT INTO {LINK_TABLE} ( geoid, source_table, category, feature_count, intersect_length_m, intersect_area_sqm ) SELECT t.geoid, %s, %s, count(*)::integer, null, null FROM {TRACT_TABLE} t JOIN public.{table_name} s ON t.geom && s.geom AND st_covers(t.geom, s.geom) GROUP BY t.geoid """, (table_name, category), ) elif "line" in gclass: cur.execute( f""" INSERT INTO {LINK_TABLE} ( geoid, source_table, category, feature_count, intersect_length_m, intersect_area_sqm ) SELECT t.geoid, %s, %s, count(*)::integer, sum(st_length(st_intersection(t.geom, s.geom)::geography)), null FROM {TRACT_TABLE} t JOIN public.{table_name} s ON t.geom && s.geom AND st_intersects(t.geom, s.geom) GROUP BY t.geoid """, (table_name, category), ) elif "polygon" in gclass: cur.execute( f""" INSERT INTO {LINK_TABLE} ( geoid, source_table, category, feature_count, intersect_length_m, intersect_area_sqm ) SELECT t.geoid, %s, %s, count(*)::integer, null, sum(st_area(st_intersection(t.geom, s.geom)::geography)) FROM {TRACT_TABLE} t JOIN public.{table_name} s ON t.geom && s.geom AND st_intersects(t.geom, s.geom) GROUP BY t.geoid """, (table_name, category), ) # State-level GEOID fallback: link rows without geometry using stateid → state FIPS prefix if stateid_count > 0: link_table_by_state(conn, table_name, category) def link_table_by_state(conn, table_name: str, category: str): """Link EIA records to census tracts via state FIPS code prefix on GEOID. EIA data has stateid (2-letter abbreviation). We map to FIPS and match all census tracts whose GEOID starts with that state FIPS code. Each state gets one link row with the count of records for that state. """ # Build a VALUES list for the state FIPS mapping fips_values = ", ".join(f"('{abbr}', '{fips}')" for abbr, fips in STATE_FIPS.items()) with conn: with conn.cursor() as cur: cur.execute( f""" INSERT INTO {LINK_TABLE} ( geoid, source_table, category, feature_count, intersect_length_m, intersect_area_sqm ) SELECT t.geoid, %s, %s, state_counts.record_count::integer, null, null FROM ( SELECT sf.fips AS state_fips, count(*) AS record_count FROM public.{table_name} s JOIN (VALUES {fips_values}) AS sf(abbr, fips) ON upper(s.properties->>'stateid') = sf.abbr WHERE s.geom IS NULL GROUP BY sf.fips ) state_counts JOIN {TRACT_TABLE} t ON left(t.geoid, 2) = state_counts.state_fips ON CONFLICT DO NOTHING """, (table_name, category), ) rows = cur.rowcount if rows > 0: print(f" state-level geocoding linked {rows} tract rows for {table_name}") def build_summary_table(conn): """Create GEOID summary aggregating all categories.""" with conn: with conn.cursor() as cur: cur.execute(f"drop table if exists {SUMMARY_TABLE}") cur.execute( f""" create table {SUMMARY_TABLE} as select geoid, sum(feature_count) filter (where category = 'electric_grid')::integer as electric_grid_feature_count, sum(coalesce(intersect_length_m, 0)) filter (where category = 'electric_grid') as electric_grid_length_m, sum(coalesce(intersect_area_sqm, 0)) filter (where category = 'electric_grid') as electric_grid_area_sqm, sum(feature_count) filter (where category = 'power_plants')::integer as power_plant_feature_count, sum(feature_count) filter (where category = 'gas_infrastructure')::integer as gas_infrastructure_feature_count, sum(coalesce(intersect_length_m, 0)) filter (where category = 'gas_infrastructure') as gas_infrastructure_length_m, sum(coalesce(intersect_area_sqm, 0)) filter (where category = 'gas_infrastructure') as gas_infrastructure_area_sqm from {LINK_TABLE} group by geoid """ ) cur.execute(f"alter table {SUMMARY_TABLE} add primary key (geoid)") cur.execute( f"create index energy_atlas_tract_link_geoid_idx on {LINK_TABLE} (geoid)" ) cur.execute( f"create index energy_atlas_tract_link_category_idx on {LINK_TABLE} (category)" ) cur.execute(f"analyze {LINK_TABLE}") cur.execute(f"analyze {SUMMARY_TABLE}") def build_flat_tables(conn): """Create analyst-friendly flat tables from JSON properties.""" with conn.cursor() as cur: cur.execute( """ select table_name from information_schema.tables where table_schema='public' and table_name in ( 'energy_eia_electricity_operating_generator_capacity', 'energy_eia_electricity_facility_fuel', 'energy_eia_seds' ) """ ) available = {row[0] for row in cur.fetchall()} with conn: with conn.cursor() as cur: cur.execute("drop table if exists public.energy_eia_electric_power_operational_data_flat") cur.execute("drop table if exists public.energy_eia_rto_region_data_flat") if "energy_eia_electricity_operating_generator_capacity" in available: cur.execute("drop table if exists public.energy_eia_operating_generator_capacity_flat") # EIA stored lower-48 longitudes as positive numbers for periods # 2008-01 through 2010-11 (~600k rows). The negative sign is # restored here for any state other than AK (Alaska legitimately # has Aleutian plants east of the dateline with positive lons). # geom is rebuilt from the corrected coordinates so the source # table's pre-correction geometry is discarded. cur.execute( r""" create table public.energy_eia_operating_generator_capacity_flat as with parsed as ( select gid, properties, properties->>'stateid' as state_id_raw, case when (properties->>'latitude') ~ '^-?[0-9]+(\.[0-9]+)?$' then (properties->>'latitude')::double precision end as latitude_raw, case when (properties->>'longitude') ~ '^-?[0-9]+(\.[0-9]+)?$' then (properties->>'longitude')::double precision end as longitude_raw from public.energy_eia_electricity_operating_generator_capacity ), fixed as ( select *, case when longitude_raw > 0 and state_id_raw <> 'AK' then -longitude_raw else longitude_raw end as longitude_fixed from parsed ) select gid, case when latitude_raw is not null and longitude_fixed is not null and latitude_raw between -90 and 90 and longitude_fixed between -180 and 180 then st_setsrid(st_makepoint(longitude_fixed, latitude_raw), 4326) end as geom, properties->>'period' as period, properties->>'plantid' as plant_id, properties->>'plantName' as plant_name, state_id_raw as state_id, properties->>'stateName' as state_name, properties->>'entityid' as entity_id, properties->>'entityName' as entity_name, properties->>'generatorid' as generator_id, properties->>'status' as status, properties->>'sector' as sector, properties->>'sectorName' as sector_name, properties->>'energy_source_code' as energy_source_code, properties->>'energy-source-desc' as energy_source_desc, properties->>'prime_mover_code' as prime_mover_code, properties->>'balancing_authority_code' as balancing_authority_code, properties->>'balancing-authority-name' as balancing_authority_name, latitude_raw as latitude, longitude_fixed as longitude, nullif(properties->>'nameplate-capacity-mw', '')::double precision as nameplate_capacity_mw, nullif(properties->>'net-summer-capacity-mw', '')::double precision as net_summer_capacity_mw, nullif(properties->>'net-winter-capacity-mw', '')::double precision as net_winter_capacity_mw, properties as raw_properties from fixed """ ) cur.execute( "create index energy_eia_operating_generator_capacity_flat_geom_gix " "on public.energy_eia_operating_generator_capacity_flat using gist (geom)" ) cur.execute( "create index energy_eia_operating_generator_capacity_flat_plant_id_idx " "on public.energy_eia_operating_generator_capacity_flat (plant_id)" ) cur.execute( "create index energy_eia_operating_generator_capacity_flat_state_id_idx " "on public.energy_eia_operating_generator_capacity_flat (state_id)" ) cur.execute("analyze public.energy_eia_operating_generator_capacity_flat") if "energy_eia_electricity_facility_fuel" in available: cur.execute("drop table if exists public.energy_eia_facility_fuel_flat") cur.execute( r""" create table public.energy_eia_facility_fuel_flat as select gid, properties->>'period' as period, coalesce(properties->>'plantCode', properties->>'plantid') as plant_id, properties->>'plantName' as plant_name, properties->>'state' as state_id, properties->>'stateDescription' as state_name, properties->>'primeMover' as prime_mover_code, properties->>'primeMoverDescription' as prime_mover_desc, properties->>'fuel2002' as energy_source_code, properties->>'fuel2002Description' as energy_source_desc, case when (properties->>'generation') ~ '^-?[0-9]+(\.[0-9]+)?$' then (properties->>'generation')::double precision else null end as generation_mwh, case when (properties->>'gross-generation') ~ '^-?[0-9]+(\.[0-9]+)?$' then (properties->>'gross-generation')::double precision else null end as gross_generation_mwh, properties as raw_properties from public.energy_eia_electricity_facility_fuel """ ) cur.execute( "create index energy_eia_facility_fuel_flat_plant_id_idx " "on public.energy_eia_facility_fuel_flat (plant_id)" ) cur.execute( "create index energy_eia_facility_fuel_flat_period_idx " "on public.energy_eia_facility_fuel_flat (period)" ) cur.execute( "create index energy_eia_facility_fuel_flat_state_id_idx " "on public.energy_eia_facility_fuel_flat (state_id)" ) cur.execute("analyze public.energy_eia_facility_fuel_flat") if "energy_eia_seds" in available: # SEDS column mapping verified 2026-05-17 via length=5 probe of # https://api.eia.gov/v2/seds/data/. Confirmed keys: period (YYYY), # seriesId, seriesDescription, stateId, stateDescription, value, unit. # No sector field — sector is encoded in seriesId. cur.execute("drop table if exists public.energy_eia_seds_flat") cur.execute( r""" create table public.energy_eia_seds_flat as select gid, properties->>'period' as period, case when (properties->>'period') ~ '^[0-9]{4}$' then (properties->>'period')::integer end as year, properties->>'seriesId' as series_id, properties->>'seriesDescription' as series_description, properties->>'stateId' as state_id, properties->>'stateDescription' as state_name, case when (properties->>'value') ~ '^-?[0-9]+(\.[0-9]+)?$' then (properties->>'value')::double precision end as value, properties->>'unit' as unit, properties as raw_properties from public.energy_eia_seds """ ) cur.execute( "create index energy_eia_seds_flat_state_id_idx " "on public.energy_eia_seds_flat (state_id)" ) cur.execute( "create index energy_eia_seds_flat_series_id_idx " "on public.energy_eia_seds_flat (series_id)" ) cur.execute( "create index energy_eia_seds_flat_year_idx " "on public.energy_eia_seds_flat (year)" ) cur.execute("analyze public.energy_eia_seds_flat") def prune_stale_layer_versions(conn) -> int: """Drop superseded EIA layer tables and remove stale catalog rows. Superseded versions are identified by a normalized source key (source_url without trailing /data). The newest entry is kept. """ with conn.cursor() as cur: cur.execute( """ with ranked as ( select c.table_name, row_number() over ( partition by coalesce( nullif(regexp_replace(c.source_url, '/data/?$', ''), ''), nullif(c.source_item_id, ''), c.table_name ) order by c.imported_at desc, c.table_name desc ) as rn from public.energy_atlas_layers_catalog c ) select r.table_name from ranked r where r.rn > 1 """ ) stale_tables = [row[0] for row in cur.fetchall()] pruned = 0 with conn: with conn.cursor() as cur: for table_name in stale_tables: # Guardrail: only manage script-owned EIA tables. if not table_name.startswith("energy_eia_"): continue cur.execute( """ select exists ( select 1 from information_schema.tables where table_schema='public' and table_name=%s ) """, (table_name,), ) table_exists = cur.fetchone()[0] if table_exists: cur.execute( sql.SQL("drop table if exists public.{} cascade").format( sql.Identifier(table_name) ) ) print(f"pruned stale table public.{table_name}") cur.execute( "delete from public.energy_atlas_layers_catalog where table_name = %s", (table_name,), ) pruned += 1 return pruned def prune_unselected_layers(conn, selected_table_names: List[str]) -> int: """Drop catalog/table entries that are not in the currently selected dataset set.""" selected = set(selected_table_names) with conn.cursor() as cur: cur.execute( """ select table_name from public.energy_atlas_layers_catalog where table_name like 'energy_eia_%' """ ) existing = [row[0] for row in cur.fetchall()] to_remove = [name for name in existing if name not in selected] removed = 0 with conn: with conn.cursor() as cur: for table_name in to_remove: cur.execute( """ select exists ( select 1 from information_schema.tables where table_schema='public' and table_name=%s ) """, (table_name,), ) table_exists = cur.fetchone()[0] if table_exists: cur.execute( sql.SQL("drop table if exists public.{} cascade").format( sql.Identifier(table_name) ) ) print(f"pruned unselected table public.{table_name}") cur.execute( "delete from public.energy_atlas_layers_catalog where table_name = %s", (table_name,), ) removed += 1 return removed FINAL_FLAT_TABLES = ( "energy_eia_operating_generator_capacity_flat", "energy_eia_facility_fuel_flat", "energy_eia_seds_flat", ) def keep_only_target_flat_table(conn) -> int: """Drop all energy_eia_* tables except the final flat tables.""" with conn.cursor() as cur: cur.execute( """ select table_name from information_schema.tables where table_schema='public' and table_name like 'energy_eia_%%' and table_name <> ALL(%s) """, (list(FINAL_FLAT_TABLES),), ) to_drop = [row[0] for row in cur.fetchall()] dropped = 0 with conn: with conn.cursor() as cur: for table_name in to_drop: cur.execute( sql.SQL("drop table if exists public.{} cascade").format( sql.Identifier(table_name) ) ) print(f"dropped non-target table public.{table_name}") dropped += 1 return dropped def parse_args(): """Parse command-line arguments.""" parser = argparse.ArgumentParser( description=( "Ingest EIA energy infrastructure data via EIA API into PostGIS with GEOID linking." ) ) parser.add_argument( "--category", choices=["power", "state_energy", "all"], default="power", help="Infrastructure category to ingest.", ) parser.add_argument( "--max-records", type=int, default=0, help="Cap on API records to process per dataset (0=all).", ) parser.add_argument( "--endpoint", action="append", default=None, help=( "Limit ingest to specific EIA endpoint(s). " "Repeatable. Substring match against api_endpoint. " "Other datasets are skipped (not re-ingested, not pruned)." ), ) parser.add_argument( "--skip-ingest", action="store_true", help="Skip import; rebuild GEOID links/summary only.", ) parser.add_argument( "--list-only", action="store_true", help="List selected datasets and exit.", ) parser.add_argument( "--keep-stale-tables", action="store_true", help="Do not prune superseded EIA tables/catalog entries.", ) return parser.parse_args() def main(): """Main ingestion pipeline.""" args = parse_args() if not EIA_API_KEY: print("ERROR: EIA_API_KEY not set. Export it and try again.") sys.exit(1) # Discover EIA datasets by category datasets = get_eia_datasets(args.category) if not datasets: raise RuntimeError(f"No datasets found for category '{args.category}'") if args.max_records > 0: print(f"limiting to {args.max_records} records per dataset") # Build ingest list with table names datasets_to_ingest = [] for dataset in datasets: if args.endpoint and not any(filt in dataset.api_endpoint for filt in args.endpoint): continue table_name = standardize_table_name(dataset.dataset_id) datasets_to_ingest.append((dataset, table_name, dataset.category)) if args.list_only: print(f"selected_datasets={len(datasets_to_ingest)}") for dataset, table_name, category in datasets_to_ingest: print(f"{category}\tpublic.{table_name}\t{dataset.name}\t{dataset.source_url}") return if not datasets_to_ingest: raise RuntimeError("No datasets selected. Try --category all or --list-only.") conn = connect_db() try: ensure_source_catalog_table(conn) if not args.skip_ingest: for dataset, table_name, category in datasets_to_ingest: print(f"importing {dataset.name} -> public.{table_name} [{category}]") success = import_layer_to_postgis(dataset, table_name, max_records=args.max_records) if success: upsert_layer_catalog(conn, table_name, dataset, category) add_geom_index_and_analyze(conn, table_name) # Pruning compares against the *full* selected set; skip when --endpoint # is narrowing the run, otherwise we'd drop catalog entries for endpoints # we deliberately chose not to touch. if not args.endpoint: selected_table_names = [table_name for _, table_name, _ in datasets_to_ingest] removed = prune_unselected_layers(conn, selected_table_names) if removed > 0: print(f"pruned unselected layers: {removed}") if not args.keep_stale_tables: pruned = prune_stale_layer_versions(conn) if pruned > 0: print(f"pruned stale layer versions: {pruned}") build_flat_tables(conn) dropped_non_target = keep_only_target_flat_table(conn) if dropped_non_target > 0: print(f"dropped non-target energy tables: {dropped_non_target}") with conn.cursor() as cur: cur.execute("select count(*) from public.energy_atlas_layers_catalog") catalog_count = cur.fetchone()[0] counts = {} for tbl in FINAL_FLAT_TABLES: cur.execute( "select to_regclass(%s) is not null", (f"public.{tbl}",), ) if cur.fetchone()[0]: cur.execute(f"select count(*) from public.{tbl}") counts[tbl] = cur.fetchone()[0] else: counts[tbl] = None print(f"\ndone: catalog_layers={catalog_count}") for tbl, n in counts.items(): print(f" {tbl}: {n if n is not None else 'missing'} rows") finally: conn.close() if __name__ == "__main__": main()