Skip to content

Commit 0e0e424

Browse files
committed
smc.py: add subset sampling
1 parent f414f70 commit 0e0e424

3 files changed

Lines changed: 274 additions & 21 deletions

File tree

gpmp/misc/knn_cov.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def estimate_cov_matrix_knn(
2222
x,
2323
n_random=50,
2424
n_neighbors=50,
25-
ef=50,
25+
ef=100,
2626
max_ef_construction=200,
2727
M=16,
2828
):

gpmp/misc/smc.py

Lines changed: 271 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,30 +22,28 @@
2222
from scipy.optimize import brentq
2323
import gpmp.num as gnp
2424
import gpmp.misc.knn_cov
25-
import matplotlib.pyplot as plt
26-
import matplotlib.gridspec as gridspec
2725

2826

2927
@dataclass
3028
class ParticlesSetConfig:
3129
initial_distribution_type: str = "randunif"
3230
resample_scheme: str = "multinomial"
3331
param_s_initial_value: float = 0.5
34-
param_s_upper_bound: float = 1e4
32+
param_s_upper_bound: float = 1e5
3533
param_s_lower_bound: float = 1e-3
3634
jitter_initial_value: float = 1e-16
3735
jitter_max_iterations: int = 10
38-
covariance_method: str = "knn"
39-
covariance_knn_n_random: int = 50
40-
covariance_knn_n_neighbors: int = 100
36+
covariance_method: str = "normal"
37+
covariance_knn_n_random: int = 20
38+
covariance_knn_n_neighbors: int = 200
4139

4240

4341
@dataclass
4442
class SMCConfig:
4543
compute_next_logpdf_param_method: str = "p0" # or "ess"
4644
mh_steps: int = 20
47-
mh_acceptation_rate_min: float = 0.2
48-
mh_acceptation_rate_max: float = 0.4
45+
mh_acceptation_rate_min: float = 0.15
46+
mh_acceptation_rate_max: float = 0.30
4947
mh_adjustment_factor: float = 1.4
5048
mh_adjustment_max_iterations: int = 50
5149

@@ -312,16 +310,19 @@ def residual_resample(self, debug=False):
312310

313311
# Multinomial step on residuals
314312
if N_residual > 0:
315-
try:
316-
p_vals = residuals / N_residual
313+
residuals = gnp.maximum(residuals, 0.0)
314+
total_residual = gnp.sum(residuals)
315+
if total_residual == 0.0:
316+
p_vals = gnp.full_like(residuals, 1.0 / len(residuals))
317+
else:
318+
p_vals = residuals / total_residual
317319

318-
counts_res = ParticlesSet.multinomial_rvs(
319-
N_residual, residuals / N_residual, self.rng
320-
)
321-
except Exception:
322-
extype, value, tb = __import__("sys").exc_info()
323-
__import__("traceback").print_exc()
324-
__import__("pdb").post_mortem(tb)
320+
# Defensive check before multinomial draw
321+
if gnp.any(p_vals < 0) or gnp.any(p_vals > 1) or gnp.any(gnp.isnan(p_vals)):
322+
print("Residual resampling error: invalid p_vals.")
323+
__import__("pdb").set_trace()
324+
325+
counts_res = ParticlesSet.multinomial_rvs(N_residual, p_vals, self.rng)
325326
else:
326327
counts_res = gnp.zeros_like(counts_det)
327328

