comparison MLaaS/tfaas_client.py @ 0:cbbe42422d56 draft

planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
author kls286
date Tue, 28 Mar 2023 15:07:30 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:cbbe42422d56
1 #!/usr/bin/env python
2 #-*- coding: utf-8 -*-
3 #pylint: disable=
4 """
5 File : tfaas_client.py
6 Author : Valentin Kuznetsov <vkuznet AT gmail dot com>
7 Description: simple python client to communicate with TFaaS server
8 """
9
10 # system modules
11 import os
12 import sys
13 import pwd
14 import ssl
15 import json
16 import binascii
17 import argparse
18 import itertools
19 import mimetypes
20 if sys.version_info < (2, 7):
21 raise Exception("TFaaS client requires python 2.7 or greater")
22 # python 3
23 if sys.version.startswith('3.'):
24 import urllib.request as urllib2
25 import urllib.parse as urllib
26 import http.client as httplib
27 import http.cookiejar as cookielib
28 else:
29 import mimetools
30 import urllib
31 import urllib2
32 import httplib
33 import cookielib
34
35 TFAAS_CLIENT = 'tfaas-client/1.1::python/%s.%s' % sys.version_info[:2]
36
37 class OptionParser():
38 def __init__(self):
39 "User based option parser"
40 self.parser = argparse.ArgumentParser(prog='PROG')
41 self.parser.add_argument("--url", action="store",
42 dest="url", default="", help="TFaaS URL")
43 self.parser.add_argument("--upload", action="store",
44 dest="upload", default="", help="upload model to TFaaS")
45 self.parser.add_argument("--bundle", action="store",
46 dest="bundle", default="", help="upload bundle ML files to TFaaS")
47 self.parser.add_argument("--predict", action="store",
48 dest="predict", default="", help="fetch prediction from TFaaS")
49 self.parser.add_argument("--image", action="store",
50 dest="image", default="", help="fetch prediction for given image")
51 self.parser.add_argument("--model", action="store",
52 dest="model", default="", help="TF model to use")
53 self.parser.add_argument("--delete", action="store",
54 dest="delete", default="", help="delete model in TFaaS")
55 self.parser.add_argument("--models", action="store_true",
56 dest="models", default=False, help="show existing models in TFaaS")
57 self.parser.add_argument("--verbose", action="store_true",
58 dest="verbose", default=False, help="verbose output")
59 msg = 'specify private key file name, default $X509_USER_PROXY'
60 self.parser.add_argument("--key", action="store",
61 default=x509(), dest="ckey", help=msg)
62 msg = 'specify private certificate file name, default $X509_USER_PROXY'
63 self.parser.add_argument("--cert", action="store",
64 default=x509(), dest="cert", help=msg)
65 default_ca = os.environ.get("X509_CERT_DIR")
66 if not default_ca or not os.path.exists(default_ca):
67 default_ca = "/etc/grid-security/certificates"
68 if not os.path.exists(default_ca):
69 default_ca = ""
70 if default_ca:
71 msg = 'specify CA path, default currently is %s' % default_ca
72 else:
73 msg = 'specify CA path; defaults to system CAs.'
74 self.parser.add_argument("--capath", action="store",
75 default=default_ca, dest="capath", help=msg)
76 msg = 'specify number of retries upon busy DAS server message'
77
78 class HTTPSClientAuthHandler(urllib2.HTTPSHandler):
79 """
80 Simple HTTPS client authentication class based on provided
81 key/ca information
82 """
83 def __init__(self, key=None, cert=None, capath=None, level=0):
84 if level > 0:
85 urllib2.HTTPSHandler.__init__(self, debuglevel=1)
86 else:
87 urllib2.HTTPSHandler.__init__(self)
88 self.key = key
89 self.cert = cert
90 self.capath = capath
91
92 def https_open(self, req):
93 """Open request method"""
94 #Rather than pass in a reference to a connection class, we pass in
95 # a reference to a function which, for all intents and purposes,
96 # will behave as a constructor
97 return self.do_open(self.get_connection, req)
98
99 def get_connection(self, host, timeout=300):
100 """Connection method"""
101 if self.key and self.cert and not self.capath:
102 return httplib.HTTPSConnection(host, key_file=self.key,
103 cert_file=self.cert)
104 elif self.cert and self.capath:
105 context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
106 context.load_verify_locations(capath=self.capath)
107 context.load_cert_chain(self.cert)
108 return httplib.HTTPSConnection(host, context=context)
109 return httplib.HTTPSConnection(host)
110
111 def x509():
112 "Helper function to get x509 either from env or tmp file"
113 proxy = os.environ.get('X509_USER_PROXY', '')
114 if not proxy:
115 proxy = '/tmp/x509up_u%s' % pwd.getpwuid( os.getuid() ).pw_uid
116 if not os.path.isfile(proxy):
117 return ''
118 return proxy
119
120 def check_auth(key):
121 "Check if user runs das_client with key/cert and warn users to switch"
122 if not key:
123 msg = "WARNING: tfaas_client is running without user credentials/X509 proxy, create proxy via 'voms-proxy-init -voms cms -rfc'"
124 print(msg)
125
126 def fullpath(path):
127 "Expand path to full path"
128 if path and path[0] == '~':
129 path = path.replace('~', '')
130 path = path[1:] if path[0] == '/' else path
131 path = os.path.join(os.environ['HOME'], path)
132 return path
133
134 def choose_boundary():
135 """
136 Helper function to replace deprecated mimetools.choose_boundary
137 https://stackoverflow.com/questions/27099290/where-is-mimetools-choose-boundary-function-in-python3
138 https://docs.python.org/2.7/library/mimetools.html?highlight=choose_boundary#mimetools.choose_boundary
139 >>> mimetools.choose_boundary()
140 '192.168.1.191.502.42035.1678979116.376.1'
141 """
142 # we will return any random string
143 import uuid
144 return str(uuid.uuid4())
145
146 # credit: https://pymotw.com/2/urllib2/#uploading-files
147 class MultiPartForm(object):
148 """Accumulate the data to be used when posting a form."""
149
150 def __init__(self):
151 self.form_fields = []
152 self.files = []
153 if sys.version.startswith('3.'):
154 self.boundary = choose_boundary()
155 else:
156 self.boundary = mimetools.choose_boundary()
157 return
158
159 def get_content_type(self):
160 return 'multipart/form-data; boundary=%s' % self.boundary
161
162 def add_field(self, name, value):
163 """Add a simple field to the form data."""
164 self.form_fields.append((name, value))
165 return
166
167 def add_file(self, fieldname, filename, fileHandle, mimetype=None):
168 """Add a file to be uploaded."""
169 body = fileHandle.read()
170 if mimetype is None:
171 mimetype = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
172 if mimetype == 'application/octet-stream':
173 body = binascii.b2a_base64(body)
174 # if isinstance(body, bytes):
175 # body = body.decode("utf-8")
176 self.files.append((fieldname, filename, mimetype, body))
177 return
178
179 def __str__(self):
180 """Return a string representing the form data, including attached files."""
181 # Build a list of lists, each containing "lines" of the
182 # request. Each part is separated by a boundary string.
183 # Once the list is built, return a string where each
184 # line is separated by '\r\n'.
185 parts = []
186 part_boundary = '--' + self.boundary
187
188 # Add the form fields
189 parts.extend(
190 [ part_boundary,
191 'Content-Disposition: form-data; name="%s"' % name,
192 '',
193 value,
194 ]
195 for name, value in self.form_fields
196 )
197
198 # Add the files to upload
199 # here we use form-data content disposition instead of file one
200 # since this is how we define handlers in our Go server
201 # for more info see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition
202 parts.extend(
203 [ part_boundary,
204 'Content-Disposition: form-data; name="%s"; filename="%s"' % \
205 (field_name, filename),
206 'Content-Type: %s' % content_type,
207 '',
208 body,
209 ]
210 for field_name, filename, content_type, body in self.files
211 )
212
213 # Flatten the list and add closing boundary marker,
214 # then return CR+LF separated data
215 flattened = list(itertools.chain(*parts))
216 flattened.append('--' + self.boundary + '--')
217 flattened.append('')
218 return '\r\n'.join(flattened)
219
220 def models(host, verbose=None, ckey=None, cert=None, capath=None):
221 "models API shows models from TFaaS server"
222 url = host + '/models'
223 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', ''))
224 headers = {"Accept": "application/json", "User-Agent": client}
225 if verbose:
226 print("URL : %s" % url)
227 encoded_data = json.dumps({})
228 return getdata(url, headers, encoded_data, ckey, cert, capath, verbose, 'GET')
229
230 def delete(host, model, verbose=None, ckey=None, cert=None, capath=None):
231 "delete API deletes given model in TFaaS server"
232 url = host + '/delete'
233 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', ''))
234 headers = {"User-Agent": client}
235 if verbose:
236 print("URL : %s" % url)
237 print("model : %s" % model)
238 form = MultiPartForm()
239 form.add_field('model', model)
240 edata = str(form)
241 headers['Content-length'] = len(edata)
242 headers['Content-Type'] = form.get_content_type()
243 return getdata(url, headers, edata, ckey, cert, capath, verbose, method='DELETE')
244
245 def bundle(host, ifile, verbose=None, ckey=None, cert=None, capath=None):
246 "bundle API uploads given bundle model files to TFaaS server"
247 url = host + '/upload'
248 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', ''))
249 headers = {"User-Agent": client, "Content-Encoding": "gzip", "Content-Type": "application/octet-stream"}
250 data = open(ifile, 'rb').read()
251 return getdata(url, headers, data, ckey, cert, capath, verbose)
252
253 def upload(host, ifile, verbose=None, ckey=None, cert=None, capath=None):
254 "upload API uploads given model to TFaaS server"
255 url = host + '/upload'
256 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', ''))
257 headers = {"User-Agent": client}
258 params = json.load(open(ifile))
259 if verbose:
260 print("URL : %s" % url)
261 print("ifile : %s" % ifile)
262 print("params: %s" % json.dumps(params))
263
264 form = MultiPartForm()
265 for key in params.keys():
266 if key in ['model', 'labels', 'params']:
267 flag = 'r'
268 if key == 'model':
269 flag = 'rb'
270 name = params[key]
271 form.add_file(key, name, fileHandle=open(name, flag))
272 else:
273 form.add_field(key, params[key])
274 edata = str(form)
275 headers['Content-length'] = len(edata)
276 headers['Content-Type'] = form.get_content_type()
277 headers['Content-Encoding'] = 'base64'
278 return getdata(url, headers, edata, ckey, cert, capath, verbose)
279
280 def predict(host, ifile, model, verbose=None, ckey=None, cert=None, capath=None):
281 "predict API get predictions from TFaaS server"
282 url = host + '/json'
283 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', ''))
284 headers = {"Accept": "application/json", "User-Agent": client}
285 params = json.load(open(ifile))
286 if model: # overwrite model name in given input file
287 params['model'] = model
288 if verbose:
289 print("URL : %s" % url)
290 print("ifile : %s" % ifile)
291 print("params: %s" % json.dumps(params))
292 encoded_data = json.dumps(params)
293 return getdata(url, headers, encoded_data, ckey, cert, capath, verbose)
294
295 def predictImage(host, ifile, model, verbose=None, ckey=None, cert=None, capath=None):
296 "predict API get predictions from TFaaS server"
297 url = host + '/image'
298 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', ''))
299 headers = {"Accept": "application/json", "User-Agent": client}
300 if verbose:
301 print("URL : %s" % url)
302 print("ifile : %s" % ifile)
303 print("model : %s" % model)
304 form = MultiPartForm()
305 # form.add_file('image', ifile, fileHandle=open(ifile, 'r'))
306 form.add_file('image', ifile, fileHandle=open(ifile, 'rb'))
307 form.add_field('model', model)
308 edata = str(form)
309 headers['Content-length'] = len(edata)
310 headers['Content-Type'] = form.get_content_type()
311 return getdata(url, headers, edata, ckey, cert, capath, verbose)
312
313 def getdata(url, headers, encoded_data, ckey, cert, capath, verbose=None, method='POST'):
314 "helper function to use in predict/upload APIs, it place given URL call to the server"
315 debug = 1 if verbose else 0
316 req = urllib2.Request(url=url, headers=headers, data=encoded_data)
317 if method == 'DELETE':
318 req.get_method = lambda: 'DELETE'
319 elif method == 'GET':
320 req = urllib2.Request(url=url, headers=headers)
321 if ckey and cert:
322 ckey = fullpath(ckey)
323 cert = fullpath(cert)
324 http_hdlr = HTTPSClientAuthHandler(ckey, cert, capath, debug)
325 elif cert and capath:
326 cert = fullpath(cert)
327 http_hdlr = HTTPSClientAuthHandler(ckey, cert, capath, debug)
328 else:
329 http_hdlr = urllib2.HTTPHandler(debuglevel=debug)
330 proxy_handler = urllib2.ProxyHandler({})
331 cookie_jar = cookielib.CookieJar()
332 cookie_handler = urllib2.HTTPCookieProcessor(cookie_jar)
333 data = {}
334 try:
335 opener = urllib2.build_opener(http_hdlr, proxy_handler, cookie_handler)
336 fdesc = opener.open(req)
337 if url.endswith('json'):
338 data = json.load(fdesc)
339 else:
340 data = fdesc.read()
341 fdesc.close()
342 except urllib2.HTTPError as error:
343 print(error.read())
344 sys.exit(1)
345 if url.endswith('json'):
346 return json.dumps(data)
347 return data
348
349 def main():
350 "Main function"
351 optmgr = OptionParser()
352 opts = optmgr.parser.parse_args()
353 check_auth(opts.ckey)
354 res = ''
355 if opts.upload:
356 res = upload(opts.url, opts.upload, opts.verbose, opts.ckey, opts.cert, opts.capath)
357 if opts.bundle:
358 res = bundle(opts.url, opts.bundle, opts.verbose, opts.ckey, opts.cert, opts.capath)
359 elif opts.delete:
360 res = delete(opts.url, opts.delete, opts.verbose, opts.ckey, opts.cert, opts.capath)
361 elif opts.models:
362 res = models(opts.url, opts.verbose, opts.ckey, opts.cert, opts.capath)
363 elif opts.predict:
364 res = predict(opts.url, opts.predict, opts.model, opts.verbose, opts.ckey, opts.cert, opts.capath)
365 elif opts.image:
366 res = predictImage(opts.url, opts.image, opts.model, opts.verbose, opts.ckey, opts.cert, opts.capath)
367 if res:
368 print(res)
369
370 if __name__ == '__main__':
371 main()