Mercurial > repos > mvdbeek > dapars
comparison dapars.py @ 0:bb84ee2f2137 draft
planemo upload for repository https://github.com/mvdbeek/dapars commit 868f8f2f7ac5d70c39b7d725ff087833b0f24f52-dirty
author | mvdbeek |
---|---|
date | Tue, 27 Oct 2015 10:14:33 -0400 |
parents | |
children | 1b20ba32b4c5 |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:bb84ee2f2137 |
---|---|
1 import argparse | |
2 import os | |
3 import csv | |
4 import numpy as np | |
5 from collections import OrderedDict, namedtuple | |
6 import filter_utr | |
7 import subprocess | |
8 from multiprocessing import Pool | |
9 import warnings | |
10 | |
11 | |
12 def parse_args(): | |
13 """ | |
14 Returns floating point values except for input files. | |
15 My initial approach will not filter anything. (FDR. fold_change, PDUI, Num_least ...) | |
16 :param argv: | |
17 :return: | |
18 """ | |
19 parser = argparse.ArgumentParser(prog='DaPars', description='Determines the usage of proximal polyA usage') | |
20 parser.add_argument("-c", "--control_alignments", nargs="+", required=True, | |
21 help="Alignment files in BAM format from control condition") | |
22 parser.add_argument("-t", "--treatment_alignments", nargs="+", required=True, | |
23 help="Alignment files in BAM format from treatment condition") | |
24 parser.add_argument("-u", "--utr_bed_file", required=True, type=file, | |
25 help="Bed file describing longest 3UTR positions") | |
26 parser.add_argument("-o", "--output_file", required=True, type=argparse.FileType('w'), | |
27 help="file containing output") | |
28 parser.add_argument("-cpu", required=False, type=int, default=1, | |
29 help="Number of CPU cores to use.") | |
30 parser.add_argument("-s", "--search_start", required=False, type=int, default=50, | |
31 help="Start search for breakpoint n nucleotides downstream of UTR start") | |
32 parser.add_argument("-ct", "--coverage_threshold", required=False, type=float, default=20, | |
33 help="minimum coverage in each aligment to be considered for determining breakpoints") | |
34 parser.add_argument("-b", "--breakpoint_bed", required=False, type=argparse.FileType('w'), | |
35 help="Write bedfile with coordinates of breakpoint positions to supplied path.") | |
36 parser.add_argument("-v", "--version", action='version', version='%(prog)s 0.1.4') | |
37 return parser.parse_args() | |
38 | |
39 | |
40 class UtrFinder(): | |
41 """ | |
42 This seems to be the main caller. | |
43 """ | |
44 | |
45 def __init__(self, args): | |
46 self.control_alignments = [file for file in args.control_alignments] | |
47 self.treatment_alignments = [file for file in args.treatment_alignments] | |
48 self.n_cpus = args.cpu | |
49 self.search_start = args.search_start | |
50 self.coverage_threshold = args.coverage_threshold | |
51 self.utr = args.utr_bed_file | |
52 self.gtf_fields = filter_utr.get_gtf_fields() | |
53 self.result_file = args.output_file | |
54 self.all_alignments = self.control_alignments + self.treatment_alignments | |
55 self.alignment_names = { file: os.path.basename(file) for file in self.all_alignments } | |
56 self.num_samples = len(self.all_alignments) | |
57 self.utr_dict = self.get_utr_dict(0.2) | |
58 self.dump_utr_dict_to_bedfile() | |
59 print "Established dictionary of 3\'UTRs" | |
60 self.coverage_files = self.run_bedtools_coverage() | |
61 self.utr_coverages = self.read_coverage_result() | |
62 print "Established dictionary of 3\'UTR coverages" | |
63 self.coverage_weights = self.get_coverage_weights() | |
64 self.result_tuple = self.get_result_tuple() | |
65 self.result_d = self.calculate_apa_ratios() | |
66 self.write_results() | |
67 if args.breakpoint_bed: | |
68 self.bed_output = args.breakpoint_bed | |
69 self.write_bed() | |
70 | |
71 | |
72 def dump_utr_dict_to_bedfile(self): | |
73 w = csv.writer(open("tmp_bedfile.bed", "w"), delimiter="\t") | |
74 for gene, utr in self.utr_dict.iteritems(): | |
75 w.writerow([utr["chr"], utr["new_start"]-1, utr["new_end"], gene, ".", utr["strand"]]) | |
76 | |
77 def run_bedtools_coverage(self): | |
78 """ | |
79 Use bedtools coverage to generate pileup data for all alignment files for the regions specified in utr_dict. | |
80 """ | |
81 coverage_files = [] | |
82 cmds = [] | |
83 for alignment_file in self.all_alignments: | |
84 cmd = "sort -k1,1 -k2,2n tmp_bedfile.bed | " | |
85 cmd = cmd + "bedtools coverage -d -s -abam {alignment_file} -b stdin |" \ | |
86 " cut -f 4,7,8 > coverage_file_{alignment_name}".format( | |
87 alignment_file = alignment_file, alignment_name= self.alignment_names[alignment_file] ) | |
88 cmds.append(cmd) | |
89 pool = Pool(self.n_cpus) | |
90 subprocesses = [subprocess.Popen([cmd], shell=True) for cmd in cmds] | |
91 [p.wait() for p in subprocesses] | |
92 coverage_files = ["gene_position_coverage_{alignment_name}".format( | |
93 alignment_name = self.alignment_names[alignment_file]) for alignment_file in self.all_alignments ] | |
94 return coverage_files | |
95 | |
96 def read_coverage_result(self): | |
97 """ | |
98 Read coverages back in and store as dictionary of numpy arrays | |
99 """ | |
100 coverage_dict = { gene: { name: np.zeros(utr_d["new_end"]+1-utr_d["new_start"]) for name in self.alignment_names.itervalues() } for gene, utr_d in self.utr_dict.iteritems() } | |
101 for alignment_name in self.alignment_names.itervalues(): | |
102 with open("coverage_file_{alignment_name}".format(alignment_name = alignment_name)) as coverage_file: | |
103 for line in coverage_file: | |
104 gene, position, coverage= line.strip().split("\t") | |
105 coverage_dict[gene][alignment_name][int(position)-1] = coverage | |
106 for utr_d in self.utr_dict.itervalues(): | |
107 if utr_d["strand"] == "-": | |
108 for alignment_name in self.alignment_names.values(): | |
109 coverage_dict[gene][alignment_name] = coverage_dict[gene][alignment_name][::-1] | |
110 return coverage_dict | |
111 | |
112 def get_utr_dict(self, shift): | |
113 utr_dict = OrderedDict() | |
114 for line in self.utr: | |
115 if not line.startswith("#"): | |
116 filter_utr.get_feature_dict( line=line, gtf_fields=self.gtf_fields, utr_dict=utr_dict, feature="UTR" ) | |
117 gene, utr_d = utr_dict.popitem() | |
118 utr_d = utr_d[0] | |
119 end_shift = int(round(abs(utr_d["start"] - utr_d["end"]) * shift)) | |
120 if utr_d["strand"] == "+": | |
121 utr_d["new_end"] = utr_d["end"] - end_shift | |
122 utr_d["new_start"] = utr_d["start"] | |
123 else: | |
124 utr_d["new_end"] = utr_d["end"] | |
125 utr_d["new_start"] = utr_d["start"] + end_shift | |
126 if utr_d["new_start"] + 50 < utr_d["new_end"]: | |
127 utr_dict[gene] = utr_d | |
128 return utr_dict | |
129 | |
130 def get_utr_coverage(self): | |
131 """ | |
132 Returns a dict: | |
133 { UTR : [coverage_aligment1, ...]} | |
134 """ | |
135 utr_coverages = {} | |
136 for utr, utr_d in self.utr_dict.iteritems(): | |
137 if utr_d["chr"] in self.available_chromosomes: | |
138 if utr_d["strand"] == "+": | |
139 is_reverse = False | |
140 else: | |
141 is_reverse = True | |
142 utr_coverage = [] | |
143 for bam in self.all_alignments: | |
144 bp_coverage = get_bp_coverage(bam, utr_d["chr"], utr_d["new_start"], utr_d["new_end"], is_reverse) | |
145 utr_coverage.append(bp_coverage) | |
146 utr_coverages[utr] = utr_coverage | |
147 return utr_coverages | |
148 | |
149 def get_coverage_weights(self): | |
150 """ | |
151 Return weights for normalizing coverage. | |
152 utr_coverage is still confusing. | |
153 """ | |
154 coverage_per_alignment = [] | |
155 for utr in self.utr_coverages.itervalues(): # TODO: be smarter about this. | |
156 utr_coverage = [] | |
157 for vector in utr.itervalues(): | |
158 utr_coverage.append(np.sum(vector)) | |
159 coverage_per_alignment.append(utr_coverage) | |
160 coverages = np.array([ sum(x) for x in zip(*coverage_per_alignment) ]) | |
161 coverage_weights = coverages / np.mean(coverages) # TODO: proabably median is better suited? | |
162 return coverage_weights | |
163 | |
164 def get_result_tuple(self): | |
165 static_desc = ["chr", "start", "end", "strand", "gene", "breakpoint", "control_mean_percent", "treatment_mean_percent" ] | |
166 samples_desc = [] | |
167 for statistic in ["coverage_long", "coverage_short", "percent_long"]: | |
168 for i, sample in enumerate(self.control_alignments): | |
169 samples_desc.append("control_{i}_{statistic}".format(i=i, statistic = statistic)) | |
170 for i, sample in enumerate(self.treatment_alignments): | |
171 samples_desc.append("treatment_{i}_{statistic}".format(i=i, statistic = statistic)) | |
172 return namedtuple("result", static_desc + samples_desc) | |
173 | |
174 def calculate_apa_ratios(self): | |
175 result_d = OrderedDict() | |
176 arg_d = {"result_tuple": self.result_tuple, | |
177 "coverage_weights":self.coverage_weights, | |
178 "num_samples":self.num_samples, | |
179 "num_control":len(self.control_alignments), | |
180 "num_treatment":len(self.treatment_alignments), | |
181 "result_d":result_d} | |
182 pool = Pool(self.n_cpus) | |
183 tasks = [ (self.utr_coverages[utr], utr, utr_d, self.result_tuple._fields, self.coverage_weights, self.num_samples, | |
184 len(self.control_alignments), len(self.treatment_alignments), result_d, self.search_start, self.coverage_threshold) for utr, utr_d in self.utr_dict.iteritems() ] | |
185 processed_tasks = [ pool.apply_async(calculate_all_utr, t) for t in tasks] | |
186 result = [res.get() for res in processed_tasks] | |
187 for d in result: | |
188 if isinstance(d, dict): | |
189 t = self.result_tuple(**d) | |
190 result_d[d["gene"]] = t | |
191 return result_d | |
192 | |
193 def write_results(self): | |
194 w = csv.writer(self.result_file, delimiter='\t') | |
195 header = list(self.result_tuple._fields) | |
196 header[0] = "#chr" | |
197 w.writerow(header) # field header | |
198 w.writerows( self.result_d.values()) | |
199 | |
200 def write_bed(self): | |
201 w = csv.writer(self.bed_output, delimiter='\t') | |
202 bed = [(result.chr, result.breakpoint, result.breakpoint+1, result.gene, 0, result.strand) for result in self.result_d.itervalues()] | |
203 w.writerows(bed) | |
204 | |
205 def calculate_all_utr(utr_coverage, utr, utr_d, result_tuple_fields, coverage_weights, num_samples, num_control, | |
206 num_treatment, result_d, search_start, coverage_threshold): | |
207 res = dict(zip(result_tuple_fields, result_tuple_fields)) | |
208 if utr_d["strand"] == "+": | |
209 is_reverse = False | |
210 else: | |
211 is_reverse = True | |
212 mse, breakpoint, abundances = estimate_coverage_extended_utr(utr_coverage, | |
213 utr_d["new_start"], | |
214 utr_d["new_end"], | |
215 is_reverse, | |
216 coverage_weights, | |
217 search_start, | |
218 coverage_threshold) | |
219 if not str(mse) == "Na": | |
220 long_coverage_vector = abundances[0] | |
221 short_coverage_vector = abundances[1] | |
222 num_non_zero = sum((np.array(long_coverage_vector) + np.array(short_coverage_vector)) > 0) # TODO: This introduces bias | |
223 if num_non_zero == num_samples: | |
224 percentage_long = [] | |
225 for i in range(num_samples): | |
226 ratio = float(long_coverage_vector[i]) / (long_coverage_vector[i] + short_coverage_vector[i]) # long 3'UTR percentage | |
227 percentage_long.append(ratio) | |
228 for i in range(num_control): | |
229 res["control_{i}_coverage_long".format(i=i)] = float(long_coverage_vector[i]) | |
230 res["control_{i}_coverage_short".format(i=i)] = float(short_coverage_vector[i]) | |
231 res["control_{i}_percent_long".format(i=i)] = percentage_long[i] | |
232 for k in range(num_treatment): | |
233 i = k + num_control | |
234 res["treatment_{i}_coverage_long".format(i=k)] = float(long_coverage_vector[i]) | |
235 res["treatment_{i}_coverage_short".format(i=k)] = float(short_coverage_vector[i]) | |
236 res["treatment_{i}_percent_long".format(i=k)] = percentage_long[i] | |
237 control_mean_percent = np.mean(np.array(percentage_long[:num_control])) | |
238 treatment_mean_percent = np.mean(np.array(percentage_long[num_control:])) | |
239 res["chr"] = utr_d["chr"] | |
240 res["start"] = utr_d["start"] | |
241 res["end"] = utr_d["end"] | |
242 res["strand"] = utr_d["strand"] | |
243 if is_reverse: | |
244 breakpoint = utr_d["new_end"] - breakpoint | |
245 else: | |
246 breakpoint = utr_d["new_start"] + breakpoint | |
247 res["breakpoint"] = breakpoint | |
248 res["control_mean_percent"] = control_mean_percent | |
249 res["treatment_mean_percent"] = treatment_mean_percent | |
250 res["gene"] = utr | |
251 return res | |
252 | |
253 | |
254 def estimate_coverage_extended_utr(utr_coverage, UTR_start, | |
255 UTR_end, is_reverse, coverage_weigths, search_start, coverage_threshold): | |
256 """ | |
257 We are searching for a breakpoint in coverage?! | |
258 utr_coverage is a list with items corresponding to numpy arrays of coverage for a sample. | |
259 """ | |
260 search_point_end = int(abs((UTR_end - UTR_start)) * 0.1) # TODO: This is 10% of total UTR end. Why? | |
261 num_samples = len(utr_coverage) | |
262 ##read coverage filtering | |
263 normalized_utr_coverage = [coverage/ coverage_weigths[i] for i, coverage in enumerate( utr_coverage.values() )] | |
264 start_coverage = [np.mean(coverage[0:99]) for coverage in utr_coverage.values()] # filters threshold on mean coverage over first 100 nt | |
265 is_above_threshold = sum(np.array(start_coverage) >= coverage_threshold) >= num_samples # This filters on the raw threshold. Why? | |
266 is_above_length = UTR_end - UTR_start >= 150 | |
267 if (is_above_threshold) and (is_above_length): | |
268 if not is_reverse: | |
269 search_region = range(UTR_start + search_start, UTR_end - search_point_end + 1) | |
270 else: | |
271 search_region = range(UTR_end - search_start, UTR_start + search_point_end - 1, -1) | |
272 search_end = UTR_end - UTR_start - search_point_end | |
273 normalized_utr_coverage = np.array(normalized_utr_coverage) | |
274 breakpoints = range(search_start, search_end + 1) | |
275 mse_list = [ estimate_mse(normalized_utr_coverage,bp, num_samples) for bp in breakpoints ] | |
276 if len(mse_list) > 0: | |
277 min_ele_index = mse_list.index(min(mse_list)) | |
278 breakpoint = breakpoints[min_ele_index] | |
279 UTR_abundances = estimate_abundance(normalized_utr_coverage, breakpoint, num_samples) | |
280 select_mean_squared_error = mse_list[min_ele_index] | |
281 selected_break_point = breakpoint | |
282 else: | |
283 select_mean_squared_error = 'Na' | |
284 UTR_abundances = 'Na' | |
285 selected_break_point = 'Na' | |
286 else: | |
287 select_mean_squared_error = 'Na' | |
288 UTR_abundances = 'Na' | |
289 selected_break_point = 'Na' | |
290 | |
291 return select_mean_squared_error, selected_break_point, UTR_abundances | |
292 | |
293 | |
294 def estimate_mse(cov, bp, num_samples): | |
295 """ | |
296 get abundance of long utr vs short utr with breakpoint specifying the position of long and short utr. | |
297 """ | |
298 with warnings.catch_warnings(): | |
299 warnings.simplefilter("ignore", category=RuntimeWarning) | |
300 long_utr_vector = cov[:num_samples, bp:] | |
301 short_utr_vector = cov[:num_samples, 0:bp] | |
302 mean_long_utr = np.mean(long_utr_vector, 1) | |
303 mean_short_utr = np.mean(short_utr_vector, 1) | |
304 square_mean_centered_short_utr_vector = (short_utr_vector[:num_samples] - mean_short_utr[:, np.newaxis] )**2 | |
305 square_mean_centered_long_utr_vector = (long_utr_vector[:num_samples] - mean_long_utr[:, np.newaxis])**2 | |
306 mse = np.mean(np.append(square_mean_centered_short_utr_vector[:num_samples], square_mean_centered_long_utr_vector[:num_samples])) | |
307 return mse | |
308 | |
309 def estimate_abundance(cov, bp, num_samples): | |
310 with warnings.catch_warnings(): | |
311 warnings.simplefilter("ignore", category=RuntimeWarning) | |
312 long_utr_vector = cov[:num_samples, bp:] | |
313 short_utr_vector = cov[:num_samples, 0:bp] | |
314 mean_long_utr = np.mean(long_utr_vector, 1) | |
315 mean_short_utr = np.mean(short_utr_vector, 1) | |
316 return mean_long_utr, mean_short_utr | |
317 | |
318 | |
319 if __name__ == '__main__': | |
320 args = parse_args() | |
321 find_utr = UtrFinder(args) | |
322 |