Package trac :: Package db :: Module api

Source Code for Module trac.db.api

  1  # -*- coding: utf-8 -*- 
  2  # 
  3  # Copyright (C) 2005-2023 Edgewall Software 
  4  # Copyright (C) 2005 Christopher Lenz <[email protected]> 
  5  # All rights reserved. 
  6  # 
  7  # This software is licensed as described in the file COPYING, which 
  8  # you should have received as part of this distribution. The terms 
  9  # are also available at https://trac.edgewall.org/wiki/TracLicense. 
 10  # 
 11  # This software consists of voluntary contributions made by many 
 12  # individuals. For the exact contribution history, see the revision 
 13  # history and logs, available at https://trac.edgewall.org/log/. 
 14  # 
 15  # Author: Christopher Lenz <[email protected]> 
 16   
 17  import importlib 
 18  import os 
 19  import time 
 20  import urllib 
 21  from abc import ABCMeta, abstractmethod 
 22   
 23  from trac import db_default 
 24  from trac.api import IEnvironmentSetupParticipant, ISystemInfoProvider 
 25  from trac.config import BoolOption, ConfigurationError, IntOption, Option 
 26  from trac.core import * 
 27  from trac.db.pool import ConnectionPool 
 28  from trac.db.schema import Table 
 29  from trac.db.util import ConnectionWrapper 
 30  from trac.util.concurrency import ThreadLocal 
 31  from trac.util.html import tag 
 32  from trac.util.text import unicode_passwd 
 33  from trac.util.translation import _, tag_ 
