Cache SPARQL queries to a local database

This commit is contained in:
Val Lorentz 2023-05-28 20:03:55 +02:00
parent 44eb8147c8
commit f076efffc6
3 changed files with 121 additions and 5 deletions

103
glowtables/cache.py Normal file
View File

@ -0,0 +1,103 @@
# This file is part of the Glowtables software
# Copyright (C) 2023 Valentin Lorentz
#
# This program is free software: you can redistribute it and/or modify it under the
# terms of the GNU Affero General Public License version 3, as published by the
# Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
# PARTICULAR PURPOSE. See the GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
"""SPARQL query cache"""
import datetime
import random
import sqlite3
from typing import Optional
EXPIRE_PROBA = 0.001
"""Probability an ``INSERT INTO`` is preceded by a ``DELETE`` of all old records."""
CACHE_LIFETIME = datetime.timedelta(days=7)
def _now() -> datetime.datetime:
return datetime.datetime.now(tz=datetime.timezone.utc)
class Cache:
"""A simple key-value cache for SPARQL queries"""
def __init__(self, db: str):
self._db = sqlite3.connect(db)
self._init_schema()
def _init_schema(self):
"""Initialize tables and indexes"""
with self._db:
self._db.execute(
"""
CREATE TABLE IF NOT EXISTS sparql_queries (
url TEXT,
query TEXT,
response TEXT,
date TEXT -- ISO8601 timestamp of the recorded query, must be UTC
);
"""
)
self._db.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS sparql_queries_pk
ON sparql_queries (url, query)
"""
)
def _expire(self) -> None:
"""Randomly delete outdated item from the database."""
if random.random() < EXPIRE_PROBA:
with self._db:
self._db.execute(
"""
DELETE FROM sparql_queries WHERE date < ?
""",
((_now() - CACHE_LIFETIME).isoformat()),
)
def get(self, url: str, query: str) -> Optional[str]:
"""Gets the response to a previous query from the cache, or None."""
with self._db:
cur = self._db.execute(
"""
SELECT response
FROM sparql_queries
WHERE url=? AND query=? AND date >= ?
""",
(url, query, (_now() - CACHE_LIFETIME).isoformat()),
)
rows = list(cur)
if rows:
# cache hit
((resp,),) = rows
return resp
else:
# cache miss
return None
def set(self, url: str, query: str, response: str) -> None:
"""Adds the response of a query to the cache."""
self._expire()
with self._db:
self._db.execute(
"""
INSERT INTO sparql_queries(url, query, response, date)
VALUES (?, ?, ?, ?)
ON CONFLICT(url, query) DO UPDATE SET
response=EXCLUDED.response,
date=EXCLUDED.date
""",
(url, query, response, _now().isoformat()),
)

View File

@ -15,11 +15,14 @@
"""Abstraction over SPARQL backends, primarily meant to be mocked by tests."""
import abc
import json
import urllib.parse
from typing import Iterable
import requests
from .cache import Cache
class SparqlBackend(abc.ABC):
"""Abstract class for SPARQL clients"""
@ -32,7 +35,7 @@ class SparqlBackend(abc.ABC):
class RemoteSparqlBackend(SparqlBackend):
"""Queries a SPARQL API over HTTP."""
def __init__(self, url: str, agent: str):
def __init__(self, url: str, agent: str, cache: Cache):
"""
:param url: Base URL of the endpoint
:param agent: User-Agent to use in HTTP requests
@ -40,6 +43,7 @@ class RemoteSparqlBackend(SparqlBackend):
self._url = url
self._session = requests.Session()
self._session.headers["User-Agent"] = agent
self._cache = cache
def query(self, query: str) -> Iterable[tuple]:
headers = {
@ -47,9 +51,15 @@ class RemoteSparqlBackend(SparqlBackend):
"Accept": "application/json",
}
params = {"query": query}
resp = self._session.post(
self._url, headers=headers, data=urllib.parse.urlencode(params)
).json()
resp_text = self._cache.get(self._url, query)
if not resp_text:
resp_text = self._session.post(
self._url, headers=headers, data=urllib.parse.urlencode(params)
).text
self._cache.set(self._url, query, resp_text)
resp = json.loads(resp_text)
variables = resp["head"]["vars"]
for result in resp["results"]["bindings"]:
yield tuple(result.get(variable) for variable in variables)

View File

@ -21,6 +21,7 @@ import urllib.parse
import pytest
import rdflib
from glowtables.cache import Cache
from glowtables.sparql import RemoteSparqlBackend
@ -47,4 +48,6 @@ def rdflib_sparql(requests_mock, rdflib_graph: rdflib.Graph) -> RemoteSparqlBack
}
requests_mock.register_uri("POST", "mock://sparql.example.org/", json=json_callback)
return RemoteSparqlBackend("mock://sparql.example.org/", agent="Mock Client")
return RemoteSparqlBackend(
"mock://sparql.example.org/", agent="Mock Client", cache=Cache(":memory:")
)