comparison dapars.py @ 0:bb84ee2f2137 draft

planemo upload for repository https://github.com/mvdbeek/dapars commit 868f8f2f7ac5d70c39b7d725ff087833b0f24f52-dirty
author mvdbeek
date Tue, 27 Oct 2015 10:14:33 -0400
parents
children 1b20ba32b4c5
comparison
equal deleted inserted replaced
-1:000000000000 0:bb84ee2f2137
1 import argparse
2 import os
3 import csv
4 import numpy as np
5 from collections import OrderedDict, namedtuple
6 import filter_utr
7 import subprocess
8 from multiprocessing import Pool
9 import warnings
10
11
12 def parse_args():
13 """
14 Returns floating point values except for input files.
15 My initial approach will not filter anything. (FDR. fold_change, PDUI, Num_least ...)
16 :param argv:
17 :return:
18 """
19 parser = argparse.ArgumentParser(prog='DaPars', description='Determines the usage of proximal polyA usage')
20 parser.add_argument("-c", "--control_alignments", nargs="+", required=True,
21 help="Alignment files in BAM format from control condition")
22 parser.add_argument("-t", "--treatment_alignments", nargs="+", required=True,
23 help="Alignment files in BAM format from treatment condition")
24 parser.add_argument("-u", "--utr_bed_file", required=True, type=file,
25 help="Bed file describing longest 3UTR positions")
26 parser.add_argument("-o", "--output_file", required=True, type=argparse.FileType('w'),
27 help="file containing output")
28 parser.add_argument("-cpu", required=False, type=int, default=1,
29 help="Number of CPU cores to use.")
30 parser.add_argument("-s", "--search_start", required=False, type=int, default=50,
31 help="Start search for breakpoint n nucleotides downstream of UTR start")
32 parser.add_argument("-ct", "--coverage_threshold", required=False, type=float, default=20,
33 help="minimum coverage in each aligment to be considered for determining breakpoints")
34 parser.add_argument("-b", "--breakpoint_bed", required=False, type=argparse.FileType('w'),
35 help="Write bedfile with coordinates of breakpoint positions to supplied path.")
36 parser.add_argument("-v", "--version", action='version', version='%(prog)s 0.1.4')
37 return parser.parse_args()
38
39
40 class UtrFinder():
41 """
42 This seems to be the main caller.
43 """
44
45 def __init__(self, args):
46 self.control_alignments = [file for file in args.control_alignments]
47 self.treatment_alignments = [file for file in args.treatment_alignments]
48 self.n_cpus = args.cpu
49 self.search_start = args.search_start
50 self.coverage_threshold = args.coverage_threshold
51 self.utr = args.utr_bed_file
52 self.gtf_fields = filter_utr.get_gtf_fields()
53 self.result_file = args.output_file
54 self.all_alignments = self.control_alignments + self.treatment_alignments
55 self.alignment_names = { file: os.path.basename(file) for file in self.all_alignments }
56 self.num_samples = len(self.all_alignments)
57 self.utr_dict = self.get_utr_dict(0.2)
58 self.dump_utr_dict_to_bedfile()
59 print "Established dictionary of 3\'UTRs"
60 self.coverage_files = self.run_bedtools_coverage()
61 self.utr_coverages = self.read_coverage_result()
62 print "Established dictionary of 3\'UTR coverages"
63 self.coverage_weights = self.get_coverage_weights()
64 self.result_tuple = self.get_result_tuple()
65 self.result_d = self.calculate_apa_ratios()
66 self.write_results()
67 if args.breakpoint_bed:
68 self.bed_output = args.breakpoint_bed
69 self.write_bed()
70
71
72 def dump_utr_dict_to_bedfile(self):
73 w = csv.writer(open("tmp_bedfile.bed", "w"), delimiter="\t")
74 for gene, utr in self.utr_dict.iteritems():
75 w.writerow([utr["chr"], utr["new_start"]-1, utr["new_end"], gene, ".", utr["strand"]])
76
77 def run_bedtools_coverage(self):
78 """
79 Use bedtools coverage to generate pileup data for all alignment files for the regions specified in utr_dict.
80 """
81 coverage_files = []
82 cmds = []
83 for alignment_file in self.all_alignments:
84 cmd = "sort -k1,1 -k2,2n tmp_bedfile.bed | "
85 cmd = cmd + "bedtools coverage -d -s -abam {alignment_file} -b stdin |" \
86 " cut -f 4,7,8 > coverage_file_{alignment_name}".format(
87 alignment_file = alignment_file, alignment_name= self.alignment_names[alignment_file] )
88 cmds.append(cmd)
89 pool = Pool(self.n_cpus)
90 subprocesses = [subprocess.Popen([cmd], shell=True) for cmd in cmds]
91 [p.wait() for p in subprocesses]
92 coverage_files = ["gene_position_coverage_{alignment_name}".format(
93 alignment_name = self.alignment_names[alignment_file]) for alignment_file in self.all_alignments ]
94 return coverage_files
95
96 def read_coverage_result(self):
97 """
98 Read coverages back in and store as dictionary of numpy arrays
99 """
100 coverage_dict = { gene: { name: np.zeros(utr_d["new_end"]+1-utr_d["new_start"]) for name in self.alignment_names.itervalues() } for gene, utr_d in self.utr_dict.iteritems() }
101 for alignment_name in self.alignment_names.itervalues():
102 with open("coverage_file_{alignment_name}".format(alignment_name = alignment_name)) as coverage_file:
103 for line in coverage_file:
104 gene, position, coverage= line.strip().split("\t")
105 coverage_dict[gene][alignment_name][int(position)-1] = coverage
106 for utr_d in self.utr_dict.itervalues():
107 if utr_d["strand"] == "-":
108 for alignment_name in self.alignment_names.values():
109 coverage_dict[gene][alignment_name] = coverage_dict[gene][alignment_name][::-1]
110 return coverage_dict
111
112 def get_utr_dict(self, shift):
113 utr_dict = OrderedDict()
114 for line in self.utr:
115 if not line.startswith("#"):
116 filter_utr.get_feature_dict( line=line, gtf_fields=self.gtf_fields, utr_dict=utr_dict, feature="UTR" )
117 gene, utr_d = utr_dict.popitem()
118 utr_d = utr_d[0]
119 end_shift = int(round(abs(utr_d["start"] - utr_d["end"]) * shift))
120 if utr_d["strand"] == "+":
121 utr_d["new_end"] = utr_d["end"] - end_shift
122 utr_d["new_start"] = utr_d["start"]
123 else:
124 utr_d["new_end"] = utr_d["end"]
125 utr_d["new_start"] = utr_d["start"] + end_shift
126 if utr_d["new_start"] + 50 < utr_d["new_end"]:
127 utr_dict[gene] = utr_d
128 return utr_dict
129
130 def get_utr_coverage(self):
131 """
132 Returns a dict:
133 { UTR : [coverage_aligment1, ...]}
134 """
135 utr_coverages = {}
136 for utr, utr_d in self.utr_dict.iteritems():
137 if utr_d["chr"] in self.available_chromosomes:
138 if utr_d["strand"] == "+":
139 is_reverse = False
140 else:
141 is_reverse = True
142 utr_coverage = []
143 for bam in self.all_alignments:
144 bp_coverage = get_bp_coverage(bam, utr_d["chr"], utr_d["new_start"], utr_d["new_end"], is_reverse)
145 utr_coverage.append(bp_coverage)
146 utr_coverages[utr] = utr_coverage
147 return utr_coverages
148
149 def get_coverage_weights(self):
150 """
151 Return weights for normalizing coverage.
152 utr_coverage is still confusing.
153 """
154 coverage_per_alignment = []
155 for utr in self.utr_coverages.itervalues(): # TODO: be smarter about this.
156 utr_coverage = []
157 for vector in utr.itervalues():
158 utr_coverage.append(np.sum(vector))
159 coverage_per_alignment.append(utr_coverage)
160 coverages = np.array([ sum(x) for x in zip(*coverage_per_alignment) ])
161 coverage_weights = coverages / np.mean(coverages) # TODO: proabably median is better suited?
162 return coverage_weights
163
164 def get_result_tuple(self):
165 static_desc = ["chr", "start", "end", "strand", "gene", "breakpoint", "control_mean_percent", "treatment_mean_percent" ]
166 samples_desc = []
167 for statistic in ["coverage_long", "coverage_short", "percent_long"]:
168 for i, sample in enumerate(self.control_alignments):
169 samples_desc.append("control_{i}_{statistic}".format(i=i, statistic = statistic))
170 for i, sample in enumerate(self.treatment_alignments):
171 samples_desc.append("treatment_{i}_{statistic}".format(i=i, statistic = statistic))
172 return namedtuple("result", static_desc + samples_desc)
173
174 def calculate_apa_ratios(self):
175 result_d = OrderedDict()
176 arg_d = {"result_tuple": self.result_tuple,
177 "coverage_weights":self.coverage_weights,
178 "num_samples":self.num_samples,
179 "num_control":len(self.control_alignments),
180 "num_treatment":len(self.treatment_alignments),
181 "result_d":result_d}
182 pool = Pool(self.n_cpus)
183 tasks = [ (self.utr_coverages[utr], utr, utr_d, self.result_tuple._fields, self.coverage_weights, self.num_samples,
184 len(self.control_alignments), len(self.treatment_alignments), result_d, self.search_start, self.coverage_threshold) for utr, utr_d in self.utr_dict.iteritems() ]
185 processed_tasks = [ pool.apply_async(calculate_all_utr, t) for t in tasks]
186 result = [res.get() for res in processed_tasks]
187 for d in result:
188 if isinstance(d, dict):
189 t = self.result_tuple(**d)
190 result_d[d["gene"]] = t
191 return result_d
192
193 def write_results(self):
194 w = csv.writer(self.result_file, delimiter='\t')
195 header = list(self.result_tuple._fields)
196 header[0] = "#chr"
197 w.writerow(header) # field header
198 w.writerows( self.result_d.values())
199
200 def write_bed(self):
201 w = csv.writer(self.bed_output, delimiter='\t')
202 bed = [(result.chr, result.breakpoint, result.breakpoint+1, result.gene, 0, result.strand) for result in self.result_d.itervalues()]
203 w.writerows(bed)
204
205 def calculate_all_utr(utr_coverage, utr, utr_d, result_tuple_fields, coverage_weights, num_samples, num_control,
206 num_treatment, result_d, search_start, coverage_threshold):
207 res = dict(zip(result_tuple_fields, result_tuple_fields))
208 if utr_d["strand"] == "+":
209 is_reverse = False
210 else:
211 is_reverse = True
212 mse, breakpoint, abundances = estimate_coverage_extended_utr(utr_coverage,
213 utr_d["new_start"],
214 utr_d["new_end"],
215 is_reverse,
216 coverage_weights,
217 search_start,
218 coverage_threshold)
219 if not str(mse) == "Na":
220 long_coverage_vector = abundances[0]
221 short_coverage_vector = abundances[1]
222 num_non_zero = sum((np.array(long_coverage_vector) + np.array(short_coverage_vector)) > 0) # TODO: This introduces bias
223 if num_non_zero == num_samples:
224 percentage_long = []
225 for i in range(num_samples):
226 ratio = float(long_coverage_vector[i]) / (long_coverage_vector[i] + short_coverage_vector[i]) # long 3'UTR percentage
227 percentage_long.append(ratio)
228 for i in range(num_control):
229 res["control_{i}_coverage_long".format(i=i)] = float(long_coverage_vector[i])
230 res["control_{i}_coverage_short".format(i=i)] = float(short_coverage_vector[i])
231 res["control_{i}_percent_long".format(i=i)] = percentage_long[i]
232 for k in range(num_treatment):
233 i = k + num_control
234 res["treatment_{i}_coverage_long".format(i=k)] = float(long_coverage_vector[i])
235 res["treatment_{i}_coverage_short".format(i=k)] = float(short_coverage_vector[i])
236 res["treatment_{i}_percent_long".format(i=k)] = percentage_long[i]
237 control_mean_percent = np.mean(np.array(percentage_long[:num_control]))
238 treatment_mean_percent = np.mean(np.array(percentage_long[num_control:]))
239 res["chr"] = utr_d["chr"]
240 res["start"] = utr_d["start"]
241 res["end"] = utr_d["end"]
242 res["strand"] = utr_d["strand"]
243 if is_reverse:
244 breakpoint = utr_d["new_end"] - breakpoint
245 else:
246 breakpoint = utr_d["new_start"] + breakpoint
247 res["breakpoint"] = breakpoint
248 res["control_mean_percent"] = control_mean_percent
249 res["treatment_mean_percent"] = treatment_mean_percent
250 res["gene"] = utr
251 return res
252
253
254 def estimate_coverage_extended_utr(utr_coverage, UTR_start,
255 UTR_end, is_reverse, coverage_weigths, search_start, coverage_threshold):
256 """
257 We are searching for a breakpoint in coverage?!
258 utr_coverage is a list with items corresponding to numpy arrays of coverage for a sample.
259 """
260 search_point_end = int(abs((UTR_end - UTR_start)) * 0.1) # TODO: This is 10% of total UTR end. Why?
261 num_samples = len(utr_coverage)
262 ##read coverage filtering
263 normalized_utr_coverage = [coverage/ coverage_weigths[i] for i, coverage in enumerate( utr_coverage.values() )]
264 start_coverage = [np.mean(coverage[0:99]) for coverage in utr_coverage.values()] # filters threshold on mean coverage over first 100 nt
265 is_above_threshold = sum(np.array(start_coverage) >= coverage_threshold) >= num_samples # This filters on the raw threshold. Why?
266 is_above_length = UTR_end - UTR_start >= 150
267 if (is_above_threshold) and (is_above_length):
268 if not is_reverse:
269 search_region = range(UTR_start + search_start, UTR_end - search_point_end + 1)
270 else:
271 search_region = range(UTR_end - search_start, UTR_start + search_point_end - 1, -1)
272 search_end = UTR_end - UTR_start - search_point_end
273 normalized_utr_coverage = np.array(normalized_utr_coverage)
274 breakpoints = range(search_start, search_end + 1)
275 mse_list = [ estimate_mse(normalized_utr_coverage,bp, num_samples) for bp in breakpoints ]
276 if len(mse_list) > 0:
277 min_ele_index = mse_list.index(min(mse_list))
278 breakpoint = breakpoints[min_ele_index]
279 UTR_abundances = estimate_abundance(normalized_utr_coverage, breakpoint, num_samples)
280 select_mean_squared_error = mse_list[min_ele_index]
281 selected_break_point = breakpoint
282 else:
283 select_mean_squared_error = 'Na'
284 UTR_abundances = 'Na'
285 selected_break_point = 'Na'
286 else:
287 select_mean_squared_error = 'Na'
288 UTR_abundances = 'Na'
289 selected_break_point = 'Na'
290
291 return select_mean_squared_error, selected_break_point, UTR_abundances
292
293
294 def estimate_mse(cov, bp, num_samples):
295 """
296 get abundance of long utr vs short utr with breakpoint specifying the position of long and short utr.
297 """
298 with warnings.catch_warnings():
299 warnings.simplefilter("ignore", category=RuntimeWarning)
300 long_utr_vector = cov[:num_samples, bp:]
301 short_utr_vector = cov[:num_samples, 0:bp]
302 mean_long_utr = np.mean(long_utr_vector, 1)
303 mean_short_utr = np.mean(short_utr_vector, 1)
304 square_mean_centered_short_utr_vector = (short_utr_vector[:num_samples] - mean_short_utr[:, np.newaxis] )**2
305 square_mean_centered_long_utr_vector = (long_utr_vector[:num_samples] - mean_long_utr[:, np.newaxis])**2
306 mse = np.mean(np.append(square_mean_centered_short_utr_vector[:num_samples], square_mean_centered_long_utr_vector[:num_samples]))
307 return mse
308
309 def estimate_abundance(cov, bp, num_samples):
310 with warnings.catch_warnings():
311 warnings.simplefilter("ignore", category=RuntimeWarning)
312 long_utr_vector = cov[:num_samples, bp:]
313 short_utr_vector = cov[:num_samples, 0:bp]
314 mean_long_utr = np.mean(long_utr_vector, 1)
315 mean_short_utr = np.mean(short_utr_vector, 1)
316 return mean_long_utr, mean_short_utr
317
318
319 if __name__ == '__main__':
320 args = parse_args()
321 find_utr = UtrFinder(args)
322