opdb/opdb/db/orm.py

158 lines
5.3 KiB
Python

# This file is part of the Open Parts Database software
# Copyright (C) 2022 Valentin Lorentz
#
# This program is free software: you can redistribute it and/or modify it under the
# terms of the GNU Affero General Public License version 3, as published by the
# Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
# PARTICULAR PURPOSE. See the GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
"""
A minimalist ORM
Features:
* generates postgresql schemas
* provides easy access to postgresql's COPY TO (even for jsonb columns)
* checks :cls:`datetime.datetime` objects are timezone-aware.
"""
import dataclasses
import datetime
import json
import typing
import psycopg
_TSelf = typing.TypeVar("_TSelf", bound="BaseModel")
_TYPE_TO_SQL = {
datetime.datetime: "timestamptz",
str: "text",
bytes: "bytea",
dict: "jsonb",
}
def _type_to_sql(type_: type, *, nullable=False) -> str:
origin = getattr(type_, "__origin__", None)
if origin is typing.Union:
variants = type_.__args__ # type: ignore[attr-defined]
non_none_variants = [
variant for variant in variants if not issubclass(variant, type(None))
]
if len(variants) != 2:
raise TypeError(
f"Unsupported type: {type_} (expected exactly 2 variants, "
f"got {variants!r})"
)
if len(non_none_variants) != 1:
raise TypeError(
f"Unsupported type: {type_} (expected exactly 1 non-None variant, "
f"got {non_none_variants!r})"
)
(inner_type,) = non_none_variants
# type is Optional[inner_type]
return _type_to_sql(inner_type, nullable=True)
elif origin is not None:
# another generic type; simply ignore its __args__
return _type_to_sql(origin)
else:
sql_type = _TYPE_TO_SQL[type_]
if not nullable:
sql_type += " NOT NULL"
return sql_type
class BaseModel:
"""
Base class for all model classes, which provides class methods to generate
DB schema and efficiently insert instances.
"""
TABLE: str
"""Name of the SQL table."""
PK: tuple[str, ...]
"""Primary key of the SQL table."""
__DATETIME_FIELD_NAMES: list[str]
__JSON_FIELD_NAMES: list[str]
def __init_subclass__(cls, *args, **kwargs):
"""
Precomputes ``__DATETIME_FIELD_NAMES`` and ``__JSON_FIELD_NAMES`` on
class initialization, so ``__post_init__`` and ``copy_to_db`` do not need
to run the whole introspection machinery every time.
"""
super().__init_subclass__(*args, **kwargs)
cls.__DATETIME_FIELD_NAMES = []
cls.__JSON_FIELD_NAMES = []
for (field_name, field_type) in cls.__annotations__.items():
if isinstance(field_type, type):
origin = getattr(field_type, "__origin__", None)
args = getattr(field_type, "__args__", None)
if issubclass(field_type, datetime.datetime) or (
origin is typing.Union and datetime.datetime in args
):
cls.__DATETIME_FIELD_NAMES.append(field_name)
if issubclass(field_type, dict) or (
origin is not None and issubclass(origin, dict)
):
cls.__JSON_FIELD_NAMES.append(field_name)
return cls
def __post_init__(self):
"""
Errors if any of the fields is a naive datetime.
"""
for field_name in self.__DATETIME_FIELD_NAMES:
if getattr(self, field_name).tzinfo is None:
raise TypeError(f"{field_name} must be a timezone-aware datetime.")
@classmethod
def copy_to_db(
cls: type[_TSelf], conn: psycopg.Connection, objects: typing.Iterable[_TSelf]
) -> None:
"""
Takes a postgresql connection and an iterable of instances,
and inserts all the instances efficiently in postgresql.
"""
cols = [field.name for field in dataclasses.fields(cls)]
with conn.cursor() as cur:
with cur.copy(f"COPY {cls.TABLE} ({', '.join(cols)}) FROM STDIN") as copy:
for obj in objects:
row = tuple(
json.dumps(getattr(obj, col))
if col in cls.__JSON_FIELD_NAMES
else getattr(obj, col)
for col in cols
)
copy.write_row(row)
@classmethod
def db_schema(cls) -> str:
"""
Returns SQL code suitable to initialize a table to store instances
of this class.
"""
return "\n".join(
[
f"CREATE TABLE IF NOT EXISTS {cls.TABLE} (",
",\n".join(
f" {field.name} {_type_to_sql(field.type)}"
for field in dataclasses.fields(cls)
),
");",
f"CREATE UNIQUE INDEX IF NOT EXISTS {cls.TABLE}_pk ON {cls.TABLE} "
f"({', '.join(cls.PK)});",
]
)