import MySQLdb import atexit import time class SelectWrapper(tuple): pass class Field(str): @staticmethod def test(value): if type(value) is Field: return value else: return '%s' class Database(object): """ High level database access object. """ def __init__(self, prefix, mysql_args): self.connection = MySQLdb.connect(**mysql_args) self.cursor = self.connection.cursor() self.prefix = prefix self.execution_time = 0.0 self.last_queries = [('', ())] def execute(self, query, params = ()): self.last_queries.append((query, params)) start = time.time() length = self.cursor.execute(query, params) self.execution_time += time.time() - start return length def __iter__(self): return iter(self.cursor) def fetchone(self): return self.cursor.fetchone() def fetchall(self): return self.cursor.fetchall() def select(self, fields, tables, wheres = (), joins = (), order_by = None, group_by = None, havings = (), limit = None, distinct = 'ALL', table_name = None): """ Execute a select with straight joins. """ query = 'SELECT %s %s FROM ' % (distinct, ', '.join(fields)) params = [] t_query = [] for table in tables: if type(table) is SelectWrapper: t_query.append('(%s) AS %s ' % (table[0], table[2])) params.extend(table[1]) elif '.' in table: t_query.append('%s ' % table) else: t_query.append('%s%s ' % (self.prefix, table)) query += ', '.join(t_query) + ' ' for join in joins: expr, params2 = self.construct_expression(join[2]) query += '%s JOIN %s ON %s ' % (join[0], join[1], expr) params.extend(params2) if wheres: where, params2 = self.construct_where(wheres) query += ' WHERE %s ' % where params.extend(params2) if group_by: query += ' GROUP BY %s ' % group_by if havings: having, params2 = self.construct_where(havings) query += ' HAVING %s ' % having params.extend(params2) if order_by: query += ' ORDER BY %s %s ' % order_by query += ' %s ' % self.construct_limit(limit) if table_name: return SelectWrapper((query, params, table_name)) else: return self.execute(query, params) def insert(self, table, keys = None, values = (), unescaped_values = ()): query = 'INSERT INTO %s%s ' % (self.prefix, table) if keys: query += '(%s) ' % ', '.join(keys) query += 'VALUES (%s)' % ', '.join(['%s' for v in values] + [v for v in unescaped_values]) return self.execute(query, values) def delete(self, table, wheres, limit = None): query = 'DELETE FROM %s%s ' % (self.prefix, table) where, params = self.construct_where(wheres) query += ' %s ' % where query += ' %s ' % self.construct_limit(limit) return self.execute(query, params) def update(self, table, values, wheres, limit = None): query = 'UPDATE %s%s SET ' % (self.prefix, table) updates, params1 = self.construct_update(values) where, params2 = self.construct_where(wheres) query += ' %s WHERE %s ' % (updates, where) query += ' %s ' % self.construct_limit(limit) return self.execute(query, params1 + params2) @staticmethod def construct_update(values): query = [] params = [] for item in values: if type(item) in (str, unicode): query.append(item) else: query.append('%s = %%s' % item[0]) params.append(item[1]) return ', '.join(query), tuple(params) def commit(self): self.connection.commit() def close(self): try: self.connection.close() except MySQLdb.Error: pass def lastrowid(self): return self.cursor.lastrowid @staticmethod def construct_where(wheres): """ Construct a WHERE from a list. List items can be: - str/unicode: Will be inserted straightly into the query, without any sanitizing. - tuple/list (field, value) with value str/unicode: Inserts field = value into the query with value properly sanitized. - tuple/list (field, values) with values tuple/list: Inserts field IN values into the query with values properly sanitized. The list is concatenated using AND. """ query = [] params = [] for item in wheres: q, p = Database.construct_expression(item) query.append(q) params.extend(p) return ' AND '.join(query), tuple(params) @staticmethod def construct_expression(item): if type(item) in (str, unicode): return item, () else: if type(item[1]) != SelectWrapper: if len(item) > 2: op = item[2] elif type(item[1]) in (tuple, list): op = 'IN' else: op = '=' if type(item[1]) in (tuple, list): return ('%s %s (%s)' % (item[0], op, ', '.join((Field.test(i) for i in item[1]))), [i for i in item[1] if type(i) != Field]) else: return '%s %s %%s' % (item[0], op), (item[1], ) else: return '%s (%s) AS %s' % (item[0], item[1][0], item[1][2]), item[1][1] @staticmethod def construct_limit(limit): if type(limit) is tuple: return 'LIMIT %s, %s' % limit elif limit is None: return '' else: return 'LIMIT %s' % int(limit) escape = MySQLdb.escape_string