diff --git a/opdb/__main__.py b/opdb/__main__.py
new file mode 100644
index 0000000..261c20a
--- /dev/null
+++ b/opdb/__main__.py
@@ -0,0 +1,51 @@
+# This file is part of the Open Parts Database software
+# Copyright (C) 2022 Valentin Lorentz
+#
+# This program is free software: you can redistribute it and/or modify it under the
+# terms of the GNU Affero General Public License version 3, as published by the
+# Free Software Foundation.
+#
+# This program is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
+# PARTICULAR PURPOSE. See the GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License along with
+# this program. If not, see .
+
+"""
+CLI entrypoint.
+"""
+
+import sys
+import typing
+
+
+def error(msg: str) -> typing.NoReturn:
+ """Prints the message on stderr and exits with code 1."""
+ print(msg, file=sys.stderr)
+ sys.exit(1)
+
+
+def main() -> None:
+ """CLI entrypoint"""
+ try:
+ (executable, subcommand, *args) = sys.argv
+ except ValueError:
+ error(f"Syntax: {sys.argv[0]} [ [ [...]]]")
+
+ if subcommand == "initdb":
+ from opdb.db import Db # pylint: disable=import-outside-toplevel
+
+ try:
+ (dsn,) = args
+ except ValueError:
+ error(f"Syntax: {executable} initdb ")
+
+ with Db.open(dsn) as db:
+ db.init()
+ else:
+ error(f"Unknown subcommand: {subcommand}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/opdb/conftest.py b/opdb/conftest.py
index 1a7d0ff..87802c1 100644
--- a/opdb/conftest.py
+++ b/opdb/conftest.py
@@ -18,7 +18,7 @@ pytest fixtures
import pytest
-from opdb.db import Db, models
+from opdb.db import Db
def iter_subclasses(cls):
@@ -35,9 +35,6 @@ def opdb_db(postgresql) -> Db:
"""
pytest fixture which yields an empty initialized OPDB database.
"""
- with postgresql.cursor() as cur:
- for name in dir(models):
- cls = getattr(models, name)
- if hasattr(cls, "TABLE"):
- cur.execute(cls.db_schema())
- return Db(postgresql)
+ db = Db(postgresql)
+ db.init()
+ return db
diff --git a/opdb/db/db.py b/opdb/db/db.py
index f59c6a2..28f0e1f 100644
--- a/opdb/db/db.py
+++ b/opdb/db/db.py
@@ -44,6 +44,16 @@ class Db:
with psycopg.connect(dsn) as conn:
yield Db(conn)
+ def init(self) -> None:
+ """
+ Initializes the schema for the connected database.
+ """
+ with self.conn.cursor() as cur:
+ for name in dir(models):
+ cls = getattr(models, name)
+ if hasattr(cls, "TABLE"):
+ cur.execute(cls.db_schema())
+
def get_last_web_page_snapshot(
self, url: str
) -> typing.Optional[models.WebPageSnapshot]: