Mercurial > repos > kls286 > chap_test_20230328
view build/bdist.linux-x86_64/egg/MLaaS/tfaas_client.py @ 0:cbbe42422d56 draft
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
author | kls286 |
---|---|
date | Tue, 28 Mar 2023 15:07:30 +0000 |
parents | |
children |
line wrap: on
line source
#!/usr/bin/env python #-*- coding: utf-8 -*- #pylint: disable= """ File : tfaas_client.py Author : Valentin Kuznetsov <vkuznet AT gmail dot com> Description: simple python client to communicate with TFaaS server """ # system modules import os import sys import pwd import ssl import json import binascii import argparse import itertools import mimetypes if sys.version_info < (2, 7): raise Exception("TFaaS client requires python 2.7 or greater") # python 3 if sys.version.startswith('3.'): import urllib.request as urllib2 import urllib.parse as urllib import http.client as httplib import http.cookiejar as cookielib else: import mimetools import urllib import urllib2 import httplib import cookielib TFAAS_CLIENT = 'tfaas-client/1.1::python/%s.%s' % sys.version_info[:2] class OptionParser(): def __init__(self): "User based option parser" self.parser = argparse.ArgumentParser(prog='PROG') self.parser.add_argument("--url", action="store", dest="url", default="", help="TFaaS URL") self.parser.add_argument("--upload", action="store", dest="upload", default="", help="upload model to TFaaS") self.parser.add_argument("--bundle", action="store", dest="bundle", default="", help="upload bundle ML files to TFaaS") self.parser.add_argument("--predict", action="store", dest="predict", default="", help="fetch prediction from TFaaS") self.parser.add_argument("--image", action="store", dest="image", default="", help="fetch prediction for given image") self.parser.add_argument("--model", action="store", dest="model", default="", help="TF model to use") self.parser.add_argument("--delete", action="store", dest="delete", default="", help="delete model in TFaaS") self.parser.add_argument("--models", action="store_true", dest="models", default=False, help="show existing models in TFaaS") self.parser.add_argument("--verbose", action="store_true", dest="verbose", default=False, help="verbose output") msg = 'specify private key file name, default $X509_USER_PROXY' self.parser.add_argument("--key", action="store", default=x509(), dest="ckey", help=msg) msg = 'specify private certificate file name, default $X509_USER_PROXY' self.parser.add_argument("--cert", action="store", default=x509(), dest="cert", help=msg) default_ca = os.environ.get("X509_CERT_DIR") if not default_ca or not os.path.exists(default_ca): default_ca = "/etc/grid-security/certificates" if not os.path.exists(default_ca): default_ca = "" if default_ca: msg = 'specify CA path, default currently is %s' % default_ca else: msg = 'specify CA path; defaults to system CAs.' self.parser.add_argument("--capath", action="store", default=default_ca, dest="capath", help=msg) msg = 'specify number of retries upon busy DAS server message' class HTTPSClientAuthHandler(urllib2.HTTPSHandler): """ Simple HTTPS client authentication class based on provided key/ca information """ def __init__(self, key=None, cert=None, capath=None, level=0): if level > 0: urllib2.HTTPSHandler.__init__(self, debuglevel=1) else: urllib2.HTTPSHandler.__init__(self) self.key = key self.cert = cert self.capath = capath def https_open(self, req): """Open request method""" #Rather than pass in a reference to a connection class, we pass in # a reference to a function which, for all intents and purposes, # will behave as a constructor return self.do_open(self.get_connection, req) def get_connection(self, host, timeout=300): """Connection method""" if self.key and self.cert and not self.capath: return httplib.HTTPSConnection(host, key_file=self.key, cert_file=self.cert) elif self.cert and self.capath: context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) context.load_verify_locations(capath=self.capath) context.load_cert_chain(self.cert) return httplib.HTTPSConnection(host, context=context) return httplib.HTTPSConnection(host) def x509(): "Helper function to get x509 either from env or tmp file" proxy = os.environ.get('X509_USER_PROXY', '') if not proxy: proxy = '/tmp/x509up_u%s' % pwd.getpwuid( os.getuid() ).pw_uid if not os.path.isfile(proxy): return '' return proxy def check_auth(key): "Check if user runs das_client with key/cert and warn users to switch" if not key: msg = "WARNING: tfaas_client is running without user credentials/X509 proxy, create proxy via 'voms-proxy-init -voms cms -rfc'" print(msg) def fullpath(path): "Expand path to full path" if path and path[0] == '~': path = path.replace('~', '') path = path[1:] if path[0] == '/' else path path = os.path.join(os.environ['HOME'], path) return path def choose_boundary(): """ Helper function to replace deprecated mimetools.choose_boundary https://stackoverflow.com/questions/27099290/where-is-mimetools-choose-boundary-function-in-python3 https://docs.python.org/2.7/library/mimetools.html?highlight=choose_boundary#mimetools.choose_boundary >>> mimetools.choose_boundary() '192.168.1.191.502.42035.1678979116.376.1' """ # we will return any random string import uuid return str(uuid.uuid4()) # credit: https://pymotw.com/2/urllib2/#uploading-files class MultiPartForm(object): """Accumulate the data to be used when posting a form.""" def __init__(self): self.form_fields = [] self.files = [] if sys.version.startswith('3.'): self.boundary = choose_boundary() else: self.boundary = mimetools.choose_boundary() return def get_content_type(self): return 'multipart/form-data; boundary=%s' % self.boundary def add_field(self, name, value): """Add a simple field to the form data.""" self.form_fields.append((name, value)) return def add_file(self, fieldname, filename, fileHandle, mimetype=None): """Add a file to be uploaded.""" body = fileHandle.read() if mimetype is None: mimetype = mimetypes.guess_type(filename)[0] or 'application/octet-stream' if mimetype == 'application/octet-stream': body = binascii.b2a_base64(body) # if isinstance(body, bytes): # body = body.decode("utf-8") self.files.append((fieldname, filename, mimetype, body)) return def __str__(self): """Return a string representing the form data, including attached files.""" # Build a list of lists, each containing "lines" of the # request. Each part is separated by a boundary string. # Once the list is built, return a string where each # line is separated by '\r\n'. parts = [] part_boundary = '--' + self.boundary # Add the form fields parts.extend( [ part_boundary, 'Content-Disposition: form-data; name="%s"' % name, '', value, ] for name, value in self.form_fields ) # Add the files to upload # here we use form-data content disposition instead of file one # since this is how we define handlers in our Go server # for more info see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition parts.extend( [ part_boundary, 'Content-Disposition: form-data; name="%s"; filename="%s"' % \ (field_name, filename), 'Content-Type: %s' % content_type, '', body, ] for field_name, filename, content_type, body in self.files ) # Flatten the list and add closing boundary marker, # then return CR+LF separated data flattened = list(itertools.chain(*parts)) flattened.append('--' + self.boundary + '--') flattened.append('') return '\r\n'.join(flattened) def models(host, verbose=None, ckey=None, cert=None, capath=None): "models API shows models from TFaaS server" url = host + '/models' client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) headers = {"Accept": "application/json", "User-Agent": client} if verbose: print("URL : %s" % url) encoded_data = json.dumps({}) return getdata(url, headers, encoded_data, ckey, cert, capath, verbose, 'GET') def delete(host, model, verbose=None, ckey=None, cert=None, capath=None): "delete API deletes given model in TFaaS server" url = host + '/delete' client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) headers = {"User-Agent": client} if verbose: print("URL : %s" % url) print("model : %s" % model) form = MultiPartForm() form.add_field('model', model) edata = str(form) headers['Content-length'] = len(edata) headers['Content-Type'] = form.get_content_type() return getdata(url, headers, edata, ckey, cert, capath, verbose, method='DELETE') def bundle(host, ifile, verbose=None, ckey=None, cert=None, capath=None): "bundle API uploads given bundle model files to TFaaS server" url = host + '/upload' client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) headers = {"User-Agent": client, "Content-Encoding": "gzip", "Content-Type": "application/octet-stream"} data = open(ifile, 'rb').read() return getdata(url, headers, data, ckey, cert, capath, verbose) def upload(host, ifile, verbose=None, ckey=None, cert=None, capath=None): "upload API uploads given model to TFaaS server" url = host + '/upload' client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) headers = {"User-Agent": client} params = json.load(open(ifile)) if verbose: print("URL : %s" % url) print("ifile : %s" % ifile) print("params: %s" % json.dumps(params)) form = MultiPartForm() for key in params.keys(): if key in ['model', 'labels', 'params']: flag = 'r' if key == 'model': flag = 'rb' name = params[key] form.add_file(key, name, fileHandle=open(name, flag)) else: form.add_field(key, params[key]) edata = str(form) headers['Content-length'] = len(edata) headers['Content-Type'] = form.get_content_type() headers['Content-Encoding'] = 'base64' return getdata(url, headers, edata, ckey, cert, capath, verbose) def predict(host, ifile, model, verbose=None, ckey=None, cert=None, capath=None): "predict API get predictions from TFaaS server" url = host + '/json' client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) headers = {"Accept": "application/json", "User-Agent": client} params = json.load(open(ifile)) if model: # overwrite model name in given input file params['model'] = model if verbose: print("URL : %s" % url) print("ifile : %s" % ifile) print("params: %s" % json.dumps(params)) encoded_data = json.dumps(params) return getdata(url, headers, encoded_data, ckey, cert, capath, verbose) def predictImage(host, ifile, model, verbose=None, ckey=None, cert=None, capath=None): "predict API get predictions from TFaaS server" url = host + '/image' client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) headers = {"Accept": "application/json", "User-Agent": client} if verbose: print("URL : %s" % url) print("ifile : %s" % ifile) print("model : %s" % model) form = MultiPartForm() # form.add_file('image', ifile, fileHandle=open(ifile, 'r')) form.add_file('image', ifile, fileHandle=open(ifile, 'rb')) form.add_field('model', model) edata = str(form) headers['Content-length'] = len(edata) headers['Content-Type'] = form.get_content_type() return getdata(url, headers, edata, ckey, cert, capath, verbose) def getdata(url, headers, encoded_data, ckey, cert, capath, verbose=None, method='POST'): "helper function to use in predict/upload APIs, it place given URL call to the server" debug = 1 if verbose else 0 req = urllib2.Request(url=url, headers=headers, data=encoded_data) if method == 'DELETE': req.get_method = lambda: 'DELETE' elif method == 'GET': req = urllib2.Request(url=url, headers=headers) if ckey and cert: ckey = fullpath(ckey) cert = fullpath(cert) http_hdlr = HTTPSClientAuthHandler(ckey, cert, capath, debug) elif cert and capath: cert = fullpath(cert) http_hdlr = HTTPSClientAuthHandler(ckey, cert, capath, debug) else: http_hdlr = urllib2.HTTPHandler(debuglevel=debug) proxy_handler = urllib2.ProxyHandler({}) cookie_jar = cookielib.CookieJar() cookie_handler = urllib2.HTTPCookieProcessor(cookie_jar) data = {} try: opener = urllib2.build_opener(http_hdlr, proxy_handler, cookie_handler) fdesc = opener.open(req) if url.endswith('json'): data = json.load(fdesc) else: data = fdesc.read() fdesc.close() except urllib2.HTTPError as error: print(error.read()) sys.exit(1) if url.endswith('json'): return json.dumps(data) return data def main(): "Main function" optmgr = OptionParser() opts = optmgr.parser.parse_args() check_auth(opts.ckey) res = '' if opts.upload: res = upload(opts.url, opts.upload, opts.verbose, opts.ckey, opts.cert, opts.capath) if opts.bundle: res = bundle(opts.url, opts.bundle, opts.verbose, opts.ckey, opts.cert, opts.capath) elif opts.delete: res = delete(opts.url, opts.delete, opts.verbose, opts.ckey, opts.cert, opts.capath) elif opts.models: res = models(opts.url, opts.verbose, opts.ckey, opts.cert, opts.capath) elif opts.predict: res = predict(opts.url, opts.predict, opts.model, opts.verbose, opts.ckey, opts.cert, opts.capath) elif opts.image: res = predictImage(opts.url, opts.image, opts.model, opts.verbose, opts.ckey, opts.cert, opts.capath) if res: print(res) if __name__ == '__main__': main()