gonna figure it out

This commit is contained in:
2024-09-28 23:12:43 -07:00
parent b6e3e24aec
commit d637131886
61 changed files with 9059 additions and 174 deletions

View File

@@ -0,0 +1,7 @@
"""This module defines some dialect-specific functions used for administration tasks."""
from geoalchemy2.admin.dialects import common # noqa
from geoalchemy2.admin.dialects import geopackage # noqa
from geoalchemy2.admin.dialects import mysql # noqa
from geoalchemy2.admin.dialects import postgresql # noqa
from geoalchemy2.admin.dialects import sqlite # noqa

View File

@@ -0,0 +1,104 @@
"""This module defines functions used by several dialects."""
import sqlalchemy
from packaging import version
from sqlalchemy import Column
from sqlalchemy.sql import expression
from sqlalchemy.types import TypeDecorator
from geoalchemy2.types import Geometry
_SQLALCHEMY_VERSION_BEFORE_14 = version.parse(sqlalchemy.__version__) < version.parse("1.4")
def _spatial_idx_name(table_name, column_name):
return "idx_{}_{}".format(table_name, column_name)
def _format_select_args(*args):
if _SQLALCHEMY_VERSION_BEFORE_14:
return [args]
else:
return args
def check_management(*args):
"""Default function to check management (always True by default)."""
return True
def _get_gis_cols(table, spatial_types, dialect, check_col_management=None):
if check_col_management is not None:
func = check_col_management
else:
func = check_management
return [
col
for col in table.columns
if (
isinstance(col, Column)
and _check_spatial_type(col.type, spatial_types, dialect)
and func(col)
)
]
def _check_spatial_type(tested_type, spatial_types, dialect=None):
return isinstance(tested_type, spatial_types) or (
isinstance(tested_type, TypeDecorator)
and isinstance(tested_type.load_dialect_impl(dialect), spatial_types)
)
def _get_dispatch_info(table, bind, check_col_management=None):
"""Get info required for dispatch events."""
dialect = bind.dialect
# Filter Geometry columns from the table
# Note: Geography and PostGIS >= 2.0 don't need this
gis_cols = _get_gis_cols(table, Geometry, dialect, check_col_management=check_col_management)
# Find all other columns that are not managed Geometries
regular_cols = [x for x in table.columns if x not in gis_cols]
return dialect, gis_cols, regular_cols
def _update_table_for_dispatch(table, regular_cols):
"""Update the table before dispatch events."""
# Save original table column list for later
table.info["_saved_columns"] = table.columns
# Temporarily patch a set of columns not including the
# managed Geometry columns
column_collection = expression.ColumnCollection()
for col in regular_cols:
column_collection.add(col)
table.columns = column_collection
def setup_create_drop(table, bind, check_col_management=None):
"""Prepare the table for before_create and before_drop events."""
dialect, gis_cols, regular_cols = _get_dispatch_info(table, bind, check_col_management)
_update_table_for_dispatch(table, regular_cols)
return dialect, gis_cols, regular_cols
def reflect_geometry_column(inspector, table, column_info):
return
def before_create(table, bind, **kw):
return
def after_create(table, bind, **kw):
return
def before_drop(table, bind, **kw):
return
def after_drop(table, bind, **kw):
return

View File

