Mercurial > repos > kls286 > chap_test_20230328
comparison build/lib/MLaaS/ktrain.py @ 0:cbbe42422d56 draft
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
| author | kls286 |
|---|---|
| date | Tue, 28 Mar 2023 15:07:30 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:cbbe42422d56 |
|---|---|
| 1 #!/usr/bin/env python | |
| 2 #-*- coding: utf-8 -*- | |
| 3 #pylint: disable= | |
| 4 """ | |
| 5 File : ktrain.py | |
| 6 Author : Valentin Kuznetsov <vkuznet AT gmail dot com> | |
| 7 Description: Keras based ML network to train over MNIST dataset | |
| 8 """ | |
| 9 | |
| 10 # system modules | |
| 11 import os | |
| 12 import sys | |
| 13 import json | |
| 14 import gzip | |
| 15 import pickle | |
| 16 import argparse | |
| 17 | |
| 18 # third-party modules | |
| 19 import numpy as np | |
| 20 import tensorflow as tf | |
| 21 from tensorflow import keras | |
| 22 from tensorflow.keras import layers | |
| 23 from tensorflow.keras import backend as K | |
| 24 from tensorflow.python.tools import saved_model_utils | |
| 25 | |
| 26 | |
| 27 def modelGraph(model_dir): | |
| 28 """ | |
| 29 Provide input/output names used by TF Graph along with graph itself | |
| 30 The code is based on TF saved_model_cli.py script. | |
| 31 """ | |
| 32 input_names = [] | |
| 33 output_names = [] | |
| 34 tag_sets = saved_model_utils.get_saved_model_tag_sets(model_dir) | |
| 35 for tag_set in sorted(tag_sets): | |
| 36 print('%r' % ', '.join(sorted(tag_set))) | |
| 37 meta_graph_def = saved_model_utils.get_meta_graph_def(model_dir, tag_set[0]) | |
| 38 for key in meta_graph_def.signature_def.keys(): | |
| 39 meta = meta_graph_def.signature_def[key] | |
| 40 if hasattr(meta, 'inputs') and hasattr(meta, 'outputs'): | |
| 41 inputs = meta.inputs | |
| 42 outputs = meta.outputs | |
| 43 input_signatures = list(meta.inputs.values()) | |
| 44 input_names = [signature.name for signature in input_signatures] | |
| 45 if len(input_names) > 0: | |
| 46 output_signatures = list(meta.outputs.values()) | |
| 47 output_names = [signature.name for signature in output_signatures] | |
| 48 return input_names, output_names, meta_graph_def | |
| 49 | |
| 50 def readData(fin, num_classes): | |
| 51 """ | |
| 52 Helper function to read MNIST data and provide it to | |
| 53 upstream code, e.g. to the training layer | |
| 54 """ | |
| 55 # Load the data and split it between train and test sets | |
| 56 # (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() | |
| 57 f = gzip.open(fin, 'rb') | |
| 58 if sys.version_info < (3,): | |
| 59 mnist_data = pickle.load(f) | |
| 60 else: | |
| 61 mnist_data = pickle.load(f, encoding='bytes') | |
| 62 f.close() | |
| 63 (x_train, y_train), (x_test, y_test) = mnist_data | |
| 64 | |
| 65 # Scale images to the [0, 1] range | |
| 66 x_train = x_train.astype("float32") / 255 | |
| 67 x_test = x_test.astype("float32") / 255 | |
| 68 # Make sure images have shape (28, 28, 1) | |
| 69 x_train = np.expand_dims(x_train, -1) | |
| 70 x_test = np.expand_dims(x_test, -1) | |
| 71 print("x_train shape:", x_train.shape) | |
| 72 print(x_train.shape[0], "train samples") | |
| 73 print(x_test.shape[0], "test samples") | |
| 74 | |
| 75 | |
| 76 # convert class vectors to binary class matrices | |
| 77 y_train = keras.utils.to_categorical(y_train, num_classes) | |
| 78 y_test = keras.utils.to_categorical(y_test, num_classes) | |
| 79 return x_train, y_train, x_test, y_test | |
| 80 | |
| 81 | |
| 82 def train(fin, fout=None, model_name=None, epochs=1, batch_size=128, h5=False): | |
| 83 """ | |
| 84 train function for MNIST | |
| 85 """ | |
| 86 # Model / data parameters | |
| 87 num_classes = 10 | |
| 88 input_shape = (28, 28, 1) | |
| 89 | |
| 90 # create ML model | |
| 91 model = keras.Sequential( | |
| 92 [ | |
| 93 keras.Input(shape=input_shape), | |
| 94 layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), | |
| 95 layers.MaxPooling2D(pool_size=(2, 2)), | |
| 96 layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), | |
| 97 layers.MaxPooling2D(pool_size=(2, 2)), | |
| 98 layers.Flatten(), | |
| 99 layers.Dropout(0.5), | |
| 100 layers.Dense(num_classes, activation="softmax"), | |
| 101 ] | |
| 102 ) | |
| 103 | |
| 104 model.summary() | |
| 105 print("model input", model.input, type(model.input), model.input.__dict__) | |
| 106 print("model output", model.output, type(model.output), model.output.__dict__) | |
| 107 model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) | |
| 108 | |
| 109 # train model | |
| 110 x_train, y_train, x_test, y_test = readData(fin, num_classes) | |
| 111 model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1) | |
| 112 | |
| 113 # evaluate trained model | |
| 114 score = model.evaluate(x_test, y_test, verbose=0) | |
| 115 print("Test loss:", score[0]) | |
| 116 print("Test accuracy:", score[1]) | |
| 117 print("save model to", fout) | |
| 118 writer(fout, model_name, model, input_shape, h5) | |
| 119 | |
| 120 def writer(fout, model_name, model, input_shape, h5=False): | |
| 121 """ | |
| 122 Writer provide write function for given model | |
| 123 """ | |
| 124 if not fout: | |
| 125 return | |
| 126 model.save(fout) | |
| 127 if h5: | |
| 128 model.save('{}/{}'.format(fout, h5), save_format='h5') | |
| 129 pbModel = '{}/saved_model.pb'.format(fout) | |
| 130 pbtxtModel = '{}/saved_model.pbtxt'.format(fout) | |
| 131 convert(pbModel, pbtxtModel) | |
| 132 | |
| 133 # get meta-data information about our ML model | |
| 134 input_names, output_names, model_graph = modelGraph(model_name) | |
| 135 print("### input", input_names) | |
| 136 print("### output", output_names) | |
| 137 # ML uses (28,28,1) shape, i.e. 28x28 black-white images | |
| 138 # if we'll use color images we'll use shape (28, 28, 3) | |
| 139 img_channels = input_shape[2] # last item represent number of colors | |
| 140 meta = {'name': model_name, | |
| 141 'model': 'saved_model.pb', | |
| 142 'labels': 'labels.txt', | |
| 143 'img_channels': img_channels, | |
| 144 'input_name': input_names[0].split(':')[0], | |
| 145 'output_name': output_names[0].split(':')[0], | |
| 146 'input_node': model.input.name, | |
| 147 'output_node': model.output.name | |
| 148 } | |
| 149 with open(fout+'/params.json', 'w') as ostream: | |
| 150 ostream.write(json.dumps(meta)) | |
| 151 with open(fout+'/labels.txt', 'w') as ostream: | |
| 152 for i in range(0, 10): | |
| 153 ostream.write(str(i)+'\n') | |
| 154 with open(fout + '/model.graph', 'wb') as ostream: | |
| 155 ostream.write(model_graph.SerializeToString()) | |
| 156 | |
| 157 def convert(fin, fout): | |
| 158 """ | |
| 159 convert input model.pb into output model.pbtxt | |
| 160 Based on internet search: | |
| 161 - https://www.tensorflow.org/guide/saved_model | |
| 162 - https://www.programcreek.com/python/example/123317/tensorflow.core.protobuf.saved_model_pb2.SavedModel | |
| 163 """ | |
| 164 import google.protobuf | |
| 165 from tensorflow.core.protobuf import saved_model_pb2 | |
| 166 import tensorflow as tf | |
| 167 | |
| 168 saved_model = saved_model_pb2.SavedModel() | |
| 169 | |
| 170 with open(fin, 'rb') as f: | |
| 171 saved_model.ParseFromString(f.read()) | |
| 172 | |
| 173 with open(fout, 'w') as f: | |
| 174 f.write(google.protobuf.text_format.MessageToString(saved_model)) | |
| 175 | |
| 176 | |
| 177 class OptionParser(): | |
| 178 def __init__(self): | |
| 179 "User based option parser" | |
| 180 self.parser = argparse.ArgumentParser(prog='PROG') | |
| 181 self.parser.add_argument("--fin", action="store", | |
| 182 dest="fin", default="", help="Input MNIST file") | |
| 183 self.parser.add_argument("--fout", action="store", | |
| 184 dest="fout", default="", help="Output models area") | |
| 185 self.parser.add_argument("--model", action="store", | |
| 186 dest="model", default="mnist", help="model name") | |
| 187 self.parser.add_argument("--epochs", action="store", | |
| 188 dest="epochs", default=1, help="number of epochs to use in ML training") | |
| 189 self.parser.add_argument("--batch_size", action="store", | |
| 190 dest="batch_size", default=128, help="batch size to use in training") | |
| 191 self.parser.add_argument("--h5", action="store", | |
| 192 dest="h5", default="mnist", help="h5 model file name") | |
| 193 | |
| 194 def main(): | |
| 195 "Main function" | |
| 196 optmgr = OptionParser() | |
| 197 opts = optmgr.parser.parse_args() | |
| 198 train(opts.fin, opts.fout, | |
| 199 model_name=opts.model, | |
| 200 epochs=opts.epochs, | |
| 201 batch_size=opts.batch_size, | |
| 202 h5=opts.h5) | |
| 203 | |
| 204 if __name__ == '__main__': | |
| 205 main() |
