Mercurial > repos > kls286 > chap_test_20230328
annotate build/lib/MLaaS/ktrain.py @ 0:cbbe42422d56 draft
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
author | kls286 |
---|---|
date | Tue, 28 Mar 2023 15:07:30 +0000 |
parents | |
children |
rev | line source |
---|---|
0
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
1 #!/usr/bin/env python |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
2 #-*- coding: utf-8 -*- |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
3 #pylint: disable= |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
4 """ |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
5 File : ktrain.py |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
6 Author : Valentin Kuznetsov <vkuznet AT gmail dot com> |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
7 Description: Keras based ML network to train over MNIST dataset |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
8 """ |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
9 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
10 # system modules |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
11 import os |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
12 import sys |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
13 import json |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
14 import gzip |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
15 import pickle |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
16 import argparse |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
17 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
18 # third-party modules |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
19 import numpy as np |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
20 import tensorflow as tf |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
21 from tensorflow import keras |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
22 from tensorflow.keras import layers |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
23 from tensorflow.keras import backend as K |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
24 from tensorflow.python.tools import saved_model_utils |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
25 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
26 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
27 def modelGraph(model_dir): |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
28 """ |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
29 Provide input/output names used by TF Graph along with graph itself |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
30 The code is based on TF saved_model_cli.py script. |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
31 """ |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
32 input_names = [] |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
33 output_names = [] |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
34 tag_sets = saved_model_utils.get_saved_model_tag_sets(model_dir) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
35 for tag_set in sorted(tag_sets): |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
36 print('%r' % ', '.join(sorted(tag_set))) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
37 meta_graph_def = saved_model_utils.get_meta_graph_def(model_dir, tag_set[0]) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
38 for key in meta_graph_def.signature_def.keys(): |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
39 meta = meta_graph_def.signature_def[key] |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
40 if hasattr(meta, 'inputs') and hasattr(meta, 'outputs'): |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
41 inputs = meta.inputs |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
42 outputs = meta.outputs |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
43 input_signatures = list(meta.inputs.values()) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
44 input_names = [signature.name for signature in input_signatures] |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
45 if len(input_names) > 0: |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
46 output_signatures = list(meta.outputs.values()) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
47 output_names = [signature.name for signature in output_signatures] |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
48 return input_names, output_names, meta_graph_def |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
49 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
50 def readData(fin, num_classes): |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
51 """ |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
52 Helper function to read MNIST data and provide it to |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
53 upstream code, e.g. to the training layer |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
54 """ |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
55 # Load the data and split it between train and test sets |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
56 # (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
57 f = gzip.open(fin, 'rb') |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
58 if sys.version_info < (3,): |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
59 mnist_data = pickle.load(f) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
60 else: |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
61 mnist_data = pickle.load(f, encoding='bytes') |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
62 f.close() |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
63 (x_train, y_train), (x_test, y_test) = mnist_data |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
64 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
65 # Scale images to the [0, 1] range |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
66 x_train = x_train.astype("float32") / 255 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
67 x_test = x_test.astype("float32") / 255 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
68 # Make sure images have shape (28, 28, 1) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
69 x_train = np.expand_dims(x_train, -1) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
70 x_test = np.expand_dims(x_test, -1) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
71 print("x_train shape:", x_train.shape) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
72 print(x_train.shape[0], "train samples") |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
73 print(x_test.shape[0], "test samples") |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
74 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
75 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
76 # convert class vectors to binary class matrices |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
77 y_train = keras.utils.to_categorical(y_train, num_classes) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
78 y_test = keras.utils.to_categorical(y_test, num_classes) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
79 return x_train, y_train, x_test, y_test |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
80 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
81 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
82 def train(fin, fout=None, model_name=None, epochs=1, batch_size=128, h5=False): |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
83 """ |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
84 train function for MNIST |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
85 """ |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
86 # Model / data parameters |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
87 num_classes = 10 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
88 input_shape = (28, 28, 1) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
89 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
90 # create ML model |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
91 model = keras.Sequential( |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
92 [ |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
93 keras.Input(shape=input_shape), |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
94 layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
95 layers.MaxPooling2D(pool_size=(2, 2)), |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
96 layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
97 layers.MaxPooling2D(pool_size=(2, 2)), |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
98 layers.Flatten(), |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
99 layers.Dropout(0.5), |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
100 layers.Dense(num_classes, activation="softmax"), |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
101 ] |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
102 ) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
103 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
104 model.summary() |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
105 print("model input", model.input, type(model.input), model.input.__dict__) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
106 print("model output", model.output, type(model.output), model.output.__dict__) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
107 model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
108 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
109 # train model |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
110 x_train, y_train, x_test, y_test = readData(fin, num_classes) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
111 model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
112 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
113 # evaluate trained model |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
114 score = model.evaluate(x_test, y_test, verbose=0) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
115 print("Test loss:", score[0]) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
116 print("Test accuracy:", score[1]) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
117 print("save model to", fout) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
118 writer(fout, model_name, model, input_shape, h5) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
119 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
120 def writer(fout, model_name, model, input_shape, h5=False): |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
121 """ |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
122 Writer provide write function for given model |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
123 """ |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
124 if not fout: |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
125 return |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
126 model.save(fout) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
127 if h5: |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
128 model.save('{}/{}'.format(fout, h5), save_format='h5') |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
129 pbModel = '{}/saved_model.pb'.format(fout) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
130 pbtxtModel = '{}/saved_model.pbtxt'.format(fout) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
131 convert(pbModel, pbtxtModel) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
132 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
133 # get meta-data information about our ML model |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
134 input_names, output_names, model_graph = modelGraph(model_name) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
135 print("### input", input_names) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
136 print("### output", output_names) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
137 # ML uses (28,28,1) shape, i.e. 28x28 black-white images |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
138 # if we'll use color images we'll use shape (28, 28, 3) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
139 img_channels = input_shape[2] # last item represent number of colors |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
140 meta = {'name': model_name, |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
141 'model': 'saved_model.pb', |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
142 'labels': 'labels.txt', |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
143 'img_channels': img_channels, |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
144 'input_name': input_names[0].split(':')[0], |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
145 'output_name': output_names[0].split(':')[0], |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
146 'input_node': model.input.name, |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
147 'output_node': model.output.name |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
148 } |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
149 with open(fout+'/params.json', 'w') as ostream: |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
150 ostream.write(json.dumps(meta)) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
151 with open(fout+'/labels.txt', 'w') as ostream: |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
152 for i in range(0, 10): |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
153 ostream.write(str(i)+'\n') |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
154 with open(fout + '/model.graph', 'wb') as ostream: |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
155 ostream.write(model_graph.SerializeToString()) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
156 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
157 def convert(fin, fout): |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
158 """ |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
159 convert input model.pb into output model.pbtxt |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
160 Based on internet search: |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
161 - https://www.tensorflow.org/guide/saved_model |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
162 - https://www.programcreek.com/python/example/123317/tensorflow.core.protobuf.saved_model_pb2.SavedModel |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
163 """ |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
164 import google.protobuf |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
165 from tensorflow.core.protobuf import saved_model_pb2 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
166 import tensorflow as tf |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
167 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
168 saved_model = saved_model_pb2.SavedModel() |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
169 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
170 with open(fin, 'rb') as f: |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
171 saved_model.ParseFromString(f.read()) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
172 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
173 with open(fout, 'w') as f: |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
174 f.write(google.protobuf.text_format.MessageToString(saved_model)) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
175 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
176 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
177 class OptionParser(): |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
178 def __init__(self): |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
179 "User based option parser" |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
180 self.parser = argparse.ArgumentParser(prog='PROG') |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
181 self.parser.add_argument("--fin", action="store", |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
182 dest="fin", default="", help="Input MNIST file") |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
183 self.parser.add_argument("--fout", action="store", |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
184 dest="fout", default="", help="Output models area") |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
185 self.parser.add_argument("--model", action="store", |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
186 dest="model", default="mnist", help="model name") |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
187 self.parser.add_argument("--epochs", action="store", |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
188 dest="epochs", default=1, help="number of epochs to use in ML training") |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
189 self.parser.add_argument("--batch_size", action="store", |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
190 dest="batch_size", default=128, help="batch size to use in training") |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
191 self.parser.add_argument("--h5", action="store", |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
192 dest="h5", default="mnist", help="h5 model file name") |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
193 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
194 def main(): |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
195 "Main function" |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
196 optmgr = OptionParser() |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
197 opts = optmgr.parser.parse_args() |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
198 train(opts.fin, opts.fout, |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
199 model_name=opts.model, |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
200 epochs=opts.epochs, |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
201 batch_size=opts.batch_size, |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
202 h5=opts.h5) |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
203 |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
204 if __name__ == '__main__': |
cbbe42422d56
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
kls286
parents:
diff
changeset
|
205 main() |