435 lines
17 KiB
Python
435 lines
17 KiB
Python
import datetime
|
|
import functools
|
|
import logging
|
|
import os.path
|
|
import re
|
|
import sqlite3 as sqlite
|
|
|
|
### Decorators ###
|
|
|
|
|
|
def transaction(func):
|
|
"""Use this decorator on any methods that needs to query the database to
|
|
ensure that connections are properly opened and closed, without being
|
|
left open unnecessarily.
|
|
"""
|
|
@functools.wraps(func)
|
|
def newfunc(self, *args, **kwargs):
|
|
created_conn = False
|
|
if not self.conn:
|
|
self.conn = sqlite.connect(self.path)
|
|
self.conn.row_factory = sqlite.Row
|
|
self.cursor = self.conn.cursor()
|
|
created_conn = True
|
|
|
|
return_value = func(self, *args, **kwargs)
|
|
# try:
|
|
# return_value = func(self, *args, **kwargs)
|
|
# except Exception as e:
|
|
# if self.conn.in_transaction:
|
|
# self.conn.rollback()
|
|
# if created_conn:
|
|
# self.cursor.close()
|
|
# self.conn.close()
|
|
# self.cursor = self.conn = None
|
|
# raise e
|
|
|
|
if self.conn.in_transaction:
|
|
self.conn.commit()
|
|
|
|
if created_conn:
|
|
self.cursor.close()
|
|
self.conn.close()
|
|
self.cursor = self.conn = None
|
|
|
|
return return_value
|
|
|
|
return newfunc
|
|
|
|
|
|
### Classes ###
|
|
|
|
class DatabaseVersion:
|
|
|
|
PRE_RELEASE_NAME_ORDER_NUMBER = {
|
|
None: 0,
|
|
'alpha': 1,
|
|
'beta': 2,
|
|
'rc': 3
|
|
}
|
|
|
|
def __init__(self, major, minor, patch, pre_release_name=None, pre_release_number=None):
|
|
self.major = major
|
|
self.minor = minor
|
|
self.patch = patch
|
|
self.pre_release_name = pre_release_name
|
|
self.pre_release_number = pre_release_number
|
|
|
|
def __lt__(self, other):
|
|
self_tuple = (self.major, self.minor, self.patch, self.PRE_RELEASE_NAME_ORDER_NUMBER[self.pre_release_name], self.pre_release_number)
|
|
other_tuple = (other.major, other.minor, other.patch, self.PRE_RELEASE_NAME_ORDER_NUMBER[other.pre_release_name], other.pre_release_number)
|
|
return self_tuple < other_tuple
|
|
|
|
def __str__(self):
|
|
version_string = f'{self.major}.{self.minor}.{self.patch}'
|
|
if self.pre_release_name is not None:
|
|
version_string += f'-{self.pre_release_name}'
|
|
if self.pre_release_number is not None:
|
|
version_string += f'.{self.pre_release_number}'
|
|
return version_string
|
|
|
|
@classmethod
|
|
def from_string(cls, version_string):
|
|
match = re.match(r'(\d+).(\d+).(\d+)(?:-([:alpha:]+)(?:.(\d+))?)?', version_string)
|
|
if not match:
|
|
return None
|
|
groups = match.groups()
|
|
return cls(int(groups[0]), int(groups[1]), int(groups[2]), groups[3], int(groups[4]))
|
|
|
|
|
|
class Database:
|
|
|
|
LATEST_VERSION = DatabaseVersion(0, 2, 0)
|
|
|
|
# TODO now that I'm separating these statements by version, I could probably make these
|
|
# scripts instead of lists of individual statements...
|
|
SCHEMA_VERSION_STATEMENTS = {
|
|
DatabaseVersion(0, 1, 0): [
|
|
'''
|
|
CREATE TABLE IF NOT EXISTS redditor_points (
|
|
id TEXT UNIQUE NOT NULL,
|
|
name TEXT UNIQUE NOT NULL,
|
|
points INTEGER DEFAULT 0
|
|
)
|
|
'''
|
|
],
|
|
DatabaseVersion(0, 2, 0): [
|
|
'''
|
|
CREATE TABLE IF NOT EXISTS bot_version (
|
|
major INTEGER NOT NULL,
|
|
minor INTEGER NOT NULL,
|
|
patch INTEGER NOT NULL,
|
|
pre_release_name TEXT,
|
|
pre_release_number INTEGER
|
|
)
|
|
''',
|
|
'''
|
|
ALTER TABLE redditor_points RENAME TO redditor
|
|
''',
|
|
'''
|
|
CREATE TABLE IF NOT EXISTS submission (
|
|
id TEXT UNIQUE NOT NULL,
|
|
author_id TEXT UNIQUE NOT NULL
|
|
)
|
|
''',
|
|
'''
|
|
CREATE TABLE IF NOT EXISTS comment (
|
|
id TEXT UNIQUE NOT NULL,
|
|
author_id TEXT NOT NULL,
|
|
author_rowid INTEGER, -- May be NULL **for now**
|
|
created_at_datetime TEXT NOT NULL,
|
|
FOREIGN KEY (author_rowid) REFERENCES redditor (rowid) ON DELETE CASCADE
|
|
)
|
|
''',
|
|
'''
|
|
CREATE TABLE IF NOT EXISTS solution (
|
|
submission_rowid INTEGER NOT NULL,
|
|
author_rowid INTEGER NOT NULL,
|
|
comment_rowid INTEGER NOT NULL,
|
|
chosen_by_comment_rowid INTEGER NOT NULL,
|
|
removed_by_comment_rowid INTEGER,
|
|
FOREIGN KEY (submission_rowid) REFERENCES submission (rowid) ON DELETE CASCADE,
|
|
FOREIGN KEY (author_rowid) REFERENCES redditor (rowid) ON DELETE CASCADE,
|
|
FOREIGN KEY (comment_rowid) REFERENCES comment (rowid) ON DELETE CASCADE,
|
|
FOREIGN KEY (chosen_by_comment_rowid) REFERENCES comment (rowid) ON DELETE SET NULL,
|
|
FOREIGN KEY (removed_by_comment_rowid) REFERENCES comment (rowid) ON DELETE SET NULL,
|
|
PRIMARY KEY (submission_rowid, author_rowid)
|
|
)
|
|
'''
|
|
]
|
|
}
|
|
|
|
def __init__(self, dbpath):
|
|
self.path = dbpath
|
|
self.conn = None
|
|
self.cursor = None
|
|
|
|
if not os.path.exists(self.path):
|
|
logging.info('No database found; creating...')
|
|
self._run_migrations()
|
|
logging.info('Successfully created database')
|
|
else:
|
|
logging.info(f'Using existing database: {self.path}')
|
|
current_version = self._get_current_version()
|
|
if current_version != self.LATEST_VERSION:
|
|
logging.info('Newer database version exists; migrating...')
|
|
self._run_migrations(current_version)
|
|
logging.info('Successfully completed all migrations')
|
|
|
|
@transaction
|
|
def _run_migrations(self, current_version=None):
|
|
if not current_version:
|
|
current_version = DatabaseVersion(0, 0, 0)
|
|
logging.info(f'Current database version: {current_version}')
|
|
|
|
versions = sorted(v for v in self.SCHEMA_VERSION_STATEMENTS if current_version < v)
|
|
for v in versions:
|
|
logging.info(f'Beginning migration to version: {v}...')
|
|
for sql_stmt in self.SCHEMA_VERSION_STATEMENTS[v]:
|
|
self.cursor.execute(sql_stmt)
|
|
if DatabaseVersion(0, 1, 0) < v:
|
|
# Only update bot_version table starting at version 0.2.0
|
|
self.cursor.execute('DELETE FROM bot_version')
|
|
params = {
|
|
'major': v.major,
|
|
'minor': v.minor,
|
|
'patch': v.patch,
|
|
'pre_release_name': v.pre_release_name,
|
|
'pre_release_number': v.pre_release_number
|
|
}
|
|
insert_stmt = '''
|
|
INSERT INTO bot_version (major, minor, patch, pre_release_name, pre_release_number)
|
|
VALUES (:major, :minor, :patch, :pre_release_name, :pre_release_number)
|
|
'''
|
|
self.cursor.execute(insert_stmt, params)
|
|
logging.info(f'Successfully completed migration')
|
|
|
|
@transaction
|
|
def _get_current_version(self):
|
|
# self.cursor.execute('select * from sqlite_master')
|
|
# for row in self.cursor.fetchmany():
|
|
# logging.info(tuple(row))
|
|
self.cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'bot_version'")
|
|
has_version_table = (self.cursor.rowcount == 1)
|
|
if not has_version_table:
|
|
current_version = DatabaseVersion(0, 1, 0)
|
|
else:
|
|
self.cursor.execute('SELECT major, minor, patch, pre_release_name, pre_release_number FROM bot_version')
|
|
row = self.cursor.fetchone()
|
|
pre_release_number = int(row['pre_release_number']) if row['pre_release_number'] else None
|
|
current_version = DatabaseVersion(int(row['major']), int(row['minor']), int(row['patch']), row['pre_release_name'], pre_release_number)
|
|
|
|
return current_version
|
|
|
|
### Public Methods ###
|
|
|
|
@transaction
|
|
def add_redditor(self, redditor):
|
|
insert_stmt = '''
|
|
INSERT OR IGNORE INTO redditor (id, name)
|
|
VALUES (:id, :name)
|
|
'''
|
|
self.cursor.execute(insert_stmt, {'id': redditor.id, 'name': redditor.name})
|
|
return self.cursor.rowcount
|
|
|
|
@transaction
|
|
def remove_redditor(self, redditor):
|
|
insert_stmt = '''
|
|
DELETE FROM redditor
|
|
WHERE id = :id
|
|
AND name = :name
|
|
'''
|
|
self.cursor.execute(insert_stmt, {'id': redditor.id, 'name': redditor.name})
|
|
return self.cursor.rowcount
|
|
|
|
@transaction
|
|
def has_already_solved_once(self, submission, solver):
|
|
select_stmt = '''
|
|
SELECT count(solution.rowid) AS num_solutions
|
|
FROM solution
|
|
JOIN submission ON (solution.submission_rowid = submission.rowid)
|
|
JOIN redditor ON (solution.author_rowid = redditor.rowid)
|
|
WHERE submission.id = :submission_id
|
|
AND redditor.id = :author_id
|
|
'''
|
|
self.cursor.execute(select_stmt, {'submission_id': submission.id, 'author_id': solver.id})
|
|
row = self.cursor.fetchone()
|
|
return row and row['num_solutions'] > 0
|
|
|
|
def add_point_for_solution(self, submission, solver, solution_comment, chooser, chosen_by_comment):
|
|
self._add_submission(submission)
|
|
self._add_comment(solution_comment, solver)
|
|
self._add_comment(chosen_by_comment, chooser)
|
|
|
|
self._update_points(solver, 1)
|
|
rowcount = self._add_solution(submission, solver, solution_comment, chosen_by_comment)
|
|
if rowcount == 0:
|
|
# Was not able to add solution, probably because user has already solved this submission
|
|
self._update_points(solver, -1)
|
|
# if rowcount > 0:
|
|
# # TODO update author_rowid for comment?
|
|
return rowcount
|
|
|
|
def soft_remove_point_for_solution(self, submission, solver, remover, removed_by_comment):
|
|
self._add_comment(removed_by_comment, remover)
|
|
rowcount = self._soft_remove_solution(submission, solver, removed_by_comment)
|
|
if rowcount > 0:
|
|
rowcount = self._update_points(solver, -1)
|
|
return rowcount
|
|
|
|
@transaction
|
|
def add_back_point_for_solution(self, submission, solver):
|
|
self._update_points(solver, 1)
|
|
# submission_rowid = self._get_rowid_from_reddit_id('submission', submission)
|
|
# author_rowid = self._get_rowid_from_reddit_id('redditor', solver)
|
|
submission_rowid = self._get_submission_rowid(submission)
|
|
author_rowid = self._get_redditor_rowid(solver)
|
|
params = {'submission_rowid': submission_rowid, 'author_rowid': author_rowid}
|
|
update_stmt = '''
|
|
UPDATE solution
|
|
SET removed_by_comment_rowid = NULL
|
|
WHERE submission_rowid = :submission_rowid
|
|
AND author_rowid = :author_rowid
|
|
'''
|
|
return self.cursor.execute(update_stmt, params)
|
|
|
|
@transaction
|
|
def remove_point_and_delete_solution(self, submission, solver):
|
|
params = {
|
|
# 'submission_rowid': self._get_rowid_from_reddit_id('submission', submission),
|
|
# 'author_rowid': self._get_rowid_from_reddit_id('redditor', solver)
|
|
'submission_rowid': self._get_submission_rowid(submission),
|
|
'author_rowid': self._get_redditor_rowid(solver)
|
|
}
|
|
delete_stmt = '''
|
|
DELETE FROM solution
|
|
WHERE submission_rowid = :submission_rowid
|
|
AND author_rowid = :author_rowid
|
|
'''
|
|
self.cursor.execute(delete_stmt, params)
|
|
return self._update_points(solver, -1)
|
|
|
|
@transaction
|
|
def get_points(self, redditor, add_if_none=False):
|
|
params = {'id': redditor.id, 'name': redditor.name}
|
|
select_stmt = '''
|
|
SELECT points
|
|
FROM redditor
|
|
WHERE id = :id AND name = :name
|
|
'''
|
|
self.cursor.execute(select_stmt, params)
|
|
row = self.cursor.fetchone()
|
|
points = 0
|
|
if row:
|
|
points = row['points']
|
|
elif add_if_none:
|
|
self.add_redditor(redditor)
|
|
|
|
return points
|
|
|
|
### Private Methods ###
|
|
|
|
# @transaction
|
|
# def _get_rowid_from_reddit_id(self, table_name, reddit_object):
|
|
# params = {'table_name': table_name, 'reddit_id': reddit_object.id}
|
|
# self.cursor.execute('SELECT rowid FROM :table_name WHERE id = :reddit_id', params)
|
|
# row = self.cursor.fetchone()
|
|
# return row['rowid'] if row else None
|
|
|
|
# @transaction
|
|
def _get_submission_rowid(self, submission):
|
|
return self._get_rowid_from_reddit_id('SELECT rowid FROM submission WHERE id = :reddit_id', {'reddit_id': submission.id})
|
|
# self.cursor.execute('SELECT rowid FROM submission WHERE id = :reddit_id', {'reddit_id': submission.id})
|
|
# row = self.cursor.fetchone()
|
|
# return row['rowid'] if row else None
|
|
|
|
def _get_comment_rowid(self, comment):
|
|
return self._get_rowid_from_reddit_id('SELECT rowid FROM comment WHERE id = :reddit_id', {'reddit_id': comment.id})
|
|
|
|
def _get_redditor_rowid(self, redditor):
|
|
return self._get_rowid_from_reddit_id('SELECT rowid FROM redditor WHERE id = :reddit_id', {'reddit_id': redditor.id})
|
|
|
|
@transaction
|
|
def _get_rowid_from_reddit_id(self, stmt, params):
|
|
self.cursor.execute(stmt, params)
|
|
row = self.cursor.fetchone()
|
|
return row['rowid'] if row else None
|
|
|
|
@transaction
|
|
def _add_comment(self, comment, author):
|
|
params = {
|
|
'id': comment.id,
|
|
'author_id': author.id,
|
|
'created_at_datetime': reddit_datetime_to_iso(comment.created_utc)
|
|
}
|
|
insert_stmt = '''
|
|
INSERT INTO comment (id, author_id, created_at_datetime)
|
|
VALUES (:id, :author_id, :created_at_datetime)
|
|
'''
|
|
self.cursor.execute(insert_stmt, params)
|
|
return self.cursor.rowcount
|
|
|
|
@transaction
|
|
def _add_submission(self, submission):
|
|
insert_stmt = '''
|
|
INSERT OR IGNORE INTO submission (id, author_id)
|
|
VALUES (:id, :author_id)
|
|
'''
|
|
self.cursor.execute(insert_stmt, {'id': submission.id, 'author_id': submission.author.id})
|
|
return self.cursor.rowcount
|
|
|
|
@transaction
|
|
def _add_solution(self, submission, solver, comment, chosen_by_comment):
|
|
submission_rowid = self._get_submission_rowid(submission)
|
|
author_rowid = self._get_redditor_rowid(solver)
|
|
comment_rowid = self._get_comment_rowid(comment)
|
|
chosen_by_comment_rowid = self._get_comment_rowid(chosen_by_comment)
|
|
params = {
|
|
'submission_rowid': submission_rowid,
|
|
'author_rowid': author_rowid,
|
|
'comment_rowid': comment_rowid,
|
|
'chosen_by_comment_rowid': chosen_by_comment_rowid
|
|
}
|
|
insert_stmt = '''
|
|
INSERT INTO solution (submission_rowid, author_rowid, comment_rowid, chosen_by_comment_rowid)
|
|
VALUES (:submission_rowid, :author_rowid, :comment_rowid, :chosen_by_comment_rowid)
|
|
'''
|
|
self.cursor.execute(insert_stmt, params)
|
|
return self.cursor.rowcount
|
|
|
|
@transaction
|
|
def _soft_remove_solution(self, submission, solver, removed_by_comment):
|
|
submission_rowid = self._get_submission_rowid(submission)
|
|
author_rowid = self._get_redditor_rowid(solver)
|
|
removed_by_comment_rowid = self._get_comment_rowid(removed_by_comment)
|
|
params = {
|
|
'submission_rowid': submission_rowid,
|
|
'author_rowid': author_rowid,
|
|
'removed_by_comment_rowid': removed_by_comment_rowid,
|
|
}
|
|
update_stmt = '''
|
|
UPDATE solution
|
|
SET removed_by_comment_rowid = :removed_by_comment_rowid
|
|
WHERE submission_rowid = :submission_rowid AND author_rowid = :author_rowid
|
|
'''
|
|
self.cursor.execute(update_stmt, params)
|
|
return self.cursor.rowcount
|
|
|
|
@transaction
|
|
def _update_points(self, redditor, points_modifier):
|
|
"""points_modifier is positive to add points, negative to subtract."""
|
|
points = self.get_points(redditor, add_if_none=True)
|
|
if points + points_modifier <= 0:
|
|
return self.remove_redditor(redditor)
|
|
else:
|
|
params = {
|
|
'id': redditor.id,
|
|
'name': redditor.name,
|
|
'points': points + points_modifier,
|
|
}
|
|
update_stmt = '''
|
|
UPDATE redditor
|
|
SET points = :points
|
|
WHERE id = :id AND name = :name
|
|
'''
|
|
self.cursor.execute(update_stmt, params)
|
|
return self.cursor.rowcount
|
|
|
|
|
|
### Utility ###
|
|
|
|
|
|
def reddit_datetime_to_iso(timestamp):
|
|
return datetime.datetime.utcfromtimestamp(timestamp).isoformat()
|
|
|