wudb.py 56.2 KB
Newer Older
1 2
#!/usr/bin/env python3

3 4 5 6 7 8 9 10 11 12
# TODO: 
# FILES table: OBSOLETE column
#     OBSOLETE says that this file was replaced by a newer version, for example checkrels may want to create a new file with only part of the output. 
#     Should this be a pointer to the file that replaced the obsolete one? How to signal a file that is obsolete, but not replaced by anything?
#     If one file is replaced by several (say, due to a data corruption in the middle), we need a 1:n relationship. If several files are replaced by one (merge), 
#     we need n:1. What do? Do we really want an n:n relationship here? Disallow fragmenting files, or maybe simply not track it in the DB if we do?
# FILES table: CHECKSUM column
#     We need a fast check that the information stored in the DB still accurately reflects the file system contents. The test should also warn about files in upload/ which are not listed in DB


13 14 15
# Make Python 2.7 use the print() syntax from Python 3
from __future__ import print_function

16
import sys
17
import sqlite3
18 19
import threading
import traceback
20
import collections
21
import abc
22
from datetime import datetime
23
from workunit import Workunit
24 25 26 27
if sys.version_info.major == 3:
    from queue import Queue
else:
    from Queue import Queue
28
import patterns
29 30
import cadologger
import logging
31

32
DEBUG = 1
33
exclusive_transaction = [None, None]
34 35 36 37

DEFERRED = object()
IMMEDIATE = object()
EXCLUSIVE = object()
38 39 40

logger = logging.getLogger("Database")
logger.setLevel(logging.NOTSET)
41

42

43 44
PRINTED_CANCELLED_WARNING = False

45
def join3(l, pre=None, post=None, sep=", "):
46 47
    """ 
    If any parameter is None, it is interpreted as the empty string 
48
    >>> join3 ( ('a'), pre="+", post="-", sep=", ")
49
    '+a-'
50
    >>> join3 ( ('a', 'b'), pre="+", post="-", sep=", ")
51 52 53
    '+a-, +b-'
    >>> join3 ( ('a', 'b'))
    'a, b'
54
    >>> join3 ( ('a', 'b', 'c'), pre="+", post="-", sep=", ")
55 56
    '+a-, +b-, +c-'
    """
57 58 59 60 61 62
    if pre is None:
        pre = ""
    if post is None:
        post = ""
    if sep is None:
        sep = "";
Alexander Kruppa's avatar
Alexander Kruppa committed
63
    return sep.join([pre + k + post for k in l])
64 65

def dict_join3(d, sep=None, op=None, pre=None, post=None):
66
    """ 
67 68
    If any parameter is None, it is interpreted as the empty string
    >>> dict_join3 ( {"a": "1", "b": "2"}, sep=",", op="=", pre="-", post="+")
69 70
    '-a=1+,-b=2+'
    """
Alexander Kruppa's avatar
Alexander Kruppa committed
71 72 73 74 75 76 77 78
    if pre is None:
        pre = ""
    if post is None:
        post = ""
    if sep is None:
        sep = "";
    if op is None:
        op = ""
79
    return sep.join([pre + op.join(k) + post for k in d.items()])
80

81
def conn_commit(conn):
82
    logger.transaction("Commit on connection %d", id(conn))
83 84 85 86 87
    if DEBUG > 1:
        if not exclusive_transaction[0] is None and not conn is exclusive_transaction[0]:
            logger.warning("Commit on connection %d, but exclusive lock was on %d", id(conn), id(exclusive_transaction[0]))
        exclusive_transaction[0] = None
        exclusive_transaction[1] = None
88 89 90
    conn.commit()

def conn_close(conn):
91
    logger.transaction("Closing connection %d", id(conn))
92 93
    if conn.in_transaction:
        logger.warning("Connection %d being closed while in transaction", id(conn))
94 95
    conn.close()

96 97 98 99 100 101 102 103 104 105 106 107 108
# Dummy class for defining "constants" with reverse lookup
STATUS_NAMES = ["AVAILABLE", "ASSIGNED", "RECEIVED_OK", "RECEIVED_ERROR",
        "VERIFIED_OK", "VERIFIED_ERROR", "CANCELLED"]
STATUS_VALUES = range(len(STATUS_NAMES))
WuStatusBase = collections.namedtuple("WuStatusBase", STATUS_NAMES)
class WuStatusClass(WuStatusBase):
    def check(self, status):
        assert status in self
    def get_name(self, status):
        self.check(status)
        return STATUS_NAMES[status]

WuStatus = WuStatusClass(*STATUS_VALUES)
109

110

