Files
PointsBot/pointsbot/database.py

424 lines
16 KiB
Python
Raw Normal View History

import datetime
2020-01-15 11:07:13 -08:00
import functools
2021-02-15 21:25:58 -08:00
import logging
2020-01-15 11:07:13 -08:00
import os.path
2021-02-15 21:25:58 -08:00
import re
2020-01-15 11:07:13 -08:00
import sqlite3 as sqlite
### Decorators ###
2020-01-15 17:42:49 -08:00
def transaction(func):
2021-03-07 22:34:08 -08:00
"""Use this decorator on any methods that needs to query the database to ensure that connections
are properly opened and closed, without being left open unnecessarily.
"""
2020-01-15 17:42:49 -08:00
@functools.wraps(func)
def newfunc(self, *args, **kwargs):
created_conn = False
if not self.conn:
self.conn = sqlite.connect(self.path)
self.conn.row_factory = sqlite.Row
self.cursor = self.conn.cursor()
created_conn = True
2021-02-22 00:26:34 -08:00
return_value = func(self, *args, **kwargs)
# try:
# return_value = func(self, *args, **kwargs)
# except Exception as e:
# if self.conn.in_transaction:
# self.conn.rollback()
# if created_conn:
# self.cursor.close()
# self.conn.close()
# self.cursor = self.conn = None
# raise e
2020-01-15 17:42:49 -08:00
if self.conn.in_transaction:
self.conn.commit()
if created_conn:
self.cursor.close()
self.conn.close()
self.cursor = self.conn = None
2021-02-15 21:25:58 -08:00
return return_value
2020-01-15 17:42:49 -08:00
return newfunc
2020-01-15 11:07:13 -08:00
2020-01-15 17:42:49 -08:00
### Classes ###
2021-03-07 22:34:08 -08:00
2021-02-15 21:25:58 -08:00
class DatabaseVersion:
PRE_RELEASE_NAME_ORDER_NUMBER = {
2021-02-22 00:26:34 -08:00
None: 0,
'alpha': 1,
'beta': 2,
'rc': 3
2021-02-15 21:25:58 -08:00
}
def __init__(self, major, minor, patch, pre_release_name=None, pre_release_number=None):
self.major = major
self.minor = minor
self.patch = patch
self.pre_release_name = pre_release_name
self.pre_release_number = pre_release_number
def __lt__(self, other):
2021-03-07 22:34:08 -08:00
return self._to_tuple() < other._to_tuple()
def __eq__(self, other):
return self._to_tuple() == other._to_tuple()
def __hash__(self):
return hash(self._to_tuple())
def _to_tuple(self):
return (self.major, self.minor, self.patch, self.PRE_RELEASE_NAME_ORDER_NUMBER[self.pre_release_name], self.pre_release_number)
2021-02-15 21:25:58 -08:00
def __str__(self):
version_string = f'{self.major}.{self.minor}.{self.patch}'
if self.pre_release_name is not None:
version_string += f'-{self.pre_release_name}'
if self.pre_release_number is not None:
version_string += f'.{self.pre_release_number}'
return version_string
@classmethod
def from_string(cls, version_string):
match = re.match(r'(\d+).(\d+).(\d+)(?:-([:alpha:]+)(?:.(\d+))?)?', version_string)
if not match:
return None
groups = match.groups()
return cls(int(groups[0]), int(groups[1]), int(groups[2]), groups[3], int(groups[4]))
2020-01-15 17:42:49 -08:00
class Database:
2021-02-15 21:25:58 -08:00
LATEST_VERSION = DatabaseVersion(0, 2, 0)
2021-02-22 00:26:34 -08:00
# TODO now that I'm separating these statements by version, I could probably make these
# scripts instead of lists of individual statements...
2021-02-15 21:25:58 -08:00
SCHEMA_VERSION_STATEMENTS = {
DatabaseVersion(0, 1, 0): [
'''
CREATE TABLE IF NOT EXISTS redditor_points (
id TEXT UNIQUE NOT NULL,
name TEXT UNIQUE NOT NULL,
points INTEGER DEFAULT 0
)
'''
],
DatabaseVersion(0, 2, 0): [
'''
CREATE TABLE IF NOT EXISTS bot_version (
major INTEGER NOT NULL,
minor INTEGER NOT NULL,
patch INTEGER NOT NULL,
pre_release_name TEXT,
pre_release_number INTEGER
)
''',
'''
ALTER TABLE redditor_points RENAME TO redditor
''',
'''
CREATE TABLE IF NOT EXISTS submission (
id TEXT UNIQUE NOT NULL,
author_id TEXT UNIQUE NOT NULL
)
''',
'''
CREATE TABLE IF NOT EXISTS comment (
id TEXT UNIQUE NOT NULL,
author_id TEXT NOT NULL,
author_rowid INTEGER, -- May be NULL **for now**
created_at_datetime TEXT NOT NULL,
FOREIGN KEY (author_rowid) REFERENCES redditor (rowid) ON DELETE CASCADE
)
''',
'''
CREATE TABLE IF NOT EXISTS solution (
submission_rowid INTEGER NOT NULL,
author_rowid INTEGER NOT NULL,
comment_rowid INTEGER NOT NULL,
chosen_by_comment_rowid INTEGER NOT NULL,
removed_by_comment_rowid INTEGER,
FOREIGN KEY (submission_rowid) REFERENCES submission (rowid) ON DELETE CASCADE,
FOREIGN KEY (author_rowid) REFERENCES redditor (rowid) ON DELETE CASCADE,
FOREIGN KEY (comment_rowid) REFERENCES comment (rowid) ON DELETE CASCADE,
FOREIGN KEY (chosen_by_comment_rowid) REFERENCES comment (rowid) ON DELETE SET NULL,
FOREIGN KEY (removed_by_comment_rowid) REFERENCES comment (rowid) ON DELETE SET NULL,
PRIMARY KEY (submission_rowid, author_rowid)
)
'''
]
}
2020-01-15 17:42:49 -08:00
def __init__(self, dbpath):
self.path = dbpath
self.conn = None
self.cursor = None
if not os.path.exists(self.path):
2021-02-15 21:25:58 -08:00
logging.info('No database found; creating...')
self._run_migrations()
logging.info('Successfully created database')
else:
2021-02-22 00:26:34 -08:00
logging.info(f'Using existing database: {self.path}')
2021-02-15 21:25:58 -08:00
current_version = self._get_current_version()
if current_version != self.LATEST_VERSION:
logging.info('Newer database version exists; migrating...')
self._run_migrations(current_version)
logging.info('Successfully completed all migrations')
2020-01-15 17:42:49 -08:00
@transaction
2021-02-15 21:25:58 -08:00
def _run_migrations(self, current_version=None):
if not current_version:
current_version = DatabaseVersion(0, 0, 0)
logging.info(f'Current database version: {current_version}')
2021-02-22 00:26:34 -08:00
2021-02-15 21:25:58 -08:00
versions = sorted(v for v in self.SCHEMA_VERSION_STATEMENTS if current_version < v)
for v in versions:
logging.info(f'Beginning migration to version: {v}...')
for sql_stmt in self.SCHEMA_VERSION_STATEMENTS[v]:
self.cursor.execute(sql_stmt)
2021-02-22 00:26:34 -08:00
if DatabaseVersion(0, 1, 0) < v:
# Only update bot_version table starting at version 0.2.0
self.cursor.execute('DELETE FROM bot_version')
params = {
'major': v.major,
'minor': v.minor,
'patch': v.patch,
'pre_release_name': v.pre_release_name,
'pre_release_number': v.pre_release_number
}
insert_stmt = '''
INSERT INTO bot_version (major, minor, patch, pre_release_name, pre_release_number)
VALUES (:major, :minor, :patch, :pre_release_name, :pre_release_number)
'''
self.cursor.execute(insert_stmt, params)
2021-02-15 21:25:58 -08:00
logging.info(f'Successfully completed migration')
2020-01-15 17:42:49 -08:00
@transaction
2021-02-15 21:25:58 -08:00
def _get_current_version(self):
self.cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'bot_version'")
2021-03-07 22:34:08 -08:00
has_version_table = self.cursor.fetchone()
2021-02-15 21:25:58 -08:00
if not has_version_table:
current_version = DatabaseVersion(0, 1, 0)
else:
self.cursor.execute('SELECT major, minor, patch, pre_release_name, pre_release_number FROM bot_version')
row = self.cursor.fetchone()
pre_release_number = int(row['pre_release_number']) if row['pre_release_number'] else None
current_version = DatabaseVersion(int(row['major']), int(row['minor']), int(row['patch']), row['pre_release_name'], pre_release_number)
2021-02-15 21:25:58 -08:00
return current_version
### Public Methods ###
2020-01-15 17:42:49 -08:00
@transaction
def add_redditor(self, redditor):
insert_stmt = '''
INSERT OR IGNORE INTO redditor (id, name)
2020-01-15 17:42:49 -08:00
VALUES (:id, :name)
2020-02-11 09:08:03 -08:00
'''
2020-01-15 17:42:49 -08:00
self.cursor.execute(insert_stmt, {'id': redditor.id, 'name': redditor.name})
return self.cursor.rowcount
@transaction
def remove_redditor(self, redditor):
insert_stmt = '''
DELETE FROM redditor
WHERE id = :id
AND name = :name
'''
self.cursor.execute(insert_stmt, {'id': redditor.id, 'name': redditor.name})
return self.cursor.rowcount
@transaction
def has_already_solved_once(self, submission, solver):
select_stmt = '''
SELECT count(solution.rowid) AS num_solutions
FROM solution
JOIN submission ON (solution.submission_rowid = submission.rowid)
JOIN redditor ON (solution.author_rowid = redditor.rowid)
WHERE submission.id = :submission_id
AND redditor.id = :author_id
'''
self.cursor.execute(select_stmt, {'submission_id': submission.id, 'author_id': solver.id})
row = self.cursor.fetchone()
return row and row['num_solutions'] > 0
2021-02-22 00:26:34 -08:00
def add_point_for_solution(self, submission, solver, solution_comment, chooser, chosen_by_comment):
self._add_submission(submission)
self._add_comment(solution_comment, solver)
self._add_comment(chosen_by_comment, chooser)
self._update_points(solver, 1)
rowcount = self._add_solution(submission, solver, solution_comment, chosen_by_comment)
if rowcount == 0:
# Was not able to add solution, probably because user has already solved this submission
self._update_points(solver, -1)
# if rowcount > 0:
# # TODO update author_rowid for comment?
return rowcount
def soft_remove_point_for_solution(self, submission, solver, remover, removed_by_comment):
self._add_comment(removed_by_comment, remover)
rowcount = self._soft_remove_solution(submission, solver, removed_by_comment)
if rowcount > 0:
rowcount = self._update_points(solver, -1)
return rowcount
@transaction
def add_back_point_for_solution(self, submission, solver):
self._update_points(solver, 1)
2021-02-22 00:26:34 -08:00
submission_rowid = self._get_submission_rowid(submission)
author_rowid = self._get_redditor_rowid(solver)
params = {'submission_rowid': submission_rowid, 'author_rowid': author_rowid}
update_stmt = '''
UPDATE solution
SET removed_by_comment_rowid = NULL
WHERE submission_rowid = :submission_rowid
AND author_rowid = :author_rowid
'''
return self.cursor.execute(update_stmt, params)
2021-02-22 00:26:34 -08:00
@transaction
def remove_point_and_delete_solution(self, submission, solver):
params = {
2021-02-22 00:26:34 -08:00
'submission_rowid': self._get_submission_rowid(submission),
'author_rowid': self._get_redditor_rowid(solver)
}
delete_stmt = '''
DELETE FROM solution
WHERE submission_rowid = :submission_rowid
AND author_rowid = :author_rowid
'''
self.cursor.execute(delete_stmt, params)
return self._update_points(solver, -1)
@transaction
def get_points(self, redditor, add_if_none=False):
params = {'id': redditor.id, 'name': redditor.name}
select_stmt = '''
SELECT points
FROM redditor
WHERE id = :id AND name = :name
'''
self.cursor.execute(select_stmt, params)
row = self.cursor.fetchone()
points = 0
if row:
points = row['points']
elif add_if_none:
self.add_redditor(redditor)
return points
2020-02-11 09:08:03 -08:00
### Private Methods ###
2021-02-22 00:26:34 -08:00
def _get_submission_rowid(self, submission):
return self._get_rowid_from_reddit_id('SELECT rowid FROM submission WHERE id = :reddit_id', {'reddit_id': submission.id})
def _get_comment_rowid(self, comment):
return self._get_rowid_from_reddit_id('SELECT rowid FROM comment WHERE id = :reddit_id', {'reddit_id': comment.id})
def _get_redditor_rowid(self, redditor):
return self._get_rowid_from_reddit_id('SELECT rowid FROM redditor WHERE id = :reddit_id', {'reddit_id': redditor.id})
@transaction
2021-02-22 00:26:34 -08:00
def _get_rowid_from_reddit_id(self, stmt, params):
self.cursor.execute(stmt, params)
row = self.cursor.fetchone()
return row['rowid'] if row else None
@transaction
def _add_comment(self, comment, author):
params = {
'id': comment.id,
'author_id': author.id,
'created_at_datetime': reddit_datetime_to_iso(comment.created_utc)
}
insert_stmt = '''
INSERT INTO comment (id, author_id, created_at_datetime)
VALUES (:id, :author_id, :created_at_datetime)
'''
self.cursor.execute(insert_stmt, params)
return self.cursor.rowcount
@transaction
def _add_submission(self, submission):
insert_stmt = '''
INSERT OR IGNORE INTO submission (id, author_id)
VALUES (:id, :author_id)
'''
self.cursor.execute(insert_stmt, {'id': submission.id, 'author_id': submission.author.id})
return self.cursor.rowcount
@transaction
def _add_solution(self, submission, solver, comment, chosen_by_comment):
2021-02-22 00:26:34 -08:00
submission_rowid = self._get_submission_rowid(submission)
author_rowid = self._get_redditor_rowid(solver)
comment_rowid = self._get_comment_rowid(comment)
chosen_by_comment_rowid = self._get_comment_rowid(chosen_by_comment)
params = {
'submission_rowid': submission_rowid,
'author_rowid': author_rowid,
'comment_rowid': comment_rowid,
'chosen_by_comment_rowid': chosen_by_comment_rowid
}
insert_stmt = '''
INSERT INTO solution (submission_rowid, author_rowid, comment_rowid, chosen_by_comment_rowid)
VALUES (:submission_rowid, :author_rowid, :comment_rowid, :chosen_by_comment_rowid)
'''
self.cursor.execute(insert_stmt, params)
return self.cursor.rowcount
@transaction
def _soft_remove_solution(self, submission, solver, removed_by_comment):
2021-02-22 00:26:34 -08:00
submission_rowid = self._get_submission_rowid(submission)
author_rowid = self._get_redditor_rowid(solver)
removed_by_comment_rowid = self._get_comment_rowid(removed_by_comment)
params = {
'submission_rowid': submission_rowid,
'author_rowid': author_rowid,
'removed_by_comment_rowid': removed_by_comment_rowid,
}
update_stmt = '''
UPDATE solution
SET removed_by_comment_rowid = :removed_by_comment_rowid
WHERE submission_rowid = :submission_rowid AND author_rowid = :author_rowid
'''
self.cursor.execute(update_stmt, params)
return self.cursor.rowcount
2020-02-11 09:08:03 -08:00
@transaction
def _update_points(self, redditor, points_modifier):
"""points_modifier is positive to add points, negative to subtract."""
2020-01-15 17:42:49 -08:00
points = self.get_points(redditor, add_if_none=True)
if points + points_modifier <= 0:
return self.remove_redditor(redditor)
else:
params = {
'id': redditor.id,
'name': redditor.name,
'points': points + points_modifier,
}
update_stmt = '''
UPDATE redditor
SET points = :points
WHERE id = :id AND name = :name
'''
self.cursor.execute(update_stmt, params)
return self.cursor.rowcount
2020-01-15 17:42:49 -08:00
### Utility ###
2021-02-22 00:26:34 -08:00
def reddit_datetime_to_iso(timestamp):
return datetime.datetime.utcfromtimestamp(timestamp).isoformat()
2020-01-15 17:42:49 -08:00