Package trac :: Package db :: Package tests :: Module api

Source Code for Module trac.db.tests.api

  1  # -*- coding: utf-8 -*- 
  2  # 
  3  # Copyright (C) 2005-2020 Edgewall Software 
  4  # All rights reserved. 
  5  # 
  6  # This software is licensed as described in the file COPYING, which 
  7  # you should have received as part of this distribution. The terms 
  8  # are also available at https://trac.edgewall.org/wiki/TracLicense. 
  9  # 
 10  # This software consists of voluntary contributions made by many 
 11  # individuals. For the exact contribution history, see the revision 
 12  # history and logs, available at https://trac.edgewall.org/log/. 
 13   
 14  import copy 
 15  import os 
 16  import unittest 
 17   
 18  from trac.config import ConfigurationError 
 19  from trac.db.api import DatabaseManager, get_column_names, \ 
 20                          parse_connection_uri 
 21  from trac.db_default import (schema as default_schema, 
 22                               db_version as default_db_version) 
 23  from trac.db.schema import Column, Table 
 24  from trac.test import EnvironmentStub, get_dburi 
 25   
 26   
27 -class ParseConnectionStringTestCase(unittest.TestCase):
28
29 - def test_sqlite_relative(self):
30 # Default syntax for specifying DB path relative to the environment 31 # directory 32 self.assertEqual(('sqlite', {'path': 'db/trac.db'}), 33 parse_connection_uri('sqlite:db/trac.db'))
34
35 - def test_sqlite_absolute(self):
36 # Standard syntax 37 self.assertEqual(('sqlite', {'path': '/var/db/trac.db'}), 38 parse_connection_uri('sqlite:///var/db/trac.db')) 39 # Legacy syntax 40 self.assertEqual(('sqlite', {'path': '/var/db/trac.db'}), 41 parse_connection_uri('sqlite:/var/db/trac.db'))
42
44 # In-memory database 45 self.assertEqual(('sqlite', {'path': 'db/trac.db', 46 'params': {'timeout': '10000'}}), 47 parse_connection_uri('sqlite:db/trac.db?timeout=10000'))
48
49 - def test_sqlite_windows_path(self):
50 # In-memory database 51 os_name = os.name 52 try: 53 os.name = 'nt' 54 self.assertEqual(('sqlite', {'path': 'C:/project/db/trac.db'}), 55 parse_connection_uri('sqlite:C|/project/db/trac.db')) 56 finally: 57 os.name = os_name
58
59 - def test_postgres_simple(self):
60 self.assertEqual(('postgres', {'host': 'localhost', 'path': '/trac'}), 61 parse_connection_uri('postgres://localhost/trac'))
62
63 - def test_postgres_with_port(self):
64 self.assertEqual(('postgres', {'host': 'localhost', 'port': 9431, 65 'path': '/trac'}), 66 parse_connection_uri('postgres://localhost:9431/trac'))
67
68 - def test_postgres_with_creds(self):
69 self.assertEqual(('postgres', {'user': 'john', 'password': 'letmein', 70 'host': 'localhost', 'port': 9431, 71 'path': '/trac'}), 72 parse_connection_uri('postgres://john:letmein@localhost:9431/trac'))
73
75 self.assertEqual(('postgres', {'user': 'john', 'password': ':@/', 76 'host': 'localhost', 'path': '/trac'}), 77 parse_connection_uri('postgres://john:%3a%40%2f@localhost/trac'))
78
79 - def test_mysql_simple(self):
80 self.assertEqual(('mysql', {'host': 'localhost', 'path': '/trac'}), 81 parse_connection_uri('mysql://localhost/trac'))
82
83 - def test_mysql_with_creds(self):
84 self.assertEqual(('mysql', {'user': 'john', 'password': 'letmein', 85 'host': 'localhost', 'port': 3306, 86 'path': '/trac'}), 87 parse_connection_uri('mysql://john:letmein@localhost:3306/trac'))
88
89 - def test_empty_string(self):
90 self.assertRaises(ConfigurationError, parse_connection_uri, '')
91
92 - def test_invalid_port(self):
93 self.assertRaises(ConfigurationError, parse_connection_uri, 94 'postgres://localhost:42:42')
95
96 - def test_invalid_schema(self):
97 self.assertRaises(ConfigurationError, parse_connection_uri, 98 'sqlitedb/trac.db')
99
100 - def test_no_path(self):
101 self.assertRaises(ConfigurationError, parse_connection_uri, 102 'sqlite:')
103
105 self.assertRaises(ConfigurationError, parse_connection_uri, 106 'postgres://localhost/schema?name')
107 108
109 -class StringsTestCase(unittest.TestCase):
110
111 - def setUp(self):
112 self.env = EnvironmentStub()
113
114 - def tearDown(self):
115 self.env.reset_db()
116
117 - def test_insert_unicode(self):
118 with self.env.db_transaction as db: 119 quoted = db.quote('system') 120 db("INSERT INTO " + quoted + " (name,value) VALUES (%s,%s)", 121 ('test-unicode', u'ünicöde')) 122 self.assertEqual([(u'ünicöde',)], self.env.db_query( 123 "SELECT value FROM " + quoted + " WHERE name='test-unicode'"))
124
125 - def test_insert_empty(self):
126 from trac.util.text import empty 127 with self.env.db_transaction as db: 128 quoted = db.quote('system') 129 db("INSERT INTO " + quoted + " (name,value) VALUES (%s,%s)", 130 ('test-empty', empty)) 131 self.assertEqual([(u'',)], self.env.db_query( 132 "SELECT value FROM " + quoted + " WHERE name='test-empty'"))
133
134 - def test_insert_markup(self):
135 from trac.util.html import Markup 136 with self.env.db_transaction as db: 137 quoted = db.quote('system') 138 db("INSERT INTO " + quoted + " (name,value) VALUES (%s,%s)", 139 ('test-markup', Markup(u'<em>märkup</em>'))) 140 self.assertEqual([(u'<em>märkup</em>',)], self.env.db_query( 141 "SELECT value FROM " + quoted + " WHERE name='test-markup'"))
142
143 - def test_quote(self):
144 with self.env.db_query as db: 145 cursor = db.cursor() 146 cursor.execute('SELECT 1 AS %s' % \ 147 db.quote(r'alpha\`\"\'\\beta``gamma""delta')) 148 self.assertEqual(r'alpha\`\"\'\\beta``gamma""delta', 149 get_column_names(cursor)[0])
150
152 name = """%?`%s"%'%%""" 153 154 def test(logging=False): 155 with self.env.db_query as db: 156 cursor = db.cursor() 157 if logging: 158 cursor.log = self.env.log 159 160 cursor.execute('SELECT 1 AS ' + db.quote(name)) 161 self.assertEqual(name, get_column_names(cursor)[0]) 162 cursor.execute('SELECT %s AS ' + db.quote(name), (42,)) 163 self.assertEqual(name, get_column_names(cursor)[0]) 164 stmt = """ 165 UPDATE {0} SET value=%s WHERE 1=(SELECT 0 AS {1}) 166 """.format(db.quote('system'), db.quote(name)) 167 cursor.executemany(stmt, []) 168 cursor.executemany(stmt, [('42',), ('43',)])
169 170 test() 171 test(True)
172
173 - def test_prefix_match_case_sensitive(self):
174 with self.env.db_transaction as db: 175 db.executemany(""" 176 INSERT INTO {0} (name,value) VALUES (%s,1) 177 """.format(db.quote('system')), 178 [('blahblah',), ('BlahBlah',), ('BLAHBLAH',), (u'BlähBlah',), 179 (u'BlahBläh',)]) 180 181 with self.env.db_query as db: 182 names = sorted(name for name, in db( 183 "SELECT name FROM {0} WHERE name {1}" 184 .format(db.quote('system'), db.prefix_match()), 185 (db.prefix_match_value('Blah'),))) 186 self.assertEqual('BlahBlah', names[0]) 187 self.assertEqual(u'BlahBläh', names[1]) 188 self.assertEqual(2, len(names))
189
190 - def test_prefix_match_metachars(self):
191 def do_query(prefix): 192 with self.env.db_query as db: 193 return [name for name, in db(""" 194 SELECT name FROM {0} WHERE name {1} ORDER BY name 195 """.format(db.quote('system'), db.prefix_match()), 196 (db.prefix_match_value(prefix),))]
197 198 values = ['foo*bar', 'foo*bar!', 'foo?bar', 'foo?bar!', 199 'foo[bar', 'foo[bar!', 'foo]bar', 'foo]bar!', 200 'foo%bar', 'foo%bar!', 'foo_bar', 'foo_bar!', 201 'foo/bar', 'foo/bar!', 'fo*ob?ar[fo]ob%ar_fo/obar'] 202 with self.env.db_transaction as db: 203 db.executemany(""" 204 INSERT INTO {0} (name,value) VALUES (%s,1) 205 """.format(db.quote('system')), 206 [(value,) for value in values]) 207 208 self.assertEqual(['foo*bar', 'foo*bar!'], do_query('foo*')) 209 self.assertEqual(['foo?bar', 'foo?bar!'], do_query('foo?')) 210 self.assertEqual(['foo[bar', 'foo[bar!'], do_query('foo[')) 211 self.assertEqual(['foo]bar', 'foo]bar!'], do_query('foo]')) 212 self.assertEqual(['foo%bar', 'foo%bar!'], do_query('foo%')) 213 self.assertEqual(['foo_bar', 'foo_bar!'], do_query('foo_')) 214 self.assertEqual(['foo/bar', 'foo/bar!'], do_query('foo/')) 215 self.assertEqual(['fo*ob?ar[fo]ob%ar_fo/obar'], do_query('fo*')) 216 self.assertEqual(['fo*ob?ar[fo]ob%ar_fo/obar'], 217 do_query('fo*ob?ar[fo]ob%ar_fo/obar')) 218 219
220 -class ConnectionTestCase(unittest.TestCase):
221 - def setUp(self):
222 self.env = EnvironmentStub() 223 self.schema = [ 224 Table('HOURS', key='ID')[ 225 Column('ID', auto_increment=True), 226 Column('AUTHOR') 227 ], 228 Table('blog', key='bid')[ 229 Column('bid', auto_increment=True), 230 Column('author'), 231 Column('comment') 232 ] 233 ] 234 self.dbm = DatabaseManager(self.env) 235 self.dbm.drop_tables(self.schema) 236 self.dbm.create_tables(self.schema)
237
238 - def tearDown(self):
239 DatabaseManager(self.env).drop_tables(self.schema) 240 self.env.reset_db()
241
242 - def test_drop_column(self):
243 """Data is preserved when column is dropped.""" 244 table_data = [ 245 ('blog', ('author', 'comment'), 246 (('author1', 'comment one'), 247 ('author2', 'comment two'))), 248 ] 249 self.dbm.insert_into_tables(table_data) 250 251 with self.env.db_transaction as db: 252 db.drop_column('blog', 'comment') 253 254 data = list(self.env.db_query("SELECT * FROM blog")) 255 self.assertEqual((1, 'author1'), data[0]) 256 self.assertEqual((2, 'author2'), data[1])
257
259 """Error is not raised when dropping non-existent column.""" 260 table_data = [ 261 ('blog', ('author', 'comment'), 262 (('author1', 'comment one'), 263 ('author2', 'comment two'))), 264 ] 265 self.dbm.insert_into_tables(table_data) 266 267 with self.env.db_transaction as db: 268 db.drop_column('blog', 'tags') 269 270 data = list(self.env.db_query("SELECT * FROM blog")) 271 self.assertEqual((1, 'author1', 'comment one'), data[0]) 272 self.assertEqual((2, 'author2', 'comment two'), data[1])
273
275 """Transaction is rolled back when an exception occurs in the 276 transaction context manager. 277 """ 278 insert_sql = "INSERT INTO blog (bid, author) VALUES (42, 'anonymous')" 279 try: 280 with self.env.db_transaction as db: 281 db(insert_sql) 282 db(insert_sql) 283 except self.env.db_exc.IntegrityError: 284 pass 285 286 for _, in self.env.db_query(""" 287 SELECT author FROM blog WHERE bid=42 288 """): 289 self.fail("Transaction was not rolled back")
290
292 """Transaction is rolled back when an exception occurs in the 293 inner transaction context manager. 294 """ 295 sql = "INSERT INTO blog (bid, author) VALUES (42, 'anonymous')" 296 try: 297 with self.env.db_transaction as db_outer: 298 db_outer(sql) 299 with self.env.db_transaction as db_inner: 300 db_inner(sql) 301 except self.env.db_exc.IntegrityError: 302 pass 303 304 for _, in self.env.db_query(""" 305 SELECT author FROM blog WHERE bid=42 306 """): 307 self.fail("Transaction was not rolled back")
308
309 - def test_get_last_id(self):
310 q = "INSERT INTO report (author) VALUES ('anonymous')" 311 with self.env.db_transaction as db: 312 cursor = db.cursor() 313 cursor.execute(q) 314 # Row ID correct before... 315 id1 = db.get_last_id(cursor, 'report') 316 db.commit() 317 cursor.execute(q) 318 # ... and after commit() 319 db.commit() 320 id2 = db.get_last_id(cursor, 'report') 321 322 self.assertNotEqual(0, id1) 323 self.assertEqual(id1 + 1, id2)
324
326 with self.env.db_transaction as db: 327 db("INSERT INTO report (id, author) VALUES (42, 'anonymous')") 328 cursor = db.cursor() 329 db.update_sequence(cursor, 'report') 330 331 self.env.db_transaction( 332 "INSERT INTO report (author) VALUES ('next-id')") 333 334 self.assertEqual(43, self.env.db_query( 335 "SELECT id FROM report WHERE author='next-id'")[0][0])
336
338 with self.env.db_transaction as db: 339 cursor = db.cursor() 340 cursor.execute( 341 "INSERT INTO blog (bid, author) VALUES (42, 'anonymous')") 342 db.update_sequence(cursor, 'blog', 'bid') 343 344 self.env.db_transaction( 345 "INSERT INTO blog (author) VALUES ('next-id')") 346 347 self.assertEqual(43, self.env.db_query( 348 "SELECT bid FROM blog WHERE author='next-id'")[0][0])
349
351 """Test for regression described in comment:4:ticket:11512.""" 352 with self.env.db_transaction as db: 353 db("INSERT INTO %s (%s, %s) VALUES (42, 'anonymous')" 354 % (db.quote('HOURS'), db.quote('ID'), db.quote('AUTHOR'))) 355 cursor = db.cursor() 356 db.update_sequence(cursor, 'HOURS', 'ID') 357 358 with self.env.db_transaction as db: 359 cursor = db.cursor() 360 cursor.execute( 361 "INSERT INTO %s (%s) VALUES ('next-id')" 362 % (db.quote('HOURS'), db.quote('AUTHOR'))) 363 last_id = db.get_last_id(cursor, 'HOURS', 'ID') 364 365 self.assertEqual(43, last_id)
366
367 - def test_get_table_names(self):
368 schema = default_schema + self.schema 369 with self.env.db_query as db: 370 # Some DB (e.g. MariaDB) normalize the table names to lower case 371 self.assertEqual( 372 sorted(table.name.lower() for table in schema), 373 sorted(name.lower() for name in db.get_table_names()))
374
375 - def test_get_column_names(self):
376 schema = default_schema + self.schema 377 with self.env.db_query as db: 378 for table in schema: 379 column_names = [col.name for col in table.columns] 380 self.assertEqual(column_names, 381 db.get_column_names(table.name))
382
384 with self.assertRaises(self.env.db_exc.OperationalError) as cm: 385 self.dbm.get_column_names('blah') 386 self.assertIn(unicode(cm.exception), ('Table "blah" not found', 387 'Table `blah` not found'))
388 389
390 -class DatabaseManagerTestCase(unittest.TestCase):
391
392 - def setUp(self):
393 self.env = EnvironmentStub(default_data=True) 394 self.dbm = DatabaseManager(self.env)
395
396 - def tearDown(self):
397 self.env.reset_db()
398
399 - def test_destroy_db(self):
400 """Database doesn't exist after calling destroy_db.""" 401 with self.env.db_query as db: 402 db("SELECT name FROM " + db.quote('system')) 403 self.assertIsNotNone(self.dbm._cnx_pool) 404 self.dbm.destroy_db() 405 self.assertIsNone(self.dbm._cnx_pool) # No connection pool 406 scheme, params = parse_connection_uri(get_dburi()) 407 if scheme != 'postgres' or params.get('schema', 'public') != 'public': 408 self.assertFalse(self.dbm.db_exists()) 409 else: 410 self.assertEqual([], self.dbm.get_table_names())
411
412 - def test_get_column_names(self):
413 """Get column names for the default database.""" 414 for table in default_schema: 415 column_names = [col.name for col in table.columns] 416 self.assertEqual(column_names, 417 self.dbm.get_column_names(table.name))
418
420 """Get database version for the default entry named 421 `database_version`. 422 """ 423 self.assertEqual(default_db_version, self.dbm.get_database_version())
424
425 - def test_get_table_names(self):
426 """Get table names for the default database.""" 427 self.assertEqual(sorted(table.name for table in default_schema), 428 sorted(self.dbm.get_table_names()))
429
430 - def test_has_table(self):
431 self.assertIs(True, self.dbm.has_table('system')) 432 self.assertIs(True, self.dbm.has_table('wiki')) 433 self.assertIs(False, self.dbm.has_table('trac')) 434 self.assertIs(False, self.dbm.has_table('blah.blah'))
435
436 - def test_no_database_version(self):
437 """False is returned when entry doesn't exist""" 438 self.assertFalse(self.dbm.get_database_version('trac_plugin_version'))
439
441 """Set database version for the default entry named 442 `database_version`. 443 """ 444 new_db_version = default_db_version + 1 445 self.dbm.set_database_version(new_db_version) 446 self.assertEqual(new_db_version, self.dbm.get_database_version()) 447 self.assertEqual([('INFO', 'Upgraded database_version from 45 to 46')], 448 self.env.log_messages) 449 450 # Restore the previous version to avoid destroying the database 451 # on teardown 452 self.dbm.set_database_version(default_db_version) 453 self.assertEqual(default_db_version, self.dbm.get_database_version())
454
456 """Get and set database version for an entry with an 457 arbitrary name. 458 """ 459 name = 'trac_plugin_version' 460 db_ver = 1 461 462 self.dbm.set_database_version(db_ver, name) 463 self.assertEqual([], self.env.log_messages) 464 self.assertEqual(db_ver, self.dbm.get_database_version(name)) 465 # DB update will be skipped when new value equals database version 466 self.dbm.set_database_version(db_ver, name) 467 self.assertEqual([], self.env.log_messages)
468
469 - def test_get_sequence_names(self):
470 sequence_names = [] 471 if self.dbm.connection_uri.startswith('postgres'): 472 for table in default_schema: 473 for column in table.columns: 474 if column.name == 'id' and column.auto_increment: 475 sequence_names.append(table.name) 476 sequence_names.sort() 477 478 self.assertEqual(sequence_names, self.dbm.get_sequence_names())
479 480
481 -class ModifyTableTestCase(unittest.TestCase):
482
483 - def setUp(self):
484 self.env = EnvironmentStub() 485 self.dbm = DatabaseManager(self.env) 486 self.schema = [ 487 Table('table1', key='col1')[ 488 Column('col1', auto_increment=True), 489 Column('col2'), 490 Column('col3'), 491 ], 492 Table('table2', key='col1')[ 493 Column('col1'), 494 Column('col2'), 495 ], 496 Table('table3', key='col2')[ 497 Column('col1'), 498 Column('col2', type='int'), 499 Column('col3') 500 ] 501 ] 502 self.dbm.create_tables(self.schema) 503 self.new_schema = copy.deepcopy([self.schema[0], self.schema[2]]) 504 self.new_schema[0].remove_columns(('col2',)) 505 self.new_schema[1].columns.append(Column('col4')) 506 self.new_schema.append( 507 Table('table4')[ 508 Column('col1'), 509 ] 510 )
511
512 - def tearDown(self):
513 self.dbm.drop_tables(['table1', 'table2', 'table3', 'table4']) 514 self.env.reset_db()
515
516 - def _insert_data(self):
517 table_data = [ 518 ('table1', ('col2', 'col3'), 519 (('data1', 'data2'), 520 ('data3', 'data4'))), 521 ('table2', ('col1', 'col2'), 522 (('data5', 'data6'), 523 ('data7', 'data8'))), 524 ('table3', ('col1', 'col2', 'col3'), 525 (('data9', 10, 'data11'), 526 ('data12', 13, 'data14'))), 527 ] 528 self.dbm.insert_into_tables(table_data)
529
530 - def test_drop_columns(self):
531 """Data is preserved when column is dropped.""" 532 self._insert_data() 533 534 self.dbm.drop_columns('table1', ('col2',)) 535 536 self.assertEqual(['col1', 'col3'], self.dbm.get_column_names('table1')) 537 data = list(self.env.db_query("SELECT * FROM table1")) 538 self.assertEqual((1, 'data2'), data[0]) 539 self.assertEqual((2, 'data4'), data[1])
540
542 """Data is preserved when columns are dropped.""" 543 self._insert_data() 544 545 self.dbm.drop_columns('table3', ('col1', 'col3')) 546 547 self.assertEqual(['col2'], self.dbm.get_column_names('table3')) 548 data = list(self.env.db_query("SELECT * FROM table3")) 549 self.assertEqual((10,), data[0]) 550 self.assertEqual((13,), data[1])
551
553 with self.assertRaises(self.env.db_exc.OperationalError) as cm: 554 self.dbm.drop_columns('blah', ('col1',)) 555 self.assertIn(unicode(cm.exception), ('Table "blah" not found', 556 'Table `blah` not found'))
557
559 """The upgraded tables have the new schema.""" 560 self.dbm.upgrade_tables(self.new_schema) 561 562 for table in self.new_schema: 563 self.assertEqual([col.name for col in table.columns], 564 self.dbm.get_column_names(table.name))
565
567 """The data is migrated to the upgraded tables.""" 568 self._insert_data() 569 570 self.dbm.upgrade_tables(self.new_schema) 571 self.env.db_transaction(""" 572 INSERT INTO table1 (col3) VALUES ('data12') 573 """) 574 575 data = list(self.env.db_query("SELECT * FROM table1")) 576 self.assertEqual((1, 'data2'), data[0]) 577 self.assertEqual((2, 'data4'), data[1]) 578 self.assertEqual(3, self.env.db_query(""" 579 SELECT col1 FROM table1 WHERE col3='data12'""")[0][0]) 580 data = list(self.env.db_query("SELECT * FROM table2")) 581 self.assertEqual(('data5', 'data6'), data[0]) 582 self.assertEqual(('data7', 'data8'), data[1]) 583 data = list(self.env.db_query("SELECT * FROM table3")) 584 self.assertEqual(('data9', 10, 'data11', None), data[0]) 585 self.assertEqual(('data12', 13, 'data14', None), data[1])
586
588 schema = [ 589 Table('table1', key='id')[ 590 Column('id', auto_increment=True), 591 Column('name'), 592 Column('value'), 593 ], 594 ] 595 self.dbm.upgrade_tables(schema) 596 self.assertEqual(['id', 'name', 'value'], 597 self.dbm.get_column_names('table1')) 598 self.assertEqual([], list(self.env.db_query("SELECT * FROM table1")))
599 600
601 -def test_suite():
602 suite = unittest.TestSuite() 603 suite.addTest(unittest.makeSuite(ParseConnectionStringTestCase)) 604 suite.addTest(unittest.makeSuite(StringsTestCase)) 605 suite.addTest(unittest.makeSuite(ConnectionTestCase)) 606 suite.addTest(unittest.makeSuite(DatabaseManagerTestCase)) 607 suite.addTest(unittest.makeSuite(ModifyTableTestCase)) 608 return suite
609 610 611 if __name__ == '__main__': 612 unittest.main(defaultTest='test_suite') 613