Mercurial > repos > mvdbeek > dapars
comparison dapars.py @ 0:bb84ee2f2137 draft
planemo upload for repository https://github.com/mvdbeek/dapars commit 868f8f2f7ac5d70c39b7d725ff087833b0f24f52-dirty
| author | mvdbeek |
|---|---|
| date | Tue, 27 Oct 2015 10:14:33 -0400 |
| parents | |
| children | 1b20ba32b4c5 |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:bb84ee2f2137 |
|---|---|
| 1 import argparse | |
| 2 import os | |
| 3 import csv | |
| 4 import numpy as np | |
| 5 from collections import OrderedDict, namedtuple | |
| 6 import filter_utr | |
| 7 import subprocess | |
| 8 from multiprocessing import Pool | |
| 9 import warnings | |
| 10 | |
| 11 | |
| 12 def parse_args(): | |
| 13 """ | |
| 14 Returns floating point values except for input files. | |
| 15 My initial approach will not filter anything. (FDR. fold_change, PDUI, Num_least ...) | |
| 16 :param argv: | |
| 17 :return: | |
| 18 """ | |
| 19 parser = argparse.ArgumentParser(prog='DaPars', description='Determines the usage of proximal polyA usage') | |
| 20 parser.add_argument("-c", "--control_alignments", nargs="+", required=True, | |
| 21 help="Alignment files in BAM format from control condition") | |
| 22 parser.add_argument("-t", "--treatment_alignments", nargs="+", required=True, | |
| 23 help="Alignment files in BAM format from treatment condition") | |
| 24 parser.add_argument("-u", "--utr_bed_file", required=True, type=file, | |
| 25 help="Bed file describing longest 3UTR positions") | |
| 26 parser.add_argument("-o", "--output_file", required=True, type=argparse.FileType('w'), | |
| 27 help="file containing output") | |
| 28 parser.add_argument("-cpu", required=False, type=int, default=1, | |
| 29 help="Number of CPU cores to use.") | |
| 30 parser.add_argument("-s", "--search_start", required=False, type=int, default=50, | |
| 31 help="Start search for breakpoint n nucleotides downstream of UTR start") | |
| 32 parser.add_argument("-ct", "--coverage_threshold", required=False, type=float, default=20, | |
| 33 help="minimum coverage in each aligment to be considered for determining breakpoints") | |
| 34 parser.add_argument("-b", "--breakpoint_bed", required=False, type=argparse.FileType('w'), | |
| 35 help="Write bedfile with coordinates of breakpoint positions to supplied path.") | |
| 36 parser.add_argument("-v", "--version", action='version', version='%(prog)s 0.1.4') | |
| 37 return parser.parse_args() | |
| 38 | |
| 39 | |
| 40 class UtrFinder(): | |
| 41 """ | |
| 42 This seems to be the main caller. | |
| 43 """ | |
| 44 | |
| 45 def __init__(self, args): | |
| 46 self.control_alignments = [file for file in args.control_alignments] | |
| 47 self.treatment_alignments = [file for file in args.treatment_alignments] | |
| 48 self.n_cpus = args.cpu | |
| 49 self.search_start = args.search_start | |
| 50 self.coverage_threshold = args.coverage_threshold | |
| 51 self.utr = args.utr_bed_file | |
| 52 self.gtf_fields = filter_utr.get_gtf_fields() | |
| 53 self.result_file = args.output_file | |
| 54 self.all_alignments = self.control_alignments + self.treatment_alignments | |
| 55 self.alignment_names = { file: os.path.basename(file) for file in self.all_alignments } | |
| 56 self.num_samples = len(self.all_alignments) | |
| 57 self.utr_dict = self.get_utr_dict(0.2) | |
| 58 self.dump_utr_dict_to_bedfile() | |
| 59 print "Established dictionary of 3\'UTRs" | |
| 60 self.coverage_files = self.run_bedtools_coverage() | |
| 61 self.utr_coverages = self.read_coverage_result() | |
| 62 print "Established dictionary of 3\'UTR coverages" | |
| 63 self.coverage_weights = self.get_coverage_weights() | |
| 64 self.result_tuple = self.get_result_tuple() | |
| 65 self.result_d = self.calculate_apa_ratios() | |
| 66 self.write_results() | |
| 67 if args.breakpoint_bed: | |
| 68 self.bed_output = args.breakpoint_bed | |
| 69 self.write_bed() | |
| 70 | |
| 71 | |
| 72 def dump_utr_dict_to_bedfile(self): | |
| 73 w = csv.writer(open("tmp_bedfile.bed", "w"), delimiter="\t") | |
| 74 for gene, utr in self.utr_dict.iteritems(): | |
| 75 w.writerow([utr["chr"], utr["new_start"]-1, utr["new_end"], gene, ".", utr["strand"]]) | |
| 76 | |
| 77 def run_bedtools_coverage(self): | |
| 78 """ | |
| 79 Use bedtools coverage to generate pileup data for all alignment files for the regions specified in utr_dict. | |
| 80 """ | |
| 81 coverage_files = [] | |
| 82 cmds = [] | |
| 83 for alignment_file in self.all_alignments: | |
| 84 cmd = "sort -k1,1 -k2,2n tmp_bedfile.bed | " | |
| 85 cmd = cmd + "bedtools coverage -d -s -abam {alignment_file} -b stdin |" \ | |
| 86 " cut -f 4,7,8 > coverage_file_{alignment_name}".format( | |
| 87 alignment_file = alignment_file, alignment_name= self.alignment_names[alignment_file] ) | |
| 88 cmds.append(cmd) | |
| 89 pool = Pool(self.n_cpus) | |
| 90 subprocesses = [subprocess.Popen([cmd], shell=True) for cmd in cmds] | |
| 91 [p.wait() for p in subprocesses] | |
| 92 coverage_files = ["gene_position_coverage_{alignment_name}".format( | |
| 93 alignment_name = self.alignment_names[alignment_file]) for alignment_file in self.all_alignments ] | |
| 94 return coverage_files | |
| 95 | |
| 96 def read_coverage_result(self): | |
| 97 """ | |
| 98 Read coverages back in and store as dictionary of numpy arrays | |
| 99 """ | |
| 100 coverage_dict = { gene: { name: np.zeros(utr_d["new_end"]+1-utr_d["new_start"]) for name in self.alignment_names.itervalues() } for gene, utr_d in self.utr_dict.iteritems() } | |
| 101 for alignment_name in self.alignment_names.itervalues(): | |
| 102 with open("coverage_file_{alignment_name}".format(alignment_name = alignment_name)) as coverage_file: | |
| 103 for line in coverage_file: | |
| 104 gene, position, coverage= line.strip().split("\t") | |
| 105 coverage_dict[gene][alignment_name][int(position)-1] = coverage | |
| 106 for utr_d in self.utr_dict.itervalues(): | |
| 107 if utr_d["strand"] == "-": | |
| 108 for alignment_name in self.alignment_names.values(): | |
| 109 coverage_dict[gene][alignment_name] = coverage_dict[gene][alignment_name][::-1] | |
| 110 return coverage_dict | |
| 111 | |
| 112 def get_utr_dict(self, shift): | |
| 113 utr_dict = OrderedDict() | |
| 114 for line in self.utr: | |
| 115 if not line.startswith("#"): | |
| 116 filter_utr.get_feature_dict( line=line, gtf_fields=self.gtf_fields, utr_dict=utr_dict, feature="UTR" ) | |
| 117 gene, utr_d = utr_dict.popitem() | |
| 118 utr_d = utr_d[0] | |
| 119 end_shift = int(round(abs(utr_d["start"] - utr_d["end"]) * shift)) | |
| 120 if utr_d["strand"] == "+": | |
| 121 utr_d["new_end"] = utr_d["end"] - end_shift | |
| 122 utr_d["new_start"] = utr_d["start"] | |
| 123 else: | |
| 124 utr_d["new_end"] = utr_d["end"] | |
| 125 utr_d["new_start"] = utr_d["start"] + end_shift | |
| 126 if utr_d["new_start"] + 50 < utr_d["new_end"]: | |
| 127 utr_dict[gene] = utr_d | |
| 128 return utr_dict | |
| 129 | |
| 130 def get_utr_coverage(self): | |
| 131 """ | |
| 132 Returns a dict: | |
| 133 { UTR : [coverage_aligment1, ...]} | |
| 134 """ | |
| 135 utr_coverages = {} | |
| 136 for utr, utr_d in self.utr_dict.iteritems(): | |
| 137 if utr_d["chr"] in self.available_chromosomes: | |
| 138 if utr_d["strand"] == "+": | |
| 139 is_reverse = False | |
| 140 else: | |
| 141 is_reverse = True | |
| 142 utr_coverage = [] | |
| 143 for bam in self.all_alignments: | |
| 144 bp_coverage = get_bp_coverage(bam, utr_d["chr"], utr_d["new_start"], utr_d["new_end"], is_reverse) | |
| 145 utr_coverage.append(bp_coverage) | |
| 146 utr_coverages[utr] = utr_coverage | |
| 147 return utr_coverages | |
| 148 | |
| 149 def get_coverage_weights(self): | |
| 150 """ | |
| 151 Return weights for normalizing coverage. | |
| 152 utr_coverage is still confusing. | |
| 153 """ | |
| 154 coverage_per_alignment = [] | |
| 155 for utr in self.utr_coverages.itervalues(): # TODO: be smarter about this. | |
| 156 utr_coverage = [] | |
| 157 for vector in utr.itervalues(): | |
| 158 utr_coverage.append(np.sum(vector)) | |
| 159 coverage_per_alignment.append(utr_coverage) | |
| 160 coverages = np.array([ sum(x) for x in zip(*coverage_per_alignment) ]) | |
| 161 coverage_weights = coverages / np.mean(coverages) # TODO: proabably median is better suited? | |
| 162 return coverage_weights | |
| 163 | |
| 164 def get_result_tuple(self): | |
| 165 static_desc = ["chr", "start", "end", "strand", "gene", "breakpoint", "control_mean_percent", "treatment_mean_percent" ] | |
| 166 samples_desc = [] | |
| 167 for statistic in ["coverage_long", "coverage_short", "percent_long"]: | |
| 168 for i, sample in enumerate(self.control_alignments): | |
| 169 samples_desc.append("control_{i}_{statistic}".format(i=i, statistic = statistic)) | |
| 170 for i, sample in enumerate(self.treatment_alignments): | |
| 171 samples_desc.append("treatment_{i}_{statistic}".format(i=i, statistic = statistic)) | |
| 172 return namedtuple("result", static_desc + samples_desc) | |
| 173 | |
| 174 def calculate_apa_ratios(self): | |
| 175 result_d = OrderedDict() | |
| 176 arg_d = {"result_tuple": self.result_tuple, | |
| 177 "coverage_weights":self.coverage_weights, | |
| 178 "num_samples":self.num_samples, | |
| 179 "num_control":len(self.control_alignments), | |
| 180 "num_treatment":len(self.treatment_alignments), | |
| 181 "result_d":result_d} | |
| 182 pool = Pool(self.n_cpus) | |
| 183 tasks = [ (self.utr_coverages[utr], utr, utr_d, self.result_tuple._fields, self.coverage_weights, self.num_samples, | |
| 184 len(self.control_alignments), len(self.treatment_alignments), result_d, self.search_start, self.coverage_threshold) for utr, utr_d in self.utr_dict.iteritems() ] | |
| 185 processed_tasks = [ pool.apply_async(calculate_all_utr, t) for t in tasks] | |
| 186 result = [res.get() for res in processed_tasks] | |
| 187 for d in result: | |
| 188 if isinstance(d, dict): | |
| 189 t = self.result_tuple(**d) | |
| 190 result_d[d["gene"]] = t | |
| 191 return result_d | |
| 192 | |
| 193 def write_results(self): | |
| 194 w = csv.writer(self.result_file, delimiter='\t') | |
| 195 header = list(self.result_tuple._fields) | |
| 196 header[0] = "#chr" | |
| 197 w.writerow(header) # field header | |
| 198 w.writerows( self.result_d.values()) | |
| 199 | |
| 200 def write_bed(self): | |
| 201 w = csv.writer(self.bed_output, delimiter='\t') | |
| 202 bed = [(result.chr, result.breakpoint, result.breakpoint+1, result.gene, 0, result.strand) for result in self.result_d.itervalues()] | |
| 203 w.writerows(bed) | |
| 204 | |
| 205 def calculate_all_utr(utr_coverage, utr, utr_d, result_tuple_fields, coverage_weights, num_samples, num_control, | |
| 206 num_treatment, result_d, search_start, coverage_threshold): | |
| 207 res = dict(zip(result_tuple_fields, result_tuple_fields)) | |
| 208 if utr_d["strand"] == "+": | |
| 209 is_reverse = False | |
| 210 else: | |
| 211 is_reverse = True | |
| 212 mse, breakpoint, abundances = estimate_coverage_extended_utr(utr_coverage, | |
| 213 utr_d["new_start"], | |
| 214 utr_d["new_end"], | |
| 215 is_reverse, | |
| 216 coverage_weights, | |
| 217 search_start, | |
| 218 coverage_threshold) | |
| 219 if not str(mse) == "Na": | |
| 220 long_coverage_vector = abundances[0] | |
| 221 short_coverage_vector = abundances[1] | |
| 222 num_non_zero = sum((np.array(long_coverage_vector) + np.array(short_coverage_vector)) > 0) # TODO: This introduces bias | |
| 223 if num_non_zero == num_samples: | |
| 224 percentage_long = [] | |
| 225 for i in range(num_samples): | |
| 226 ratio = float(long_coverage_vector[i]) / (long_coverage_vector[i] + short_coverage_vector[i]) # long 3'UTR percentage | |
| 227 percentage_long.append(ratio) | |
| 228 for i in range(num_control): | |
| 229 res["control_{i}_coverage_long".format(i=i)] = float(long_coverage_vector[i]) | |
| 230 res["control_{i}_coverage_short".format(i=i)] = float(short_coverage_vector[i]) | |
| 231 res["control_{i}_percent_long".format(i=i)] = percentage_long[i] | |
| 232 for k in range(num_treatment): | |
| 233 i = k + num_control | |
| 234 res["treatment_{i}_coverage_long".format(i=k)] = float(long_coverage_vector[i]) | |
| 235 res["treatment_{i}_coverage_short".format(i=k)] = float(short_coverage_vector[i]) | |
| 236 res["treatment_{i}_percent_long".format(i=k)] = percentage_long[i] | |
| 237 control_mean_percent = np.mean(np.array(percentage_long[:num_control])) | |
| 238 treatment_mean_percent = np.mean(np.array(percentage_long[num_control:])) | |
| 239 res["chr"] = utr_d["chr"] | |
| 240 res["start"] = utr_d["start"] | |
| 241 res["end"] = utr_d["end"] | |
| 242 res["strand"] = utr_d["strand"] | |
| 243 if is_reverse: | |
| 244 breakpoint = utr_d["new_end"] - breakpoint | |
| 245 else: | |
| 246 breakpoint = utr_d["new_start"] + breakpoint | |
| 247 res["breakpoint"] = breakpoint | |
| 248 res["control_mean_percent"] = control_mean_percent | |
| 249 res["treatment_mean_percent"] = treatment_mean_percent | |
| 250 res["gene"] = utr | |
| 251 return res | |
| 252 | |
| 253 | |
| 254 def estimate_coverage_extended_utr(utr_coverage, UTR_start, | |
| 255 UTR_end, is_reverse, coverage_weigths, search_start, coverage_threshold): | |
| 256 """ | |
| 257 We are searching for a breakpoint in coverage?! | |
| 258 utr_coverage is a list with items corresponding to numpy arrays of coverage for a sample. | |
| 259 """ | |
| 260 search_point_end = int(abs((UTR_end - UTR_start)) * 0.1) # TODO: This is 10% of total UTR end. Why? | |
| 261 num_samples = len(utr_coverage) | |
| 262 ##read coverage filtering | |
| 263 normalized_utr_coverage = [coverage/ coverage_weigths[i] for i, coverage in enumerate( utr_coverage.values() )] | |
| 264 start_coverage = [np.mean(coverage[0:99]) for coverage in utr_coverage.values()] # filters threshold on mean coverage over first 100 nt | |
| 265 is_above_threshold = sum(np.array(start_coverage) >= coverage_threshold) >= num_samples # This filters on the raw threshold. Why? | |
| 266 is_above_length = UTR_end - UTR_start >= 150 | |
| 267 if (is_above_threshold) and (is_above_length): | |
| 268 if not is_reverse: | |
| 269 search_region = range(UTR_start + search_start, UTR_end - search_point_end + 1) | |
| 270 else: | |
| 271 search_region = range(UTR_end - search_start, UTR_start + search_point_end - 1, -1) | |
| 272 search_end = UTR_end - UTR_start - search_point_end | |
| 273 normalized_utr_coverage = np.array(normalized_utr_coverage) | |
| 274 breakpoints = range(search_start, search_end + 1) | |
| 275 mse_list = [ estimate_mse(normalized_utr_coverage,bp, num_samples) for bp in breakpoints ] | |
| 276 if len(mse_list) > 0: | |
| 277 min_ele_index = mse_list.index(min(mse_list)) | |
| 278 breakpoint = breakpoints[min_ele_index] | |
| 279 UTR_abundances = estimate_abundance(normalized_utr_coverage, breakpoint, num_samples) | |
| 280 select_mean_squared_error = mse_list[min_ele_index] | |
| 281 selected_break_point = breakpoint | |
| 282 else: | |
| 283 select_mean_squared_error = 'Na' | |
| 284 UTR_abundances = 'Na' | |
| 285 selected_break_point = 'Na' | |
| 286 else: | |
| 287 select_mean_squared_error = 'Na' | |
| 288 UTR_abundances = 'Na' | |
| 289 selected_break_point = 'Na' | |
| 290 | |
| 291 return select_mean_squared_error, selected_break_point, UTR_abundances | |
| 292 | |
| 293 | |
| 294 def estimate_mse(cov, bp, num_samples): | |
| 295 """ | |
| 296 get abundance of long utr vs short utr with breakpoint specifying the position of long and short utr. | |
| 297 """ | |
| 298 with warnings.catch_warnings(): | |
| 299 warnings.simplefilter("ignore", category=RuntimeWarning) | |
| 300 long_utr_vector = cov[:num_samples, bp:] | |
| 301 short_utr_vector = cov[:num_samples, 0:bp] | |
| 302 mean_long_utr = np.mean(long_utr_vector, 1) | |
| 303 mean_short_utr = np.mean(short_utr_vector, 1) | |
| 304 square_mean_centered_short_utr_vector = (short_utr_vector[:num_samples] - mean_short_utr[:, np.newaxis] )**2 | |
| 305 square_mean_centered_long_utr_vector = (long_utr_vector[:num_samples] - mean_long_utr[:, np.newaxis])**2 | |
| 306 mse = np.mean(np.append(square_mean_centered_short_utr_vector[:num_samples], square_mean_centered_long_utr_vector[:num_samples])) | |
| 307 return mse | |
| 308 | |
| 309 def estimate_abundance(cov, bp, num_samples): | |
| 310 with warnings.catch_warnings(): | |
| 311 warnings.simplefilter("ignore", category=RuntimeWarning) | |
| 312 long_utr_vector = cov[:num_samples, bp:] | |
| 313 short_utr_vector = cov[:num_samples, 0:bp] | |
| 314 mean_long_utr = np.mean(long_utr_vector, 1) | |
| 315 mean_short_utr = np.mean(short_utr_vector, 1) | |
| 316 return mean_long_utr, mean_short_utr | |
| 317 | |
| 318 | |
| 319 if __name__ == '__main__': | |
| 320 args = parse_args() | |
| 321 find_utr = UtrFinder(args) | |
| 322 |
