Mercurial > repos > bgruening > sklearn_generalized_linear
comparison keras_deep_learning.py @ 40:a8771df897b2 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
| author | bgruening |
|---|---|
| date | Wed, 09 Aug 2023 11:13:19 +0000 |
| parents | 34f295eb5782 |
| children |
comparison
equal
deleted
inserted
replaced
| 39:1a72afcb0752 | 40:a8771df897b2 |
|---|---|
| 1 import argparse | 1 import argparse |
| 2 import json | 2 import json |
| 3 import pickle | |
| 4 import warnings | 3 import warnings |
| 5 from ast import literal_eval | 4 from ast import literal_eval |
| 6 | 5 |
| 7 import keras | |
| 8 import pandas as pd | |
| 9 import six | 6 import six |
| 10 from galaxy_ml.utils import get_search_params, SafeEval, try_get_attr | 7 from galaxy_ml.model_persist import dump_model_to_h5 |
| 11 from keras.models import Model, Sequential | 8 from galaxy_ml.utils import SafeEval, try_get_attr |
| 9 from tensorflow import keras | |
| 10 from tensorflow.keras.models import Model, Sequential | |
| 12 | 11 |
| 13 safe_eval = SafeEval() | 12 safe_eval = SafeEval() |
| 14 | 13 |
| 15 | 14 |
| 16 def _handle_shape(literal): | 15 def _handle_shape(literal): |
| 17 """ | 16 """Eval integer or list/tuple of integers from string |
| 18 Eval integer or list/tuple of integers from string | |
| 19 | 17 |
| 20 Parameters: | 18 Parameters: |
| 21 ----------- | 19 ----------- |
| 22 literal : str. | 20 literal : str. |
| 23 """ | 21 """ |
| 30 print(e) | 28 print(e) |
| 31 return literal | 29 return literal |
| 32 | 30 |
| 33 | 31 |
| 34 def _handle_regularizer(literal): | 32 def _handle_regularizer(literal): |
| 35 """ | 33 """Construct regularizer from string literal |
| 36 Construct regularizer from string literal | |
| 37 | 34 |
| 38 Parameters | 35 Parameters |
| 39 ---------- | 36 ---------- |
| 40 literal : str. E.g. '(0.1, 0)' | 37 literal : str. E.g. '(0.1, 0)' |
| 41 """ | 38 """ |
| 55 | 52 |
| 56 return keras.regularizers.l1_l2(l1=l1, l2=l2) | 53 return keras.regularizers.l1_l2(l1=l1, l2=l2) |
| 57 | 54 |
| 58 | 55 |
| 59 def _handle_constraint(config): | 56 def _handle_constraint(config): |
| 60 """ | 57 """Construct constraint from galaxy tool parameters. |
| 61 Construct constraint from galaxy tool parameters. | |
| 62 Suppose correct dictionary format | 58 Suppose correct dictionary format |
| 63 | 59 |
| 64 Parameters | 60 Parameters |
| 65 ---------- | 61 ---------- |
| 66 config : dict. E.g. | 62 config : dict. E.g. |
| 89 def _handle_lambda(literal): | 85 def _handle_lambda(literal): |
| 90 return None | 86 return None |
| 91 | 87 |
| 92 | 88 |
| 93 def _handle_layer_parameters(params): | 89 def _handle_layer_parameters(params): |
| 94 """ | 90 """Access to handle all kinds of parameters""" |
| 95 Access to handle all kinds of parameters | |
| 96 """ | |
| 97 for key, value in six.iteritems(params): | 91 for key, value in six.iteritems(params): |
| 98 if value in ("None", ""): | 92 if value in ("None", ""): |
| 99 params[key] = None | 93 params[key] = None |
| 100 continue | 94 continue |
| 101 | 95 |
| 102 if type(value) in [int, float, bool] or ( | 96 if type(value) in [int, float, bool] or ( |
| 103 type(value) is str and value.isalpha() | 97 type(value) is str and value.isalpha() |
| 104 ): | 98 ): |
| 105 continue | 99 continue |
| 106 | 100 |
| 107 if ( | 101 if key in [ |
| 108 key | 102 "input_shape", |
| 109 in [ | 103 "noise_shape", |
| 110 "input_shape", | 104 "shape", |
| 111 "noise_shape", | 105 "batch_shape", |
| 112 "shape", | 106 "target_shape", |
| 113 "batch_shape", | 107 "dims", |
| 114 "target_shape", | 108 "kernel_size", |
| 115 "dims", | 109 "strides", |
| 116 "kernel_size", | 110 "dilation_rate", |
| 117 "strides", | 111 "output_padding", |
| 118 "dilation_rate", | 112 "cropping", |
| 119 "output_padding", | 113 "size", |
| 120 "cropping", | 114 "padding", |
| 121 "size", | 115 "pool_size", |
| 122 "padding", | 116 "axis", |
| 123 "pool_size", | 117 "shared_axes", |
| 124 "axis", | 118 ] and isinstance(value, str): |
| 125 "shared_axes", | |
| 126 ] | |
| 127 and isinstance(value, str) | |
| 128 ): | |
| 129 params[key] = _handle_shape(value) | 119 params[key] = _handle_shape(value) |
| 130 | 120 |
| 131 elif key.endswith("_regularizer") and isinstance(value, dict): | 121 elif key.endswith("_regularizer") and isinstance(value, dict): |
| 132 params[key] = _handle_regularizer(value) | 122 params[key] = _handle_regularizer(value) |
| 133 | 123 |
| 139 | 129 |
| 140 return params | 130 return params |
| 141 | 131 |
| 142 | 132 |
| 143 def get_sequential_model(config): | 133 def get_sequential_model(config): |
| 144 """ | 134 """Construct keras Sequential model from Galaxy tool parameters |
| 145 Construct keras Sequential model from Galaxy tool parameters | |
| 146 | 135 |
| 147 Parameters: | 136 Parameters: |
| 148 ----------- | 137 ----------- |
| 149 config : dictionary, galaxy tool parameters loaded by JSON | 138 config : dictionary, galaxy tool parameters loaded by JSON |
| 150 """ | 139 """ |
| 163 if kwargs: | 152 if kwargs: |
| 164 kwargs = safe_eval("dict(" + kwargs + ")") | 153 kwargs = safe_eval("dict(" + kwargs + ")") |
| 165 options.update(kwargs) | 154 options.update(kwargs) |
| 166 | 155 |
| 167 # add input_shape to the first layer only | 156 # add input_shape to the first layer only |
| 168 if not getattr(model, "_layers") and input_shape is not None: | 157 if not model.get_config()["layers"] and input_shape is not None: |
| 169 options["input_shape"] = input_shape | 158 options["input_shape"] = input_shape |
| 170 | 159 |
| 171 model.add(klass(**options)) | 160 model.add(klass(**options)) |
| 172 | 161 |
| 173 return model | 162 return model |
| 174 | 163 |
| 175 | 164 |
| 176 def get_functional_model(config): | 165 def get_functional_model(config): |
| 177 """ | 166 """Construct keras functional model from Galaxy tool parameters |
| 178 Construct keras functional model from Galaxy tool parameters | |
| 179 | 167 |
| 180 Parameters | 168 Parameters |
| 181 ----------- | 169 ----------- |
| 182 config : dictionary, galaxy tool parameters loaded by JSON | 170 config : dictionary, galaxy tool parameters loaded by JSON |
| 183 """ | 171 """ |
| 219 | 207 |
| 220 return Model(inputs=input_layers, outputs=output_layers) | 208 return Model(inputs=input_layers, outputs=output_layers) |
| 221 | 209 |
| 222 | 210 |
| 223 def get_batch_generator(config): | 211 def get_batch_generator(config): |
| 224 """ | 212 """Construct keras online data generator from Galaxy tool parameters |
| 225 Construct keras online data generator from Galaxy tool parameters | |
| 226 | 213 |
| 227 Parameters | 214 Parameters |
| 228 ----------- | 215 ----------- |
| 229 config : dictionary, galaxy tool parameters loaded by JSON | 216 config : dictionary, galaxy tool parameters loaded by JSON |
| 230 """ | 217 """ |
| 244 | 231 |
| 245 return klass(**config) | 232 return klass(**config) |
| 246 | 233 |
| 247 | 234 |
| 248 def config_keras_model(inputs, outfile): | 235 def config_keras_model(inputs, outfile): |
| 249 """ | 236 """config keras model layers and output JSON |
| 250 config keras model layers and output JSON | |
| 251 | 237 |
| 252 Parameters | 238 Parameters |
| 253 ---------- | 239 ---------- |
| 254 inputs : dict | 240 inputs : dict |
| 255 loaded galaxy tool parameters from `keras_model_config` | 241 loaded galaxy tool parameters from `keras_model_config` |
| 269 | 255 |
| 270 with open(outfile, "w") as f: | 256 with open(outfile, "w") as f: |
| 271 json.dump(json.loads(json_string), f, indent=2) | 257 json.dump(json.loads(json_string), f, indent=2) |
| 272 | 258 |
| 273 | 259 |
| 274 def build_keras_model( | 260 def build_keras_model(inputs, outfile, model_json, batch_mode=False): |
| 275 inputs, | 261 """for `keras_model_builder` tool |
| 276 outfile, | |
| 277 model_json, | |
| 278 infile_weights=None, | |
| 279 batch_mode=False, | |
| 280 outfile_params=None, | |
| 281 ): | |
| 282 """ | |
| 283 for `keras_model_builder` tool | |
| 284 | 262 |
| 285 Parameters | 263 Parameters |
| 286 ---------- | 264 ---------- |
| 287 inputs : dict | 265 inputs : dict |
| 288 loaded galaxy tool parameters from `keras_model_builder` tool. | 266 loaded galaxy tool parameters from `keras_model_builder` tool. |
| 289 outfile : str | 267 outfile : str |
| 290 Path to galaxy dataset containing the keras_galaxy model output. | 268 Path to galaxy dataset containing the keras_galaxy model output. |
| 291 model_json : str | 269 model_json : str |
| 292 Path to dataset containing keras model JSON. | 270 Path to dataset containing keras model JSON. |
| 293 infile_weights : str or None | |
| 294 If string, path to dataset containing model weights. | |
| 295 batch_mode : bool, default=False | 271 batch_mode : bool, default=False |
| 296 Whether to build online batch classifier. | 272 Whether to build online batch classifier. |
| 297 outfile_params : str, default=None | |
| 298 File path to search parameters output. | |
| 299 """ | 273 """ |
| 300 with open(model_json, "r") as f: | 274 with open(model_json, "r") as f: |
| 301 json_model = json.load(f) | 275 json_model = json.load(f) |
| 302 | 276 |
| 303 config = json_model["config"] | 277 config = json_model["config"] |
| 305 options = {} | 279 options = {} |
| 306 | 280 |
| 307 if json_model["class_name"] == "Sequential": | 281 if json_model["class_name"] == "Sequential": |
| 308 options["model_type"] = "sequential" | 282 options["model_type"] = "sequential" |
| 309 klass = Sequential | 283 klass = Sequential |
| 310 elif json_model["class_name"] == "Model": | 284 elif json_model["class_name"] == "Functional": |
| 311 options["model_type"] = "functional" | 285 options["model_type"] = "functional" |
| 312 klass = Model | 286 klass = Model |
| 313 else: | 287 else: |
| 314 raise ValueError("Unknow Keras model class: %s" % json_model["class_name"]) | 288 raise ValueError("Unknow Keras model class: %s" % json_model["class_name"]) |
| 315 | 289 |
| 316 # load prefitted model | 290 # load prefitted model |
| 317 if inputs["mode_selection"]["mode_type"] == "prefitted": | 291 if inputs["mode_selection"]["mode_type"] == "prefitted": |
| 318 estimator = klass.from_config(config) | 292 # estimator = klass.from_config(config) |
| 319 estimator.load_weights(infile_weights) | 293 # estimator.load_weights(infile_weights) |
| 294 raise Exception("Prefitted was deprecated!") | |
| 320 # build train model | 295 # build train model |
| 321 else: | 296 else: |
| 322 cls_name = inputs["mode_selection"]["learning_type"] | 297 cls_name = inputs["mode_selection"]["learning_type"] |
| 323 klass = try_get_attr("galaxy_ml.keras_galaxy_models", cls_name) | 298 klass = try_get_attr("galaxy_ml.keras_galaxy_models", cls_name) |
| 324 | 299 |
| 336 ] | 311 ] |
| 337 ) | 312 ) |
| 338 ) | 313 ) |
| 339 | 314 |
| 340 train_metrics = inputs["mode_selection"]["compile_params"]["metrics"] | 315 train_metrics = inputs["mode_selection"]["compile_params"]["metrics"] |
| 316 if not isinstance(train_metrics, list): # for older galaxy | |
| 317 train_metrics = train_metrics.split(",") | |
| 341 if train_metrics[-1] == "none": | 318 if train_metrics[-1] == "none": |
| 342 train_metrics = train_metrics[:-1] | 319 train_metrics.pop() |
| 343 options["metrics"] = train_metrics | 320 options["metrics"] = train_metrics |
| 344 | 321 |
| 345 options.update(inputs["mode_selection"]["fit_params"]) | 322 options.update(inputs["mode_selection"]["fit_params"]) |
| 346 options["seed"] = inputs["mode_selection"]["random_seed"] | 323 options["seed"] = inputs["mode_selection"]["random_seed"] |
| 347 | 324 |
| 353 options["prediction_steps"] = inputs["mode_selection"]["prediction_steps"] | 330 options["prediction_steps"] = inputs["mode_selection"]["prediction_steps"] |
| 354 options["class_positive_factor"] = inputs["mode_selection"][ | 331 options["class_positive_factor"] = inputs["mode_selection"][ |
| 355 "class_positive_factor" | 332 "class_positive_factor" |
| 356 ] | 333 ] |
| 357 estimator = klass(config, **options) | 334 estimator = klass(config, **options) |
| 358 if outfile_params: | |
| 359 hyper_params = get_search_params(estimator) | |
| 360 # TODO: remove this after making `verbose` tunable | |
| 361 for h_param in hyper_params: | |
| 362 if h_param[1].endswith("verbose"): | |
| 363 h_param[0] = "@" | |
| 364 df = pd.DataFrame(hyper_params, columns=["", "Parameter", "Value"]) | |
| 365 df.to_csv(outfile_params, sep="\t", index=False) | |
| 366 | 335 |
| 367 print(repr(estimator)) | 336 print(repr(estimator)) |
| 368 # save model by pickle | 337 # save model |
| 369 with open(outfile, "wb") as f: | 338 dump_model_to_h5(estimator, outfile, verbose=1) |
| 370 pickle.dump(estimator, f, pickle.HIGHEST_PROTOCOL) | |
| 371 | 339 |
| 372 | 340 |
| 373 if __name__ == "__main__": | 341 if __name__ == "__main__": |
| 374 warnings.simplefilter("ignore") | 342 warnings.simplefilter("ignore") |
| 375 | 343 |
| 376 aparser = argparse.ArgumentParser() | 344 aparser = argparse.ArgumentParser() |
| 377 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 345 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
| 378 aparser.add_argument("-m", "--model_json", dest="model_json") | 346 aparser.add_argument("-m", "--model_json", dest="model_json") |
| 379 aparser.add_argument("-t", "--tool_id", dest="tool_id") | 347 aparser.add_argument("-t", "--tool_id", dest="tool_id") |
| 380 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") | |
| 381 aparser.add_argument("-o", "--outfile", dest="outfile") | 348 aparser.add_argument("-o", "--outfile", dest="outfile") |
| 382 aparser.add_argument("-p", "--outfile_params", dest="outfile_params") | |
| 383 args = aparser.parse_args() | 349 args = aparser.parse_args() |
| 384 | 350 |
| 385 input_json_path = args.inputs | 351 input_json_path = args.inputs |
| 386 with open(input_json_path, "r") as param_handler: | 352 with open(input_json_path, "r") as param_handler: |
| 387 inputs = json.load(param_handler) | 353 inputs = json.load(param_handler) |
| 388 | 354 |
| 389 tool_id = args.tool_id | 355 tool_id = args.tool_id |
| 390 outfile = args.outfile | 356 outfile = args.outfile |
| 391 outfile_params = args.outfile_params | |
| 392 model_json = args.model_json | 357 model_json = args.model_json |
| 393 infile_weights = args.infile_weights | |
| 394 | 358 |
| 395 # for keras_model_config tool | 359 # for keras_model_config tool |
| 396 if tool_id == "keras_model_config": | 360 if tool_id == "keras_model_config": |
| 397 config_keras_model(inputs, outfile) | 361 config_keras_model(inputs, outfile) |
| 398 | 362 |
| 401 batch_mode = False | 365 batch_mode = False |
| 402 if tool_id == "keras_batch_models": | 366 if tool_id == "keras_batch_models": |
| 403 batch_mode = True | 367 batch_mode = True |
| 404 | 368 |
| 405 build_keras_model( | 369 build_keras_model( |
| 406 inputs=inputs, | 370 inputs=inputs, model_json=model_json, batch_mode=batch_mode, outfile=outfile |
| 407 model_json=model_json, | |
| 408 infile_weights=infile_weights, | |
| 409 batch_mode=batch_mode, | |
| 410 outfile=outfile, | |
| 411 outfile_params=outfile_params, | |
| 412 ) | 371 ) |
