diff CHAP/writer.py @ 0:cbbe42422d56 draft

planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
author kls286
date Tue, 28 Mar 2023 15:07:30 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/CHAP/writer.py	Tue Mar 28 15:07:30 2023 +0000
@@ -0,0 +1,243 @@
+#!/usr/bin/env python
+"""
+File       : writer.py
+Author     : Valentin Kuznetsov <vkuznet AT gmail dot com>
+Description: generic Writer module
+"""
+
+# system modules
+import argparse
+import json
+import logging
+import os
+import sys
+from time import time
+
+# local modules
+# from pipeline import PipelineObject
+
+class Writer():
+    """
+    Writer represent generic file writer
+    """
+
+    def __init__(self):
+        """
+        Constructor of Writer class
+        """
+        self.__name__ = self.__class__.__name__
+        self.logger = logging.getLogger(self.__name__)
+        self.logger.propagate = False
+
+    def write(self, data, filename, **_write_kwargs):
+        """
+        write API
+
+        :param filename: Name of file to write to
+        :param data: data to write to file
+        :return: data written to file
+        """
+
+        t0 = time()
+        self.logger.info(f'Executing "write" with filename={filename}, type(data)={type(data)}, kwargs={_write_kwargs}')
+
+        data = self._write(data, filename, **_write_kwargs)
+
+        self.logger.info(f'Finished "write" in {time()-t0:.3f} seconds\n')
+
+        return(data)
+
+    def _write(self, data, filename):
+        with open(filename, 'a') as file:
+            file.write(data)
+        return(data)
+
+class YAMLWriter(Writer):
+    def _write(self, data, filename, force_overwrite=False):
+        '''If `data` is a `dict`, write it to `filename`.
+
+        :param data: the dictionary to write to `filename`.
+        :type data: dict
+        :param filename: name of the file to write to.
+        :type filename: str
+        :param force_overwrite: flag to allow data in `filename` to be
+            overwritten if it already exists.
+        :type force_overwrite: bool
+        :raises TypeError: if `data` is not a `dict`
+        :raises RuntimeError: if `filename` already exists and
+            `force_overwrite` is `False`.
+        :return: the original input data
+        :rtype: dict
+        '''
+
+        import yaml
+
+        if not isinstance(data, (dict, list)):
+            raise(TypeError(f'{self.__name__}.write: input data must be a dict or list.'))
+
+        if not force_overwrite:
+            if os.path.isfile(filename):
+                raise(RuntimeError(f'{self.__name__}: {filename} already exists.'))
+
+        with open(filename, 'w') as outf:
+            yaml.dump(data, outf, sort_keys=False)
+
+        return(data)
+
+class ExtractArchiveWriter(Writer):
+    def _write(self, data, filename):
+        '''Take a .tar archive represented as bytes in `data` and write the
+        extracted archive to files.
+
+        :param data: the archive data
+        :type data: bytes
+        :param filename: the name of a directory to which the archive files will
+            be written
+        :type filename: str
+        :return: the original `data`
+        :rtype: bytes
+        '''
+
+        from io import BytesIO
+        import tarfile
+
+        tar = tarfile.open(fileobj=BytesIO(data))
+        tar.extractall(path=filename)
+
+        return(data)
+
+
+class NexusWriter(Writer):
+    def _write(self, data, filename, force_overwrite=False):
+        '''Write `data` to a NeXus file
+
+        :param data: the data to write to `filename`.
+        :param filename: name of the file to write to.
+        :param force_overwrite: flag to allow data in `filename` to be
+            overwritten, if it already exists. 
+        :return: the original input data
+        '''
+
+        from nexusformat.nexus import NXobject
+        import xarray as xr
+        
+        if isinstance(data, NXobject):
+            nxstructure = data
+
+        elif isinstance(data, xr.Dataset):
+            nxstructure = self.get_nxdata_from_dataset(data)
+
+        elif isinstance(data, xr.DataArray):
+            nxstructure = self.get_nxdata_from_dataarray(data)
+        
+        else:
+            raise(TypeError(f'{self.__name__}.write: unknown data format: {type(data).__name__}'))
+
+        mode = 'w' if force_overwrite else 'w-'
+        nxstructure.save(filename, mode=mode)
+
+        return(data)
+
+
+    def get_nxdata_from_dataset(self, dset):
+        '''Return an instance of `nexusformat.nexus.NXdata` that represents the
+        data and metadata attributes contained in `dset`.
+
+        :param dset: the input dataset to represent
+        :type data: xarray.Dataset
+        :return: `dset` represented as an instance of `nexusformat.nexus.NXdata`
+        :rtype: nexusformat.nexus.NXdata
+        '''
+
+        from nexusformat.nexus import NXdata, NXfield
+
+        nxdata_args = {'signal':None, 'axes':()}
+
+        for var in dset.data_vars:
+            data_var = dset[var]
+            nxfield = NXfield(data_var.data,
+                              name=data_var.name,
+                              attrs=data_var.attrs)
+            if nxdata_args['signal'] is None:
+                nxdata_args['signal'] = nxfield
+            else:
+                nxdata_args[var] = nxfield
+
+        for coord in dset.coords:
+            coord_var = dset[coord]
+            nxfield = NXfield(coord_var.data,
+                              name=coord_var.name,
+                              attrs=coord_var.attrs)
+            nxdata_args['axes'] = (*nxdata_args['axes'], nxfield)
+
+        nxdata = NXdata(**nxdata_args)
+        nxdata.attrs['xarray_attrs'] = json.dumps(dset.attrs)
+
+        return(nxdata)
+
+    def get_nxdata_from_dataarray(self, darr):
+        '''Return an instance of `nexusformat.nexus.NXdata` that represents the
+        data and metadata attributes contained in `darr`.
+
+        :param darr: the input dataset to represent
+        :type darr: xarray.DataArray
+        :return: `darr` represented as an instance of `nexusformat.nexus.NXdata`
+        :rtype: nexusformat.nexus.NXdata
+        '''
+
+        from nexusformat.nexus import NXdata, NXfield
+
+        nxdata_args = {'signal':None, 'axes':()}
+
+        nxdata_args['signal'] = NXfield(darr.data,
+                                        name=darr.name,
+                                        attrs=darr.attrs)
+
+
+        for coord in darr.coords:
+            coord_var = darr[coord]
+            nxfield = NXfield(coord_var.data,
+                              name=coord_var.name,
+                              attrs=coord_var.attrs)
+            nxdata_args['axes'] = (*nxdata_args['axes'], nxfield)
+
+        nxdata = NXdata(**nxdata_args)
+        nxdata.attrs['xarray_attrs'] = json.dumps(darr.attrs)
+
+        return(nxdata)
+
+
+class OptionParser():
+    '''User based option parser'''
+    def __init__(self):
+        self.parser = argparse.ArgumentParser(prog='PROG')
+        self.parser.add_argument("--data", action="store",
+            dest="data", default="", help="Input data")
+        self.parser.add_argument("--filename", action="store",
+            dest="filename", default="", help="Output file")
+        self.parser.add_argument("--writer", action="store",
+            dest="writer", default="Writer", help="Writer class name")
+        self.parser.add_argument('--log-level', choices=logging._nameToLevel.keys(),
+            dest='log_level', default='INFO', help='logging level')
+
+def main():
+    '''Main function'''
+    optmgr  = OptionParser()
+    opts = optmgr.parser.parse_args()
+    clsName = opts.writer
+    try:
+        writerCls = getattr(sys.modules[__name__],clsName)
+    except:
+        print(f'Unsupported writer {clsName}')
+        sys.exit(1)
+
+    writer = writerCls()
+    writer.logger.setLevel(getattr(logging, opts.log_level))
+    log_handler = logging.StreamHandler()
+    log_handler.setFormatter(logging.Formatter('{name:20}: {message}', style='{'))
+    writer.logger.addHandler(log_handler)
+    data = writer.write(opts.data, opts.filename)
+    print(f"Writer {writer} writes to {opts.filename}, data {data}")
+
+if __name__ == '__main__':
+    main()