diff build/bdist.linux-x86_64/egg/MLaaS/tfaas_client.py @ 0:cbbe42422d56 draft

planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
author kls286
date Tue, 28 Mar 2023 15:07:30 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/build/bdist.linux-x86_64/egg/MLaaS/tfaas_client.py	Tue Mar 28 15:07:30 2023 +0000
@@ -0,0 +1,371 @@
+#!/usr/bin/env python
+#-*- coding: utf-8 -*-
+#pylint: disable=
+"""
+File       : tfaas_client.py
+Author     : Valentin Kuznetsov <vkuznet AT gmail dot com>
+Description: simple python client to communicate with TFaaS server
+"""
+
+# system modules
+import os
+import sys
+import pwd
+import ssl
+import json
+import binascii
+import argparse
+import itertools
+import mimetypes
+if  sys.version_info < (2, 7):
+    raise Exception("TFaaS client requires python 2.7 or greater")
+# python 3
+if  sys.version.startswith('3.'):
+    import urllib.request as urllib2
+    import urllib.parse as urllib
+    import http.client as httplib
+    import http.cookiejar as cookielib
+else:
+    import mimetools
+    import urllib
+    import urllib2
+    import httplib
+    import cookielib
+
+TFAAS_CLIENT = 'tfaas-client/1.1::python/%s.%s' % sys.version_info[:2]
+
+class OptionParser():
+    def __init__(self):
+        "User based option parser"
+        self.parser = argparse.ArgumentParser(prog='PROG')
+        self.parser.add_argument("--url", action="store",
+            dest="url", default="", help="TFaaS URL")
+        self.parser.add_argument("--upload", action="store",
+            dest="upload", default="", help="upload model to TFaaS")
+        self.parser.add_argument("--bundle", action="store",
+            dest="bundle", default="", help="upload bundle ML files to TFaaS")
+        self.parser.add_argument("--predict", action="store",
+            dest="predict", default="", help="fetch prediction from TFaaS")
+        self.parser.add_argument("--image", action="store",
+            dest="image", default="", help="fetch prediction for given image")
+        self.parser.add_argument("--model", action="store",
+            dest="model", default="", help="TF model to use")
+        self.parser.add_argument("--delete", action="store",
+            dest="delete", default="", help="delete model in TFaaS")
+        self.parser.add_argument("--models", action="store_true",
+            dest="models", default=False, help="show existing models in TFaaS")
+        self.parser.add_argument("--verbose", action="store_true",
+            dest="verbose", default=False, help="verbose output")
+        msg  = 'specify private key file name, default $X509_USER_PROXY'
+        self.parser.add_argument("--key", action="store",
+                               default=x509(), dest="ckey", help=msg)
+        msg  = 'specify private certificate file name, default $X509_USER_PROXY'
+        self.parser.add_argument("--cert", action="store",
+                               default=x509(), dest="cert", help=msg)
+        default_ca = os.environ.get("X509_CERT_DIR")
+        if not default_ca or not os.path.exists(default_ca):
+            default_ca = "/etc/grid-security/certificates"
+            if not os.path.exists(default_ca):
+                default_ca = ""
+        if default_ca:
+            msg = 'specify CA path, default currently is %s' % default_ca
+        else:
+            msg = 'specify CA path; defaults to system CAs.'
+        self.parser.add_argument("--capath", action="store",
+                               default=default_ca, dest="capath", help=msg)
+        msg  = 'specify number of retries upon busy DAS server message'
+
+class HTTPSClientAuthHandler(urllib2.HTTPSHandler):
+    """
+    Simple HTTPS client authentication class based on provided
+    key/ca information
+    """
+    def __init__(self, key=None, cert=None, capath=None, level=0):
+        if  level > 0:
+            urllib2.HTTPSHandler.__init__(self, debuglevel=1)
+        else:
+            urllib2.HTTPSHandler.__init__(self)
+        self.key = key
+        self.cert = cert
+        self.capath = capath
+
+    def https_open(self, req):
+        """Open request method"""
+        #Rather than pass in a reference to a connection class, we pass in
+        # a reference to a function which, for all intents and purposes,
+        # will behave as a constructor
+        return self.do_open(self.get_connection, req)
+
+    def get_connection(self, host, timeout=300):
+        """Connection method"""
+        if  self.key and self.cert and not self.capath:
+            return httplib.HTTPSConnection(host, key_file=self.key,
+                                                cert_file=self.cert)
+        elif self.cert and self.capath:
+            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+            context.load_verify_locations(capath=self.capath)
+            context.load_cert_chain(self.cert)
+            return httplib.HTTPSConnection(host, context=context)
+        return httplib.HTTPSConnection(host)
+
+def x509():
+    "Helper function to get x509 either from env or tmp file"
+    proxy = os.environ.get('X509_USER_PROXY', '')
+    if  not proxy:
+        proxy = '/tmp/x509up_u%s' % pwd.getpwuid( os.getuid() ).pw_uid
+        if  not os.path.isfile(proxy):
+            return ''
+    return proxy
+
+def check_auth(key):
+    "Check if user runs das_client with key/cert and warn users to switch"
+    if  not key:
+        msg  = "WARNING: tfaas_client is running without user credentials/X509 proxy, create proxy via 'voms-proxy-init -voms cms -rfc'"
+        print(msg)
+
+def fullpath(path):
+    "Expand path to full path"
+    if  path and path[0] == '~':
+        path = path.replace('~', '')
+        path = path[1:] if path[0] == '/' else path
+        path = os.path.join(os.environ['HOME'], path)
+    return path
+
+def choose_boundary():
+    """
+    Helper function to replace deprecated mimetools.choose_boundary
+    https://stackoverflow.com/questions/27099290/where-is-mimetools-choose-boundary-function-in-python3
+    https://docs.python.org/2.7/library/mimetools.html?highlight=choose_boundary#mimetools.choose_boundary
+    >>> mimetools.choose_boundary()
+    '192.168.1.191.502.42035.1678979116.376.1'
+    """
+    # we will return any random string
+    import uuid
+    return str(uuid.uuid4())
+
+# credit: https://pymotw.com/2/urllib2/#uploading-files
+class MultiPartForm(object):
+    """Accumulate the data to be used when posting a form."""
+
+    def __init__(self):
+        self.form_fields = []
+        self.files = []
+        if  sys.version.startswith('3.'):
+            self.boundary = choose_boundary()
+        else:
+            self.boundary = mimetools.choose_boundary()
+        return
+    
+    def get_content_type(self):
+        return 'multipart/form-data; boundary=%s' % self.boundary
+
+    def add_field(self, name, value):
+        """Add a simple field to the form data."""
+        self.form_fields.append((name, value))
+        return
+
+    def add_file(self, fieldname, filename, fileHandle, mimetype=None):
+        """Add a file to be uploaded."""
+        body = fileHandle.read()
+        if mimetype is None:
+            mimetype = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
+        if mimetype == 'application/octet-stream':
+            body = binascii.b2a_base64(body)
+#         if isinstance(body, bytes):
+#             body = body.decode("utf-8") 
+        self.files.append((fieldname, filename, mimetype, body))
+        return
+    
+    def __str__(self):
+        """Return a string representing the form data, including attached files."""
+        # Build a list of lists, each containing "lines" of the
+        # request.  Each part is separated by a boundary string.
+        # Once the list is built, return a string where each
+        # line is separated by '\r\n'.  
+        parts = []
+        part_boundary = '--' + self.boundary
+        
+        # Add the form fields
+        parts.extend(
+            [ part_boundary,
+              'Content-Disposition: form-data; name="%s"' % name,
+              '',
+              value,
+            ]
+            for name, value in self.form_fields
+            )
+        
+        # Add the files to upload
+        # here we use form-data content disposition instead of file one
+        # since this is how we define handlers in our Go server
+        # for more info see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition
+        parts.extend(
+            [ part_boundary,
+              'Content-Disposition: form-data; name="%s"; filename="%s"' % \
+                 (field_name, filename),
+              'Content-Type: %s' % content_type,
+              '',
+              body,
+            ]
+            for field_name, filename, content_type, body in self.files
+            )
+        
+        # Flatten the list and add closing boundary marker,
+        # then return CR+LF separated data
+        flattened = list(itertools.chain(*parts))
+        flattened.append('--' + self.boundary + '--')
+        flattened.append('')
+        return '\r\n'.join(flattened)
+
+def models(host, verbose=None, ckey=None, cert=None, capath=None):
+    "models API shows models from TFaaS server"
+    url = host + '/models'
+    client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', ''))
+    headers = {"Accept": "application/json", "User-Agent": client}
+    if verbose:
+        print("URL   : %s" % url)
+    encoded_data = json.dumps({})
+    return getdata(url, headers, encoded_data, ckey, cert, capath, verbose, 'GET')
+
+def delete(host, model, verbose=None, ckey=None, cert=None, capath=None):
+    "delete API deletes given model in TFaaS server"
+    url = host + '/delete'
+    client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', ''))
+    headers = {"User-Agent": client}
+    if verbose:
+        print("URL   : %s" % url)
+        print("model : %s" % model)
+    form = MultiPartForm()
+    form.add_field('model', model)
+    edata = str(form)
+    headers['Content-length'] = len(edata)
+    headers['Content-Type'] = form.get_content_type()
+    return getdata(url, headers, edata, ckey, cert, capath, verbose, method='DELETE')
+
+def bundle(host, ifile, verbose=None, ckey=None, cert=None, capath=None):
+    "bundle API uploads given bundle model files to TFaaS server"
+    url = host + '/upload'
+    client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', ''))
+    headers = {"User-Agent": client, "Content-Encoding": "gzip", "Content-Type": "application/octet-stream"}
+    data = open(ifile, 'rb').read()
+    return getdata(url, headers, data, ckey, cert, capath, verbose)
+
+def upload(host, ifile, verbose=None, ckey=None, cert=None, capath=None):
+    "upload API uploads given model to TFaaS server"
+    url = host + '/upload'
+    client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', ''))
+    headers = {"User-Agent": client}
+    params = json.load(open(ifile))
+    if verbose:
+        print("URL   : %s" % url)
+        print("ifile : %s" % ifile)
+        print("params: %s" % json.dumps(params))
+
+    form = MultiPartForm()
+    for key in params.keys():
+        if key in ['model', 'labels', 'params']:
+            flag = 'r'
+            if key == 'model':
+                flag = 'rb'
+            name = params[key]
+            form.add_file(key, name, fileHandle=open(name, flag))
+        else:
+            form.add_field(key, params[key])
+    edata = str(form)
+    headers['Content-length'] = len(edata)
+    headers['Content-Type'] = form.get_content_type()
+    headers['Content-Encoding'] = 'base64'
+    return getdata(url, headers, edata, ckey, cert, capath, verbose)
+
+def predict(host, ifile, model, verbose=None, ckey=None, cert=None, capath=None):
+    "predict API get predictions from TFaaS server"
+    url = host + '/json'
+    client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', ''))
+    headers = {"Accept": "application/json", "User-Agent": client}
+    params = json.load(open(ifile))
+    if model: # overwrite model name in given input file
+        params['model'] = model
+    if verbose:
+        print("URL   : %s" % url)
+        print("ifile : %s" % ifile)
+        print("params: %s" % json.dumps(params))
+    encoded_data = json.dumps(params)
+    return getdata(url, headers, encoded_data, ckey, cert, capath, verbose)
+
+def predictImage(host, ifile, model, verbose=None, ckey=None, cert=None, capath=None):
+    "predict API get predictions from TFaaS server"
+    url = host + '/image'
+    client = '%s (%s)' % (TFAAS_CLIENT, os.environ.get('USER', ''))
+    headers = {"Accept": "application/json", "User-Agent": client}
+    if verbose:
+        print("URL   : %s" % url)
+        print("ifile : %s" % ifile)
+        print("model : %s" % model)
+    form = MultiPartForm()
+#     form.add_file('image', ifile, fileHandle=open(ifile, 'r'))
+    form.add_file('image', ifile, fileHandle=open(ifile, 'rb'))
+    form.add_field('model', model)
+    edata = str(form)
+    headers['Content-length'] = len(edata)
+    headers['Content-Type'] = form.get_content_type()
+    return getdata(url, headers, edata, ckey, cert, capath, verbose)
+
+def getdata(url, headers, encoded_data, ckey, cert, capath, verbose=None, method='POST'):
+    "helper function to use in predict/upload APIs, it place given URL call to the server"
+    debug = 1 if verbose else 0
+    req = urllib2.Request(url=url, headers=headers, data=encoded_data)
+    if method == 'DELETE':
+        req.get_method = lambda: 'DELETE'
+    elif method == 'GET':
+        req = urllib2.Request(url=url, headers=headers)
+    if  ckey and cert:
+        ckey = fullpath(ckey)
+        cert = fullpath(cert)
+        http_hdlr  = HTTPSClientAuthHandler(ckey, cert, capath, debug)
+    elif cert and capath:
+        cert = fullpath(cert)
+        http_hdlr  = HTTPSClientAuthHandler(ckey, cert, capath, debug)
+    else:
+        http_hdlr  = urllib2.HTTPHandler(debuglevel=debug)
+    proxy_handler  = urllib2.ProxyHandler({})
+    cookie_jar     = cookielib.CookieJar()
+    cookie_handler = urllib2.HTTPCookieProcessor(cookie_jar)
+    data = {}
+    try:
+        opener = urllib2.build_opener(http_hdlr, proxy_handler, cookie_handler)
+        fdesc = opener.open(req)
+        if url.endswith('json'):
+            data = json.load(fdesc)
+        else:
+            data = fdesc.read()
+        fdesc.close()
+    except urllib2.HTTPError as error:
+        print(error.read())
+        sys.exit(1)
+    if url.endswith('json'):
+        return json.dumps(data)
+    return data
+
+def main():
+    "Main function"
+    optmgr  = OptionParser()
+    opts = optmgr.parser.parse_args()
+    check_auth(opts.ckey)
+    res = ''
+    if opts.upload:
+        res = upload(opts.url, opts.upload, opts.verbose, opts.ckey, opts.cert, opts.capath)
+    if opts.bundle:
+        res = bundle(opts.url, opts.bundle, opts.verbose, opts.ckey, opts.cert, opts.capath)
+    elif opts.delete:
+        res = delete(opts.url, opts.delete, opts.verbose, opts.ckey, opts.cert, opts.capath)
+    elif opts.models:
+        res = models(opts.url, opts.verbose, opts.ckey, opts.cert, opts.capath)
+    elif opts.predict:
+        res = predict(opts.url, opts.predict, opts.model, opts.verbose, opts.ckey, opts.cert, opts.capath)
+    elif opts.image:
+        res = predictImage(opts.url, opts.image, opts.model, opts.verbose, opts.ckey, opts.cert, opts.capath)
+    if res:
+        print(res)
+
+if __name__ == '__main__':
+    main()