comparison scripts/infer_stasis_clusters.py @ 3:35224ab3a175 draft default tip

planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/main/tools/hyphy commit cee1ce4bd7d82088b9bf62403bc175c13223e020
author iuc
date Wed, 11 Mar 2026 11:10:09 +0000
parents
children
comparison
equal deleted inserted replaced
2:7c25f1fb45da 3:35224ab3a175
1 #!/usr/bin/env python3
2 """
3 B-STILL Stasis Cluster Inference Tool
4 ====================================
5 Identifies regional footprints of extreme purifying selection (stasis) in B-STILL
6 JSON results using a FWER-controlled Hypergeometric Scan Statistic.
7
8 Usage:
9 python3 infer_stasis_clusters.py input.json --ebf 10 --permutations 10000 --output results.json
10 """
11
12 import argparse
13 import json
14 import sys
15 import time
16
17 import numpy as np
18 from scipy.stats import hypergeom
19
20
21 def get_sf_optimized(n, d, L, K, cache):
22 """Retrieves or computes Hypergeometric Survival Function value."""
23 key = (n, d)
24 if key not in cache:
25 cache[key] = hypergeom.sf(n - 1, L, K, d)
26 return cache[key]
27
28
29 def scan_intervals(indices, L, K, max_size, sf_cache, threshold=None):
30 """
31 Scans all possible intervals [i, j] anchored by stasis events.
32 Returns the minimum p-value if threshold is None, else returns all significant segments.
33 """
34 best_p = 1.0
35 segments = []
36 num_events = len(indices)
37
38 for n in range(3, min(max_size + 1, num_events + 1)):
39 for i in range(num_events - n + 1):
40 d = indices[i + n - 1] - indices[i] + 1
41 p = get_sf_optimized(n, d, L, K, sf_cache)
42
43 if threshold is None:
44 if p < best_p:
45 best_p = p
46 else:
47 if p <= threshold:
48 segments.append({
49 "start": int(indices[i] + 1),
50 "end": int(indices[i + n - 1] + 1),
51 "p_value": p,
52 "k": n,
53 "d": int(d)
54 })
55
56 return best_p if threshold is None else segments
57
58
59 def merge_segments(segments, merge_dist=15):
60 """Merges overlapping or nearby significant segments."""
61 if not segments:
62 return []
63 segments.sort(key=lambda x: x['start'])
64
65 merged = []
66 curr = segments[0]
67 for next_s in segments[1:]:
68 if next_s['start'] <= curr['end'] + merge_dist:
69 curr['end'] = max(curr['end'], next_s['end'])
70 curr['p_value'] = min(curr['p_value'], next_s['p_value'])
71 curr['d'] = curr['end'] - curr['start'] + 1
72 else:
73 merged.append(curr)
74 curr = next_s
75 merged.append(curr)
76 return merged
77
78
79 def main():
80 parser = argparse.ArgumentParser(description="Infer stasis clusters from B-STILL JSON.")
81 parser.add_argument("input", help="Path to B-STILL JSON result file")
82 parser.add_argument("--ebf", type=float, default=10.0, help="EBF threshold for defining stasis sites (default: 10.0)")
83 parser.add_argument("--permutations", type=int, default=10000, help="Number of permutations for FWER control (default: 10000)")
84 parser.add_argument("--alpha", type=float, default=0.05, help="Family-wise error rate threshold (default: 0.05)")
85 parser.add_argument("--max-cluster", type=int, default=30, help="Maximum number of stasis sites per interval scan (default: 30)")
86 parser.add_argument("--merge", type=int, default=15, help="Distance in codons to merge adjacent clusters (default: 15)")
87 parser.add_argument("--output", help="Path to save results in JSON format")
88
89 args = parser.parse_args()
90
91 try:
92 with open(args.input, "r") as f:
93 data = json.load(f)
94 except Exception as e:
95 print("Error loading JSON: {0}".format(e))
96 sys.exit(1)
97
98 sites = data.get("MLE", {}).get("content", {}).get("0", [])
99 ebfs = [s[12] if (len(s) > 12 and isinstance(s[12], (int, float))) else 0 for s in sites]
100 L = len(ebfs)
101
102 if L < 10:
103 print("Alignment too short for cluster analysis.")
104 sys.exit(0)
105
106 stasis_indices = np.array([i for i, val in enumerate(ebfs) if val >= args.ebf])
107 K = len(stasis_indices)
108
109 print("--- B-STILL Cluster Inference ---")
110 print("Input: {0}".format(args.input))
111 print("Gene Length (L): {0} codons".format(L))
112 print("Stasis Sites (K): {0} (EBF >= {1})".format(K, args.ebf))
113
114 if K < 3:
115 print("Insufficient stasis sites to form clusters (minimum 3 required).")
116 sys.exit(0)
117
118 print("Running {0} permutations for FWER control...".format(args.permutations))
119 null_min_ps = []
120 all_positions = np.arange(L)
121 sf_cache = {}
122
123 start_time = time.time()
124 for i in range(args.permutations):
125 if i > 0 and i % 1000 == 0:
126 elapsed = time.time() - start_time
127 print(" Processed {0} permutations... ({1:.1f} per sec)".format(i, i / elapsed))
128 shuffled = sorted(np.random.choice(all_positions, K, replace=False))
129 min_p = scan_intervals(shuffled, L, K, args.max_cluster, sf_cache)
130 null_min_ps.append(min_p)
131
132 crit_p = np.percentile(null_min_ps, args.alpha * 100)
133 print("Gene-specific Critical P-value (FWER {0}): {1:.2e}".format(args.alpha, crit_p))
134
135 print("Scanning observed sequence for significant clusters...")
136 raw_segments = scan_intervals(stasis_indices, L, K, args.max_cluster, sf_cache, threshold=crit_p)
137
138 final_clusters = merge_segments(raw_segments, merge_dist=args.merge)
139
140 for c in final_clusters:
141 c['k'] = sum(1 for idx in stasis_indices if c['start'] <= idx + 1 <= c['end'])
142
143 print("\nFound {0} significant stasis clusters:".format(len(final_clusters)))
144 if final_clusters:
145 print("\nLegend:")
146 print(" k : Number of high-confidence stasis sites within the cluster")
147 print(" d : Total span of the cluster in codons")
148 print("\n{:<8} | {:<8} | {:<5} | {:<5} | {:<10}".format("Start", "End", "k", "d", "P-value"))
149 print("-" * 45)
150 for c in final_clusters:
151 print("{:<8} | {:<8} | {:<5} | {:<5} | {:.2e}".format(c['start'], c['end'], c['k'], c['d'], c['p_value']))
152
153 if args.output:
154 output_data = {
155 "input_file": args.input,
156 "parameters": vars(args),
157 "summary": {
158 "gene_length": L,
159 "total_stasis_sites": K,
160 "critical_p_value": float(crit_p),
161 "num_clusters": len(final_clusters)
162 },
163 "clusters": final_clusters
164 }
165 with open(args.output, "w") as f:
166 json.dump(output_data, f, indent=4)
167 print("\nDetailed results saved to {0}".format(args.output))
168
169
170 if __name__ == "__main__":
171 main()