111 112 113 114 115 116 117 118 119
def check_tablename(name):
    """ Test whether name is a valid SQL table name.
    
    Raise an exception if it isn't.
    """
    no_ = name.replace("_", "")
    if not no_[0].isalpha() or not no_[1:].isalnum():
        raise Exception("%s is not valid for an SQL table name" % name)

120 121
# If we try to update the status in any way other than progressive 
# (AVAILABLE -> ASSIGNED -> ...), we raise this exception
122 123 124
class StatusUpdateError(Exception):
    pass

125
class MyCursor(sqlite3.Cursor):
126 127 128 129 130 131 132
    """ This class represents a DB cursor and provides convenience functions 
        around SQL queries. In particular it is meant to provide an  
        (1) an interface to SQL functionality via method calls with parameters, 
        and 
        (2) hiding some particularities of the SQL variant of the underlying 
            DBMS as far as possible """
        
133 134 135 136
    # This is used in where queries; it converts from named arguments such as 
    # "eq" to a binary operator such as "="
    name_to_operator = {"lt": "<", "le": "<=", "eq": "=", "ge": ">=", "gt" : ">", "ne": "!="}
    
137
    def __init__(self, conn):
138
        # Enable foreign key support
139 140
        self._conn = conn
        super().__init__(conn)
141

142 143 144 145 146
    @staticmethod
    def _without_None(d):
        """ Return a copy of the dictionary d, but without entries whose values 
            are None """
        return {k[0]:k[1] for k in d.items() if k[1] is not None}
147 148 149 150 151 152 153

    @staticmethod
    def as_string(d):
        if d is None:
            return ""
        else:
            return ", " + dict_join3(d, sep=", ", op=" AS ")
154
    
155
    @classmethod
156
    def _where_str(cls, name, **args):
157 158 159 160 161 162
        where = ""
        values = []
        for opname in args:
            if args[opname] is None:
                continue
            if where == "":
163
                where = " " + name + " "
164
            else:
165
                where = where + " AND "
166
            where = where + join3(args[opname].keys(), post=" " + cls.name_to_operator[opname] + " ?", sep=" AND ")
167 168 169
            values = values + list(args[opname].values())
        return (where, values)

170
    def _exec(self, command, values=None):
171 172
        """ Wrapper around self.execute() that prints arguments 
            for debugging and retries in case of "database locked" exception """
173 174 175 176 177 178 179 180 181 182 183
        
        # FIXME: should be the caller's class name, as _exec could be 
        # called from outside this class
        classname = self.__class__.__name__
        parent = sys._getframe(1).f_code.co_name
        command_str = command.replace("?", "%r")
        if not values is None:
            command_str = command_str % tuple(values)
        logger.transaction("%s.%s(): connection = %s, command = %s",
                           classname, parent, id(self._conn), command_str)
        
184 185 186 187 188 189 190 191
        i = 0
        while True:
            try:
                if values is None:
                    self.execute(command)
                else:
                    self.execute(command, values)
                break
192 193 194 195 196 197 198 199
            except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
                if str(e) == "database disk image is malformed" or \
                        str(e) == "disk I/O error":
                    logger.critical("sqlite3 reports error accessing the database.")
                    logger.critical("Database file may have gotten corrupted, "
                            "or maybe filesystem does not properly support "
                            "file locking.")
                    raise
200
                i += 1
201 202
                if i == 10 or str(e) != "database is locked":
                    raise
203 204 205
        logger.transaction("%s.%s(): connection = %s, command finished",
                           classname, parent, id(self._conn))

206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
    def begin(self, mode=None):
        if mode is None:
            self._exec("BEGIN")
        elif mode is DEFERRED:
            self._exec("BEGIN DEFERRED")
        elif mode is IMMEDIATE:
            self._exec("BEGIN IMMEDIATE")
        elif mode is EXCLUSIVE:
            if DEBUG > 1:
                tb = traceback.extract_stack()
                if not exclusive_transaction == [None, None]:
                    old_tb_str = "".join(traceback.format_list(exclusive_transaction[1]))
                    new_tb_str = "".join(traceback.format_list(tb))
                    logger.warning("Called MyCursor.begin(EXCLUSIVE) when there was aleady an exclusive transaction %d\n%s",
                                        id(exclusive_transaction[0]), old_tb_str)
                    logger.warning("New transaction: %d\n%s", id(self.connection), new_tb_str)
            
            self._exec("BEGIN EXCLUSIVE")
            
            if DEBUG > 1:
                assert exclusive_transaction == [None, None]
                exclusive_transaction[0] = self.connection
                exclusive_transaction[1] = tb
        else:
            raise TypeError("Invalid mode parameter: %r" % mode)
    