34 35 36 -class DbContextManager(object):
37 """Database Context Manager 38 39 The outermost `DbContextManager` will close the connection. 40 """ 41 42 db = None 43
44 - def __init__(self, env):
45 self.dbmgr = DatabaseManager(env)
46
47 - def execute(self, query, params=None):
48 """Shortcut for directly executing a query.""" 49 with self as db: 50 return db.execute(query, params)
51 52 __call__ = execute 53
54 - def executemany(self, query, params=None):
55 """Shortcut for directly calling "executemany" on a query.""" 56 with self as db: 57 return db.executemany(query, params)
58
59 60 -class TransactionContextManager(DbContextManager):
61 """Transactioned Database Context Manager for retrieving a 62 `~trac.db.util.ConnectionWrapper`. 63 64 The outermost such context manager will perform a commit upon 65 normal exit or a rollback after an exception. 66 """ 67
68 - def __enter__(self):
69 db = self.dbmgr._transaction_local.wdb # outermost writable db 70 if not db: 71 db = self.dbmgr._transaction_local.rdb # reuse wrapped connection 72 if db: 73 db = ConnectionWrapper(db.cnx, db.log) 74 else: 75 db = self.dbmgr.get_connection() 76 self.dbmgr._transaction_local.wdb = self.db = db 77 return db
78
79 - def __exit__(self, et, ev, tb):
80 if self.db: 81 self.dbmgr._transaction_local.wdb = None 82 if et is None: 83 self.db.commit() 84 else: 85 self.db.rollback() 86 if not self.dbmgr._transaction_local.rdb: 87 self.db.close()
88
89 90 -class QueryContextManager(DbContextManager):
91 """Database Context Manager for retrieving a read-only 92 `~trac.db.util.ConnectionWrapper`. 93 """ 94
95 - def __enter__(self):
96 db = self.dbmgr._transaction_local.rdb # outermost readonly db 97 if not db: 98 db = self.dbmgr._transaction_local.wdb # reuse wrapped connection 99 if db: 100 db = ConnectionWrapper(db.cnx, db.log, readonly=True) 101 else: 102 db = self.dbmgr.get_connection(readonly=True) 103 self.dbmgr._transaction_local.rdb = self.db = db 104 return db
105
106 - def __exit__(self, et, ev, tb):
107 if self.db: 108 self.dbmgr._transaction_local.rdb = None 109 if not self.dbmgr._transaction_local.wdb: 110 self.db.close()
111
112 113 -class ConnectionBase(object):
114 """Abstract base class for database connection classes.""" 115 116 __metaclass__ = ABCMeta 117 118 @abstractmethod
119 - def cast(self, column, type):
120 """Returns a clause casting `column` as `type`.""" 121 pass
122 123 @abstractmethod
124 - def concat(self, *args):
125 """Returns a clause concatenating the sequence `args`.""" 126 pass
127 128 @abstractmethod
129 - def drop_column(self, table, column):
130 """Drops the `column` from `table`.""" 131 pass
132 133 @abstractmethod
134 - def drop_table(self, table):
135 """Drops the `table`.""" 136 pass
137 138 @abstractmethod
139 - def get_column_names(self, table):
140 """Returns the list of the column names in `table`.""" 141 pass
142 143 @abstractmethod
144 - def get_last_id(self, cursor, table, column='id'):
145 """Returns the current value of the primary key sequence for `table`. 146 The `column` of the primary key may be specified, which defaults 147 to `id`.""" 148 pass
149 150 @abstractmethod
151 - def get_sequence_names(self):
152 """Returns a list of the sequence names.""" 153 pass
154 155 @abstractmethod
156 - def get_table_names(self):
157 """Returns a list of the table names.""" 158 pass
159 160 @abstractmethod
161 - def has_table(self, table):
162 """Returns whether the table exists.""" 163 pass
164 165 @abstractmethod
166 - def like(self):
167 """Returns a case-insensitive `LIKE` clause.""" 168 pass
169 170 @abstractmethod
171 - def like_escape(self, text):
172 """Returns `text` escaped for use in a `LIKE` clause.""" 173 pass
174 175 @abstractmethod
176 - def prefix_match(self):
177 """Return a case sensitive prefix-matching operator.""" 178 pass
179 180 @abstractmethod
181 - def prefix_match_value(self, prefix):
182 """Return a value for case sensitive prefix-matching operator.""" 183 pass
184 185 @abstractmethod
186 - def quote(self, identifier):
187 """Returns the quoted `identifier`.""" 188 pass
189 190 @abstractmethod
191 - def reset_tables(self):
192 """Deletes all data from the tables and resets autoincrement indexes. 193 194 :return: list of names of the tables that were reset. 195 """ 196 pass
197 198 @abstractmethod
199 - def update_sequence(self, cursor, table, column='id'):
200 """Updates the current value of the primary key sequence for `table`. 201 The `column` of the primary key may be specified, which defaults 202 to `id`.""" 203 pass
204
205 206 -class IDatabaseConnector(Interface):
207 """Extension point interface for components that support the 208 connection to relational databases. 209 """ 210
212 """Return the connection URL schemes supported by the 213 connector, and their relative priorities as an iterable of 214 `(scheme, priority)` tuples. 215 216 If `priority` is a negative number, this is indicative of an 217 error condition with the connector. An error message should be 218 attached to the `error` attribute of the connector. 219 """
220
221 - def get_connection(path, log=None, **kwargs):
222 """Create a new connection to the database."""
223
224 - def get_exceptions():
225 """Return an object (typically a module) containing all the 226 backend-specific exception types as attributes, named 227 according to the Python Database API 228 (http://www.python.org/dev/peps/pep-0249/). 229 """
230
231 - def init_db(path, schema=None, log=None, **kwargs):
232 """Initialize the database."""
233
234 - def destroy_db(self, path, log=None, **kwargs):
235 """Destroy the database."""
236
237 - def db_exists(self, path, log=None, **kwargs):
238 """Return `True` if the database exists."""
239
240 - def to_sql(table):
241 """Return the DDL statements necessary to create the specified 242 table, including indices."""
243
244 - def backup(dest):
245 """Backup the database to a location defined by 246 trac.backup_dir"""
247
248 - def get_system_info():
249 """Yield a sequence of `(name, version)` tuples describing the 250 name and version information of external packages used by the 251 connector. 252 """
253
254 255 -class DatabaseManager(Component):
256 """Component used to manage the `IDatabaseConnector` implementations.""" 257 258 implements(IEnvironmentSetupParticipant, ISystemInfoProvider) 259 260 connectors = ExtensionPoint(IDatabaseConnector) 261 262 connection_uri = Option('trac', 'database', 'sqlite:db/trac.db', 263 """Database connection 264 [wiki:TracEnvironment#DatabaseConnectionStrings string] for this 265 project""") 266 267 backup_dir = Option('trac', 'backup_dir', 'db', 268 """Database backup location""") 269 270 timeout = IntOption('trac', 'timeout', '20', 271 """Timeout value for database connection, in seconds. 272 Use '0' to specify ''no timeout''.""") 273 274 debug_sql = BoolOption('trac', 'debug_sql', False, 275 """Show the SQL queries in the Trac log, at DEBUG level. 276 """) 277
278 - def __init__(self):
279 self._cnx_pool = None 280 self._transaction_local = ThreadLocal(wdb=None, rdb=None)
281
282 - def init_db(self):
283 connector, args = self.get_connector() 284 args['schema'] = db_default.schema 285 connector.init_db(**args)
286
287 - def destroy_db(self):
288 connector, args = self.get_connector() 289 # Connections to on-disk db must be closed before deleting it. 290 self.shutdown() 291 connector.destroy_db(**args)
292
293 - def db_exists(self):
294 connector, args = self.get_connector() 295 return connector.db_exists(**args)
296
297 - def create_tables(self, schema):
298 """Create the specified tables. 299 300 :param schema: an iterable of table objects. 301 302 :since: version 1.0.2 303 """ 304 connector = self.get_connector()[0] 305 with self.env.db_transaction as db: 306 for table in schema: 307 for sql in connector.to_sql(table): 308 db(sql)
309
310 - def drop_columns(self, table, columns):
311 """Drops the specified columns from table. 312 313 :since: version 1.2 314 """ 315 table_name = table.name if isinstance(table, Table) else table 316 with self.env.db_transaction as db: 317 if not db.has_table(table_name): 318 raise self.env.db_exc.OperationalError('Table %s not found' % 319 db.quote(table_name)) 320 for col in columns: 321 db.drop_column(table_name, col)
322
323 - def drop_tables(self, schema):
324 """Drop the specified tables. 325 326 :param schema: an iterable of `Table` objects or table names. 327 328 :since: version 1.0.2 329 """ 330 with self.env.db_transaction as db: 331 for table in schema: 332 table_name = table.name if isinstance(table, Table) else table 333 db.drop_table(table_name)
334
335 - def insert_into_tables(self, data_or_callable):
336 """Insert data into existing tables. 337 338 :param data_or_callable: Nested tuples of table names, column names 339 and row data:: 340 341 (table1, 342 (column1, column2), 343 ((row1col1, row1col2), 344 (row2col1, row2col2)), 345 table2, ...) 346 347 or a callable that takes a single parameter 348 `db` and returns the aforementioned nested 349 tuple. 350 :since: version 1.1.3 351 """ 352 with self.env.db_transaction as db: 353 data = data_or_callable(db) if callable(data_or_callable) \ 354 else data_or_callable 355 for table, cols, vals in data: 356 db.executemany("INSERT INTO %s (%s) VALUES (%s)" 357 % (db.quote(table), ','.join(cols), 358 ','.join(['%s'] * len(cols))), vals)
359
360 - def reset_tables(self):
361 """Deletes all data from the tables and resets autoincrement indexes. 362 363 :return: list of names of the tables that were reset. 364 365 :since: version 1.1.3 366 """ 367 with self.env.db_transaction as db: 368 return db.reset_tables()
369
370 - def upgrade_tables(self, new_schema):
371 """Upgrade table schema to `new_schema`, preserving data in 372 columns that exist in the current schema and `new_schema`. 373 374 :param new_schema: tuple or list of `Table` objects 375 376 :since: version 1.2 377 """ 378 with self.env.db_transaction as db: 379 cursor = db.cursor() 380 for new_table in new_schema: 381 temp_table_name = new_table.name + '_old' 382 has_table = self.has_table(new_table) 383 if has_table: 384 old_column_names = set(self.get_column_names(new_table)) 385 new_column_names = {col.name for col in new_table.columns} 386 column_names = old_column_names & new_column_names 387 if column_names: 388 cols_to_copy = ','.join(db.quote(name) 389 for name in column_names) 390 cursor.execute(""" 391 CREATE TEMPORARY TABLE %s AS SELECT * FROM %s 392 """ % (db.quote(temp_table_name), 393 db.quote(new_table.name))) 394 self.drop_tables((new_table,)) 395 self.create_tables((new_table,)) 396 if has_table and column_names: 397 cursor.execute(""" 398 INSERT INTO %s (%s) SELECT %s FROM %s 399 """ % (db.quote(new_table.name), cols_to_copy, 400 cols_to_copy, db.quote(temp_table_name))) 401 for col in new_table.columns: 402 if col.auto_increment: 403 db.update_sequence(cursor, new_table.name, 404 col.name) 405 self.drop_tables((temp_table_name,))
406
407 - def get_connection(self, readonly=False):
408 """Get a database connection from the pool. 409 410 If `readonly` is `True`, the returned connection will purposely 411 lack the `rollback` and `commit` methods. 412 """ 413 if not self._cnx_pool: 414 connector, args = self.get_connector() 415 self._cnx_pool = ConnectionPool(5, connector, **args) 416 db = self._cnx_pool.get_cnx(self.timeout or None) 417 if readonly: 418 db = ConnectionWrapper(db, readonly=True) 419 return db
420
421 - def get_database_version(self, name='database_version'):
422 """Returns the database version from the SYSTEM table as an int, 423 or `False` if the entry is not found. 424 425 :param name: The name of the entry that contains the database version 426 in the SYSTEM table. Defaults to `database_version`, 427 which contains the database version for Trac. 428 """ 429 with self.env.db_query as db: 430 for value, in db(""" 431 SELECT value FROM {0} WHERE name=%s 432 """.format(db.quote('system')), (name,)): 433 return int(value) 434 else: 435 return False
436
437 - def get_exceptions(self):
438 return self.get_connector()[0].get_exceptions()
439
440 - def get_sequence_names(self):
441 """Returns a list of the sequence names. 442 443 :since: 1.3.2 444 """ 445 with self.env.db_query as db: 446 return db.get_sequence_names()
447
448 - def get_table_names(self):
449 """Returns a list of the table names. 450 451 :since: 1.1.6 452 """ 453 with self.env.db_query as db: 454 return db.get_table_names()
455
456 - def get_column_names(self, table):
457 """Returns a list of the column names for `table`. 458 459 :param table: a `Table` object or table name. 460 461 :since: 1.2 462 """ 463 table_name = table.name if isinstance(table, Table) else table 464 with self.env.db_query as db: 465 if not db.has_table(table_name): 466 raise self.env.db_exc.OperationalError('Table %s not found' % 467 db.quote(table_name)) 468 return db.get_column_names(table_name)
469
470 - def has_table(self, table):
471 """Returns whether the table exists.""" 472 table_name = table.name if isinstance(table, Table) else table 473 with self.env.db_query as db: 474 return db.has_table(table_name)
475
476 - def set_database_version(self, version, name='database_version'):
477 """Sets the database version in the SYSTEM table. 478 479 :param version: an integer database version. 480 :param name: The name of the entry that contains the database version 481 in the SYSTEM table. Defaults to `database_version`, 482 which contains the database version for Trac. 483 """ 484 current_database_version = self.get_database_version(name) 485 if current_database_version is False: 486 with self.env.db_transaction as db: 487 db(""" 488 INSERT INTO {0} (name, value) VALUES (%s, %s) 489 """.format(db.quote('system')), (name, version)) 490 else: 491 with self.env.db_transaction as db: 492 db(""" 493 UPDATE {0} SET value=%s WHERE name=%s 494 """.format(db.quote('system')), (version, name)) 495 self.log.info("Upgraded %s from %d to %d", 496 name, current_database_version, version)
497
498 - def needs_upgrade(self, version, name='database_version'):
499 """Checks the database version to determine if an upgrade is needed. 500 501 :param version: the expected integer database version. 502 :param name: the name of the entry in the SYSTEM table that contains 503 the database version. Defaults to `database_version`, 504 which contains the database version for Trac. 505 506 :return: `True` if the stored version is less than the expected 507 version, `False` if it is equal to the expected version. 508 :raises TracError: if the stored version is greater than the expected 509 version. 510 """ 511 dbver = self.get_database_version(name) 512 if dbver == version: 513 return False 514 elif dbver > version: 515 raise TracError(_("Need to downgrade %(name)s.", name=name)) 516 self.log.info("Need to upgrade %s from %d to %d", 517 name, dbver, version) 518 return True
519
520 - def upgrade(self, version, name='database_version', pkg='trac.upgrades'):
521 """Invokes `do_upgrade(env, version, cursor)` in module 522 `"%s/db%i.py" % (pkg, version)`, for each required version upgrade. 523 524 :param version: the expected integer database version. 525 :param name: the name of the entry in the SYSTEM table that contains 526 the database version. Defaults to `database_version`, 527 which contains the database version for Trac. 528 :param pkg: the package containing the upgrade modules. 529 530 :raises TracError: if the package or module doesn't exist. 531 """ 532 dbver = self.get_database_version(name) 533 for i in xrange(dbver + 1, version + 1): 534 module = '%s.db%i' % (pkg, i) 535 try: 536 upgrader = importlib.import_module(module) 537 except ImportError: 538 raise TracError(_("No upgrade module %(module)s.py", 539 module=module)) 540 with self.env.db_transaction as db: 541 cursor = db.cursor() 542 upgrader.do_upgrade(self.env, i, cursor) 543 self.set_database_version(i, name)
544
545 - def shutdown(self, tid=None):
546 if self._cnx_pool: 547 self._cnx_pool.shutdown(tid) 548 if not tid: 549 self._cnx_pool = None
550
551 - def backup(self, dest=None):
552 """Save a backup of the database. 553 554 :param dest: base filename to write to. 555 556 Returns the file actually written. 557 """ 558 connector, args = self.get_connector() 559 if not dest: 560 backup_dir = self.backup_dir 561 if not os.path.isabs(backup_dir): 562 backup_dir = os.path.join(self.env.path, backup_dir) 563 db_str = self.config.get('trac', 'database') 564 db_name, db_path = db_str.split(":", 1) 565 dest_name = '%s.%i.%d.bak' % (db_name, self.env.database_version, 566 int(time.time())) 567 dest = os.path.join(backup_dir, dest_name) 568 else: 569 backup_dir = os.path.dirname(dest) 570 if not os.path.exists(backup_dir): 571 os.makedirs(backup_dir) 572 return connector.backup(dest)
573
574 - def get_connector(self):
575 scheme, args = parse_connection_uri(self.connection_uri) 576 candidates = [ 577 (priority, connector) 578 for connector in self.connectors 579 for scheme_, priority in connector.get_supported_schemes() 580 if scheme_ == scheme 581 ] 582 if not candidates: 583 raise TracError(_('Unsupported database type "%(scheme)s"', 584 scheme=scheme)) 585 priority, connector = max(candidates) 586 if priority < 0: 587 raise TracError(connector.error) 588 589 if scheme == 'sqlite': 590 if args['path'] == ':memory:': 591 # Special case for SQLite in-memory database, always get 592 # the /same/ connection over 593 pass 594 elif not os.path.isabs(args['path']): 595 # Special case for SQLite to support a path relative to the 596 # environment directory 597 args['path'] = os.path.join(self.env.path, 598 args['path'].lstrip('/')) 599 600 if self.debug_sql: 601 args['log'] = self.log 602 return connector, args
603 604 # IEnvironmentSetupParticipant methods 605
606 - def environment_created(self):
607 """Insert default data into the database.""" 608 self.insert_into_tables(db_default.get_data)
609 612
613 - def upgrade_environment(self):
615 616 # ISystemInfoProvider methods 617
618 - def get_system_info(self):
619 connector = self.get_connector()[0] 620 for info in connector.get_system_info(): 621 yield info
622
623 624 -def get_column_names(cursor):
625 """Retrieve column names from a cursor, if possible.""" 626 return [unicode(d[0], 'utf-8') if isinstance(d[0], str) else d[0] 627 for d in cursor.description] if cursor.description else []
628
629 630 -def parse_connection_uri(db_str):
631 """Parse the database connection string. 632 633 The database connection string for an environment is specified through 634 the `database` option in the `[trac]` section of trac.ini. 635 636 :return: a tuple containing the scheme and a dictionary of attributes: 637 `user`, `password`, `host`, `port`, `path`, `params`. 638 :since: 1.1.3 639 """ 640 if not db_str: 641 section = tag.a("[trac]", 642 title=_("TracIni documentation"), 643 class_='trac-target-new', 644 href='https://trac.edgewall.org/wiki/TracIni' 645 '#trac-section') 646 raise ConfigurationError( 647 tag_("Database connection string is empty. Set the %(option)s " 648 "configuration option in the %(section)s section of " 649 "trac.ini. Please refer to the %(doc)s for help.", 650 option=tag.code("database"), section=section, 651 doc=_doc_db_str())) 652 653 try: 654 scheme, rest = db_str.split(':', 1) 655 except ValueError: 656 raise _invalid_db_str(db_str) 657 658 if not rest.startswith('/'): 659 if scheme == 'sqlite' and rest: 660 # Support for relative and in-memory SQLite connection strings 661 host = None 662 path = rest 663 else: 664 raise _invalid_db_str(db_str) 665 else: 666 if not rest.startswith('//'): 667 host = None 668 rest = rest[1:] 669 elif rest.startswith('///'): 670 host = None 671 rest = rest[3:] 672 else: 673 rest = rest[2:] 674 if '/' in rest: 675 host, rest = rest.split('/', 1) 676 else: 677 host = rest 678 rest = '' 679 path = None 680 681 if host and '@' in host: 682 user, host = host.split('@', 1) 683 if ':' in user: 684 user, password = user.split(':', 1) 685 else: 686 password = None 687 if user: 688 user = urllib.unquote(user) 689 if password: 690 password = unicode_passwd(urllib.unquote(password)) 691 else: 692 user = password = None 693 694 if host and ':' in host: 695 host, port = host.split(':', 1) 696 try: 697 port = int(port) 698 except ValueError: 699 raise _invalid_db_str(db_str) 700 else: 701 port = None 702 703 if not path: 704 path = '/' + rest 705 if os.name == 'nt': 706 # Support local paths containing drive letters on Win32 707 if len(rest) > 1 and rest[1] == '|': 708 path = "%s:%s" % (rest[0], rest[2:]) 709 710 params = {} 711 if '?' in path: 712 path, qs = path.split('?', 1) 713 qs = qs.split('&') 714 for param in qs: 715 try: 716 name, value = param.split('=', 1) 717 except ValueError: 718 raise _invalid_db_str(db_str) 719 value = urllib.unquote(value) 720 params[name] = value 721 722 args = zip(('user', 'password', 'host', 'port', 'path', 'params'), 723 (user, password, host, port, path, params)) 724 return scheme, {key: value for key, value in args if value}
725
726 727 -def _invalid_db_str(db_str):
728 return ConfigurationError( 729 tag_("Invalid format %(db_str)s for the database connection string. " 730 "Please refer to the %(doc)s for help.", 731 db_str=tag.code(db_str), doc=_doc_db_str()))
732
733 734 -def _doc_db_str():
735 return tag.a(_("documentation"), 736 title=_("Database Connection Strings documentation"), 737 class_='trac-target-new', 738 href='https://trac.edgewall.org/wiki/' 739 'TracIni#DatabaseConnectionStrings')
740