@@ -0,0 +1,381 @@
"""This module defines specific functions for GeoPackage dialect.
See GeoPackage specifications here: http://www.geopackage.org/spec/
"""
import re
from sqlalchemy import text
from sqlalchemy.dialects import registry
from sqlalchemy.dialects.sqlite.pysqlite import SQLiteDialect_pysqlite
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import func
from sqlalchemy.sql import select
from geoalchemy2 import functions
from geoalchemy2.admin.dialects.common import _check_spatial_type
from geoalchemy2.admin.dialects.common import _format_select_args
from geoalchemy2.admin.dialects.common import _spatial_idx_name
from geoalchemy2.admin.dialects.common import setup_create_drop
from geoalchemy2.admin.dialects.sqlite import _SQLITE_FUNCTIONS
from geoalchemy2.admin.dialects.sqlite import get_col_dim
from geoalchemy2.admin.dialects.sqlite import load_spatialite_driver
from geoalchemy2.types import Geography
from geoalchemy2.types import Geometry
from geoalchemy2.types import _DummyGeometry
class GeoPackageDialect(SQLiteDialect_pysqlite):
"""Define a specific dialect for GeoPackage."""
name = "geopackage"
driver = "gpkg"
supports_statement_cache = True
"""Enable caching for GeoPackage dialect."""
registry.register("gpkg", "geoalchemy2.admin.dialects.geopackage", "GeoPackageDialect")
def load_geopackage_driver(dbapi_conn, *args):
"""Load SpatiaLite extension in GeoPackage connection and set VirtualGpkg and Amphibious modes.
.. Warning::
The path to the SpatiaLite module should be set in the `SPATIALITE_LIBRARY_PATH`
environment variable.
Args:
dbapi_conn: The DBAPI connection.
"""
load_spatialite_driver(dbapi_conn, *args)
dbapi_conn.execute("SELECT AutoGpkgStart();")
dbapi_conn.execute("SELECT EnableGpkgAmphibiousMode();")
def init_geopackage(dbapi_conn, *args):
"""Initialize GeoPackage tables.
Args:
dbapi_conn: The DBAPI connection.
.. Warning::
No EPSG SRID is loaded in the `gpkg_spatial_ref_sys` table after initialization but
it is possible to load other EPSG SRIDs afterwards using the
`gpkgInsertEpsgSRID(srid)`.
Nevertheless, SRIDs of newly created tables are automatically added.
"""
if not dbapi_conn.execute("SELECT CheckGeoPackageMetaData();").fetchone()[0]:
# This only works on the main database
dbapi_conn.execute("SELECT gpkgCreateBaseTables();")
def load_spatialite_gpkg(*args, **kwargs):
"""Load SpatiaLite extension in GeoPackage and initialize internal tables.
See :func:`geoalchemy2.admin.dialects.geopackage.load_geopackage_driver` and
:func:`geoalchemy2.admin.dialects.geopackage.init_geopackage` functions for details about
arguments.
"""
load_geopackage_driver(*args)
init_geopackage(*args, **kwargs)
def _get_spatialite_attrs(bind, table_name, col_name):
attrs = bind.execute(
text(
"""SELECT
A.geometry_type_name,
A.srs_id,
A.z,
A.m,
IFNULL(B.has_index, 0) AS has_index
FROM gpkg_geometry_columns
AS A
LEFT JOIN (
SELECT table_name, column_name, COUNT(*) AS has_index
FROM gpkg_extensions
WHERE LOWER(table_name) = LOWER(:table_name)
AND column_name = :column_name
AND extension_name = 'gpkg_rtree_index'
) AS B
ON LOWER(A.table_name) = LOWER(B.table_name)
AND A.column_name = B.column_name
WHERE LOWER(A.table_name) = LOWER(:table_name)
AND A.column_name = :column_name;
"""
).bindparams(table_name=table_name, column_name=col_name)
).fetchone()
if attrs is None:
# If the column is not registered as a spatial column we ignore it
return None
geometry_type, srid, has_z, has_m, has_index = attrs
coord_dimension = "XY"
if has_z:
coord_dimension += "Z"
if has_m:
coord_dimension += "M"
col_attributes = geometry_type, coord_dimension, srid, has_index
return col_attributes
def _setup_dummy_type(table, gis_cols):
"""Setup dummy type for new Geometry columns so they can be updated later."""
for col in gis_cols:
# Add dummy columns with GEOMETRY type
type_str = re.fullmatch("(.+?)[ZMzm]*", col.type.geometry_type).group(1)
col._actual_type = col.type
col.type = _DummyGeometry(geometry_type=type_str)
table.columns = table.info["_saved_columns"]
def create_spatial_index(bind, table, col):
"""Create spatial index on the given column."""
stmt = select(*_format_select_args(func.gpkgAddSpatialIndex(table.name, col.name)))
stmt = stmt.execution_options(autocommit=True)
bind.execute(stmt)
def disable_spatial_index(bind, table, col):
"""Disable spatial indexes if present."""
for i in ["", "_node", "_parent", "_rowid"]:
bind.execute(
text(
"DROP TABLE IF EXISTS rtree_{}_{}{};".format(
table.name,
col.name,
i,
)
)
)
bind.execute(
text(
"""DELETE FROM gpkg_extensions
WHERE LOWER(table_name) = LOWER(:table_name)
AND column_name = :column_name
AND extension_name = 'gpkg_rtree_index';"""
).bindparams(
table_name=table.name,
column_name=col.name,
)
)
def reflect_geometry_column(inspector, table, column_info):
"""Reflect a column of type Geometry with GeoPackage dialect."""
# Get geometry type, SRID and spatial index from the SpatiaLite metadata
if not isinstance(column_info.get("type"), Geometry):
return
col_attributes = _get_spatialite_attrs(inspector.bind, table.name, column_info["name"])
if col_attributes is not None:
geometry_type, coord_dimension, srid, spatial_index = col_attributes
coord_dimension = {
"XY": 2,
"XYZ": 3,
"XYM": 3,
"XYZM": 4,
}.get(coord_dimension, coord_dimension)
# Set attributes
column_info["type"].geometry_type = geometry_type
column_info["type"].dimension = coord_dimension
column_info["type"].srid = srid
column_info["type"].spatial_index = bool(spatial_index)
# Spatial indexes are not automatically reflected with GeoPackage dialect
column_info["type"]._spatial_index_reflected = False
def before_create(table, bind, **kw):
"""Handle spatial indexes during the before_create event."""
dialect, gis_cols, regular_cols = setup_create_drop(table, bind)
# Remove the spatial indexes from the table metadata because they should not be
# created during the table.create() step since the associated columns do not exist
# at this time.
table.info["_after_create_indexes"] = []
current_indexes = set(table.indexes)
for idx in current_indexes:
for col in table.info["_saved_columns"]:
if _check_spatial_type(col.type, Geometry, dialect) and col in idx.columns.values():
table.indexes.remove(idx)
if idx.name != _spatial_idx_name(table.name, col.name) or not getattr(
col.type, "spatial_index", False
):
table.info["_after_create_indexes"].append(idx)
if len(gis_cols) > 1:
raise ValueError("Only one geometry column is allowed for a table stored in a GeoPackage.")
elif len(gis_cols) == 1:
col = gis_cols[0]
srid = col.type.srid
if col.type.geometry_type is None:
col.type.geometry_type = "GEOMETRY"
# Add the SRID of the table in 'gpkg_spatial_ref_sys' if this table exists
if not bind.execute(
text("SELECT COUNT(*) FROM gpkg_spatial_ref_sys WHERE srs_id = :srid;").bindparams(
srid=srid
)
).scalar():
bind.execute(text("SELECT gpkgInsertEpsgSRID(:srid)").bindparams(srid=srid))
_setup_dummy_type(table, gis_cols)
def after_create(table, bind, **kw):
"""Handle spatial indexes during the after_create event."""
dialect = bind.dialect
for col in table.columns:
# Add the managed Geometry columns with gpkgAddGeometryColumn()
if _check_spatial_type(col.type, Geometry, dialect):
col.type = col._actual_type
del col._actual_type
dimension = get_col_dim(col)
has_z = "Z" in dimension
has_m = "M" in dimension
bind.execute(
text(
"""INSERT INTO gpkg_contents
VALUES (
:table_name,
'features',
:table_name,
"",
strftime('%Y-%m-%dT%H:%M:%fZ', CURRENT_TIMESTAMP),
NULL,
NULL,
NULL,
NULL,
:srid
);"""
).bindparams(
table_name=table.name,
srid=col.type.srid,
)
)
bind.execute(
text(
"""INSERT INTO gpkg_geometry_columns
VALUES (:table_name, :column_name, :geometry_type, :srid, :has_z, :has_m);"""
).bindparams(
table_name=table.name,
column_name=col.name,
geometry_type=col.type.geometry_type,
srid=col.type.srid,
has_z=has_z,
has_m=has_m,
)
)
stmt = select(*_format_select_args(func.gpkgAddGeometryTriggers(table.name, col.name)))
stmt = stmt.execution_options(autocommit=True)
bind.execute(stmt)
for col in table.columns:
# Add spatial indexes for the Geometry and Geography columns
# TODO: Check that the Geography type makes sense here
if (
_check_spatial_type(col.type, (Geometry, Geography), dialect)
and col.type.spatial_index is True
):
create_spatial_index(bind, table, col)
for idx in table.info.pop("_after_create_indexes"):
table.indexes.add(idx)
idx.create(bind=bind)
def before_drop(table, bind, **kw):
"""Handle spatial indexes during the before_drop event."""
dialect, gis_cols, regular_cols = setup_create_drop(table, bind)
for col in gis_cols:
# Disable spatial indexes if present
# TODO: This is useless but if we remove it then the disable_spatial_index should be
# tested separately
disable_spatial_index(bind, table, col)
# Remove metadata from internal tables
# (this is equivalent to DiscardGeometryColumn but for GeoPackage)
bind.execute(
text(
"""DELETE FROM gpkg_extensions
WHERE LOWER(table_name) = LOWER(:table_name)
AND column_name = :column_name;"""
).bindparams(
table_name=table.name,
column_name=col.name,
)
)
bind.execute(
text(
"""DELETE FROM gpkg_geometry_columns
WHERE LOWER(table_name) = LOWER(:table_name)
AND column_name = :column_name;"""
).bindparams(
table_name=table.name,
column_name=col.name,
)
)
bind.execute(
text(
"""DELETE FROM gpkg_contents
WHERE LOWER(table_name) = LOWER(:table_name);"""
).bindparams(table_name=table.name)
)
def after_drop(table, bind, **kw):
"""Handle spatial indexes during the after_drop event."""
table.columns = table.info.pop("_saved_columns")
def _compiles_gpkg(cls, fn):
def _compile_gpkg(element, compiler, **kw):
return "{}({})".format(fn, compiler.process(element.clauses, **kw))
compiles(getattr(functions, cls), "geopackage")(_compile_gpkg)
def register_gpkg_mapping(mapping):
"""Register compilation mappings for the given functions.
Args:
mapping: Should have the following form::
{
"function_name_1": "gpkg_function_name_1",
"function_name_2": "gpkg_function_name_2",
...
}
"""
for cls, fn in mapping.items():
_compiles_gpkg(cls, fn)
register_gpkg_mapping(_SQLITE_FUNCTIONS)
def create_spatial_ref_sys_view(bind):
"""Create the `spatial_ref_sys` view from the `gpkg_spatial_ref_sys` table.
.. Note::
This is usually only needed to use the `ST_Transform` function on GeoPackage data
because this function, when used with SpatiaLite, requires the `spatial_ref_sys` table.
"""
bind.execute(
text(
"""CREATE VIEW spatial_ref_sys AS
SELECT
srs_id AS srid,
organization AS auth_name,
organization_coordsys_id AS auth_srid,
definition AS srtext
FROM gpkg_spatial_ref_sys;"""
)
)

