Mercurial > repos > kls286 > chap_test_20230328
comparison build/lib/MLaaS/tfaas_client.py @ 0:cbbe42422d56 draft
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
| author | kls286 |
|---|---|
| date | Tue, 28 Mar 2023 15:07:30 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:cbbe42422d56 |
|---|---|
| 1 #!/usr/bin/env python | |
| 2 #-*- coding: utf-8 -*- | |
| 3 #pylint: disable= | |
| 4 """ | |
| 5 File : tfaas_client.py | |
| 6 Author : Valentin Kuznetsov <vkuznet AT gmail dot com> | |
| 7 Description: simple python client to communicate with TFaaS server | |
| 8 """ | |
| 9 | |
| 10 # system modules | |
| 11 import os | |
| 12 import sys | |
| 13 import pwd | |
| 14 import ssl | |
| 15 import json | |
| 16 import binascii | |
| 17 import argparse | |
| 18 import itertools | |
| 19 import mimetypes | |
| 20 if sys.version_info < (2, 7): | |
| 21 raise Exception("TFaaS client requires python 2.7 or greater") | |
| 22 # python 3 | |
| 23 if sys.version.startswith('3.'): | |
| 24 import urllib.request as urllib2 | |
| 25 import urllib.parse as urllib | |
| 26 import http.client as httplib | |
| 27 import http.cookiejar as cookielib | |
| 28 else: | |
| 29 import mimetools | |
| 30 import urllib | |
| 31 import urllib2 | |
| 32 import httplib | |
| 33 import cookielib | |
| 34 | |
| 35 TFAAS_CLIENT = 'tfaas-client/1.1::python/%s.%s' % sys.version_info[:2] | |
| 36 | |
| 37 class OptionParser(): | |
| 38 def __init__(self): | |
| 39 "User based option parser" | |
| 40 self.parser = argparse.ArgumentParser(prog='PROG') | |
| 41 self.parser.add_argument("--url", action="store", | |
| 42 dest="url", default="", help="TFaaS URL") | |
| 43 self.parser.add_argument("--upload", action="store", | |
| 44 dest="upload", default="", help="upload model to TFaaS") | |
| 45 self.parser.add_argument("--bundle", action="store", | |
| 46 dest="bundle", default="", help="upload bundle ML files to TFaaS") | |
| 47 self.parser.add_argument("--predict", action="store", | |
| 48 dest="predict", default="", help="fetch prediction from TFaaS") | |
| 49 self.parser.add_argument("--image", action="store", | |
| 50 dest="image", default="", help="fetch prediction for given image") | |
| 51 self.parser.add_argument("--model", action="store", | |
| 52 dest="model", default="", help="TF model to use") | |
| 53 self.parser.add_argument("--delete", action="store", | |
| 54 dest="delete", default="", help="delete model in TFaaS") | |
| 55 self.parser.add_argument("--models", action="store_true", | |
| 56 dest="models", default=False, help="show existing models in TFaaS") | |
| 57 self.parser.add_argument("--verbose", action="store_true", | |
| 58 dest="verbose", default=False, help="verbose output") | |
| 59 msg = 'specify private key file name, default $X509_USER_PROXY' | |
| 60 self.parser.add_argument("--key", action="store", | |
| 61 default=x509(), dest="ckey", help=msg) | |
| 62 msg = 'specify private certificate file name, default $X509_USER_PROXY' | |
| 63 self.parser.add_argument("--cert", action="store", | |
| 64 default=x509(), dest="cert", help=msg) | |
| 65 default_ca = os.environ.get("X509_CERT_DIR") | |
| 66 if not default_ca or not os.path.exists(default_ca): | |
| 67 default_ca = "/etc/grid-security/certificates" | |
| 68 if not os.path.exists(default_ca): | |
| 69 default_ca = "" | |
| 70 if default_ca: | |
| 71 msg = 'specify CA path, default currently is %s' % default_ca | |
| 72 else: | |
| 73 msg = 'specify CA path; defaults to system CAs.' | |
| 74 self.parser.add_argument("--capath", action="store", | |
| 75 default=default_ca, dest="capath", help=msg) | |
| 76 msg = 'specify number of retries upon busy DAS server message' | |
| 77 | |
| 78 class HTTPSClientAuthHandler(urllib2.HTTPSHandler): | |
| 79 """ | |
| 80 Simple HTTPS client authentication class based on provided | |
| 81 key/ca information | |
| 82 """ | |
| 83 def __init__(self, key=None, cert=None, capath=None, level=0): | |
| 84 if level > 0: | |
| 85 urllib2.HTTPSHandler.__init__(self, debuglevel=1) | |
| 86 else: | |
| 87 urllib2.HTTPSHandler.__init__(self) | |
| 88 self.key = key | |
| 89 self.cert = cert | |
| 90 self.capath = capath | |
| 91 | |
| 92 def https_open(self, req): | |
| 93 """Open request method""" | |
| 94 #Rather than pass in a reference to a connection class, we pass in | |
| 95 # a reference to a function which, for all intents and purposes, | |
| 96 # will behave as a constructor | |
| 97 return self.do_open(self.get_connection, req) | |
| 98 | |
| 99 def get_connection(self, host, timeout=300): | |
| 100 """Connection method""" | |
| 101 if self.key and self.cert and not self.capath: | |
| 102 return httplib.HTTPSConnection(host, key_file=self.key, | |
| 103 cert_file=self.cert) | |
| 104 elif self.cert and self.capath: | |
| 105 context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) | |
| 106 context.load_verify_locations(capath=self.capath) | |
| 107 context.load_cert_chain(self.cert) | |
| 108 return httplib.HTTPSConnection(host, context=context) | |
| 109 return httplib.HTTPSConnection(host) | |
| 110 | |
| 111 def x509(): | |
| 112 "Helper function to get x509 either from env or tmp file" | |
| 113 proxy = os.environ.get('X509_USER_PROXY', '') | |
| 114 if not proxy: | |
| 115 proxy = '/tmp/x509up_u%s' % pwd.getpwuid( os.getuid() ).pw_uid | |
| 116 if not os.path.isfile(proxy): | |
| 117 return '' | |
| 118 return proxy | |
| 119 | |
| 120 def check_auth(key): | |
| 121 "Check if user runs das_client with key/cert and warn users to switch" | |
| 122 if not key: | |
| 123 msg = "WARNING: tfaas_client is running without user credentials/X509 proxy, create proxy via 'voms-proxy-init -voms cms -rfc'" | |
| 124 print(msg) | |
| 125 | |
| 126 def fullpath(path): | |
| 127 "Expand path to full path" | |
| 128 if path and path[0] == '~': | |
| 129 path = path.replace('~', '') | |
| 130 path = path[1:] if path[0] == '/' else path | |
| 131 path = os.path.join(os.environ['HOME'], path) | |
| 132 return path | |
| 133 | |
| 134 def choose_boundary(): | |
| 135 """ | |
| 136 Helper function to replace deprecated mimetools.choose_boundary | |
| 137 https://stackoverflow.com/questions/27099290/where-is-mimetools-choose-boundary-function-in-python3 | |
| 138 https://docs.python.org/2.7/library/mimetools.html?highlight=choose_boundary#mimetools.choose_boundary | |
| 139 >>> mimetools.choose_boundary() | |
| 140 '192.168.1.191.502.42035.1678979116.376.1' | |
| 141 """ | |
| 142 # we will return any random string | |
| 143 import uuid | |
| 144 return str(uuid.uuid4()) | |
| 145 | |
| 146 # credit: https://pymotw.com/2/urllib2/#uploading-files | |
| 147 class MultiPartForm(object): | |
| 148 """Accumulate the data to be used when posting a form.""" | |
| 149 | |
| 150 def __init__(self): | |
| 151 self.form_fields = [] | |
| 152 self.files = [] | |
| 153 if sys.version.startswith('3.'): | |
| 154 self.boundary = choose_boundary() | |
| 155 else: | |
| 156 self.boundary = mimetools.choose_boundary() | |
| 157 return | |
| 158 | |
| 159 def get_content_type(self): | |
| 160 return 'multipart/form-data; boundary=%s' % self.boundary | |
| 161 | |
| 162 def add_field(self, name, value): | |
| 163 """Add a simple field to the form data.""" | |
| 164 self.form_fields.append((name, value)) | |
| 165 return | |
| 166 | |
| 167 def add_file(self, fieldname, filename, fileHandle, mimetype=None): | |
| 168 """Add a file to be uploaded.""" | |
| 169 body = fileHandle.read() | |
| 170 if mimetype is None: | |
| 171 mimetype = mimetypes.guess_type(filename)[0] or 'application/octet-stream' | |
| 172 if mimetype == 'application/octet-stream': | |
| 173 body = binascii.b2a_base64(body) | |
| 174 # if isinstance(body, bytes): | |
| 175 # body = body.decode("utf-8") | |
| 176 self.files.append((fieldname, filename, mimetype, body)) | |
| 177 return | |
| 178 | |
| 179 def __str__(self): | |
| 180 """Return a string representing the form data, including attached files.""" | |
| 181 # Build a list of lists, each containing "lines" of the | |
| 182 # request. Each part is separated by a boundary string. | |
| 183 # Once the list is built, return a string where each | |
| 184 # line is separated by '\r\n'. | |
| 185 parts = [] | |
| 186 part_boundary = '--' + self.boundary | |
| 187 | |
| 188 # Add the form fields | |
| 189 parts.extend( | |
| 190 [ part_boundary, | |
| 191 'Content-Disposition: form-data; name="%s"' % name, | |
| 192 '', | |
| 193 value, | |
| 194 ] | |
| 195 for name, value in self.form_fields | |
| 196 ) | |
| 197 | |
| 198 # Add the files to upload | |
| 199 # here we use form-data content disposition instead of file one | |
| 200 # since this is how we define handlers in our Go server | |
| 201 # for more info see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition | |
| 202 parts.extend( | |
| 203 [ part_boundary, | |
| 204 'Content-Disposition: form-data; name="%s"; filename="%s"' % \ | |
| 205 (field_name, filename), | |
| 206 'Content-Type: %s' % content_type, | |
| 207 '', | |
| 208 body, | |
| 209 ] | |
| 210 for field_name, filename, content_type, body in self.files | |
| 211 ) | |
| 212 | |
| 213 # Flatten the list and add closing boundary marker, | |
| 214 # then return CR+LF separated data | |
| 215 flattened = list(itertools.chain(*parts)) | |
| 216 flattened.append('--' + self.boundary + '--') | |
| 217 flattened.append('') | |
| 218 return '\r\n'.join(flattened) | |
| 219 | |
| 220 def models(host, verbose=None, ckey=None, cert=None, capath=None): | |
| 221 "models API shows models from TFaaS server" | |
| 222 url = host + '/models' | |
| 223 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) | |
| 224 headers = {"Accept": "application/json", "User-Agent": client} | |
| 225 if verbose: | |
| 226 print("URL : %s" % url) | |
| 227 encoded_data = json.dumps({}) | |
| 228 return getdata(url, headers, encoded_data, ckey, cert, capath, verbose, 'GET') | |
| 229 | |
| 230 def delete(host, model, verbose=None, ckey=None, cert=None, capath=None): | |
| 231 "delete API deletes given model in TFaaS server" | |
| 232 url = host + '/delete' | |
| 233 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) | |
| 234 headers = {"User-Agent": client} | |
| 235 if verbose: | |
| 236 print("URL : %s" % url) | |
| 237 print("model : %s" % model) | |
| 238 form = MultiPartForm() | |
| 239 form.add_field('model', model) | |
| 240 edata = str(form) | |
| 241 headers['Content-length'] = len(edata) | |
| 242 headers['Content-Type'] = form.get_content_type() | |
| 243 return getdata(url, headers, edata, ckey, cert, capath, verbose, method='DELETE') | |
| 244 | |
| 245 def bundle(host, ifile, verbose=None, ckey=None, cert=None, capath=None): | |
| 246 "bundle API uploads given bundle model files to TFaaS server" | |
| 247 url = host + '/upload' | |
| 248 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) | |
| 249 headers = {"User-Agent": client, "Content-Encoding": "gzip", "Content-Type": "application/octet-stream"} | |
| 250 data = open(ifile, 'rb').read() | |
| 251 return getdata(url, headers, data, ckey, cert, capath, verbose) | |
| 252 | |
| 253 def upload(host, ifile, verbose=None, ckey=None, cert=None, capath=None): | |
| 254 "upload API uploads given model to TFaaS server" | |
| 255 url = host + '/upload' | |
| 256 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) | |
| 257 headers = {"User-Agent": client} | |
| 258 params = json.load(open(ifile)) | |
| 259 if verbose: | |
| 260 print("URL : %s" % url) | |
| 261 print("ifile : %s" % ifile) | |
| 262 print("params: %s" % json.dumps(params)) | |
| 263 | |
| 264 form = MultiPartForm() | |
| 265 for key in params.keys(): | |
| 266 if key in ['model', 'labels', 'params']: | |
| 267 flag = 'r' | |
| 268 if key == 'model': | |
| 269 flag = 'rb' | |
| 270 name = params[key] | |
| 271 form.add_file(key, name, fileHandle=open(name, flag)) | |
| 272 else: | |
| 273 form.add_field(key, params[key]) | |
| 274 edata = str(form) | |
| 275 headers['Content-length'] = len(edata) | |
| 276 headers['Content-Type'] = form.get_content_type() | |
| 277 headers['Content-Encoding'] = 'base64' | |
| 278 return getdata(url, headers, edata, ckey, cert, capath, verbose) | |
| 279 | |
| 280 def predict(host, ifile, model, verbose=None, ckey=None, cert=None, capath=None): | |
| 281 "predict API get predictions from TFaaS server" | |
| 282 url = host + '/json' | |
| 283 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) | |
| 284 headers = {"Accept": "application/json", "User-Agent": client} | |
| 285 params = json.load(open(ifile)) | |
| 286 if model: # overwrite model name in given input file | |
| 287 params['model'] = model | |
| 288 if verbose: | |
| 289 print("URL : %s" % url) | |
| 290 print("ifile : %s" % ifile) | |
| 291 print("params: %s" % json.dumps(params)) | |
| 292 encoded_data = json.dumps(params) | |
| 293 return getdata(url, headers, encoded_data, ckey, cert, capath, verbose) | |
| 294 | |
| 295 def predictImage(host, ifile, model, verbose=None, ckey=None, cert=None, capath=None): | |
| 296 "predict API get predictions from TFaaS server" | |
| 297 url = host + '/image' | |
| 298 client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) | |
| 299 headers = {"Accept": "application/json", "User-Agent": client} | |
| 300 if verbose: | |
| 301 print("URL : %s" % url) | |
| 302 print("ifile : %s" % ifile) | |
| 303 print("model : %s" % model) | |
| 304 form = MultiPartForm() | |
| 305 # form.add_file('image', ifile, fileHandle=open(ifile, 'r')) | |
| 306 form.add_file('image', ifile, fileHandle=open(ifile, 'rb')) | |
| 307 form.add_field('model', model) | |
| 308 edata = str(form) | |
| 309 headers['Content-length'] = len(edata) | |
| 310 headers['Content-Type'] = form.get_content_type() | |
| 311 return getdata(url, headers, edata, ckey, cert, capath, verbose) | |
| 312 | |
| 313 def getdata(url, headers, encoded_data, ckey, cert, capath, verbose=None, method='POST'): | |
| 314 "helper function to use in predict/upload APIs, it place given URL call to the server" | |
| 315 debug = 1 if verbose else 0 | |
| 316 req = urllib2.Request(url=url, headers=headers, data=encoded_data) | |
| 317 if method == 'DELETE': | |
| 318 req.get_method = lambda: 'DELETE' | |
| 319 elif method == 'GET': | |
| 320 req = urllib2.Request(url=url, headers=headers) | |
| 321 if ckey and cert: | |
| 322 ckey = fullpath(ckey) | |
| 323 cert = fullpath(cert) | |
| 324 http_hdlr = HTTPSClientAuthHandler(ckey, cert, capath, debug) | |
| 325 elif cert and capath: | |
| 326 cert = fullpath(cert) | |
| 327 http_hdlr = HTTPSClientAuthHandler(ckey, cert, capath, debug) | |
| 328 else: | |
| 329 http_hdlr = urllib2.HTTPHandler(debuglevel=debug) | |
| 330 proxy_handler = urllib2.ProxyHandler({}) | |
| 331 cookie_jar = cookielib.CookieJar() | |
| 332 cookie_handler = urllib2.HTTPCookieProcessor(cookie_jar) | |
| 333 data = {} | |
| 334 try: | |
| 335 opener = urllib2.build_opener(http_hdlr, proxy_handler, cookie_handler) | |
| 336 fdesc = opener.open(req) | |
| 337 if url.endswith('json'): | |
| 338 data = json.load(fdesc) | |
| 339 else: | |
| 340 data = fdesc.read() | |
| 341 fdesc.close() | |
| 342 except urllib2.HTTPError as error: | |
| 343 print(error.read()) | |
| 344 sys.exit(1) | |
| 345 if url.endswith('json'): | |
| 346 return json.dumps(data) | |
| 347 return data | |
| 348 | |
| 349 def main(): | |
| 350 "Main function" | |
| 351 optmgr = OptionParser() | |
| 352 opts = optmgr.parser.parse_args() | |
| 353 check_auth(opts.ckey) | |
| 354 res = '' | |
| 355 if opts.upload: | |
| 356 res = upload(opts.url, opts.upload, opts.verbose, opts.ckey, opts.cert, opts.capath) | |
| 357 if opts.bundle: | |
| 358 res = bundle(opts.url, opts.bundle, opts.verbose, opts.ckey, opts.cert, opts.capath) | |
| 359 elif opts.delete: | |
| 360 res = delete(opts.url, opts.delete, opts.verbose, opts.ckey, opts.cert, opts.capath) | |
| 361 elif opts.models: | |
| 362 res = models(opts.url, opts.verbose, opts.ckey, opts.cert, opts.capath) | |
| 363 elif opts.predict: | |
| 364 res = predict(opts.url, opts.predict, opts.model, opts.verbose, opts.ckey, opts.cert, opts.capath) | |
| 365 elif opts.image: | |
| 366 res = predictImage(opts.url, opts.image, opts.model, opts.verbose, opts.ckey, opts.cert, opts.capath) | |
| 367 if res: | |
| 368 print(res) | |
| 369 | |
| 370 if __name__ == '__main__': | |
| 371 main() |
