diff dapars.py @ 17:917a2f7ab841 draft

planemo upload for repository https://github.com/mvdbeek/dapars commit b1b007c561ea6c9db145c88b6b128d66ecd05e24-dirty
author mvdbeek
date Fri, 30 Oct 2015 10:35:17 -0400
parents f8bb40b2ff31
children ed151db39c7e
line wrap: on
line diff
--- a/dapars.py	Fri Oct 30 07:31:36 2015 -0400
+++ b/dapars.py	Fri Oct 30 10:35:17 2015 -0400
@@ -11,6 +11,7 @@
 import matplotlib.pyplot as plt
 import matplotlib.gridspec as gridspec
 from tabulate import tabulate
+import statsmodels.sandbox.stats.multicomp as mc
 
 def directory_path(str):
     if os.path.exists(str):
@@ -35,13 +36,14 @@
                         help="file containing output")
     parser.add_argument("-cpu", required=False, type=int, default=1,
                         help="Number of CPU cores to use.")
+    parser.add_argument("-l", "--local_minimum", action='store_true')
     parser.add_argument("-s", "--search_start", required=False, type=int, default=50,
                         help="Start search for breakpoint n nucleotides downstream of UTR start")
     parser.add_argument("-ct", "--coverage_threshold", required=False, type=float, default=20,
                         help="minimum coverage in each aligment to be considered for determining breakpoints")
     parser.add_argument("-b", "--breakpoint_bed", required=False, type=argparse.FileType('w'),
                         help="Write bedfile with coordinates of breakpoint positions to supplied path.")
-    parser.add_argument("-v", "--version", action='version', version='%(prog)s 0.2.2')
+    parser.add_argument("-v", "--version", action='version', version='%(prog)s 0.2.3')
     parser.add_argument("-p", "--plot_path", default=None, required=False, type=directory_path,
                         help="If plot_path is specified will write a coverage plot for every UTR in that directory.")
     parser.add_argument("-html", "--html_file", default=None, required=False, type=argparse.FileType('w'),
@@ -60,6 +62,7 @@
         self.n_cpus = args.cpu
         self.search_start = args.search_start
         self.coverage_threshold = args.coverage_threshold
+        self.local_minimum = args.local_minimum
         self.plot_path = args.plot_path
         self.html_file = args.html_file
         self.utr = args.utr_bed_file
@@ -77,6 +80,7 @@
         self.coverage_weights = self.get_coverage_weights()
         self.result_tuple = self.get_result_tuple()
         self.result_d = self.calculate_apa_ratios()
+        self.results = self.order_by_p()
         self.write_results()
         if args.breakpoint_bed:
             self.bed_output = args.breakpoint_bed
@@ -84,6 +88,15 @@
         if self.plot_path:
             self.write_html()
 
+    def order_by_p(self):
+        results = [result for result in self.result_d.itervalues()]
+        p_values = np.array([ result.p_value for result in self.result_d.itervalues() ])
+        adj_p_values = mc.fdrcorrection0(p_values, 0.05)[1]
+        sort_index = np.argsort(adj_p_values)
+        results = [ results[i]._replace(adj_p_value=adj_p_values[i]) for i in sort_index ]
+        return results
+
+
     def dump_utr_dict_to_bedfile(self):
         w = csv.writer(open("tmp_bedfile.bed", "w"), delimiter="\t")
         for gene, utr in self.utr_dict.iteritems():
@@ -160,7 +173,7 @@
         return coverage_weights
 
     def get_result_tuple(self):
-        static_desc = ["chr", "start", "end", "strand", "gene", "t_stat", "p_value", "breakpoint",
+        static_desc = ["chr", "start", "end", "strand", "gene", "t_stat", "p_value", "adj_p_value", "breakpoint",
                        "breakpoint_type", "control_mean_percent", "treatment_mean_percent" ]
         samples_desc = []
         for statistic in ["coverage_long", "coverage_short", "percent_long"]:
@@ -180,11 +193,11 @@
                  "result_d":result_d}
         pool = Pool(self.n_cpus)
         tasks = [ (self.utr_coverages[utr], self.plot_path, utr, utr_d, self.coverage_weights, len(self.control_alignments),
-                   len(self.treatment_alignments), self.search_start, self.coverage_threshold) \
+                   len(self.treatment_alignments), self.search_start, self.local_minimum, self.coverage_threshold) \
                   for utr, utr_d in self.utr_dict.iteritems() ]
