Mercurial > repos > kls286 > chap_test_20230328
annotate build/lib/MLaaS/ktrain.py @ 1:1016ae8f31ec draft
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
| author | kls286 | 
|---|---|
| date | Tue, 28 Mar 2023 15:16:40 +0000 | 
| parents | cbbe42422d56 | 
| children | 
| rev | line source | 
|---|---|
| 
0
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
1 #!/usr/bin/env python | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
2 #-*- coding: utf-8 -*- | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
3 #pylint: disable= | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
4 """ | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
5 File : ktrain.py | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
6 Author : Valentin Kuznetsov <vkuznet AT gmail dot com> | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
7 Description: Keras based ML network to train over MNIST dataset | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
8 """ | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
9 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
10 # system modules | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
11 import os | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
12 import sys | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
13 import json | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
14 import gzip | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
15 import pickle | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
16 import argparse | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
17 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
18 # third-party modules | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
19 import numpy as np | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
20 import tensorflow as tf | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
21 from tensorflow import keras | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
22 from tensorflow.keras import layers | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
23 from tensorflow.keras import backend as K | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
24 from tensorflow.python.tools import saved_model_utils | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
25 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
26 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
27 def modelGraph(model_dir): | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
28 """ | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
29 Provide input/output names used by TF Graph along with graph itself | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
30 The code is based on TF saved_model_cli.py script. | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
31 """ | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
32 input_names = [] | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
33 output_names = [] | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
34 tag_sets = saved_model_utils.get_saved_model_tag_sets(model_dir) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
35 for tag_set in sorted(tag_sets): | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
36 print('%r' % ', '.join(sorted(tag_set))) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
37 meta_graph_def = saved_model_utils.get_meta_graph_def(model_dir, tag_set[0]) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
38 for key in meta_graph_def.signature_def.keys(): | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
39 meta = meta_graph_def.signature_def[key] | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
40 if hasattr(meta, 'inputs') and hasattr(meta, 'outputs'): | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
41 inputs = meta.inputs | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
42 outputs = meta.outputs | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
43 input_signatures = list(meta.inputs.values()) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
44 input_names = [signature.name for signature in input_signatures] | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
45 if len(input_names) > 0: | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
46 output_signatures = list(meta.outputs.values()) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
47 output_names = [signature.name for signature in output_signatures] | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
48 return input_names, output_names, meta_graph_def | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
49 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
50 def readData(fin, num_classes): | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
51 """ | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
52 Helper function to read MNIST data and provide it to | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
53 upstream code, e.g. to the training layer | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
54 """ | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
55 # Load the data and split it between train and test sets | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
56 # (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
57 f = gzip.open(fin, 'rb') | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
58 if sys.version_info < (3,): | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
59 mnist_data = pickle.load(f) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
60 else: | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
61 mnist_data = pickle.load(f, encoding='bytes') | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
62 f.close() | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
63 (x_train, y_train), (x_test, y_test) = mnist_data | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
64 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
65 # Scale images to the [0, 1] range | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
66 x_train = x_train.astype("float32") / 255 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
67 x_test = x_test.astype("float32") / 255 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
68 # Make sure images have shape (28, 28, 1) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
69 x_train = np.expand_dims(x_train, -1) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
70 x_test = np.expand_dims(x_test, -1) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
71 print("x_train shape:", x_train.shape) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
72 print(x_train.shape[0], "train samples") | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
73 print(x_test.shape[0], "test samples") | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
74 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
75 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
76 # convert class vectors to binary class matrices | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
77 y_train = keras.utils.to_categorical(y_train, num_classes) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
78 y_test = keras.utils.to_categorical(y_test, num_classes) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
79 return x_train, y_train, x_test, y_test | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
80 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
81 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
82 def train(fin, fout=None, model_name=None, epochs=1, batch_size=128, h5=False): | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
83 """ | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
84 train function for MNIST | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
85 """ | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
86 # Model / data parameters | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
87 num_classes = 10 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
88 input_shape = (28, 28, 1) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
89 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
90 # create ML model | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
91 model = keras.Sequential( | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
92 [ | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
93 keras.Input(shape=input_shape), | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
94 layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
95 layers.MaxPooling2D(pool_size=(2, 2)), | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
96 layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
97 layers.MaxPooling2D(pool_size=(2, 2)), | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
98 layers.Flatten(), | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
99 layers.Dropout(0.5), | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
100 layers.Dense(num_classes, activation="softmax"), | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
101 ] | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
102 ) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
103 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
104 model.summary() | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
105 print("model input", model.input, type(model.input), model.input.__dict__) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
106 print("model output", model.output, type(model.output), model.output.__dict__) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
107 model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
108 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
109 # train model | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
110 x_train, y_train, x_test, y_test = readData(fin, num_classes) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
111 model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
112 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
113 # evaluate trained model | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
114 score = model.evaluate(x_test, y_test, verbose=0) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
115 print("Test loss:", score[0]) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
116 print("Test accuracy:", score[1]) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
117 print("save model to", fout) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
118 writer(fout, model_name, model, input_shape, h5) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
119 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
120 def writer(fout, model_name, model, input_shape, h5=False): | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
121 """ | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
122 Writer provide write function for given model | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
123 """ | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
124 if not fout: | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
125 return | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
126 model.save(fout) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
127 if h5: | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
128 model.save('{}/{}'.format(fout, h5), save_format='h5') | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
129 pbModel = '{}/saved_model.pb'.format(fout) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
130 pbtxtModel = '{}/saved_model.pbtxt'.format(fout) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
131 convert(pbModel, pbtxtModel) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
132 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
133 # get meta-data information about our ML model | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
134 input_names, output_names, model_graph = modelGraph(model_name) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
135 print("### input", input_names) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
136 print("### output", output_names) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
137 # ML uses (28,28,1) shape, i.e. 28x28 black-white images | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
138 # if we'll use color images we'll use shape (28, 28, 3) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
139 img_channels = input_shape[2] # last item represent number of colors | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
140 meta = {'name': model_name, | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
141 'model': 'saved_model.pb', | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
142 'labels': 'labels.txt', | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
143 'img_channels': img_channels, | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
144 'input_name': input_names[0].split(':')[0], | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
145 'output_name': output_names[0].split(':')[0], | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
146 'input_node': model.input.name, | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
147 'output_node': model.output.name | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
148 } | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
149 with open(fout+'/params.json', 'w') as ostream: | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
150 ostream.write(json.dumps(meta)) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
151 with open(fout+'/labels.txt', 'w') as ostream: | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
152 for i in range(0, 10): | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
153 ostream.write(str(i)+'\n') | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
154 with open(fout + '/model.graph', 'wb') as ostream: | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
155 ostream.write(model_graph.SerializeToString()) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
156 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
157 def convert(fin, fout): | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
158 """ | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
159 convert input model.pb into output model.pbtxt | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
160 Based on internet search: | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
161 - https://www.tensorflow.org/guide/saved_model | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
162 - https://www.programcreek.com/python/example/123317/tensorflow.core.protobuf.saved_model_pb2.SavedModel | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
163 """ | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
164 import google.protobuf | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
165 from tensorflow.core.protobuf import saved_model_pb2 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
166 import tensorflow as tf | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
167 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
168 saved_model = saved_model_pb2.SavedModel() | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
169 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
170 with open(fin, 'rb') as f: | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
171 saved_model.ParseFromString(f.read()) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
172 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
173 with open(fout, 'w') as f: | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
174 f.write(google.protobuf.text_format.MessageToString(saved_model)) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
175 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
176 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
177 class OptionParser(): | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
178 def __init__(self): | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
179 "User based option parser" | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
180 self.parser = argparse.ArgumentParser(prog='PROG') | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
181 self.parser.add_argument("--fin", action="store", | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
182 dest="fin", default="", help="Input MNIST file") | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
183 self.parser.add_argument("--fout", action="store", | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
184 dest="fout", default="", help="Output models area") | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
185 self.parser.add_argument("--model", action="store", | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
186 dest="model", default="mnist", help="model name") | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
187 self.parser.add_argument("--epochs", action="store", | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
188 dest="epochs", default=1, help="number of epochs to use in ML training") | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
189 self.parser.add_argument("--batch_size", action="store", | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
190 dest="batch_size", default=128, help="batch size to use in training") | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
191 self.parser.add_argument("--h5", action="store", | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
192 dest="h5", default="mnist", help="h5 model file name") | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
193 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
194 def main(): | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
195 "Main function" | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
196 optmgr = OptionParser() | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
197 opts = optmgr.parser.parse_args() | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
198 train(opts.fin, opts.fout, | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
199 model_name=opts.model, | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
200 epochs=opts.epochs, | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
201 batch_size=opts.batch_size, | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
202 h5=opts.h5) | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
203 | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
204 if __name__ == '__main__': | 
| 
 
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
 
kls286 
parents:  
diff
changeset
 | 
205 main() | 