232 233 234
    def pragma(self, prag):
        self._exec("PRAGMA %s;" % prag)
    
235
    def create_table(self, table, layout):
236
        """ Creates a table with fields as described in the layout parameter """
237 238
        command = "CREATE TABLE IF NOT EXISTS %s( %s );" % \
                  (table, ", ".join(" ".join(k) for k in layout))
239
        self._exec (command)
240
    
241 242
    def create_index(self, name, table, columns):
        """ Creates an index with fields as described in the columns list """
243 244
        command = "CREATE INDEX IF NOT EXISTS %s ON %s( %s );" \
                  % (name, table, ", ".join(columns))
245
        self._exec (command)
246
    
247
    def insert(self, table, d):
248
        """ Insert a new entry, where d is a dictionary containing the 
249 250
            field:value pairs. Returns the row id of the newly created entry """
        # INSERT INTO table (field_1, field_2, ..., field_n) 
251
        # 	VALUES (value_1, value_2, ..., value_n)
252 253 254 255 256 257 258

        # Fields is a copy of d but with entries removed that have value None.
        # This is done primarily to avoid having "id" listed explicitly in the 
        # INSERT statement, because the DB fills in a new value automatically 
        # if "id" is the primary key. But I guess not listing field:NULL items 
        # explicitly in an INSERT is a good thing in general
        fields = self._without_None(d)
259
        fields_str = ", ".join(fields.keys())
260 261

        sqlformat = ", ".join(("?",) * len(fields)) # sqlformat = "?, ?, ?, " ... "?"
262 263
        command = "INSERT INTO %s( %s ) VALUES ( %s );" \
                  % (table, fields_str, sqlformat)
264
        values = list(fields.values())
265
        self._exec(command, values)
266 267
        rowid = self.lastrowid
        return rowid
268

269
    def update(self, table, d, **conditions):
270 271 272
        """ Update fields of an existing entry. conditions specifies the where 
            clause to use for to update, entries in the dictionary d are the 
            fields and their values to update """
273
        # UPDATE table SET column_1=value1, column2=value_2, ..., 
274
        # column_n=value_n WHERE column_n+1=value_n+1, ...,
275
        setstr = join3(d.keys(), post = " = ?", sep = ", ")
276
        (wherestr, wherevalues) = self._where_str("WHERE", **conditions)
277
        command = "UPDATE %s SET %s %s" % (table, setstr, wherestr)
Alexander Kruppa's avatar
Alexander Kruppa committed
278
        values = list(d.values()) + wherevalues
279
        self._exec(command, values)
280
    
281
    def where_query(self, joinsource, col_alias=None, limit=None, order=None,
282
                    **conditions):
283
        # Table/Column names cannot be substituted, so include in query directly.
284
        (WHERE, values) = self._where_str("WHERE", **conditions)
285 286 287 288 289
        if order is None:
            ORDER = ""
        else:
            if not order[1] in ("ASC", "DESC"):
                raise Exception
290
            ORDER = " ORDER BY %s %s" % (order[0], order[1])
291 292 293
        if limit is None:
            LIMIT = ""
        else:
294
            LIMIT = " LIMIT %s" % int(limit)
295
        AS = self.as_string(col_alias);
296 297 298
        command = "SELECT * %s FROM %s %s %s %s" \
                  % (AS, joinsource, WHERE, ORDER, LIMIT)
        return (command, values)
299

300 301
    def where(self, joinsource, col_alias=None, limit=None, order=None,
              values=[], **conditions):
302 303 304 305 306 307
        """ Get a up to "limit" table rows (limit == 0: no limit) where 
            the key:value pairs of the dictionary "conditions" are set to the 
            same value in the database table """
        (command, newvalues) = self.where_query(joinsource, col_alias, limit,
                                             order, **conditions)
        self._exec(command + ";", values + newvalues)
308
        
Alexander Kruppa's avatar
Alexander Kruppa committed
309 310 311 312 313 314 315
    def count(self, joinsource, **conditions):
        """ Count rows where the key:value pairs of the dictionary "conditions" are 
            set to the same value in the database table """

        # Table/Column names cannot be substituted, so include in query directly.
        (WHERE, values) = self._where_str("WHERE", **conditions)

316
        command = "SELECT COUNT(*) FROM %s %s;" % (joinsource, WHERE)
Alexander Kruppa's avatar
Alexander Kruppa committed
317 318 319 320
        self._exec(command, values)
        r = self.fetchone()
        return int(r[0])
        
321 322 323
    def delete(self, table, **conditions):
        """ Delete the rows specified by conditions """
        (WHERE, values) = self._where_str("WHERE", **conditions)
