diff --git a/pointsbot/database.py b/pointsbot/database.py index 0a66818..cb0ad94 100644 --- a/pointsbot/database.py +++ b/pointsbot/database.py @@ -9,9 +9,8 @@ import sqlite3 as sqlite def transaction(func): - """Use this decorator on any methods that needs to query the database to - ensure that connections are properly opened and closed, without being - left open unnecessarily. + """Use this decorator on any methods that needs to query the database to ensure that connections + are properly opened and closed, without being left open unnecessarily. """ @functools.wraps(func) def newfunc(self, *args, **kwargs): @@ -49,6 +48,8 @@ def transaction(func): ### Classes ### + +# @functools.total_ordering class DatabaseVersion: PRE_RELEASE_NAME_ORDER_NUMBER = { @@ -66,9 +67,21 @@ class DatabaseVersion: self.pre_release_number = pre_release_number def __lt__(self, other): - self_tuple = (self.major, self.minor, self.patch, self.PRE_RELEASE_NAME_ORDER_NUMBER[self.pre_release_name], self.pre_release_number) - other_tuple = (other.major, other.minor, other.patch, self.PRE_RELEASE_NAME_ORDER_NUMBER[other.pre_release_name], other.pre_release_number) - return self_tuple < other_tuple + # self_tuple = (self.major, self.minor, self.patch, self.PRE_RELEASE_NAME_ORDER_NUMBER[self.pre_release_name], self.pre_release_number) + # other_tuple = (other.major, other.minor, other.patch, self.PRE_RELEASE_NAME_ORDER_NUMBER[other.pre_release_name], other.pre_release_number) + return self._to_tuple() < other._to_tuple() + + def __eq__(self, other): + # self_tuple = (self.major, self.minor, self.patch, self.PRE_RELEASE_NAME_ORDER_NUMBER[self.pre_release_name], self.pre_release_number) + # other_tuple = (other.major, other.minor, other.patch, self.PRE_RELEASE_NAME_ORDER_NUMBER[other.pre_release_name], other.pre_release_number) + return self._to_tuple() == other._to_tuple() + + def __hash__(self): + # self_tuple = (self.major, self.minor, self.patch, self.PRE_RELEASE_NAME_ORDER_NUMBER[self.pre_release_name], self.pre_release_number) + return hash(self._to_tuple()) + + def _to_tuple(self): + return (self.major, self.minor, self.patch, self.PRE_RELEASE_NAME_ORDER_NUMBER[self.pre_release_name], self.pre_release_number) def __str__(self): version_string = f'{self.major}.{self.minor}.{self.patch}' @@ -200,7 +213,7 @@ class Database: # for row in self.cursor.fetchmany(): # logging.info(tuple(row)) self.cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'bot_version'") - has_version_table = (self.cursor.rowcount == 1) + has_version_table = self.cursor.fetchone() if not has_version_table: current_version = DatabaseVersion(0, 1, 0) else: