Package skytools :: Module sqltools
[frames] | no frames]

Source Code for Module skytools.sqltools

  1   
  2  """Database tools.""" 
  3   
  4  import os 
  5  from cStringIO import StringIO 
  6  from quoting import quote_copy, quote_literal 
  7   
  8  # 
  9  # Fully qualified table name 
 10  # 
 11   
12 -def fq_name_parts(tbl):
13 "Return fully qualified name parts." 14 15 tmp = tbl.split('.') 16 if len(tmp) == 1: 17 return ('public', tbl) 18 elif len(tmp) == 2: 19 return tmp 20 else: 21 raise Exception('Syntax error in table name:'+tbl)
22
23 -def fq_name(tbl):
24 "Return fully qualified name." 25 return '.'.join(fq_name_parts(tbl))
26 27 # 28 # info about table 29 #
30 -def get_table_oid(curs, table_name):
31 schema, name = fq_name_parts(table_name) 32 q = """select c.oid from pg_namespace n, pg_class c 33 where c.relnamespace = n.oid 34 and n.nspname = %s and c.relname = %s""" 35 curs.execute(q, [schema, name]) 36 res = curs.fetchall() 37 if len(res) == 0: 38 raise Exception('Table not found: '+table_name) 39 return res[0][0]
40
41 -def get_table_pkeys(curs, tbl):
42 oid = get_table_oid(curs, tbl) 43 q = "SELECT k.attname FROM pg_index i, pg_attribute k"\ 44 " WHERE i.indrelid = %s AND k.attrelid = i.indexrelid"\ 45 " AND i.indisprimary AND k.attnum > 0 AND NOT k.attisdropped"\ 46 " ORDER BY k.attnum" 47 curs.execute(q, [oid]) 48 return map(lambda x: x[0], curs.fetchall())
49
50 -def get_table_columns(curs, tbl):
51 oid = get_table_oid(curs, tbl) 52 q = "SELECT k.attname FROM pg_attribute k"\ 53 " WHERE k.attrelid = %s"\ 54 " AND k.attnum > 0 AND NOT k.attisdropped"\ 55 " ORDER BY k.attnum" 56 curs.execute(q, [oid]) 57 return map(lambda x: x[0], curs.fetchall())
58 59 # 60 # exist checks 61 #
62 -def exists_schema(curs, schema):
63 q = "select count(1) from pg_namespace where nspname = %s" 64 curs.execute(q, [schema]) 65 res = curs.fetchone() 66 return res[0]
67
68 -def exists_table(curs, table_name):
69 schema, name = fq_name_parts(table_name) 70 q = """select count(1) from pg_namespace n, pg_class c 71 where c.relnamespace = n.oid and c.relkind = 'r' 72 and n.nspname = %s and c.relname = %s""" 73 curs.execute(q, [schema, name]) 74 res = curs.fetchone() 75 return res[0]
76
77 -def exists_type(curs, type_name):
78 schema, name = fq_name_parts(type_name) 79 q = """select count(1) from pg_namespace n, pg_type t 80 where t.typnamespace = n.oid 81 and n.nspname = %s and t.typname = %s""" 82 curs.execute(q, [schema, name]) 83 res = curs.fetchone() 84 return res[0]
85
86 -def exists_function(curs, function_name, nargs):
87 # this does not check arg types, so may match several functions 88 schema, name = fq_name_parts(function_name) 89 q = """select count(1) from pg_namespace n, pg_proc p 90 where p.pronamespace = n.oid and p.pronargs = %s 91 and n.nspname = %s and p.proname = %s""" 92 curs.execute(q, [nargs, schema, name]) 93 res = curs.fetchone() 94 return res[0]
95
96 -def exists_language(curs, lang_name):
97 q = """select count(1) from pg_language 98 where lanname = %s""" 99 curs.execute(q, [lang_name]) 100 res = curs.fetchone() 101 return res[0]
102 103 # 104 # Support for PostgreSQL snapshot 105 # 106
107 -class Snapshot(object):
108 "Represents a PostgreSQL snapshot." 109
110 - def __init__(self, str):
111 "Create snapshot from string." 112 113 self.sn_str = str 114 tmp = str.split(':') 115 if len(tmp) != 3: 116 raise Exception('Unknown format for snapshot') 117 self.xmin = int(tmp[0]) 118 self.xmax = int(tmp[1]) 119 self.txid_list = [] 120 if tmp[2] != "": 121 for s in tmp[2].split(','): 122 self.txid_list.append(int(s))
123
124 - def contains(self, txid):
125 "Is txid visible in snapshot." 126 127 txid = int(txid) 128 129 if txid < self.xmin: 130 return True 131 if txid >= self.xmax: 132 return False 133 if txid in self.txid_list: 134 return False 135 return True
136 137 # 138 # Copy helpers 139 # 140
141 -def _gen_dict_copy(tbl, row, fields):
142 tmp = [] 143 for f in fields: 144 v = row[f] 145 tmp.append(quote_copy(v)) 146 return "\t".join(tmp)
147
148 -def _gen_dict_insert(tbl, row, fields):
149 tmp = [] 150 for f in fields: 151 v = row[f] 152 tmp.append(quote_literal(v)) 153 fmt = "insert into %s (%s) values (%s);" 154 return fmt % (tbl, ",".join(fields), ",".join(tmp))
155
156 -def _gen_list_copy(tbl, row, fields):
157 tmp = [] 158 for i in range(len(fields)): 159 v = row[i] 160 tmp.append(quote_copy(v)) 161 return "\t".join(tmp)
162
163 -def _gen_list_insert(tbl, row, fields):
164 tmp = [] 165 for i in range(len(fields)): 166 v = row[i] 167 tmp.append(quote_literal(v)) 168 fmt = "insert into %s (%s) values (%s);" 169 return fmt % (tbl, ",".join(fields), ",".join(tmp))
170
171 -def magic_insert(curs, tablename, data, fields = None, use_insert = 0):
172 """Copy/insert a list of dict/list data to database. 173 174 If curs == None, then the copy or insert statements are returned 175 as string. For list of dict the field list is optional, as its 176 possible to guess them from dict keys. 177 """ 178 if len(data) == 0: 179 return 180 181 # decide how to process 182 if type(data[0]) == type({}): 183 if fields == None: 184 fields = data[0].keys() 185 if use_insert: 186 row_func = _gen_dict_insert 187 else: 188 row_func = _gen_dict_copy 189 else: 190 if fields == None: 191 raise Exception("Non-dict data needs field list") 192 if use_insert: 193 row_func = _gen_list_insert 194 else: 195 row_func = _gen_list_copy 196 197 # init processing 198 buf = StringIO() 199 if curs == None and use_insert == 0: 200 fmt = "COPY %s (%s) FROM STDIN;\n" 201 buf.write(fmt % (tablename, ",".join(fields))) 202 203 # process data 204 for row in data: 205 buf.write(row_func(tablename, row, fields)) 206 buf.write("\n") 207 208 # if user needs only string, return it 209 if curs == None: 210 if use_insert == 0: 211 buf.write("\\.\n") 212 return buf.getvalue() 213 214 # do the actual copy/inserts 215 if use_insert: 216 curs.execute(buf.getvalue()) 217 else: 218 buf.seek(0) 219 hdr = "%s (%s)" % (tablename, ",".join(fields)) 220 curs.copy_from(buf, hdr)
221
222 -def db_copy_from_dict(curs, tablename, dict_list, fields = None):
223 """Do a COPY FROM STDIN using list of dicts as source.""" 224 225 if len(dict_list) == 0: 226 return 227 228 if fields == None: 229 fields = dict_list[0].keys() 230 231 buf = StringIO() 232 for dat in dict_list: 233 row = [] 234 for k in fields: 235 row.append(quote_copy(dat[k])) 236 buf.write("\t".join(row)) 237 buf.write("\n") 238 239 buf.seek(0) 240 hdr = "%s (%s)" % (tablename, ",".join(fields)) 241 242 curs.copy_from(buf, hdr)
243
244 -def db_copy_from_list(curs, tablename, row_list, fields):
245 """Do a COPY FROM STDIN using list of lists as source.""" 246 247 if len(row_list) == 0: 248 return 249 250 if fields == None or len(fields) == 0: 251 raise Exception('Need field list') 252 253 buf = StringIO() 254 for dat in row_list: 255 row = [] 256 for i in range(len(fields)): 257 row.append(quote_copy(dat[i])) 258 buf.write("\t".join(row)) 259 buf.write("\n") 260 261 buf.seek(0) 262 hdr = "%s (%s)" % (tablename, ",".join(fields)) 263 264 curs.copy_from(buf, hdr)
265 266 # 267 # Full COPY of table from one db to another 268 # 269
270 -class CopyPipe(object):
271 "Splits one big COPY to chunks." 272
273 - def __init__(self, dstcurs, tablename, limit = 512*1024, cancel_func=None):
274 self.tablename = tablename 275 self.dstcurs = dstcurs 276 self.buf = StringIO() 277 self.limit = limit 278 self.cancel_func = None 279 self.total_rows = 0 280 self.total_bytes = 0
281
282 - def write(self, data):
283 "New data from psycopg" 284 285 self.total_bytes += len(data) 286 self.total_rows += data.count("\n") 287 288 if self.buf.tell() >= self.limit: 289 pos = data.find('\n') 290 if pos >= 0: 291 # split at newline 292 p1 = data[:pos + 1] 293 p2 = data[pos + 1:] 294 self.buf.write(p1) 295 self.flush() 296 297 data = p2 298 299 self.buf.write(data)
300
301 - def flush(self):
302 "Send data out." 303 304 if self.cancel_func: 305 self.cancel_func() 306 307 if self.buf.tell() > 0: 308 self.buf.seek(0) 309 self.dstcurs.copy_from(self.buf, self.tablename) 310 self.buf.seek(0) 311 self.buf.truncate()
312
313 -def full_copy(tablename, src_curs, dst_curs, column_list = []):
314 """COPY table from one db to another.""" 315 316 if column_list: 317 hdr = "%s (%s)" % (tablename, ",".join(column_list)) 318 else: 319 hdr = tablename 320 buf = CopyPipe(dst_curs, hdr) 321 src_curs.copy_to(buf, hdr) 322 buf.flush() 323 324 return (buf.total_bytes, buf.total_rows)
325 326 327 # 328 # SQL installer 329 # 330
331 -class DBObject(object):
332 """Base class for installable DB objects.""" 333 name = None 334 sql = None 335 sql_file = None
336 - def __init__(self, name, sql = None, sql_file = None):
337 self.name = name 338 self.sql = sql 339 self.sql_file = sql_file
340 - def get_sql(self):
341 if self.sql: 342 return self.sql 343 if self.sql_file: 344 if self.sql_file[0] == "/": 345 fn = self.sql_file 346 else: 347 contrib_list = [ 348 "/opt/pgsql/share/contrib", 349 "/usr/share/postgresql/8.0/contrib", 350 "/usr/share/postgresql/8.0/contrib", 351 "/usr/share/postgresql/8.1/contrib", 352 "/usr/share/postgresql/8.2/contrib", 353 ] 354 for dir in contrib_list: 355 fn = os.path.join(dir, self.sql_file) 356 if os.path.isfile(fn): 357 return open(fn, "r").read() 358 raise Exception('File not found: '+self.sql_file) 359 raise Exception('object not defined')
360 - def create(self, curs):
361 curs.execute(self.get_sql())
362
363 -class DBSchema(DBObject):
364 """Handles db schema."""
365 - def exists(self, curs):
366 return exists_schema(curs, self.name)
367
368 -class DBTable(DBObject):
369 """Handles db table."""
370 - def exists(self, curs):
371 return exists_table(curs, self.name)
372
373 -class DBFunction(DBObject):
374 """Handles db function."""
375 - def __init__(self, name, nargs, sql = None, sql_file = None):
376 DBObject.__init__(self, name, sql, sql_file) 377 self.nargs = nargs
378 - def exists(self, curs):
379 return exists_function(curs, self.name, self.nargs)
380
381 -class DBLanguage(DBObject):
382 """Handles db language."""
383 - def __init__(self, name):
384 DBObject.__init__(self, name, sql = "create language %s" % name)
385 - def exists(self, curs):
386 return exists_language(curs, self.name)
387
388 -def db_install(curs, list, log = None):
389 """Installs list of objects into db.""" 390 for obj in list: 391 if not obj.exists(curs): 392 if log: 393 log.info('Installing %s' % obj.name) 394 obj.create(curs) 395 else: 396 if log: 397 log.info('%s is installed' % obj.name)
398