comparison uniprot.py @ 9:f31d8d59ffb6 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/uniprot_rest_interface commit 1c020106d4d7f957c9f1ec0d9885bbb2d56e70e7
author bgruening
date Tue, 06 Aug 2024 14:49:34 +0000
parents d2ad6e2c55d1
children
comparison
equal deleted inserted replaced
8:d2ad6e2c55d1 9:f31d8d59ffb6
1 #!/usr/bin/env python
2 """
3 uniprot python interface
4 to access the uniprot database
5
6 Based on work from Jan Rudolph: https://github.com/jdrudolph/uniprot
7 available services:
8 map
9 retrieve
10
11 rewitten using inspiration form: https://findwork.dev/blog/advanced-usage-python-requests-timeouts-retries-hooks/
12 """
13 import argparse 1 import argparse
2 import json
3 import re
14 import sys 4 import sys
5 import time
6 import zlib
7 from time import sleep
8 from urllib.parse import (
9 parse_qs,
10 urlencode,
11 urlparse,
12 )
13 from xml.etree import ElementTree
15 14
16 import requests 15 import requests
17 from requests.adapters import HTTPAdapter 16 from requests.adapters import (
18 from requests.packages.urllib3.util.retry import Retry 17 HTTPAdapter,
19 18 Retry,
20
21 DEFAULT_TIMEOUT = 5 # seconds
22 URL = 'https://legacy.uniprot.org/'
23
24 retry_strategy = Retry(
25 total=5,
26 backoff_factor=2,
27 status_forcelist=[429, 500, 502, 503, 504],
28 allowed_methods=["HEAD", "GET", "OPTIONS", "POST"]
29 ) 19 )
30 20
31 21
32 class TimeoutHTTPAdapter(HTTPAdapter): 22 BATCH_SIZE = 50000 # Limit at UniProt is 100k
33 def __init__(self, *args, **kwargs): 23 POLLING_INTERVAL = 5
34 self.timeout = DEFAULT_TIMEOUT 24 API_URL = "https://rest.uniprot.org"
35 if "timeout" in kwargs: 25
36 self.timeout = kwargs["timeout"] 26
37 del kwargs["timeout"] 27 retries = Retry(total=5, backoff_factor=0.25, status_forcelist=[500, 502, 503, 504])
38 super().__init__(*args, **kwargs) 28 session = requests.Session()
39 29 session.mount("https://", HTTPAdapter(max_retries=retries))
40 def send(self, request, **kwargs): 30
41 timeout = kwargs.get("timeout") 31
42 if timeout is None: 32 def check_response(response):
43 kwargs["timeout"] = self.timeout 33 try:
44 return super().send(request, **kwargs) 34 response.raise_for_status()
45 35 except requests.HTTPError:
46 36 raise
47 def _map(query, f, t, format='tab', chunk_size=100): 37
48 """ _map is not meant for use with the python interface, use `map` instead 38
49 """ 39 def submit_id_mapping(from_db, to_db, ids):
50 tool = 'uploadlists/' 40 print(f"{from_db} {to_db}")
51 data = {'format': format, 'from': f, 'to': t} 41 request = requests.post(
52 42 f"{API_URL}/idmapping/run",
53 req = [] 43 data={"from": from_db, "to": to_db, "ids": ",".join(ids)},
54 for i in range(0, len(query), chunk_size): 44 )
55 q = query[i:i + chunk_size] 45 check_response(request)
56 req.append(dict([("url", URL + tool), 46 return request.json()["jobId"]
57 ('data', data), 47
58 ("files", {'file': ' '.join(q)})])) 48
59 return req 49 def get_next_link(headers):
60 response = requests.post(URL + tool, data=data) 50 re_next_link = re.compile(r'<(.+)>; rel="next"')
61 response.raise_for_status() 51 if "Link" in headers:
62 page = response.text 52 match = re_next_link.match(headers["Link"])
63 if "The service is temporarily unavailable" in page: 53 if match:
64 exit("The UNIPROT service is temporarily unavailable. Please try again later.") 54 return match.group(1)
65 return page 55
66 56
67 57 def check_id_mapping_results_ready(job_id):
68 if __name__ == '__main__': 58 while True:
69 parser = argparse.ArgumentParser(description='retrieve uniprot mapping') 59 request = session.get(f"{API_URL}/idmapping/status/{job_id}")
70 subparsers = parser.add_subparsers(dest='tool') 60 check_response(request)
71 61 j = request.json()
72 mapping = subparsers.add_parser('map') 62 if "jobStatus" in j:
73 mapping.add_argument('f', help='from') 63 if j["jobStatus"] in ["NEW", "RUNNING"]:
74 mapping.add_argument('t', help='to') 64 print(f"Retrying in {POLLING_INTERVAL}s")
75 mapping.add_argument('inp', nargs='?', type=argparse.FileType('r'), 65 time.sleep(POLLING_INTERVAL)
76 default=sys.stdin, help='input file (default: stdin)') 66 else:
77 mapping.add_argument('out', nargs='?', type=argparse.FileType('w'), 67 raise Exception(j["jobStatus"])
78 default=sys.stdout, help='output file (default: stdout)') 68 else:
79 mapping.add_argument('--format', default='tab', help='output format') 69 return bool(j["results"] or j["failedIds"])
80 70
81 retrieve = subparsers.add_parser('retrieve') 71
82 retrieve.add_argument('inp', metavar='in', nargs='?', type=argparse.FileType('r'), 72 def get_batch(batch_response, file_format, compressed):
83 default=sys.stdin, help='input file (default: stdin)') 73 batch_url = get_next_link(batch_response.headers)
84 retrieve.add_argument('out', nargs='?', type=argparse.FileType('w'), 74 while batch_url:
85 default=sys.stdout, help='output file (default: stdout)') 75 batch_response = session.get(batch_url)
86 retrieve.add_argument('-f', '--format', help='specify output format', default='txt') 76 batch_response.raise_for_status()
77 yield decode_results(batch_response, file_format, compressed)
78 batch_url = get_next_link(batch_response.headers)
79
80
81 def combine_batches(all_results, batch_results, file_format):
82 if file_format == "json":
83 for key in ("results", "failedIds"):
84 if key in batch_results and batch_results[key]:
85 all_results[key] += batch_results[key]
86 elif file_format == "tsv":
87 return all_results + batch_results[1:]
88 else:
89 return all_results + batch_results
90 return all_results
91
92
93 def get_id_mapping_results_link(job_id):
94 url = f"{API_URL}/idmapping/details/{job_id}"
95 request = session.get(url)
96 check_response(request)
97 return request.json()["redirectURL"]
98
99
100 def decode_results(response, file_format, compressed):
101 if compressed:
102 decompressed = zlib.decompress(response.content, 16 + zlib.MAX_WBITS)
103 if file_format == "json":
104 j = json.loads(decompressed.decode("utf-8"))
105 return j
106 elif file_format in ["tsv", "gff"]:
107 return [line for line in decompressed.decode("utf-8").split("\n") if line]
108 elif file_format == "xlsx":
109 return [decompressed]
110 elif file_format == "xml":
111 return [decompressed.decode("utf-8")]
112 else:
113 return decompressed.decode("utf-8")
114 elif file_format == "json":
115 return response.json()
116 elif file_format in ["tsv", "gff"]:
117 return [line for line in response.text.split("\n") if line]
118 elif file_format == "xlsx":
119 return [response.content]
120 elif file_format == "xml":
121 return [response.text]
122 return response.text
123
124
125 def get_xml_namespace(element):
126 m = re.match(r"\{(.*)\}", element.tag)
127 return m.groups()[0] if m else ""
128
129
130 def merge_xml_results(xml_results):
131 merged_root = ElementTree.fromstring(xml_results[0])
132 for result in xml_results[1:]:
133 root = ElementTree.fromstring(result)
134 for child in root.findall("{http://uniprot.org/uniprot}entry"):
135 merged_root.insert(-1, child)
136 ElementTree.register_namespace("", get_xml_namespace(merged_root[0]))
137 return ElementTree.tostring(merged_root, encoding="utf-8", xml_declaration=True)
138
139
140 def print_progress_batches(batch_index, size, total):
141 n_fetched = min((batch_index + 1) * size, total)
142 print(f"Fetched: {n_fetched} / {total}")
143
144
145 def get_id_mapping_results_search(url, first):
146 parsed = urlparse(url)
147 query = parse_qs(parsed.query)
148 file_format = query["format"][0] if "format" in query else "json"
149 if "size" in query:
150 size = int(query["size"][0])
151 else:
152 size = 500
153 query["size"] = size
154 compressed = (
155 query["compressed"][0].lower() == "true" if "compressed" in query else False
156 )
157 parsed = parsed._replace(query=urlencode(query, doseq=True))
158 url = parsed.geturl()
159 request = session.get(url)
160 check_response(request)
161 results = decode_results(request, file_format, compressed)
162 total = int(request.headers["x-total-results"])
163 print_progress_batches(0, size, total)
164 for i, batch in enumerate(get_batch(request, file_format, compressed), 1):
165 results = combine_batches(results, batch, file_format)
166 print_progress_batches(i, size, total)
167 if len(results) > 1 and file_format == "tsv" and not first:
168 results = results[1:]
169 if file_format == "xml":
170 return merge_xml_results(results)
171 return results
172
173
174 # print(results)
175 # {'results': [{'from': 'P05067', 'to': 'CHEMBL2487'}], 'failedIds': ['P12345']}
176
177 if __name__ == "__main__":
178 parser = argparse.ArgumentParser(description="retrieve uniprot mapping")
179 subparsers = parser.add_subparsers(dest="tool")
180
181 mapping = subparsers.add_parser("map")
182 mapping.add_argument("f", help="from")
183 mapping.add_argument("t", help="to")
184 mapping.add_argument(
185 "inp",
186 nargs="?",
187 type=argparse.FileType("r"),
188 default=sys.stdin,
189 help="input file (default: stdin)",
190 )
191 mapping.add_argument(
192 "out",
193 nargs="?",
194 type=argparse.FileType("w"),
195 default=sys.stdout,
196 help="output file (default: stdout)",
197 )
198 mapping.add_argument("--format", default="tab", help="output format")
199
200 retrieve = subparsers.add_parser("retrieve")
201 retrieve.add_argument(
202 "inp",
203 metavar="in",
204 nargs="?",
205 type=argparse.FileType("r"),
206 default=sys.stdin,
207 help="input file (default: stdin)",
208 )
209 retrieve.add_argument(
210 "out",
211 nargs="?",
212 type=argparse.FileType("w"),
213 default=sys.stdout,
214 help="output file (default: stdout)",
215 )
216 retrieve.add_argument("-f", "--format", help="specify output format", default="txt")
217 mapping = subparsers.add_parser("menu")
87 218
88 args = parser.parse_args() 219 args = parser.parse_args()
220
221 # code for auto generating the from - to conditional
222 if args.tool == "menu":
223 from lxml import etree
224
225 request = session.get("https://rest.uniprot.org/configure/idmapping/fields")
226 check_response(request)
227 fields = request.json()
228
229 tos = dict()
230 from_cond = etree.Element("conditional", name="from_cond")
231 from_select = etree.SubElement(
232 from_cond, "param", name="from", type="select", label="Source database:"
233 )
234
235 rules = dict()
236 for rule in fields["rules"]:
237 rules[rule["ruleId"]] = rule["tos"]
238
239 for group in fields["groups"]:
240 group_name = group["groupName"]
241 group_name = group_name.replace("databases", "DBs")
242 for item in group["items"]:
243 if item["to"]:
244 tos[item["name"]] = f"{group_name} - {item['displayName']}"
245
246 for group in fields["groups"]:
247 group_name = group["groupName"]
248 group_name = group_name.replace("databases", "DBs")
249 for item in group["items"]:
250 if not item["from"]:
251 continue
252 option = etree.SubElement(from_select, "option", value=item["name"])
253 option.text = f"{group_name} - {item['displayName']}"
254 when = etree.SubElement(from_cond, "when", value=item["name"])
255
256 to_select = etree.SubElement(
257 when, "param", name="to", type="select", label="Target database:"
258 )
259 ruleId = item["ruleId"]
260 for to in rules[ruleId]:
261 option = etree.SubElement(to_select, "option", value=to)
262 option.text = tos[to]
263 etree.indent(from_cond, space=" ")
264 print(etree.tostring(from_cond, pretty_print=True, encoding="unicode"))
265 sys.exit(0)
89 266
90 # get the IDs from the file as sorted list 267 # get the IDs from the file as sorted list
91 # (sorted is convenient for testing) 268 # (sorted is convenient for testing)
92 query = set() 269 query = set()
93 for line in args.inp: 270 for line in args.inp:
94 query.add(line.strip()) 271 query.add(line.strip())
95 query = sorted(query) 272 query = list(query)
96 273 results = []
97 if args.tool == 'map': 274 first = True # if False the header is removed
98 pload = _map(query, args.f, args.t, chunk_size=100) 275 while len(query) > 0:
99 elif args.tool == 'retrieve': 276 batch = query[:BATCH_SIZE]
100 pload = _map(query, 'ACC+ID', 'ACC', args.format, chunk_size=100) 277 query = query[BATCH_SIZE:]
101 278 print(f"processing {len(batch)} left {len(query)}")
102 adapter = TimeoutHTTPAdapter(max_retries=retry_strategy) 279 if args.tool == "map":
103 http = requests.Session() 280 job_id = submit_id_mapping(from_db=args.f, to_db=args.t, ids=batch)
104 http.mount("https://", adapter) 281 elif args.tool == "retrieve":
105 for i, p in enumerate(pload): 282 job_id = submit_id_mapping(from_db="UniProtKB_AC-ID", to_db="UniProtKB", ids=batch)
106 response = http.post(**p) 283
107 args.out.write(response.text) 284 if check_id_mapping_results_ready(job_id):
108 http.close() 285 link = get_id_mapping_results_link(job_id)
286 link = f"{link}?format={args.format}"
287 print(link)
288 results.extend(get_id_mapping_results_search(link, first))
289 first = False
290 print(f"got {len(results)} results so far")
291 if len(query):
292 sleep(5)
293
294 if not isinstance(results, str):
295 results = "\n".join(results)
296 args.out.write(f"{results}\n")