324
        command = "DELETE FROM %s %s;" % (table, WHERE)
325
        self._exec(command, values)
326

327 328
    def where_as_dict(self, joinsource, col_alias=None, limit=None,
                      order=None, values=[], **conditions):
329
        self.where(joinsource, col_alias=col_alias, limit=limit, 
330
                      order=order, values=values, **conditions)
331 332
        # cursor.description is a list of lists, where the first element of 
        # each inner list is the column name
333
        result = []
334 335
        desc = [k[0] for k in self.description]
        row = self.fetchone()
336
        while row is not None:
337
            # print("MyCursor.where_as_dict(): row = %s" % row)
338
            result.append(dict(zip(desc, row)))
339
            row = self.fetchone()
340 341
        return result

342 343

class DbTable(object):
344
    """ A class template defining access methods to a database table """
345
    
346 347 348 349 350 351 352 353 354 355 356
    @staticmethod
    def _subdict(d, l):
        """ Returns a dictionary of those key:value pairs of d for which key 
            exists l """
        if d is None:
            return None
        return {k:d[k] for k in d.keys() if k in l}

    def _get_colnames(self):
        return [k[0] for k in self.fields]

357 358 359 360 361 362
    def getname(self):
        return self.tablename

    def getpk(self):
        return self.primarykey

363 364 365 366 367
    def dictextract(self, d):
        """ Return a dictionary with all those key:value pairs of d
            for which key is in self._get_colnames() """
        return self._subdict(d, self._get_colnames())

368
    def create(self, cursor):
369 370 371 372 373
        fields = list(self.fields)
        if self.references:
            # If this table references another table, we use the primary
            # key of the referenced table as the foreign key name
            r = self.references # referenced table
374 375
            fk = (r.getpk(), "INTEGER", "REFERENCES %s ( %s ) " \
                  % (r.getname(), r.getpk()))
376 377
            fields.append(fk)
        cursor.create_table(self.tablename, fields)
378 379
        if self.references:
            # We always create an index on the foreign key
380 381
            cursor.create_index(self.tablename + "_pkindex", r.tablename, 
                                (fk[0], ))
382 383 384
        for indexname in self.index:
            cursor.create_index(self.tablename + "_" + indexname, 
                                self.tablename, self.index[indexname])
385

386
    def insert(self, cursor, values, foreign=None):
387 388
        """ Insert a new row into this table. The column:value pairs are 
            specified key:value pairs of the dictionary d. 
389 390
            The database's row id for the new entry is stored in 
            d[primarykey] """
391 392 393 394 395 396 397 398 399 400 401 402 403
        d = self.dictextract(values)
        assert self.primarykey not in d or d[self.primarykey] is None
        # If a foreign key is specified in foreign, add it to the column
        # that is marked as being a foreign key
        if foreign:
            r = self.references.primarykey
            assert not r in d or d[r] is None
            d[r] = foreign
        values[self.primarykey] = cursor.insert(self.tablename, d)

    def insert_list(self, cursor, values, foreign=None):
        for v in values:
            self.insert(cursor, v, foreign)
404

405
    def update(self, cursor, d, **conditions):
406 407
        """ Update an existing row in this table. The column:value pairs to 
            be written are specified key:value pairs of the dictionary d """
408
        cursor.update(self.tablename, d, **conditions)
409

410 411 412 413
    def delete(self, cursor, **conditions):
        """ Delete an existing row in this table """
        cursor.delete(self.tablename, **conditions)

414
    def where(self, cursor, limit=None, order=None, **conditions):
415
        assert order is None or order[0] in self._get_colnames()
416 417 418
        return cursor.where_as_dict(self.tablename, limit=limit, 
                                    order=order, **conditions)

419 420

class WuTable(DbTable):
421
    tablename = "workunits"
422
    fields = (
423
        ("wurowid", "INTEGER PRIMARY KEY ASC", "UNIQUE NOT NULL"), 
424 425 426 427 428 429 430 431 432
        ("wuid", "TEXT", "UNIQUE NOT NULL"), 
        ("status", "INTEGER", "NOT NULL"), 
        ("wu", "TEXT", "NOT NULL"), 
        ("timecreated", "TEXT", ""), 
        ("timeassigned", "TEXT", ""), 
        ("assignedclient", "TEXT", ""), 
        ("timeresult", "TEXT", ""), 
        ("resultclient", "TEXT", ""), 
        ("errorcode", "INTEGER", ""), 
433
        ("failedcommand", "INTEGER", ""), 
434
        ("timeverified", "TEXT", ""),
435
        ("retryof", "INTEGER", "REFERENCES %s" % tablename),
436
        ("priority", "INTEGER", "")
437
    )
