{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "# PostGIS Table Loader Notebook\n", "\n", "Use this notebook to load additional tabular files into the same PostgreSQL/PostGIS database used by the scripts in this folder.\n", "\n", "Expected environment variables (same pattern as your .py files):\n", "- `PGWEB_HOST`, `PGWEB_PORT`, `PGWEB_USER`, `PGWEB_PASSWORD`\n", "- Optional override: `PGDATABASE` (defaults to `data_centers`)" ] }, { "cell_type": "code", "execution_count": null, "id": "1", "metadata": {}, "outputs": [], "source": [ "import os\n", "from pathlib import Path\n", "from typing import List, Tuple\n", "\n", "import pandas as pd\n", "import psycopg2\n", "from psycopg2 import sql\n", "from psycopg2.extras import execute_values\n", "\n", "print('pandas:', pd.__version__)\n", "print('psycopg2 loaded successfully')" ] }, { "cell_type": "code", "execution_count": null, "id": "2", "metadata": {}, "outputs": [], "source": [ "# Load DB env vars for notebook kernels (which often do not inherit shell exports).\n", "def load_env_file(env_path: str = '.env') -> None:\n", " p = Path(env_path)\n", " if not p.exists():\n", " print(f'No {env_path} file found in {Path.cwd()}')\n", " return\n", "\n", " loaded = 0\n", " for raw_line in p.read_text(encoding='utf-8').splitlines():\n", " line = raw_line.strip()\n", " if not line or line.startswith('#') or '=' not in line:\n", " continue\n", " key, value = line.split('=', 1)\n", " key = key.strip()\n", " value = value.strip().strip('\"').strip(\"'\")\n", " if key and key not in os.environ:\n", " os.environ[key] = value\n", " loaded += 1\n", " print(f'Loaded {loaded} env var(s) from {env_path}')\n", "\n", "\n", "def require_env(keys):\n", " missing = [k for k in keys if not os.getenv(k)]\n", " if missing:\n", " raise EnvironmentError(\n", " 'Missing required env vars in notebook kernel: ' + ', '.join(missing) +\n", " '.\\nSet them in this notebook, or add them to a .env file in this folder.'\n", " )\n", "\n", "\n", "load_env_file('.env')\n", "print('PGWEB_HOST:', os.getenv('PGWEB_HOST', ''))\n", "print('PGWEB_PORT:', os.getenv('PGWEB_PORT', ''))\n", "print('PGWEB_USER:', os.getenv('PGWEB_USER', ''))" ] }, { "cell_type": "code", "execution_count": null, "id": "3", "metadata": {}, "outputs": [], "source": [ "# Connection setup: mirrors the existing scripts in this repository.\n", "required_keys = ['PGWEB_HOST', 'PGWEB_PORT', 'PGWEB_USER', 'PGWEB_PASSWORD']\n", "require_env(required_keys)\n", "\n", "DB_NAME = os.getenv('PGDATABASE', 'data_centers')\n", "\n", "def get_conn():\n", " return psycopg2.connect(\n", " host=os.environ['PGWEB_HOST'],\n", " port=os.environ['PGWEB_PORT'],\n", " user=os.environ['PGWEB_USER'],\n", " password=os.environ['PGWEB_PASSWORD'],\n", " dbname=\"data_centers\",\n", " )\n", "\n", "with get_conn() as conn:\n", " with conn.cursor() as cur:\n", " cur.execute('select current_database(), current_user, version()')\n", " db, usr, ver = cur.fetchone()\n", " print('Connected to DB:', db)\n", " print('As user:', usr)\n", " print('Postgres:', ver.split(',')[0])\n", " cur.execute('create extension if not exists postgis')\n", " print('PostGIS extension is available')" ] }, { "cell_type": "code", "execution_count": null, "id": "4", "metadata": {}, "outputs": [], "source": [ "def parse_table_name(table_fqn: str) -> Tuple[str, str]:\n", " table_fqn = table_fqn.strip()\n", " if '.' in table_fqn:\n", " schema, table = table_fqn.split('.', 1)\n", " else:\n", " schema, table = 'public', table_fqn\n", " return schema, table\n", "\n", "\n", "def clean_column_name(name: str) -> str:\n", " s = str(name).strip().lower()\n", " out = []\n", " last_was_us = False\n", " for ch in s:\n", " keep = ch.isalnum() or ch == '_'\n", " c = ch if keep else '_'\n", " if c == '_':\n", " if last_was_us:\n", " continue\n", " last_was_us = True\n", " else:\n", " last_was_us = False\n", " out.append(c)\n", " cleaned = ''.join(out).strip('_')\n", " if not cleaned:\n", " cleaned = 'col'\n", " if cleaned[0].isdigit():\n", " cleaned = 'c_' + cleaned\n", " return cleaned\n", "\n", "\n", "def uniquify_columns(cols: List[str]) -> List[str]:\n", " seen = {}\n", " result = []\n", " for c in cols:\n", " base = clean_column_name(c)\n", " idx = seen.get(base, 0)\n", " if idx == 0:\n", " result.append(base)\n", " else:\n", " result.append(f'{base}_{idx}')\n", " seen[base] = idx + 1\n", " return result\n", "\n", "\n", "def postgres_type_for_series(s: pd.Series) -> str:\n", " dt = s.dtype\n", " if pd.api.types.is_integer_dtype(dt):\n", " return 'bigint'\n", " if pd.api.types.is_float_dtype(dt):\n", " return 'double precision'\n", " if pd.api.types.is_bool_dtype(dt):\n", " return 'boolean'\n", " if pd.api.types.is_datetime64_any_dtype(dt):\n", " return 'timestamp'\n", " return 'text'\n", "\n", "\n", "def read_tabular(path: str, sheet_name=0) -> pd.DataFrame:\n", " p = Path(path)\n", " suffix = p.suffix.lower()\n", " if suffix == '.csv':\n", " return pd.read_csv(p)\n", " if suffix in ['.xlsx', '.xls']:\n", " return pd.read_excel(p, sheet_name=sheet_name)\n", " if suffix == '.jsonl':\n", " return pd.read_json(p, lines=True)\n", " if suffix == '.json':\n", " return pd.read_json(p)\n", " if suffix == '.parquet':\n", " return pd.read_parquet(p)\n", " raise ValueError(f'Unsupported file type: {suffix}')" ] }, { "cell_type": "code", "execution_count": null, "id": "5", "metadata": {}, "outputs": [], "source": [ "def load_dataframe_to_postgis(\n", " df: pd.DataFrame,\n", " table_fqn: str,\n", " if_exists: str = 'replace', # replace | append | fail\n", " batch_size: int = 5000,\n", " analyze: bool = True,\n", ") -> None:\n", " if if_exists not in {'replace', 'append', 'fail'}:\n", " raise ValueError(\"if_exists must be one of: replace, append, fail\")\n", "\n", " schema, table = parse_table_name(table_fqn)\n", "\n", " work = df.copy()\n", " work.columns = uniquify_columns([str(c) for c in work.columns])\n", "\n", " # Convert pandas NaN/NaT to Python None for psycopg2.\n", " work = work.astype(object).where(pd.notna(work), None)\n", "\n", " col_defs = []\n", " for col in work.columns:\n", " pg_type = postgres_type_for_series(df[col]) if col in df.columns else 'text'\n", " col_defs.append((col, pg_type))\n", "\n", " rows = [tuple(row) for row in work.itertuples(index=False, name=None)]\n", "\n", " with get_conn() as conn:\n", " with conn.cursor() as cur:\n", " cur.execute('create extension if not exists postgis')\n", "\n", " cur.execute(\n", " 'select to_regclass(%s)',\n", " (f'{schema}.{table}',),\n", " )\n", " exists = cur.fetchone()[0] is not None\n", "\n", " if exists and if_exists == 'fail':\n", " raise RuntimeError(f'Table {schema}.{table} already exists')\n", "\n", " if exists and if_exists == 'replace':\n", " cur.execute(\n", " sql.SQL('drop table {}.{}').format(\n", " sql.Identifier(schema),\n", " sql.Identifier(table),\n", " )\n", " )\n", " exists = False\n", "\n", " if not exists:\n", " ddl_cols = [\n", " sql.SQL('{} {}').format(sql.Identifier(c), sql.SQL(t))\n", " for c, t in col_defs\n", " ]\n", " create_sql = sql.SQL('create table {}.{} ({})').format(\n", " sql.Identifier(schema),\n", " sql.Identifier(table),\n", " sql.SQL(', ').join(ddl_cols),\n", " )\n", " cur.execute(create_sql)\n", "\n", " if rows:\n", " insert_sql = sql.SQL('insert into {}.{} ({}) values %s').format(\n", " sql.Identifier(schema),\n", " sql.Identifier(table),\n", " sql.SQL(', ').join(sql.Identifier(c) for c in work.columns),\n", " )\n", " execute_values(cur, insert_sql.as_string(cur), rows, page_size=batch_size)\n", "\n", " if analyze:\n", " cur.execute(\n", " sql.SQL('analyze {}.{}').format(\n", " sql.Identifier(schema),\n", " sql.Identifier(table),\n", " )\n", " )\n", "\n", " conn.commit()\n", "\n", " print(f'Loaded {len(rows)} rows into {schema}.{table}')\n", " print('Columns:', ', '.join(work.columns))" ] }, { "cell_type": "code", "execution_count": null, "id": "6", "metadata": {}, "outputs": [], "source": [ "def add_point_geometry(\n", " table_fqn: str,\n", " lon_col: str = 'longitude',\n", " lat_col: str = 'latitude',\n", " geom_col: str = 'geom',\n", " srid: int = 4326,\n", ") -> None:\n", " schema, table = parse_table_name(table_fqn)\n", "\n", " with get_conn() as conn:\n", " with conn.cursor() as cur:\n", " cur.execute(\n", " sql.SQL('alter table {}.{} add column if not exists {} geometry(Point, %s)').format(\n", " sql.Identifier(schema),\n", " sql.Identifier(table),\n", " sql.Identifier(geom_col),\n", " ),\n", " (srid,),\n", " )\n", "\n", " cur.execute(\n", " sql.SQL(\n", " 'update {}.{} set {} = case '\n", " 'when {} is not null and {} is not null '\n", " 'then ST_SetSRID(ST_MakePoint({}::double precision, {}::double precision), %s) '\n", " 'else null end'\n", " ).format(\n", " sql.Identifier(schema),\n", " sql.Identifier(table),\n", " sql.Identifier(geom_col),\n", " sql.Identifier(lon_col),\n", " sql.Identifier(lat_col),\n", " sql.Identifier(lon_col),\n", " sql.Identifier(lat_col),\n", " ),\n", " (srid,),\n", " )\n", "\n", " idx_name = f'{table}_{geom_col}_gix'\n", " cur.execute(\n", " sql.SQL('create index if not exists {} on {}.{} using gist ({})').format(\n", " sql.Identifier(idx_name),\n", " sql.Identifier(schema),\n", " sql.Identifier(table),\n", " sql.Identifier(geom_col),\n", " )\n", " )\n", " cur.execute(\n", " sql.SQL('analyze {}.{}').format(sql.Identifier(schema), sql.Identifier(table))\n", " )\n", " conn.commit()\n", "\n", " print(f'Geometry built in {schema}.{table}.{geom_col} and GiST indexed')" ] }, { "cell_type": "code", "execution_count": null, "id": "7", "metadata": {}, "outputs": [], "source": [ "# Example for the opposition cases table.\n", "FILE_PATH = 'new/Opposition_Cases_Geocoded.csv'\n", "TARGET_TABLE = 'public.opposition_cases_geocoded'\n", "IF_EXISTS = 'replace' # replace | append | fail\n", "\n", "# This file uses state abbreviations in the state column and includes lon/lat.\n", "df = read_tabular(FILE_PATH)\n", "print('Rows:', len(df), 'Cols:', len(df.columns))\n", "display(df[['id', 'state', 'state_id', 'lon', 'lat']].head(5))\n", "\n", "load_dataframe_to_postgis(df, TARGET_TABLE, if_exists=IF_EXISTS)" ] }, { "cell_type": "code", "execution_count": null, "id": "8", "metadata": {}, "outputs": [], "source": [ "# Optional: this table has lon/lat columns, so build geometry with those names.\n", "add_point_geometry(TARGET_TABLE, lon_col='lon', lat_col='lat', geom_col='geom', srid=4326)" ] }, { "cell_type": "code", "execution_count": null, "id": "9", "metadata": {}, "outputs": [], "source": [ "# Quick sanity check: show row count and latest tables in public schema.\n", "with get_conn() as conn:\n", " with conn.cursor() as cur:\n", " schema, table = parse_table_name(TARGET_TABLE)\n", " cur.execute(\n", " sql.SQL('select count(*) from {}.{}').format(sql.Identifier(schema), sql.Identifier(table))\n", " )\n", " print('Target row count:', cur.fetchone()[0])\n", "\n", " cur.execute(\n", " \"\"\"\n", " select schemaname, tablename\n", " from pg_tables\n", " where schemaname = 'public'\n", " order by tablename desc\n", " limit 25\n", " \"\"\"\n", " )\n", " for row in cur.fetchall():\n", " print(f'{row[0]}.{row[1]}')" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.14.5" } }, "nbformat": 4, "nbformat_minor": 5 }