View File

@@ -0,0 +1,200 @@
"""This module defines specific functions for MySQL dialect."""
from sqlalchemy import text
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.sqltypes import NullType
from geoalchemy2 import functions
from geoalchemy2.admin.dialects.common import _check_spatial_type
from geoalchemy2.admin.dialects.common import _spatial_idx_name
from geoalchemy2.admin.dialects.common import setup_create_drop
from geoalchemy2.types import Geography
from geoalchemy2.types import Geometry
_POSSIBLE_TYPES = [
"geometry",
"point",
"linestring",
"polygon",
"multipoint",
"multilinestring",
"multipolygon",
"geometrycollection",
]
def reflect_geometry_column(inspector, table, column_info):
"""Reflect a column of type Geometry with Postgresql dialect."""
if not isinstance(column_info.get("type"), (Geometry, NullType)):
return
column_name = column_info.get("name")
schema = table.schema or inspector.default_schema_name
# Check geometry type, SRID and if the column is nullable
geometry_type_query = """SELECT DATA_TYPE, SRS_ID, IS_NULLABLE
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME = '{}' and COLUMN_NAME = '{}'""".format(
table.name, column_name
)
if schema is not None:
geometry_type_query += """ and table_schema = '{}'""".format(schema)
geometry_type, srid, nullable_str = inspector.bind.execute(text(geometry_type_query)).one()
is_nullable = str(nullable_str).lower() == "yes"
if geometry_type not in _POSSIBLE_TYPES:
return
# Check if the column has spatial index
has_index_query = """SELECT DISTINCT
INDEX_TYPE
FROM INFORMATION_SCHEMA.STATISTICS
WHERE TABLE_NAME = '{}' and COLUMN_NAME = '{}'""".format(
table.name, column_name
)
if schema is not None:
has_index_query += """ and TABLE_SCHEMA = '{}'""".format(schema)
spatial_index_res = inspector.bind.execute(text(has_index_query)).scalar()
spatial_index = str(spatial_index_res).lower() == "spatial"
# Set attributes
column_info["type"] = Geometry(
geometry_type=geometry_type.upper(),
srid=srid,
spatial_index=spatial_index,
nullable=is_nullable,
_spatial_index_reflected=True,
)
def before_create(table, bind, **kw):
"""Handle spatial indexes during the before_create event."""
dialect, gis_cols, regular_cols = setup_create_drop(table, bind)
# Remove the spatial indexes from the table metadata because they should not be
# created during the table.create() step since the associated columns do not exist
# at this time.
table.info["_after_create_indexes"] = []
current_indexes = set(table.indexes)
for idx in current_indexes:
for col in table.info["_saved_columns"]:
if (_check_spatial_type(col.type, Geometry, dialect)) and col in idx.columns.values():
table.indexes.remove(idx)
if idx.name != _spatial_idx_name(table.name, col.name) or not getattr(
col.type, "spatial_index", False
):
table.info["_after_create_indexes"].append(idx)
table.columns = table.info.pop("_saved_columns")
def after_create(table, bind, **kw):
"""Handle spatial indexes during the after_create event."""
# Restore original column list including managed Geometry columns
dialect = bind.dialect
# table.columns = table.info.pop("_saved_columns")
for col in table.columns:
# Add spatial indices for the Geometry and Geography columns
if (
_check_spatial_type(col.type, (Geometry, Geography), dialect)
and col.type.spatial_index is True
):
# If the index does not exist, define it and create it
if not [i for i in table.indexes if col in i.columns.values()]:
sql = "ALTER TABLE {} ADD SPATIAL INDEX({});".format(table.name, col.name)
q = text(sql)
bind.execute(q)
for idx in table.info.pop("_after_create_indexes"):
table.indexes.add(idx)
def before_drop(table, bind, **kw):
return
def after_drop(table, bind, **kw):
return
_MYSQL_FUNCTIONS = {
"ST_AsEWKB": "ST_AsBinary",
}
def _compiles_mysql(cls, fn):
def _compile_mysql(element, compiler, **kw):
return "{}({})".format(fn, compiler.process(element.clauses, **kw))
compiles(getattr(functions, cls), "mysql")(_compile_mysql)
compiles(getattr(functions, cls), "mariadb")(_compile_mysql)
def register_mysql_mapping(mapping):
"""Register compilation mappings for the given functions.
Args:
mapping: Should have the following form::
{
"function_name_1": "mysql_function_name_1",
"function_name_2": "mysql_function_name_2",
...
}
"""
for cls, fn in mapping.items():
_compiles_mysql(cls, fn)
register_mysql_mapping(_MYSQL_FUNCTIONS)
def _compile_GeomFromText_MySql(element, compiler, **kw):
element.identifier = "ST_GeomFromText"
compiled = compiler.process(element.clauses, **kw)
srid = element.type.srid
if srid > 0:
return "{}({}, {})".format(element.identifier, compiled, srid)
else:
return "{}({})".format(element.identifier, compiled)
def _compile_GeomFromWKB_MySql(element, compiler, **kw):
element.identifier = "ST_GeomFromWKB"
wkb_data = list(element.clauses)[0].value
if isinstance(wkb_data, memoryview):
list(element.clauses)[0].value = wkb_data.tobytes()
compiled = compiler.process(element.clauses, **kw)
srid = element.type.srid
if srid > 0:
return "{}({}, {})".format(element.identifier, compiled, srid)
else:
return "{}({})".format(element.identifier, compiled)
@compiles(functions.ST_GeomFromText, "mysql") # type: ignore
@compiles(functions.ST_GeomFromText, "mariadb") # type: ignore
def _MySQL_ST_GeomFromText(element, compiler, **kw):
return _compile_GeomFromText_MySql(element, compiler, **kw)
@compiles(functions.ST_GeomFromEWKT, "mysql") # type: ignore
@compiles(functions.ST_GeomFromEWKT, "mariadb") # type: ignore
def _MySQL_ST_GeomFromEWKT(element, compiler, **kw):
return _compile_GeomFromText_MySql(element, compiler, **kw)
@compiles(functions.ST_GeomFromWKB, "mysql") # type: ignore
@compiles(functions.ST_GeomFromWKB, "mariadb") # type: ignore
def _MySQL_ST_GeomFromWKB(element, compiler, **kw):
return _compile_GeomFromWKB_MySql(element, compiler, **kw)
@compiles(functions.ST_GeomFromEWKB, "mysql") # type: ignore
@compiles(functions.ST_GeomFromEWKB, "mariadb") # type: ignore
def _MySQL_ST_GeomFromEWKB(element, compiler, **kw):
return _compile_GeomFromWKB_MySql(element, compiler, **kw)

