annotate build/lib/MLaaS/ktrain.py @ 0:cbbe42422d56 draft

planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
author kls286
date Tue, 28 Mar 2023 15:07:30 +0000
parents
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
1 #!/usr/bin/env python
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
2 #-*- coding: utf-8 -*-
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
3 #pylint: disable=
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
4 """
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
5 File : ktrain.py
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
6 Author : Valentin Kuznetsov <vkuznet AT gmail dot com>
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
7 Description: Keras based ML network to train over MNIST dataset
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
8 """
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
9
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
10 # system modules
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
11 import os
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
12 import sys
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
13 import json
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
14 import gzip
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
15 import pickle
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
16 import argparse
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
17
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
18 # third-party modules
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
19 import numpy as np
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
20 import tensorflow as tf
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
21 from tensorflow import keras
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
22 from tensorflow.keras import layers
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
23 from tensorflow.keras import backend as K
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
24 from tensorflow.python.tools import saved_model_utils
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
25
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
26
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
27 def modelGraph(model_dir):
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
28 """
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
29 Provide input/output names used by TF Graph along with graph itself
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
30 The code is based on TF saved_model_cli.py script.
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
31 """
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
32 input_names = []
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
33 output_names = []
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
34 tag_sets = saved_model_utils.get_saved_model_tag_sets(model_dir)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
35 for tag_set in sorted(tag_sets):
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
36 print('%r' % ', '.join(sorted(tag_set)))
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
37 meta_graph_def = saved_model_utils.get_meta_graph_def(model_dir, tag_set[0])
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
38 for key in meta_graph_def.signature_def.keys():
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
39 meta = meta_graph_def.signature_def[key]
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
40 if hasattr(meta, 'inputs') and hasattr(meta, 'outputs'):
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
41 inputs = meta.inputs
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
42 outputs = meta.outputs
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
43 input_signatures = list(meta.inputs.values())
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
44 input_names = [signature.name for signature in input_signatures]
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
45 if len(input_names) > 0:
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
46 output_signatures = list(meta.outputs.values())
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
47 output_names = [signature.name for signature in output_signatures]
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
48 return input_names, output_names, meta_graph_def
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
49
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
50 def readData(fin, num_classes):
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
51 """
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
52 Helper function to read MNIST data and provide it to
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
53 upstream code, e.g. to the training layer
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
54 """
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
55 # Load the data and split it between train and test sets
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
56 # (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
57 f = gzip.open(fin, 'rb')
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
58 if sys.version_info < (3,):
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
59 mnist_data = pickle.load(f)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
60 else:
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
61 mnist_data = pickle.load(f, encoding='bytes')
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
62 f.close()
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
63 (x_train, y_train), (x_test, y_test) = mnist_data
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
64
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
65 # Scale images to the [0, 1] range
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
66 x_train = x_train.astype("float32") / 255
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
67 x_test = x_test.astype("float32") / 255
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
68 # Make sure images have shape (28, 28, 1)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
69 x_train = np.expand_dims(x_train, -1)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
70 x_test = np.expand_dims(x_test, -1)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
71 print("x_train shape:", x_train.shape)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
72 print(x_train.shape[0], "train samples")
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
73 print(x_test.shape[0], "test samples")
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
74
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
75
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
76 # convert class vectors to binary class matrices
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
77 y_train = keras.utils.to_categorical(y_train, num_classes)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
78 y_test = keras.utils.to_categorical(y_test, num_classes)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
79 return x_train, y_train, x_test, y_test
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
80
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
81
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
82 def train(fin, fout=None, model_name=None, epochs=1, batch_size=128, h5=False):
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
83 """
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
84 train function for MNIST
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
85 """
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
86 # Model / data parameters
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
87 num_classes = 10
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
88 input_shape = (28, 28, 1)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
89
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
90 # create ML model
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
91 model = keras.Sequential(
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
92 [
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
93 keras.Input(shape=input_shape),
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
94 layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
95 layers.MaxPooling2D(pool_size=(2, 2)),
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
96 layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
97 layers.MaxPooling2D(pool_size=(2, 2)),
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
98 layers.Flatten(),
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
99 layers.Dropout(0.5),
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
100 layers.Dense(num_classes, activation="softmax"),
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
101 ]
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
102 )
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
103
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
104 model.summary()
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
105 print("model input", model.input, type(model.input), model.input.__dict__)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
106 print("model output", model.output, type(model.output), model.output.__dict__)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
107 model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
108
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
109 # train model
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
110 x_train, y_train, x_test, y_test = readData(fin, num_classes)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
111 model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
112
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
113 # evaluate trained model
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
114 score = model.evaluate(x_test, y_test, verbose=0)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
115 print("Test loss:", score[0])
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
116 print("Test accuracy:", score[1])
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
117 print("save model to", fout)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
118 writer(fout, model_name, model, input_shape, h5)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
119
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
120 def writer(fout, model_name, model, input_shape, h5=False):
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
121 """
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
122 Writer provide write function for given model
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
123 """
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
124 if not fout:
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
125 return
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
126 model.save(fout)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
127 if h5:
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
128 model.save('{}/{}'.format(fout, h5), save_format='h5')
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
129 pbModel = '{}/saved_model.pb'.format(fout)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
130 pbtxtModel = '{}/saved_model.pbtxt'.format(fout)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
131 convert(pbModel, pbtxtModel)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
132
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
133 # get meta-data information about our ML model
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
134 input_names, output_names, model_graph = modelGraph(model_name)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
135 print("### input", input_names)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
136 print("### output", output_names)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
137 # ML uses (28,28,1) shape, i.e. 28x28 black-white images
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
138 # if we'll use color images we'll use shape (28, 28, 3)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
139 img_channels = input_shape[2] # last item represent number of colors
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
140 meta = {'name': model_name,
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
141 'model': 'saved_model.pb',
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
142 'labels': 'labels.txt',
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
143 'img_channels': img_channels,
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
144 'input_name': input_names[0].split(':')[0],
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
145 'output_name': output_names[0].split(':')[0],
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
146 'input_node': model.input.name,
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
147 'output_node': model.output.name
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
148 }
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
149 with open(fout+'/params.json', 'w') as ostream:
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
150 ostream.write(json.dumps(meta))
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
151 with open(fout+'/labels.txt', 'w') as ostream:
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
152 for i in range(0, 10):
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
153 ostream.write(str(i)+'\n')
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
154 with open(fout + '/model.graph', 'wb') as ostream:
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
155 ostream.write(model_graph.SerializeToString())
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
156
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
157 def convert(fin, fout):
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
158 """
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
159 convert input model.pb into output model.pbtxt
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
160 Based on internet search:
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
161 - https://www.tensorflow.org/guide/saved_model
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
162 - https://www.programcreek.com/python/example/123317/tensorflow.core.protobuf.saved_model_pb2.SavedModel
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
163 """
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
164 import google.protobuf
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
165 from tensorflow.core.protobuf import saved_model_pb2
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
166 import tensorflow as tf
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
167
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
168 saved_model = saved_model_pb2.SavedModel()
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
169
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
170 with open(fin, 'rb') as f:
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
171 saved_model.ParseFromString(f.read())
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
172
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
173 with open(fout, 'w') as f:
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
174 f.write(google.protobuf.text_format.MessageToString(saved_model))
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
175
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
176
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
177 class OptionParser():
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
178 def __init__(self):
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
179 "User based option parser"
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
180 self.parser = argparse.ArgumentParser(prog='PROG')
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
181 self.parser.add_argument("--fin", action="store",
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
182 dest="fin", default="", help="Input MNIST file")
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
183 self.parser.add_argument("--fout", action="store",
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
184 dest="fout", default="", help="Output models area")
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
185 self.parser.add_argument("--model", action="store",
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
186 dest="model", default="mnist", help="model name")
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
187 self.parser.add_argument("--epochs", action="store",
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
188 dest="epochs", default=1, help="number of epochs to use in ML training")
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
189 self.parser.add_argument("--batch_size", action="store",
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
190 dest="batch_size", default=128, help="batch size to use in training")
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
191 self.parser.add_argument("--h5", action="store",
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
192 dest="h5", default="mnist", help="h5 model file name")
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
193
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
194 def main():
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
195 "Main function"
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
196 optmgr = OptionParser()
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
197 opts = optmgr.parser.parse_args()
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
198 train(opts.fin, opts.fout,
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
199 model_name=opts.model,
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
200 epochs=opts.epochs,
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
201 batch_size=opts.batch_size,
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
202 h5=opts.h5)
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
203
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
204 if __name__ == '__main__':
cbbe42422d56 planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff changeset
205 main()