Mercurial > repos > kls286 > chap_test_20230328
comparison CHAP/writer.py @ 0:cbbe42422d56 draft
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
| author | kls286 |
|---|---|
| date | Tue, 28 Mar 2023 15:07:30 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:cbbe42422d56 |
|---|---|
| 1 #!/usr/bin/env python | |
| 2 """ | |
| 3 File : writer.py | |
| 4 Author : Valentin Kuznetsov <vkuznet AT gmail dot com> | |
| 5 Description: generic Writer module | |
| 6 """ | |
| 7 | |
| 8 # system modules | |
| 9 import argparse | |
| 10 import json | |
| 11 import logging | |
| 12 import os | |
| 13 import sys | |
| 14 from time import time | |
| 15 | |
| 16 # local modules | |
| 17 # from pipeline import PipelineObject | |
| 18 | |
| 19 class Writer(): | |
| 20 """ | |
| 21 Writer represent generic file writer | |
| 22 """ | |
| 23 | |
| 24 def __init__(self): | |
| 25 """ | |
| 26 Constructor of Writer class | |
| 27 """ | |
| 28 self.__name__ = self.__class__.__name__ | |
| 29 self.logger = logging.getLogger(self.__name__) | |
| 30 self.logger.propagate = False | |
| 31 | |
| 32 def write(self, data, filename, **_write_kwargs): | |
| 33 """ | |
| 34 write API | |
| 35 | |
| 36 :param filename: Name of file to write to | |
| 37 :param data: data to write to file | |
| 38 :return: data written to file | |
| 39 """ | |
| 40 | |
| 41 t0 = time() | |
| 42 self.logger.info(f'Executing "write" with filename={filename}, type(data)={type(data)}, kwargs={_write_kwargs}') | |
| 43 | |
| 44 data = self._write(data, filename, **_write_kwargs) | |
| 45 | |
| 46 self.logger.info(f'Finished "write" in {time()-t0:.3f} seconds\n') | |
| 47 | |
| 48 return(data) | |
| 49 | |
| 50 def _write(self, data, filename): | |
| 51 with open(filename, 'a') as file: | |
| 52 file.write(data) | |
| 53 return(data) | |
| 54 | |
| 55 class YAMLWriter(Writer): | |
| 56 def _write(self, data, filename, force_overwrite=False): | |
| 57 '''If `data` is a `dict`, write it to `filename`. | |
| 58 | |
| 59 :param data: the dictionary to write to `filename`. | |
| 60 :type data: dict | |
| 61 :param filename: name of the file to write to. | |
| 62 :type filename: str | |
| 63 :param force_overwrite: flag to allow data in `filename` to be | |
| 64 overwritten if it already exists. | |
| 65 :type force_overwrite: bool | |
| 66 :raises TypeError: if `data` is not a `dict` | |
| 67 :raises RuntimeError: if `filename` already exists and | |
| 68 `force_overwrite` is `False`. | |
| 69 :return: the original input data | |
| 70 :rtype: dict | |
| 71 ''' | |
| 72 | |
| 73 import yaml | |
| 74 | |
| 75 if not isinstance(data, (dict, list)): | |
| 76 raise(TypeError(f'{self.__name__}.write: input data must be a dict or list.')) | |
| 77 | |
| 78 if not force_overwrite: | |
| 79 if os.path.isfile(filename): | |
| 80 raise(RuntimeError(f'{self.__name__}: {filename} already exists.')) | |
| 81 | |
| 82 with open(filename, 'w') as outf: | |
| 83 yaml.dump(data, outf, sort_keys=False) | |
| 84 | |
| 85 return(data) | |
| 86 | |
| 87 class ExtractArchiveWriter(Writer): | |
| 88 def _write(self, data, filename): | |
| 89 '''Take a .tar archive represented as bytes in `data` and write the | |
| 90 extracted archive to files. | |
| 91 | |
| 92 :param data: the archive data | |
| 93 :type data: bytes | |
| 94 :param filename: the name of a directory to which the archive files will | |
| 95 be written | |
| 96 :type filename: str | |
| 97 :return: the original `data` | |
| 98 :rtype: bytes | |
| 99 ''' | |
| 100 | |
| 101 from io import BytesIO | |
| 102 import tarfile | |
| 103 | |
| 104 tar = tarfile.open(fileobj=BytesIO(data)) | |
| 105 tar.extractall(path=filename) | |
| 106 | |
| 107 return(data) | |
| 108 | |
| 109 | |
| 110 class NexusWriter(Writer): | |
| 111 def _write(self, data, filename, force_overwrite=False): | |
| 112 '''Write `data` to a NeXus file | |
| 113 | |
| 114 :param data: the data to write to `filename`. | |
| 115 :param filename: name of the file to write to. | |
| 116 :param force_overwrite: flag to allow data in `filename` to be | |
| 117 overwritten, if it already exists. | |
| 118 :return: the original input data | |
| 119 ''' | |
| 120 | |
| 121 from nexusformat.nexus import NXobject | |
| 122 import xarray as xr | |
| 123 | |
| 124 if isinstance(data, NXobject): | |
| 125 nxstructure = data | |
| 126 | |
| 127 elif isinstance(data, xr.Dataset): | |
| 128 nxstructure = self.get_nxdata_from_dataset(data) | |
| 129 | |
| 130 elif isinstance(data, xr.DataArray): | |
| 131 nxstructure = self.get_nxdata_from_dataarray(data) | |
| 132 | |
| 133 else: | |
| 134 raise(TypeError(f'{self.__name__}.write: unknown data format: {type(data).__name__}')) | |
| 135 | |
| 136 mode = 'w' if force_overwrite else 'w-' | |
| 137 nxstructure.save(filename, mode=mode) | |
| 138 | |
| 139 return(data) | |
| 140 | |
| 141 | |
| 142 def get_nxdata_from_dataset(self, dset): | |
| 143 '''Return an instance of `nexusformat.nexus.NXdata` that represents the | |
| 144 data and metadata attributes contained in `dset`. | |
| 145 | |
| 146 :param dset: the input dataset to represent | |
| 147 :type data: xarray.Dataset | |
| 148 :return: `dset` represented as an instance of `nexusformat.nexus.NXdata` | |
| 149 :rtype: nexusformat.nexus.NXdata | |
| 150 ''' | |
| 151 | |
| 152 from nexusformat.nexus import NXdata, NXfield | |
| 153 | |
| 154 nxdata_args = {'signal':None, 'axes':()} | |
| 155 | |
| 156 for var in dset.data_vars: | |
| 157 data_var = dset[var] | |
| 158 nxfield = NXfield(data_var.data, | |
| 159 name=data_var.name, | |
| 160 attrs=data_var.attrs) | |
| 161 if nxdata_args['signal'] is None: | |
| 162 nxdata_args['signal'] = nxfield | |
| 163 else: | |
| 164 nxdata_args[var] = nxfield | |
| 165 | |
| 166 for coord in dset.coords: | |
| 167 coord_var = dset[coord] | |
| 168 nxfield = NXfield(coord_var.data, | |
| 169 name=coord_var.name, | |
| 170 attrs=coord_var.attrs) | |
| 171 nxdata_args['axes'] = (*nxdata_args['axes'], nxfield) | |
| 172 | |
| 173 nxdata = NXdata(**nxdata_args) | |
| 174 nxdata.attrs['xarray_attrs'] = json.dumps(dset.attrs) | |
| 175 | |
| 176 return(nxdata) | |
| 177 | |
| 178 def get_nxdata_from_dataarray(self, darr): | |
| 179 '''Return an instance of `nexusformat.nexus.NXdata` that represents the | |
| 180 data and metadata attributes contained in `darr`. | |
| 181 | |
| 182 :param darr: the input dataset to represent | |
| 183 :type darr: xarray.DataArray | |
| 184 :return: `darr` represented as an instance of `nexusformat.nexus.NXdata` | |
| 185 :rtype: nexusformat.nexus.NXdata | |
| 186 ''' | |
| 187 | |
| 188 from nexusformat.nexus import NXdata, NXfield | |
| 189 | |
| 190 nxdata_args = {'signal':None, 'axes':()} | |
| 191 | |
| 192 nxdata_args['signal'] = NXfield(darr.data, | |
| 193 name=darr.name, | |
| 194 attrs=darr.attrs) | |
| 195 | |
| 196 | |
| 197 for coord in darr.coords: | |
| 198 coord_var = darr[coord] | |
| 199 nxfield = NXfield(coord_var.data, | |
| 200 name=coord_var.name, | |
| 201 attrs=coord_var.attrs) | |
| 202 nxdata_args['axes'] = (*nxdata_args['axes'], nxfield) | |
| 203 | |
| 204 nxdata = NXdata(**nxdata_args) | |
| 205 nxdata.attrs['xarray_attrs'] = json.dumps(darr.attrs) | |
| 206 | |
| 207 return(nxdata) | |
| 208 | |
| 209 | |
| 210 class OptionParser(): | |
| 211 '''User based option parser''' | |
| 212 def __init__(self): | |
| 213 self.parser = argparse.ArgumentParser(prog='PROG') | |
| 214 self.parser.add_argument("--data", action="store", | |
| 215 dest="data", default="", help="Input data") | |
| 216 self.parser.add_argument("--filename", action="store", | |
| 217 dest="filename", default="", help="Output file") | |
| 218 self.parser.add_argument("--writer", action="store", | |
| 219 dest="writer", default="Writer", help="Writer class name") | |
| 220 self.parser.add_argument('--log-level', choices=logging._nameToLevel.keys(), | |
| 221 dest='log_level', default='INFO', help='logging level') | |
| 222 | |
| 223 def main(): | |
| 224 '''Main function''' | |
| 225 optmgr = OptionParser() | |
| 226 opts = optmgr.parser.parse_args() | |
| 227 clsName = opts.writer | |
| 228 try: | |
| 229 writerCls = getattr(sys.modules[__name__],clsName) | |
| 230 except: | |
| 231 print(f'Unsupported writer {clsName}') | |
| 232 sys.exit(1) | |
| 233 | |
| 234 writer = writerCls() | |
| 235 writer.logger.setLevel(getattr(logging, opts.log_level)) | |
| 236 log_handler = logging.StreamHandler() | |
| 237 log_handler.setFormatter(logging.Formatter('{name:20}: {message}', style='{')) | |
| 238 writer.logger.addHandler(log_handler) | |
| 239 data = writer.write(opts.data, opts.filename) | |
| 240 print(f"Writer {writer} writes to {opts.filename}, data {data}") | |
| 241 | |
| 242 if __name__ == '__main__': | |
| 243 main() |