438 439
    primarykey = fields[0][0]
    references = None
440
    index = {"wuidindex": (fields[1][0],), "statusindex" : (fields[2][0],)}
441

442
class FilesTable(DbTable):
443
    tablename = "files"
444
    fields = (
445
        ("filesrowid", "INTEGER PRIMARY KEY ASC", "UNIQUE NOT NULL"), 
446
        ("filename", "TEXT", ""), 
447 448 449
        ("path", "TEXT", "UNIQUE NOT NULL"),
        ("type", "TEXT", ""),
        ("command", "INTEGER", "")
450
    )
451 452
    primarykey = fields[0][0]
    references = WuTable()
453
    index = {}
454

455

456 457
class DictDbTable(DbTable):
    fields = (
458
        ("rowid", "INTEGER PRIMARY KEY ASC", "UNIQUE NOT NULL"),
459
        ("key", "TEXT", "UNIQUE NOT NULL"),
460
        ("type", "INTEGER", "NOT NULL"),
461 462 463 464 465
        ("value", "TEXT", "")
        )
    primarykey = fields[0][0]
    references = None
    index = {"keyindex": ("key",)}
466 467 468
    def __init__(self, *args, name = None, **kwargs):
        self.tablename = name
        super().__init__(*args, **kwargs)
469 470


471
class DictDbAccess(collections.MutableMapping):
472
    """ A DB-backed flat dictionary.
473
    
474 475 476
    Flat means that the value of each dictionary entry must be a type that
    the underlying DB understands, like integers, strings, etc., but not
    collections or other complex types.
477
    
478 479 480
    A copy of all the data in the table is kept in memory; read accesses 
    are always served from the in-memory dict. Write accesses write through
    to the DB.
481
    
482 483
    >>> conn = sqlite3.connect(':memory:')
    >>> d = DictDbAccess(conn, 'test')
484 485
    >>> d == {}
    True
486
    >>> d['a'] = '1'
487 488
    >>> d == {'a': '1'}
    True
489
    >>> d['a'] = 2
490 491
    >>> d == {'a': 2}
    True
492
    >>> d['b'] = '3'
493
    >>> d == {'a': 2, 'b': '3'}
494 495 496
    True
    >>> del(d)
    >>> d = DictDbAccess(conn, 'test')
497
    >>> d == {'a': 2, 'b': '3'}
498 499
    True
    >>> del(d['b'])
500 501
    >>> d == {'a': 2}
    True
502
    >>> d.setdefault('a', '3')
503
    2
504 505
    >>> d == {'a': 2}
    True
506 507 508
    >>> d.setdefault('b', 3.0)
    3.0
    >>> d == {'a': 2, 'b': 3.0}
509 510
    True
    >>> d.setdefault(None, {'a': '3', 'c': '4'})
511
    >>> d == {'a': 2, 'b': 3.0, 'c': '4'}
512
    True
513 514
    >>> d.update({'a': '3', 'd': True})
    >>> d == {'a': '3', 'b': 3.0, 'c': '4', 'd': True}
515 516 517
    True
    >>> del(d)
    >>> d = DictDbAccess(conn, 'test')
518
    >>> d == {'a': '3', 'b': 3.0, 'c': '4', 'd': True}
519
    True
520
    >>> d.clear(['a', 'd'])
521 522 523 524 525 526 527 528 529 530 531 532 533
    >>> d == {'b': 3.0, 'c': '4'}
    True
    >>> del(d)
    >>> d = DictDbAccess(conn, 'test')
    >>> d == {'b': 3.0, 'c': '4'}
    True
    >>> d.clear()
    >>> d == {}
    True
    >>> del(d)
    >>> d = DictDbAccess(conn, 'test')
    >>> d == {}
    True
534
    """
535
    
536
    types = (str, int, float, bool)
537
    
538
    def __init__(self, db, name):
539
        ''' Attaches to a DB table and reads values stored therein. 
540 541 542 543 544
        
        db can be a string giving the file name for the DB (same as for 
        sqlite3.connect()), or an open DB connection. The latter is allowed 
        primarily for making the doctest work, so we can reuse the same 
        memory-backed DB connection, but it may be useful in other contexts.
545
        '''
546 547 548
        
        if isinstance(db, str):
            self._conn = sqlite3.connect(db)
549
            self._ownconn = True
550 551
        else:
            self._conn = db
552
            self._ownconn = False
553
        self._table = DictDbTable(name = name)
554
        # Create an empty table if none exists
