Mercurial > repos > mvdbeek > dapars
comparison dapars.py @ 5:a5d8b08af089 draft
planemo upload for repository https://github.com/mvdbeek/dapars commit deab588a5d5ec7022de63a395fbd04e415ba0a42
author | mvdbeek |
---|---|
date | Thu, 29 Oct 2015 15:51:10 -0400 |
parents | 73b932244237 |
children | 538c4e2b423e |
comparison
equal
deleted
inserted
replaced
4:73b932244237 | 5:a5d8b08af089 |
---|---|
1 import argparse | 1 import argparse |
2 import os | 2 import os |
3 import csv | 3 import csv |
4 import numpy as np | 4 import numpy as np |
5 from scipy import stats | |
5 from collections import OrderedDict, namedtuple | 6 from collections import OrderedDict, namedtuple |
6 import filter_utr | 7 import filter_utr |
7 import subprocess | 8 import subprocess |
8 from multiprocessing import Pool | 9 from multiprocessing import Pool |
9 import warnings | 10 import warnings |
10 | 11 import matplotlib.pyplot as plt |
12 import matplotlib.gridspec as gridspec | |
13 from tabulate import tabulate | |
14 | |
15 def directory_path(str): | |
16 if os.path.exists(str): | |
17 return str | |
18 else: | |
19 os.mkdir(str) | |
20 return str | |
11 | 21 |
12 def parse_args(): | 22 def parse_args(): |
13 """ | 23 """ |
14 Returns floating point values except for input files. | 24 Returns floating point values except for input files. |
15 My initial approach will not filter anything. (FDR. fold_change, PDUI, Num_least ...) | 25 My initial approach will not filter anything. (FDR. fold_change, PDUI, Num_least ...) |
16 :param argv: | |
17 :return: | |
18 """ | 26 """ |
19 parser = argparse.ArgumentParser(prog='DaPars', description='Determines the usage of proximal polyA usage') | 27 parser = argparse.ArgumentParser(prog='DaPars', description='Determines the usage of proximal polyA usage') |
20 parser.add_argument("-c", "--control_alignments", nargs="+", required=True, | 28 parser.add_argument("-c", "--control_alignments", nargs="+", required=True, |
21 help="Alignment files in BAM format from control condition") | 29 help="Alignment files in BAM format from control condition") |
22 parser.add_argument("-t", "--treatment_alignments", nargs="+", required=True, | 30 parser.add_argument("-t", "--treatment_alignments", nargs="+", required=True, |
31 help="Start search for breakpoint n nucleotides downstream of UTR start") | 39 help="Start search for breakpoint n nucleotides downstream of UTR start") |
32 parser.add_argument("-ct", "--coverage_threshold", required=False, type=float, default=20, | 40 parser.add_argument("-ct", "--coverage_threshold", required=False, type=float, default=20, |
33 help="minimum coverage in each aligment to be considered for determining breakpoints") | 41 help="minimum coverage in each aligment to be considered for determining breakpoints") |
34 parser.add_argument("-b", "--breakpoint_bed", required=False, type=argparse.FileType('w'), | 42 parser.add_argument("-b", "--breakpoint_bed", required=False, type=argparse.FileType('w'), |
35 help="Write bedfile with coordinates of breakpoint positions to supplied path.") | 43 help="Write bedfile with coordinates of breakpoint positions to supplied path.") |
36 parser.add_argument("-v", "--version", action='version', version='%(prog)s 0.1.5') | 44 parser.add_argument("-v", "--version", action='version', version='%(prog)s 0.2.0') |
45 parser.add_argument("-p", "--plot_path", default=None, required=False, type=directory_path, | |
46 help="If plot_path is specified will write a coverage plot for every UTR in that directory.") | |
47 parser.add_argument("-html", "--html_file", default=None, required=False, type=argparse.FileType('w'), | |
48 help="Write an html file to the specified location. Only to be used within a galaxy wrapper") | |
37 return parser.parse_args() | 49 return parser.parse_args() |
38 | 50 |
39 | 51 |
40 class UtrFinder(): | 52 class UtrFinder(): |
41 """ | 53 """ |
46 self.control_alignments = [file for file in args.control_alignments] | 58 self.control_alignments = [file for file in args.control_alignments] |
47 self.treatment_alignments = [file for file in args.treatment_alignments] | 59 self.treatment_alignments = [file for file in args.treatment_alignments] |
48 self.n_cpus = args.cpu | 60 self.n_cpus = args.cpu |
49 self.search_start = args.search_start | 61 self.search_start = args.search_start |
50 self.coverage_threshold = args.coverage_threshold | 62 self.coverage_threshold = args.coverage_threshold |
63 self.plot_path = args.plot_path | |
64 self.html_file = args.html_file | |
51 self.utr = args.utr_bed_file | 65 self.utr = args.utr_bed_file |
52 self.gtf_fields = filter_utr.get_gtf_fields() | 66 self.gtf_fields = filter_utr.get_gtf_fields() |
53 self.result_file = args.output_file | 67 self.result_file = args.output_file |
54 self.all_alignments = self.control_alignments + self.treatment_alignments | 68 self.all_alignments = self.control_alignments + self.treatment_alignments |
55 self.alignment_names = { file: os.path.basename(file) for file in self.all_alignments } | 69 self.alignment_names = { file: os.path.basename(file) for file in self.all_alignments } |
65 self.result_d = self.calculate_apa_ratios() | 79 self.result_d = self.calculate_apa_ratios() |
66 self.write_results() | 80 self.write_results() |
67 if args.breakpoint_bed: | 81 if args.breakpoint_bed: |
68 self.bed_output = args.breakpoint_bed | 82 self.bed_output = args.breakpoint_bed |
69 self.write_bed() | 83 self.write_bed() |
70 | 84 if self.plot_path: |
85 self.write_html() | |
71 | 86 |
72 def dump_utr_dict_to_bedfile(self): | 87 def dump_utr_dict_to_bedfile(self): |
73 w = csv.writer(open("tmp_bedfile.bed", "w"), delimiter="\t") | 88 w = csv.writer(open("tmp_bedfile.bed", "w"), delimiter="\t") |
74 for gene, utr in self.utr_dict.iteritems(): | 89 for gene, utr in self.utr_dict.iteritems(): |
75 w.writerow([utr["chr"], utr["new_start"]-1, utr["new_end"], gene, ".", utr["strand"]]) | 90 w.writerow([utr["chr"], utr["new_start"]-1, utr["new_end"], gene, ".", utr["strand"]]) |
108 for alignment_name in self.alignment_names.values(): | 123 for alignment_name in self.alignment_names.values(): |
109 coverage_dict[gene][alignment_name] = coverage_dict[gene][alignment_name][::-1] | 124 coverage_dict[gene][alignment_name] = coverage_dict[gene][alignment_name][::-1] |
110 return coverage_dict | 125 return coverage_dict |
111 | 126 |
112 def get_utr_dict(self, shift): | 127 def get_utr_dict(self, shift): |
128 """ | |
129 The utr end is extended by UTR length * shift, to discover novel distal polyA sites. | |
130 Set to 0 to disable. | |
131 """ | |
113 utr_dict = OrderedDict() | 132 utr_dict = OrderedDict() |
114 for line in self.utr: | 133 for line in self.utr: |
115 if not line.startswith("#"): | 134 if not line.startswith("#"): |
116 filter_utr.get_feature_dict( line=line, gtf_fields=self.gtf_fields, utr_dict=utr_dict, feature="UTR" ) | 135 filter_utr.get_feature_dict( line=line, gtf_fields=self.gtf_fields, utr_dict=utr_dict, feature="UTR" ) |
117 gene, utr_d = utr_dict.popitem() | 136 gene, utr_d = utr_dict.popitem() |
137 utr_coverage = [] | 156 utr_coverage = [] |
138 for vector in utr.itervalues(): | 157 for vector in utr.itervalues(): |
139 utr_coverage.append(np.sum(vector)) | 158 utr_coverage.append(np.sum(vector)) |
140 coverage_per_alignment.append(utr_coverage) | 159 coverage_per_alignment.append(utr_coverage) |
141 coverages = np.array([ sum(x) for x in zip(*coverage_per_alignment) ]) | 160 coverages = np.array([ sum(x) for x in zip(*coverage_per_alignment) ]) |
142 coverage_weights = coverages / np.mean(coverages) # TODO: proabably median is better suited? | 161 coverage_weights = coverages / np.mean(coverages) # TODO: proabably median is better suited? Or even no normalization! |
143 return coverage_weights | 162 return coverage_weights |
144 | 163 |
145 def get_result_tuple(self): | 164 def get_result_tuple(self): |
146 static_desc = ["chr", "start", "end", "strand", "gene", "breakpoint", | 165 static_desc = ["chr", "start", "end", "strand", "gene", "t_stat", "p_value", "breakpoint", |
147 "breakpoint_type", "control_mean_percent", "treatment_mean_percent" ] | 166 "breakpoint_type", "control_mean_percent", "treatment_mean_percent" ] |
148 samples_desc = [] | 167 samples_desc = [] |
149 for statistic in ["coverage_long", "coverage_short", "percent_long"]: | 168 for statistic in ["coverage_long", "coverage_short", "percent_long"]: |
150 for i, sample in enumerate(self.control_alignments): | 169 for i, sample in enumerate(self.control_alignments): |
151 samples_desc.append("control_{i}_{statistic}".format(i=i, statistic = statistic)) | 170 samples_desc.append("control_{i}_{statistic}".format(i=i, statistic = statistic)) |
160 "num_samples":self.num_samples, | 179 "num_samples":self.num_samples, |
161 "num_control":len(self.control_alignments), | 180 "num_control":len(self.control_alignments), |
162 "num_treatment":len(self.treatment_alignments), | 181 "num_treatment":len(self.treatment_alignments), |
163 "result_d":result_d} | 182 "result_d":result_d} |
164 pool = Pool(self.n_cpus) | 183 pool = Pool(self.n_cpus) |
165 tasks = [ (self.utr_coverages[utr], utr, utr_d, self.result_tuple._fields, self.coverage_weights, self.num_samples, | 184 tasks = [ (self.utr_coverages[utr], self.plot_path, utr, utr_d, self.coverage_weights, len(self.control_alignments), |
166 len(self.control_alignments), len(self.treatment_alignments), self.search_start, | 185 len(self.treatment_alignments), self.search_start, self.coverage_threshold) \ |
167 self.coverage_threshold) for utr, utr_d in self.utr_dict.iteritems() ] | 186 for utr, utr_d in self.utr_dict.iteritems() ] |
168 processed_tasks = [ pool.apply_async(calculate_all_utr, t) for t in tasks] | 187 processed_tasks = [ pool.apply_async(calculate_all_utr, t) for t in tasks] |
169 result = [res.get() for res in processed_tasks] | 188 result_list = [res.get() for res in processed_tasks] |
170 for res_control, res_treatment in result: | 189 for res_control, res_treatment in result_list: |
171 if isinstance(res_control, dict): | 190 if not res_control: |
172 t = self.result_tuple(**res_control) | 191 continue |
173 result_d[res_control["gene"]+"_bp_control"] = t | 192 for i, result in enumerate(res_control): |
174 if isinstance(res_treatment, dict): | 193 if isinstance(result, dict): |
175 t = self.result_tuple(**res_treatment) | 194 t = self.result_tuple(**result) |
176 result_d[res_treatment["gene"]+"_bp_treatment"] = t | 195 result_d[result["gene"]+"_bp_control_{i}".format(i=i)] = t |
196 for i, result in enumerate(res_treatment): | |
197 if isinstance(result, dict): | |
198 t = self.result_tuple(**result) | |
199 result_d[result["gene"]+"_bp_treatment_{i}".format(i=i)] = t | |
177 return result_d | 200 return result_d |
178 | 201 |
179 def write_results(self): | 202 def write_results(self): |
180 w = csv.writer(self.result_file, delimiter='\t') | 203 w = csv.writer(self.result_file, delimiter='\t') |
181 header = list(self.result_tuple._fields) | 204 header = list(self.result_tuple._fields) |
182 header[0] = "#chr" | 205 header[0] = "#chr" |
183 w.writerow(header) # field header | 206 w.writerow(header) # field header |
184 w.writerows( self.result_d.values()) | 207 w.writerows( self.result_d.values()) |
185 | 208 |
209 def write_html(self): | |
210 output_lines = [(gene_str_to_link(result.gene), result.breakpoint, result.breakpoint_type, result.p_value ) for result in self.result_d.itervalues()] | |
211 if self.html_file: | |
212 self.html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value"], tablefmt="html")) | |
213 else: | |
214 with open(os.path.join(self.plot_path, "index.html"), "w") as html_file: | |
215 html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value"], tablefmt="html")) | |
216 | |
186 def write_bed(self): | 217 def write_bed(self): |
187 w = csv.writer(self.bed_output, delimiter='\t') | 218 w = csv.writer(self.bed_output, delimiter='\t') |
188 bed = [(result.chr, result.breakpoint, int(result.breakpoint)+1, result.gene+"_"+result.breakpoint_type, 0, result.strand) for result in self.result_d.itervalues()] | 219 bed = [(result.chr, result.breakpoint, int(result.breakpoint)+1, result.gene+"_"+result.breakpoint_type, 0, result.strand) for result in self.result_d.itervalues()] |
189 w.writerows(bed) | 220 w.writerows(bed) |
190 | 221 |
191 | 222 |
192 def calculate_all_utr(utr_coverage, utr, utr_d, result_tuple_fields, coverage_weights, num_samples, num_control, | 223 def calculate_all_utr(utr_coverage, plot_path, utr, utr_d, coverage_weights, num_control, num_treatment, search_start, coverage_threshold): |
193 num_treatment, search_start, coverage_threshold): | |
194 res_control = dict(zip(result_tuple_fields, result_tuple_fields)) | |
195 res_treatment = res_control.copy() | |
196 if utr_d["strand"] == "+": | 224 if utr_d["strand"] == "+": |
197 is_reverse = False | 225 is_reverse = False |
198 else: | 226 else: |
199 is_reverse = True | 227 is_reverse = True |
200 control_breakpoint, \ | 228 control_breakpoints, control_abundances, treatment_breakpoints, treatment_abundances = \ |
201 control_abundance, \ | 229 optimize_breakpoint(plot_path, utr, utr_coverage, utr_d["new_start"], utr_d["new_end"], coverage_weights, search_start, coverage_threshold, num_control) |
202 treatment_breakpoint, \ | 230 res_control = breakpoints_to_result(utr, utr_d, control_breakpoints, "control_breakpoint", control_abundances, is_reverse, |
203 treatment_abundance = optimize_breakpoint(utr_coverage, utr_d["new_start"], utr_d["new_end"], coverage_weights, | |
204 search_start, coverage_threshold, num_control) | |
205 if control_breakpoint: | |
206 breakpoint_to_result(res_control, utr, utr_d, control_breakpoint, "control_breakpoint", control_abundance, is_reverse, num_samples, | |
207 num_control, num_treatment) | 231 num_control, num_treatment) |
208 if treatment_breakpoint: | 232 res_treatment = breakpoints_to_result(utr, utr_d, treatment_breakpoints, "treatment_breakpoint", treatment_abundances, is_reverse, |
209 breakpoint_to_result(res_treatment, utr, utr_d, treatment_breakpoint, "treatment_breakpoint", treatment_abundance, is_reverse, | 233 num_control, num_treatment) |
210 num_samples, num_control, num_treatment) | |
211 if res_control == dict(zip(result_tuple_fields, result_tuple_fields)): | |
212 res_control = False | |
213 if res_treatment == dict(zip(result_tuple_fields, result_tuple_fields)): | |
214 res_treatment == False | |
215 return res_control, res_treatment | 234 return res_control, res_treatment |
216 | 235 |
217 | 236 |
218 def breakpoint_to_result(res, utr, utr_d, breakpoint, breakpoint_type, | 237 def breakpoints_to_result(utr, utr_d, breakpoints, breakpoint_type, |
219 abundances, is_reverse, num_samples, num_control, num_treatment): | 238 abundances, is_reverse, num_control, num_treatment): |
220 """ | 239 """ |
221 Takes in a result dictionary res and fills the necessary fields | 240 Takes in a result dictionary res and fills the necessary fields |
222 """ | 241 """ |
223 long_coverage_vector = abundances[0] | 242 if not breakpoints: |
224 short_coverage_vector = abundances[1] | 243 return False |
225 num_non_zero = sum((np.array(long_coverage_vector) + np.array(short_coverage_vector)) > 0) # TODO: This introduces bias | 244 result = [] |
226 if num_non_zero == num_samples: | 245 for breakpoint, abundance in zip(breakpoints, abundances): |
227 percentage_long = [] | 246 res = {} |
228 for i in range(num_samples): | 247 long_coverage_vector = abundance[0] |
229 ratio = float(long_coverage_vector[i]) / (long_coverage_vector[i] + short_coverage_vector[i]) # long 3'UTR percentage | 248 short_coverage_vector = abundance[1] |
230 percentage_long.append(ratio) | 249 percentage_long = long_coverage_vector/(long_coverage_vector+short_coverage_vector) |
231 for i in range(num_control): | 250 for i in range(num_control): |
232 res["control_{i}_coverage_long".format(i=i)] = float(long_coverage_vector[i]) | 251 res["control_{i}_coverage_long".format(i=i)] = float(long_coverage_vector[i]) |
233 res["control_{i}_coverage_short".format(i=i)] = float(short_coverage_vector[i]) | 252 res["control_{i}_coverage_short".format(i=i)] = float(short_coverage_vector[i]) |
234 res["control_{i}_percent_long".format(i=i)] = percentage_long[i] | 253 res["control_{i}_percent_long".format(i=i)] = percentage_long[i] |
235 for k in range(num_treatment): | 254 for k in range(num_treatment): |
236 i = k + num_control | 255 i = k + num_control |
237 res["treatment_{i}_coverage_long".format(i=k)] = float(long_coverage_vector[i]) | 256 res["treatment_{i}_coverage_long".format(i=k)] = float(long_coverage_vector[i]) |
238 res["treatment_{i}_coverage_short".format(i=k)] = float(short_coverage_vector[i]) | 257 res["treatment_{i}_coverage_short".format(i=k)] = float(short_coverage_vector[i]) |
239 res["treatment_{i}_percent_long".format(i=k)] = percentage_long[i] | 258 res["treatment_{i}_percent_long".format(i=k)] = percentage_long[i] |
259 res["t_stat"], res["p_value"] = stat_test(percentage_long[:num_control], percentage_long[num_control:]) | |
240 control_mean_percent = np.mean(np.array(percentage_long[:num_control])) | 260 control_mean_percent = np.mean(np.array(percentage_long[:num_control])) |
241 treatment_mean_percent = np.mean(np.array(percentage_long[num_control:])) | 261 treatment_mean_percent = np.mean(np.array(percentage_long[num_control:])) |
242 res["chr"] = utr_d["chr"] | 262 res["chr"] = utr_d["chr"] |
243 res["start"] = utr_d["start"] | 263 res["start"] = utr_d["start"] |
244 res["end"] = utr_d["end"] | 264 res["end"] = utr_d["end"] |
250 res["breakpoint"] = breakpoint | 270 res["breakpoint"] = breakpoint |
251 res["breakpoint_type"] = breakpoint_type | 271 res["breakpoint_type"] = breakpoint_type |
252 res["control_mean_percent"] = control_mean_percent | 272 res["control_mean_percent"] = control_mean_percent |
253 res["treatment_mean_percent"] = treatment_mean_percent | 273 res["treatment_mean_percent"] = treatment_mean_percent |
254 res["gene"] = utr | 274 res["gene"] = utr |
255 | 275 result.append(res) |
256 | 276 return result |
257 def optimize_breakpoint(utr_coverage, UTR_start, UTR_end, coverage_weigths, search_start, coverage_threshold, num_control): | 277 |
278 | |
279 def optimize_breakpoint(plot_path, utr, utr_coverage, UTR_start, UTR_end, coverage_weigths, search_start, coverage_threshold, num_control): | |
258 """ | 280 """ |
259 We are searching for a point within the UTR that minimizes the mean squared error, if the coverage vector was divided | 281 We are searching for a point within the UTR that minimizes the mean squared error, if the coverage vector was divided |
260 at that point. utr_coverage is a list with items corresponding to numpy arrays of coverage for a sample. | 282 at that point. utr_coverage is a list with items corresponding to numpy arrays of coverage for a sample. |
261 """ | 283 """ |
262 search_point_end = int(abs((UTR_end - UTR_start)) * 0.1) # TODO: This is 10% of total UTR end. Why? | |
263 num_samples = len(utr_coverage) | 284 num_samples = len(utr_coverage) |
264 normalized_utr_coverage = np.array([coverage/ coverage_weigths[i] for i, coverage in enumerate( utr_coverage.values() )]) | 285 normalized_utr_coverage = np.array(utr_coverage.values())/np.expand_dims(coverage_weigths, axis=1) |
265 start_coverage = [np.mean(coverage[0:99]) for coverage in utr_coverage.values()] # filters threshold on mean coverage over first 100 nt | 286 start_coverage = [np.mean(coverage[0:99]) for coverage in utr_coverage.values()] # filters threshold on mean coverage over first 100 nt |
266 is_above_threshold = sum(np.array(start_coverage) >= coverage_threshold) >= num_samples # This filters on the raw threshold. Why? | 287 is_above_threshold = sum(np.array(start_coverage) >= coverage_threshold) >= num_samples # This filters on the raw threshold. Why? |
267 is_above_length = UTR_end - UTR_start >= 150 | 288 is_above_length = UTR_end - UTR_start >= 150 |
268 if (is_above_threshold) and (is_above_length): | 289 if (is_above_threshold) and (is_above_length): |
269 search_end = UTR_end - UTR_start - search_point_end | 290 search_end = UTR_end - UTR_start |
270 breakpoints = range(search_start, search_end + 1) | 291 breakpoints = range(search_start, search_end + 1) |
271 mse_list = [ estimate_mse(normalized_utr_coverage, bp, num_samples, num_control) for bp in breakpoints ] | 292 mse_list = [ estimate_mse(normalized_utr_coverage, bp, num_samples, num_control) for bp in breakpoints ] |
293 mse_list = [mse_list[0] for i in xrange(search_start)] + mse_list | |
294 if plot_path: | |
295 plot_coverage_breakpoint(plot_path, utr, mse_list, normalized_utr_coverage, num_control) | |
272 if len(mse_list) > 0: | 296 if len(mse_list) > 0: |
273 return mse_to_breakpoint(mse_list, normalized_utr_coverage, breakpoints, num_samples) | 297 return mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples) |
274 return False, False, False, False | 298 return False, False, False, False |
275 | 299 |
276 | 300 |
277 def mse_to_breakpoint(mse_list, normalized_utr_coverage, breakpoints, num_samples): | 301 def plot_coverage_breakpoint(plot_path, utr, mse_list, normalized_utr_coverage, num_control): |
278 """ | 302 """ |
279 Take in mse_list with control and treatment mse and return breakpoint and utr abundance | 303 |
280 """ | 304 """ |
281 mse_control = [mse[0] for mse in mse_list] | 305 fig = plt.figure(figsize=(8, 8)) |
282 mse_treatment = [mse[1] for mse in mse_list] | 306 gs = gridspec.GridSpec(2, 1) |
283 control_index = mse_control.index(min(mse_control)) | 307 ax1 = plt.subplot(gs[0, :]) |
284 treatment_index = mse_treatment.index(min(mse_treatment)) | 308 ax2 = plt.subplot(gs[1, :]) |
285 control_breakpoint = breakpoints[control_index] | 309 ax1.set_title("mean-squared error plot") |
286 treatment_breakpoint = breakpoints[treatment_index] | 310 ax1.set_ylabel("mean-squared error") |
287 control_abundance = estimate_abundance(normalized_utr_coverage, control_breakpoint, num_samples) | 311 ax1.set_xlabel("nt after UTR start") |
288 treatment_abundance = estimate_abundance(normalized_utr_coverage, treatment_breakpoint, num_samples) | 312 ax2.set_title("coverage plot") |
289 return control_breakpoint, control_abundance, treatment_breakpoint, treatment_abundance | 313 ax2.set_xlabel("nt after UTR start") |
290 | 314 ax2.set_ylabel("normalized nucleotide coverage") |
315 mse_control = [ condition[0] for condition in mse_list] | |
316 mse_treatment = [ condition[1] for condition in mse_list] | |
317 minima_control = get_minima(np.array(mse_control)) | |
318 minima_treatment = get_minima(np.array(mse_treatment)) | |
319 control = normalized_utr_coverage[:num_control] | |
320 treatment = normalized_utr_coverage[num_control:] | |
321 ax1.plot(mse_control, "b-") | |
322 ax1.plot(mse_treatment, "r-") | |
323 [ax2.plot(cov, "b-") for cov in control] | |
324 [ax2.plot(cov, "r-") for cov in treatment] | |
325 [ax2.axvline(val, color="b", alpha=0.25) for val in minima_control] | |
326 ax2.axvline(mse_control.index(min(mse_control)), color="b", alpha=1) | |
327 [ax2.axvline(val, color="r", alpha=0.25) for val in minima_treatment] | |
328 ax2.axvline(mse_treatment.index(min(mse_treatment)), color="r", alpha=1) | |
329 fig.add_subplot(ax1) | |
330 fig.add_subplot(ax2) | |
331 gs.tight_layout(fig) | |
332 fig.savefig(os.path.join(plot_path, "{utr}.svg".format(utr=utr))) | |
333 | |
334 | |
335 def mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples): | |
336 """ | |
337 Take in mse_list with control and treatment mse and return breakpoint and utr abundance for all local minima | |
338 in mse_list | |
339 """ | |
340 mse_control = np.array([mse[0] for mse in mse_list]) | |
341 mse_treatment = np.array([mse[1] for mse in mse_list]) | |
342 control_breakpoints = list(get_minima(mse_control)) | |
343 treatment_breakpoints = list(get_minima(mse_treatment)) | |
344 control_abundances = [estimate_abundance(normalized_utr_coverage, bp, num_samples) for bp in control_breakpoints] | |
345 treatment_abundances = [estimate_abundance(normalized_utr_coverage, bp, num_samples) for bp in treatment_breakpoints] | |
346 return control_breakpoints, control_abundances, treatment_breakpoints, treatment_abundances | |
347 | |
348 def get_minima(a): | |
349 """ | |
350 get minima for numpy array a | |
351 """ | |
352 return np.where(np.r_[True, a[1:] < a[:-1]] & np.r_[a[:-1] < a[1:], True])[0]+1 | |
291 | 353 |
292 def estimate_mse(cov, bp, num_samples, num_control): | 354 def estimate_mse(cov, bp, num_samples, num_control): |
293 """ | 355 """ |
294 get abundance of long utr vs short utr with breakpoint specifying the position of long and short utr. | 356 get abundance of long utr vs short utr with breakpoint specifying the position of long and short utr. |
295 """ | 357 """ |
313 short_utr_vector = cov[:num_samples, 0:bp] | 375 short_utr_vector = cov[:num_samples, 0:bp] |
314 mean_long_utr = np.mean(long_utr_vector, 1) | 376 mean_long_utr = np.mean(long_utr_vector, 1) |
315 mean_short_utr = np.mean(short_utr_vector, 1) | 377 mean_short_utr = np.mean(short_utr_vector, 1) |
316 return mean_long_utr, mean_short_utr | 378 return mean_long_utr, mean_short_utr |
317 | 379 |
380 def stat_test(a,b): | |
381 return stats.ttest_ind(a,b) | |
382 | |
383 def gene_str_to_link(str): | |
384 return "<a href=\"{str}.svg\" type=\"image/svg+xml\" target=\"_blank\">{str}</a>".format(str=str) | |
318 | 385 |
319 if __name__ == '__main__': | 386 if __name__ == '__main__': |
320 args = parse_args() | 387 args = parse_args() |
321 find_utr = UtrFinder(args) | 388 find_utr = UtrFinder(args) |
322 | 389 |