comparison dapars.py @ 17:917a2f7ab841 draft

planemo upload for repository https://github.com/mvdbeek/dapars commit b1b007c561ea6c9db145c88b6b128d66ecd05e24-dirty
author mvdbeek
date Fri, 30 Oct 2015 10:35:17 -0400
parents f8bb40b2ff31
children ed151db39c7e
comparison
equal deleted inserted replaced
16:f8bb40b2ff31 17:917a2f7ab841
9 from multiprocessing import Pool 9 from multiprocessing import Pool
10 import warnings 10 import warnings
11 import matplotlib.pyplot as plt 11 import matplotlib.pyplot as plt
12 import matplotlib.gridspec as gridspec 12 import matplotlib.gridspec as gridspec
13 from tabulate import tabulate 13 from tabulate import tabulate
14 import statsmodels.sandbox.stats.multicomp as mc
14 15
15 def directory_path(str): 16 def directory_path(str):
16 if os.path.exists(str): 17 if os.path.exists(str):
17 return str 18 return str
18 else: 19 else:
33 help="Bed file describing longest 3UTR positions") 34 help="Bed file describing longest 3UTR positions")
34 parser.add_argument("-o", "--output_file", required=True, type=argparse.FileType('w'), 35 parser.add_argument("-o", "--output_file", required=True, type=argparse.FileType('w'),
35 help="file containing output") 36 help="file containing output")
36 parser.add_argument("-cpu", required=False, type=int, default=1, 37 parser.add_argument("-cpu", required=False, type=int, default=1,
37 help="Number of CPU cores to use.") 38 help="Number of CPU cores to use.")
39 parser.add_argument("-l", "--local_minimum", action='store_true')
38 parser.add_argument("-s", "--search_start", required=False, type=int, default=50, 40 parser.add_argument("-s", "--search_start", required=False, type=int, default=50,
39 help="Start search for breakpoint n nucleotides downstream of UTR start") 41 help="Start search for breakpoint n nucleotides downstream of UTR start")
40 parser.add_argument("-ct", "--coverage_threshold", required=False, type=float, default=20, 42 parser.add_argument("-ct", "--coverage_threshold", required=False, type=float, default=20,
41 help="minimum coverage in each aligment to be considered for determining breakpoints") 43 help="minimum coverage in each aligment to be considered for determining breakpoints")
42 parser.add_argument("-b", "--breakpoint_bed", required=False, type=argparse.FileType('w'), 44 parser.add_argument("-b", "--breakpoint_bed", required=False, type=argparse.FileType('w'),
43 help="Write bedfile with coordinates of breakpoint positions to supplied path.") 45 help="Write bedfile with coordinates of breakpoint positions to supplied path.")
44 parser.add_argument("-v", "--version", action='version', version='%(prog)s 0.2.2') 46 parser.add_argument("-v", "--version", action='version', version='%(prog)s 0.2.3')
45 parser.add_argument("-p", "--plot_path", default=None, required=False, type=directory_path, 47 parser.add_argument("-p", "--plot_path", default=None, required=False, type=directory_path,
46 help="If plot_path is specified will write a coverage plot for every UTR in that directory.") 48 help="If plot_path is specified will write a coverage plot for every UTR in that directory.")
47 parser.add_argument("-html", "--html_file", default=None, required=False, type=argparse.FileType('w'), 49 parser.add_argument("-html", "--html_file", default=None, required=False, type=argparse.FileType('w'),
48 help="Write an html file to the specified location. Only to be used within a galaxy wrapper") 50 help="Write an html file to the specified location. Only to be used within a galaxy wrapper")
49 return parser.parse_args() 51 return parser.parse_args()
58 self.control_alignments = [file for file in args.control_alignments] 60 self.control_alignments = [file for file in args.control_alignments]
59 self.treatment_alignments = [file for file in args.treatment_alignments] 61 self.treatment_alignments = [file for file in args.treatment_alignments]
60 self.n_cpus = args.cpu 62 self.n_cpus = args.cpu
61 self.search_start = args.search_start 63 self.search_start = args.search_start
62 self.coverage_threshold = args.coverage_threshold 64 self.coverage_threshold = args.coverage_threshold
65 self.local_minimum = args.local_minimum
63 self.plot_path = args.plot_path 66 self.plot_path = args.plot_path
64 self.html_file = args.html_file 67 self.html_file = args.html_file
65 self.utr = args.utr_bed_file 68 self.utr = args.utr_bed_file
66 self.gtf_fields = filter_utr.get_gtf_fields() 69 self.gtf_fields = filter_utr.get_gtf_fields()
67 self.result_file = args.output_file 70 self.result_file = args.output_file
75 self.utr_coverages = self.read_coverage_result() 78 self.utr_coverages = self.read_coverage_result()
76 print "Established dictionary of 3\'UTR coverages" 79 print "Established dictionary of 3\'UTR coverages"
77 self.coverage_weights = self.get_coverage_weights() 80 self.coverage_weights = self.get_coverage_weights()
78 self.result_tuple = self.get_result_tuple() 81 self.result_tuple = self.get_result_tuple()
79 self.result_d = self.calculate_apa_ratios() 82 self.result_d = self.calculate_apa_ratios()
83 self.results = self.order_by_p()
80 self.write_results() 84 self.write_results()
81 if args.breakpoint_bed: 85 if args.breakpoint_bed:
82 self.bed_output = args.breakpoint_bed 86 self.bed_output = args.breakpoint_bed
83 self.write_bed() 87 self.write_bed()
84 if self.plot_path: 88 if self.plot_path:
85 self.write_html() 89 self.write_html()
90
91 def order_by_p(self):
92 results = [result for result in self.result_d.itervalues()]
93 p_values = np.array([ result.p_value for result in self.result_d.itervalues() ])
94 adj_p_values = mc.fdrcorrection0(p_values, 0.05)[1]
95 sort_index = np.argsort(adj_p_values)
96 results = [ results[i]._replace(adj_p_value=adj_p_values[i]) for i in sort_index ]
97 return results
98
86 99
87 def dump_utr_dict_to_bedfile(self): 100 def dump_utr_dict_to_bedfile(self):
88 w = csv.writer(open("tmp_bedfile.bed", "w"), delimiter="\t") 101 w = csv.writer(open("tmp_bedfile.bed", "w"), delimiter="\t")
89 for gene, utr in self.utr_dict.iteritems(): 102 for gene, utr in self.utr_dict.iteritems():
90 w.writerow([utr["chr"], utr["new_start"]-1, utr["new_end"], gene, ".", utr["strand"]]) 103 w.writerow([utr["chr"], utr["new_start"]-1, utr["new_end"], gene, ".", utr["strand"]])
158 coverages = np.array([ sum(x) for x in zip(*coverage_per_alignment) ]) 171 coverages = np.array([ sum(x) for x in zip(*coverage_per_alignment) ])
159 coverage_weights = coverages / np.mean(coverages) # TODO: proabably median is better suited? Or even no normalization! 172 coverage_weights = coverages / np.mean(coverages) # TODO: proabably median is better suited? Or even no normalization!
160 return coverage_weights 173 return coverage_weights
161 174
162 def get_result_tuple(self): 175 def get_result_tuple(self):
163 static_desc = ["chr", "start", "end", "strand", "gene", "t_stat", "p_value", "breakpoint", 176 static_desc = ["chr", "start", "end", "strand", "gene", "t_stat", "p_value", "adj_p_value", "breakpoint",
164 "breakpoint_type", "control_mean_percent", "treatment_mean_percent" ] 177 "breakpoint_type", "control_mean_percent", "treatment_mean_percent" ]
165 samples_desc = [] 178 samples_desc = []
166 for statistic in ["coverage_long", "coverage_short", "percent_long"]: 179 for statistic in ["coverage_long", "coverage_short", "percent_long"]:
167 for i, sample in enumerate(self.control_alignments): 180 for i, sample in enumerate(self.control_alignments):
168 samples_desc.append("control_{i}_{statistic}".format(i=i, statistic = statistic)) 181 samples_desc.append("control_{i}_{statistic}".format(i=i, statistic = statistic))
178 "num_control":len(self.control_alignments), 191 "num_control":len(self.control_alignments),
179 "num_treatment":len(self.treatment_alignments), 192 "num_treatment":len(self.treatment_alignments),
180 "result_d":result_d} 193 "result_d":result_d}
181 pool = Pool(self.n_cpus) 194 pool = Pool(self.n_cpus)
182 tasks = [ (self.utr_coverages[utr], self.plot_path, utr, utr_d, self.coverage_weights, len(self.control_alignments), 195 tasks = [ (self.utr_coverages[utr], self.plot_path, utr, utr_d, self.coverage_weights, len(self.control_alignments),
183 len(self.treatment_alignments), self.search_start, self.coverage_threshold) \ 196 len(self.treatment_alignments), self.search_start, self.local_minimum, self.coverage_threshold) \
184 for utr, utr_d in self.utr_dict.iteritems() ] 197 for utr, utr_d in self.utr_dict.iteritems() ]
185 processed_tasks = [ pool.apply_async(calculate_all_utr, t) for t in tasks] 198 #processed_tasks = [ pool.apply_async(calculate_all_utr, t) for t in tasks]
186 result_list = [res.get() for res in processed_tasks] 199 #result_list = [res.get() for res in processed_tasks]
187 #result_list = [calculate_all_utr(*t) for t in tasks] # uncomment for easier debugging 200 result_list = [calculate_all_utr(*t) for t in tasks] # uncomment for easier debugging
188 for res_control, res_treatment in result_list: 201 for res_control, res_treatment in result_list:
189 if not res_control: 202 if not res_control:
190 continue 203 continue
191 for i, result in enumerate(res_control): 204 for i, result in enumerate(res_control):
192 if isinstance(result, dict): 205 if isinstance(result, dict):
201 def write_results(self): 214 def write_results(self):
202 w = csv.writer(self.result_file, delimiter='\t') 215 w = csv.writer(self.result_file, delimiter='\t')
203 header = list(self.result_tuple._fields) 216 header = list(self.result_tuple._fields)
204 header[0] = "#chr" 217 header[0] = "#chr"
205 w.writerow(header) # field header 218 w.writerow(header) # field header
206 w.writerows( self.result_d.values()) 219 w.writerows( self.results)
207 220
208 def write_html(self): 221 def write_html(self):
209 output_lines = [(gene_str_to_link(result.gene), result.breakpoint, result.breakpoint_type, result.p_value ) for result in self.result_d.itervalues()] 222 output_lines = [(gene_str_to_link(result.gene), result.breakpoint, result.breakpoint_type, result.p_value ) for result in self.results]
210 if self.html_file: 223 if self.html_file:
211 self.html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value"], tablefmt="html")) 224 self.html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value", "adj_p_value"], tablefmt="html"))
212 else: 225 else:
213 with open(os.path.join(self.plot_path, "index.html"), "w") as html_file: 226 with open(os.path.join(self.plot_path, "index.html"), "w") as html_file:
214 html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value"], tablefmt="html")) 227 html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value", "adj_p_value"], tablefmt="html"))
228 html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value", "adj_p_value"], tablefmt="html"))
215 229
216 def write_bed(self): 230 def write_bed(self):
217 w = csv.writer(self.bed_output, delimiter='\t') 231 w = csv.writer(self.bed_output, delimiter='\t')
218 bed = [(result.chr, result.breakpoint, int(result.breakpoint)+1, result.gene+"_"+result.breakpoint_type, 0, result.strand) for result in self.result_d.itervalues()] 232 bed = [(result.chr, result.breakpoint, int(result.breakpoint)+1, result.gene+"_"+result.breakpoint_type, 0, result.strand) for result in self.results]
219 w.writerows(bed) 233 w.writerows(bed)
220 234
221 235
222 def calculate_all_utr(utr_coverage, plot_path, utr, utr_d, coverage_weights, num_control, num_treatment, search_start, coverage_threshold): 236 def calculate_all_utr(utr_coverage, plot_path, utr, utr_d, coverage_weights, num_control, num_treatment, search_start, local_minimum, coverage_threshold):
223 if utr_d["strand"] == "+": 237 if utr_d["strand"] == "+":
224 is_reverse = False 238 is_reverse = False
225 else: 239 else:
226 is_reverse = True 240 is_reverse = True
227 control_breakpoints, control_abundances, treatment_breakpoints, treatment_abundances = \ 241 control_breakpoints, control_abundances, treatment_breakpoints, treatment_abundances = \
228 optimize_breakpoint(plot_path, utr, utr_coverage, utr_d["new_start"], utr_d["new_end"], coverage_weights, search_start, coverage_threshold, num_control) 242 optimize_breakpoint(plot_path, utr, utr_coverage, utr_d["new_start"], utr_d["new_end"], coverage_weights, search_start, local_minimum, coverage_threshold, num_control)
229 res_control = breakpoints_to_result(utr, utr_d, control_breakpoints, "control_breakpoint", control_abundances, is_reverse, 243 res_control = breakpoints_to_result(utr, utr_d, control_breakpoints, "control_breakpoint", control_abundances, is_reverse,
230 num_control, num_treatment) 244 num_control, num_treatment)
231 res_treatment = breakpoints_to_result(utr, utr_d, treatment_breakpoints, "treatment_breakpoint", treatment_abundances, is_reverse, 245 res_treatment = breakpoints_to_result(utr, utr_d, treatment_breakpoints, "treatment_breakpoint", treatment_abundances, is_reverse,
232 num_control, num_treatment) 246 num_control, num_treatment)
233 return res_control, res_treatment 247 return res_control, res_treatment
241 if not breakpoints: 255 if not breakpoints:
242 return False 256 return False
243 result = [] 257 result = []
244 for breakpoint, abundance in zip(breakpoints, abundances): 258 for breakpoint, abundance in zip(breakpoints, abundances):
245 res = {} 259 res = {}
246 long_coverage_vector = abundance[0] 260 res["adj_p_value"] = "NA"
247 short_coverage_vector = abundance[1] 261 long_coverage_vector, short_coverage_vector = abundance
248 percentage_long = long_coverage_vector/(long_coverage_vector+short_coverage_vector) 262 percentage_long = long_coverage_vector/(long_coverage_vector+short_coverage_vector)
249 for i in range(num_control): 263 for i in range(num_control):
250 res["control_{i}_coverage_long".format(i=i)] = float(long_coverage_vector[i]) 264 res["control_{i}_coverage_long".format(i=i)] = long_coverage_vector[i]
251 res["control_{i}_coverage_short".format(i=i)] = float(short_coverage_vector[i]) 265 res["control_{i}_coverage_short".format(i=i)] = short_coverage_vector[i]
252 res["control_{i}_percent_long".format(i=i)] = percentage_long[i] 266 res["control_{i}_percent_long".format(i=i)] = percentage_long[i]
253 for k in range(num_treatment): 267 for k in range(num_treatment):
254 i = k + num_control 268 i = k + num_control
255 res["treatment_{i}_coverage_long".format(i=k)] = float(long_coverage_vector[i]) 269 res["treatment_{i}_coverage_long".format(i=k)] = long_coverage_vector[i]
256 res["treatment_{i}_coverage_short".format(i=k)] = float(short_coverage_vector[i]) 270 res["treatment_{i}_coverage_short".format(i=k)] = short_coverage_vector[i]
257 res["treatment_{i}_percent_long".format(i=k)] = percentage_long[i] 271 res["treatment_{i}_percent_long".format(i=k)] = percentage_long[i]
258 res["t_stat"], res["p_value"] = stat_test(percentage_long[:num_control], percentage_long[num_control:]) 272 res["t_stat"], res["p_value"] = stat_test(percentage_long[:num_control], percentage_long[num_control:])
259 control_mean_percent = np.mean(percentage_long[:num_control]) 273 control_mean_percent = np.mean(percentage_long[:num_control])
260 treatment_mean_percent = np.mean(percentage_long[num_control:]) 274 treatment_mean_percent = np.mean(percentage_long[num_control:])
261 res["chr"] = utr_d["chr"] 275 res["chr"] = utr_d["chr"]
273 res["gene"] = utr 287 res["gene"] = utr
274 result.append(res) 288 result.append(res)
275 return result 289 return result
276 290
277 291
278 def optimize_breakpoint(plot_path, utr, utr_coverage, UTR_start, UTR_end, coverage_weigths, search_start, coverage_threshold, num_control): 292 def optimize_breakpoint(plot_path, utr, utr_coverage, UTR_start, UTR_end, coverage_weigths, search_start, local_minimum, coverage_threshold, num_control):
279 """ 293 """
280 We are searching for a point within the UTR that minimizes the mean squared error, if the coverage vector was divided 294 We are searching for a point within the UTR that minimizes the mean squared error, if the coverage vector was divided
281 at that point. utr_coverage is a list with items corresponding to numpy arrays of coverage for a sample. 295 at that point. utr_coverage is a list with items corresponding to numpy arrays of coverage for a sample.
282 """ 296 """
283 num_samples = len(utr_coverage) 297 num_samples = len(utr_coverage)
291 mse_list = [ estimate_mse(normalized_utr_coverage, bp, num_samples, num_control) for bp in breakpoints ] 305 mse_list = [ estimate_mse(normalized_utr_coverage, bp, num_samples, num_control) for bp in breakpoints ]
292 mse_list = [mse_list[0] for i in xrange(search_start)] + mse_list 306 mse_list = [mse_list[0] for i in xrange(search_start)] + mse_list
293 if plot_path: 307 if plot_path:
294 plot_coverage_breakpoint(plot_path, utr, mse_list, normalized_utr_coverage, num_control) 308 plot_coverage_breakpoint(plot_path, utr, mse_list, normalized_utr_coverage, num_control)
295 if len(mse_list) > 0: 309 if len(mse_list) > 0:
296 return mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples) 310 return mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples, local_minimum)
297 return False, False, False, False 311 return False, False, False, False
298 312
299 313
300 def plot_coverage_breakpoint(plot_path, utr, mse_list, normalized_utr_coverage, num_control): 314 def plot_coverage_breakpoint(plot_path, utr, mse_list, normalized_utr_coverage, num_control):
301 """ 315 """
329 fig.add_subplot(ax2) 343 fig.add_subplot(ax2)
330 gs.tight_layout(fig) 344 gs.tight_layout(fig)
331 fig.savefig(os.path.join(plot_path, "{utr}.svg".format(utr=utr))) 345 fig.savefig(os.path.join(plot_path, "{utr}.svg".format(utr=utr)))
332 346
333 347
334 def mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples): 348 def mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples, local_minimum):
335 """ 349 """
336 Take in mse_list with control and treatment mse and return breakpoint and utr abundance for all local minima 350 Take in mse_list with control and treatment mse and return breakpoint and utr abundance for all local minima
337 in mse_list 351 in mse_list
338 """ 352 """
339 mse_control = np.array([mse[0] for mse in mse_list]) 353 mse_control = np.array([mse[0] for mse in mse_list])
340 mse_treatment = np.array([mse[1] for mse in mse_list]) 354 mse_treatment = np.array([mse[1] for mse in mse_list])
341 control_breakpoints = list(get_minima(mse_control)) 355 control_breakpoints = list(get_minima(mse_control, local_minimum))
342 treatment_breakpoints = list(get_minima(mse_treatment)) 356 treatment_breakpoints = list(get_minima(mse_treatment, local_minimum))
343 control_abundances = [estimate_abundance(normalized_utr_coverage, bp, num_samples) for bp in control_breakpoints] 357 control_abundances = [estimate_abundance(normalized_utr_coverage, bp, num_samples) for bp in control_breakpoints]
344 treatment_abundances = [estimate_abundance(normalized_utr_coverage, bp, num_samples) for bp in treatment_breakpoints] 358 treatment_abundances = [estimate_abundance(normalized_utr_coverage, bp, num_samples) for bp in treatment_breakpoints]
345 return control_breakpoints, control_abundances, treatment_breakpoints, treatment_abundances 359 return control_breakpoints, control_abundances, treatment_breakpoints, treatment_abundances
346 360
347 def get_minima(a): 361 def get_minima(a, local_minimum=False):
348 """ 362 """
349 get minima for numpy array a 363 get minima for numpy array a. If local is false, only return absolute minimum
350 """ 364 """
351 return np.where(np.r_[True, a[1:] < a[:-1]] & np.r_[a[:-1] < a[1:], True])[0]+1 365 if not local_minimum:
366 return np.where(a == a.min())
367 else:
368 return np.where(np.r_[True, a[1:] < a[:-1]] & np.r_[a[:-1] < a[1:], True])[0]+1
352 369
353 def estimate_mse(cov, bp, num_samples, num_control): 370 def estimate_mse(cov, bp, num_samples, num_control):
354 """ 371 """
355 get abundance of long utr vs short utr with breakpoint specifying the position of long and short utr. 372 get abundance of long utr vs short utr with breakpoint specifying the position of long and short utr.
356 """ 373 """