Mercurial > repos > mvdbeek > dapars
diff dapars.py @ 0:bb84ee2f2137 draft
planemo upload for repository https://github.com/mvdbeek/dapars commit 868f8f2f7ac5d70c39b7d725ff087833b0f24f52-dirty
author | mvdbeek |
---|---|
date | Tue, 27 Oct 2015 10:14:33 -0400 |
parents | |
children | 1b20ba32b4c5 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/dapars.py Tue Oct 27 10:14:33 2015 -0400 @@ -0,0 +1,322 @@ +import argparse +import os +import csv +import numpy as np +from collections import OrderedDict, namedtuple +import filter_utr +import subprocess +from multiprocessing import Pool +import warnings + + +def parse_args(): + """ + Returns floating point values except for input files. + My initial approach will not filter anything. (FDR. fold_change, PDUI, Num_least ...) + :param argv: + :return: + """ + parser = argparse.ArgumentParser(prog='DaPars', description='Determines the usage of proximal polyA usage') + parser.add_argument("-c", "--control_alignments", nargs="+", required=True, + help="Alignment files in BAM format from control condition") + parser.add_argument("-t", "--treatment_alignments", nargs="+", required=True, + help="Alignment files in BAM format from treatment condition") + parser.add_argument("-u", "--utr_bed_file", required=True, type=file, + help="Bed file describing longest 3UTR positions") + parser.add_argument("-o", "--output_file", required=True, type=argparse.FileType('w'), + help="file containing output") + parser.add_argument("-cpu", required=False, type=int, default=1, + help="Number of CPU cores to use.") + parser.add_argument("-s", "--search_start", required=False, type=int, default=50, + help="Start search for breakpoint n nucleotides downstream of UTR start") + parser.add_argument("-ct", "--coverage_threshold", required=False, type=float, default=20, + help="minimum coverage in each aligment to be considered for determining breakpoints") + parser.add_argument("-b", "--breakpoint_bed", required=False, type=argparse.FileType('w'), + help="Write bedfile with coordinates of breakpoint positions to supplied path.") + parser.add_argument("-v", "--version", action='version', version='%(prog)s 0.1.4') + return parser.parse_args() + + +class UtrFinder(): + """ + This seems to be the main caller. + """ + + def __init__(self, args): + self.control_alignments = [file for file in args.control_alignments] + self.treatment_alignments = [file for file in args.treatment_alignments] + self.n_cpus = args.cpu + self.search_start = args.search_start + self.coverage_threshold = args.coverage_threshold + self.utr = args.utr_bed_file + self.gtf_fields = filter_utr.get_gtf_fields() + self.result_file = args.output_file + self.all_alignments = self.control_alignments + self.treatment_alignments + self.alignment_names = { file: os.path.basename(file) for file in self.all_alignments } + self.num_samples = len(self.all_alignments) + self.utr_dict = self.get_utr_dict(0.2) + self.dump_utr_dict_to_bedfile() + print "Established dictionary of 3\'UTRs" + self.coverage_files = self.run_bedtools_coverage() + self.utr_coverages = self.read_coverage_result() + print "Established dictionary of 3\'UTR coverages" + self.coverage_weights = self.get_coverage_weights() + self.result_tuple = self.get_result_tuple() + self.result_d = self.calculate_apa_ratios() + self.write_results() + if args.breakpoint_bed: + self.bed_output = args.breakpoint_bed + self.write_bed() + + + def dump_utr_dict_to_bedfile(self): + w = csv.writer(open("tmp_bedfile.bed", "w"), delimiter="\t") + for gene, utr in self.utr_dict.iteritems(): + w.writerow([utr["chr"], utr["new_start"]-1, utr["new_end"], gene, ".", utr["strand"]]) + + def run_bedtools_coverage(self): + """ + Use bedtools coverage to generate pileup data for all alignment files for the regions specified in utr_dict. + """ + coverage_files = [] + cmds = [] + for alignment_file in self.all_alignments: + cmd = "sort -k1,1 -k2,2n tmp_bedfile.bed | " + cmd = cmd + "bedtools coverage -d -s -abam {alignment_file} -b stdin |" \ + " cut -f 4,7,8 > coverage_file_{alignment_name}".format( + alignment_file = alignment_file, alignment_name= self.alignment_names[alignment_file] ) + cmds.append(cmd) + pool = Pool(self.n_cpus) + subprocesses = [subprocess.Popen([cmd], shell=True) for cmd in cmds] + [p.wait() for p in subprocesses] + coverage_files = ["gene_position_coverage_{alignment_name}".format( + alignment_name = self.alignment_names[alignment_file]) for alignment_file in self.all_alignments ] + return coverage_files + + def read_coverage_result(self): + """ + Read coverages back in and store as dictionary of numpy arrays + """ + coverage_dict = { gene: { name: np.zeros(utr_d["new_end"]+1-utr_d["new_start"]) for name in self.alignment_names.itervalues() } for gene, utr_d in self.utr_dict.iteritems() } + for alignment_name in self.alignment_names.itervalues(): + with open("coverage_file_{alignment_name}".format(alignment_name = alignment_name)) as coverage_file: + for line in coverage_file: + gene, position, coverage= line.strip().split("\t") + coverage_dict[gene][alignment_name][int(position)-1] = coverage + for utr_d in self.utr_dict.itervalues(): + if utr_d["strand"] == "-": + for alignment_name in self.alignment_names.values(): + coverage_dict[gene][alignment_name] = coverage_dict[gene][alignment_name][::-1] + return coverage_dict + + def get_utr_dict(self, shift): + utr_dict = OrderedDict() + for line in self.utr: + if not line.startswith("#"): + filter_utr.get_feature_dict( line=line, gtf_fields=self.gtf_fields, utr_dict=utr_dict, feature="UTR" ) + gene, utr_d = utr_dict.popitem() + utr_d = utr_d[0] + end_shift = int(round(abs(utr_d["start"] - utr_d["end"]) * shift)) + if utr_d["strand"] == "+": + utr_d["new_end"] = utr_d["end"] - end_shift + utr_d["new_start"] = utr_d["start"] + else: + utr_d["new_end"] = utr_d["end"] + utr_d["new_start"] = utr_d["start"] + end_shift + if utr_d["new_start"] + 50 < utr_d["new_end"]: + utr_dict[gene] = utr_d + return utr_dict + + def get_utr_coverage(self): + """ + Returns a dict: + { UTR : [coverage_aligment1, ...]} + """ + utr_coverages = {} + for utr, utr_d in self.utr_dict.iteritems(): + if utr_d["chr"] in self.available_chromosomes: + if utr_d["strand"] == "+": + is_reverse = False + else: + is_reverse = True + utr_coverage = [] + for bam in self.all_alignments: + bp_coverage = get_bp_coverage(bam, utr_d["chr"], utr_d["new_start"], utr_d["new_end"], is_reverse) + utr_coverage.append(bp_coverage) + utr_coverages[utr] = utr_coverage + return utr_coverages + + def get_coverage_weights(self): + """ + Return weights for normalizing coverage. + utr_coverage is still confusing. + """ + coverage_per_alignment = [] + for utr in self.utr_coverages.itervalues(): # TODO: be smarter about this. + utr_coverage = [] + for vector in utr.itervalues(): + utr_coverage.append(np.sum(vector)) + coverage_per_alignment.append(utr_coverage) + coverages = np.array([ sum(x) for x in zip(*coverage_per_alignment) ]) + coverage_weights = coverages / np.mean(coverages) # TODO: proabably median is better suited? + return coverage_weights + + def get_result_tuple(self): + static_desc = ["chr", "start", "end", "strand", "gene", "breakpoint", "control_mean_percent", "treatment_mean_percent" ] + samples_desc = [] + for statistic in ["coverage_long", "coverage_short", "percent_long"]: + for i, sample in enumerate(self.control_alignments): + samples_desc.append("control_{i}_{statistic}".format(i=i, statistic = statistic)) + for i, sample in enumerate(self.treatment_alignments): + samples_desc.append("treatment_{i}_{statistic}".format(i=i, statistic = statistic)) + return namedtuple("result", static_desc + samples_desc) + + def calculate_apa_ratios(self): + result_d = OrderedDict() + arg_d = {"result_tuple": self.result_tuple, + "coverage_weights":self.coverage_weights, + "num_samples":self.num_samples, + "num_control":len(self.control_alignments), + "num_treatment":len(self.treatment_alignments), + "result_d":result_d} + pool = Pool(self.n_cpus) + tasks = [ (self.utr_coverages[utr], utr, utr_d, self.result_tuple._fields, self.coverage_weights, self.num_samples, + len(self.control_alignments), len(self.treatment_alignments), result_d, self.search_start, self.coverage_threshold) for utr, utr_d in self.utr_dict.iteritems() ] + processed_tasks = [ pool.apply_async(calculate_all_utr, t) for t in tasks] + result = [res.get() for res in processed_tasks] + for d in result: + if isinstance(d, dict): + t = self.result_tuple(**d) + result_d[d["gene"]] = t + return result_d + + def write_results(self): + w = csv.writer(self.result_file, delimiter='\t') + header = list(self.result_tuple._fields) + header[0] = "#chr" + w.writerow(header) # field header + w.writerows( self.result_d.values()) + + def write_bed(self): + w = csv.writer(self.bed_output, delimiter='\t') + bed = [(result.chr, result.breakpoint, result.breakpoint+1, result.gene, 0, result.strand) for result in self.result_d.itervalues()] + w.writerows(bed) + +def calculate_all_utr(utr_coverage, utr, utr_d, result_tuple_fields, coverage_weights, num_samples, num_control, + num_treatment, result_d, search_start, coverage_threshold): + res = dict(zip(result_tuple_fields, result_tuple_fields)) + if utr_d["strand"] == "+": + is_reverse = False + else: + is_reverse = True + mse, breakpoint, abundances = estimate_coverage_extended_utr(utr_coverage, + utr_d["new_start"], + utr_d["new_end"], + is_reverse, + coverage_weights, + search_start, + coverage_threshold) + if not str(mse) == "Na": + long_coverage_vector = abundances[0] + short_coverage_vector = abundances[1] + num_non_zero = sum((np.array(long_coverage_vector) + np.array(short_coverage_vector)) > 0) # TODO: This introduces bias + if num_non_zero == num_samples: + percentage_long = [] + for i in range(num_samples): + ratio = float(long_coverage_vector[i]) / (long_coverage_vector[i] + short_coverage_vector[i]) # long 3'UTR percentage + percentage_long.append(ratio) + for i in range(num_control): + res["control_{i}_coverage_long".format(i=i)] = float(long_coverage_vector[i]) + res["control_{i}_coverage_short".format(i=i)] = float(short_coverage_vector[i]) + res["control_{i}_percent_long".format(i=i)] = percentage_long[i] + for k in range(num_treatment): + i = k + num_control + res["treatment_{i}_coverage_long".format(i=k)] = float(long_coverage_vector[i]) + res["treatment_{i}_coverage_short".format(i=k)] = float(short_coverage_vector[i]) + res["treatment_{i}_percent_long".format(i=k)] = percentage_long[i] + control_mean_percent = np.mean(np.array(percentage_long[:num_control])) + treatment_mean_percent = np.mean(np.array(percentage_long[num_control:])) + res["chr"] = utr_d["chr"] + res["start"] = utr_d["start"] + res["end"] = utr_d["end"] + res["strand"] = utr_d["strand"] + if is_reverse: + breakpoint = utr_d["new_end"] - breakpoint + else: + breakpoint = utr_d["new_start"] + breakpoint + res["breakpoint"] = breakpoint + res["control_mean_percent"] = control_mean_percent + res["treatment_mean_percent"] = treatment_mean_percent + res["gene"] = utr + return res + + +def estimate_coverage_extended_utr(utr_coverage, UTR_start, + UTR_end, is_reverse, coverage_weigths, search_start, coverage_threshold): + """ + We are searching for a breakpoint in coverage?! + utr_coverage is a list with items corresponding to numpy arrays of coverage for a sample. + """ + search_point_end = int(abs((UTR_end - UTR_start)) * 0.1) # TODO: This is 10% of total UTR end. Why? + num_samples = len(utr_coverage) + ##read coverage filtering + normalized_utr_coverage = [coverage/ coverage_weigths[i] for i, coverage in enumerate( utr_coverage.values() )] + start_coverage = [np.mean(coverage[0:99]) for coverage in utr_coverage.values()] # filters threshold on mean coverage over first 100 nt + is_above_threshold = sum(np.array(start_coverage) >= coverage_threshold) >= num_samples # This filters on the raw threshold. Why? + is_above_length = UTR_end - UTR_start >= 150 + if (is_above_threshold) and (is_above_length): + if not is_reverse: + search_region = range(UTR_start + search_start, UTR_end - search_point_end + 1) + else: + search_region = range(UTR_end - search_start, UTR_start + search_point_end - 1, -1) + search_end = UTR_end - UTR_start - search_point_end + normalized_utr_coverage = np.array(normalized_utr_coverage) + breakpoints = range(search_start, search_end + 1) + mse_list = [ estimate_mse(normalized_utr_coverage,bp, num_samples) for bp in breakpoints ] + if len(mse_list) > 0: + min_ele_index = mse_list.index(min(mse_list)) + breakpoint = breakpoints[min_ele_index] + UTR_abundances = estimate_abundance(normalized_utr_coverage, breakpoint, num_samples) + select_mean_squared_error = mse_list[min_ele_index] + selected_break_point = breakpoint + else: + select_mean_squared_error = 'Na' + UTR_abundances = 'Na' + selected_break_point = 'Na' + else: + select_mean_squared_error = 'Na' + UTR_abundances = 'Na' + selected_break_point = 'Na' + + return select_mean_squared_error, selected_break_point, UTR_abundances + + +def estimate_mse(cov, bp, num_samples): + """ + get abundance of long utr vs short utr with breakpoint specifying the position of long and short utr. + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + long_utr_vector = cov[:num_samples, bp:] + short_utr_vector = cov[:num_samples, 0:bp] + mean_long_utr = np.mean(long_utr_vector, 1) + mean_short_utr = np.mean(short_utr_vector, 1) + square_mean_centered_short_utr_vector = (short_utr_vector[:num_samples] - mean_short_utr[:, np.newaxis] )**2 + square_mean_centered_long_utr_vector = (long_utr_vector[:num_samples] - mean_long_utr[:, np.newaxis])**2 + mse = np.mean(np.append(square_mean_centered_short_utr_vector[:num_samples], square_mean_centered_long_utr_vector[:num_samples])) + return mse + +def estimate_abundance(cov, bp, num_samples): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + long_utr_vector = cov[:num_samples, bp:] + short_utr_vector = cov[:num_samples, 0:bp] + mean_long_utr = np.mean(long_utr_vector, 1) + mean_short_utr = np.mean(short_utr_vector, 1) + return mean_long_utr, mean_short_utr + + +if __name__ == '__main__': + args = parse_args() + find_utr = UtrFinder(args) +