555 556 557
        cursor = self._conn.cursor(MyCursor)
        self._table.create(cursor);
        # Get the entries currently stored in the DB
558
        self._data = self._getall()
559
        cursor.close()
560
    
561 562 563 564 565 566 567 568 569 570 571 572 573
    # Implement the abstract methods defined by collections.MutableMapping
    # All but __del__ and __setitem__ are simply passed through to the self._data
    # dictionary
    
    def __getitem__(self, key):
        return self._data.__getitem__(key)
    
    def __iter__(self):
        return self._data.__iter__()
    
    def  __len__(self):
        return self._data.__len__()
    
574 575 576
    def __str__(self):
        return self._data.__str__()
    
577
    def __del__(self):
578
        """ Close the DB connection and delete the in-memory dictionary """
579
        if self._ownconn:
580 581 582 583 584 585 586
            # When we shut down Python hard, e.g., in an exception, the 
            # conn_close() function object may have been destroyed already
            # and trying to call it would raise another exception.
            if callable(conn_close):
                conn_close(self._conn)
            else:
                self._conn.close()
587
    
588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609
    def __convert_value(self, row):
        valuestr = row["value"]
        valuetype = row["type"]
        # Look up constructor for this type
        typecon = self.types[int(valuetype)]
        # Bool is handled separately as bool("False") == True
        if typecon == bool:
            if valuestr == "True":
                return True
            elif valuestr == "False":
                return False
            else:
                raise ValueError("Value %s invalid for Bool type", valuestr)
        return typecon(valuestr)
    
    def __get_type_idx(self, value):
        valuetype = type(value)
        for (idx, t) in enumerate(self.types):
            if valuetype == t:
                return idx
        raise TypeError("Type %s not supported" % str(valuetype))
    
610
    def _getall(self):
611
        """ Reads the whole table and returns it as a dict """
612 613 614
        cursor = self._conn.cursor(MyCursor)
        rows = self._table.where(cursor)
        cursor.close()
615
        return {r["key"]: self.__convert_value(r) for r in rows}
616
    
617 618
    def __setitem_nocommit(self, cursor, key, value):
        """ Set dictioary key to value and update/insert into table,
619
        but don't commit. Cursor must be given
620 621
        """
        update = {"value": str(value), "type": self.__get_type_idx(value)}
622
        if key in self._data:
623
            # Update the table row where column "key" equals key
624
            self._table.update(cursor, update, eq={"key": key})
625
        else:
626 627 628 629
            # Insert a new row
            update["key"] = key
            self._table.insert(cursor, update)
        # Update the in-memory dict
630
        self._data[key] = value
631
    
632 633
    def __setitem__(self, key, value):
        """ Access by indexing, e.g., d["foo"]. Always commits """
634
        cursor = self._conn.cursor(MyCursor)
635 636
        if not self._conn.in_transaction:
            cursor.begin(EXCLUSIVE)
637
        self.__setitem_nocommit(cursor, key, value)
638
        conn_commit(self._conn)
639
        cursor.close()
640
    
641 642
    def __delitem__(self, key, commit=True):
        """ Delete a key from the dictionary """
643
        cursor = self._conn.cursor(MyCursor)
644 645
        if not self._conn.in_transaction:
            cursor.begin(EXCLUSIVE)
646
        self._table.delete(cursor, eq={"key": key})
647 648
        if commit:
            conn_commit(self._conn)
649
        cursor.close()
650
        del(self._data[key])
651
    
652
    def setdefault(self, key, default = None, commit=True):
653
        ''' Setdefault function that allows a mapping as input
654 655 656
        
        Values from default dict are merged into self, *not* overwriting
        existing values in self '''
657
        if key is None and isinstance(default, collections.Mapping):
658 659 660
            update = {key:default[key] for key in default if not key in self}
            if update:
                self.update(update, commit=commit)
661 662
            return None
        elif not key in self:
663
            self.update({key:default}, commit=commit)
664
        return self._data[key]
665
    
666
    def update(self, other, commit=True):
667
        cursor = self._conn.cursor(MyCursor)
668 669
        if not self._conn.in_transaction:
            cursor.begin(EXCLUSIVE)
670
        for (key, value) in other.items():
671
            self.__setitem_nocommit(cursor, key, value)
672 673
        if commit:
            conn_commit(self._conn)
674
        cursor.close()
675
    
676
    def clear(self, args = None, commit=True):
677
        """ Overridden clear that allows removing several keys atomically """
678
        cursor = self._conn.cursor(MyCursor)
679 680
        if not self._conn.in_transaction:
            cursor.begin(EXCLUSIVE)
