Mercurial > repos > kls286 > chap_test_20230328
comparison MLaaS/mnist_img.py @ 0:cbbe42422d56 draft
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
author | kls286 |
---|---|
date | Tue, 28 Mar 2023 15:07:30 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:cbbe42422d56 |
---|---|
1 #!/usr/bin/env python | |
2 #-*- coding: utf-8 -*- | |
3 #pylint: disable= | |
4 """ | |
5 File : mnist_img.py | |
6 Author : Valentin Kuznetsov <vkuznet AT gmail dot com> | |
7 Description: | |
8 """ | |
9 | |
10 import json | |
11 import gzip | |
12 import argparse | |
13 # from itertools import chain | |
14 | |
15 import numpy as np | |
16 import matplotlib.pyplot as plt | |
17 | |
18 | |
19 def readImage(fname, fout, num_images=5, imgId=2): | |
20 """ | |
21 Helper function to read MNIST image | |
22 """ | |
23 image_size = 28 | |
24 with gzip.open(fname, 'r') as fstream: | |
25 fstream.read(16) | |
26 buf = fstream.read(image_size * image_size * num_images) | |
27 data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32) | |
28 data = data.reshape(num_images, image_size, image_size, 1) | |
29 image = np.asarray(data[imgId]).squeeze() | |
30 plt.imsave(fout, image) | |
31 print("read:", fname, "wrote:", fout, "image:", type(image), "shape:", image.shape) | |
32 | |
33 def img2json(image): | |
34 """ | |
35 Convert given image to JSON data format used by TFaaS | |
36 """ | |
37 # values = [int(i) for i in list(chain.from_iterable(image))] | |
38 # values = image.tolist() | |
39 values = [] | |
40 for row in image.tolist(): | |
41 row = [int(i) for i in row] | |
42 vals = [[i] for i in row] | |
43 values.append(vals) | |
44 # final values should be an array of elements, e.g. single image representation | |
45 values = [values] | |
46 keys = [str(i) for i in range(0, 10)] | |
47 meta = { | |
48 'keys': keys, | |
49 'values': values, | |
50 'model': 'mnist' | |
51 } | |
52 with open('img.json', 'w') as ostream: | |
53 ostream.write(json.dumps(meta)) | |
54 | |
55 | |
56 class OptionParser(): | |
57 def __init__(self): | |
58 "User based option parser" | |
59 fname = "train-images-idx3-ubyte.gz" | |
60 self.parser = argparse.ArgumentParser(prog='PROG') | |
61 self.parser.add_argument("--fin", action="store", | |
62 dest="fin", default=fname, help=f"Input MNIST file, default {fname}") | |
63 self.parser.add_argument("--fout", action="store", | |
64 dest="fout", default="img.png", help="Output image fila name, default img.png") | |
65 self.parser.add_argument("--nimages", action="store", | |
66 dest="nimages", default=5, help="number of images to read, default 5") | |
67 self.parser.add_argument("--imgid", action="store", | |
68 dest="imgid", default=2, help="image index to use from nimages, default 2 (number 4)") | |
69 | |
70 def main(): | |
71 """ | |
72 main function to produce image file from mnist dataset. | |
73 MNIST dataset can be downloaded from | |
74 curl -O http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz | |
75 """ | |
76 optmgr = OptionParser() | |
77 opts = optmgr.parser.parse_args() | |
78 num_images = int(opts.nimages) | |
79 imgId = int(opts.imgid) | |
80 img = readImage(opts.fin, opts.fout, num_images, imgId) | |
81 | |
82 if __name__ == '__main__': | |
83 main() |