Mercurial > repos > bcclaywell > argo_navis
comparison bin/format_beastfile.py @ 0:d67268158946 draft
planemo upload commit a3f181f5f126803c654b3a66dd4e83a48f7e203b
| author | bcclaywell |
|---|---|
| date | Mon, 12 Oct 2015 17:43:33 -0400 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:d67268158946 |
|---|---|
| 1 #!/usr/bin/env python | |
| 2 """ | |
| 3 Formats a given BEAST XML file (possibly all ready to run) and respecifies the information needed to run the | |
| 4 classic Discrete trait. | |
| 5 | |
| 6 Some things that would be nice: | |
| 7 * specify output files/formats (could let you run from root instead of the dir) | |
| 8 """ | |
| 9 | |
| 10 from Bio import SeqIO | |
| 11 import xml.etree.ElementTree as ET | |
| 12 import argparse | |
| 13 import copy | |
| 14 import csv | |
| 15 | |
| 16 | |
| 17 def clear_children(node): | |
| 18 "Element.remove doesn't seem to work the way it's supposed to, so we're doing this" | |
| 19 node_attrib = copy.copy(node.attrib) | |
| 20 node.clear() | |
| 21 node.attrib = node_attrib | |
| 22 | |
| 23 | |
| 24 def set_alignment(xmldoc, alignment): | |
| 25 """This function replaces the alignment data in xmldoc with that from sequences in alignment.""" | |
| 26 aln_node = xmldoc.find('data') | |
| 27 # First clear out the old alignment sequences | |
| 28 clear_children(aln_node) | |
| 29 print "seqs" | |
| 30 for seq in aln_node: | |
| 31 print seq | |
| 32 # Next, construct and throw in the new sequence nodes | |
| 33 for seq_record in alignment: | |
| 34 seqid = seq_record.name | |
| 35 ET.SubElement(aln_node, 'sequence', | |
| 36 attrib=dict(id="seq_" + seqid, | |
| 37 taxon=seqid, | |
| 38 totalcount="4", | |
| 39 value=str(seq_record.seq))) | |
| 40 | |
| 41 | |
| 42 def get_data_id(xmldoc): | |
| 43 """The data set will have a given name, assigned by BEAUti, typically based on the named of the data file | |
| 44 loaded into it. This name gets referred to in a number of places (presumably so there can be a number of | |
| 45 partitions/datasets in an analysis), and is needed by other bits of code that do their thing.""" | |
| 46 return xmldoc.find(".//data[@name='alignment'][@id]").attrib['id'] | |
| 47 | |
| 48 | |
| 49 def default_deme_getter(metarow): | |
| 50 """A default function for getting the deme data from a given metadata row. Specifically defaults to 'deme' | |
| 51 first, then to 'community' next. Returns none if it doesn't find either.""" | |
| 52 return metarow.get('deme') or metarow.get('community') | |
| 53 | |
| 54 | |
| 55 def set_deme(xmldoc, metadata, deme_getter=default_deme_getter): | |
| 56 """Sets the deme information of the xmldoc based on metadata, and using the deme_getter (by default the | |
| 57 `default_deme_getter` function above.""" | |
| 58 trait_node = xmldoc.iter('traitSet').next() | |
| 59 trait_string = ",\n".join([row['sequence'] + "=" + deme_getter(row) for row in metadata]) | |
| 60 trait_node.text = trait_string | |
| 61 | |
| 62 | |
| 63 def build_date_node(date_spec, data_id): | |
| 64 """Builds a node of date traits, given the date_spec string which is the actual string representation of | |
| 65 the sequence -> date mapping. Has to create a `taxa` subnode, and a `data` subnode of that, which points | |
| 66 to the data set in question via `idref`.""" | |
| 67 date_node = ET.Element('trait', | |
| 68 id='dateTrait.t:' + data_id, | |
| 69 spec='beast.evolution.tree.TraitSet', | |
| 70 traitname='date') | |
| 71 date_node.text = date_spec | |
| 72 taxa_node = ET.SubElement(date_node, 'taxa', | |
| 73 id='TaxonSet.' + data_id, | |
| 74 spec='TaxonSet') | |
| 75 _ = ET.SubElement(taxa_node, 'data', | |
| 76 idref=data_id, | |
| 77 name="alignment") | |
| 78 return date_node | |
| 79 | |
| 80 | |
| 81 def set_date(xmldoc, metadata, date_attr='date'): | |
| 82 """Builds a dateTrait node via `build_date_node` above, and inserts into the `.//state/tree` node. | |
| 83 However, this `tree` node already contains a `taxonset` node which has a `data` node, and this | |
| 84 `taxonset` node has the same id as the `taxa` node in the the date `trait` node. As such, the node that | |
| 85 _was_ present must be removed, so that we don't get a duplicate id error. Instead, we replace the old | |
| 86 taxonset node with one which has an `idref` pointing to the `taxa` node inside the `trait` node. This is | |
| 87 rather convoluted, and I'm not possible that some file with multiple datasets wouldn't break on this, but | |
| 88 this described strategy seems to work for now.""" | |
| 89 # First get our tree node; we'll be adding the date data to this | |
| 90 tree_node = xmldoc.find('.//state/tree') | |
| 91 # Construct our trait string, just as we do for `set_deme` | |
| 92 trait_string = ",\n".join([row['sequence'] + "=" + row[date_attr] for row in metadata]) | |
| 93 # Build the date trait node, and carry out all the weird mucking to get the new `taxonset` node in, as | |
| 94 # described in the docstring | |
| 95 data_id = get_data_id(xmldoc) | |
| 96 date_node = build_date_node(trait_string, data_id) | |
| 97 old_taxonset = tree_node.find("./taxonset") | |
| 98 tree_node.insert(0, date_node) | |
| 99 tree_node.remove(old_taxonset) | |
| 100 new_taxonset = ET.SubElement(tree_node, "taxonset", idref="TaxonSet."+data_id) | |
| 101 | |
| 102 | |
| 103 def get_current_interval(xmldoc): | |
| 104 run_node = xmldoc.find('run') | |
| 105 loggers = run_node.findall('logger') | |
| 106 intervals = list(set([int(l.get('logEvery')) for l in loggers if l.get('id') != 'screenlog'])) | |
| 107 if len(intervals) > 1: | |
| 108 raise "Cannot get an interval for this xml doc; there are multiple such values" | |
| 109 return intervals[0] | |
| 110 | |
| 111 | |
| 112 def set_mcmc(xmldoc, samples, sampling_interval): | |
| 113 "Sets the MCMC chain settings (how often to log, how long to run, etc" | |
| 114 run_node = xmldoc.find('run') | |
| 115 # XXX Should really make it so that you only have to specify _one_, and it will find current value of | |
| 116 # other so that chain length doesn't break. | |
| 117 chain_length = samples * sampling_interval + 1 | |
| 118 run_node.set('chainLength', str(chain_length)) | |
| 119 loggers = run_node.findall('logger') | |
| 120 for logger in loggers: | |
| 121 logevery = sampling_interval * 10 if logger.get('id') == 'screenlog' else sampling_interval | |
| 122 logger.set('logEvery', str(logevery)) | |
| 123 | |
| 124 | |
| 125 def normalize_filenames(xmldoc, logger_filename="posterior.log", treefile_filename="posterior.trait.trees"): | |
| 126 run_node = xmldoc.find('run') | |
| 127 logfile_node = run_node.find('logger[@id="tracelog"]') | |
| 128 treefile_node = run_node.find('logger[@id="treeWithTraitLogger.deme"]') | |
| 129 logfile_node.set('fileName', logger_filename) | |
| 130 treefile_node.set('fileName', treefile_filename) | |
| 131 | |
| 132 | |
| 133 def set_deme_count(xmldoc, metadata, deme_getter=default_deme_getter): | |
| 134 "Updates the model specs based onthe number of demes in the data set." | |
| 135 demes = list(set(map(deme_getter, metadata))) | |
| 136 demes.sort() | |
| 137 deme_count = len(demes) | |
| 138 mig_dim = (deme_count - 1) * deme_count / 2 | |
| 139 for xpath in ['.//parameter[@id="relativeGeoRates.s:deme"]', './/stateNode[@id="rateIndicator.s:deme"]']: | |
| 140 xmldoc.find(xpath).set('dimension', str(mig_dim)) | |
| 141 code_map = map(lambda ix: ix[1] + "=" + str(ix[0]), enumerate(demes)) | |
| 142 code_map = ",".join(code_map) + ",? = " + " ".join(map(str, range(deme_count))) + " " | |
| 143 user_data_type_node = xmldoc.find('.//userDataType') | |
| 144 user_data_type_node.set('codeMap', code_map) | |
| 145 user_data_type_node.set('states', str(deme_count)) | |
| 146 trait_frequencies_param = xmldoc.find('.//frequencies/parameter[@id="traitfrequencies.s:deme"]') | |
| 147 trait_frequencies_param.set('dimension', str(deme_count)) | |
| 148 trait_frequencies_param.text = str(1.0/deme_count) | |
| 149 | |
| 150 | |
| 151 | |
| 152 def get_args(): | |
| 153 def int_or_floatify(string): | |
| 154 return int(float(string)) | |
| 155 parser = argparse.ArgumentParser() | |
| 156 parser.add_argument('template', type=argparse.FileType('r'), | |
| 157 help="""A template BEAST XML (presumably created by Beauti) ready insertion of an alignment and | |
| 158 discrete trait.""") | |
| 159 parser.add_argument('-a', '--alignment', | |
| 160 help="Replace alignment in beast file with this alignment; Fasta format.") | |
| 161 parser.add_argument('-m', '--metadata', type=argparse.FileType('r'), | |
| 162 help="Should contain 'community' column referencing the deme.") | |
| 163 parser.add_argument('-s', '--samples', type=int_or_floatify, | |
| 164 help="Number of samples in output log file(s).") | |
| 165 parser.add_argument('-d', '--deme-col', | |
| 166 help="""Specifies the deme column for metadata; defaults to deme or community (whichever is present) | |
| 167 if not specified.""") | |
| 168 parser.add_argument('-D', '--date-col', | |
| 169 help="If specified, will add a date specification to the output BEAST XML file.") | |
| 170 parser.add_argument('-i', '--sampling-interval', type=int_or_floatify, | |
| 171 help="""Number of chain states to simulate between successive states samples for logfiles. The | |
| 172 total chain length is therefor samples * sampling_interval.""") | |
| 173 parser.add_argument('beastfile', type=argparse.FileType('w'), | |
| 174 help="Output BEAST XML file.") | |
| 175 return parser.parse_args() | |
| 176 | |
| 177 | |
| 178 def main(args): | |
| 179 # Read in old data | |
| 180 xmldoc = ET.parse(args.template) | |
| 181 | |
| 182 # Modify the data sets | |
| 183 if args.alignment: | |
| 184 alignment = SeqIO.parse(args.alignment, 'fasta') | |
| 185 set_alignment(xmldoc, alignment) | |
| 186 if args.metadata: | |
| 187 metadata = list(csv.DictReader(args.metadata)) | |
| 188 # Set the deme getter | |
| 189 deme_getter = lambda row: row[args.deme_col] if args.deme_col else default_deme_getter(row) | |
| 190 set_deme(xmldoc, metadata, deme_getter) | |
| 191 # _could_ do something smart here where we look at which sequences in the XML file traitset that match | |
| 192 # alignment passed in if _only_ alignment is passed in. Probably not worth it though... | |
| 193 set_deme_count(xmldoc, metadata, deme_getter) | |
| 194 if args.date_col: | |
| 195 set_date(xmldoc, metadata, args.date_col) | |
| 196 | |
| 197 if args.samples or args.sampling_interval: | |
| 198 interval = args.sampling_interval if args.sampling_interval else get_current_interval(xmldoc) | |
| 199 set_mcmc(xmldoc, args.samples, interval) | |
| 200 | |
| 201 # Make sure that we always have the same file names out. These are specified as defaults of the function, | |
| 202 # but could be customized here or through the cl args if needed. | |
| 203 normalize_filenames(xmldoc) | |
| 204 | |
| 205 # Write the output | |
| 206 xmldoc.write(args.beastfile) | |
| 207 | |
| 208 | |
| 209 if __name__ == '__main__': | |
| 210 main(get_args()) | |
| 211 | |
| 212 |
