Improved db migrations, & fixes
This commit is contained in:
@@ -1,15 +1,17 @@
|
||||
import datetime
|
||||
import functools
|
||||
import logging
|
||||
import os.path
|
||||
import re
|
||||
import sqlite3 as sqlite
|
||||
from collections import namedtuple
|
||||
|
||||
### Decorators ###
|
||||
|
||||
|
||||
def transaction(func):
|
||||
"""Use this decorator on any methods that needs to query the database to
|
||||
ensure that connections are properly opened and closed.
|
||||
ensure that connections are properly opened and closed, without being
|
||||
left open unnecessarily.
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def newfunc(self, *args, **kwargs):
|
||||
@@ -20,7 +22,16 @@ def transaction(func):
|
||||
self.cursor = self.conn.cursor()
|
||||
created_conn = True
|
||||
|
||||
ret = func(self, *args, **kwargs)
|
||||
try:
|
||||
return_value = func(self, *args, **kwargs)
|
||||
except Exception as e:
|
||||
if self.conn.in_transaction:
|
||||
self.conn.rollback()
|
||||
if created_conn:
|
||||
self.cursor.close()
|
||||
self.conn.close()
|
||||
self.cursor = self.conn = None
|
||||
raise e
|
||||
|
||||
if self.conn.in_transaction:
|
||||
self.conn.commit()
|
||||
@@ -30,75 +41,114 @@ def transaction(func):
|
||||
self.conn.close()
|
||||
self.cursor = self.conn = None
|
||||
|
||||
return ret
|
||||
return return_value
|
||||
|
||||
return newfunc
|
||||
|
||||
|
||||
### Classes ###
|
||||
|
||||
DatabaseVersion = namedtuple('DatabaseVersion', 'major minor patch pre_release_name pre_release_number')
|
||||
class DatabaseVersion:
|
||||
|
||||
PRE_RELEASE_NAME_ORDER_NUMBER = {
|
||||
None: None,
|
||||
'alpha': 0,
|
||||
'beta': 1,
|
||||
'rc': 2
|
||||
}
|
||||
|
||||
def __init__(self, major, minor, patch, pre_release_name=None, pre_release_number=None):
|
||||
self.major = major
|
||||
self.minor = minor
|
||||
self.patch = patch
|
||||
self.pre_release_name = pre_release_name
|
||||
self.pre_release_number = pre_release_number
|
||||
|
||||
def __lt__(self, other):
|
||||
self_tuple = (self.major, self.minor, self.patch, self.PRE_RELEASE_NAME_ORDER_NUMBER[self.pre_release_name], self.pre_release_number)
|
||||
other_tuple = (other.major, other.minor, other.patch, self.PRE_RELEASE_NAME_ORDER_NUMBER[other.pre_release_name], other.pre_release_number)
|
||||
return self_tuple < other_tuple
|
||||
|
||||
def __str__(self):
|
||||
version_string = f'{self.major}.{self.minor}.{self.patch}'
|
||||
if self.pre_release_name is not None:
|
||||
version_string += f'-{self.pre_release_name}'
|
||||
if self.pre_release_number is not None:
|
||||
version_string += f'.{self.pre_release_number}'
|
||||
return version_string
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, version_string):
|
||||
match = re.match(r'(\d+).(\d+).(\d+)(?:-([:alpha:]+)(?:.(\d+))?)?', version_string)
|
||||
if not match:
|
||||
return None
|
||||
groups = match.groups()
|
||||
return cls(int(groups[0]), int(groups[1]), int(groups[2]), groups[3], int(groups[4]))
|
||||
|
||||
|
||||
class Database:
|
||||
|
||||
VERSION = DatabaseVersion(0, 2, 0, None, None)
|
||||
# LATEST_VERSION = DatabaseVersion(0, 2, 0, None, None)
|
||||
LATEST_VERSION = DatabaseVersion(0, 2, 0)
|
||||
|
||||
SCHEMA = '''
|
||||
---------------------------
|
||||
-- Schema version: 0.1.0 --
|
||||
---------------------------
|
||||
|
||||
CREATE TABLE IF NOT EXISTS redditor_points (
|
||||
id TEXT UNIQUE NOT NULL,
|
||||
name TEXT UNIQUE NOT NULL,
|
||||
points INTEGER DEFAULT 0
|
||||
);
|
||||
|
||||
---------------------------
|
||||
-- Schema version: 0.2.0 --
|
||||
---------------------------
|
||||
|
||||
-- Tracking bot/db version for potential future use in migrations et al.
|
||||
CREATE TABLE IF NOT EXISTS bot_version (
|
||||
major INTEGER NOT NULL,
|
||||
minor INTEGER NOT NULL,
|
||||
patch INTEGER NOT NULL,
|
||||
pre_release_name TEXT,
|
||||
pre_release_number INTEGER
|
||||
);
|
||||
INSERT OR IGNORE INTO bot_version (major, minor, patch) VALUES (0, 2, 0);
|
||||
|
||||
ALTER TABLE redditor_points RENAME TO redditor;
|
||||
-- TODO rename "id" columns to "reddit_id" for consistency/clarity?
|
||||
-- ALTER TABLE redditor RENAME COLUMN id TO reddit_id;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS submission (
|
||||
id TEXT UNIQUE NOT NULL,
|
||||
author_id TEXT UNIQUE NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS comment (
|
||||
id TEXT UNIQUE NOT NULL,
|
||||
author_id TEXT NOT NULL,
|
||||
author_rowid INTEGER, -- May be NULL **for now**
|
||||
created_at_datetime TEXT NOT NULL,
|
||||
FOREIGN KEY (author_rowid) REFERENCES redditor (rowid) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS solution (
|
||||
submission_rowid INTEGER NOT NULL,
|
||||
author_rowid INTEGER NOT NULL,
|
||||
comment_rowid INTEGER NOT NULL,
|
||||
chosen_by_comment_rowid INTEGER NOT NULL,
|
||||
removed_by_comment_rowid INTEGER,
|
||||
FOREIGN KEY (submission_rowid) REFERENCES submission (rowid) ON DELETE CASCADE,
|
||||
FOREIGN KEY (author_rowid) REFERENCES redditor (rowid) ON DELETE CASCADE,
|
||||
FOREIGN KEY (comment_rowid) REFERENCES comment (rowid) ON DELETE CASCADE,
|
||||
FOREIGN KEY (chosen_by_comment_rowid) REFERENCES comment (rowid) ON DELETE SET NULL,
|
||||
FOREIGN KEY (removed_by_comment_rowid) REFERENCES comment (rowid) ON DELETE SET NULL,
|
||||
PRIMARY KEY (submission_rowid, author_rowid) ON DELETE CASCADE
|
||||
);
|
||||
'''
|
||||
SCHEMA_VERSION_STATEMENTS = {
|
||||
DatabaseVersion(0, 1, 0): [
|
||||
'''
|
||||
CREATE TABLE IF NOT EXISTS redditor_points (
|
||||
id TEXT UNIQUE NOT NULL,
|
||||
name TEXT UNIQUE NOT NULL,
|
||||
points INTEGER DEFAULT 0
|
||||
)
|
||||
'''
|
||||
],
|
||||
DatabaseVersion(0, 2, 0): [
|
||||
'''
|
||||
CREATE TABLE IF NOT EXISTS bot_version (
|
||||
major INTEGER NOT NULL,
|
||||
minor INTEGER NOT NULL,
|
||||
patch INTEGER NOT NULL,
|
||||
pre_release_name TEXT,
|
||||
pre_release_number INTEGER
|
||||
)
|
||||
''',
|
||||
'''
|
||||
INSERT OR IGNORE INTO bot_version (major, minor, patch) VALUES (0, 2, 0)
|
||||
''',
|
||||
'''
|
||||
ALTER TABLE redditor_points RENAME TO redditor
|
||||
''',
|
||||
'''
|
||||
CREATE TABLE IF NOT EXISTS submission (
|
||||
id TEXT UNIQUE NOT NULL,
|
||||
author_id TEXT UNIQUE NOT NULL
|
||||
)
|
||||
''',
|
||||
'''
|
||||
CREATE TABLE IF NOT EXISTS comment (
|
||||
id TEXT UNIQUE NOT NULL,
|
||||
author_id TEXT NOT NULL,
|
||||
author_rowid INTEGER, -- May be NULL **for now**
|
||||
created_at_datetime TEXT NOT NULL,
|
||||
FOREIGN KEY (author_rowid) REFERENCES redditor (rowid) ON DELETE CASCADE
|
||||
)
|
||||
''',
|
||||
'''
|
||||
CREATE TABLE IF NOT EXISTS solution (
|
||||
submission_rowid INTEGER NOT NULL,
|
||||
author_rowid INTEGER NOT NULL,
|
||||
comment_rowid INTEGER NOT NULL,
|
||||
chosen_by_comment_rowid INTEGER NOT NULL,
|
||||
removed_by_comment_rowid INTEGER,
|
||||
FOREIGN KEY (submission_rowid) REFERENCES submission (rowid) ON DELETE CASCADE,
|
||||
FOREIGN KEY (author_rowid) REFERENCES redditor (rowid) ON DELETE CASCADE,
|
||||
FOREIGN KEY (comment_rowid) REFERENCES comment (rowid) ON DELETE CASCADE,
|
||||
FOREIGN KEY (chosen_by_comment_rowid) REFERENCES comment (rowid) ON DELETE SET NULL,
|
||||
FOREIGN KEY (removed_by_comment_rowid) REFERENCES comment (rowid) ON DELETE SET NULL,
|
||||
PRIMARY KEY (submission_rowid, author_rowid)
|
||||
)
|
||||
'''
|
||||
]
|
||||
}
|
||||
|
||||
def __init__(self, dbpath):
|
||||
self.path = dbpath
|
||||
@@ -106,28 +156,41 @@ class Database:
|
||||
self.cursor = None
|
||||
|
||||
if not os.path.exists(self.path):
|
||||
self._create()
|
||||
logging.info('No database found; creating...')
|
||||
self._run_migrations()
|
||||
logging.info('Successfully created database')
|
||||
else:
|
||||
self._migrate_if_necessary()
|
||||
current_version = self._get_current_version()
|
||||
if current_version != self.LATEST_VERSION:
|
||||
logging.info('Newer database version exists; migrating...')
|
||||
self._run_migrations(current_version)
|
||||
logging.info('Successfully completed all migrations')
|
||||
|
||||
@transaction
|
||||
def _create(self):
|
||||
self.cursor.execute(self.SCHEMA)
|
||||
def _run_migrations(self, current_version=None):
|
||||
if not current_version:
|
||||
current_version = DatabaseVersion(0, 0, 0)
|
||||
logging.info(f'Current database version: {current_version}')
|
||||
versions = sorted(v for v in self.SCHEMA_VERSION_STATEMENTS if current_version < v)
|
||||
for v in versions:
|
||||
logging.info(f'Beginning migration to version: {v}...')
|
||||
for sql_stmt in self.SCHEMA_VERSION_STATEMENTS[v]:
|
||||
self.cursor.execute(sql_stmt)
|
||||
logging.info(f'Successfully completed migration')
|
||||
|
||||
@transaction
|
||||
def _migrate_if_necessary(self):
|
||||
def _get_current_version(self):
|
||||
self.cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'bot_version'")
|
||||
has_version_table = (self.cursor.rowcount == 1)
|
||||
has_outdated_version = False
|
||||
if has_version_table:
|
||||
if not has_version_table:
|
||||
current_version = DatabaseVersion(0, 1, 0)
|
||||
else:
|
||||
self.cursor.execute('SELECT major, minor, patch, pre_release_name, pre_release_number FROM bot_version')
|
||||
row = self.cursor.fetchone()
|
||||
pre_release_number = int(row['pre_release_number']) if row['pre_release_number'] else None
|
||||
current_version = DatabaseVersion(int(row['major']), int(row['minor']), int(row['patch']), row['pre_release_name'], pre_release_number)
|
||||
has_outdated_version = (current_version == self.VERSION)
|
||||
|
||||
if not has_version_table or has_outdated_version:
|
||||
self.cursor.execute(self.SCHEMA)
|
||||
return current_version
|
||||
|
||||
### Public Methods ###
|
||||
|
||||
@@ -151,18 +214,14 @@ class Database:
|
||||
return self.cursor.rowcount
|
||||
|
||||
@transaction
|
||||
# def has_already_solved_once(self, solver, submission):
|
||||
def has_already_solved_once(self, submission, solver):
|
||||
# author_id = self._get_rowid_from_reddit_id('redditor', solver)
|
||||
select_stmt = '''
|
||||
SELECT count(solution.rowid) AS num_solutions
|
||||
FROM solution
|
||||
JOIN submission ON (solution.submission_rowid = submission.rowid)
|
||||
JOIN redditor ON (solution.author_rowid = redditor.rowid)
|
||||
-- JOIN comment ON (solution.comment_rowid = comment.rowid)
|
||||
WHERE submission.id = :submission_id
|
||||
AND redditor.id = :author_id
|
||||
-- AND comment.author_id = :author_id
|
||||
'''
|
||||
self.cursor.execute(select_stmt, {'submission_id': submission.id, 'author_id': solver.id})
|
||||
row = self.cursor.fetchone()
|
||||
@@ -174,32 +233,24 @@ class Database:
|
||||
self._add_comment(chosen_by_comment, chooser)
|
||||
|
||||
self._update_points(solver, 1)
|
||||
# rowcount = self._add_solution(submission, solution_comment, chosen_by_comment)
|
||||
rowcount = self._add_solution(submission, solver, solution_comment, chosen_by_comment)
|
||||
if rowcount == 0:
|
||||
# Was not able to add solution, probably because user has already solved this submission
|
||||
self._update_points(solver, -1)
|
||||
# if rowcount > 0:
|
||||
# rowcount = self._update_points(solver, 1)
|
||||
# # TODO update author_rowid for comment?
|
||||
return rowcount
|
||||
|
||||
# def remove_point_for_solution(self, submission, solver, solution_comment, remover, removed_by_comment):
|
||||
# def remove_point_for_solution(self, submission, solver, remover, removed_by_comment):
|
||||
def soft_remove_point_for_solution(self, submission, solver, remover, removed_by_comment):
|
||||
# submission = removed_by_comment.submission
|
||||
self._add_comment(removed_by_comment, remover)
|
||||
# rowcount = self._soft_remove_solution(submission, solution_comment, removed_by_comment)
|
||||
rowcount = self._soft_remove_solution(submission, solver, removed_by_comment)
|
||||
if rowcount > 0:
|
||||
rowcount = self._update_points(solver, -1)
|
||||
# TODO move "remove redditor" logic here since it doesn't need to be considered when adding points?
|
||||
return rowcount
|
||||
|
||||
@transaction
|
||||
def add_back_point_for_solution(self, submission, solver):
|
||||
self._update_points(solver, 1)
|
||||
# submission_rowid = self._get_submission_rowid(submission)
|
||||
submission_rowid = self._get_rowid_from_reddit_id('submission', submission)
|
||||
author_rowid = self._get_rowid_from_reddit_id('redditor', solver)
|
||||
params = {'submission_rowid': submission_rowid, 'author_rowid': author_rowid}
|
||||
@@ -266,12 +317,6 @@ class Database:
|
||||
self.cursor.execute(insert_stmt, params)
|
||||
return self.cursor.rowcount
|
||||
|
||||
# @transaction
|
||||
# def _get_comment_rowid(self, comment):
|
||||
# self.cursor.execute('SELECT rowid FROM comment WHERE id = :id', {'id': comment.id})
|
||||
# row = self.cursor.fetchone()
|
||||
# return row['rowid'] if row else None
|
||||
|
||||
@transaction
|
||||
def _add_submission(self, submission):
|
||||
insert_stmt = '''
|
||||
@@ -281,17 +326,8 @@ class Database:
|
||||
self.cursor.execute(insert_stmt, {'id': submission.id, 'author_id': submission.author.id})
|
||||
return self.cursor.rowcount
|
||||
|
||||
# @transaction
|
||||
# def _get_submission_rowid(self, submission):
|
||||
# self.cursor.execute('SELECT rowid FROM submission WHERE id = :id', {'id': submission.id})
|
||||
# row = self.cursor.fetchone()
|
||||
# return row['rowid'] if row else None
|
||||
|
||||
@transaction
|
||||
def _add_solution(self, submission, solver, comment, chosen_by_comment):
|
||||
# submission_rowid = self._get_submission_rowid(submission)
|
||||
# comment_rowid = self._get_comment_rowid(comment)
|
||||
# chosen_by_comment_rowid = self._get_comment_rowid(chosen_by_comment)
|
||||
submission_rowid = self._get_rowid_from_reddit_id('submission', submission)
|
||||
author_rowid = self._get_rowid_from_reddit_id('redditor', solver)
|
||||
comment_rowid = self._get_rowid_from_reddit_id('comment', comment)
|
||||
@@ -310,18 +346,12 @@ class Database:
|
||||
return self.cursor.rowcount
|
||||
|
||||
@transaction
|
||||
# def _soft_remove_solution(self, submission, comment, removed_by_comment):
|
||||
def _soft_remove_solution(self, submission, solver, removed_by_comment):
|
||||
# submission_rowid = self._get_submission_rowid(submission)
|
||||
# comment_rowid = self._get_comment_rowid(comment)
|
||||
# removed_by_comment_rowid = self._get_comment_rowid(removed_by_comment)
|
||||
submission_rowid = self._get_rowid_from_reddit_id('submission', submission)
|
||||
author_rowid = self._get_rowid_from_reddit_id('redditor', solver)
|
||||
# comment_rowid = self._get_rowid_from_reddit_id('comment', comment)
|
||||
removed_by_comment_rowid = self._get_rowid_from_reddit_id('comment', removed_by_comment)
|
||||
params = {
|
||||
'submission_rowid': submission_rowid,
|
||||
# 'comment_rowid': comment_rowid,
|
||||
'author_rowid': author_rowid,
|
||||
'removed_by_comment_rowid': removed_by_comment_rowid,
|
||||
}
|
||||
@@ -333,12 +363,6 @@ class Database:
|
||||
self.cursor.execute(update_stmt, params)
|
||||
return self.cursor.rowcount
|
||||
|
||||
# @transaction
|
||||
# def _get_redditor_rowid(self, redditor):
|
||||
# self.cursor.execute('SELECT rowid FROM redditor WHERE id = :id', {'id': redditor.id})
|
||||
# row = self.cursor.fetchone()
|
||||
# return row['rowid'] if row else None
|
||||
|
||||
@transaction
|
||||
def _update_points(self, redditor, points_modifier):
|
||||
"""points_modifier is positive to add points, negative to subtract."""
|
||||
|
||||
Reference in New Issue
Block a user