-        processed_tasks = [ pool.apply_async(calculate_all_utr, t) for t in tasks]
-        result_list = [res.get() for res in processed_tasks]
-        #result_list = [calculate_all_utr(*t) for t in tasks]  # uncomment for easier debugging
+        #processed_tasks = [ pool.apply_async(calculate_all_utr, t) for t in tasks]
+        #result_list = [res.get() for res in processed_tasks]
+        result_list = [calculate_all_utr(*t) for t in tasks]  # uncomment for easier debugging
         for res_control, res_treatment in result_list:
             if not res_control:
                 continue
@@ -203,29 +216,30 @@
         header = list(self.result_tuple._fields)
         header[0] = "#chr"
         w.writerow(header)    # field header
-        w.writerows( self.result_d.values())
+        w.writerows( self.results)
 
     def write_html(self):
-        output_lines = [(gene_str_to_link(result.gene), result.breakpoint, result.breakpoint_type, result.p_value ) for result in self.result_d.itervalues()]
+        output_lines = [(gene_str_to_link(result.gene), result.breakpoint, result.breakpoint_type, result.p_value ) for result in self.results]
         if self.html_file:
-            self.html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value"], tablefmt="html"))
+            self.html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value", "adj_p_value"], tablefmt="html"))
         else:
             with open(os.path.join(self.plot_path, "index.html"), "w") as html_file:
-                html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value"], tablefmt="html"))
+                html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value", "adj_p_value"], tablefmt="html"))
+                html_file.write(tabulate(output_lines, headers=["gene", "breakpoint", "breakpoint_type", "p_value", "adj_p_value"], tablefmt="html"))
 
     def write_bed(self):
         w = csv.writer(self.bed_output, delimiter='\t')
-        bed = [(result.chr, result.breakpoint, int(result.breakpoint)+1, result.gene+"_"+result.breakpoint_type, 0, result.strand) for result in self.result_d.itervalues()]
+        bed = [(result.chr, result.breakpoint, int(result.breakpoint)+1, result.gene+"_"+result.breakpoint_type, 0, result.strand) for result in self.results]
         w.writerows(bed)
 
 
-def calculate_all_utr(utr_coverage, plot_path, utr, utr_d, coverage_weights, num_control, num_treatment, search_start, coverage_threshold):
+def calculate_all_utr(utr_coverage, plot_path, utr, utr_d, coverage_weights, num_control, num_treatment, search_start, local_minimum, coverage_threshold):
     if utr_d["strand"] == "+":
         is_reverse = False
     else:
         is_reverse = True
     control_breakpoints, control_abundances, treatment_breakpoints, treatment_abundances  = \
