Mercurial > repos > kls286 > chap_test_20230328
comparison MLaaS/tfaas_client.py @ 0:cbbe42422d56 draft
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
author | kls286 |
---|---|
date | Tue, 28 Mar 2023 15:07:30 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:cbbe42422d56 |
---|---|
1 #!/usr/bin/env python | |
2 #-*- coding: utf-8 -*- | |
3 #pylint: disable= | |
4 """ | |
5 File : tfaas_client.py | |
6 Author : Valentin Kuznetsov <vkuznet AT gmail dot com> | |
7 Description: simple python client to communicate with TFaaS server | |
8 """ | |
9 | |
10 # system modules | |
11 import os | |
12 import sys | |
13 import pwd | |
14 import ssl | |
15 import json | |
16 import binascii | |
17 import argparse | |
18 import itertools | |
19 import mimetypes | |
20 if sys.version_info < (2, 7): | |
21 raise Exception("TFaaS client requires python 2.7 or greater") | |
22 # python 3 | |
23 if sys.version.startswith('3.'): | |
24 import urllib.request as urllib2 | |
25 import urllib.parse as urllib | |
26 import http.client as httplib | |
27 import http.cookiejar as cookielib | |
28 else: | |
29 import mimetools | |
30 import urllib | |
31 import urllib2 | |
32 import httplib | |
33 import cookielib | |
34 | |
35 TFAAS_CLIENT = 'tfaas-client/1.1::python/%s.%s' % sys.version_info[:2] | |
36 | |
37 class OptionParser(): | |
38 def __init__(self): | |
39 "User based option parser" | |
40 self.parser = argparse.ArgumentParser(prog='PROG') | |
41 self.parser.add_argument("--url", action="store", | |
42 dest="url", default="", help="TFaaS URL") | |
43 self.parser.add_argument("--upload", action="store", | |
44 dest="upload", default="", help="upload model to TFaaS") | |
45 self.parser.add_argument("--bundle", action="store", | |
46 dest="bundle", default="", help="upload bundle ML files to TFaaS") | |
47 self.parser.add_argument("--predict", action="store", | |
48 dest="predict", default="", help="fetch prediction from TFaaS") | |
49 self.parser.add_argument("--image", action="store", | |
50 dest="image", default="", help="fetch prediction for given image") | |
51 self.parser.add_argument("--model", action="store", | |
52 dest="model", default="", help="TF model to use") | |
53 self.parser.add_argument("--delete", action="store", | |
54 dest="delete", default="", help="delete model in TFaaS") | |
55 self.parser.add_argument("--models", action="store_true", | |
56 dest="models", default=False, help="show existing models in TFaaS") | |
57 self.parser.add_argument("--verbose", action="store_true", | |
58 dest="verbose", default=False, help="verbose output") | |
59 msg = 'specify private key file name, default $X509_USER_PROXY' | |
60 self.parser.add_argument("--key", action="store", | |
61 default=x509(), dest="ckey", help=msg) | |
62 msg = 'specify private certificate file name, default $X509_USER_PROXY' | |
63 self.parser.add_argument("--cert", action="store", | |
64 default=x509(), dest="cert", help=msg) | |
65 default_ca = os.environ.get("X509_CERT_DIR") | |
66 if not default_ca or not os.path.exists(default_ca): | |
67 default_ca = "/etc/grid-security/certificates" | |
68 if not os.path.exists(default_ca): | |
69 default_ca = "" | |
70 if default_ca: | |
71 msg = 'specify CA path, default currently is %s' % default_ca | |
72 else: | |
73 msg = 'specify CA path; defaults to system CAs.' | |
74 self.parser.add_argument("--capath", action="store", | |
75 default=default_ca, dest="capath", help=msg) | |
76 msg = 'specify number of retries upon busy DAS server message' | |
77 | |
78 class HTTPSClientAuthHandler(urllib2.HTTPSHandler): | |
79 """ | |
80 Simple HTTPS client authentication class based on provided | |
81 key/ca information | |
82 """ | |
83 def __init__(self, key=None, cert=None, capath=None, level=0): | |
84 if level > 0: | |
85 urllib2.HTTPSHandler.__init__(self, debuglevel=1) | |
86 else: | |
87 urllib2.HTTPSHandler.__init__(self) | |
88 self.key = key | |
89 self.cert = cert | |
90 self.capath = capath | |
91 | |
92 def https_open(self, req): | |
93 """Open request method""" | |
94 #Rather than pass in a reference to a connection class, we pass in | |
95 # a reference to a function which, for all intents and purposes, | |
96 # will behave as a constructor | |
97 return self.do_open(self.get_connection, req) | |
98 | |
99 def get_connection(self, host, timeout=300): | |
100 """Connection method""" | |
101 if self.key and self.cert and not self.capath: | |
102 return httplib.HTTPSConnection(host, key_file=self.key, | |
103 cert_file=self.cert) | |
104 elif self.cert and self.capath: | |
105 context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) | |
106 context.load_verify_locations(capath=self.capath) | |
107 context.load_cert_chain(self.cert) | |
108 return httplib.HTTPSConnection(host, context=context) | |
109 return httplib.HTTPSConnection(host) | |
110 | |
111 def x509(): | |
112 "Helper function to get x509 either from env or tmp file" | |
113 proxy = os.environ.get('X509_USER_PROXY', '') | |
114 if not proxy: | |
115 proxy = '/tmp/x509up_u%s' % pwd.getpwuid( os.getuid() ).pw_uid | |
116 if not os.path.isfile(proxy): | |
117 return '' | |
118 return proxy | |
119 | |
120 def check_auth(key): | |
121 "Check if user runs das_client with key/cert and warn users to switch" | |
122 if not key: | |
123 msg = "WARNING: tfaas_client is running without user credentials/X509 proxy, create proxy via 'voms-proxy-init -voms cms -rfc'" | |
124 print(msg) | |
125 | |
126 def fullpath(path): | |
127 "Expand path to full path" | |
128 if path and path[0] == '~': | |
129 path = path.replace('~', '') | |
130 path = path[1:] if path[0] == '/' else path | |
131 path = os.path.join(os.environ['HOME'], path) | |
132 return path | |
133 | |
134 def choose_boundary(): | |
135 """ | |
136 Helper function to replace deprecated mimetools.choose_boundary | |
137 https://stackoverflow.com/questions/27099290/where-is-mimetools-choose-boundary-function-in-python3 | |
138 https://docs.python.org/2.7/library/mimetools.html?highlight=choose_boundary#mimetools.choose_boundary | |
139 >>> mimetools.choose_boundary() | |
140 '192.168.1.191.502.42035.1678979116.376.1' | |
141 """ | |
142 # we will return any random string | |
143 import uuid | |
144 return str(uuid.uuid4()) | |
145 | |
146 # credit: https://pymotw.com/2/urllib2/#uploading-files | |
147 class MultiPartForm(object): | |
148 """Accumulate the data to be used when posting a form.""" | |
149 | |
150 def __init__(self): | |
151 self.form_fields = [] | |
152 self.files = [] | |
153 if sys.version.startswith('3.'): | |
154 self.boundary = choose_boundary() | |
155 else: | |
156 self.boundary = mimetools.choose_boundary() | |
157 return | |
158 | |
159 def get_content_type(self): | |
160 return 'multipart/form-data; boundary=%s' % self.boundary | |
161 | |
162 def add_field(self, name, value): | |
163 """Add a simple field to the form data.""" | |
164 self.form_fields.append((name, value)) | |
165 return | |
166 | |
167 def add_file(self, fieldname, filename, fileHandle, mimetype=None): | |
168 """Add a file to be uploaded.""" | |
169 body = fileHandle.read() | |
170 if mimetype is None: | |
171 mimetype = mimetypes.guess_type(filename)[0] or 'application/octet-stream' | |
172 if mimetype == 'application/octet-stream': | |
173 body = binascii.b2a_base64(body) | |
174 # if isinstance(body, bytes): | |
175 # body = body.decode("utf-8") | |
176 self.files.append((fieldname, filename, mimetype, body)) | |
177 return | |
178 | |
179 def __str__(self): | |
180 """Return a string representing the form data, including attached files.""" | |
181 # Build a list of lists, each containing "lines" of the | |
182 # request. Each part is separated by a boundary string. | |
183 # Once the list is built, return a string where each | |
184 # line is separated by '\r\n'. | |
185 parts = [] | |
186 part_boundary = '--' + self.boundary | |
187 | |
188 # Add the form fields | |
189 parts.extend( | |
190 [ part_boundary, | |
191 'Content-Disposition: form-data; name="%s"' % name, | |
192 '', | |
193 value, | |
194 ] | |
195 for name, value in self.form_fields | |
196 ) | |
197 | |
198 # Add the files to upload | |
199 # here we use form-data content disposition instead of file one | |
200 # since this is how we define handlers in our Go server | |
201 # for more info see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition | |
202 parts.extend( | |
203 [ part_boundary, | |
204 'Content-Disposition: form-data; name="%s"; filename="%s"' % \ | |
205 (field_name, filename), | |
206 'Content-Type: %s' % content_type, | |
207 '', | |
208 body, | |
209 ] | |
210 for field_name, filename, content_type, body in self.files | |
211 ) | |
212 | |
213 # Flatten the list and add closing boundary marker, | |
214 # then return CR+LF separated data | |
215 flattened = list(itertools.chain(*parts)) | |
216 flattened.append('--' + self.boundary + '--') | |
217 flattened.append('') | |
218 return '\r\n'.join(flattened) | |
219 | |
220 def models(host, verbose=None, ckey=None, cert=None, capath=None): | |
221 "models API shows models from TFaaS server" | |
222 url = host + '/models' | |
223 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) | |
224 headers = {"Accept": "application/json", "User-Agent": client} | |
225 if verbose: | |
226 print("URL : %s" % url) | |
227 encoded_data = json.dumps({}) | |
228 return getdata(url, headers, encoded_data, ckey, cert, capath, verbose, 'GET') | |
229 | |
230 def delete(host, model, verbose=None, ckey=None, cert=None, capath=None): | |
231 "delete API deletes given model in TFaaS server" | |
232 url = host + '/delete' | |
233 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) | |
234 headers = {"User-Agent": client} | |
235 if verbose: | |
236 print("URL : %s" % url) | |
237 print("model : %s" % model) | |
238 form = MultiPartForm() | |
239 form.add_field('model', model) | |
240 edata = str(form) | |
241 headers['Content-length'] = len(edata) | |
242 headers['Content-Type'] = form.get_content_type() | |
243 return getdata(url, headers, edata, ckey, cert, capath, verbose, method='DELETE') | |
244 | |
245 def bundle(host, ifile, verbose=None, ckey=None, cert=None, capath=None): | |
246 "bundle API uploads given bundle model files to TFaaS server" | |
247 url = host + '/upload' | |
248 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) | |
249 headers = {"User-Agent": client, "Content-Encoding": "gzip", "Content-Type": "application/octet-stream"} | |
250 data = open(ifile, 'rb').read() | |
251 return getdata(url, headers, data, ckey, cert, capath, verbose) | |
252 | |
253 def upload(host, ifile, verbose=None, ckey=None, cert=None, capath=None): | |
254 "upload API uploads given model to TFaaS server" | |
255 url = host + '/upload' | |
256 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) | |
257 headers = {"User-Agent": client} | |
258 params = json.load(open(ifile)) | |
259 if verbose: | |
260 print("URL : %s" % url) | |
261 print("ifile : %s" % ifile) | |
262 print("params: %s" % json.dumps(params)) | |
263 | |
264 form = MultiPartForm() | |
265 for key in params.keys(): | |
266 if key in ['model', 'labels', 'params']: | |
267 flag = 'r' | |
268 if key == 'model': | |
269 flag = 'rb' | |
270 name = params[key] | |
271 form.add_file(key, name, fileHandle=open(name, flag)) | |
272 else: | |
273 form.add_field(key, params[key]) | |
274 edata = str(form) | |
275 headers['Content-length'] = len(edata) | |
276 headers['Content-Type'] = form.get_content_type() | |
277 headers['Content-Encoding'] = 'base64' | |
278 return getdata(url, headers, edata, ckey, cert, capath, verbose) | |
279 | |
280 def predict(host, ifile, model, verbose=None, ckey=None, cert=None, capath=None): | |
281 "predict API get predictions from TFaaS server" | |
282 url = host + '/json' | |
283 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) | |
284 headers = {"Accept": "application/json", "User-Agent": client} | |
285 params = json.load(open(ifile)) | |
286 if model: # overwrite model name in given input file | |
287 params['model'] = model | |
288 if verbose: | |
289 print("URL : %s" % url) | |
290 print("ifile : %s" % ifile) | |
291 print("params: %s" % json.dumps(params)) | |
292 encoded_data = json.dumps(params) | |
293 return getdata(url, headers, encoded_data, ckey, cert, capath, verbose) | |
294 | |
295 def predictImage(host, ifile, model, verbose=None, ckey=None, cert=None, capath=None): | |
296 "predict API get predictions from TFaaS server" | |
297 url = host + '/image' | |
298 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) | |
299 headers = {"Accept": "application/json", "User-Agent": client} | |
300 if verbose: | |
301 print("URL : %s" % url) | |
302 print("ifile : %s" % ifile) | |
303 print("model : %s" % model) | |
304 form = MultiPartForm() | |
305 # form.add_file('image', ifile, fileHandle=open(ifile, 'r')) | |
306 form.add_file('image', ifile, fileHandle=open(ifile, 'rb')) | |
307 form.add_field('model', model) | |
308 edata = str(form) | |
309 headers['Content-length'] = len(edata) | |
310 headers['Content-Type'] = form.get_content_type() | |
311 return getdata(url, headers, edata, ckey, cert, capath, verbose) | |
312 | |
313 def getdata(url, headers, encoded_data, ckey, cert, capath, verbose=None, method='POST'): | |
314 "helper function to use in predict/upload APIs, it place given URL call to the server" | |
315 debug = 1 if verbose else 0 | |
316 req = urllib2.Request(url=url, headers=headers, data=encoded_data) | |
317 if method == 'DELETE': | |
318 req.get_method = lambda: 'DELETE' | |
319 elif method == 'GET': | |
320 req = urllib2.Request(url=url, headers=headers) | |
321 if ckey and cert: | |
322 ckey = fullpath(ckey) | |
323 cert = fullpath(cert) | |
324 http_hdlr = HTTPSClientAuthHandler(ckey, cert, capath, debug) | |
325 elif cert and capath: | |
326 cert = fullpath(cert) | |
327 http_hdlr = HTTPSClientAuthHandler(ckey, cert, capath, debug) | |
328 else: | |
329 http_hdlr = urllib2.HTTPHandler(debuglevel=debug) | |
330 proxy_handler = urllib2.ProxyHandler({}) | |
331 cookie_jar = cookielib.CookieJar() | |
332 cookie_handler = urllib2.HTTPCookieProcessor(cookie_jar) | |
333 data = {} | |
334 try: | |
335 opener = urllib2.build_opener(http_hdlr, proxy_handler, cookie_handler) | |
336 fdesc = opener.open(req) | |
337 if url.endswith('json'): | |
338 data = json.load(fdesc) | |
339 else: | |
340 data = fdesc.read() | |
341 fdesc.close() | |
342 except urllib2.HTTPError as error: | |
343 print(error.read()) | |
344 sys.exit(1) | |
345 if url.endswith('json'): | |
346 return json.dumps(data) | |
347 return data | |
348 | |
349 def main(): | |
350 "Main function" | |
351 optmgr = OptionParser() | |
352 opts = optmgr.parser.parse_args() | |
353 check_auth(opts.ckey) | |
354 res = '' | |
355 if opts.upload: | |
356 res = upload(opts.url, opts.upload, opts.verbose, opts.ckey, opts.cert, opts.capath) | |
357 if opts.bundle: | |
358 res = bundle(opts.url, opts.bundle, opts.verbose, opts.ckey, opts.cert, opts.capath) | |
359 elif opts.delete: | |
360 res = delete(opts.url, opts.delete, opts.verbose, opts.ckey, opts.cert, opts.capath) | |
361 elif opts.models: | |
362 res = models(opts.url, opts.verbose, opts.ckey, opts.cert, opts.capath) | |
363 elif opts.predict: | |
364 res = predict(opts.url, opts.predict, opts.model, opts.verbose, opts.ckey, opts.cert, opts.capath) | |
365 elif opts.image: | |
366 res = predictImage(opts.url, opts.image, opts.model, opts.verbose, opts.ckey, opts.cert, opts.capath) | |
367 if res: | |
368 print(res) | |
369 | |
370 if __name__ == '__main__': | |
371 main() |