Mercurial > repos > kls286 > chap_test_20230328
diff build/bdist.linux-x86_64/egg/MLaaS/tfaas_client.py @ 0:cbbe42422d56 draft
planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
author | kls286 |
---|---|
date | Tue, 28 Mar 2023 15:07:30 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/build/bdist.linux-x86_64/egg/MLaaS/tfaas_client.py Tue Mar 28 15:07:30 2023 +0000 @@ -0,0 +1,371 @@ +#!/usr/bin/env python +#-*- coding: utf-8 -*- +#pylint: disable= +""" +File : tfaas_client.py +Author : Valentin Kuznetsov <vkuznet AT gmail dot com> +Description: simple python client to communicate with TFaaS server +""" + +# system modules +import os +import sys +import pwd +import ssl +import json +import binascii +import argparse +import itertools +import mimetypes +if sys.version_info < (2, 7): + raise Exception("TFaaS client requires python 2.7 or greater") +# python 3 +if sys.version.startswith('3.'): + import urllib.request as urllib2 + import urllib.parse as urllib + import http.client as httplib + import http.cookiejar as cookielib +else: + import mimetools + import urllib + import urllib2 + import httplib + import cookielib + +TFAAS_CLIENT = 'tfaas-client/1.1::python/%s.%s' % sys.version_info[:2] + +class OptionParser(): + def __init__(self): + "User based option parser" + self.parser = argparse.ArgumentParser(prog='PROG') + self.parser.add_argument("--url", action="store", + dest="url", default="", help="TFaaS URL") + self.parser.add_argument("--upload", action="store", + dest="upload", default="", help="upload model to TFaaS") + self.parser.add_argument("--bundle", action="store", + dest="bundle", default="", help="upload bundle ML files to TFaaS") + self.parser.add_argument("--predict", action="store", + dest="predict", default="", help="fetch prediction from TFaaS") + self.parser.add_argument("--image", action="store", + dest="image", default="", help="fetch prediction for given image") + self.parser.add_argument("--model", action="store", + dest="model", default="", help="TF model to use") + self.parser.add_argument("--delete", action="store", + dest="delete", default="", help="delete model in TFaaS") + self.parser.add_argument("--models", action="store_true", + dest="models", default=False, help="show existing models in TFaaS") + self.parser.add_argument("--verbose", action="store_true", + dest="verbose", default=False, help="verbose output") + msg = 'specify private key file name, default $X509_USER_PROXY' + self.parser.add_argument("--key", action="store", + default=x509(), dest="ckey", help=msg) + msg = 'specify private certificate file name, default $X509_USER_PROXY' + self.parser.add_argument("--cert", action="store", + default=x509(), dest="cert", help=msg) + default_ca = os.environ.get("X509_CERT_DIR") + if not default_ca or not os.path.exists(default_ca): + default_ca = "/etc/grid-security/certificates" + if not os.path.exists(default_ca): + default_ca = "" + if default_ca: + msg = 'specify CA path, default currently is %s' % default_ca + else: + msg = 'specify CA path; defaults to system CAs.' + self.parser.add_argument("--capath", action="store", + default=default_ca, dest="capath", help=msg) + msg = 'specify number of retries upon busy DAS server message' + +class HTTPSClientAuthHandler(urllib2.HTTPSHandler): + """ + Simple HTTPS client authentication class based on provided + key/ca information + """ + def __init__(self, key=None, cert=None, capath=None, level=0): + if level > 0: + urllib2.HTTPSHandler.__init__(self, debuglevel=1) + else: + urllib2.HTTPSHandler.__init__(self) + self.key = key + self.cert = cert + self.capath = capath + + def https_open(self, req): + """Open request method""" + #Rather than pass in a reference to a connection class, we pass in + # a reference to a function which, for all intents and purposes, + # will behave as a constructor + return self.do_open(self.get_connection, req) + + def get_connection(self, host, timeout=300): + """Connection method""" + if self.key and self.cert and not self.capath: + return httplib.HTTPSConnection(host, key_file=self.key, + cert_file=self.cert) + elif self.cert and self.capath: + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_verify_locations(capath=self.capath) + context.load_cert_chain(self.cert) + return httplib.HTTPSConnection(host, context=context) + return httplib.HTTPSConnection(host) + +def x509(): + "Helper function to get x509 either from env or tmp file" + proxy = os.environ.get('X509_USER_PROXY', '') + if not proxy: + proxy = '/tmp/x509up_u%s' % pwd.getpwuid( os.getuid() ).pw_uid + if not os.path.isfile(proxy): + return '' + return proxy + +def check_auth(key): + "Check if user runs das_client with key/cert and warn users to switch" + if not key: + msg = "WARNING: tfaas_client is running without user credentials/X509 proxy, create proxy via 'voms-proxy-init -voms cms -rfc'" + print(msg) + +def fullpath(path): + "Expand path to full path" + if path and path[0] == '~': + path = path.replace('~', '') + path = path[1:] if path[0] == '/' else path + path = os.path.join(os.environ['HOME'], path) + return path + +def choose_boundary(): + """ + Helper function to replace deprecated mimetools.choose_boundary + https://stackoverflow.com/questions/27099290/where-is-mimetools-choose-boundary-function-in-python3 + https://docs.python.org/2.7/library/mimetools.html?highlight=choose_boundary#mimetools.choose_boundary + >>> mimetools.choose_boundary() + '192.168.1.191.502.42035.1678979116.376.1' + """ + # we will return any random string + import uuid + return str(uuid.uuid4()) + +# credit: https://pymotw.com/2/urllib2/#uploading-files +class MultiPartForm(object): + """Accumulate the data to be used when posting a form.""" + + def __init__(self): + self.form_fields = [] + self.files = [] + if sys.version.startswith('3.'): + self.boundary = choose_boundary() + else: + self.boundary = mimetools.choose_boundary() + return + + def get_content_type(self): + return 'multipart/form-data; boundary=%s' % self.boundary + + def add_field(self, name, value): + """Add a simple field to the form data.""" + self.form_fields.append((name, value)) + return + + def add_file(self, fieldname, filename, fileHandle, mimetype=None): + """Add a file to be uploaded.""" + body = fileHandle.read() + if mimetype is None: + mimetype = mimetypes.guess_type(filename)[0] or 'application/octet-stream' + if mimetype == 'application/octet-stream': + body = binascii.b2a_base64(body) +# if isinstance(body, bytes): +# body = body.decode("utf-8") + self.files.append((fieldname, filename, mimetype, body)) + return + + def __str__(self): + """Return a string representing the form data, including attached files.""" + # Build a list of lists, each containing "lines" of the + # request. Each part is separated by a boundary string. + # Once the list is built, return a string where each + # line is separated by '\r\n'. + parts = [] + part_boundary = '--' + self.boundary + + # Add the form fields + parts.extend( + [ part_boundary, + 'Content-Disposition: form-data; name="%s"' % name, + '', + value, + ] + for name, value in self.form_fields + ) + + # Add the files to upload + # here we use form-data content disposition instead of file one + # since this is how we define handlers in our Go server + # for more info see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition + parts.extend( + [ part_boundary, + 'Content-Disposition: form-data; name="%s"; filename="%s"' % \ + (field_name, filename), + 'Content-Type: %s' % content_type, + '', + body, + ] + for field_name, filename, content_type, body in self.files + ) + + # Flatten the list and add closing boundary marker, + # then return CR+LF separated data + flattened = list(itertools.chain(*parts)) + flattened.append('--' + self.boundary + '--') + flattened.append('') + return '\r\n'.join(flattened) + +def models(host, verbose=None, ckey=None, cert=None, capath=None): + "models API shows models from TFaaS server" + url = host + '/models' + client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) + headers = {"Accept": "application/json", "User-Agent": client} + if verbose: + print("URL : %s" % url) + encoded_data = json.dumps({}) + return getdata(url, headers, encoded_data, ckey, cert, capath, verbose, 'GET') + +def delete(host, model, verbose=None, ckey=None, cert=None, capath=None): + "delete API deletes given model in TFaaS server" + url = host + '/delete' + client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) + headers = {"User-Agent": client} + if verbose: + print("URL : %s" % url) + print("model : %s" % model) + form = MultiPartForm() + form.add_field('model', model) + edata = str(form) + headers['Content-length'] = len(edata) + headers['Content-Type'] = form.get_content_type() + return getdata(url, headers, edata, ckey, cert, capath, verbose, method='DELETE') + +def bundle(host, ifile, verbose=None, ckey=None, cert=None, capath=None): + "bundle API uploads given bundle model files to TFaaS server" + url = host + '/upload' + client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) + headers = {"User-Agent": client, "Content-Encoding": "gzip", "Content-Type": "application/octet-stream"} + data = open(ifile, 'rb').read() + return getdata(url, headers, data, ckey, cert, capath, verbose) + +def upload(host, ifile, verbose=None, ckey=None, cert=None, capath=None): + "upload API uploads given model to TFaaS server" + url = host + '/upload' + client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) + headers = {"User-Agent": client} + params = json.load(open(ifile)) + if verbose: + print("URL : %s" % url) + print("ifile : %s" % ifile) + print("params: %s" % json.dumps(params)) + + form = MultiPartForm() + for key in params.keys(): + if key in ['model', 'labels', 'params']: + flag = 'r' + if key == 'model': + flag = 'rb' + name = params[key] + form.add_file(key, name, fileHandle=open(name, flag)) + else: + form.add_field(key, params[key]) + edata = str(form) + headers['Content-length'] = len(edata) + headers['Content-Type'] = form.get_content_type() + headers['Content-Encoding'] = 'base64' + return getdata(url, headers, edata, ckey, cert, capath, verbose) + +def predict(host, ifile, model, verbose=None, ckey=None, cert=None, capath=None): + "predict API get predictions from TFaaS server" + url = host + '/json' + client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) + headers = {"Accept": "application/json", "User-Agent": client} + params = json.load(open(ifile)) + if model: # overwrite model name in given input file + params['model'] = model + if verbose: + print("URL : %s" % url) + print("ifile : %s" % ifile) + print("params: %s" % json.dumps(params)) + encoded_data = json.dumps(params) + return getdata(url, headers, encoded_data, ckey, cert, capath, verbose) + +def predictImage(host, ifile, model, verbose=None, ckey=None, cert=None, capath=None): + "predict API get predictions from TFaaS server" + url = host + '/image' + client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', '')) + headers = {"Accept": "application/json", "User-Agent": client} + if verbose: + print("URL : %s" % url) + print("ifile : %s" % ifile) + print("model : %s" % model) + form = MultiPartForm() +# form.add_file('image', ifile, fileHandle=open(ifile, 'r')) + form.add_file('image', ifile, fileHandle=open(ifile, 'rb')) + form.add_field('model', model) + edata = str(form) + headers['Content-length'] = len(edata) + headers['Content-Type'] = form.get_content_type() + return getdata(url, headers, edata, ckey, cert, capath, verbose) + +def getdata(url, headers, encoded_data, ckey, cert, capath, verbose=None, method='POST'): + "helper function to use in predict/upload APIs, it place given URL call to the server" + debug = 1 if verbose else 0 + req = urllib2.Request(url=url, headers=headers, data=encoded_data) + if method == 'DELETE': + req.get_method = lambda: 'DELETE' + elif method == 'GET': + req = urllib2.Request(url=url, headers=headers) + if ckey and cert: + ckey = fullpath(ckey) + cert = fullpath(cert) + http_hdlr = HTTPSClientAuthHandler(ckey, cert, capath, debug) + elif cert and capath: + cert = fullpath(cert) + http_hdlr = HTTPSClientAuthHandler(ckey, cert, capath, debug) + else: + http_hdlr = urllib2.HTTPHandler(debuglevel=debug) + proxy_handler = urllib2.ProxyHandler({}) + cookie_jar = cookielib.CookieJar() + cookie_handler = urllib2.HTTPCookieProcessor(cookie_jar) + data = {} + try: + opener = urllib2.build_opener(http_hdlr, proxy_handler, cookie_handler) + fdesc = opener.open(req) + if url.endswith('json'): + data = json.load(fdesc) + else: + data = fdesc.read() + fdesc.close() + except urllib2.HTTPError as error: + print(error.read()) + sys.exit(1) + if url.endswith('json'): + return json.dumps(data) + return data + +def main(): + "Main function" + optmgr = OptionParser() + opts = optmgr.parser.parse_args() + check_auth(opts.ckey) + res = '' + if opts.upload: + res = upload(opts.url, opts.upload, opts.verbose, opts.ckey, opts.cert, opts.capath) + if opts.bundle: + res = bundle(opts.url, opts.bundle, opts.verbose, opts.ckey, opts.cert, opts.capath) + elif opts.delete: + res = delete(opts.url, opts.delete, opts.verbose, opts.ckey, opts.cert, opts.capath) + elif opts.models: + res = models(opts.url, opts.verbose, opts.ckey, opts.cert, opts.capath) + elif opts.predict: + res = predict(opts.url, opts.predict, opts.model, opts.verbose, opts.ckey, opts.cert, opts.capath) + elif opts.image: + res = predictImage(opts.url, opts.image, opts.model, opts.verbose, opts.ckey, opts.cert, opts.capath) + if res: + print(res) + +if __name__ == '__main__': + main()