| 0 | 1 import pysam, re, string | 
|  | 2 import matplotlib.pyplot as plt | 
|  | 3 import pandas as pd | 
|  | 4 from collections import defaultdict | 
|  | 5 from collections import OrderedDict | 
|  | 6 import argparse | 
|  | 7 | 
|  | 8 class MismatchFrequencies: | 
|  | 9     '''Iterate over a SAM/BAM alignment file, collecting reads with mismatches. One | 
|  | 10     class instance per alignment file. The result_dict attribute will contain a | 
|  | 11     nested dictionary with name, readlength and mismatch count.''' | 
|  | 12     def __init__(self, result_dict={}, alignment_file=None, name="name", minimal_readlength=21, maximal_readlength=21, | 
|  | 13                  number_of_allowed_mismatches=1, ignore_5p_nucleotides=0, ignore_3p_nucleotides=0): | 
|  | 14 | 
|  | 15         self.result_dict = result_dict | 
|  | 16         self.name = name | 
|  | 17         self.minimal_readlength = minimal_readlength | 
|  | 18         self.maximal_readlength = maximal_readlength | 
|  | 19         self.number_of_allowed_mismatches = number_of_allowed_mismatches | 
|  | 20         self.ignore_5p_nucleotides = ignore_5p_nucleotides | 
|  | 21         self.ignore_3p_nucleotides = ignore_3p_nucleotides | 
|  | 22 | 
|  | 23         if alignment_file: | 
|  | 24             self.pysam_alignment = pysam.Samfile(alignment_file) | 
|  | 25             result_dict[name]=self.get_mismatches(self.pysam_alignment, minimal_readlength, maximal_readlength) | 
|  | 26 | 
|  | 27     def get_mismatches(self, pysam_alignment, minimal_readlength, maximal_readlength): | 
|  | 28         mismatch_dict = defaultdict(int) | 
|  | 29         len_dict={} | 
|  | 30         for i in range(minimal_readlength, maximal_readlength+1): | 
|  | 31             len_dict[i]=mismatch_dict.copy() | 
|  | 32         for alignedread in pysam_alignment: | 
|  | 33             if self.read_is_valid(alignedread, minimal_readlength, maximal_readlength): | 
|  | 34                 len_dict[int(alignedread.rlen)]['total_mapped'] += 1 | 
|  | 35                 MD=alignedread.opt('MD') | 
|  | 36                 if self.read_has_mismatch(alignedread, self.number_of_allowed_mismatches): | 
|  | 37                     (ref_base, mismatch_base)=self.read_to_reference_mismatch(MD, alignedread.seq, alignedread.is_reverse) | 
|  | 38                     if ref_base == None: | 
|  | 39                             continue | 
|  | 40                     else: | 
|  | 41                         for i, base in enumerate(ref_base): | 
|  | 42                             len_dict[int(alignedread.rlen)][ref_base[i]+' to '+mismatch_base[i]] += 1 | 
|  | 43         return len_dict | 
|  | 44 | 
|  | 45     def read_is_valid(self, read, min_readlength, max_readlength): | 
|  | 46         '''Filter out reads that are unmatched, too short or | 
|  | 47         too long or that contian insertions''' | 
|  | 48         if read.is_unmapped: | 
|  | 49             return False | 
|  | 50         if read.rlen < min_readlength: | 
|  | 51             return False | 
|  | 52         if read.rlen > max_readlength: | 
|  | 53             return False | 
|  | 54         else: | 
|  | 55             return True | 
|  | 56 | 
|  | 57     def read_has_mismatch(self, read, number_of_allowed_mismatches=1): | 
|  | 58         '''keep only reads with one mismatch. Could be simplified''' | 
|  | 59         NM=read.opt('NM') | 
|  | 60         if NM <1: #filter out reads with no mismatch | 
|  | 61             return False | 
|  | 62         if NM >number_of_allowed_mismatches: #filter out reads with more than 1 mismtach | 
|  | 63             return False | 
|  | 64         else: | 
|  | 65             return True | 
|  | 66 | 
|  | 67     def mismatch_in_allowed_region(self, readseq, mismatch_position): | 
|  | 68         ''' | 
|  | 69         >>> M = MismatchFrequencies() | 
|  | 70         >>> readseq = 'AAAAAA' | 
|  | 71         >>> mismatch_position = 2 | 
|  | 72         >>> M.mismatch_in_allowed_region(readseq, mismatch_position) | 
|  | 73         True | 
|  | 74         >>> M = MismatchFrequencies(ignore_3p_nucleotides=2, ignore_5p_nucleotides=2) | 
|  | 75         >>> readseq = 'AAAAAA' | 
|  | 76         >>> mismatch_position = 1 | 
|  | 77         >>> M.mismatch_in_allowed_region(readseq, mismatch_position) | 
|  | 78         False | 
|  | 79         >>> readseq = 'AAAAAA' | 
|  | 80         >>> mismatch_position = 4 | 
|  | 81         >>> M.mismatch_in_allowed_region(readseq, mismatch_position) | 
|  | 82         False | 
|  | 83         ''' | 
|  | 84         mismatch_position+=1 # To compensate for starting the count at 0 | 
|  | 85         five_p = self.ignore_5p_nucleotides | 
|  | 86         three_p = self.ignore_3p_nucleotides | 
|  | 87         if any([five_p > 0, three_p > 0]): | 
|  | 88             if any([mismatch_position <= five_p, | 
|  | 89                     mismatch_position >= (len(readseq)+1-three_p)]): #Again compensate for starting the count at 0 | 
|  | 90                 return False | 
|  | 91             else: | 
|  | 92                 return True | 
|  | 93         else: | 
|  | 94             return True | 
|  | 95 | 
|  | 96     def read_to_reference_mismatch(self, MD, readseq, is_reverse): | 
|  | 97         ''' | 
|  | 98         This is where the magic happens. The MD tag contains SNP and indel information, | 
|  | 99         without looking to the genome sequence. This is a typical MD tag: 3C0G2A6. | 
|  | 100         3 bases of the read align to the reference, followed by a mismatch, where the | 
|  | 101         reference base is C, followed by 10 bases aligned to the reference. | 
|  | 102         suppose a reference 'CTTCGATAATCCTT' | 
|  | 103                              |||  || |||||| | 
|  | 104                  and a read 'CTTATATTATCCTT'. | 
|  | 105         This situation is represented by the above MD tag. | 
|  | 106         Given MD tag and read sequence this function returns the reference base C, G and A, | 
|  | 107         and the mismatched base A, T, T. | 
|  | 108         >>> M = MismatchFrequencies() | 
|  | 109         >>> MD='3C0G2A7' | 
|  | 110         >>> seq='CTTATATTATCCTT' | 
|  | 111         >>> result=M.read_to_reference_mismatch(MD, seq, is_reverse=False) | 
|  | 112         >>> result[0]=="CGA" | 
|  | 113         True | 
|  | 114         >>> result[1]=="ATT" | 
|  | 115         True | 
|  | 116         >>> | 
|  | 117         ''' | 
|  | 118         search=re.finditer('[ATGC]',MD) | 
|  | 119         if '^' in MD: | 
|  | 120             print 'WARNING insertion detected, mismatch calling skipped for this read!!!' | 
|  | 121             return (None, None) | 
|  | 122         start_index=0 # refers to the leading integer of the MD string before an edited base | 
|  | 123         current_position=0 # position of the mismatched nucleotide in the MD tag string | 
|  | 124         mismatch_position=0 # position of edited base in current read | 
|  | 125         reference_base="" | 
|  | 126         mismatched_base="" | 
|  | 127         for result in search: | 
|  | 128             current_position=result.start() | 
|  | 129             mismatch_position=mismatch_position+1+int(MD[start_index:current_position]) #converts the leading characters before an edited base into integers | 
|  | 130             start_index=result.end() | 
|  | 131             reference_base+=MD[result.end()-1] | 
|  | 132             mismatched_base+=readseq[mismatch_position-1] | 
|  | 133         if is_reverse: | 
|  | 134             reference_base=reverseComplement(reference_base) | 
|  | 135             mismatched_base=reverseComplement(mismatched_base) | 
|  | 136         if mismatched_base=='N': | 
|  | 137             return (None, None) | 
|  | 138         if self.mismatch_in_allowed_region(readseq, mismatch_position): | 
|  | 139             return (reference_base, mismatched_base) | 
|  | 140         else: | 
|  | 141             return (None, None) | 
|  | 142 | 
|  | 143 | 
|  | 144 def reverseComplement(sequence): | 
|  | 145     '''do a reverse complement of DNA base. | 
|  | 146     >>> reverseComplement('ATGC')=='GCAT' | 
|  | 147     True | 
|  | 148     >>> | 
|  | 149     ''' | 
|  | 150     sequence=sequence.upper() | 
|  | 151     complement = string.maketrans('ATCGN', 'TAGCN') | 
|  | 152     return sequence.upper().translate(complement)[::-1] | 
|  | 153 | 
|  | 154 def barplot(df, library, axes): | 
|  | 155     df.plot(kind='bar', ax=axes, subplots=False,\ | 
|  | 156             stacked=False, legend='test',\ | 
|  | 157             title='Mismatches in TE small RNAs from {0}'.format(library)) | 
|  | 158 | 
|  | 159 def result_dict_to_df(result_dict): | 
|  | 160     mismatches = [] | 
|  | 161     libraries = [] | 
|  | 162     for mismatch, library in result_dict.iteritems(): | 
|  | 163         mismatches.append(mismatch) | 
|  | 164         libraries.append(pd.DataFrame.from_dict(library, orient='index')) | 
|  | 165     df=pd.concat(libraries, keys=mismatches) | 
|  | 166     df.index.names = ['library', 'readsize'] | 
|  | 167     return df | 
|  | 168 | 
|  | 169 def df_to_tab(df, output): | 
|  | 170     df.to_csv(output, sep='\t') | 
|  | 171 | 
|  | 172 def plot_result(result_dict, args): | 
|  | 173     names=args.name | 
|  | 174     nrows=len(names)/2+1 | 
|  | 175     fig = plt.figure(figsize=(16,32)) | 
|  | 176     for i,library in enumerate (names): | 
|  | 177         axes=fig.add_subplot(nrows,2,i+1) | 
|  | 178         library_dict=result_dict[library] | 
|  | 179         for length in library_dict.keys(): | 
|  | 180             for mismatch in library_dict[length]: | 
|  | 181                 if mismatch == 'total_mapped': | 
|  | 182                     continue | 
|  | 183                 library_dict[length][mismatch]=library_dict[length][mismatch]/float(library_dict[length]['total_mapped'])*100 | 
|  | 184             del library_dict[length]['total_mapped'] | 
|  | 185         df=pd.DataFrame(library_dict) | 
|  | 186         barplot(df, library, axes), | 
|  | 187         axes.set_ylabel('Percent of mapped reads with mismatches') | 
|  | 188     fig.savefig(args.output_pdf, format='pdf') | 
|  | 189 | 
|  | 190 def setup_MismatchFrequencies(args): | 
|  | 191     resultDict=OrderedDict() | 
|  | 192     kw_list=[{'result_dict' : resultDict, | 
|  | 193              'alignment_file' :alignment_file, | 
|  | 194              'name' : name, | 
|  | 195              'minimal_readlength' : args.min, | 
|  | 196              'maximal_readlength' : args.max, | 
|  | 197              'number_of_allowed_mismatches' : args.n_mm, | 
|  | 198              'ignore_5p_nucleotides' : args.five_p, | 
|  | 199              'ignore_3p_nucleotides' : args.three_p} | 
|  | 200              for alignment_file, name in zip(args.input, args.name)] | 
|  | 201     return (kw_list, resultDict) | 
|  | 202 | 
|  | 203 def run_MismatchFrequencies(args): | 
|  | 204     kw_list, resultDict=setup_MismatchFrequencies(args) | 
|  | 205     [MismatchFrequencies(**kw_dict) for kw_dict in kw_list] | 
|  | 206     return resultDict | 
|  | 207 | 
|  | 208 def main(): | 
|  | 209     result_dict=run_MismatchFrequencies(args) | 
|  | 210     df=result_dict_to_df(result_dict) | 
|  | 211     plot_result(result_dict, args) | 
|  | 212     df_to_tab(df, args.output_tab) | 
|  | 213 | 
|  | 214 if __name__ == "__main__": | 
|  | 215 | 
|  | 216     parser = argparse.ArgumentParser(description='Produce mismatch statistics for BAM/SAM alignment files.') | 
|  | 217     parser.add_argument('--input', nargs='*', help='Input files in SAM/BAM format') | 
|  | 218     parser.add_argument('--name', nargs='*', help='Name for input file to display in output file. Should have same length as the number of inputs') | 
|  | 219     parser.add_argument('--output_pdf', help='Output filename for graph') | 
|  | 220     parser.add_argument('--output_tab', help='Output filename for table') | 
|  | 221     parser.add_argument('--min', '--minimal_readlength', type=int, help='minimum readlength') | 
|  | 222     parser.add_argument('--max', '--maximal_readlength', type=int, help='maximum readlength') | 
|  | 223     parser.add_argument('--n_mm', '--number_allowed_mismatches', type=int, default=1, help='discard reads with more than n mismatches') | 
|  | 224     parser.add_argument('--five_p', '--ignore_5p_nucleotides', type=int, default=0, help='when calculating nucleotide mismatch frequencies ignore the first N nucleotides of the read') | 
|  | 225     parser.add_argument('--three_p', '--ignore_3p_nucleotides', type=int, default=1, help='when calculating nucleotide mismatch frequencies ignore the last N nucleotides of the read') | 
|  | 226     #args = parser.parse_args(['--input', '3mismatches_ago2ip.bam', '2mismatch.bam', '--name', 'Siomi1', 'Siomi2' , '--five_p', '3','--three_p','3','--output_pdf', 'out.pdf', '--output_tab', 'out.tab', '--min', '21', '--max', '21']) | 
|  | 227     args = parser.parse_args() | 
|  | 228     main() | 
|  | 229 |