158 lines
5.3 KiB
Python
158 lines
5.3 KiB
Python
# This file is part of the Open Parts Database software
|
|
# Copyright (C) 2022 Valentin Lorentz
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify it under the
|
|
# terms of the GNU Affero General Public License version 3, as published by the
|
|
# Free Software Foundation.
|
|
#
|
|
# This program is distributed in the hope that it will be useful, but WITHOUT ANY
|
|
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
|
|
# PARTICULAR PURPOSE. See the GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License along with
|
|
# this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
"""
|
|
A minimalist ORM
|
|
|
|
Features:
|
|
|
|
* generates postgresql schemas
|
|
* provides easy access to postgresql's COPY TO (even for jsonb columns)
|
|
* checks :cls:`datetime.datetime` objects are timezone-aware.
|
|
"""
|
|
|
|
import dataclasses
|
|
import datetime
|
|
import json
|
|
import typing
|
|
|
|
import psycopg
|
|
|
|
_TSelf = typing.TypeVar("_TSelf", bound="BaseModel")
|
|
|
|
_TYPE_TO_SQL = {
|
|
datetime.datetime: "timestamptz",
|
|
str: "text",
|
|
bytes: "bytea",
|
|
dict: "jsonb",
|
|
}
|
|
|
|
|
|
def _type_to_sql(type_: type, *, nullable=False) -> str:
|
|
origin = getattr(type_, "__origin__", None)
|
|
if origin is typing.Union:
|
|
variants = type_.__args__ # type: ignore[attr-defined]
|
|
non_none_variants = [
|
|
variant for variant in variants if not issubclass(variant, type(None))
|
|
]
|
|
if len(variants) != 2:
|
|
raise TypeError(
|
|
f"Unsupported type: {type_} (expected exactly 2 variants, "
|
|
f"got {variants!r})"
|
|
)
|
|
if len(non_none_variants) != 1:
|
|
raise TypeError(
|
|
f"Unsupported type: {type_} (expected exactly 1 non-None variant, "
|
|
f"got {non_none_variants!r})"
|
|
)
|
|
|
|
(inner_type,) = non_none_variants
|
|
# type is Optional[inner_type]
|
|
|
|
return _type_to_sql(inner_type, nullable=True)
|
|
elif origin is not None:
|
|
# another generic type; simply ignore its __args__
|
|
return _type_to_sql(origin)
|
|
else:
|
|
sql_type = _TYPE_TO_SQL[type_]
|
|
if not nullable:
|
|
sql_type += " NOT NULL"
|
|
return sql_type
|
|
|
|
|
|
class BaseModel:
|
|
"""
|
|
Base class for all model classes, which provides class methods to generate
|
|
DB schema and efficiently insert instances.
|
|
"""
|
|
|
|
TABLE: str
|
|
"""Name of the SQL table."""
|
|
|
|
PK: tuple[str, ...]
|
|
"""Primary key of the SQL table."""
|
|
|
|
__DATETIME_FIELD_NAMES: list[str]
|
|
__JSON_FIELD_NAMES: list[str]
|
|
|
|
def __init_subclass__(cls, *args, **kwargs):
|
|
"""
|
|
Precomputes ``__DATETIME_FIELD_NAMES`` and ``__JSON_FIELD_NAMES`` on
|
|
class initialization, so ``__post_init__`` and ``copy_to_db`` do not need
|
|
to run the whole introspection machinery every time.
|
|
"""
|
|
super().__init_subclass__(*args, **kwargs)
|
|
cls.__DATETIME_FIELD_NAMES = []
|
|
cls.__JSON_FIELD_NAMES = []
|
|
for (field_name, field_type) in cls.__annotations__.items():
|
|
if isinstance(field_type, type):
|
|
origin = getattr(field_type, "__origin__", None)
|
|
args = getattr(field_type, "__args__", None)
|
|
if issubclass(field_type, datetime.datetime) or (
|
|
origin is typing.Union and datetime.datetime in args
|
|
):
|
|
cls.__DATETIME_FIELD_NAMES.append(field_name)
|
|
if issubclass(field_type, dict) or (
|
|
origin is not None and issubclass(origin, dict)
|
|
):
|
|
cls.__JSON_FIELD_NAMES.append(field_name)
|
|
return cls
|
|
|
|
def __post_init__(self):
|
|
"""
|
|
Errors if any of the fields is a naive datetime.
|
|
"""
|
|
for field_name in self.__DATETIME_FIELD_NAMES:
|
|
if getattr(self, field_name).tzinfo is None:
|
|
raise TypeError(f"{field_name} must be a timezone-aware datetime.")
|
|
|
|
@classmethod
|
|
def copy_to_db(
|
|
cls: type[_TSelf], conn: psycopg.Connection, objects: typing.Iterable[_TSelf]
|
|
) -> None:
|
|
"""
|
|
Takes a postgresql connection and an iterable of instances,
|
|
and inserts all the instances efficiently in postgresql.
|
|
"""
|
|
cols = [field.name for field in dataclasses.fields(cls)]
|
|
with conn.cursor() as cur:
|
|
with cur.copy(f"COPY {cls.TABLE} ({', '.join(cols)}) FROM STDIN") as copy:
|
|
for obj in objects:
|
|
row = tuple(
|
|
json.dumps(getattr(obj, col))
|
|
if col in cls.__JSON_FIELD_NAMES
|
|
else getattr(obj, col)
|
|
for col in cols
|
|
)
|
|
copy.write_row(row)
|
|
|
|
@classmethod
|
|
def db_schema(cls) -> str:
|
|
"""
|
|
Returns SQL code suitable to initialize a table to store instances
|
|
of this class.
|
|
"""
|
|
return "\n".join(
|
|
[
|
|
f"CREATE TABLE IF NOT EXISTS {cls.TABLE} (",
|
|
",\n".join(
|
|
f" {field.name} {_type_to_sql(field.type)}"
|
|
for field in dataclasses.fields(cls)
|
|
),
|
|
");",
|
|
f"CREATE UNIQUE INDEX IF NOT EXISTS {cls.TABLE}_pk ON {cls.TABLE} "
|
|
f"({', '.join(cls.PK)});",
|
|
]
|
|
)
|