1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 import re, os
18
19 from genshi import Markup
20
21 from trac.core import *
22 from trac.config import Option
23 from trac.db.api import IDatabaseConnector, _parse_db_str
24 from trac.db.util import ConnectionWrapper, IterableCursor
25 from trac.util import get_pkginfo
26 from trac.util.compat import any, close_fds
27 from trac.util.text import empty, exception_to_unicode, to_unicode
28 from trac.util.translation import _
29
30 has_psycopg = False
31 try:
32 import psycopg2 as psycopg
33 import psycopg2.extensions
34 from psycopg2 import DataError, ProgrammingError
35 from psycopg2.extensions import register_type, UNICODE, \
36 register_adapter, AsIs, QuotedString
37
38 register_type(UNICODE)
39 register_adapter(Markup, lambda markup: QuotedString(unicode(markup)))
40 register_adapter(type(empty), lambda empty: AsIs("''"))
41
42 has_psycopg = True
43 except ImportError:
44 pass
45
46 _like_escape_re = re.compile(r'([/_%])')
47
48
49 _type_map = {
50 'int64': 'bigint',
51 }
52
53
55 """Quote the parameters and assemble the DSN."""
56
57 dsn = {'dbname': path, 'user': user, 'password': password, 'host': host,
58 'port': port}
59 return ' '.join(["%s='%s'" % (k,v) for k,v in dsn.iteritems() if v])
60
61
62 -class PostgreSQLConnector(Component):
63 """Database connector for PostgreSQL.
64
65 Database URLs should be of the form:
66 {{{
67 postgres://user[:password]@host[:port]/database[?schema=my_schema]
68 }}}
69 """
70 implements(IDatabaseConnector)
71
72 pg_dump_path = Option('trac', 'pg_dump_path', 'pg_dump',
73 """Location of pg_dump for Postgres database backups""")
74
76 self._version = None
77 self.error = None
78
80 if not has_psycopg:
81 self.error = _("Cannot load Python bindings for PostgreSQL")
82 yield ('postgres', self.error and -1 or 1)
83
84 - def get_connection(self, path, log=None, user=None, password=None,
85 host=None, port=None, params={}):
86 cnx = PostgreSQLConnection(path, log, user, password, host, port,
87 params)
88 if not self._version:
89 self._version = get_pkginfo(psycopg).get('version',
90 psycopg.__version__)
91 self.env.systeminfo.append(('psycopg2', self._version))
92 self.required = True
93 return cnx
94
95 - def init_db(self, path, log=None, user=None, password=None, host=None,
96 port=None, params={}):
97 cnx = self.get_connection(path, log, user, password, host, port,
98 params)
99 cursor = cnx.cursor()
100 if cnx.schema:
101 cursor.execute('CREATE SCHEMA "%s"' % cnx.schema)
102 cursor.execute('SET search_path TO %s', (cnx.schema,))
103 from trac.db_default import schema
104 for table in schema:
105 for stmt in self.to_sql(table):
106 cursor.execute(stmt)
107 cnx.commit()
108
109 - def to_sql(self, table):
110 sql = ['CREATE TABLE "%s" (' % table.name]
111 coldefs = []
112 for column in table.columns:
113 ctype = column.type
114 ctype = _type_map.get(ctype, ctype)
115 if column.auto_increment:
116 ctype = 'SERIAL'
117 if len(table.key) == 1 and column.name in table.key:
118 ctype += ' PRIMARY KEY'
119 coldefs.append(' "%s" %s' % (column.name, ctype))
120 if len(table.key) > 1:
121 coldefs.append(' CONSTRAINT "%s_pk" PRIMARY KEY ("%s")'
122 % (table.name, '","'.join(table.key)))
123 sql.append(',\n'.join(coldefs) + '\n)')
124 yield '\n'.join(sql)
125 for index in table.indices:
126 unique = index.unique and 'UNIQUE' or ''
127 yield 'CREATE %s INDEX "%s_%s_idx" ON "%s" ("%s")' % \
128 (unique, table.name,
129 '_'.join(index.columns), table.name,
130 '","'.join(index.columns))
131
132 - def alter_column_types(self, table, columns):
133 """Yield SQL statements altering the type of one or more columns of
134 a table.
135
136 Type changes are specified as a `columns` dict mapping column names
137 to `(from, to)` SQL type tuples.
138 """
139 alterations = []
140 for name, (from_, to) in sorted(columns.iteritems()):
141 to = _type_map.get(to, to)
142 if to != _type_map.get(from_, from_):
143 alterations.append((name, to))
144 if alterations:
145 yield "ALTER TABLE %s %s" % (table,
146 ', '.join("ALTER COLUMN %s TYPE %s" % each
147 for each in alterations))
148
149 - def backup(self, dest_file):
150 from subprocess import Popen, PIPE
151 db_url = self.env.config.get('trac', 'database')
152 scheme, db_prop = _parse_db_str(db_url)
153 db_params = db_prop.setdefault('params', {})
154 db_name = os.path.basename(db_prop['path'])
155
156 args = [self.pg_dump_path, '-C', '--inserts', '-x', '-Z', '8']
157 if 'user' in db_prop:
158 args.extend(['-U', db_prop['user']])
159 if 'host' in db_params:
160 host = db_params['host']
161 else:
162 host = db_prop.get('host')
163 if host:
164 args.extend(['-h', host])
165 if '/' not in host:
166 args.extend(['-p', str(db_prop.get('port', '5432'))])
167
168 if 'schema' in db_params:
169 try:
170 p = Popen([self.pg_dump_path, '--version'], stdout=PIPE,
171 close_fds=close_fds)
172 except OSError, e:
173 raise TracError(_("Unable to run %(path)s: %(msg)s",
174 path=self.pg_dump_path,
175 msg=exception_to_unicode(e)))
176
177 version = p.communicate()[0]
178 if re.search(r' 8\.[01]\.', version):
179 args.extend(['-n', db_params['schema']])
180 else:
181 args.extend(['-n', '"%s"' % db_params['schema']])
182
183 dest_file += ".gz"
184 args.extend(['-f', dest_file, db_name])
185
186 environ = os.environ.copy()
187 if 'password' in db_prop:
188 environ['PGPASSWORD'] = str(db_prop['password'])
189 try:
190 p = Popen(args, env=environ, stderr=PIPE, close_fds=close_fds)
191 except OSError, e:
192 raise TracError(_("Unable to run %(path)s: %(msg)s",
193 path=self.pg_dump_path,
194 msg=exception_to_unicode(e)))
195 errmsg = p.communicate()[1]
196 if p.returncode != 0:
197 raise TracError(_("pg_dump failed: %(msg)s",
198 msg=to_unicode(errmsg.strip())))
199 if not os.path.exists(dest_file):
200 raise TracError(_("No destination file created"))
201 return dest_file
202
203
204 -class PostgreSQLConnection(ConnectionWrapper):
205 """Connection wrapper for PostgreSQL."""
206
207 poolable = True
208
209 - def __init__(self, path, log=None, user=None, password=None, host=None,
210 port=None, params={}):
211 if path.startswith('/'):
212 path = path[1:]
213 if 'host' in params:
214 host = params['host']
215
216 cnx = psycopg.connect(assemble_pg_dsn(path, user, password, host,
217 port))
218
219 cnx.set_client_encoding('UNICODE')
220 try:
221 self.schema = None
222 if 'schema' in params:
223 self.schema = params['schema']
224 cnx.cursor().execute('SET search_path TO %s', (self.schema,))
225 cnx.commit()
226 except (DataError, ProgrammingError):
227 cnx.rollback()
228 ConnectionWrapper.__init__(self, cnx, log)
229
230 self._version = self._get_version()
231
232 - def cast(self, column, type):
233
234 return 'CAST(%s AS %s)' % (column, _type_map.get(type, type))
235
236 - def concat(self, *args):
237 return '||'.join(args)
238
240 """Return a case-insensitive LIKE clause."""
241 return "ILIKE %s ESCAPE '/'"
242
243 - def like_escape(self, text):
244 return _like_escape_re.sub(r'/\1', text)
245
246 - def prefix_match(self):
247 """Return a case sensitive prefix-matching operator."""
248 return "LIKE %s ESCAPE '/'"
249
250 - def prefix_match_value(self, prefix):
251 """Return a value for case sensitive prefix-matching operator."""
252 return self.like_escape(prefix) + '%'
253
254 - def quote(self, identifier):
255 """Return the quoted identifier."""
256 return '"%s"' % identifier.replace('"', '""')
257
258 - def get_last_id(self, cursor, table, column='id'):
259 cursor.execute("SELECT CURRVAL(%s)",
260 (self.quote(self._sequence_name(table, column)),))
261 return cursor.fetchone()[0]
262
263 - def update_sequence(self, cursor, table, column='id'):
264 cursor.execute("SELECT SETVAL(%%s, (SELECT MAX(%s) FROM %s))"
265 % (self.quote(column), self.quote(table)),
266 (self.quote(self._sequence_name(table, column)),))
267
269 return IterableCursor(self.cnx.cursor(), self.log)
270
271 - def drop_table(self, table):
272 cursor = self.cursor()
273 if self._version and any(self._version.startswith(version)
274 for version in ('8.0.', '8.1.')):
275 cursor.execute("""SELECT table_name FROM information_schema.tables
276 WHERE table_schema=current_schema()
277 AND table_name=%s""", (table,))
278 for row in cursor:
279 if row[0] == table:
280 cursor.execute("DROP TABLE " + self.quote(table))
281 break
282 else:
283 cursor.execute("DROP TABLE IF EXISTS " + self.quote(table))
284
285 - def _sequence_name(self, table, column):
286 return '%s_%s_seq' % (table, column)
287
288 - def _get_version(self):
289 cursor = self.cursor()
290 cursor.execute('SELECT version()')
291 for version, in cursor:
292
293 if version.startswith('PostgreSQL '):
294 return version.split(' ', 2)[1]
295