Mercurial > repos > mvdbeek > dapars
comparison dapars.py @ 17:917a2f7ab841 draft
planemo upload for repository https://github.com/mvdbeek/dapars commit b1b007c561ea6c9db145c88b6b128d66ecd05e24-dirty
author | mvdbeek |
---|---|
date | Fri, 30 Oct 2015 10:35:17 -0400 |
parents | f8bb40b2ff31 |
children | ed151db39c7e |
comparison
equal
deleted
inserted
replaced
16:f8bb40b2ff31 | 17:917a2f7ab841 |
---|---|
9 from multiprocessing import Pool | 9 from multiprocessing import Pool |
10 import warnings | 10 import warnings |
11 import matplotlib.pyplot as plt | 11 import matplotlib.pyplot as plt |
12 import matplotlib.gridspec as gridspec | 12 import matplotlib.gridspec as gridspec |
13 from tabulate import tabulate | 13 from tabulate import tabulate |
14 import statsmodels.sandbox.stats.multicomp as mc | |
14 | 15 |
15 def directory_path(str): | 16 def directory_path(str): |
16 if os.path.exists(str): | 17 if os.path.exists(str): |
17 return str | 18 return str |
18 else: | 19 else: |
33 help="Bed file describing longest 3UTR positions") | 34 help="Bed file describing longest 3UTR positions") |
34 parser.add_argument("-o", "--output_file", required=True, type=argparse.FileType('w'), | 35 parser.add_argument("-o", "--output_file", required=True, type=argparse.FileType('w'), |
35 help="file containing output") | 36 help="file containing output") |
36 parser.add_argument("-cpu", required=False, type=int, default=1, | 37 parser.add_argument("-cpu", required=False, type=int, default=1, |
37 help="Number of CPU cores to use.") | 38 help="Number of CPU cores to use.") |
39 parser.add_argument("-l", "--local_minimum", action='store_true') | |
38 parser.add_argument("-s", "--search_start", required=False, type=int, default=50, | 40 parser.add_argument("-s", "--search_start", required=False, type=int, default=50, |
39 help="Start search for breakpoint n nucleotides downstream of UTR start") | 41 help="Start search for breakpoint n nucleotides downstream of UTR start") |
40 parser.add_argument("-ct", "--coverage_threshold", required=False, type=float, default=20, | 42 parser.add_argument("-ct", "--coverage_threshold", required=False, type=float, default=20, |
41 help="minimum coverage in each aligment to be considered for determining breakpoints") | 43 help="minimum coverage in each aligment to be considered for determining breakpoints") |
42 parser.add_argument("-b", "--breakpoint_bed", required=False, type=argparse.FileType('w'), | 44 parser.add_argument("-b", "--breakpoint_bed", required=False, type=argparse.FileType('w'), |
43 help="Write bedfile with coordinates of breakpoint positions to supplied path.") | 45 help="Write bedfile with coordinates of breakpoint positions to supplied path.") |
44 parser.add_argument("-v", "--version", action='version', version='%(prog)s 0.2.2') | 46 parser.add_argument("-v", "--version", action='version', version='%(prog)s 0.2.3') |
45 parser.add_argument("-p", "--plot_path", default=None, required=False, type=directory_path, | 47 parser.add_argument("-p", "--plot_path", default=None, required=False, type=directory_path, |
46 help="If plot_path is specified will write a coverage plot for every UTR in that directory.") | 48 help="If plot_path is specified will write a coverage plot for every UTR in that directory.") |
47 parser.add_argument("-html", "--html_file", default=None, required=False, type=argparse.FileType('w'), | 49 parser.add_argument("-html", "--html_file", default=None, required=False, type=argparse.FileType('w'), |
48 help="Write an html file to the specified location. Only to be used within a galaxy wrapper") | 50 help="Write an html file to the specified location. Only to be used within a galaxy wrapper") |
49 return parser.parse_args() | 51 return parser.parse_args() |
58 self.control_alignments = [file for file in args.control_alignments] | 60 self.control_alignments = [file for file in args.control_alignments] |
59 self.treatment_alignments = [file for file in args.treatment_alignments] | 61 self.treatment_alignments = [file for file in args.treatment_alignments] |
60 self.n_cpus = args.cpu | 62 self.n_cpus = args.cpu |
61 self.search_start = args.search_start | 63 self.search_start = args.search_start |
62 self.coverage_threshold = args.coverage_threshold | 64 self.coverage_threshold = args.coverage_threshold |
65 self.local_minimum = args.local_minimum | |
63 self.plot_path = args.plot_path | 66 self.plot_path = args.plot_path |
64 self.html_file = args.html_file | 67 self.html_file = args.html_file |
65 self.utr = args.utr_bed_file | 68 self.utr = args.utr_bed_file |
66 self.gtf_fields = filter_utr.get_gtf_fields() | 69 self.gtf_fields = filter_utr.get_gtf_fields() |
67 self.result_file = args.output_file | 70 self.result_file = args.output_file |
75 self.utr_coverages = self.read_coverage_result() | 78 self.utr_coverages = self.read_coverage_result() |
76 print "Established dictionary of 3\'UTR coverages" | 79 print "Established dictionary of 3\'UTR coverages" |
77 self.coverage_weights = self.get_coverage_weights() | 80 self.coverage_weights = self.get_coverage_weights() |
78 self.result_tuple = self.get_result_tuple() | 81 self.result_tuple = self.get_result_tuple() |
79 self.result_d = self.calculate_apa_ratios() | 82 self.result_d = self.calculate_apa_ratios() |
83 self.results = self.order_by_p() | |
80 self.write_results() | 84 self.write_results() |
81 if args.breakpoint_bed: | 85 if args.breakpoint_bed: |
82 self.bed_output = args.breakpoint_bed | 86 self.bed_output = args.breakpoint_bed |
83 self.write_bed() | 87 self.write_bed() |
84 if self.plot_path: | 88 if self.plot_path: |
85 self.write_html() | 89 self.write_html() |
90 | |
91 def order_by_p(self): | |
92 results = [result for result in self.result_d.itervalues()] | |
93 p_values = np.array([ result.p_value for result in self.result_d.itervalues() ]) | |
94 adj_p_values = mc.fdrcorrection0(p_values, 0.05)[1] | |
95 sort_index = np.argsort(adj_p_values) | |
96 results = [ results[i]._replace(adj_p_value=adj_p_values[i]) for i in sort_index ] | |
97 return results | |
98 | |
86 | 99 |
87 def dump_utr_dict_to_bedfile(self): | 100 def dump_utr_dict_to_bedfile(self): |
88 w = csv.writer(open("tmp_bedfile.bed", "w"), delimiter="\t") | 101 w = csv.writer(open("tmp_bedfile.bed", "w"), delimiter="\t") |
89 for gene, utr in self.utr_dict.iteritems(): | 102 for gene, utr in self.utr_dict.iteritems(): |
90 w.writerow([utr["chr"], utr["new_start"]-1, utr["new_end"], gene, ".", utr["strand"]]) | 103 w.writerow([utr["chr"], utr["new_start"]-1, utr["new_end"], gene, ".", utr["strand"]]) |
158 coverages = np.array([ sum(x) for x in zip(*coverage_per_alignment) ]) | 171 coverages = np.array([ sum(x) for x in zip(*coverage_per_alignment) ]) |
159 coverage_weights = coverages / np.mean(coverages) # TODO: proabably median is better suited? Or even no normalization! | 172 coverage_weights = coverages / np.mean(coverages) # TODO: proabably median is better suited? Or even no normalization! |
160 return coverage_weights | 173 return coverage_weights |
161 | 174 |
162 def get_result_tuple(self): | 175 def get_result_tuple(self): |
163 static_desc = ["chr", "start", "end", "strand", "gene", "t_stat", "p_value", "breakpoint", | 176 static_desc = ["chr", "start", "end", "strand", "gene", "t_stat", "p_value", "adj_p_value", "breakpoint", |
164 "breakpoint_type", "control_mean_percent", "treatment_mean_percent" ] | 177 "breakpoint_type", "control_mean_percent", "treatment_mean_percent" ] |
165 samples_desc = [] | 178 samples_desc = [] |
166 for statistic in ["coverage_long", "coverage_short", "percent_long"]: | 179 for statistic in ["coverage_long", "coverage_short", "percent_long"]: |
167 for i, sample in enumerate(self.control_alignments): | 180 for i, sample in enumerate(self.control_alignments): |
168 samples_desc.append("control_{i}_{statistic}".format(i=i, statistic = statistic)) | 181 samples_desc.append("control_{i}_{statistic}".format(i=i, statistic = statistic)) |
178 "num_control":len(self.control_alignments), | 191 "num_control":len(self.control_alignments), |
179 "num_treatment":len(self.treatment_alignments), | 192 "num_treatment":len(self.treatment_alignments), |
180 "result_d":result_d} | 193 "result_d":result_d} |
181 pool = Pool(self.n_cpus) | 194 pool = Pool(self.n_cpus) |
182 tasks = [ (self.utr_coverages[utr], self.plot_path, utr, utr_d, self.coverage_weights, len(self.control_alignments), | 195 tasks = [ (self.utr_coverages[utr], self.plot_path, utr, utr_d, self.coverage_weights, len(self.control_alignments), |
183 len(self.treatment_alignments), self.search_start, self.coverage_threshold) \ | 196 len(self.treatment_alignments), self.search_start, self.local_minimum, self.coverage_threshold) \ |
184 for utr, utr_d in self.utr_dict.iteritems() ] | 197 for utr, utr_d in self.utr_dict.iteritems() ] |
185 processed_tasks = [ pool.apply_async(calculate_all_utr, t) for t in tasks] | 198 #processed_tasks = [ pool.apply_async(calculate_all_utr, t) for t in tasks] |
186 result_list = [res.get() for res in processed_tasks] | 199 #result_list = [res.get() for res in processed_tasks] |
187 #result_list = [calculate_all_utr(*t) for t in tasks] # uncomment for easier debugging | 200 result_list = [calculate_all_utr(*t) for t in tasks] # uncomment for easier debugging |
188 for res_control, res_treatment in result_list: | 201 for res_control, res_treatment in result_list: |
189 if not res_control: | 202 if not res_control: |
190 continue | 203 continue |
191 for i, result in enumerate(res_control): | 204 for i, result in enumerate(res_control): |
192 if isinstance(result, dict): | 205 if isinstance(result, dict): |
201 def write_results(self): | 214 def write_results(self): |
202 w = csv.writer(self.result_file, delimiter='\t') | 215 w = csv.writer(self.result_file, delimiter='\t') |
203 header = list(self.result_tuple._fields) | 216 header = list(self.result_tuple._fields) |
204 header[0] = "#chr" | 217 header[0] = "#chr" |
205 w.writerow(header) # field header | 218 w.writerow(header) # field header |
206 w.writerows( self.result_d.values()) | 219 w.writerows( self.results) |
207 | 220 |
208 def write_html(self): | 221 def write_html(self): |
209 output_lines = [(gene_str_to_link(result.gene), result.breakpoint, result.breakpoint_type, result.p_value ) for result in self.result_d.itervalues()] | 222 output_lines = [(gene_str_to_link(result.gene), result.breakpoint, result.breakpoint_type, result.p_value ) for result in self.results] |
210 if self.html_file: | 223 if self.html_file: |
211 self.html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value"], tablefmt="html")) | 224 self.html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value", "adj_p_value"], tablefmt="html")) |
212 else: | 225 else: |
213 with open(os.path.join(self.plot_path, "index.html"), "w") as html_file: | 226 with open(os.path.join(self.plot_path, "index.html"), "w") as html_file: |
214 html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value"], tablefmt="html")) | 227 html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value", "adj_p_value"], tablefmt="html")) |
228 html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value", "adj_p_value"], tablefmt="html")) | |
215 | 229 |
216 def write_bed(self): | 230 def write_bed(self): |
217 w = csv.writer(self.bed_output, delimiter='\t') | 231 w = csv.writer(self.bed_output, delimiter='\t') |
218 bed = [(result.chr, result.breakpoint, int(result.breakpoint)+1, result.gene+"_"+result.breakpoint_type, 0, result.strand) for result in self.result_d.itervalues()] | 232 bed = [(result.chr, result.breakpoint, int(result.breakpoint)+1, result.gene+"_"+result.breakpoint_type, 0, result.strand) for result in self.results] |
219 w.writerows(bed) | 233 w.writerows(bed) |
220 | 234 |
221 | 235 |
222 def calculate_all_utr(utr_coverage, plot_path, utr, utr_d, coverage_weights, num_control, num_treatment, search_start, coverage_threshold): | 236 def calculate_all_utr(utr_coverage, plot_path, utr, utr_d, coverage_weights, num_control, num_treatment, search_start, local_minimum, coverage_threshold): |
223 if utr_d["strand"] == "+": | 237 if utr_d["strand"] == "+": |
224 is_reverse = False | 238 is_reverse = False |
225 else: | 239 else: |
226 is_reverse = True | 240 is_reverse = True |
227 control_breakpoints, control_abundances, treatment_breakpoints, treatment_abundances = \ | 241 control_breakpoints, control_abundances, treatment_breakpoints, treatment_abundances = \ |
228 optimize_breakpoint(plot_path, utr, utr_coverage, utr_d["new_start"], utr_d["new_end"], coverage_weights, search_start, coverage_threshold, num_control) | 242 optimize_breakpoint(plot_path, utr, utr_coverage, utr_d["new_start"], utr_d["new_end"], coverage_weights, search_start, local_minimum, coverage_threshold, num_control) |
229 res_control = breakpoints_to_result(utr, utr_d, control_breakpoints, "control_breakpoint", control_abundances, is_reverse, | 243 res_control = breakpoints_to_result(utr, utr_d, control_breakpoints, "control_breakpoint", control_abundances, is_reverse, |
230 num_control, num_treatment) | 244 num_control, num_treatment) |
231 res_treatment = breakpoints_to_result(utr, utr_d, treatment_breakpoints, "treatment_breakpoint", treatment_abundances, is_reverse, | 245 res_treatment = breakpoints_to_result(utr, utr_d, treatment_breakpoints, "treatment_breakpoint", treatment_abundances, is_reverse, |
232 num_control, num_treatment) | 246 num_control, num_treatment) |
233 return res_control, res_treatment | 247 return res_control, res_treatment |
241 if not breakpoints: | 255 if not breakpoints: |
242 return False | 256 return False |
243 result = [] | 257 result = [] |
244 for breakpoint, abundance in zip(breakpoints, abundances): | 258 for breakpoint, abundance in zip(breakpoints, abundances): |
245 res = {} | 259 res = {} |
246 long_coverage_vector = abundance[0] | 260 res["adj_p_value"] = "NA" |
247 short_coverage_vector = abundance[1] | 261 long_coverage_vector, short_coverage_vector = abundance |
248 percentage_long = long_coverage_vector/(long_coverage_vector+short_coverage_vector) | 262 percentage_long = long_coverage_vector/(long_coverage_vector+short_coverage_vector) |
249 for i in range(num_control): | 263 for i in range(num_control): |
250 res["control_{i}_coverage_long".format(i=i)] = float(long_coverage_vector[i]) | 264 res["control_{i}_coverage_long".format(i=i)] = long_coverage_vector[i] |
251 res["control_{i}_coverage_short".format(i=i)] = float(short_coverage_vector[i]) | 265 res["control_{i}_coverage_short".format(i=i)] = short_coverage_vector[i] |
252 res["control_{i}_percent_long".format(i=i)] = percentage_long[i] | 266 res["control_{i}_percent_long".format(i=i)] = percentage_long[i] |
253 for k in range(num_treatment): | 267 for k in range(num_treatment): |
254 i = k + num_control | 268 i = k + num_control |
255 res["treatment_{i}_coverage_long".format(i=k)] = float(long_coverage_vector[i]) | 269 res["treatment_{i}_coverage_long".format(i=k)] = long_coverage_vector[i] |
256 res["treatment_{i}_coverage_short".format(i=k)] = float(short_coverage_vector[i]) | 270 res["treatment_{i}_coverage_short".format(i=k)] = short_coverage_vector[i] |
257 res["treatment_{i}_percent_long".format(i=k)] = percentage_long[i] | 271 res["treatment_{i}_percent_long".format(i=k)] = percentage_long[i] |
258 res["t_stat"], res["p_value"] = stat_test(percentage_long[:num_control], percentage_long[num_control:]) | 272 res["t_stat"], res["p_value"] = stat_test(percentage_long[:num_control], percentage_long[num_control:]) |
259 control_mean_percent = np.mean(percentage_long[:num_control]) | 273 control_mean_percent = np.mean(percentage_long[:num_control]) |
260 treatment_mean_percent = np.mean(percentage_long[num_control:]) | 274 treatment_mean_percent = np.mean(percentage_long[num_control:]) |
261 res["chr"] = utr_d["chr"] | 275 res["chr"] = utr_d["chr"] |
273 res["gene"] = utr | 287 res["gene"] = utr |
274 result.append(res) | 288 result.append(res) |
275 return result | 289 return result |
276 | 290 |
277 | 291 |
278 def optimize_breakpoint(plot_path, utr, utr_coverage, UTR_start, UTR_end, coverage_weigths, search_start, coverage_threshold, num_control): | 292 def optimize_breakpoint(plot_path, utr, utr_coverage, UTR_start, UTR_end, coverage_weigths, search_start, local_minimum, coverage_threshold, num_control): |
279 """ | 293 """ |
280 We are searching for a point within the UTR that minimizes the mean squared error, if the coverage vector was divided | 294 We are searching for a point within the UTR that minimizes the mean squared error, if the coverage vector was divided |
281 at that point. utr_coverage is a list with items corresponding to numpy arrays of coverage for a sample. | 295 at that point. utr_coverage is a list with items corresponding to numpy arrays of coverage for a sample. |
282 """ | 296 """ |
283 num_samples = len(utr_coverage) | 297 num_samples = len(utr_coverage) |
291 mse_list = [ estimate_mse(normalized_utr_coverage, bp, num_samples, num_control) for bp in breakpoints ] | 305 mse_list = [ estimate_mse(normalized_utr_coverage, bp, num_samples, num_control) for bp in breakpoints ] |
292 mse_list = [mse_list[0] for i in xrange(search_start)] + mse_list | 306 mse_list = [mse_list[0] for i in xrange(search_start)] + mse_list |
293 if plot_path: | 307 if plot_path: |
294 plot_coverage_breakpoint(plot_path, utr, mse_list, normalized_utr_coverage, num_control) | 308 plot_coverage_breakpoint(plot_path, utr, mse_list, normalized_utr_coverage, num_control) |
295 if len(mse_list) > 0: | 309 if len(mse_list) > 0: |
296 return mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples) | 310 return mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples, local_minimum) |
297 return False, False, False, False | 311 return False, False, False, False |
298 | 312 |
299 | 313 |
300 def plot_coverage_breakpoint(plot_path, utr, mse_list, normalized_utr_coverage, num_control): | 314 def plot_coverage_breakpoint(plot_path, utr, mse_list, normalized_utr_coverage, num_control): |
301 """ | 315 """ |
329 fig.add_subplot(ax2) | 343 fig.add_subplot(ax2) |
330 gs.tight_layout(fig) | 344 gs.tight_layout(fig) |
331 fig.savefig(os.path.join(plot_path, "{utr}.svg".format(utr=utr))) | 345 fig.savefig(os.path.join(plot_path, "{utr}.svg".format(utr=utr))) |
332 | 346 |
333 | 347 |
334 def mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples): | 348 def mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples, local_minimum): |
335 """ | 349 """ |
336 Take in mse_list with control and treatment mse and return breakpoint and utr abundance for all local minima | 350 Take in mse_list with control and treatment mse and return breakpoint and utr abundance for all local minima |
337 in mse_list | 351 in mse_list |
338 """ | 352 """ |
339 mse_control = np.array([mse[0] for mse in mse_list]) | 353 mse_control = np.array([mse[0] for mse in mse_list]) |
340 mse_treatment = np.array([mse[1] for mse in mse_list]) | 354 mse_treatment = np.array([mse[1] for mse in mse_list]) |
341 control_breakpoints = list(get_minima(mse_control)) | 355 control_breakpoints = list(get_minima(mse_control, local_minimum)) |
342 treatment_breakpoints = list(get_minima(mse_treatment)) | 356 treatment_breakpoints = list(get_minima(mse_treatment, local_minimum)) |
343 control_abundances = [estimate_abundance(normalized_utr_coverage, bp, num_samples) for bp in control_breakpoints] | 357 control_abundances = [estimate_abundance(normalized_utr_coverage, bp, num_samples) for bp in control_breakpoints] |
344 treatment_abundances = [estimate_abundance(normalized_utr_coverage, bp, num_samples) for bp in treatment_breakpoints] | 358 treatment_abundances = [estimate_abundance(normalized_utr_coverage, bp, num_samples) for bp in treatment_breakpoints] |
345 return control_breakpoints, control_abundances, treatment_breakpoints, treatment_abundances | 359 return control_breakpoints, control_abundances, treatment_breakpoints, treatment_abundances |
346 | 360 |
347 def get_minima(a): | 361 def get_minima(a, local_minimum=False): |
348 """ | 362 """ |
349 get minima for numpy array a | 363 get minima for numpy array a. If local is false, only return absolute minimum |
350 """ | 364 """ |
351 return np.where(np.r_[True, a[1:] < a[:-1]] & np.r_[a[:-1] < a[1:], True])[0]+1 | 365 if not local_minimum: |
366 return np.where(a == a.min()) | |
367 else: | |
368 return np.where(np.r_[True, a[1:] < a[:-1]] & np.r_[a[:-1] < a[1:], True])[0]+1 | |
352 | 369 |
353 def estimate_mse(cov, bp, num_samples, num_control): | 370 def estimate_mse(cov, bp, num_samples, num_control): |
354 """ | 371 """ |
355 get abundance of long utr vs short utr with breakpoint specifying the position of long and short utr. | 372 get abundance of long utr vs short utr with breakpoint specifying the position of long and short utr. |
356 """ | 373 """ |