From e3dd878cf31bedb0ac7859df00ab97f6fefb1202 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eli=C3=A1n=20Hanisch?= Date: Thu, 5 Jul 2012 02:31:53 -0300 Subject: [PATCH] save ban autoremoval information in a csv files, so it isn't lost during plugin reloads. --- Bantracker/plugin.py | 58 +++++++++++++++++++++++++++++++++++++++++--- Bantracker/test.py | 23 ++++++++++++++++++ 2 files changed, 77 insertions(+), 4 deletions(-) diff --git a/Bantracker/plugin.py b/Bantracker/plugin.py index 34fd292..c080a02 100644 --- a/Bantracker/plugin.py +++ b/Bantracker/plugin.py @@ -58,6 +58,7 @@ import sqlite import pytz import cPickle import datetime +import csv import time import random import hashlib @@ -255,6 +256,23 @@ class Ban(object): return 'ban' return 'removal' + def serialize(self): + id = self.id + if id is None: + id = '' + return (id, self.channel, self.mask, self.who, self.when) + + def deserialize(self, L): + id = L[0] + if id == '': + id = None + else: + id = int(id) + self.id = id + self.channel, self.mask, self.who = L[1:4] + self.when = float(L[4]) + self.ascwhen = time.asctime(time.gmtime(self.when)) + class ReviewStore(dict): def __init__(self, filename): @@ -317,8 +335,6 @@ class BanRemoval(object): ban: ban object expires: time in seconds for it to expire """ - assert isinstance(ban, Ban), "ban is not a Ban object" - assert isinstance(expires, int), "expire time isn't an integer" self.ban = ban self.expires = expires self.notified = False @@ -329,10 +345,22 @@ class BanRemoval(object): return True return False + def serialize(self): + notified = self.notified and 1 or 0 + L = [ self.expires, notified ] + L.extend(self.ban.serialize()) + return tuple(L) + + def deserialize(self, L): + self.expires = int(L[0]) + self.notified = bool(int(L[1])) + self.ban = Ban(args=(None, None, None, None, 0)) + self.ban.deserialize(L[2:]) + class BanStore(object): def __init__(self, filename): - # this should be stored into a file + self.filename = conf.supybot.directories.data.dirize(filename) self.shelf = [] def __iter__(self): @@ -341,6 +369,26 @@ class BanStore(object): def __len__(self): return len(self.shelf) + def open(self): + try: + reader = csv.reader(open(self.filename, 'rb')) + except IOError: + return + + for row in reader: + ban = BanRemoval(None, None) + ban.deserialize(row) + self.add(ban) + + def close(self): + try: + writer = csv.writer(open(self.filename, 'wb')) + except IOError: + return + + for ban in self: + writer.writerow(ban.serialize()) + def add(self, obj): self.shelf.append(obj) @@ -408,7 +456,8 @@ class Bantracker(callbacks.Plugin): self._banreviewfix() # init autoremove stuff - self.managedBans = BanStore('FIXME') + self.managedBans = BanStore('bt.autoremove.db') + self.managedBans.open() # add our scheduled events for check bans for reviews or removal schedule.addPeriodicEvent(lambda: self.reviewBans(irc), 60*60, @@ -545,6 +594,7 @@ class Bantracker(callbacks.Plugin): schedule.removeEvent(self.name() + '_review') schedule.removeEvent(self.name() + '_autoremove') self.pendingReviews.close() + self.managedBans.close() def reset(self): global queue diff --git a/Bantracker/test.py b/Bantracker/test.py index 0afd3c8..d854410 100644 --- a/Bantracker/test.py +++ b/Bantracker/test.py @@ -406,5 +406,28 @@ class BantrackerTestCase(ChannelPluginTestCase): finally: del pluginConf.autoremove.notify.channels()[:] + def testAutoremoveStore(self): + self.feedBan('asd!*@*') + self.feedBan('qwe!*@*') + self.feedBan('zxc!*@*', mode='q') + self.assertNotError('banremove 1 10m') + self.assertNotError('banremove 2 1d') + self.assertNotError('banremove 3 1w') + cb = self.getCallback() + cb.managedBans.shelf[1].notified = True + cb.managedBans.close() + cb.managedBans.shelf = [] + cb.managedBans.open() + L = cb.managedBans.shelf + for i, n in enumerate((600, 86400, 604800)): + self.assertEqual(L[i].expires, n) + for i, n in enumerate((False, True, False)): + self.assertEqual(L[i].notified, n) + for i, n in enumerate((1, 2, 3)): + self.assertEqual(L[i].ban.id, n) + for i, n in enumerate(('asd!*@*', 'qwe!*@*', '%zxc!*@*')): + self.assertEqual(L[i].ban.mask, n) + self.assertEqual(L[0].ban.channel, '#test') +