681
        if args is None:
682 683 684 685 686 687
            self._data.clear()
            self._table.delete(cursor)
        else:
            for key in args:
                del(self._data[key])
                self._table.delete(cursor, eq={"key": key})
688 689
        if commit:
            conn_commit(self._conn)
690 691
        cursor.close()

692

693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761
class Mapper(object):
    """ This class translates between application objects, i.e., Python 
        directories, and the relational data layout in an SQL DB, i.e.,
        one or more tables which possibly have foreign key relationships 
        that map to hierarchical data structures. For now, only one 
        foreign key / subdirectory."""

    def __init__(self, table, subtables = None):
        self.table = table
        self.subtables = {}
        if subtables:
            for s in subtables.keys():
                self.subtables[s] = Mapper(subtables[s])

    def __sub_dict(self, d):
        """ For each key "k" that has a subtable assigned in "self.subtables",
        pop the entry with key "k" from "d", and store it in a new directory
        which is returned. I.e., the directory d is separated into 
        two parts: the part which corresponds to subtables and is the return 
        value, and the rest which is left in the input dictionary. """
        sub_dict = {}
        for s in self.subtables.keys():
            # Don't store s:None entries even if they exist in d
            t = d.pop(s, None)
            if not t is None:
                sub_dict[s] = t
        return sub_dict

    def getname(self):
        return self.table.getname()

    def getpk(self):
        return self.table.getpk()

    def create(self, cursor):
        self.table.create(cursor)
        for t in self.subtables.values():
            t.create(cursor)

    def insert(self, cursor, wus, foreign=None):
        pk = self.getpk()
        for wu in wus:
            # Make copy so sub_dict does not change caller's data
            wuc = wu.copy()
            # Split off entries that refer to subtables
            sub_dict = self.__sub_dict(wuc)
            # We add the entries in wuc only if it does not have a primary 
            # key yet. If it does have a primary key, we add only the data 
            # for the subtables
            if not pk in wuc:
                self.table.insert(cursor, wuc, foreign=foreign)
                # Copy primary key into caller's data
                wu[pk] = wuc[pk]
            for subtable_name in sub_dict.keys():
                self.subtables[subtable_name].insert(
                    cursor, sub_dict[subtable_name], foreign=wu[pk])

    def update(self, cursor, wus):
        pk = self.getpk()
        for wu in wus:
            assert not wu[pk] is None
            wuc = wu.copy()
            sub_dict = self.__sub_dict(wuc)
            rowid = wuc.pop(pk, None)
            if rowid:
                self.table.update(cursor, wuc, {wp: rowid})
            for s in sub.keys:
                self.subtables[s].update(cursor, sub_dict[s])
    
Alexander Kruppa's avatar
Alexander Kruppa committed
762 763 764 765
    def count(self, cursor, **cond):
        joinsource = self.table.tablename
        return cursor.count(joinsource, **cond)
    
766
    def where(self, cursor, limit = None, order = None, **cond):
767 768
        # We want:
        # SELECT * FROM (SELECT * from workunits WHERE status = 2 LIMIT 1) LEFT JOIN files USING ( wurowid );
769
        pk = self.getpk()
770 771 772
        (command, values) = cursor.where_query(self.table.tablename,
                                               limit=limit, **cond)
        joinsource = "( %s )" % command
773 774
        for s in self.subtables.keys():
            # FIXME: this probably breaks with more than 2 tables
775 776
            joinsource = "%s LEFT JOIN %s USING ( %s )" \
                         % (joinsource, self.subtables[s].getname(), pk)
777 778
        # FIXME: don't get result rows as dict! Leave as tuple and
        # take them apart positionally
779
        rows = cursor.where_as_dict(joinsource, order=order, values=values)
780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800
        wus = []
        for r in rows:
            # Collapse rows with identical primary key
            if len(wus) == 0 or r[pk] != wus[-1][pk]:
                wus.append(self.table.dictextract(r))
                for s in self.subtables.keys():
                    wus[-1][s] = None

            for (sn, sm) in self.subtables.items():
                spk = sm.getpk()
                if spk in r and not r[spk] is None:
                    if wus[-1][sn] == None:
                        # If this sub-array is empty, init it
                        wus[-1][sn] = [sm.table.dictextract(r)]
                    elif r[spk] != wus[-1][sn][-1][spk]:
                        # If not empty, and primary key of sub-table is not
                        # same as in previous entry, add it
                        wus[-1][sn].append(sm.table.dictextract(r))
        return wus