View File

@@ -0,0 +1,162 @@
"""This module defines specific functions for Postgresql dialect."""
from sqlalchemy import Index
from sqlalchemy import text
from sqlalchemy.sql import func
from sqlalchemy.sql import select
from geoalchemy2.admin.dialects.common import _check_spatial_type
from geoalchemy2.admin.dialects.common import _format_select_args
from geoalchemy2.admin.dialects.common import _spatial_idx_name
from geoalchemy2.admin.dialects.common import setup_create_drop
from geoalchemy2.types import Geography
from geoalchemy2.types import Geometry
def check_management(column):
"""Check if the column should be managed."""
return getattr(column.type, "use_typmod", None) is False
def create_spatial_index(bind, table, col):
"""Create spatial index on the given column."""
if col.type.use_N_D_index:
postgresql_ops = {col.name: "gist_geometry_ops_nd"}
else:
postgresql_ops = {}
idx = Index(
_spatial_idx_name(table.name, col.name),
col,
postgresql_using="gist",
postgresql_ops=postgresql_ops,
_column_flag=True,
)
idx.create(bind=bind)
def reflect_geometry_column(inspector, table, column_info):
"""Reflect a column of type Geometry with Postgresql dialect."""
if not isinstance(column_info.get("type"), Geometry):
return
geo_type = column_info["type"]
geometry_type = geo_type.geometry_type
coord_dimension = geo_type.dimension
if geometry_type.endswith("ZM"):
coord_dimension = 4
elif geometry_type[-1] in ["Z", "M"]:
coord_dimension = 3
# Query to check a given column has spatial index
if table.schema is not None:
schema_part = " AND nspname = '{}'".format(table.schema)
else:
schema_part = ""
has_index_query = """SELECT (indexrelid IS NOT NULL) AS has_index
FROM (
SELECT
n.nspname,
c.relname,
c.oid AS relid,
a.attname,
a.attnum
FROM pg_attribute a
INNER JOIN pg_class c ON (a.attrelid=c.oid)
INNER JOIN pg_type t ON (a.atttypid=t.oid)
INNER JOIN pg_namespace n ON (c.relnamespace=n.oid)
WHERE t.typname='geometry'
AND c.relkind='r'
) g
LEFT JOIN pg_index i ON (g.relid = i.indrelid AND g.attnum = ANY(i.indkey))
WHERE relname = '{}' AND attname = '{}'{};
""".format(
table.name, column_info["name"], schema_part
)
spatial_index = inspector.bind.execute(text(has_index_query)).scalar()
# Set attributes
column_info["type"].geometry_type = geometry_type
column_info["type"].dimension = coord_dimension
column_info["type"].spatial_index = bool(spatial_index)
# Spatial indexes are automatically reflected with PostgreSQL dialect
column_info["type"]._spatial_index_reflected = True
def before_create(table, bind, **kw):
"""Handle spatial indexes during the before_create event."""
dialect, gis_cols, regular_cols = setup_create_drop(table, bind, check_management)
# Remove the spatial indexes from the table metadata because they should not be
# created during the table.create() step since the associated columns do not exist
# at this time.
table.info["_after_create_indexes"] = []
current_indexes = set(table.indexes)
for idx in current_indexes:
for col in table.info["_saved_columns"]:
if (
_check_spatial_type(col.type, Geometry, dialect) and check_management(col)
) and col in idx.columns.values():
table.indexes.remove(idx)
if idx.name != _spatial_idx_name(table.name, col.name) or not getattr(
col.type, "spatial_index", False
):
table.info["_after_create_indexes"].append(idx)
def after_create(table, bind, **kw):
"""Handle spatial indexes during the after_create event."""
# Restore original column list including managed Geometry columns
dialect = bind.dialect
table.columns = table.info.pop("_saved_columns")
for col in table.columns:
# Add the managed Geometry columns with AddGeometryColumn()
if _check_spatial_type(col.type, Geometry, dialect) and check_management(col):
dimension = col.type.dimension
args = [table.schema] if table.schema else []
args.extend([table.name, col.name, col.type.srid, col.type.geometry_type, dimension])
if col.type.use_typmod is not None:
args.append(col.type.use_typmod)
stmt = select(*_format_select_args(func.AddGeometryColumn(*args)))
stmt = stmt.execution_options(autocommit=True)
bind.execute(stmt)
# Add spatial indices for the Geometry and Geography columns
if (
_check_spatial_type(col.type, (Geometry, Geography), dialect)
and col.type.spatial_index is True
):
# If the index does not exist, define it and create it
if not [i for i in table.indexes if col in i.columns.values()] and check_management(
col
):
create_spatial_index(bind, table, col)
for idx in table.info.pop("_after_create_indexes"):
table.indexes.add(idx)
idx.create(bind=bind)
def before_drop(table, bind, **kw):
"""Handle spatial indexes during the before_drop event."""
dialect, gis_cols, regular_cols = setup_create_drop(table, bind, check_management)
# Drop the managed Geometry columns
for col in gis_cols:
args = [table.schema] if table.schema else []
args.extend([table.name, col.name])
stmt = select(*_format_select_args(func.DropGeometryColumn(*args)))
stmt = stmt.execution_options(autocommit=True)
bind.execute(stmt)
def after_drop(table, bind, **kw):
"""Handle spatial indexes during the after_drop event."""
# Restore original column list including managed Geometry columns
saved_cols = table.info.pop("_saved_columns", None)
if saved_cols is not None:
table.columns = saved_cols

