diff dapars.py @ 5:a5d8b08af089 draft

planemo upload for repository https://github.com/mvdbeek/dapars commit deab588a5d5ec7022de63a395fbd04e415ba0a42
author mvdbeek
date Thu, 29 Oct 2015 15:51:10 -0400
parents 73b932244237
children 538c4e2b423e
line wrap: on
line diff
--- a/dapars.py	Wed Oct 28 06:22:18 2015 -0400
+++ b/dapars.py	Thu Oct 29 15:51:10 2015 -0400
@@ -2,19 +2,27 @@
 import os
 import csv
 import numpy as np
+from scipy import stats
 from collections import OrderedDict, namedtuple
 import filter_utr
 import subprocess
 from multiprocessing import Pool
 import warnings
+import matplotlib.pyplot as plt
+import matplotlib.gridspec as gridspec
+from tabulate import tabulate
 
+def directory_path(str):
+    if os.path.exists(str):
+        return str
+    else:
+        os.mkdir(str)
+        return str
 
 def parse_args():
     """
     Returns floating point values except for input files.
     My initial approach will not filter anything. (FDR. fold_change, PDUI, Num_least ...)
-    :param argv:
-    :return:
     """
     parser = argparse.ArgumentParser(prog='DaPars', description='Determines the usage of proximal polyA usage')
     parser.add_argument("-c", "--control_alignments", nargs="+", required=True,
@@ -33,7 +41,11 @@
                         help="minimum coverage in each aligment to be considered for determining breakpoints")
     parser.add_argument("-b", "--breakpoint_bed", required=False, type=argparse.FileType('w'),
                         help="Write bedfile with coordinates of breakpoint positions to supplied path.")
-    parser.add_argument("-v", "--version", action='version', version='%(prog)s 0.1.5')
+    parser.add_argument("-v", "--version", action='version', version='%(prog)s 0.2.0')
+    parser.add_argument("-p", "--plot_path", default=None, required=False, type=directory_path,
+                        help="If plot_path is specified will write a coverage plot for every UTR in that directory.")
+    parser.add_argument("-html", "--html_file", default=None, required=False, type=argparse.FileType('w'),
+                        help="Write an html file to the specified location. Only to be used within a galaxy wrapper")
     return parser.parse_args()
 
 
@@ -48,6 +60,8 @@
         self.n_cpus = args.cpu
         self.search_start = args.search_start
         self.coverage_threshold = args.coverage_threshold
+        self.plot_path = args.plot_path
+        self.html_file = args.html_file
         self.utr = args.utr_bed_file
         self.gtf_fields = filter_utr.get_gtf_fields()
         self.result_file = args.output_file
@@ -67,7 +81,8 @@
         if args.breakpoint_bed:
             self.bed_output = args.breakpoint_bed
             self.write_bed()
-
+        if self.plot_path:
+            self.write_html()
 
     def dump_utr_dict_to_bedfile(self):
         w = csv.writer(open("tmp_bedfile.bed", "w"), delimiter="\t")
@@ -110,6 +125,10 @@
         return coverage_dict
 
     def get_utr_dict(self, shift):
+        """
+        The utr end is extended by UTR length * shift, to discover novel distal polyA sites.
+        Set to 0 to disable.
+        """
         utr_dict = OrderedDict()
         for line in self.utr:
             if not line.startswith("#"):
@@ -139,11 +158,11 @@
                 utr_coverage.append(np.sum(vector))
             coverage_per_alignment.append(utr_coverage)
         coverages = np.array([ sum(x) for x in zip(*coverage_per_alignment) ])
-        coverage_weights = coverages / np.mean(coverages)  # TODO: proabably median is better suited?
+        coverage_weights = coverages / np.mean(coverages)  # TODO: proabably median is better suited? Or even no normalization!
         return coverage_weights
 
     def get_result_tuple(self):
-        static_desc = ["chr", "start", "end", "strand", "gene", "breakpoint",
+        static_desc = ["chr", "start", "end", "strand", "gene", "t_stat", "p_value", "breakpoint",
                        "breakpoint_type", "control_mean_percent", "treatment_mean_percent" ]
         samples_desc = []
         for statistic in ["coverage_long", "coverage_short", "percent_long"]:
@@ -162,18 +181,22 @@
                  "num_treatment":len(self.treatment_alignments),
                  "result_d":result_d}
         pool = Pool(self.n_cpus)
