Mercurial > repos > bcclaywell > argo_navis
comparison bin/deme_downsample.py @ 0:d67268158946 draft
planemo upload commit a3f181f5f126803c654b3a66dd4e83a48f7e203b
| author | bcclaywell |
|---|---|
| date | Mon, 12 Oct 2015 17:43:33 -0400 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:d67268158946 |
|---|---|
| 1 #!/usr/bin/env python | |
| 2 """Given the clustering results of a run of alnclst, this tool takes those results and find a single | |
| 3 representative sequence for each cluster. In particular, it chooses the cluster representative closest to the | |
| 4 cluster center. | |
| 5 """ | |
| 6 | |
| 7 import argparse | |
| 8 import random | |
| 9 import alnclst | |
| 10 import csv | |
| 11 from Bio import SeqIO | |
| 12 | |
| 13 | |
| 14 settings = dict(consensus_threshold=None, batches=2, max_iters=100) | |
| 15 | |
| 16 | |
| 17 def kmeans_runner(seqrecords, k): | |
| 18 "Runs kmeans on seqrecords and picks representatives from each cluster, returning their names in a list." | |
| 19 # Define clustering function we'll run batches number of times | |
| 20 def clustering(): | |
| 21 return alnclst.KMeansClsutering(seqrecords, k, settings['consensus_threshold'], max_iters=settings['max_iters']) | |
| 22 # Run the batches, and pick the one with the best convergence | |
| 23 _, clusts = min((c.average_distance(), c) for c in (clustering() for i in | |
| 24 xrange(settings['batches']))) | |
| 25 # Pick the best representative for every cluster, and thow in clust_reps dict | |
| 26 clust_reps = dict() | |
| 27 for cluster_id, sequence, distance in clusts.mapping_iterator(): | |
| 28 current = (distance, sequence) | |
| 29 clst_min = clust_reps.get(cluster_id, current) | |
| 30 if current <= clst_min: | |
| 31 clust_reps[cluster_id] = current | |
| 32 return [seqname for (_, (_, seqname)) in clust_reps.iteritems()] | |
| 33 | |
| 34 | |
| 35 def random_runner(seqnames, k): | |
| 36 "Randomly samples k seqnames from seqnames" | |
| 37 if len(seqnames) < k: | |
| 38 return seqnames | |
| 39 else: | |
| 40 return random.sample(seqnames, k) | |
| 41 | |
| 42 | |
| 43 def make_deme_map(deme_spec, deme_col): | |
| 44 "Turns deme metadata into a map of deme -> sequence names" | |
| 45 deme_map = dict() | |
| 46 for row in deme_spec: | |
| 47 try: | |
| 48 deme = row[deme_col] | |
| 49 except KeyError: | |
| 50 raise KeyError, "Make sure to specify a --deme-col that's actually in the deme file" | |
| 51 seqname = row['sequence'] | |
| 52 try: | |
| 53 deme_map[deme].append(seqname) | |
| 54 except KeyError: | |
| 55 deme_map[deme] = [seqname] | |
| 56 return deme_map | |
| 57 | |
| 58 | |
| 59 def get_args(): | |
| 60 parser = argparse.ArgumentParser(description=__doc__) | |
| 61 parser.add_argument('alignment', type=argparse.FileType('r'), help="Alignment FASTA file") | |
| 62 parser.add_argument('demes', type=argparse.FileType('r'), help="CSV metadata specifying deme info") | |
| 63 parser.add_argument('-c', '--deme-col', default='deme', help="Column specifying 'deme' argument in demes spec") | |
| 64 parser.add_argument('-k', help="Maximum number of representatives for each deme", type=int) | |
| 65 parser.add_argument('-s', '--seed', help="Random seed for reproducibility") | |
| 66 parser.add_argument('-m', '--method', choices=('random', 'kmeans'), | |
| 67 help="Which downsampling method should be used") | |
| 68 parser.add_argument('out_alignment', type=argparse.FileType('w'), help="Downsampled alignment output") | |
| 69 parser.add_argument('out_demes', type=argparse.FileType('w'), help="Downsampled metadata output") | |
| 70 return parser.parse_args() | |
| 71 | |
| 72 | |
| 73 def main(): | |
| 74 args = get_args() | |
| 75 # Set random seed if needed | |
| 76 if args.seed: | |
| 77 random.seed(args.seed) | |
| 78 | |
| 79 # Create a lit of seqrecords to make things easier for ourselves | |
| 80 seqrecords = SeqIO.to_dict(SeqIO.parse(args.alignment, 'fasta')) | |
| 81 demes = list(csv.DictReader(args.demes)) | |
| 82 | |
| 83 # Turn our metadata into a map of deme -> seqnames | |
| 84 deme_map = make_deme_map(demes, args.deme_col) | |
| 85 | |
| 86 # Run the specified downsampling method for each deme, and gather kept representatives | |
| 87 rep_seqnames = [] | |
| 88 for deme, seqnames in deme_map.iteritems(): | |
| 89 # this makes it safe to have a csv file with "extra" stuff | |
| 90 seqnames = [n for n in seqnames if n in seqrecords.keys()] | |
| 91 if args.method == 'random': | |
| 92 deme_rep_seqnames = random_runner(seqnames, args.k) | |
| 93 else: | |
| 94 deme_sequences = [seqrecords[n] for n in seqnames] | |
| 95 deme_rep_seqnames = kmeans_runner(deme_sequences, args.k) | |
| 96 rep_seqnames += deme_rep_seqnames | |
| 97 | |
| 98 # Filter down the actual data based on representative names | |
| 99 deme_rep_seqs = [seqrecords[n] for n in rep_seqnames] | |
| 100 deme_rep_meta = [r for r in demes if r['sequence'] in rep_seqnames] | |
| 101 | |
| 102 out_demes = csv.DictWriter(args.out_demes, deme_rep_meta[1].keys()) | |
| 103 out_demes.writeheader() | |
| 104 out_demes.writerows(deme_rep_meta) | |
| 105 | |
| 106 SeqIO.write(deme_rep_seqs, args.out_alignment, 'fasta') | |
| 107 | |
| 108 for fh in [args.alignment, args.demes, args.out_alignment, args.out_demes]: | |
| 109 fh.close() | |
| 110 | |
| 111 | |
| 112 if __name__ == '__main__': | |
| 113 main() | |
| 114 | |
| 115 |