class WuAccess(object): # {
801 802 803 804
    """ This class maps between the WORKUNIT and FILES tables 
        and a dictionary 
        {"wuid": string, ..., "timeverified": string, "files": list}
        where list is None or a list of dictionaries of the from
805 806
        {"id": int, "type": int, "wuid": string, "filename": string,
        "path": string}
807
        Operations on instances of WuAcccess are directly carried 
808
        out on the database persistent storage, i.e., they behave kind 
809
        of as if the WuAccess instance were itself a persistent 
810
        storage device """
811
    
812 813 814 815 816 817 818
    def __init__(self, db):
        if isinstance(db, str):
            self.conn = sqlite3.connect(db)
            self._ownconn = True
        else:
            self.conn = db
            self._ownconn = False
819 820 821 822
        cursor = self.conn.cursor(MyCursor)
        cursor.pragma("foreign_keys = ON")
        conn_commit(self.conn)
        cursor.close()
823
        self.mapper = Mapper(WuTable(), {"files": FilesTable()})
824
    
825
    def __del__(self):
826 827 828 829
        if self._ownconn:
            if callable(conn_close):
                conn_close(self.conn)
            else:
Emmanuel Thomé's avatar
Emmanuel Thomé committed
830
                self.conn.close()
831
    
832 833
    @staticmethod
    def to_str(wus):
834
        r = []
835
        for wu in wus:
836
            s = "Workunit %s:\n" % wu["wuid"]
837 838
            for (k,v) in wu.items():
                if k != "wuid" and k != "files":
839
                    s += "  %s: %r\n" % (k, v)
840
            if "files" in wu:
841
                s += "  Associated files:\n"
842
                if wu["files"] is None:
843
                    s += "    None\n"
844 845
                else:
                    for f in wu["files"]:
846
                        s += "    %s\n" % f
847 848
            r.append(s)
        return '\n'.join(r)
849

850 851
    @staticmethod
    def _checkstatus(wu, status):
852
        # logger.debug("WuAccess._checkstatus(%s, %s)", wu, status)
853 854 855 856 857 858
        wu_status = wu["status"]
        if wu_status != status:
            msg = "Workunit %s has status %s (%s), expected %s (%s)" % \
                  (wu["wuid"], wu_status, WuStatus.get_name(wu_status), 
                   status, WuStatus.get_name(status))
            logger.error ("WuAccess._checkstatus(): %s", msg)
859
            raise StatusUpdateError(msg)
860

861 862
    @staticmethod
    def check(data):
863
        status = data["status"]
864
        WuStatus.check(status)
865 866
        wu = Workunit(data["wu"])
        assert wu.get_id() == data["wuid"]
867
        if status > WuStatus.RECEIVED_ERROR:
868
            return
869
        if status == WuStatus.RECEIVED_ERROR:
870
            assert data["errorcode"] != 0
871
            return
872
        if status == WuStatus.RECEIVED_OK:
873
            assert data["errorcode"] is None or data["errorcode"] == 0
874
            return
875 876 877
        assert data["errorcode"] is None
        assert data["timeresult"] is None
        assert data["resultclient"] is None
878 879
        if status == WuStatus.ASSIGNED:
            return
880 881
        assert data["timeassigned"] is None
        assert data["assignedclient"] is None
882 883
        if status == WuStatus.AVAILABLE:
            return
884
        assert data["timecreated"] is None
885 886 887 888 889 890
        # etc.
    
    # Here come the application-visible functions that implement the 
    # "business logic": creating a new workunit from the text of a WU file,
    # assigning it to a client, receiving a result for the WU, marking it as
    # verified, or marking it as cancelled
891

892
    def _add_files(self, cursor, files, wuid=None, rowid=None):
893 894 895 896 897 898 899 900 901 902 903 904
        # Exactly one must be given
        assert not wuid is None or not rowid is None
        assert wuid is None or rowid is None
        # FIXME: allow selecting row to update directly via wuid, without 
        # doing query for rowid first
        pk = self.mapper.getpk()
        if rowid is None:
            wu = get_by_wuid(cursor, wuid)
            if wu:
                rowid = wu[pk]
            else:
                return False
905 906 907
        colnames = ("filename", "path", "type", "command")
        # zipped length is that of shortest list, so "command" is optional 
        d = (dict(zip(colnames, f)) for f in files)
908 909 910 911 912 913
        # These two should behave identically
        if True:
            self.mapper.insert(cursor, [{pk:rowid, "files": d},])
        else:
            self.mapper.subtables["files"].insert(cursor, d, foreign=rowid)

914
    def create_tables(self):
915
        cursor = self.conn.cursor(MyCursor)
916
        cursor.pragma("journal_mode=WAL")
917
        self.mapper.create(cursor)