-        tasks = [ (self.utr_coverages[utr], utr, utr_d, self.result_tuple._fields, self.coverage_weights, self.num_samples,
-                    len(self.control_alignments), len(self.treatment_alignments), self.search_start,
-                   self.coverage_threshold) for utr, utr_d in self.utr_dict.iteritems() ]
+        tasks = [ (self.utr_coverages[utr], self.plot_path, utr, utr_d, self.coverage_weights, len(self.control_alignments),
+                   len(self.treatment_alignments), self.search_start, self.coverage_threshold) \
+                  for utr, utr_d in self.utr_dict.iteritems() ]
         processed_tasks = [ pool.apply_async(calculate_all_utr, t) for t in tasks]
-        result = [res.get() for res in processed_tasks]
-        for res_control, res_treatment in result:
-            if isinstance(res_control, dict):
-                t = self.result_tuple(**res_control)
-                result_d[res_control["gene"]+"_bp_control"] = t
-            if isinstance(res_treatment, dict):
-                t = self.result_tuple(**res_treatment)
-                result_d[res_treatment["gene"]+"_bp_treatment"] = t
+        result_list = [res.get() for res in processed_tasks]
+        for res_control, res_treatment in result_list:
+            if not res_control:
+                continue
+            for i, result in enumerate(res_control):
+                if isinstance(result, dict):
+                    t = self.result_tuple(**result)
+                    result_d[result["gene"]+"_bp_control_{i}".format(i=i)] = t
+            for i, result in enumerate(res_treatment):
+                if isinstance(result, dict):
+                    t = self.result_tuple(**result)
+                    result_d[result["gene"]+"_bp_treatment_{i}".format(i=i)] = t
         return result_d
 
     def write_results(self):
@@ -183,51 +206,47 @@
         w.writerow(header)    # field header
         w.writerows( self.result_d.values())
 
+    def write_html(self):
+        output_lines = [(gene_str_to_link(result.gene), result.breakpoint, result.breakpoint_type, result.p_value ) for result in self.result_d.itervalues()]
+        if self.html_file:
+            self.html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value"], tablefmt="html"))
+        else:
+            with open(os.path.join(self.plot_path, "index.html"), "w") as html_file:
+                html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value"], tablefmt="html"))
+
     def write_bed(self):
         w = csv.writer(self.bed_output, delimiter='\t')
         bed = [(result.chr, result.breakpoint, int(result.breakpoint)+1, result.gene+"_"+result.breakpoint_type, 0, result.strand) for result in self.result_d.itervalues()]
         w.writerows(bed)
 
 
-def calculate_all_utr(utr_coverage, utr, utr_d, result_tuple_fields, coverage_weights, num_samples, num_control,
-                      num_treatment, search_start, coverage_threshold):
-    res_control = dict(zip(result_tuple_fields, result_tuple_fields))
-    res_treatment = res_control.copy()
+def calculate_all_utr(utr_coverage, plot_path, utr, utr_d, coverage_weights, num_control, num_treatment, search_start, coverage_threshold):
     if utr_d["strand"] == "+":
         is_reverse = False
     else:
         is_reverse = True
