Mercurial > repos > mvdbeek > dapars
diff dapars.py @ 5:a5d8b08af089 draft
planemo upload for repository https://github.com/mvdbeek/dapars commit deab588a5d5ec7022de63a395fbd04e415ba0a42
author | mvdbeek |
---|---|
date | Thu, 29 Oct 2015 15:51:10 -0400 |
parents | 73b932244237 |
children | 538c4e2b423e |
line wrap: on
line diff
--- a/dapars.py Wed Oct 28 06:22:18 2015 -0400 +++ b/dapars.py Thu Oct 29 15:51:10 2015 -0400 @@ -2,19 +2,27 @@ import os import csv import numpy as np +from scipy import stats from collections import OrderedDict, namedtuple import filter_utr import subprocess from multiprocessing import Pool import warnings +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +from tabulate import tabulate +def directory_path(str): + if os.path.exists(str): + return str + else: + os.mkdir(str) + return str def parse_args(): """ Returns floating point values except for input files. My initial approach will not filter anything. (FDR. fold_change, PDUI, Num_least ...) - :param argv: - :return: """ parser = argparse.ArgumentParser(prog='DaPars', description='Determines the usage of proximal polyA usage') parser.add_argument("-c", "--control_alignments", nargs="+", required=True, @@ -33,7 +41,11 @@ help="minimum coverage in each aligment to be considered for determining breakpoints") parser.add_argument("-b", "--breakpoint_bed", required=False, type=argparse.FileType('w'), help="Write bedfile with coordinates of breakpoint positions to supplied path.") - parser.add_argument("-v", "--version", action='version', version='%(prog)s 0.1.5') + parser.add_argument("-v", "--version", action='version', version='%(prog)s 0.2.0') + parser.add_argument("-p", "--plot_path", default=None, required=False, type=directory_path, + help="If plot_path is specified will write a coverage plot for every UTR in that directory.") + parser.add_argument("-html", "--html_file", default=None, required=False, type=argparse.FileType('w'), + help="Write an html file to the specified location. Only to be used within a galaxy wrapper") return parser.parse_args() @@ -48,6 +60,8 @@ self.n_cpus = args.cpu self.search_start = args.search_start self.coverage_threshold = args.coverage_threshold + self.plot_path = args.plot_path + self.html_file = args.html_file self.utr = args.utr_bed_file self.gtf_fields = filter_utr.get_gtf_fields() self.result_file = args.output_file @@ -67,7 +81,8 @@ if args.breakpoint_bed: self.bed_output = args.breakpoint_bed self.write_bed() - + if self.plot_path: + self.write_html() def dump_utr_dict_to_bedfile(self): w = csv.writer(open("tmp_bedfile.bed", "w"), delimiter="\t") @@ -110,6 +125,10 @@ return coverage_dict def get_utr_dict(self, shift): + """ + The utr end is extended by UTR length * shift, to discover novel distal polyA sites. + Set to 0 to disable. + """ utr_dict = OrderedDict() for line in self.utr: if not line.startswith("#"): @@ -139,11 +158,11 @@ utr_coverage.append(np.sum(vector)) coverage_per_alignment.append(utr_coverage) coverages = np.array([ sum(x) for x in zip(*coverage_per_alignment) ]) - coverage_weights = coverages / np.mean(coverages) # TODO: proabably median is better suited? + coverage_weights = coverages / np.mean(coverages) # TODO: proabably median is better suited? Or even no normalization! return coverage_weights def get_result_tuple(self): - static_desc = ["chr", "start", "end", "strand", "gene", "breakpoint", + static_desc = ["chr", "start", "end", "strand", "gene", "t_stat", "p_value", "breakpoint", "breakpoint_type", "control_mean_percent", "treatment_mean_percent" ] samples_desc = [] for statistic in ["coverage_long", "coverage_short", "percent_long"]: @@ -162,18 +181,22 @@ "num_treatment":len(self.treatment_alignments), "result_d":result_d} pool = Pool(self.n_cpus) - tasks = [ (self.utr_coverages[utr], utr, utr_d, self.result_tuple._fields, self.coverage_weights, self.num_samples, - len(self.control_alignments), len(self.treatment_alignments), self.search_start, - self.coverage_threshold) for utr, utr_d in self.utr_dict.iteritems() ] + tasks = [ (self.utr_coverages[utr], self.plot_path, utr, utr_d, self.coverage_weights, len(self.control_alignments), + len(self.treatment_alignments), self.search_start, self.coverage_threshold) \ + for utr, utr_d in self.utr_dict.iteritems() ] processed_tasks = [ pool.apply_async(calculate_all_utr, t) for t in tasks] - result = [res.get() for res in processed_tasks] - for res_control, res_treatment in result: - if isinstance(res_control, dict): - t = self.result_tuple(**res_control) - result_d[res_control["gene"]+"_bp_control"] = t - if isinstance(res_treatment, dict): - t = self.result_tuple(**res_treatment) - result_d[res_treatment["gene"]+"_bp_treatment"] = t + result_list = [res.get() for res in processed_tasks] + for res_control, res_treatment in result_list: + if not res_control: + continue + for i, result in enumerate(res_control): + if isinstance(result, dict): + t = self.result_tuple(**result) + result_d[result["gene"]+"_bp_control_{i}".format(i=i)] = t + for i, result in enumerate(res_treatment): + if isinstance(result, dict): + t = self.result_tuple(**result) + result_d[result["gene"]+"_bp_treatment_{i}".format(i=i)] = t return result_d def write_results(self): @@ -183,51 +206,47 @@ w.writerow(header) # field header w.writerows( self.result_d.values()) + def write_html(self): + output_lines = [(gene_str_to_link(result.gene), result.breakpoint, result.breakpoint_type, result.p_value ) for result in self.result_d.itervalues()] + if self.html_file: + self.html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value"], tablefmt="html")) + else: + with open(os.path.join(self.plot_path, "index.html"), "w") as html_file: + html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value"], tablefmt="html")) + def write_bed(self): w = csv.writer(self.bed_output, delimiter='\t') bed = [(result.chr, result.breakpoint, int(result.breakpoint)+1, result.gene+"_"+result.breakpoint_type, 0, result.strand) for result in self.result_d.itervalues()] w.writerows(bed) -def calculate_all_utr(utr_coverage, utr, utr_d, result_tuple_fields, coverage_weights, num_samples, num_control, - num_treatment, search_start, coverage_threshold): - res_control = dict(zip(result_tuple_fields, result_tuple_fields)) - res_treatment = res_control.copy() +def calculate_all_utr(utr_coverage, plot_path, utr, utr_d, coverage_weights, num_control, num_treatment, search_start, coverage_threshold): if utr_d["strand"] == "+": is_reverse = False else: is_reverse = True - control_breakpoint, \ - control_abundance, \ - treatment_breakpoint, \ - treatment_abundance = optimize_breakpoint(utr_coverage, utr_d["new_start"], utr_d["new_end"], coverage_weights, - search_start, coverage_threshold, num_control) - if control_breakpoint: - breakpoint_to_result(res_control, utr, utr_d, control_breakpoint, "control_breakpoint", control_abundance, is_reverse, num_samples, + control_breakpoints, control_abundances, treatment_breakpoints, treatment_abundances = \ + optimize_breakpoint(plot_path, utr, utr_coverage, utr_d["new_start"], utr_d["new_end"], coverage_weights, search_start, coverage_threshold, num_control) + res_control = breakpoints_to_result(utr, utr_d, control_breakpoints, "control_breakpoint", control_abundances, is_reverse, num_control, num_treatment) - if treatment_breakpoint: - breakpoint_to_result(res_treatment, utr, utr_d, treatment_breakpoint, "treatment_breakpoint", treatment_abundance, is_reverse, - num_samples, num_control, num_treatment) - if res_control == dict(zip(result_tuple_fields, result_tuple_fields)): - res_control = False - if res_treatment == dict(zip(result_tuple_fields, result_tuple_fields)): - res_treatment == False + res_treatment = breakpoints_to_result(utr, utr_d, treatment_breakpoints, "treatment_breakpoint", treatment_abundances, is_reverse, + num_control, num_treatment) return res_control, res_treatment -def breakpoint_to_result(res, utr, utr_d, breakpoint, breakpoint_type, - abundances, is_reverse, num_samples, num_control, num_treatment): +def breakpoints_to_result(utr, utr_d, breakpoints, breakpoint_type, + abundances, is_reverse, num_control, num_treatment): """ Takes in a result dictionary res and fills the necessary fields """ - long_coverage_vector = abundances[0] - short_coverage_vector = abundances[1] - num_non_zero = sum((np.array(long_coverage_vector) + np.array(short_coverage_vector)) > 0) # TODO: This introduces bias - if num_non_zero == num_samples: - percentage_long = [] - for i in range(num_samples): - ratio = float(long_coverage_vector[i]) / (long_coverage_vector[i] + short_coverage_vector[i]) # long 3'UTR percentage - percentage_long.append(ratio) + if not breakpoints: + return False + result = [] + for breakpoint, abundance in zip(breakpoints, abundances): + res = {} + long_coverage_vector = abundance[0] + short_coverage_vector = abundance[1] + percentage_long = long_coverage_vector/(long_coverage_vector+short_coverage_vector) for i in range(num_control): res["control_{i}_coverage_long".format(i=i)] = float(long_coverage_vector[i]) res["control_{i}_coverage_short".format(i=i)] = float(short_coverage_vector[i]) @@ -237,6 +256,7 @@ res["treatment_{i}_coverage_long".format(i=k)] = float(long_coverage_vector[i]) res["treatment_{i}_coverage_short".format(i=k)] = float(short_coverage_vector[i]) res["treatment_{i}_percent_long".format(i=k)] = percentage_long[i] + res["t_stat"], res["p_value"] = stat_test(percentage_long[:num_control], percentage_long[num_control:]) control_mean_percent = np.mean(np.array(percentage_long[:num_control])) treatment_mean_percent = np.mean(np.array(percentage_long[num_control:])) res["chr"] = utr_d["chr"] @@ -252,43 +272,85 @@ res["control_mean_percent"] = control_mean_percent res["treatment_mean_percent"] = treatment_mean_percent res["gene"] = utr + result.append(res) + return result -def optimize_breakpoint(utr_coverage, UTR_start, UTR_end, coverage_weigths, search_start, coverage_threshold, num_control): +def optimize_breakpoint(plot_path, utr, utr_coverage, UTR_start, UTR_end, coverage_weigths, search_start, coverage_threshold, num_control): """ We are searching for a point within the UTR that minimizes the mean squared error, if the coverage vector was divided at that point. utr_coverage is a list with items corresponding to numpy arrays of coverage for a sample. """ - search_point_end = int(abs((UTR_end - UTR_start)) * 0.1) # TODO: This is 10% of total UTR end. Why? num_samples = len(utr_coverage) - normalized_utr_coverage = np.array([coverage/ coverage_weigths[i] for i, coverage in enumerate( utr_coverage.values() )]) + normalized_utr_coverage = np.array(utr_coverage.values())/np.expand_dims(coverage_weigths, axis=1) start_coverage = [np.mean(coverage[0:99]) for coverage in utr_coverage.values()] # filters threshold on mean coverage over first 100 nt is_above_threshold = sum(np.array(start_coverage) >= coverage_threshold) >= num_samples # This filters on the raw threshold. Why? is_above_length = UTR_end - UTR_start >= 150 if (is_above_threshold) and (is_above_length): - search_end = UTR_end - UTR_start - search_point_end + search_end = UTR_end - UTR_start breakpoints = range(search_start, search_end + 1) mse_list = [ estimate_mse(normalized_utr_coverage, bp, num_samples, num_control) for bp in breakpoints ] + mse_list = [mse_list[0] for i in xrange(search_start)] + mse_list + if plot_path: + plot_coverage_breakpoint(plot_path, utr, mse_list, normalized_utr_coverage, num_control) if len(mse_list) > 0: - return mse_to_breakpoint(mse_list, normalized_utr_coverage, breakpoints, num_samples) + return mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples) return False, False, False, False -def mse_to_breakpoint(mse_list, normalized_utr_coverage, breakpoints, num_samples): +def plot_coverage_breakpoint(plot_path, utr, mse_list, normalized_utr_coverage, num_control): """ - Take in mse_list with control and treatment mse and return breakpoint and utr abundance + """ - mse_control = [mse[0] for mse in mse_list] - mse_treatment = [mse[1] for mse in mse_list] - control_index = mse_control.index(min(mse_control)) - treatment_index = mse_treatment.index(min(mse_treatment)) - control_breakpoint = breakpoints[control_index] - treatment_breakpoint = breakpoints[treatment_index] - control_abundance = estimate_abundance(normalized_utr_coverage, control_breakpoint, num_samples) - treatment_abundance = estimate_abundance(normalized_utr_coverage, treatment_breakpoint, num_samples) - return control_breakpoint, control_abundance, treatment_breakpoint, treatment_abundance + fig = plt.figure(figsize=(8, 8)) + gs = gridspec.GridSpec(2, 1) + ax1 = plt.subplot(gs[0, :]) + ax2 = plt.subplot(gs[1, :]) + ax1.set_title("mean-squared error plot") + ax1.set_ylabel("mean-squared error") + ax1.set_xlabel("nt after UTR start") + ax2.set_title("coverage plot") + ax2.set_xlabel("nt after UTR start") + ax2.set_ylabel("normalized nucleotide coverage") + mse_control = [ condition[0] for condition in mse_list] + mse_treatment = [ condition[1] for condition in mse_list] + minima_control = get_minima(np.array(mse_control)) + minima_treatment = get_minima(np.array(mse_treatment)) + control = normalized_utr_coverage[:num_control] + treatment = normalized_utr_coverage[num_control:] + ax1.plot(mse_control, "b-") + ax1.plot(mse_treatment, "r-") + [ax2.plot(cov, "b-") for cov in control] + [ax2.plot(cov, "r-") for cov in treatment] + [ax2.axvline(val, color="b", alpha=0.25) for val in minima_control] + ax2.axvline(mse_control.index(min(mse_control)), color="b", alpha=1) + [ax2.axvline(val, color="r", alpha=0.25) for val in minima_treatment] + ax2.axvline(mse_treatment.index(min(mse_treatment)), color="r", alpha=1) + fig.add_subplot(ax1) + fig.add_subplot(ax2) + gs.tight_layout(fig) + fig.savefig(os.path.join(plot_path, "{utr}.svg".format(utr=utr))) +def mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples): + """ + Take in mse_list with control and treatment mse and return breakpoint and utr abundance for all local minima + in mse_list + """ + mse_control = np.array([mse[0] for mse in mse_list]) + mse_treatment = np.array([mse[1] for mse in mse_list]) + control_breakpoints = list(get_minima(mse_control)) + treatment_breakpoints = list(get_minima(mse_treatment)) + control_abundances = [estimate_abundance(normalized_utr_coverage, bp, num_samples) for bp in control_breakpoints] + treatment_abundances = [estimate_abundance(normalized_utr_coverage, bp, num_samples) for bp in treatment_breakpoints] + return control_breakpoints, control_abundances, treatment_breakpoints, treatment_abundances + +def get_minima(a): + """ + get minima for numpy array a + """ + return np.where(np.r_[True, a[1:] < a[:-1]] & np.r_[a[:-1] < a[1:], True])[0]+1 + def estimate_mse(cov, bp, num_samples, num_control): """ get abundance of long utr vs short utr with breakpoint specifying the position of long and short utr. @@ -315,6 +377,11 @@ mean_short_utr = np.mean(short_utr_vector, 1) return mean_long_utr, mean_short_utr +def stat_test(a,b): + return stats.ttest_ind(a,b) + +def gene_str_to_link(str): + return "<a href=\"{str}.svg\" type=\"image/svg+xml\" target=\"_blank\">{str}</a>".format(str=str) if __name__ == '__main__': args = parse_args()