Files
data-centers/ingest_eia_energy_layers.py

1206 lines
44 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Ingest EIA energy infrastructure data via EIA API into PostGIS with GEOID linking.
Usage:
ingest_eia_energy_layers.py [--category electric|power|gas|all] [--max-records N] [--skip-ingest] [--list-only]
Features:
- Queries EIA API for energy infrastructure datasets
- Filters by category (electric grid, power plants, gas infrastructure)
- Imports GeoJSON features to standardized PostGIS table names
- Creates source catalog, tract-level GEOID linkage, and summary tables
- Skips unavailable or invalid datasets gracefully
Environment:
- EIA_API_KEY: Required. Your EIA API key from https://www.eia.gov/opendata/
- PGWEB_HOST, PGWEB_PORT, PGWEB_USER, PGWEB_PASSWORD: PostGIS connection
"""
import argparse
import hashlib
import json
import os
import re
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Dict, Any
import psycopg2
import requests
from psycopg2 import sql
from psycopg2.extras import execute_values
DB_NAME = "data_centers"
TRACT_TABLE = "public.data_center_census_tracts_2024"
LINK_TABLE = "public.energy_atlas_tract_link"
SUMMARY_TABLE = "public.energy_atlas_tract_summary"
# EIA API configuration
EIA_API_BASE = "https://api.eia.gov/v2"
EIA_API_KEY = os.environ.get("EIA_API_KEY")
if not EIA_API_KEY:
print("ERROR: EIA_API_KEY environment variable not set")
sys.exit(1)
SERVICE_URL_RE = re.compile(r"https?://[^\s\"']+", re.IGNORECASE)
SAFE_RE = re.compile(r"[^a-z0-9_]+")
# EIA dataset categories mapped to infrastructure types
EIA_DATASETS = {
"power": {
"category": "power_plants",
"endpoints": [
# Plant-level generation assets with coordinates: what + where.
"electricity/operating-generator-capacity",
# Per-plant monthly net + gross generation (Form EIA-923): how much.
"electricity/facility-fuel",
],
},
}
# Extra data fields (the EIA `data[N]=` query params) each endpoint needs.
# operating-generator-capacity returns only id columns by default; latitude/longitude
# must be requested explicitly. facility-fuel returns only id columns; generation
# values must be requested explicitly.
EIA_DATASET_DATA_FIELDS = {
"electricity/operating-generator-capacity": ["latitude", "longitude"],
"electricity/facility-fuel": ["generation", "gross-generation"],
}
# Endpoints that do not reliably support retry with ad-hoc data[] field requests.
EIA_NO_RETRY_EXTRA_FIELDS = {
}
# US state abbreviation to FIPS code mapping for state-level GEOID linking
STATE_FIPS = {
"AL": "01", "AK": "02", "AZ": "04", "AR": "05", "CA": "06",
"CO": "08", "CT": "09", "DE": "10", "DC": "11", "FL": "12",
"GA": "13", "HI": "15", "ID": "16", "IL": "17", "IN": "18",
"IA": "19", "KS": "20", "KY": "21", "LA": "22", "ME": "23",
"MD": "24", "MA": "25", "MI": "26", "MN": "27", "MS": "28",
"MO": "29", "MT": "30", "NE": "31", "NV": "32", "NH": "33",
"NJ": "34", "NM": "35", "NY": "36", "NC": "37", "ND": "38",
"OH": "39", "OK": "40", "OR": "41", "PA": "42", "RI": "44",
"SC": "45", "SD": "46", "TN": "47", "TX": "48", "UT": "49",
"VT": "50", "VA": "51", "WA": "53", "WV": "54", "WI": "55",
"WY": "56", "PR": "72",
}
@dataclass
class EIADataset:
dataset_id: str
name: str
category: str
api_endpoint: str
description: str = ""
source_url: str = ""
def slugify(value: str) -> str:
"""Convert string to SQL-safe identifier."""
cleaned = SAFE_RE.sub("_", value.lower()).strip("_")
cleaned = re.sub(r"_+", "_", cleaned)
return cleaned or "layer"
def standardize_table_name(dataset_id: str) -> str:
"""Generate standardized table name from dataset ID."""
base = f"energy_eia_{slugify(dataset_id)}"
if len(base) <= 55:
return base
digest = hashlib.md5(base.encode("utf-8")).hexdigest()[:8]
return f"{base[:46]}_{digest}"
class EIAClientError(Exception):
"""Non-retryable EIA API error (e.g. 400 for unsupported fields)."""
def iter_months(start: str, end: str):
"""Yield 'YYYY-MM' strings from start to end inclusive."""
sy, sm = (int(x) for x in start.split("-"))
ey, em = (int(x) for x in end.split("-"))
y, m = sy, sm
while (y, m) <= (ey, em):
yield f"{y:04d}-{m:02d}"
m += 1
if m > 12:
m = 1
y += 1
def discover_period_range(endpoint: str) -> tuple:
"""Return (earliest, latest) 'YYYY-MM' period strings for an endpoint.
Forces frequency=monthly so endpoints that also publish annual/quarterly
series (e.g. facility-fuel) don't return non-monthly period formats that
break iter_months. Routes through query_eia_api for retry/backoff coverage.
"""
def _one(direction: str) -> str:
data = query_eia_api(
endpoint,
params={
"length": 1,
"sort[0][column]": "period",
"sort[0][direction]": direction,
},
query_params={"frequency": "monthly"},
)
rows = (data or {}).get("response", {}).get("data", [])
if not rows:
raise RuntimeError(f"no rows returned discovering period range for {endpoint}")
return rows[0]["period"]
return _one("asc"), _one("desc")
def query_eia_api(
endpoint: str,
params: Optional[Dict[str, Any]] = None,
extra_data_fields: Optional[List[str]] = None,
query_params: Optional[Dict[str, Any]] = None,
) -> Optional[Dict]:
"""Query EIA API endpoint with retry/backoff on transient errors.
Returns parsed JSON on success. Raises EIAClientError on 4xx (caller
decides whether to retry without extra fields). Raises requests.RequestException
after exhausting retries on transient errors.
"""
if not endpoint.endswith("/data"):
endpoint = f"{endpoint}/data"
url = f"{EIA_API_BASE}/{endpoint}/"
req_params: Dict[str, Any] = {"api_key": EIA_API_KEY, "length": 5000}
if query_params:
req_params.update(query_params)
if params:
req_params.update(params)
if extra_data_fields:
for i, field in enumerate(extra_data_fields):
req_params[f"data[{i}]"] = field
max_attempts = 10
base_backoff = 5.0
max_backoff = 120.0
last_exc: Optional[Exception] = None
for attempt in range(1, max_attempts + 1):
try:
resp = requests.get(url, params=req_params, timeout=(10, 120))
if 400 <= resp.status_code < 500 and resp.status_code != 429:
raise EIAClientError(f"HTTP {resp.status_code} on {endpoint}: {resp.text[:200]}")
resp.raise_for_status()
return resp.json()
except EIAClientError:
raise
except (requests.Timeout, requests.ConnectionError, requests.HTTPError, ValueError) as e:
last_exc = e
if attempt == max_attempts:
break
sleep_s = min(base_backoff * (2 ** (attempt - 1)), max_backoff)
print(f" api error on {endpoint} (attempt {attempt}/{max_attempts}): {e}; retrying in {sleep_s:.0f}s")
time.sleep(sleep_s)
raise last_exc # type: ignore[misc]
def fetch_eia_pages(
endpoint: str,
max_records: int = 0,
extra_data_fields: Optional[List[str]] = None,
query_params: Optional[Dict[str, Any]] = None,
) -> Any:
"""Yield paged EIA records; retry without extra fields on unsupported endpoints."""
page_size = 5000
offset = 0
yielded = 0
used_extra_fields = extra_data_fields
previous_first_row: Optional[str] = None
while True:
params = {"offset": offset, "length": page_size}
try:
data = query_eia_api(
endpoint,
params=params,
extra_data_fields=used_extra_fields,
query_params=query_params,
)
except EIAClientError as e:
if used_extra_fields and endpoint not in EIA_NO_RETRY_EXTRA_FIELDS:
print(f" retrying {endpoint} without extra data fields ({e})")
used_extra_fields = None
data = query_eia_api(
endpoint,
params=params,
extra_data_fields=None,
query_params=query_params,
)
else:
raise
if not data:
return
response = data.get("response", {})
page_records = response.get("data", [])
if not page_records:
return
# Some EIA endpoints ignore offset and repeat page 1 forever.
# Detect repeated first row signature and stop pagination.
first_row_sig = json.dumps(page_records[0], sort_keys=True, default=str)
if previous_first_row is not None and first_row_sig == previous_first_row:
return
previous_first_row = first_row_sig
total = response.get("total")
try:
total_int = int(total) if total is not None else None
except (TypeError, ValueError):
total_int = None
if max_records > 0:
remaining = max_records - yielded
if remaining <= 0:
return
if len(page_records) > remaining:
page_records = page_records[:remaining]
yield page_records, used_extra_fields
yielded += len(page_records)
if max_records > 0 and yielded >= max_records:
return
if total_int is not None and yielded >= total_int:
return
if len(page_records) < page_size:
return
offset += len(page_records)
def fetch_eia_pages_by_month(
endpoint: str,
earliest: str,
latest: str,
max_records: int = 0,
extra_data_fields: Optional[List[str]] = None,
) -> Any:
"""Yield pages across months, querying one month at a time.
EIA's bulk endpoints serve large offsets slowly and return frequent 503s
under sustained load. Filtering by &frequency=monthly&start=X&end=X keeps
each query small (~17k28k rows per month for operating-generator-capacity)
and dramatically reduces failure rate and wall time.
"""
yielded = 0
for month in iter_months(earliest, latest):
if max_records > 0 and yielded >= max_records:
return
remaining = max_records - yielded if max_records > 0 else 0
month_params = {"frequency": "monthly", "start": month, "end": month}
for page_records, used_extra_fields in fetch_eia_pages(
endpoint,
max_records=remaining,
extra_data_fields=extra_data_fields,
query_params=month_params,
):
yield page_records, used_extra_fields, month
yielded += len(page_records)
if max_records > 0 and yielded >= max_records:
return
def get_eia_datasets(category: str = "all") -> List[EIADataset]:
"""Discover EIA datasets by category."""
datasets = []
cats = [category] if category != "all" else list(EIA_DATASETS.keys())
for cat in cats:
if cat not in EIA_DATASETS:
continue
cat_info = EIA_DATASETS[cat]
for endpoint in cat_info.get("endpoints", []):
dataset_id = endpoint.replace("/", "_")
datasets.append(
EIADataset(
dataset_id=dataset_id,
name=endpoint.split("/")[-1],
category=cat_info["category"],
api_endpoint=endpoint,
source_url=f"{EIA_API_BASE}/{endpoint}",
)
)
return datasets
def parse_seed_line(line: str) -> Optional[EIADataset]:
"""Parse stub - not used with EIA API."""
return None
def read_seed_file(seed_path: Path) -> List[EIADataset]:
"""Stub - datasets discovered via EIA API catalog."""
return []
def classify_layer(text: str) -> Optional[str]:
"""Stub - classification handled by get_eia_datasets()."""
return None
def import_layer_to_postgis(dataset: EIADataset, table_name: str, max_records: int = 0) -> bool:
"""Import EIA dataset to PostGIS table."""
conn = connect_db()
try:
extra_fields = EIA_DATASET_DATA_FIELDS.get(dataset.api_endpoint)
earliest, latest = discover_period_range(dataset.api_endpoint)
print(f" period range: {earliest} -> {latest}")
count = 0
geo_count = 0
initialized = False
current_month: Optional[str] = None
for page_records, used_extra_fields, month in fetch_eia_pages_by_month(
dataset.api_endpoint,
earliest=earliest,
latest=latest,
max_records=max_records,
extra_data_fields=extra_fields,
):
if month != current_month:
if current_month is not None:
print(f" progress: {count} rows ingested through {current_month}")
current_month = month
if not initialized:
with conn:
with conn.cursor() as cur:
cur.execute(
f"""
CREATE TABLE IF NOT EXISTS public.{table_name} (
gid SERIAL PRIMARY KEY,
geom GEOMETRY(GEOMETRY, 4326),
properties JSONB
)
"""
)
cur.execute(f"TRUNCATE TABLE public.{table_name}")
initialized = True
geom_rows = []
prop_rows = []
for record in page_records:
props_json = json.dumps(record)
lat = record.get("latitude") or record.get("lat")
lon = record.get("longitude") or record.get("lon")
try:
lat = float(lat) if lat is not None else None
lon = float(lon) if lon is not None else None
except (TypeError, ValueError):
lat = lon = None
if lat is not None and lon is not None and -90 <= lat <= 90 and -180 <= lon <= 180:
geom_rows.append((lon, lat, props_json))
geo_count += 1
else:
prop_rows.append((props_json,))
count += 1
with conn:
with conn.cursor() as cur:
if geom_rows:
execute_values(
cur,
f"INSERT INTO public.{table_name} (geom, properties) VALUES %s",
geom_rows,
template="(ST_SetSRID(ST_MakePoint(%s, %s), 4326), %s)",
page_size=1000,
)
if prop_rows:
execute_values(
cur,
f"INSERT INTO public.{table_name} (properties) VALUES %s",
prop_rows,
template="(%s)",
page_size=1000,
)
# Track if API ended up running without extra fields after retry.
if used_extra_fields is None:
extra_fields = None
if not initialized:
print(" no data returned")
return False
geo_msg = f", {geo_count} with geometry" if extra_fields else ""
print(f" inserted {count} features into {table_name}{geo_msg}")
return count > 0
finally:
try:
conn.close()
except:
pass
def connect_db():
"""Connect to PostGIS database."""
return psycopg2.connect(
host=os.environ["PGWEB_HOST"],
port=os.environ["PGWEB_PORT"],
user=os.environ["PGWEB_USER"],
password=os.environ["PGWEB_PASSWORD"],
dbname=DB_NAME,
)
def ensure_source_catalog_table(conn):
"""Create source catalog table if it doesn't exist."""
with conn:
with conn.cursor() as cur:
cur.execute(
"""
create table if not exists public.energy_atlas_layers_catalog (
table_name text primary key,
source_item_id text,
source_type text,
source_title text,
source_owner text,
source_url text,
category text,
imported_at timestamptz default now()
)
"""
)
def upsert_layer_catalog(conn, table_name: str, dataset: EIADataset, category: str):
"""Upsert layer metadata into source catalog."""
with conn:
with conn.cursor() as cur:
cur.execute(
"""
insert into public.energy_atlas_layers_catalog (
table_name,
source_item_id,
source_type,
source_title,
source_owner,
source_url,
category,
imported_at
)
values (%s,%s,%s,%s,%s,%s,%s,now())
on conflict (table_name) do update
set source_item_id = excluded.source_item_id,
source_type = excluded.source_type,
source_title = excluded.source_title,
source_owner = excluded.source_owner,
source_url = excluded.source_url,
category = excluded.category,
imported_at = now()
""",
(
table_name,
dataset.dataset_id,
"EIA API",
dataset.name,
"U.S. Energy Information Administration",
dataset.source_url,
category,
),
)
def add_geom_index_and_analyze(conn, table_name: str):
"""Create spatial index and analyze table."""
with conn:
with conn.cursor() as cur:
cur.execute(
f"create index if not exists {table_name}_geom_gix on public.{table_name} using gist (geom)"
)
cur.execute(f"analyze public.{table_name}")
def table_geom_class(conn, table_name: str) -> Optional[str]:
"""Get geometry type of table."""
with conn.cursor() as cur:
cur.execute(
f"""
select st_geometrytype(geom)
from public.{table_name}
where geom is not null
limit 1
"""
)
row = cur.fetchone()
if not row:
return None
return (row[0] or "").lower()
def reset_link_tables(conn):
"""Recreate GEOID linkage tables."""
with conn:
with conn.cursor() as cur:
cur.execute(f"drop table if exists {LINK_TABLE}")
cur.execute(
f"""
create table {LINK_TABLE} (
geoid text not null,
source_table text not null,
category text not null,
feature_count integer,
intersect_length_m numeric,
intersect_area_sqm numeric,
unique (geoid, source_table)
)
"""
)
def link_one_table(conn, table_name: str, category: str):
"""Link infrastructure features to census tracts via spatial join or state FIPS fallback."""
gclass = table_geom_class(conn, table_name)
# Check if table has a properties column (EIA API tables do; legacy tables may not)
with conn.cursor() as cur:
cur.execute(
"SELECT count(*) FROM information_schema.columns "
"WHERE table_schema='public' AND table_name=%s AND column_name='properties'",
(table_name,),
)
has_properties = cur.fetchone()[0] > 0
# Check if table has any rows with stateid for state-level fallback
stateid_count = 0
if has_properties:
with conn.cursor() as cur:
cur.execute(
f"SELECT count(*) FROM public.{table_name} WHERE geom IS NULL AND properties->>'stateid' IS NOT NULL"
)
stateid_count = cur.fetchone()[0]
if gclass:
# Spatial join for rows that have geometry
with conn:
with conn.cursor() as cur:
if "point" in gclass:
cur.execute(
f"""
INSERT INTO {LINK_TABLE} (
geoid,
source_table,
category,
feature_count,
intersect_length_m,
intersect_area_sqm
)
SELECT
t.geoid,
%s,
%s,
count(*)::integer,
null,
null
FROM {TRACT_TABLE} t
JOIN public.{table_name} s
ON t.geom && s.geom
AND st_covers(t.geom, s.geom)
GROUP BY t.geoid
""",
(table_name, category),
)
elif "line" in gclass:
cur.execute(
f"""
INSERT INTO {LINK_TABLE} (
geoid,
source_table,
category,
feature_count,
intersect_length_m,
intersect_area_sqm
)
SELECT
t.geoid,
%s,
%s,
count(*)::integer,
sum(st_length(st_intersection(t.geom, s.geom)::geography)),
null
FROM {TRACT_TABLE} t
JOIN public.{table_name} s
ON t.geom && s.geom
AND st_intersects(t.geom, s.geom)
GROUP BY t.geoid
""",
(table_name, category),
)
elif "polygon" in gclass:
cur.execute(
f"""
INSERT INTO {LINK_TABLE} (
geoid,
source_table,
category,
feature_count,
intersect_length_m,
intersect_area_sqm
)
SELECT
t.geoid,
%s,
%s,
count(*)::integer,
null,
sum(st_area(st_intersection(t.geom, s.geom)::geography))
FROM {TRACT_TABLE} t
JOIN public.{table_name} s
ON t.geom && s.geom
AND st_intersects(t.geom, s.geom)
GROUP BY t.geoid
""",
(table_name, category),
)
# State-level GEOID fallback: link rows without geometry using stateid → state FIPS prefix
if stateid_count > 0:
link_table_by_state(conn, table_name, category)
def link_table_by_state(conn, table_name: str, category: str):
"""Link EIA records to census tracts via state FIPS code prefix on GEOID.
EIA data has stateid (2-letter abbreviation). We map to FIPS and match
all census tracts whose GEOID starts with that state FIPS code.
Each state gets one link row with the count of records for that state.
"""
# Build a VALUES list for the state FIPS mapping
fips_values = ", ".join(f"('{abbr}', '{fips}')" for abbr, fips in STATE_FIPS.items())
with conn:
with conn.cursor() as cur:
cur.execute(
f"""
INSERT INTO {LINK_TABLE} (
geoid,
source_table,
category,
feature_count,
intersect_length_m,
intersect_area_sqm
)
SELECT
t.geoid,
%s,
%s,
state_counts.record_count::integer,
null,
null
FROM (
SELECT
sf.fips AS state_fips,
count(*) AS record_count
FROM public.{table_name} s
JOIN (VALUES {fips_values}) AS sf(abbr, fips)
ON upper(s.properties->>'stateid') = sf.abbr
WHERE s.geom IS NULL
GROUP BY sf.fips
) state_counts
JOIN {TRACT_TABLE} t
ON left(t.geoid, 2) = state_counts.state_fips
ON CONFLICT DO NOTHING
""",
(table_name, category),
)
rows = cur.rowcount
if rows > 0:
print(f" state-level geocoding linked {rows} tract rows for {table_name}")
def build_summary_table(conn):
"""Create GEOID summary aggregating all categories."""
with conn:
with conn.cursor() as cur:
cur.execute(f"drop table if exists {SUMMARY_TABLE}")
cur.execute(
f"""
create table {SUMMARY_TABLE} as
select
geoid,
sum(feature_count) filter (where category = 'electric_grid')::integer
as electric_grid_feature_count,
sum(coalesce(intersect_length_m, 0)) filter (where category = 'electric_grid')
as electric_grid_length_m,
sum(coalesce(intersect_area_sqm, 0)) filter (where category = 'electric_grid')
as electric_grid_area_sqm,
sum(feature_count) filter (where category = 'power_plants')::integer
as power_plant_feature_count,
sum(feature_count) filter (where category = 'gas_infrastructure')::integer
as gas_infrastructure_feature_count,
sum(coalesce(intersect_length_m, 0)) filter (where category = 'gas_infrastructure')
as gas_infrastructure_length_m,
sum(coalesce(intersect_area_sqm, 0)) filter (where category = 'gas_infrastructure')
as gas_infrastructure_area_sqm
from {LINK_TABLE}
group by geoid
"""
)
cur.execute(f"alter table {SUMMARY_TABLE} add primary key (geoid)")
cur.execute(
f"create index energy_atlas_tract_link_geoid_idx on {LINK_TABLE} (geoid)"
)
cur.execute(
f"create index energy_atlas_tract_link_category_idx on {LINK_TABLE} (category)"
)
cur.execute(f"analyze {LINK_TABLE}")
cur.execute(f"analyze {SUMMARY_TABLE}")
def build_flat_tables(conn):
"""Create analyst-friendly flat tables from JSON properties."""
with conn.cursor() as cur:
cur.execute(
"""
select table_name
from information_schema.tables
where table_schema='public'
and table_name in (
'energy_eia_electricity_operating_generator_capacity',
'energy_eia_electricity_facility_fuel'
)
"""
)
available = {row[0] for row in cur.fetchall()}
with conn:
with conn.cursor() as cur:
cur.execute("drop table if exists public.energy_eia_electric_power_operational_data_flat")
cur.execute("drop table if exists public.energy_eia_rto_region_data_flat")
if "energy_eia_electricity_operating_generator_capacity" in available:
cur.execute("drop table if exists public.energy_eia_operating_generator_capacity_flat")
# EIA stored lower-48 longitudes as positive numbers for periods
# 2008-01 through 2010-11 (~600k rows). The negative sign is
# restored here for any state other than AK (Alaska legitimately
# has Aleutian plants east of the dateline with positive lons).
# geom is rebuilt from the corrected coordinates so the source
# table's pre-correction geometry is discarded.
cur.execute(
r"""
create table public.energy_eia_operating_generator_capacity_flat as
with parsed as (
select
gid,
properties,
properties->>'stateid' as state_id_raw,
case
when (properties->>'latitude') ~ '^-?[0-9]+(\.[0-9]+)?$'
then (properties->>'latitude')::double precision
end as latitude_raw,
case
when (properties->>'longitude') ~ '^-?[0-9]+(\.[0-9]+)?$'
then (properties->>'longitude')::double precision
end as longitude_raw
from public.energy_eia_electricity_operating_generator_capacity
),
fixed as (
select
*,
case
when longitude_raw > 0 and state_id_raw <> 'AK'
then -longitude_raw
else longitude_raw
end as longitude_fixed
from parsed
)
select
gid,
case
when latitude_raw is not null
and longitude_fixed is not null
and latitude_raw between -90 and 90
and longitude_fixed between -180 and 180
then st_setsrid(st_makepoint(longitude_fixed, latitude_raw), 4326)
end as geom,
properties->>'period' as period,
properties->>'plantid' as plant_id,
properties->>'plantName' as plant_name,
state_id_raw as state_id,
properties->>'stateName' as state_name,
properties->>'entityid' as entity_id,
properties->>'entityName' as entity_name,
properties->>'generatorid' as generator_id,
properties->>'status' as status,
properties->>'sector' as sector,
properties->>'sectorName' as sector_name,
properties->>'energy_source_code' as energy_source_code,
properties->>'energy-source-desc' as energy_source_desc,
properties->>'prime_mover_code' as prime_mover_code,
properties->>'balancing_authority_code' as balancing_authority_code,
properties->>'balancing-authority-name' as balancing_authority_name,
latitude_raw as latitude,
longitude_fixed as longitude,
properties as raw_properties
from fixed
"""
)
cur.execute(
"create index energy_eia_operating_generator_capacity_flat_geom_gix "
"on public.energy_eia_operating_generator_capacity_flat using gist (geom)"
)
cur.execute(
"create index energy_eia_operating_generator_capacity_flat_plant_id_idx "
"on public.energy_eia_operating_generator_capacity_flat (plant_id)"
)
cur.execute(
"create index energy_eia_operating_generator_capacity_flat_state_id_idx "
"on public.energy_eia_operating_generator_capacity_flat (state_id)"
)
cur.execute("analyze public.energy_eia_operating_generator_capacity_flat")
if "energy_eia_electricity_facility_fuel" in available:
cur.execute("drop table if exists public.energy_eia_facility_fuel_flat")
cur.execute(
r"""
create table public.energy_eia_facility_fuel_flat as
select
gid,
properties->>'period' as period,
coalesce(properties->>'plantCode', properties->>'plantid') as plant_id,
properties->>'plantName' as plant_name,
properties->>'state' as state_id,
properties->>'stateDescription' as state_name,
properties->>'primeMover' as prime_mover_code,
properties->>'primeMoverDescription' as prime_mover_desc,
properties->>'fuel2002' as energy_source_code,
properties->>'fuel2002Description' as energy_source_desc,
case
when (properties->>'generation') ~ '^-?[0-9]+(\.[0-9]+)?$'
then (properties->>'generation')::double precision
else null
end as generation_mwh,
case
when (properties->>'gross-generation') ~ '^-?[0-9]+(\.[0-9]+)?$'
then (properties->>'gross-generation')::double precision
else null
end as gross_generation_mwh,
properties as raw_properties
from public.energy_eia_electricity_facility_fuel
"""
)
cur.execute(
"create index energy_eia_facility_fuel_flat_plant_id_idx "
"on public.energy_eia_facility_fuel_flat (plant_id)"
)
cur.execute(
"create index energy_eia_facility_fuel_flat_period_idx "
"on public.energy_eia_facility_fuel_flat (period)"
)
cur.execute(
"create index energy_eia_facility_fuel_flat_state_id_idx "
"on public.energy_eia_facility_fuel_flat (state_id)"
)
cur.execute("analyze public.energy_eia_facility_fuel_flat")
def prune_stale_layer_versions(conn) -> int:
"""Drop superseded EIA layer tables and remove stale catalog rows.
Superseded versions are identified by a normalized source key
(source_url without trailing /data). The newest entry is kept.
"""
with conn.cursor() as cur:
cur.execute(
"""
with ranked as (
select
c.table_name,
row_number() over (
partition by coalesce(
nullif(regexp_replace(c.source_url, '/data/?$', ''), ''),
nullif(c.source_item_id, ''),
c.table_name
)
order by c.imported_at desc, c.table_name desc
) as rn
from public.energy_atlas_layers_catalog c
)
select r.table_name
from ranked r
where r.rn > 1
"""
)
stale_tables = [row[0] for row in cur.fetchall()]
pruned = 0
with conn:
with conn.cursor() as cur:
for table_name in stale_tables:
# Guardrail: only manage script-owned EIA tables.
if not table_name.startswith("energy_eia_"):
continue
cur.execute(
"""
select exists (
select 1
from information_schema.tables
where table_schema='public' and table_name=%s
)
""",
(table_name,),
)
table_exists = cur.fetchone()[0]
if table_exists:
cur.execute(
sql.SQL("drop table if exists public.{} cascade").format(
sql.Identifier(table_name)
)
)
print(f"pruned stale table public.{table_name}")
cur.execute(
"delete from public.energy_atlas_layers_catalog where table_name = %s",
(table_name,),
)
pruned += 1
return pruned
def prune_unselected_layers(conn, selected_table_names: List[str]) -> int:
"""Drop catalog/table entries that are not in the currently selected dataset set."""
selected = set(selected_table_names)
with conn.cursor() as cur:
cur.execute(
"""
select table_name
from public.energy_atlas_layers_catalog
where table_name like 'energy_eia_%'
"""
)
existing = [row[0] for row in cur.fetchall()]
to_remove = [name for name in existing if name not in selected]
removed = 0
with conn:
with conn.cursor() as cur:
for table_name in to_remove:
cur.execute(
"""
select exists (
select 1
from information_schema.tables
where table_schema='public' and table_name=%s
)
""",
(table_name,),
)
table_exists = cur.fetchone()[0]
if table_exists:
cur.execute(
sql.SQL("drop table if exists public.{} cascade").format(
sql.Identifier(table_name)
)
)
print(f"pruned unselected table public.{table_name}")
cur.execute(
"delete from public.energy_atlas_layers_catalog where table_name = %s",
(table_name,),
)
removed += 1
return removed
FINAL_FLAT_TABLES = (
"energy_eia_operating_generator_capacity_flat",
"energy_eia_facility_fuel_flat",
)
def keep_only_target_flat_table(conn) -> int:
"""Drop all energy_eia_* tables except the final flat tables."""
with conn.cursor() as cur:
cur.execute(
"""
select table_name
from information_schema.tables
where table_schema='public'
and table_name like 'energy_eia_%%'
and table_name <> ALL(%s)
""",
(list(FINAL_FLAT_TABLES),),
)
to_drop = [row[0] for row in cur.fetchall()]
dropped = 0
with conn:
with conn.cursor() as cur:
for table_name in to_drop:
cur.execute(
sql.SQL("drop table if exists public.{} cascade").format(
sql.Identifier(table_name)
)
)
print(f"dropped non-target table public.{table_name}")
dropped += 1
return dropped
def parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description=(
"Ingest EIA energy infrastructure data via EIA API into PostGIS with GEOID linking."
)
)
parser.add_argument(
"--category",
choices=["power", "all"],
default="power",
help="Infrastructure category to ingest.",
)
parser.add_argument(
"--max-records",
type=int,
default=0,
help="Cap on API records to process per dataset (0=all).",
)
parser.add_argument(
"--endpoint",
action="append",
default=None,
help=(
"Limit ingest to specific EIA endpoint(s). "
"Repeatable. Substring match against api_endpoint. "
"Other datasets are skipped (not re-ingested, not pruned)."
),
)
parser.add_argument(
"--skip-ingest",
action="store_true",
help="Skip import; rebuild GEOID links/summary only.",
)
parser.add_argument(
"--list-only",
action="store_true",
help="List selected datasets and exit.",
)
parser.add_argument(
"--keep-stale-tables",
action="store_true",
help="Do not prune superseded EIA tables/catalog entries.",
)
return parser.parse_args()
def main():
"""Main ingestion pipeline."""
args = parse_args()
if not EIA_API_KEY:
print("ERROR: EIA_API_KEY not set. Export it and try again.")
sys.exit(1)
# Discover EIA datasets by category
datasets = get_eia_datasets(args.category)
if not datasets:
raise RuntimeError(f"No datasets found for category '{args.category}'")
if args.max_records > 0:
print(f"limiting to {args.max_records} records per dataset")
# Build ingest list with table names
datasets_to_ingest = []
for dataset in datasets:
if args.endpoint and not any(filt in dataset.api_endpoint for filt in args.endpoint):
continue
table_name = standardize_table_name(dataset.dataset_id)
datasets_to_ingest.append((dataset, table_name, dataset.category))
if args.list_only:
print(f"selected_datasets={len(datasets_to_ingest)}")
for dataset, table_name, category in datasets_to_ingest:
print(f"{category}\tpublic.{table_name}\t{dataset.name}\t{dataset.source_url}")
return
if not datasets_to_ingest:
raise RuntimeError("No datasets selected. Try --category all or --list-only.")
conn = connect_db()
try:
ensure_source_catalog_table(conn)
if not args.skip_ingest:
for dataset, table_name, category in datasets_to_ingest:
print(f"importing {dataset.name} -> public.{table_name} [{category}]")
success = import_layer_to_postgis(dataset, table_name, max_records=args.max_records)
if success:
upsert_layer_catalog(conn, table_name, dataset, category)
add_geom_index_and_analyze(conn, table_name)
# Pruning compares against the *full* selected set; skip when --endpoint
# is narrowing the run, otherwise we'd drop catalog entries for endpoints
# we deliberately chose not to touch.
if not args.endpoint:
selected_table_names = [table_name for _, table_name, _ in datasets_to_ingest]
removed = prune_unselected_layers(conn, selected_table_names)
if removed > 0:
print(f"pruned unselected layers: {removed}")
if not args.keep_stale_tables:
pruned = prune_stale_layer_versions(conn)
if pruned > 0:
print(f"pruned stale layer versions: {pruned}")
build_flat_tables(conn)
dropped_non_target = keep_only_target_flat_table(conn)
if dropped_non_target > 0:
print(f"dropped non-target energy tables: {dropped_non_target}")
with conn.cursor() as cur:
cur.execute("select count(*) from public.energy_atlas_layers_catalog")
catalog_count = cur.fetchone()[0]
counts = {}
for tbl in FINAL_FLAT_TABLES:
cur.execute(
"select to_regclass(%s) is not null",
(f"public.{tbl}",),
)
if cur.fetchone()[0]:
cur.execute(f"select count(*) from public.{tbl}")
counts[tbl] = cur.fetchone()[0]
else:
counts[tbl] = None
print(f"\ndone: catalog_layers={catalog_count}")
for tbl, n in counts.items():
print(f" {tbl}: {n if n is not None else 'missing'} rows")
finally:
conn.close()
if __name__ == "__main__":
main()