-    control_breakpoint, \
-    control_abundance, \
-    treatment_breakpoint, \
-    treatment_abundance  = optimize_breakpoint(utr_coverage, utr_d["new_start"], utr_d["new_end"], coverage_weights,
-                                                 search_start, coverage_threshold, num_control)
-    if control_breakpoint:
-        breakpoint_to_result(res_control, utr, utr_d, control_breakpoint, "control_breakpoint", control_abundance, is_reverse, num_samples,
+    control_breakpoints, control_abundances, treatment_breakpoints, treatment_abundances  = \
+        optimize_breakpoint(plot_path, utr, utr_coverage, utr_d["new_start"], utr_d["new_end"], coverage_weights, search_start, coverage_threshold, num_control)
+    res_control = breakpoints_to_result(utr, utr_d, control_breakpoints, "control_breakpoint", control_abundances, is_reverse,
                              num_control, num_treatment)
-    if treatment_breakpoint:
-        breakpoint_to_result(res_treatment, utr, utr_d, treatment_breakpoint, "treatment_breakpoint", treatment_abundance, is_reverse,
-                             num_samples, num_control, num_treatment)
-    if res_control == dict(zip(result_tuple_fields, result_tuple_fields)):
-        res_control = False
-    if res_treatment == dict(zip(result_tuple_fields, result_tuple_fields)):
-        res_treatment == False
+    res_treatment = breakpoints_to_result(utr, utr_d, treatment_breakpoints, "treatment_breakpoint", treatment_abundances, is_reverse,
+                             num_control, num_treatment)
     return res_control, res_treatment
 
 
-def breakpoint_to_result(res, utr, utr_d, breakpoint, breakpoint_type,
-                         abundances, is_reverse, num_samples, num_control, num_treatment):
+def breakpoints_to_result(utr, utr_d, breakpoints, breakpoint_type,
+                         abundances, is_reverse, num_control, num_treatment):
     """
     Takes in a result dictionary res and fills the necessary fields
     """
-    long_coverage_vector = abundances[0]
-    short_coverage_vector = abundances[1]
-    num_non_zero = sum((np.array(long_coverage_vector) + np.array(short_coverage_vector)) > 0)  # TODO: This introduces bias
-    if num_non_zero == num_samples:
-        percentage_long = []
-        for i in range(num_samples):
-            ratio = float(long_coverage_vector[i]) / (long_coverage_vector[i] + short_coverage_vector[i])  # long 3'UTR percentage
-            percentage_long.append(ratio)
+    if not breakpoints:
+        return False
+    result = []
+    for breakpoint, abundance in zip(breakpoints, abundances):
+        res = {}
+        long_coverage_vector = abundance[0]
+        short_coverage_vector = abundance[1]
+        percentage_long = long_coverage_vector/(long_coverage_vector+short_coverage_vector)
         for i in range(num_control):
             res["control_{i}_coverage_long".format(i=i)] = float(long_coverage_vector[i])
             res["control_{i}_coverage_short".format(i=i)] = float(short_coverage_vector[i])
@@ -237,6 +256,7 @@
             res["treatment_{i}_coverage_long".format(i=k)] = float(long_coverage_vector[i])
             res["treatment_{i}_coverage_short".format(i=k)] = float(short_coverage_vector[i])
             res["treatment_{i}_percent_long".format(i=k)] = percentage_long[i]
+        res["t_stat"], res["p_value"] = stat_test(percentage_long[:num_control], percentage_long[num_control:])
         control_mean_percent = np.mean(np.array(percentage_long[:num_control]))
         treatment_mean_percent = np.mean(np.array(percentage_long[num_control:]))
         res["chr"] = utr_d["chr"]
@@ -252,43 +272,85 @@
         res["control_mean_percent"] = control_mean_percent
         res["treatment_mean_percent"] = treatment_mean_percent
         res["gene"] = utr
+        result.append(res)
+    return result
 
 
-def optimize_breakpoint(utr_coverage, UTR_start, UTR_end, coverage_weigths, search_start, coverage_threshold, num_control):
+def optimize_breakpoint(plot_path, utr, utr_coverage, UTR_start, UTR_end, coverage_weigths, search_start, coverage_threshold, num_control):
     """
     We are searching for a point within the UTR that minimizes the mean squared error, if the coverage vector was divided
     at that point. utr_coverage is a list with items corresponding to numpy arrays of coverage for a sample.
     """
-    search_point_end = int(abs((UTR_end - UTR_start)) * 0.1)  # TODO: This is 10% of total UTR end. Why?
     num_samples = len(utr_coverage)
-    normalized_utr_coverage = np.array([coverage/ coverage_weigths[i] for i, coverage in enumerate( utr_coverage.values() )])
+    normalized_utr_coverage = np.array(utr_coverage.values())/np.expand_dims(coverage_weigths, axis=1)
     start_coverage = [np.mean(coverage[0:99]) for coverage in utr_coverage.values()]  # filters threshold on mean coverage over first 100 nt
     is_above_threshold = sum(np.array(start_coverage) >= coverage_threshold) >= num_samples  # This filters on the raw threshold. Why?
     is_above_length = UTR_end - UTR_start >= 150
     if (is_above_threshold) and (is_above_length):
-        search_end = UTR_end - UTR_start - search_point_end
+        search_end = UTR_end - UTR_start
         breakpoints = range(search_start, search_end + 1)
         mse_list = [ estimate_mse(normalized_utr_coverage, bp, num_samples, num_control) for bp in breakpoints ]
+        mse_list = [mse_list[0] for i in xrange(search_start)] + mse_list
+        if plot_path:
+            plot_coverage_breakpoint(plot_path, utr, mse_list, normalized_utr_coverage, num_control)
         if len(mse_list) > 0:
-            return mse_to_breakpoint(mse_list, normalized_utr_coverage, breakpoints, num_samples)
+            return mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples)
     return False, False, False, False
 
 
