# This file is part of the Open Parts Database software # Copyright (C) 2022 Valentin Lorentz # # This program is free software: you can redistribute it and/or modify it under the # terms of the GNU Affero General Public License version 3, as published by the # Free Software Foundation. # # This program is distributed in the hope that it will be useful, but WITHOUT ANY # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A # PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License along with # this program. If not, see . """ A minimalist ORM Features: * generates postgresql schemas * provides easy access to postgresql's COPY TO (even for jsonb columns) * checks :cls:`datetime.datetime` objects are timezone-aware. """ import dataclasses import datetime import json import typing import psycopg _TSelf = typing.TypeVar("_TSelf", bound="BaseModel") _TYPE_TO_SQL = { datetime.datetime: "timestamptz", str: "text", bytes: "bytea", dict: "jsonb", } def _type_to_sql(type_: type, *, nullable=False) -> str: origin = getattr(type_, "__origin__", None) if origin is typing.Union: variants = type_.__args__ # type: ignore[attr-defined] non_none_variants = [ variant for variant in variants if not issubclass(variant, type(None)) ] if len(variants) != 2: raise TypeError( f"Unsupported type: {type_} (expected exactly 2 variants, " f"got {variants!r})" ) if len(non_none_variants) != 1: raise TypeError( f"Unsupported type: {type_} (expected exactly 1 non-None variant, " f"got {non_none_variants!r})" ) (inner_type,) = non_none_variants # type is Optional[inner_type] return _type_to_sql(inner_type, nullable=True) elif origin is not None: # another generic type; simply ignore its __args__ return _type_to_sql(origin) else: sql_type = _TYPE_TO_SQL[type_] if not nullable: sql_type += " NOT NULL" return sql_type class BaseModel: """ Base class for all model classes, which provides class methods to generate DB schema and efficiently insert instances. """ TABLE: str """Name of the SQL table.""" PK: tuple[str, ...] """Primary key of the SQL table.""" __DATETIME_FIELD_NAMES: list[str] __JSON_FIELD_NAMES: list[str] def __init_subclass__(cls, *args, **kwargs): """ Precomputes ``__DATETIME_FIELD_NAMES`` and ``__JSON_FIELD_NAMES`` on class initialization, so ``__post_init__`` and ``copy_to_db`` do not need to run the whole introspection machinery every time. """ super().__init_subclass__(*args, **kwargs) cls.__DATETIME_FIELD_NAMES = [] cls.__JSON_FIELD_NAMES = [] for (field_name, field_type) in cls.__annotations__.items(): if isinstance(field_type, type): origin = getattr(field_type, "__origin__", None) args = getattr(field_type, "__args__", None) if issubclass(field_type, datetime.datetime) or ( origin is typing.Union and datetime.datetime in args ): cls.__DATETIME_FIELD_NAMES.append(field_name) if issubclass(field_type, dict) or ( origin is not None and issubclass(origin, dict) ): cls.__JSON_FIELD_NAMES.append(field_name) return cls def __post_init__(self): """ Errors if any of the fields is a naive datetime. """ for field_name in self.__DATETIME_FIELD_NAMES: if getattr(self, field_name).tzinfo is None: raise TypeError(f"{field_name} must be a timezone-aware datetime.") @classmethod def copy_to_db( cls: type[_TSelf], conn: psycopg.Connection, objects: typing.Iterable[_TSelf] ) -> None: """ Takes a postgresql connection and an iterable of instances, and inserts all the instances efficiently in postgresql. """ cols = [field.name for field in dataclasses.fields(cls)] with conn.cursor() as cur: with cur.copy(f"COPY {cls.TABLE} ({', '.join(cols)}) FROM STDIN") as copy: for obj in objects: row = tuple( json.dumps(getattr(obj, col)) if col in cls.__JSON_FIELD_NAMES else getattr(obj, col) for col in cols ) copy.write_row(row) @classmethod def db_schema(cls) -> str: """ Returns SQL code suitable to initialize a table to store instances of this class. """ return "\n".join( [ f"CREATE TABLE IF NOT EXISTS {cls.TABLE} (", ",\n".join( f" {field.name} {_type_to_sql(field.type)}" for field in dataclasses.fields(cls) ), ");", f"CREATE UNIQUE INDEX IF NOT EXISTS {cls.TABLE}_pk ON {cls.TABLE} " f"({', '.join(cls.PK)});", ] )