-        optimize_breakpoint(plot_path, utr, utr_coverage, utr_d["new_start"], utr_d["new_end"], coverage_weights, search_start, coverage_threshold, num_control)
+        optimize_breakpoint(plot_path, utr, utr_coverage, utr_d["new_start"], utr_d["new_end"], coverage_weights, search_start, local_minimum, coverage_threshold, num_control)
     res_control = breakpoints_to_result(utr, utr_d, control_breakpoints, "control_breakpoint", control_abundances, is_reverse,
                              num_control, num_treatment)
     res_treatment = breakpoints_to_result(utr, utr_d, treatment_breakpoints, "treatment_breakpoint", treatment_abundances, is_reverse,
@@ -243,17 +257,17 @@
     result = []
     for breakpoint, abundance in zip(breakpoints, abundances):
         res = {}
-        long_coverage_vector = abundance[0]
-        short_coverage_vector = abundance[1]
+        res["adj_p_value"] = "NA"
+        long_coverage_vector, short_coverage_vector = abundance
         percentage_long = long_coverage_vector/(long_coverage_vector+short_coverage_vector)
         for i in range(num_control):
-            res["control_{i}_coverage_long".format(i=i)] = float(long_coverage_vector[i])
-            res["control_{i}_coverage_short".format(i=i)] = float(short_coverage_vector[i])
+            res["control_{i}_coverage_long".format(i=i)] = long_coverage_vector[i]
+            res["control_{i}_coverage_short".format(i=i)] = short_coverage_vector[i]
             res["control_{i}_percent_long".format(i=i)] = percentage_long[i]
         for k in range(num_treatment):
             i = k + num_control
-            res["treatment_{i}_coverage_long".format(i=k)] = float(long_coverage_vector[i])
-            res["treatment_{i}_coverage_short".format(i=k)] = float(short_coverage_vector[i])
+            res["treatment_{i}_coverage_long".format(i=k)] = long_coverage_vector[i]
+            res["treatment_{i}_coverage_short".format(i=k)] = short_coverage_vector[i]
             res["treatment_{i}_percent_long".format(i=k)] = percentage_long[i]
         res["t_stat"], res["p_value"] = stat_test(percentage_long[:num_control], percentage_long[num_control:])
         control_mean_percent = np.mean(percentage_long[:num_control])
@@ -275,7 +289,7 @@
     return result
 
 
-def optimize_breakpoint(plot_path, utr, utr_coverage, UTR_start, UTR_end, coverage_weigths, search_start, coverage_threshold, num_control):
+def optimize_breakpoint(plot_path, utr, utr_coverage, UTR_start, UTR_end, coverage_weigths, search_start, local_minimum, coverage_threshold, num_control):
     """
     We are searching for a point within the UTR that minimizes the mean squared error, if the coverage vector was divided
     at that point. utr_coverage is a list with items corresponding to numpy arrays of coverage for a sample.
@@ -293,7 +307,7 @@
         if plot_path:
             plot_coverage_breakpoint(plot_path, utr, mse_list, normalized_utr_coverage, num_control)
         if len(mse_list) > 0:
-            return mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples)
+            return mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples, local_minimum)
     return False, False, False, False
 
 
@@ -331,24 +345,27 @@
     fig.savefig(os.path.join(plot_path, "{utr}.svg".format(utr=utr)))
 
 
-def mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples):
+def mse_to_breakpoint(mse_list, normalized_utr_coverage, num_samples, local_minimum):
     """
     Take in mse_list with control and treatment mse and return breakpoint and utr abundance for all local minima
     in mse_list
     """
     mse_control = np.array([mse[0] for mse in mse_list])
     mse_treatment = np.array([mse[1] for mse in mse_list])
-    control_breakpoints = list(get_minima(mse_control))
-    treatment_breakpoints = list(get_minima(mse_treatment))
+    control_breakpoints = list(get_minima(mse_control, local_minimum))
+    treatment_breakpoints = list(get_minima(mse_treatment, local_minimum))
     control_abundances = [estimate_abundance(normalized_utr_coverage, bp, num_samples) for bp in control_breakpoints]
     treatment_abundances = [estimate_abundance(normalized_utr_coverage, bp, num_samples) for bp in treatment_breakpoints]
     return control_breakpoints, control_abundances, treatment_breakpoints, treatment_abundances
 
-def get_minima(a):
+def get_minima(a, local_minimum=False):
+    """
+    get minima for numpy array a. If local is false, only return absolute minimum
     """
-    get minima for numpy array a
-    """
-    return np.where(np.r_[True, a[1:] < a[:-1]] & np.r_[a[:-1] < a[1:], True])[0]+1
+    if not local_minimum:
+        return np.where(a == a.min())
+    else:
+        return np.where(np.r_[True, a[1:] < a[:-1]] & np.r_[a[:-1] < a[1:], True])[0]+1
 
 def estimate_mse(cov, bp, num_samples, num_control):
     """