View File

@@ -0,0 +1,369 @@
"""This module defines specific functions for SQLite dialect."""
import os
from typing import Optional
from sqlalchemy import text
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import func
from sqlalchemy.sql import select
from geoalchemy2 import functions
from geoalchemy2.admin.dialects.common import _check_spatial_type
from geoalchemy2.admin.dialects.common import _format_select_args
from geoalchemy2.admin.dialects.common import _spatial_idx_name
from geoalchemy2.admin.dialects.common import setup_create_drop
from geoalchemy2.types import Geography
from geoalchemy2.types import Geometry
from geoalchemy2.types import _DummyGeometry
from geoalchemy2.utils import authorized_values_in_docstring
def load_spatialite_driver(dbapi_conn, *args):
"""Load SpatiaLite extension in SQLite connection.
.. Warning::
The path to the SpatiaLite module should be set in the `SPATIALITE_LIBRARY_PATH`
environment variable.
Args:
dbapi_conn: The DBAPI connection.
"""
if "SPATIALITE_LIBRARY_PATH" not in os.environ:
raise RuntimeError("The SPATIALITE_LIBRARY_PATH environment variable is not set.")
dbapi_conn.enable_load_extension(True)
dbapi_conn.load_extension(os.environ["SPATIALITE_LIBRARY_PATH"])
dbapi_conn.enable_load_extension(False)
_JOURNAL_MODE_VALUES = ["DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"]
@authorized_values_in_docstring(JOURNAL_MODE_VALUES=_JOURNAL_MODE_VALUES)
def init_spatialite(
dbapi_conn,
*args,
transaction: bool = False,
init_mode: Optional[str] = None,
journal_mode: Optional[str] = None,
):
"""Initialize internal SpatiaLite tables.
Args:
dbapi_conn: The DBAPI connection.
init_mode: Can be `None` to load all EPSG SRIDs, `'WGS84'` to load only the ones related
to WGS84 or `'EMPTY'` to not load any EPSG SRID.
.. Note::
It is possible to load other EPSG SRIDs afterwards using `InsertEpsgSrid(srid)`.
transaction: If set to `True` the whole operation will be handled as a single Transaction
(faster). The default value is `False` (slower, but safer).
journal_mode: Change the journal mode to the given value. This can make the table creation
much faster. The possible values are the following: <JOURNAL_MODE_VALUES>. See
https://www.sqlite.org/pragma.html#pragma_journal_mode for more details.
.. Warning::
Some values, like 'MEMORY' or 'OFF', can lead to corrupted databases if the process
is interrupted during initialization.
.. Note::
The original value is restored after the initialization.
.. Note::
When using this function as a listener it is not possible to pass the `transaction`,
`init_mode` or `journal_mode` arguments directly. To do this you can either create another
function that calls `init_spatialite` (or
:func:`geoalchemy2.admin.dialects.sqlite.load_spatialite` if you also want to load the
SpatiaLite drivers) with an hard-coded `init_mode` or just use a lambda::
>>> sqlalchemy.event.listen(
... engine,
... "connect",
... lambda x, y: init_spatialite(
... x,
... y,
... transaction=True,
... init_mode="EMPTY",
... journal_mode="OFF",
... )
... )
"""
func_args = []
# Check the value of the 'transaction' parameter
if not isinstance(transaction, (bool, int)):
raise ValueError("The 'transaction' argument must be True or False.")
else:
func_args.append(str(transaction))
# Check the value of the 'init_mode' parameter
init_mode_values = ["WGS84", "EMPTY"]
if isinstance(init_mode, str):
init_mode = init_mode.upper()
if init_mode is not None:
if init_mode not in init_mode_values:
raise ValueError("The 'init_mode' argument must be one of {}.".format(init_mode_values))
func_args.append(f"'{init_mode}'")
# Check the value of the 'journal_mode' parameter
if isinstance(journal_mode, str):
journal_mode = journal_mode.upper()
if journal_mode is not None:
if journal_mode not in _JOURNAL_MODE_VALUES:
raise ValueError(
"The 'journal_mode' argument must be one of {}.".format(_JOURNAL_MODE_VALUES)
)
if dbapi_conn.execute("SELECT CheckSpatialMetaData();").fetchone()[0] < 1:
if journal_mode is not None:
current_journal_mode = dbapi_conn.execute("PRAGMA journal_mode").fetchone()[0]
dbapi_conn.execute("PRAGMA journal_mode = {}".format(journal_mode))
dbapi_conn.execute("SELECT InitSpatialMetaData({});".format(", ".join(func_args)))
if journal_mode is not None:
dbapi_conn.execute("PRAGMA journal_mode = {}".format(current_journal_mode))
def load_spatialite(*args, **kwargs):
"""Load SpatiaLite extension in SQLite DB and initialize internal tables.
See :func:`geoalchemy2.admin.dialects.sqlite.load_spatialite_driver` and
:func:`geoalchemy2.admin.dialects.sqlite.init_spatialite` functions for details about
arguments.
"""
load_spatialite_driver(*args)
init_spatialite(*args, **kwargs)
def _get_spatialite_attrs(bind, table_name, col_name):
attrs = bind.execute(
text(
"""SELECT * FROM "geometry_columns"
WHERE LOWER(f_table_name) = LOWER(:table_name)
AND LOWER(f_geometry_column) = LOWER(:column_name)
"""
).bindparams(table_name=table_name, column_name=col_name)
).fetchone()
if attrs is None:
# If the column is not registered as a spatial column we ignore it
return None
return attrs[2:]
def get_spatialite_version(bind):
"""Get the version of the currently loaded Spatialite extension."""
return bind.execute(text("SELECT spatialite_version();")).fetchone()[0]
def _setup_dummy_type(table, gis_cols):
"""Setup dummy type for new Geometry columns so they can be updated later."""
for col in gis_cols:
# Add dummy columns with GEOMETRY type
col._actual_type = col.type
col.type = _DummyGeometry()
table.columns = table.info["_saved_columns"]
def get_col_dim(col):
"""Get dimension of the column type."""
if col.type.dimension == 4:
dimension = "XYZM"
elif col.type.dimension == 2:
dimension = "XY"
else:
if col.type.geometry_type.endswith("M"):
dimension = "XYM"
else:
dimension = "XYZ"
return dimension
def create_spatial_index(bind, table, col):
"""Create spatial index on the given column."""
stmt = select(*_format_select_args(func.CreateSpatialIndex(table.name, col.name)))
stmt = stmt.execution_options(autocommit=True)
bind.execute(stmt)
def disable_spatial_index(bind, table, col):
"""Disable spatial indexes if present."""
stmt = select(*_format_select_args(func.CheckSpatialIndex(table.name, col.name)))
if bind.execute(stmt).fetchone()[0] is not None:
stmt = select(*_format_select_args(func.DisableSpatialIndex(table.name, col.name)))
stmt = stmt.execution_options(autocommit=True)
bind.execute(stmt)
bind.execute(
text(
"DROP TABLE IF EXISTS {};".format(
_spatial_idx_name(
table.name,
col.name,
)
)
)
)
def reflect_geometry_column(inspector, table, column_info):
"""Reflect a column of type Geometry with SQLite dialect."""
# Get geometry type, SRID and spatial index from the SpatiaLite metadata
if not isinstance(column_info.get("type"), Geometry):
return
col_attributes = _get_spatialite_attrs(inspector.bind, table.name, column_info["name"])
if col_attributes is not None:
geometry_type, coord_dimension, srid, spatial_index = col_attributes
if isinstance(geometry_type, int):
geometry_type_str = str(geometry_type)
if geometry_type >= 1000:
first_digit = geometry_type_str[0]
has_z = first_digit in ["1", "3"]
has_m = first_digit in ["2", "3"]
else:
has_z = has_m = False
geometry_type = {
"0": "GEOMETRY",
"1": "POINT",
"2": "LINESTRING",
"3": "POLYGON",
"4": "MULTIPOINT",
"5": "MULTILINESTRING",
"6": "MULTIPOLYGON",
"7": "GEOMETRYCOLLECTION",
}[geometry_type_str[-1]]
if has_z:
geometry_type += "Z"
if has_m:
geometry_type += "M"
else:
if "Z" in coord_dimension and "Z" not in geometry_type[-2:]:
geometry_type += "Z"
if "M" in coord_dimension and "M" not in geometry_type[-2:]:
geometry_type += "M"
coord_dimension = {
"XY": 2,
"XYZ": 3,
"XYM": 3,
"XYZM": 4,
}.get(coord_dimension, coord_dimension)
# Set attributes
column_info["type"].geometry_type = geometry_type
column_info["type"].dimension = coord_dimension
column_info["type"].srid = srid
column_info["type"].spatial_index = bool(spatial_index)
# Spatial indexes are not automatically reflected with SQLite dialect
column_info["type"]._spatial_index_reflected = False
def before_create(table, bind, **kw):
"""Handle spatial indexes during the before_create event."""
dialect, gis_cols, regular_cols = setup_create_drop(table, bind)
# Remove the spatial indexes from the table metadata because they should not be
# created during the table.create() step since the associated columns do not exist
# at this time.
table.info["_after_create_indexes"] = []
current_indexes = set(table.indexes)
for idx in current_indexes:
for col in table.info["_saved_columns"]:
if (_check_spatial_type(col.type, Geometry, dialect)) and col in idx.columns.values():
table.indexes.remove(idx)
if idx.name != _spatial_idx_name(table.name, col.name) or not getattr(
col.type, "spatial_index", False
):
table.info["_after_create_indexes"].append(idx)
_setup_dummy_type(table, gis_cols)
def after_create(table, bind, **kw):
"""Handle spatial indexes during the after_create event."""
dialect = bind.dialect
table.columns = table.info.pop("_saved_columns")
for col in table.columns:
# Add the managed Geometry columns with RecoverGeometryColumn()
if _check_spatial_type(col.type, Geometry, dialect):
col.type = col._actual_type
del col._actual_type
dimension = get_col_dim(col)
args = [table.name, col.name, col.type.srid, col.type.geometry_type, dimension]
stmt = select(*_format_select_args(func.RecoverGeometryColumn(*args)))
stmt = stmt.execution_options(autocommit=True)
bind.execute(stmt)
for col in table.columns:
# Add spatial indexes for the Geometry and Geography columns
# TODO: Check that the Geography type makes sense here
if (
_check_spatial_type(col.type, (Geometry, Geography), dialect)
and col.type.spatial_index is True
):
create_spatial_index(bind, table, col)
for idx in table.info.pop("_after_create_indexes"):
table.indexes.add(idx)
idx.create(bind=bind)
def before_drop(table, bind, **kw):
"""Handle spatial indexes during the before_drop event."""
dialect, gis_cols, regular_cols = setup_create_drop(table, bind)
for col in gis_cols:
# Disable spatial indexes if present
disable_spatial_index(bind, table, col)
args = [table.name, col.name]
stmt = select(*_format_select_args(func.DiscardGeometryColumn(*args)))
stmt = stmt.execution_options(autocommit=True)
bind.execute(stmt)
def after_drop(table, bind, **kw):
"""Handle spatial indexes during the after_drop event."""
table.columns = table.info.pop("_saved_columns")
# Define compiled versions for functions in SpatiaLite whose names don't have
# the ST_ prefix.
_SQLITE_FUNCTIONS = {
"ST_GeomFromEWKT": "GeomFromEWKT",
"ST_GeomFromEWKB": "GeomFromEWKB",
"ST_AsBinary": "AsBinary",
"ST_AsEWKB": "AsEWKB",
"ST_AsGeoJSON": "AsGeoJSON",
}
def _compiles_sqlite(cls, fn):
def _compile_sqlite(element, compiler, **kw):
return "{}({})".format(fn, compiler.process(element.clauses, **kw))
compiles(getattr(functions, cls), "sqlite")(_compile_sqlite)
def register_sqlite_mapping(mapping):
"""Register compilation mappings for the given functions.
Args:
mapping: Should have the following form::
{
"function_name_1": "sqlite_function_name_1",
"function_name_2": "sqlite_function_name_2",
...
}
"""
for cls, fn in mapping.items():
_compiles_sqlite(cls, fn)
register_sqlite_mapping(_SQLITE_FUNCTIONS)