-def mse_to_breakpoint(mse_list, normalized_utr_coverage, breakpoints, num_samples):
+def plot_coverage_breakpoint(plot_path, utr, mse_list, normalized_utr_coverage, num_control):
     """
-    Take in mse_list with control and treatment mse and return breakpoint and utr abundance
+
     """
-    mse_control = [mse[0] for mse in mse_list]
-    mse_treatment = [mse[1] for mse in mse_list]
-    control_index = mse_control.index(min(mse_control))
-    treatment_index = mse_treatment.index(min(mse_treatment))
-    control_breakpoint = breakpoints[control_index]
-    treatment_breakpoint = breakpoints[treatment_index]
-    control_abundance = estimate_abundance(normalized_utr_coverage, control_breakpoint, num_samples)
-    treatment_abundance = estimate_abundance(normalized_utr_coverage, treatment_breakpoint, num_samples)
-    return control_breakpoint, control_abundance, treatment_breakpoint, treatment_abundance
+    fig = plt.figure(figsize=(8, 8))
+    gs = gridspec.GridSpec(2, 1)
+    ax1 = plt.subplot(gs[0, :])
+    ax2 = plt.subplot(gs[1, :])
+    ax1.set_title("mean-squared error plot")
+    ax1.set_ylabel("mean-squared error")
+    ax1.set_xlabel("nt after UTR start")
+    ax2.set_title("coverage plot")
+    ax2.set_xlabel("nt after UTR start")
+    ax2.set_ylabel("normalized nucleotide coverage")
+    mse_control = [ condition[0] for condition in mse_list]
+    mse_treatment = [ condition[1] for condition in mse_list]
+    minima_control = get_minima(np.array(mse_control))
+    minima_treatment = get_minima(np.array(mse_treatment))
+    control = normalized_utr_coverage[:num_control]
+    treatment = normalized_utr_coverage[num_control:]
+    ax1.plot(mse_control, "b-")
+    ax1.plot(mse_treatment, "r-")
+    [ax2.plot(cov, "b-") for cov in control]
+    [ax2.plot(cov, "r-") for cov in treatment]
+    [ax2.axvline(val, color="b", alpha=0.25) for val in minima_control]
+    ax2.axvline(mse_control.index(min(mse_control)), color="b", alpha=1)
+    [ax2.axvline(val, color="r", alpha=0.25) for val in minima_treatment]
+    ax2.axvline(mse_treatment.index(min(mse_treatment)), color="r", alpha=1)
+    fig.add_subplot(ax1)
+    fig.add_subplot(ax2)
+    gs.tight_layout(fig)
+    fig.savefig(os.path.join(plot_path, "{utr}.svg".format(utr=utr)))
 
 
+def mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples):
+    """
+    Take in mse_list with control and treatment mse and return breakpoint and utr abundance for all local minima
+    in mse_list
+    """
+    mse_control = np.array([mse[0] for mse in mse_list])
+    mse_treatment = np.array([mse[1] for mse in mse_list])
+    control_breakpoints = list(get_minima(mse_control))
+    treatment_breakpoints = list(get_minima(mse_treatment))
+    control_abundances = [estimate_abundance(normalized_utr_coverage, bp, num_samples) for bp in control_breakpoints]
+    treatment_abundances = [estimate_abundance(normalized_utr_coverage, bp, num_samples) for bp in treatment_breakpoints]
+    return control_breakpoints, control_abundances, treatment_breakpoints, treatment_abundances
+
+def get_minima(a):
+    """
+    get minima for numpy array a
+    """
+    return np.where(np.r_[True, a[1:] < a[:-1]] & np.r_[a[:-1] < a[1:], True])[0]+1
+
 def estimate_mse(cov, bp, num_samples, num_control):
     """
     get abundance of long utr vs short utr with breakpoint specifying the position of long and short utr.
@@ -315,6 +377,11 @@
         mean_short_utr = np.mean(short_utr_vector, 1)
         return mean_long_utr, mean_short_utr
 
+def stat_test(a,b):
+    return stats.ttest_ind(a,b)
+
+def gene_str_to_link(str):
+    return "<a href=\"{str}.svg\" type=\"image/svg+xml\" target=\"_blank\">{str}</a>".format(str=str)
 
 if __name__ == '__main__':
     args = parse_args()