@@ -1124,6 +1125,7 @@ def plot_empirical_distributions(
11241125
11251126
If `parameter_indices_pooled` is not None, plot multiple marginals on the same figure.
11261127
"""
1128+
import matplotlib.pyplot as plt
11271129
from itertools import cycle
11281130

11291131
if self.particles.x is None:
@@ -1254,7 +1256,9 @@ def run_smc_sampling(
12541256
# Create the SMC instance using the configuration objects. If none are provided,
12551257
# defaults are used based on the dataclass definitions.
12561258
if particles_config is None:
1257-
particles_config = ParticlesSetConfig(resample_scheme="residual")
1259+
particles_config = ParticlesSetConfig(
1260+
resample_scheme="residual", covariance_method="normal"
1261+
)
12581262
if smc_config is None:
12591263
smc_config = SMCConfig(
12601264
compute_next_logpdf_param_method=compute_next_logpdf_param_method,
@@ -1294,6 +1298,132 @@ def run_smc_sampling(
12941298
return smc.particles.x, smc
12951299

12961300

1301+
def log_indicator_density(f, threshold, log_px, tail="lower"):
1302+
"""Return logpdf(x) = log(1_{f(x) ? threshold} * p_X(x)) where ? depends on tail."""
1303+
1304+
def logpdf(x):
1305+
x = gnp.asarray(x)
1306+
fx = gnp.asarray(f(x))
1307+
logpx = log_px(x)
1308+
if tail == "lower":
1309+
return gnp.where(fx < threshold, logpx, gnp.asarray(-1e100))
1310+
elif tail == "upper":
1311+
return gnp.where(fx > threshold, logpx, gnp.asarray(-1e100))
1312+
else:
1313+
raise ValueError(f"Invalid tail argument: {tail}")
1314+
1315+
return logpdf
1316+
1317+
1318+
def run_subset_simulation(
1319+
f,
1320+
thresholds,
1321+
init_box,
1322+
log_px,
1323+
tail="upper",
1324+
n_particles=1000,
1325+
mh_steps=20,
1326+
min_acceptation=0.15,
1327+
max_acceptation=0.30,
1328+
resample_scheme="residual",
1329+
debug=False,
1330+
):
1331+
"""Estimate P(f(X) ? u_T), with ? = < or >, via Subset Simulation.
1332+
1333+
Parameters
1334+
----------
1335+
f : callable
1336+
Function from R^d to R (performance or score function).
1337+
thresholds : list of float
1338+
Monotonic threshold sequence: decreasing for '<', increasing for '>'.
1339+
init_box : list of [lower_bounds, upper_bounds]
1340+
Sampling domain for initial distribution.
1341+
log_px : callable
1342+
Log-density of the base distribution p_X.
1343+
tail : str
1344+
Either 'lower' (f < u_i) or 'upper' (f > u_i).
1345+
Returns
1346+
-------
1347+
p_estimate : float
1348+
Final estimate of P(f(X) ? u_T).
1349+
stage_probs : list of float
1350+
Estimated conditional probabilities p_{u_i | u_{i-1}}.
1351+
smc : SMC
1352+
The SMC object with diagnostics.
1353+
"""
1354+
if tail == "lower":
1355+
assert thresholds[0] == float(
1356+
"inf"
1357+
), "First threshold must be +8 for tail='lower'."
1358+
elif tail == "upper":
1359+
assert thresholds[0] == float(
1360+
"-inf"
1361+
), "First threshold must be -8 for tail='upper'."
1362+
else:
1363+
raise ValueError(f"Invalid tail: {tail}")
1364+
1365+
# Set up configs
1366+
particles_config = ParticlesSetConfig(
1367+
initial_distribution_type="randunif",
1368+
resample_scheme=resample_scheme,
1369+
)
1370+
smc_config = SMCConfig(
1371+
compute_next_logpdf_param_method="p0", # not used
1372+
mh_steps=mh_steps,
1373+
mh_acceptation_rate_min=min_acceptation,
1374+
mh_acceptation_rate_max=max_acceptation,
1375+
)
1376+
1377+
smc = SMC(
1378+
init_box,
1379+
n=n_particles,
1380+
particles_config=particles_config,
1381+
smc_config=smc_config,
1382+
)
1383+
1384+
# Initialize particles
1385+
smc.particles.particles_init(init_box, n_particles)
1386+
smc.log_data["target_logpdf_param"] = thresholds[1]
1387+
1388+
stage_probs = gnp.empty(len(thresholds)-1)
1389+
1390+
for k in range(1, len(thresholds)):
1391+
uk = thresholds[k]
1392+
uk_prev = thresholds[k - 1]
1393+
if debug:
1394+
print(f"\n[Stage {k}] Threshold u_k = {uk:.2f}")
1395+
1396+
# Construct logpdf for current level
1397+
logpdf_k = log_indicator_density(f, uk, log_px, tail=tail)
1398+
smc.particles.set_logpdf(logpdf_k)
1399+
1400+
# Reweight
1401+
smc.particles.reweight()
1402+
1403+
# Compute conditional probability p_{u_k | u_{k-1}}
1404+
w_sum = gnp.sum(smc.particles.w)
1405+
stage_probs[k-1] = w_sum
1406+
1407+
if debug:
1408+
print(f" p_{{{uk:.2f} | {uk_prev:.2f}}} - {w_sum:.2f}")
1409+
1410+
# Normalize weights
1411+
smc.particles.w = smc.particles.w / w_sum
1412+
1413+
# Resample and MH move
1414+
smc.particles.resample(debug=debug)
1415+
smc.move_with_controlled_acceptation_rate(debug=debug)
1416+
for _ in range(mh_steps - 1):
1417+
smc.particles.move()
1418+
1419+
smc.stage += 1
1420+
smc.log_snapshot()
1421+
1422+
# Final estimate of tail probability
1423+
p_estimate = float(gnp.prod(stage_probs))
1424+
return p_estimate, stage_probs, smc
1425+
1426+
12971427
def test_run_smc_sampling_gaussian_mixture():
12981428
import matplotlib.pyplot as plt
12991429
from scipy import stats
@@ -1358,5 +1488,128 @@ def logpdf_mixture(x, beta):
13581488
smc_instance.plot_state()
13591489

13601490

1491+
def test_subset_sampling_gaussian_icdf():
1492+
from scipy.stats import norm
1493+
import matplotlib.pyplot as plt
1494+
import numpy as np
1495+
1496+
# Define f(x) = inverse CDF of standard normal
1497+
def f(x):
1498+
return norm.ppf(x[:, 0])
1499+
1500+
# Quantile levels and corresponding thresholds
1501+
q_levels = [0.0, 0.5, 0.9, 0.97, 0.99, 1 - 1e-3, 1 - 1e-4, 1 - 1e-5, 1 - 1e-6]
1502+
thresholds = [float("-inf")] + list(norm.ppf(q_levels[1:]))
1503+
1504+
# Domain: Uniform[0, 1]
1505+
box = [[0.0], [1.0]]
1506+
1507+
def log_px(x):
1508+
inside = gnp.all((x >= 0.0) & (x <= 1.0), axis=1)
1509+
return gnp.where(inside, 0.0, -1e100)
1510+
1511+
# Run subset simulation
1512+
p_hat, stage_probs, smc = run_subset_simulation(
1513+
f=f,
1514+
thresholds=thresholds,
1515+
init_box=box,
1516+
log_px=log_px,
1517+
tail="upper",
1518+
n_particles=10000,
1519+
debug=True,
1520+
)
1521+
1522+
# Exact conditional and sequential probabilities
1523+
exact_conditional_probs = []
1524+
exact_sequential_probs = []
1525+
for i in range(1, len(q_levels)):
1526+
q_i = q_levels[i]
1527+
q_prev = q_levels[i - 1]
1528+
exact_conditional_probs.append((1 - q_i) / (1 - q_prev))
1529+
exact_sequential_probs.append(1 - q_i)
1530+
1531+
# Estimated sequential probabilities
1532+
estimated_sequential_probs = np.cumprod(stage_probs)
1533+
1534+
# Print results
1535+
print("\nEstimated conditional probs:", [f"{p:.2e}" for p in stage_probs])
1536+
print("Exact conditional probs :", [f"{p:.2e}" for p in exact_conditional_probs])
1537+
print(
1538+
"Estimated sequential probs:", [f"{p:.2e}" for p in estimated_sequential_probs]
1539+
)
1540+
print("Exact sequential probs :", [f"{p:.2e}" for p in exact_sequential_probs])
1541+
print(f"Estimated final probability: {p_hat:.3e}")
1542+
print(f"Exact final probability: {exact_sequential_probs[-1]:.3e}")
1543+
1544+
# Plot conditional probabilities
1545+
stages = list(range(1, len(thresholds)))
1546+
plt.figure(figsize=(6, 4))
1547+
plt.plot(stages, stage_probs, marker="o", label="Estimated p_{u_i | u_{i-1}}")
1548+
plt.plot(
1549+
stages,
1550+
exact_conditional_probs,
1551+
marker="x",
1552+
linestyle="--",
1553+
label="Exact p_{u_i | u_{i-1}}",
1554+
)
1555+
plt.xlabel("Stage (i)")
1556+
plt.ylabel("p")
1557+
plt.title("Conditional probabilities")
1558+
plt.legend()
1559+
plt.grid(True)
1560+
plt.tight_layout()
1561+
plt.show()
1562+
1563+
# Plot sequential tail probabilities
1564+
plt.figure(figsize=(6, 4))
1565+
plt.plot(
1566+
stages, estimated_sequential_probs, marker="o", label="Estimated P(f(X) > u_i)"
1567+
)
1568+
plt.plot(
1569+
stages,
1570+
exact_sequential_probs,
1571+
marker="x",
1572+
linestyle="--",
1573+
label="Exact P(f(X) > u_i)",
1574+
)
1575+
plt.xlabel("Stage (i)")
1576+
plt.ylabel("P(f(X) > u_i)")
1577+
plt.yscale("log")
1578+
plt.title("Sequential tail probabilities")
1579+
plt.legend()
1580+
plt.grid(True, which="both")
1581+
plt.tight_layout()
1582+
plt.show()
1583+
1584+
# Plot the inverse CDF: zoom on upper tail using 1 - x (log scale)
1585+
x_vals = np.linspace(1e-8, 1 - 1e-8, 1000)
1586+
y_vals = norm.ppf(x_vals)
1587+
x_tail = 1 - x_vals # Tail probability
1588+
1589+
plt.figure(figsize=(6, 4))
1590+
plt.semilogx(x_tail, y_vals, label=r"$f(x) = \Phi^{-1}(x)$")
1591+
1592+
for i, q in enumerate(q_levels[1:], start=1):
1593+
x_q = q
1594+
y_q = thresholds[i]
1595+
x_tail_q = 1 - x_q
1596+
plt.axhline(y=y_q, color="gray", linestyle="dotted")
1597+
plt.plot([x_tail_q], [y_q], "ro")
1598+
plt.text(
1599+
x_tail_q, y_q, f"$1-q={1 - q:.3g}$", fontsize=8, va="bottom", ha="right"
1600+
)
1601+
1602+
plt.xlabel(r"$1 - x$ (tail probability)")
1603+
plt.ylabel(r"$\Phi^{-1}(x)$")
1604+
plt.title("Zoom on right tail of standard gaussian inverse cdf")
1605+
plt.grid(True, which="both")
1606+
plt.legend()
1607+
plt.tight_layout()
1608+
plt.show()
1609+
1610+
13611611
if __name__ == "__main__":
1612+
print("Sample a Gaussian mixture")
13621613
test_run_smc_sampling_gaussian_mixture()
1614+
print("Subset sampling")
1615+
test_subset_sampling_gaussian_icdf()

gpmp/num.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -964,8 +964,8 @@ def inftobigf(a, bigf=fmax / 1000.0):
964964

965965
class DifferentiableFunction:
966966
def __init__(self, f):
967-
self.f = jax.jit(f)
968-
self.f_grad = jax.jit(jax.grad(self.f))
967+
self.f = f # jax.jit(f)
968+
self.f_grad = jax.grad(self.f) # jax.jit(jax.grad(self.f))
969969
self.f_value = None
970970
self.x_value = None
971971

0 commit comments

Comments
 (0)