Files
data-centers/postgis_table_loader.ipynb

427 lines
15 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {},
"source": [
"# PostGIS Table Loader Notebook\n",
"\n",
"Use this notebook to load additional tabular files into the same PostgreSQL/PostGIS database used by the scripts in this folder.\n",
"\n",
"Expected environment variables (same pattern as your .py files):\n",
"- `PGWEB_HOST`, `PGWEB_PORT`, `PGWEB_USER`, `PGWEB_PASSWORD`\n",
"- Optional override: `PGDATABASE` (defaults to `data_centers`)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from pathlib import Path\n",
"from typing import List, Tuple\n",
"\n",
"import pandas as pd\n",
"import psycopg2\n",
"from psycopg2 import sql\n",
"from psycopg2.extras import execute_values\n",
"\n",
"print('pandas:', pd.__version__)\n",
"print('psycopg2 loaded successfully')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2",
"metadata": {},
"outputs": [],
"source": [
"# Load DB env vars for notebook kernels (which often do not inherit shell exports).\n",
"def load_env_file(env_path: str = '.env') -> None:\n",
" p = Path(env_path)\n",
" if not p.exists():\n",
" print(f'No {env_path} file found in {Path.cwd()}')\n",
" return\n",
"\n",
" loaded = 0\n",
" for raw_line in p.read_text(encoding='utf-8').splitlines():\n",
" line = raw_line.strip()\n",
" if not line or line.startswith('#') or '=' not in line:\n",
" continue\n",
" key, value = line.split('=', 1)\n",
" key = key.strip()\n",
" value = value.strip().strip('\"').strip(\"'\")\n",
" if key and key not in os.environ:\n",
" os.environ[key] = value\n",
" loaded += 1\n",
" print(f'Loaded {loaded} env var(s) from {env_path}')\n",
"\n",
"\n",
"def require_env(keys):\n",
" missing = [k for k in keys if not os.getenv(k)]\n",
" if missing:\n",
" raise EnvironmentError(\n",
" 'Missing required env vars in notebook kernel: ' + ', '.join(missing) +\n",
" '.\\nSet them in this notebook, or add them to a .env file in this folder.'\n",
" )\n",
"\n",
"\n",
"load_env_file('.env')\n",
"print('PGWEB_HOST:', os.getenv('PGWEB_HOST', '<not set>'))\n",
"print('PGWEB_PORT:', os.getenv('PGWEB_PORT', '<not set>'))\n",
"print('PGWEB_USER:', os.getenv('PGWEB_USER', '<not set>'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"# Connection setup: mirrors the existing scripts in this repository.\n",
"required_keys = ['PGWEB_HOST', 'PGWEB_PORT', 'PGWEB_USER', 'PGWEB_PASSWORD']\n",
"require_env(required_keys)\n",
"\n",
"DB_NAME = os.getenv('PGDATABASE', 'data_centers')\n",
"\n",
"def get_conn():\n",
" return psycopg2.connect(\n",
" host=os.environ['PGWEB_HOST'],\n",
" port=os.environ['PGWEB_PORT'],\n",
" user=os.environ['PGWEB_USER'],\n",
" password=os.environ['PGWEB_PASSWORD'],\n",
" dbname=\"data_centers\",\n",
" )\n",
"\n",
"with get_conn() as conn:\n",
" with conn.cursor() as cur:\n",
" cur.execute('select current_database(), current_user, version()')\n",
" db, usr, ver = cur.fetchone()\n",
" print('Connected to DB:', db)\n",
" print('As user:', usr)\n",
" print('Postgres:', ver.split(',')[0])\n",
" cur.execute('create extension if not exists postgis')\n",
" print('PostGIS extension is available')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"metadata": {},
"outputs": [],
"source": [
"def parse_table_name(table_fqn: str) -> Tuple[str, str]:\n",
" table_fqn = table_fqn.strip()\n",
" if '.' in table_fqn:\n",
" schema, table = table_fqn.split('.', 1)\n",
" else:\n",
" schema, table = 'public', table_fqn\n",
" return schema, table\n",
"\n",
"\n",
"def clean_column_name(name: str) -> str:\n",
" s = str(name).strip().lower()\n",
" out = []\n",
" last_was_us = False\n",
" for ch in s:\n",
" keep = ch.isalnum() or ch == '_'\n",
" c = ch if keep else '_'\n",
" if c == '_':\n",
" if last_was_us:\n",
" continue\n",
" last_was_us = True\n",
" else:\n",
" last_was_us = False\n",
" out.append(c)\n",
" cleaned = ''.join(out).strip('_')\n",
" if not cleaned:\n",
" cleaned = 'col'\n",
" if cleaned[0].isdigit():\n",
" cleaned = 'c_' + cleaned\n",
" return cleaned\n",
"\n",
"\n",
"def uniquify_columns(cols: List[str]) -> List[str]:\n",
" seen = {}\n",
" result = []\n",
" for c in cols:\n",
" base = clean_column_name(c)\n",
" idx = seen.get(base, 0)\n",
" if idx == 0:\n",
" result.append(base)\n",
" else:\n",
" result.append(f'{base}_{idx}')\n",
" seen[base] = idx + 1\n",
" return result\n",
"\n",
"\n",
"def postgres_type_for_series(s: pd.Series) -> str:\n",
" dt = s.dtype\n",
" if pd.api.types.is_integer_dtype(dt):\n",
" return 'bigint'\n",
" if pd.api.types.is_float_dtype(dt):\n",
" return 'double precision'\n",
" if pd.api.types.is_bool_dtype(dt):\n",
" return 'boolean'\n",
" if pd.api.types.is_datetime64_any_dtype(dt):\n",
" return 'timestamp'\n",
" return 'text'\n",
"\n",
"\n",
"def read_tabular(path: str, sheet_name=0) -> pd.DataFrame:\n",
" p = Path(path)\n",
" suffix = p.suffix.lower()\n",
" if suffix == '.csv':\n",
" return pd.read_csv(p)\n",
" if suffix in ['.xlsx', '.xls']:\n",
" return pd.read_excel(p, sheet_name=sheet_name)\n",
" if suffix == '.jsonl':\n",
" return pd.read_json(p, lines=True)\n",
" if suffix == '.json':\n",
" return pd.read_json(p)\n",
" if suffix == '.parquet':\n",
" return pd.read_parquet(p)\n",
" raise ValueError(f'Unsupported file type: {suffix}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"def load_dataframe_to_postgis(\n",
" df: pd.DataFrame,\n",
" table_fqn: str,\n",
" if_exists: str = 'replace', # replace | append | fail\n",
" batch_size: int = 5000,\n",
" analyze: bool = True,\n",
") -> None:\n",
" if if_exists not in {'replace', 'append', 'fail'}:\n",
" raise ValueError(\"if_exists must be one of: replace, append, fail\")\n",
"\n",
" schema, table = parse_table_name(table_fqn)\n",
"\n",
" work = df.copy()\n",
" work.columns = uniquify_columns([str(c) for c in work.columns])\n",
"\n",
" # Convert pandas NaN/NaT to Python None for psycopg2.\n",
" work = work.astype(object).where(pd.notna(work), None)\n",
"\n",
" col_defs = []\n",
" for col in work.columns:\n",
" pg_type = postgres_type_for_series(df[col]) if col in df.columns else 'text'\n",
" col_defs.append((col, pg_type))\n",
"\n",
" rows = [tuple(row) for row in work.itertuples(index=False, name=None)]\n",
"\n",
" with get_conn() as conn:\n",
" with conn.cursor() as cur:\n",
" cur.execute('create extension if not exists postgis')\n",
"\n",
" cur.execute(\n",
" 'select to_regclass(%s)',\n",
" (f'{schema}.{table}',),\n",
" )\n",
" exists = cur.fetchone()[0] is not None\n",
"\n",
" if exists and if_exists == 'fail':\n",
" raise RuntimeError(f'Table {schema}.{table} already exists')\n",
"\n",
" if exists and if_exists == 'replace':\n",
" cur.execute(\n",
" sql.SQL('drop table {}.{}').format(\n",
" sql.Identifier(schema),\n",
" sql.Identifier(table),\n",
" )\n",
" )\n",
" exists = False\n",
"\n",
" if not exists:\n",
" ddl_cols = [\n",
" sql.SQL('{} {}').format(sql.Identifier(c), sql.SQL(t))\n",
" for c, t in col_defs\n",
" ]\n",
" create_sql = sql.SQL('create table {}.{} ({})').format(\n",
" sql.Identifier(schema),\n",
" sql.Identifier(table),\n",
" sql.SQL(', ').join(ddl_cols),\n",
" )\n",
" cur.execute(create_sql)\n",
"\n",
" if rows:\n",
" insert_sql = sql.SQL('insert into {}.{} ({}) values %s').format(\n",
" sql.Identifier(schema),\n",
" sql.Identifier(table),\n",
" sql.SQL(', ').join(sql.Identifier(c) for c in work.columns),\n",
" )\n",
" execute_values(cur, insert_sql.as_string(cur), rows, page_size=batch_size)\n",
"\n",
" if analyze:\n",
" cur.execute(\n",
" sql.SQL('analyze {}.{}').format(\n",
" sql.Identifier(schema),\n",
" sql.Identifier(table),\n",
" )\n",
" )\n",
"\n",
" conn.commit()\n",
"\n",
" print(f'Loaded {len(rows)} rows into {schema}.{table}')\n",
" print('Columns:', ', '.join(work.columns))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6",
"metadata": {},
"outputs": [],
"source": [
"def add_point_geometry(\n",
" table_fqn: str,\n",
" lon_col: str = 'longitude',\n",
" lat_col: str = 'latitude',\n",
" geom_col: str = 'geom',\n",
" srid: int = 4326,\n",
") -> None:\n",
" schema, table = parse_table_name(table_fqn)\n",
"\n",
" with get_conn() as conn:\n",
" with conn.cursor() as cur:\n",
" cur.execute(\n",
" sql.SQL('alter table {}.{} add column if not exists {} geometry(Point, %s)').format(\n",
" sql.Identifier(schema),\n",
" sql.Identifier(table),\n",
" sql.Identifier(geom_col),\n",
" ),\n",
" (srid,),\n",
" )\n",
"\n",
" cur.execute(\n",
" sql.SQL(\n",
" 'update {}.{} set {} = case '\n",
" 'when {} is not null and {} is not null '\n",
" 'then ST_SetSRID(ST_MakePoint({}::double precision, {}::double precision), %s) '\n",
" 'else null end'\n",
" ).format(\n",
" sql.Identifier(schema),\n",
" sql.Identifier(table),\n",
" sql.Identifier(geom_col),\n",
" sql.Identifier(lon_col),\n",
" sql.Identifier(lat_col),\n",
" sql.Identifier(lon_col),\n",
" sql.Identifier(lat_col),\n",
" ),\n",
" (srid,),\n",
" )\n",
"\n",
" idx_name = f'{table}_{geom_col}_gix'\n",
" cur.execute(\n",
" sql.SQL('create index if not exists {} on {}.{} using gist ({})').format(\n",
" sql.Identifier(idx_name),\n",
" sql.Identifier(schema),\n",
" sql.Identifier(table),\n",
" sql.Identifier(geom_col),\n",
" )\n",
" )\n",
" cur.execute(\n",
" sql.SQL('analyze {}.{}').format(sql.Identifier(schema), sql.Identifier(table))\n",
" )\n",
" conn.commit()\n",
"\n",
" print(f'Geometry built in {schema}.{table}.{geom_col} and GiST indexed')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {},
"outputs": [],
"source": [
"# Example for the opposition cases table.\n",
"FILE_PATH = 'new/Opposition_Cases_Geocoded.csv'\n",
"TARGET_TABLE = 'public.opposition_cases_geocoded'\n",
"IF_EXISTS = 'replace' # replace | append | fail\n",
"\n",
"# This file uses state abbreviations in the state column and includes lon/lat.\n",
"df = read_tabular(FILE_PATH)\n",
"print('Rows:', len(df), 'Cols:', len(df.columns))\n",
"display(df[['id', 'state', 'state_id', 'lon', 'lat']].head(5))\n",
"\n",
"load_dataframe_to_postgis(df, TARGET_TABLE, if_exists=IF_EXISTS)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"metadata": {},
"outputs": [],
"source": [
"# Optional: this table has lon/lat columns, so build geometry with those names.\n",
"add_point_geometry(TARGET_TABLE, lon_col='lon', lat_col='lat', geom_col='geom', srid=4326)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"metadata": {},
"outputs": [],
"source": [
"# Quick sanity check: show row count and latest tables in public schema.\n",
"with get_conn() as conn:\n",
" with conn.cursor() as cur:\n",
" schema, table = parse_table_name(TARGET_TABLE)\n",
" cur.execute(\n",
" sql.SQL('select count(*) from {}.{}').format(sql.Identifier(schema), sql.Identifier(table))\n",
" )\n",
" print('Target row count:', cur.fetchone()[0])\n",
"\n",
" cur.execute(\n",
" \"\"\"\n",
" select schemaname, tablename\n",
" from pg_tables\n",
" where schemaname = 'public'\n",
" order by tablename desc\n",
" limit 25\n",
" \"\"\"\n",
" )\n",
" for row in cur.fetchall():\n",
" print(f'{row[0]}.{row[1]}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.14.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}