Mercurial > repos > petrn > repeatexplorer
comparison lib/graphtools.py @ 0:f6ebec6e235e draft
Uploaded
| author | petrn |
|---|---|
| date | Thu, 19 Dec 2019 13:46:43 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:f6ebec6e235e |
|---|---|
| 1 #!/usr/bin/env python3 | |
| 2 ''' | |
| 3 This module is mainly for large graph (e.i hitsort) storage, parsing and for clustering | |
| 4 ''' | |
| 5 import os | |
| 6 import sys | |
| 7 import sqlite3 | |
| 8 import time | |
| 9 import subprocess | |
| 10 import logging | |
| 11 from collections import defaultdict | |
| 12 import collections | |
| 13 import operator | |
| 14 import math | |
| 15 import random | |
| 16 import itertools | |
| 17 import config | |
| 18 from lib import r2py | |
| 19 from lib.utils import FilePath | |
| 20 from lib.parallel.parallel import parallel2 as parallel | |
| 21 REQUIRED_VERSION = (3, 4) | |
| 22 MAX_BUFFER_SIZE = 100000 | |
| 23 if sys.version_info < REQUIRED_VERSION: | |
| 24 raise Exception("\n\npython 3.4 or higher is required!\n") | |
| 25 LOGGER = logging.getLogger(__name__) | |
| 26 | |
| 27 | |
| 28 def dfs(start, graph): | |
| 29 """ | |
| 30 helper function for cluster merging. | |
| 31 Does depth-first search, returning a set of all nodes seen. | |
| 32 Takes: a graph in node --> [neighbors] form. | |
| 33 """ | |
| 34 visited, worklist = set(), [start] | |
| 35 | |
| 36 while worklist: | |
| 37 node = worklist.pop() | |
| 38 if node not in visited: | |
| 39 visited.add(node) | |
| 40 # Add all the neighbors to the worklist. | |
| 41 worklist.extend(graph[node]) | |
| 42 | |
| 43 return visited | |
| 44 | |
| 45 | |
| 46 def graph_components(edges): | |
| 47 """ | |
| 48 Given a graph as a list of edges, divide the nodes into components. | |
| 49 Takes a list of pairs of nodes, where the nodes are integers. | |
| 50 """ | |
| 51 | |
| 52 # Construct a graph (mapping node --> [neighbors]) from the edges. | |
| 53 graph = defaultdict(list) | |
| 54 nodes = set() | |
| 55 | |
| 56 for v1, v2 in edges: | |
| 57 nodes.add(v1) | |
| 58 nodes.add(v2) | |
| 59 | |
| 60 graph[v1].append(v2) | |
| 61 graph[v2].append(v1) | |
| 62 | |
| 63 # Traverse the graph to find the components. | |
| 64 components = [] | |
| 65 | |
| 66 # We don't care what order we see the nodes in. | |
| 67 while nodes: | |
| 68 component = dfs(nodes.pop(), graph) | |
| 69 components.append(component) | |
| 70 | |
| 71 # Remove this component from the nodes under consideration. | |
| 72 nodes -= component | |
| 73 | |
| 74 return components | |
| 75 | |
| 76 | |
| 77 class Graph(): | |
| 78 ''' | |
| 79 create Graph object stored in sqlite database, either in memory or on disk | |
| 80 structure of table is: | |
| 81 V1 V2 weigth12 | |
| 82 V2 V3 weight23 | |
| 83 V4 V5 weight45 | |
| 84 ... | |
| 85 ... | |
| 86 !! this is undirected simple graph - duplicated edges must | |
| 87 be removed on graph creation | |
| 88 | |
| 89 ''' | |
| 90 # seed for random number generator - this is necessary for reproducibility between runs | |
| 91 seed = '123' | |
| 92 | |
| 93 def __init__(self, | |
| 94 source=None, | |
| 95 filename=None, | |
| 96 new=False, | |
| 97 paired=True, | |
| 98 seqids=None): | |
| 99 ''' | |
| 100 filename : fite where to store database, if not defined it is stored in memory | |
| 101 source : ncol file from which describe graph | |
| 102 new : if false and source is not define graph can be loaded from database (filename) | |
| 103 | |
| 104 vertices_name must be in correcti order!!! | |
| 105 ''' | |
| 106 | |
| 107 self.filename = filename | |
| 108 self.source = source | |
| 109 self.paired = paired | |
| 110 # path to indexed graph - will be set later | |
| 111 self.indexed_file = None | |
| 112 self._cluster_list = None | |
| 113 # these two attributes are set after clustering | |
| 114 # communities before merging | |
| 115 self.graph_2_community0 = None | |
| 116 # communities after merging | |
| 117 self.graph_2_community = None | |
| 118 self.number_of_clusters = None | |
| 119 self.binary_file = None | |
| 120 self.cluster_sizes = None | |
| 121 self.graph_tree = None | |
| 122 self.graph_tree_log = None | |
| 123 self.weights_file = None | |
| 124 | |
| 125 if filename: | |
| 126 if os.path.isfile(filename) and (new or source): | |
| 127 os.remove(filename) | |
| 128 self.conn = sqlite3.connect(filename) | |
| 129 else: | |
| 130 self.conn = sqlite3.connect(":memory:") | |
| 131 c = self.conn.cursor() | |
| 132 | |
| 133 c.execute("PRAGMA page_size=8192") | |
| 134 c.execute("PRAGMA cache_size = 2000000 ") # this helps | |
| 135 | |
| 136 try: | |
| 137 c.execute(( | |
| 138 "create table graph (v1 integer, v2 integer, weight integer, " | |
| 139 "pair integer, v1length integer, v1start integer, v1end integer, " | |
| 140 "v2length integer, v2start integer, v2end integer, pid integer," | |
| 141 "evalue real, strand text )")) | |
| 142 except sqlite3.OperationalError: | |
| 143 pass # table already exist | |
| 144 else: | |
| 145 c.execute( | |
| 146 "create table vertices (vertexname text primary key, vertexindex integer)") | |
| 147 tables = sorted(c.execute( | |
| 148 "SELECT name FROM sqlite_master WHERE type='table'").fetchall()) | |
| 149 | |
| 150 if not [('graph', ), ('vertices', )] == tables: | |
| 151 raise Exception("tables for sqlite for graph are not correct") | |
| 152 | |
| 153 if source: | |
| 154 self._read_from_hitsort() | |
| 155 | |
| 156 if paired and seqids: | |
| 157 # vertices must be defined - create graph of paired reads: | |
| 158 # last character must disinguish pair | |
| 159 c.execute(( | |
| 160 "create table pairs (basename, vertexname1, vertexname2," | |
| 161 "v1 integer, v2 integer, cluster1 integer, cluster2 integer)")) | |
| 162 buffer = [] | |
| 163 for i, k in zip(seqids[0::2], seqids[1::2]): | |
| 164 assert i[:-1] == k[:-1], "problem with pair reads ids" | |
| 165 # some vertices are not in graph - singletons | |
| 166 try: | |
| 167 index1 = self.vertices[i] | |
| 168 except KeyError: | |
| 169 index1 = -1 | |
| 170 | |
| 171 try: | |
| 172 index2 = self.vertices[k] | |
| 173 except KeyError: | |
| 174 index2 = -1 | |
| 175 | |
| 176 buffer.append((i[:-1], i, k, index1, index2)) | |
| 177 | |
| 178 self.conn.executemany( | |
| 179 "insert into pairs (basename, vertexname1, vertexname2, v1, v2) values (?,?,?,?,?)", | |
| 180 buffer) | |
| 181 self.conn.commit() | |
| 182 | |
| 183 def _read_from_hitsort(self): | |
| 184 | |
| 185 c = self.conn.cursor() | |
| 186 c.execute("delete from graph") | |
| 187 buffer = [] | |
| 188 vertices = {} | |
| 189 counter = 0 | |
| 190 v_count = 0 | |
| 191 with open(self.source, 'r') as f: | |
| 192 for i in f: | |
| 193 edge_index = {} | |
| 194 items = i.split() | |
| 195 # get or insert vertex index | |
| 196 for vn in items[0:2]: | |
| 197 if vn not in vertices: | |
| 198 vertices[vn] = v_count | |
| 199 edge_index[vn] = v_count | |
| 200 v_count += 1 | |
| 201 else: | |
| 202 edge_index[vn] = vertices[vn] | |
| 203 if self.paired: | |
| 204 pair = int(items[0][:-1] == items[1][:-1]) | |
| 205 else: | |
| 206 pair = 0 | |
| 207 buffer.append(((edge_index[items[0]], edge_index[items[1]], | |
| 208 items[2], pair) + tuple(items[3:]))) | |
| 209 if len(buffer) == MAX_BUFFER_SIZE: | |
| 210 counter += 1 | |
| 211 self.conn.executemany( | |
| 212 "insert or ignore into graph values (?,?,?,?,?,?,?,?,?,?,?,?,?)", | |
| 213 buffer) | |
| 214 buffer = [] | |
| 215 if buffer: | |
| 216 self.conn.executemany( | |
| 217 "insert or ignore into graph values (?,?,?,?,?,?,?,?,?,?,?,?,?)", | |
| 218 buffer) | |
| 219 | |
| 220 self.conn.commit() | |
| 221 self.vertices = vertices | |
| 222 self.vertexid2name = { | |
| 223 vertex: index | |
| 224 for index, vertex in vertices.items() | |
| 225 } | |
| 226 self.vcount = len(vertices) | |
| 227 c = self.conn.cursor() | |
| 228 c.execute("select count(*) from graph") | |
| 229 self.ecount = c.fetchone()[0] | |
| 230 # fill table of vertices | |
| 231 self.conn.executemany("insert into vertices values (?,?)", | |
| 232 vertices.items()) | |
| 233 self.conn.commit() | |
| 234 | |
| 235 def save_indexed_graph(self, file=None): | |
| 236 if not file: | |
| 237 self.indexed_file = "{}.int".format(self.source) | |
| 238 else: | |
| 239 self.indexed_file = file | |
| 240 c = self.conn.cursor() | |
| 241 with open(self.indexed_file, 'w') as f: | |
| 242 out = c.execute('select v1,v2,weight from graph') | |
| 243 for v1, v2, weight in out: | |
| 244 f.write('{}\t{}\t{}\n'.format(v1, v2, weight)) | |
| 245 | |
| 246 def get_subgraph(self, vertices): | |
| 247 pass | |
| 248 | |
| 249 def _levels(self): | |
| 250 with open(self.graph_tree_log, 'r') as f: | |
| 251 levels = -1 | |
| 252 for i in f: | |
| 253 if i[:5] == 'level': | |
| 254 levels += 1 | |
| 255 return levels | |
| 256 | |
| 257 def _reindex_community(self, id2com): | |
| 258 ''' | |
| 259 reindex community and superclusters so that biggest cluster is no.1 | |
| 260 ''' | |
| 261 self.conn.commit() | |
| 262 _, community, supercluster = zip(*id2com) | |
| 263 (cluster_index, frq, self.cluster_sizes, | |
| 264 self.number_of_clusters) = self._get_index_and_frequency(community) | |
| 265 | |
| 266 supercluster_index, sc_frq, _, _ = self._get_index_and_frequency( | |
| 267 supercluster) | |
| 268 id2com_reindexed = [] | |
| 269 | |
| 270 for i, _ in enumerate(id2com): | |
| 271 id2com_reindexed.append((id2com[i][0], id2com[i][1], frq[ | |
| 272 i], cluster_index[i], supercluster_index[i], sc_frq[i])) | |
| 273 return id2com_reindexed | |
| 274 | |
| 275 @staticmethod | |
| 276 def _get_index_and_frequency(membership): | |
| 277 frequency_table = collections.Counter(membership) | |
| 278 frequency_table_sorted = sorted(frequency_table.items(), | |
| 279 key=operator.itemgetter(1), | |
| 280 reverse=True) | |
| 281 frq = [] | |
| 282 for i in membership: | |
| 283 frq.append(frequency_table[i]) | |
| 284 rank = {} | |
| 285 index = 0 | |
| 286 for comm, _ in frequency_table_sorted: | |
| 287 index += 1 | |
| 288 rank[comm] = index | |
| 289 cluster_index = [rank[i] for i in membership] | |
| 290 cluster_sizes = [i[1] for i in frequency_table_sorted] | |
| 291 number_of_clusters = len(frequency_table) | |
| 292 return [cluster_index, frq, cluster_sizes, number_of_clusters] | |
| 293 | |
| 294 def louvain_clustering(self, merge_threshold=0, cleanup=False): | |
| 295 ''' | |
| 296 input - graph | |
| 297 output - list of clusters | |
| 298 executables path ?? | |
| 299 ''' | |
| 300 LOGGER.info("converting hitsort to binary format") | |
| 301 self.binary_file = "{}.bin".format(self.indexed_file) | |
| 302 self.weights_file = "{}.weight".format(self.indexed_file) | |
| 303 self.graph_tree = "{}.graph_tree".format(self.indexed_file) | |
| 304 self.graph_tree_log = "{}.graph_tree_log".format(self.indexed_file) | |
| 305 self.graph_2_community0 = "{}.graph_2_community0".format( | |
| 306 self.indexed_file) | |
| 307 self._cluster_list = None | |
| 308 self.graph_2_community = "{}.graph_2_community".format( | |
| 309 self.indexed_file) | |
| 310 print(["louvain_convert", "-i", self.indexed_file, "-o", | |
| 311 self.binary_file, "-w", self.weights_file]) | |
| 312 subprocess.check_call( | |
| 313 ["louvain_convert", "-i", self.indexed_file, "-o", | |
| 314 self.binary_file, "-w", self.weights_file], | |
| 315 timeout=None) | |
| 316 | |
| 317 gt = open(self.graph_tree, 'w') | |
| 318 gtl = open(self.graph_tree_log, 'w') | |
| 319 LOGGER.info("running louvain clustering...") | |
| 320 subprocess.check_call( | |
| 321 ["louvain_community", self.binary_file, "-l", "-1", "-w", | |
| 322 self.weights_file, "-v ", "-s", self.seed], | |
| 323 stdout=gt, | |
| 324 stderr=gtl, | |
| 325 timeout=None) | |
| 326 gt.close() | |
| 327 gtl.close() | |
| 328 | |
| 329 LOGGER.info("creating list of cummunities") | |
| 330 gt2c = open(self.graph_2_community0, 'w') | |
| 331 subprocess.check_call( | |
| 332 ['louvain_hierarchy', self.graph_tree, "-l", str(self._levels())], | |
| 333 stdout=gt2c) | |
| 334 gt2c.close() | |
| 335 if merge_threshold and self.paired: | |
| 336 com2newcom = self.find_superclusters(merge_threshold) | |
| 337 elif self.paired: | |
| 338 com2newcom = self.find_superclusters(config.SUPERCLUSTER_THRESHOLD) | |
| 339 else: | |
| 340 com2newcom = {} | |
| 341 # merging of clusters, creatting superclusters | |
| 342 LOGGER.info("mergings clusters based on mate-pairs ") | |
| 343 # modify self.graph_2_community file | |
| 344 # rewrite graph2community | |
| 345 with open(self.graph_2_community0, 'r') as fin: | |
| 346 with open(self.graph_2_community, 'w') as fout: | |
| 347 for i in fin: | |
| 348 # write graph 2 community file in format: | |
| 349 # id communityid supeclusterid | |
| 350 # if merging - community and superclustwers are identical | |
| 351 vi, com = i.split() | |
| 352 if merge_threshold: | |
| 353 ## mergin | |
| 354 if int(com) in com2newcom: | |
| 355 fout.write("{} {} {}\n".format(vi, com2newcom[int( | |
| 356 com)], com2newcom[int(com)])) | |
| 357 else: | |
| 358 fout.write("{} {} {}\n".format(vi, com, com)) | |
| 359 else: | |
| 360 ## superclusters | |
| 361 if int(com) in com2newcom: | |
| 362 fout.write("{} {} {}\n".format(vi, com, com2newcom[ | |
| 363 int(com)])) | |
| 364 else: | |
| 365 fout.write("{} {} {}\n".format(vi, com, com)) | |
| 366 | |
| 367 LOGGER.info("loading communities into database") | |
| 368 c = self.conn.cursor() | |
| 369 c.execute(("create table communities (vertexindex integer primary key," | |
| 370 "community integer, size integer, cluster integer, " | |
| 371 "supercluster integer, supercluster_size integer)")) | |
| 372 id2com = [] | |
| 373 with open(self.graph_2_community, 'r') as f: | |
| 374 for i in f: | |
| 375 name, com, supercluster = i.split() | |
| 376 id2com.append((name, com, supercluster)) | |
| 377 id2com_reindexed = self._reindex_community(id2com) | |
| 378 c.executemany("insert into communities values (?,?,?,?,?,?)", | |
| 379 id2com_reindexed) | |
| 380 #create table of superclusters - clusters | |
| 381 c.execute(("create table superclusters as " | |
| 382 "select distinct supercluster, supercluster_size, " | |
| 383 "cluster, size from communities;")) | |
| 384 # create view id-index-cluster | |
| 385 c.execute( | |
| 386 ("CREATE VIEW vertex_cluster AS SELECT vertices.vertexname," | |
| 387 "vertices.vertexindex, communities.cluster, communities.size" | |
| 388 " FROM vertices JOIN communities USING (vertexindex)")) | |
| 389 self.conn.commit() | |
| 390 | |
| 391 # add clustering infor to graph | |
| 392 LOGGER.info("updating graph table") | |
| 393 t0 = time.time() | |
| 394 | |
| 395 c.execute("alter table graph add c1 integer") | |
| 396 c.execute("alter table graph add c2 integer") | |
| 397 c.execute(("update graph set c1 = (select cluster FROM communities " | |
| 398 "where communities.vertexindex=graph.v1)")) | |
| 399 c.execute( | |
| 400 ("update graph set c2 = (select cluster FROM communities where " | |
| 401 "communities.vertexindex=graph.v2)")) | |
| 402 self.conn.commit() | |
| 403 t1 = time.time() | |
| 404 LOGGER.info("updating graph table - done in {} seconds".format(t1 - | |
| 405 t0)) | |
| 406 | |
| 407 # identify similarity connections between clusters | |
| 408 c.execute( | |
| 409 "create table cluster_connections as SELECT c1,c2 , count(*) FROM (SELECT c1, c2 FROM graph WHERE c1>c2 UNION ALL SELECT c2 as c1, c1 as c2 FROM graph WHERE c2>c1) GROUP BY c1, c2") | |
| 410 # TODO - remove directionality - summarize - | |
| 411 | |
| 412 # add cluster identity to pairs table | |
| 413 | |
| 414 if self.paired: | |
| 415 LOGGER.info("analyzing pairs ") | |
| 416 t0 = time.time() | |
| 417 c.execute( | |
| 418 "UPDATE pairs SET cluster1=(SELECT cluster FROM communities WHERE communities.vertexindex=pairs.v1)") | |
| 419 t1 = time.time() | |
| 420 LOGGER.info( | |
| 421 "updating pairs table - cluster1 - done in {} seconds".format( | |
| 422 t1 - t0)) | |
| 423 | |
| 424 t0 = time.time() | |
| 425 c.execute( | |
| 426 "UPDATE pairs SET cluster2=(SELECT cluster FROM communities WHERE communities.vertexindex=pairs.v2)") | |
| 427 t1 = time.time() | |
| 428 LOGGER.info( | |
| 429 "updating pairs table - cluster2 - done in {} seconds".format( | |
| 430 t1 - t0)) | |
| 431 # reorder records | |
| 432 | |
| 433 t0 = time.time() | |
| 434 c.execute( | |
| 435 "UPDATE pairs SET cluster1=cluster2, cluster2=cluster1, vertexname1=vertexname2,vertexname2=vertexname1 where cluster1<cluster2") | |
| 436 t1 = time.time() | |
| 437 LOGGER.info("sorting - done in {} seconds".format(t1 - t0)) | |
| 438 | |
| 439 t0 = time.time() | |
| 440 c.execute( | |
| 441 "create table cluster_mate_connections as select cluster1 as c1, cluster2 as c2, count(*) as N, group_concat(basename) as ids from pairs where cluster1!=cluster2 group by cluster1, cluster2;") | |
| 442 t1 = time.time() | |
| 443 LOGGER.info( | |
| 444 "creating cluster_mate_connections table - done in {} seconds".format( | |
| 445 t1 - t0)) | |
| 446 # summarize | |
| 447 t0 = time.time() | |
| 448 self._calculate_pair_bond() | |
| 449 t1 = time.time() | |
| 450 LOGGER.info( | |
| 451 "calculating cluster pair bond - done in {} seconds".format( | |
| 452 t1 - t0)) | |
| 453 t0 = time.time() | |
| 454 else: | |
| 455 # not paired - create empty tables | |
| 456 self._add_empty_tables() | |
| 457 | |
| 458 self.conn.commit() | |
| 459 t1 = time.time() | |
| 460 LOGGER.info("commiting changes - done in {} seconds".format(t1 - t0)) | |
| 461 | |
| 462 if cleanup: | |
| 463 LOGGER.info("cleaning clustering temp files") | |
| 464 os.unlink(self.binary_file) | |
| 465 os.unlink(self.weights_file) | |
| 466 os.unlink(self.graph_tree) | |
| 467 os.unlink(self.graph_tree_log) | |
| 468 os.unlink(self.graph_2_community0) | |
| 469 os.unlink(self.graph_2_community) | |
| 470 os.unlink(self.indexed_file) | |
| 471 self.binary_file = None | |
| 472 self.weights_file = None | |
| 473 self.graph_tree = None | |
| 474 self.graph_tree_log = None | |
| 475 self.graph_2_community0 = None | |
| 476 self.graph_2_community = None | |
| 477 self.indexed_file = None | |
| 478 | |
| 479 # calcultate k | |
| 480 | |
| 481 def find_superclusters(self, merge_threshold): | |
| 482 '''Find superclusters from clustering based on paired reads ''' | |
| 483 clsdict = {} | |
| 484 with open(self.graph_2_community0, 'r') as f: | |
| 485 for i in f: | |
| 486 vi, com = i.split() | |
| 487 if com in clsdict: | |
| 488 clsdict[com] += [self.vertexid2name[int(vi)][0:-1]] | |
| 489 else: | |
| 490 clsdict[com] = [self.vertexid2name[int(vi)][0:-1]] | |
| 491 # remove all small clusters - these will not be merged: | |
| 492 small_cls = [] | |
| 493 for i in clsdict: | |
| 494 if len(clsdict[i]) < config.MINIMUM_NUMBER_OF_READS_FOR_MERGING: | |
| 495 small_cls.append(i) | |
| 496 for i in small_cls: | |
| 497 del clsdict[i] | |
| 498 pairs = [] | |
| 499 for i, j in itertools.combinations(clsdict, 2): | |
| 500 s1 = set(clsdict[i]) | |
| 501 s2 = set(clsdict[j]) | |
| 502 wgh = len(s1 & s2) | |
| 503 if wgh < config.MINIMUM_NUMBER_OF_SHARED_PAIRS_FOR_MERGING: | |
| 504 continue | |
| 505 else: | |
| 506 n1 = len(s1) * 2 - len(clsdict[i]) | |
| 507 n2 = len(s2) * 2 - len(clsdict[j]) | |
| 508 k = 2 * wgh / (n1 + n2) | |
| 509 if k > merge_threshold: | |
| 510 pairs.append((int(i), int(j))) | |
| 511 # find connected commponents - will be merged | |
| 512 cls2merge = graph_components(pairs) | |
| 513 com2newcom = {} | |
| 514 for i in cls2merge: | |
| 515 newcom = min(i) | |
| 516 for j in i: | |
| 517 com2newcom[j] = newcom | |
| 518 return com2newcom | |
| 519 | |
| 520 def adjust_cluster_size(self, proportion_kept, ids_kept): | |
| 521 LOGGER.info("adjusting cluster sizes") | |
| 522 c = self.conn.cursor() | |
| 523 c.execute("ALTER TABLE superclusters ADD COLUMN size_uncorrected INTEGER") | |
| 524 c.execute("UPDATE superclusters SET size_uncorrected=size") | |
| 525 if ids_kept: | |
| 526 ids_kept_set = set(ids_kept) | |
| 527 ratio = (1 - proportion_kept)/proportion_kept | |
| 528 for cl, size in c.execute("SELECT cluster,size FROM superclusters"): | |
| 529 ids = self.get_cluster_reads(cl) | |
| 530 ovl_size = len(ids_kept_set.intersection(ids)) | |
| 531 size_adjusted = int(len(ids) + ovl_size * ratio) | |
| 532 if size_adjusted > size: | |
| 533 c.execute("UPDATE superclusters SET size=? WHERE cluster=?", | |
| 534 (size_adjusted, cl)) | |
| 535 self.conn.commit() | |
| 536 LOGGER.info("adjusting cluster sizes - done") | |
| 537 | |
| 538 def export_cls(self, path): | |
| 539 with open(path, 'w') as f: | |
| 540 for i in range(1, self.number_of_clusters + 1): | |
| 541 ids = self.get_cluster_reads(i) | |
| 542 f.write(">CL{}\t{}\n".format(i, len(ids))) | |
| 543 f.write("\t".join(ids)) | |
| 544 f.write("\n") | |
| 545 | |
| 546 def _calculate_pair_bond(self): | |
| 547 c = self.conn.cursor() | |
| 548 out = c.execute("select c1, c2, ids from cluster_mate_connections") | |
| 549 buffer = [] | |
| 550 for c1, c2, ids in out: | |
| 551 w = len(set(ids.split(","))) | |
| 552 n1 = len(set([i[:-1] for i in self.get_cluster_reads(c1) | |
| 553 ])) * 2 - len(self.get_cluster_reads(c1)) | |
| 554 n2 = len(set([i[:-1] for i in self.get_cluster_reads(c2) | |
| 555 ])) * 2 - len(self.get_cluster_reads(c2)) | |
| 556 buffer.append((c1, c2, n1, n2, w, 2 * w / (n1 + n2))) | |
| 557 c.execute( | |
| 558 "CREATE TABLE cluster_mate_bond (c1 INTEGER, c2 INTEGER, n1 INTEGER, n2 INTEGER, w INTEGER, k FLOAT)") | |
| 559 c.executemany(" INSERT INTO cluster_mate_bond values (?,?,?,?,?,?)", | |
| 560 buffer) | |
| 561 | |
| 562 def _add_empty_tables(self): | |
| 563 '''This is used with reads that are not paired | |
| 564 - it creates empty mate tables, this is necessary for | |
| 565 subsequent reporting to work corectly ''' | |
| 566 c = self.conn.cursor() | |
| 567 c.execute(("CREATE TABLE cluster_mate_bond (c1 INTEGER, c2 INTEGER, " | |
| 568 "n1 INTEGER, n2 INTEGER, w INTEGER, k FLOAT)")) | |
| 569 c.execute( | |
| 570 "CREATE TABLE cluster_mate_connections (c1 INTEGER, c2 INTEGER, N INTEGER, ids TEXT) ") | |
| 571 | |
| 572 def get_cluster_supercluster(self, cluster): | |
| 573 '''Get supercluster id for suplied cluster ''' | |
| 574 c = self.conn.cursor() | |
| 575 out = c.execute( | |
| 576 'SELECT supercluster FROM communities WHERE cluster="{0}" LIMIT 1'.format( | |
| 577 cluster)) | |
| 578 sc = out.fetchone()[0] | |
| 579 return sc | |
| 580 | |
| 581 def get_cluster_reads(self, cluster): | |
| 582 | |
| 583 if self._cluster_list: | |
| 584 return self._cluster_list[str(cluster)] | |
| 585 else: | |
| 586 # if queried first time | |
| 587 c = self.conn.cursor() | |
| 588 out = c.execute("select cluster, vertexname from vertex_cluster") | |
| 589 cluster_list = collections.defaultdict(list) | |
| 590 for clusterindex, vertexname in out: | |
| 591 cluster_list[str(clusterindex)].append(vertexname) | |
| 592 self._cluster_list = cluster_list | |
| 593 return self._cluster_list[str(cluster)] | |
| 594 | |
| 595 | |
| 596 def extract_cluster_blast(self, path, index, ids=None): | |
| 597 ''' Extract blast for cluster and save it to path | |
| 598 return number of blast lines ( i.e. number of graph edges E) | |
| 599 if ids is specified , only subset of blast is used''' | |
| 600 c = self.conn.cursor() | |
| 601 if ids: | |
| 602 vertexindex = ( | |
| 603 "select vertexindex from vertices " | |
| 604 "where vertexname in ({})").format('"' + '","'.join(ids) + '"') | |
| 605 | |
| 606 out = c.execute(("select * from graph where c1={0} and c2={0}" | |
| 607 " and v1 in ({1}) and v2 in ({1})").format( | |
| 608 index, vertexindex)) | |
| 609 else: | |
| 610 out = c.execute( | |
| 611 "select * from graph where c1={0} and c2={0}".format(index)) | |
| 612 E = 0 | |
| 613 N = len(self.get_cluster_reads(index)) | |
| 614 with open(path, 'w') as f: | |
| 615 for i in out: | |
| 616 print(self.vertexid2name[i[0]], | |
| 617 self.vertexid2name[ | |
| 618 i[1]], | |
| 619 i[2], | |
| 620 *i[4:13], | |
| 621 sep='\t', | |
| 622 file=f) | |
| 623 E += 1 | |
| 624 return E | |
| 625 | |
| 626 def export_clusters_files_multiple(self, | |
| 627 min_size, | |
| 628 directory, | |
| 629 sequences=None, | |
| 630 tRNA_database_path=None, | |
| 631 satellite_model_path=None): | |
| 632 def load_fun(N, E): | |
| 633 ''' estimate mem usage from graph size and density''' | |
| 634 NE = math.log(float(N) * float(E), 10) | |
| 635 if NE > 11.5: | |
| 636 return 1 | |
| 637 if NE > 11: | |
| 638 return 0.9 | |
| 639 if NE > 10: | |
| 640 return 0.4 | |
| 641 if NE > 9: | |
| 642 return 0.2 | |
| 643 if NE > 8: | |
| 644 return 0.07 | |
| 645 return 0.02 | |
| 646 | |
| 647 def estimate_sample_size(NV, NE, maxv, maxe): | |
| 648 ''' estimat suitable sampling based on the graph density | |
| 649 NV,NE is |V| and |E| of the graph | |
| 650 maxv, maxe are maximal |V| and |E|''' | |
| 651 | |
| 652 d = (2 * NE) / (NV * (NV - 1)) | |
| 653 eEst = (maxv * (maxv - 1) * d) / 2 | |
| 654 nEst = (d + math.sqrt(d**2 + 8 * d * maxe)) / (2 * d) | |
| 655 if eEst >= maxe: | |
| 656 N = int(nEst) | |
| 657 if nEst >= maxv: | |
| 658 N = int(maxv) | |
| 659 return N | |
| 660 | |
| 661 clusterindex = 1 | |
| 662 cluster_input_args = [] | |
| 663 ppn = [] | |
| 664 # is is comparative analysis? | |
| 665 if sequences.prefix_length: | |
| 666 self.conn.execute("CREATE TABLE comparative_counts (clusterindex INTEGER," | |
| 667 + ", ".join(["[{}] INTEGER".format(i) for i in sequences.prefix_codes.keys()]) + ")") | |
| 668 # do for comparative analysis | |
| 669 | |
| 670 for cl in range(self.number_of_clusters): | |
| 671 prefix_codes = dict((key, 0) for key in sequences.prefix_codes.keys()) | |
| 672 for i in self.get_cluster_reads(cl): | |
| 673 prefix_codes[i[0:sequences.prefix_length]] += 1 | |
| 674 header = ", ".join(["[" + str(i) + "]" for i in prefix_codes.keys()]) | |
| 675 values = ", ".join([str(i) for i in prefix_codes.values()]) | |
| 676 self.conn.execute( | |
| 677 "INSERT INTO comparative_counts (clusterindex, {}) VALUES ({}, {})".format( | |
| 678 header, cl, values)) | |
| 679 else: | |
| 680 prefix_codes = {} | |
| 681 | |
| 682 while True: | |
| 683 read_names = self.get_cluster_reads(clusterindex) | |
| 684 supercluster = self.get_cluster_supercluster(clusterindex) | |
| 685 N = len(read_names) | |
| 686 print("sequences.ids_kept -2 ") | |
| 687 print(sequences.ids_kept) | |
| 688 if sequences.ids_kept: | |
| 689 N_adjusted = round(len(set(sequences.ids_kept).intersection(read_names)) * | |
| 690 ((1 - config.FILTER_PROPORTION_OF_KEPT) / | |
| 691 config.FILTER_PROPORTION_OF_KEPT) + N) | |
| 692 else: | |
| 693 N_adjusted = N | |
| 694 if N < min_size: | |
| 695 break | |
| 696 else: | |
| 697 LOGGER.info("exporting cluster {}".format(clusterindex)) | |
| 698 blast_file = "{dir}/dir_CL{i:04}/hitsort_part.csv".format( | |
| 699 dir=directory, i=clusterindex) | |
| 700 cluster_dir = "{dir}/dir_CL{i:04}".format(dir=directory, | |
| 701 i=clusterindex) | |
| 702 fasta_file = "{dir}/reads_selection.fasta".format(dir=cluster_dir) | |
| 703 fasta_file_full = "{dir}/reads.fasta".format(dir=cluster_dir) | |
| 704 | |
| 705 os.makedirs(os.path.dirname(blast_file), exist_ok=True) | |
| 706 E = self.extract_cluster_blast(index=clusterindex, | |
| 707 path=blast_file) | |
| 708 # check if blast must be sampled | |
| 709 n_sample = estimate_sample_size(NV=N, | |
| 710 NE=E, | |
| 711 maxv=config.CLUSTER_VMAX, | |
| 712 maxe=config.CLUSTER_EMAX) | |
| 713 LOGGER.info("directories created..") | |
| 714 if n_sample < N: | |
| 715 LOGGER.info(("cluster is too large - sampling.." | |
| 716 "original size: {N}\n" | |
| 717 "sample size: {NS}\n" | |
| 718 "").format(N=N, NS=n_sample)) | |
| 719 random.seed(self.seed) | |
| 720 read_names_sample = random.sample(read_names, n_sample) | |
| 721 LOGGER.info("reads id sampled...") | |
| 722 blast_file_sample = "{dir}/dir_CL{i:04}/blast_sample.csv".format( | |
| 723 dir=directory, i=clusterindex) | |
| 724 E_sample = self.extract_cluster_blast( | |
| 725 index=clusterindex, | |
| 726 path=blast_file, | |
| 727 ids=read_names_sample) | |
| 728 LOGGER.info("numner of edges in sample: {}".format( | |
| 729 E_sample)) | |
| 730 sequences.save2fasta(fasta_file, subset=read_names_sample) | |
| 731 sequences.save2fasta(fasta_file_full, subset=read_names) | |
| 732 | |
| 733 else: | |
| 734 read_names_sample = None | |
| 735 E_sample = None | |
| 736 blast_file_sample = None | |
| 737 n_sample = None | |
| 738 sequences.save2fasta(fasta_file_full, subset=read_names) | |
| 739 ## TODO - use symlink instead of : | |
| 740 sequences.save2fasta(fasta_file, subset=read_names) | |
| 741 # export individual annotations tables: | |
| 742 # annotation is always for full cluster | |
| 743 LOGGER.info("exporting cluster annotation") | |
| 744 annotations = {} | |
| 745 annotations_custom = {} | |
| 746 for n in sequences.annotations: | |
| 747 print("sequences.annotations:", n) | |
| 748 if n.find("custom_db") == 0: | |
| 749 print("custom") | |
| 750 annotations_custom[n] = sequences.save_annotation( | |
| 751 annotation_name=n, | |
| 752 subset=read_names, | |
| 753 dir=cluster_dir) | |
| 754 else: | |
| 755 print("built in") | |
| 756 annotations[n] = sequences.save_annotation( | |
| 757 annotation_name=n, | |
| 758 subset=read_names, | |
| 759 dir=cluster_dir) | |
| 760 | |
| 761 cluster_input_args.append([ | |
| 762 n_sample, N,N_adjusted, blast_file, fasta_file, fasta_file_full, | |
| 763 clusterindex, supercluster, self.paired, | |
| 764 tRNA_database_path, satellite_model_path, sequences.prefix_codes, | |
| 765 prefix_codes, annotations, annotations_custom | |
| 766 ]) | |
| 767 clusterindex += 1 | |
| 768 ppn.append(load_fun(N, E)) | |
| 769 | |
| 770 | |
| 771 | |
| 772 self.conn.commit() | |
| 773 | |
| 774 # run in parallel: | |
| 775 # reorder jobs based on the ppn: | |
| 776 cluster_input_args = [ | |
| 777 x | |
| 778 for (y, x) in sorted( | |
| 779 zip(ppn, cluster_input_args), | |
| 780 key=lambda pair: pair[0], | |
| 781 reverse=True) | |
| 782 ] | |
| 783 ppn = sorted(ppn, reverse=True) | |
| 784 LOGGER.info("creating clusters in parallel") | |
| 785 clusters_info = parallel(Cluster, | |
| 786 *[list(i) for i in zip(*cluster_input_args)], | |
| 787 ppn=ppn) | |
| 788 # sort it back: | |
| 789 clusters_info = sorted(clusters_info, key=lambda cl: cl.index) | |
| 790 return clusters_info | |
| 791 | |
| 792 | |
| 793 class Cluster(): | |
| 794 ''' store and show information about cluster properties ''' | |
| 795 | |
| 796 def __init__(self, | |
| 797 size, | |
| 798 size_real, | |
| 799 size_adjusted, | |
| 800 blast_file, | |
| 801 fasta_file, | |
| 802 fasta_file_full, | |
| 803 index, | |
| 804 supercluster, | |
| 805 paired, | |
| 806 tRNA_database_path, | |
| 807 satellite_model_path, | |
| 808 all_prefix_codes, | |
| 809 prefix_codes, | |
| 810 annotations, | |
| 811 annotations_custom={}, | |
| 812 loop_index_threshold=0.7, | |
| 813 pair_completeness_threshold=0.40, | |
| 814 loop_index_unpaired_threshold=0.85): | |
| 815 if size: | |
| 816 # cluster was scaled down | |
| 817 self.size = size | |
| 818 self.size_real = size_real | |
| 819 else: | |
| 820 self.size = self.size_real = size_real | |
| 821 self.size_adjusted = size_adjusted | |
| 822 self.filtered = True if size_adjusted != size_real else False | |
| 823 self.all_prefix_codes = all_prefix_codes.keys | |
| 824 self.prefix_codes = prefix_codes | |
| 825 self.dir = FilePath(os.path.dirname(blast_file)) | |
| 826 self.blast_file = FilePath(blast_file) | |
| 827 self.fasta_file = FilePath(fasta_file) | |
| 828 self.fasta_file_full = FilePath(fasta_file_full) | |
| 829 self.index = index | |
| 830 self.assembly_files = {} | |
| 831 self.ltr_detection = None | |
| 832 self.supercluster = supercluster | |
| 833 self.annotations_files = annotations | |
| 834 self.annotations_files_custom = annotations_custom | |
| 835 self.annotations_summary, self.annotations_table = self._summarize_annotations( | |
| 836 annotations, size_real) | |
| 837 # add annotation | |
| 838 if len(annotations_custom): | |
| 839 self.annotations_summary_custom, self.annotations_custom_table = self._summarize_annotations( | |
| 840 annotations_custom, size_real) | |
| 841 else: | |
| 842 self.annotations_summary_custom, self.annotations_custom_table = "", "" | |
| 843 | |
| 844 self.paired = paired | |
| 845 self.graph_file = FilePath("{0}/graph_layout.GL".format(self.dir)) | |
| 846 self.directed_graph_file = FilePath( | |
| 847 "{0}/graph_layout_directed.RData".format(self.dir)) | |
| 848 self.fasta_oriented_file = FilePath("{0}/reads_selection_oriented.fasta".format( | |
| 849 self.dir)) | |
| 850 self.image_file = FilePath("{0}/graph_layout.png".format(self.dir)) | |
| 851 self.image_file_tmb = FilePath("{0}/graph_layout_tmb.png".format(self.dir)) | |
| 852 self.html_report_main = FilePath("{0}/index.html".format(self.dir)) | |
| 853 self.html_report_files = FilePath("{0}/html_files".format(self.dir)) | |
| 854 self.supercluster_best_hit = "NA" | |
| 855 TAREAN = r2py.R(config.RSOURCE_tarean) | |
| 856 LOGGER.info("creating graph no.{}".format(self.index)) | |
| 857 # if FileType muast be converted to str for rfunctions | |
| 858 graph_info = eval( | |
| 859 TAREAN.mgblast2graph( | |
| 860 self.blast_file, | |
| 861 seqfile=self.fasta_file, | |
| 862 seqfile_full=self.fasta_file_full, | |
| 863 graph_destination=self.graph_file, | |
| 864 directed_graph_destination=self.directed_graph_file, | |
| 865 oriented_sequences=self.fasta_oriented_file, | |
| 866 image_file=self.image_file, | |
| 867 image_file_tmb=self.image_file_tmb, | |
| 868 repex=True, | |
| 869 paired=self.paired, | |
| 870 satellite_model_path=satellite_model_path, | |
| 871 maxv=config.CLUSTER_VMAX, | |
| 872 maxe=config.CLUSTER_EMAX) | |
| 873 ) | |
| 874 print(graph_info) | |
| 875 self.ecount = graph_info['ecount'] | |
| 876 self.vcount = graph_info['vcount'] | |
| 877 self.loop_index = graph_info['loop_index'] | |
| 878 self.pair_completeness = graph_info['pair_completeness'] | |
| 879 self.orientation_score = graph_info['escore'] | |
| 880 self.satellite_probability = graph_info['satellite_probability'] | |
| 881 self.satellite = graph_info['satellite'] | |
| 882 # for paired reads: | |
| 883 cond1 = (self.paired and self.loop_index > loop_index_threshold and | |
| 884 self.pair_completeness > pair_completeness_threshold) | |
| 885 # no pairs | |
| 886 cond2 = ((not self.paired) and | |
| 887 self.loop_index > loop_index_unpaired_threshold) | |
| 888 if (cond1 or cond2) and config.ARGS.options.name != "oxford_nanopore": | |
| 889 self.putative_tandem = True | |
| 890 self.dir_tarean = FilePath("{}/tarean".format(self.dir)) | |
| 891 lock_file = self.dir + "../lock" | |
| 892 out = eval( | |
| 893 TAREAN.tarean(input_sequences=self.fasta_oriented_file, | |
| 894 output_dir=self.dir_tarean, | |
| 895 CPU=1, | |
| 896 reorient_reads=False, | |
| 897 tRNA_database_path=tRNA_database_path, | |
| 898 lock_file=lock_file) | |
| 899 ) | |
| 900 self.html_tarean = FilePath(out['htmlfile']) | |
| 901 self.tarean_contig_file = out['tarean_contig_file'] | |
| 902 self.TR_score = out['TR_score'] | |
| 903 self.TR_monomer_length = out['TR_monomer_length'] | |
| 904 self.TR_consensus = out['TR_consensus'] | |
| 905 self.pbs_score = out['pbs_score'] | |
| 906 self.max_ORF_length = out['orf_l'] | |
| 907 if (out['orf_l'] > config.ORF_THRESHOLD or | |
| 908 out['pbs_score'] > config.PBS_THRESHOLD): | |
| 909 self.tandem_rank = 3 | |
| 910 elif self.satellite: | |
| 911 self.tandem_rank = 1 | |
| 912 else: | |
| 913 self.tandem_rank = 2 | |
| 914 # some tandems could be rDNA genes - this must be check | |
| 915 # by annotation | |
| 916 if self.annotations_table: | |
| 917 rdna_score = 0 | |
| 918 contamination_score = 0 | |
| 919 for i in self.annotations_table: | |
| 920 if 'rDNA/' in i[0]: | |
| 921 rdna_score += i[1] | |
| 922 if 'contamination' in i[0]: | |
| 923 contamination_score += i[1] | |
| 924 if rdna_score > config.RDNA_THRESHOLD: | |
| 925 self.tandem_rank = 4 | |
| 926 if contamination_score > config.CONTAMINATION_THRESHOLD: | |
| 927 self.tandem_rank = 0 # other | |
| 928 | |
| 929 # by custom annotation - castom annotation has preference | |
| 930 if self.annotations_custom_table: | |
| 931 print("custom table searching") | |
| 932 rdna_score = 0 | |
| 933 contamination_score = 0 | |
| 934 print(self.annotations_custom_table) | |
| 935 for i in self.annotations_custom_table: | |
| 936 if 'rDNA' in i[0]: | |
| 937 rdna_score += i[1] | |
| 938 if 'contamination' in i[0]: | |
| 939 contamination_score += i[1] | |
| 940 if rdna_score > 0: | |
| 941 self.tandem_rank = 4 | |
| 942 if contamination_score > config.CONTAMINATION_THRESHOLD: | |
| 943 self.tandem_rank = 0 # other | |
| 944 | |
| 945 else: | |
| 946 self.putative_tandem = False | |
| 947 self.dir_tarean = None | |
| 948 self.html_tarean = None | |
| 949 self.TR_score = None | |
| 950 self.TR_monomer_length = None | |
| 951 self.TR_consensus = None | |
| 952 self.pbs_score = None | |
| 953 self.max_ORF_length = None | |
| 954 self.tandem_rank = 0 | |
| 955 self.tarean_contig_file = None | |
| 956 | |
| 957 def __str__(self): | |
| 958 out = [ | |
| 959 "cluster no {}:".format(self.index), | |
| 960 "Number of vertices : {}".format(self.size), | |
| 961 "Number of edges : {}".format(self.ecount), | |
| 962 "Loop index : {}".format(self.loop_index), | |
| 963 "Pair completeness : {}".format(self.pair_completeness), | |
| 964 "Orientation score : {}".format(self.orientation_score) | |
| 965 ] | |
| 966 return "\n".join(out) | |
| 967 | |
| 968 def listing(self, asdict=True): | |
| 969 ''' convert attributes to dictionary for printing purposes''' | |
| 970 out = {} | |
| 971 for i in dir(self): | |
| 972 # do not show private | |
| 973 if i[:2] != "__": | |
| 974 value = getattr(self, i) | |
| 975 if not callable(value): | |
| 976 # for dictionary | |
| 977 if isinstance(value, dict): | |
| 978 for k in value: | |
| 979 out[i + "_" + k] = value[k] | |
| 980 else: | |
| 981 out[i] = value | |
| 982 if asdict: | |
| 983 return out | |
| 984 else: | |
| 985 return {'keys': list(out.keys()), 'values': list(out.values())} | |
| 986 | |
| 987 def detect_ltr(self, trna_database): | |
| 988 '''detection of ltr in assembly files, output of analysis is stored in file''' | |
| 989 CREATE_ANNOTATION = r2py.R(config.RSOURCE_create_annotation, verbose=False) | |
| 990 if self.assembly_files['{}.{}.ace']: | |
| 991 ace_file = self.assembly_files['{}.{}.ace'] | |
| 992 print(ace_file, "running LTR detection") | |
| 993 fout = "{}/{}".format(self.dir, config.LTR_DETECTION_FILES['BASE']) | |
| 994 subprocess.check_call([ | |
| 995 config.LTR_DETECTION, | |
| 996 '-i', ace_file, | |
| 997 '-o', fout, | |
| 998 '-p', trna_database]) | |
| 999 # evaluate LTR presence | |
| 1000 fn = "{}/{}".format(self.dir, config.LTR_DETECTION_FILES['PBS_BLAST']) | |
| 1001 self.ltr_detection = CREATE_ANNOTATION.evaluate_LTR_detection(fn) | |
| 1002 | |
| 1003 | |
| 1004 @staticmethod | |
| 1005 def _summarize_annotations(annotations_files, size): | |
| 1006 ''' will tabulate annotation results ''' | |
| 1007 # TODO | |
| 1008 summaries = {} | |
| 1009 # weight is in percentage | |
| 1010 weight = 100 / size | |
| 1011 for i in annotations_files: | |
| 1012 with open(annotations_files[i]) as f: | |
| 1013 header = f.readline().split() | |
| 1014 id_index = [ | |
| 1015 i for i, item in enumerate(header) if item == "db_id" | |
| 1016 ][0] | |
| 1017 for line in f: | |
| 1018 classification = line.split()[id_index].split("#")[1] | |
| 1019 if classification in summaries: | |
| 1020 summaries[classification] += weight | |
| 1021 else: | |
| 1022 summaries[classification] = weight | |
| 1023 # format summaries for printing | |
| 1024 annotation_string = "" | |
| 1025 annotation_table = [] | |
| 1026 for i in sorted(summaries.items(), key=lambda x: x[1], reverse=True): | |
| 1027 ## hits with smaller proportion are not shown! | |
| 1028 if i[1] > 0.1: | |
| 1029 if i[1] > 1: | |
| 1030 annotation_string += "<b>{1:.2f}% {0}</b>\n".format(*i) | |
| 1031 else: | |
| 1032 annotation_string += "{1:.2f}% {0}\n".format(*i) | |
| 1033 annotation_table.append(i) | |
| 1034 return [annotation_string, annotation_table] | |
| 1035 | |
| 1036 @staticmethod | |
| 1037 def add_cluster_table_to_database(cluster_table, db_path): | |
| 1038 '''get column names from Cluster object and create | |
| 1039 correspopnding table in database values from all | |
| 1040 clusters are filled to database''' | |
| 1041 column_name_and_type = [] | |
| 1042 column_list = [] | |
| 1043 | |
| 1044 # get all atribute names -> they are column names | |
| 1045 # in sqlite table, detect proper sqlite type | |
| 1046 def identity(x): | |
| 1047 return (x) | |
| 1048 | |
| 1049 for i in cluster_table[1]: | |
| 1050 t = type(cluster_table[1][i]) | |
| 1051 if t == int: | |
| 1052 sqltype = "integer" | |
| 1053 convert = identity | |
| 1054 elif t == float: | |
| 1055 sqltype = "real" | |
| 1056 convert = identity | |
| 1057 elif t == bool: | |
| 1058 sqltype = "boolean" | |
| 1059 convert = bool | |
| 1060 else: | |
| 1061 sqltype = "text" | |
| 1062 convert = str | |
| 1063 column_name_and_type += ["[{}] {}".format(i, sqltype)] | |
| 1064 column_list += [tuple((i, convert))] | |
| 1065 header = ", ".join(column_name_and_type) | |
| 1066 db = sqlite3.connect(db_path) | |
| 1067 c = db.cursor() | |
| 1068 print("CREATE TABLE cluster_info ({})".format(header)) | |
| 1069 c.execute("CREATE TABLE cluster_info ({})".format(header)) | |
| 1070 # file data to cluster_table | |
| 1071 buffer = [] | |
| 1072 for i in cluster_table: | |
| 1073 buffer.append(tuple('{}'.format(fun(i[j])) for j, fun in | |
| 1074 column_list)) | |
| 1075 wildcards = ",".join(["?"] * len(column_list)) | |
| 1076 print(buffer) | |
| 1077 c.executemany("insert into cluster_info values ({})".format(wildcards), | |
| 1078 buffer) | |
| 1079 db.commit() |
