|
22 | 22 | from scipy.optimize import brentq |
23 | 23 | import gpmp.num as gnp |
24 | 24 | import gpmp.misc.knn_cov |
25 | | -import matplotlib.pyplot as plt |
26 | | -import matplotlib.gridspec as gridspec |
27 | 25 |
|
28 | 26 |
|
29 | 27 | @dataclass |
30 | 28 | class ParticlesSetConfig: |
31 | 29 | initial_distribution_type: str = "randunif" |
32 | 30 | resample_scheme: str = "multinomial" |
33 | 31 | param_s_initial_value: float = 0.5 |
34 | | - param_s_upper_bound: float = 1e4 |
| 32 | + param_s_upper_bound: float = 1e5 |
35 | 33 | param_s_lower_bound: float = 1e-3 |
36 | 34 | jitter_initial_value: float = 1e-16 |
37 | 35 | jitter_max_iterations: int = 10 |
38 | | - covariance_method: str = "knn" |
39 | | - covariance_knn_n_random: int = 50 |
40 | | - covariance_knn_n_neighbors: int = 100 |
| 36 | + covariance_method: str = "normal" |
| 37 | + covariance_knn_n_random: int = 20 |
| 38 | + covariance_knn_n_neighbors: int = 200 |
41 | 39 |
|
42 | 40 |
|
43 | 41 | @dataclass |
44 | 42 | class SMCConfig: |
45 | 43 | compute_next_logpdf_param_method: str = "p0" # or "ess" |
46 | 44 | mh_steps: int = 20 |
47 | | - mh_acceptation_rate_min: float = 0.2 |
48 | | - mh_acceptation_rate_max: float = 0.4 |
| 45 | + mh_acceptation_rate_min: float = 0.15 |
| 46 | + mh_acceptation_rate_max: float = 0.30 |
49 | 47 | mh_adjustment_factor: float = 1.4 |
50 | 48 | mh_adjustment_max_iterations: int = 50 |
51 | 49 |
|
@@ -312,16 +310,19 @@ def residual_resample(self, debug=False): |
312 | 310 |
|
313 | 311 | # Multinomial step on residuals |
314 | 312 | if N_residual > 0: |
315 | | - try: |
316 | | - p_vals = residuals / N_residual |
| 313 | + residuals = gnp.maximum(residuals, 0.0) |
| 314 | + total_residual = gnp.sum(residuals) |
| 315 | + if total_residual == 0.0: |
| 316 | + p_vals = gnp.full_like(residuals, 1.0 / len(residuals)) |
| 317 | + else: |
| 318 | + p_vals = residuals / total_residual |
317 | 319 |
|
318 | | - counts_res = ParticlesSet.multinomial_rvs( |
319 | | - N_residual, residuals / N_residual, self.rng |
320 | | - ) |
321 | | - except Exception: |
322 | | - extype, value, tb = __import__("sys").exc_info() |
323 | | - __import__("traceback").print_exc() |
324 | | - __import__("pdb").post_mortem(tb) |
| 320 | + # Defensive check before multinomial draw |
| 321 | + if gnp.any(p_vals < 0) or gnp.any(p_vals > 1) or gnp.any(gnp.isnan(p_vals)): |
| 322 | + print("Residual resampling error: invalid p_vals.") |
| 323 | + __import__("pdb").set_trace() |
| 324 | + |
| 325 | + counts_res = ParticlesSet.multinomial_rvs(N_residual, p_vals, self.rng) |
325 | 326 | else: |
326 | 327 | counts_res = gnp.zeros_like(counts_det) |
327 | 328 |
|
@@ -1124,6 +1125,7 @@ def plot_empirical_distributions( |
1124 | 1125 |
|
1125 | 1126 | If `parameter_indices_pooled` is not None, plot multiple marginals on the same figure. |
1126 | 1127 | """ |
| 1128 | + import matplotlib.pyplot as plt |
1127 | 1129 | from itertools import cycle |
1128 | 1130 |
|
1129 | 1131 | if self.particles.x is None: |
@@ -1254,7 +1256,9 @@ def run_smc_sampling( |
1254 | 1256 | # Create the SMC instance using the configuration objects. If none are provided, |
1255 | 1257 | # defaults are used based on the dataclass definitions. |
1256 | 1258 | if particles_config is None: |
1257 | | - particles_config = ParticlesSetConfig(resample_scheme="residual") |
| 1259 | + particles_config = ParticlesSetConfig( |
| 1260 | + resample_scheme="residual", covariance_method="normal" |
| 1261 | + ) |
1258 | 1262 | if smc_config is None: |
1259 | 1263 | smc_config = SMCConfig( |
1260 | 1264 | compute_next_logpdf_param_method=compute_next_logpdf_param_method, |
@@ -1294,6 +1298,132 @@ def run_smc_sampling( |
1294 | 1298 | return smc.particles.x, smc |
1295 | 1299 |
|
1296 | 1300 |
|
| 1301 | +def log_indicator_density(f, threshold, log_px, tail="lower"): |
| 1302 | + """Return logpdf(x) = log(1_{f(x) ? threshold} * p_X(x)) where ? depends on tail.""" |
| 1303 | + |
| 1304 | + def logpdf(x): |
| 1305 | + x = gnp.asarray(x) |
| 1306 | + fx = gnp.asarray(f(x)) |
| 1307 | + logpx = log_px(x) |
| 1308 | + if tail == "lower": |
| 1309 | + return gnp.where(fx < threshold, logpx, gnp.asarray(-1e100)) |
| 1310 | + elif tail == "upper": |
| 1311 | + return gnp.where(fx > threshold, logpx, gnp.asarray(-1e100)) |
| 1312 | + else: |
| 1313 | + raise ValueError(f"Invalid tail argument: {tail}") |
| 1314 | + |
| 1315 | + return logpdf |
| 1316 | + |
| 1317 | + |
| 1318 | +def run_subset_simulation( |
| 1319 | + f, |
| 1320 | + thresholds, |
| 1321 | + init_box, |
| 1322 | + log_px, |
| 1323 | + tail="upper", |
| 1324 | + n_particles=1000, |
| 1325 | + mh_steps=20, |
| 1326 | + min_acceptation=0.15, |
| 1327 | + max_acceptation=0.30, |
| 1328 | + resample_scheme="residual", |
| 1329 | + debug=False, |
| 1330 | +): |
| 1331 | + """Estimate P(f(X) ? u_T), with ? = < or >, via Subset Simulation. |
| 1332 | +
|
| 1333 | + Parameters |
| 1334 | + ---------- |
| 1335 | + f : callable |
| 1336 | + Function from R^d to R (performance or score function). |
| 1337 | + thresholds : list of float |
| 1338 | + Monotonic threshold sequence: decreasing for '<', increasing for '>'. |
| 1339 | + init_box : list of [lower_bounds, upper_bounds] |
| 1340 | + Sampling domain for initial distribution. |
| 1341 | + log_px : callable |
| 1342 | + Log-density of the base distribution p_X. |
| 1343 | + tail : str |
| 1344 | + Either 'lower' (f < u_i) or 'upper' (f > u_i). |
| 1345 | + Returns |
| 1346 | + ------- |
| 1347 | + p_estimate : float |
| 1348 | + Final estimate of P(f(X) ? u_T). |
| 1349 | + stage_probs : list of float |
| 1350 | + Estimated conditional probabilities p_{u_i | u_{i-1}}. |
| 1351 | + smc : SMC |
| 1352 | + The SMC object with diagnostics. |
| 1353 | + """ |
| 1354 | + if tail == "lower": |
| 1355 | + assert thresholds[0] == float( |
| 1356 | + "inf" |
| 1357 | + ), "First threshold must be +8 for tail='lower'." |
| 1358 | + elif tail == "upper": |
| 1359 | + assert thresholds[0] == float( |
| 1360 | + "-inf" |
| 1361 | + ), "First threshold must be -8 for tail='upper'." |
| 1362 | + else: |
| 1363 | + raise ValueError(f"Invalid tail: {tail}") |
| 1364 | + |
| 1365 | + # Set up configs |
| 1366 | + particles_config = ParticlesSetConfig( |
| 1367 | + initial_distribution_type="randunif", |
| 1368 | + resample_scheme=resample_scheme, |
| 1369 | + ) |
| 1370 | + smc_config = SMCConfig( |
| 1371 | + compute_next_logpdf_param_method="p0", # not used |
| 1372 | + mh_steps=mh_steps, |
| 1373 | + mh_acceptation_rate_min=min_acceptation, |
| 1374 | + mh_acceptation_rate_max=max_acceptation, |
| 1375 | + ) |
| 1376 | + |
| 1377 | + smc = SMC( |
| 1378 | + init_box, |
| 1379 | + n=n_particles, |
| 1380 | + particles_config=particles_config, |
| 1381 | + smc_config=smc_config, |
| 1382 | + ) |
| 1383 | + |
| 1384 | + # Initialize particles |
| 1385 | + smc.particles.particles_init(init_box, n_particles) |
| 1386 | + smc.log_data["target_logpdf_param"] = thresholds[1] |
| 1387 | + |
| 1388 | + stage_probs = gnp.empty(len(thresholds)-1) |
| 1389 | + |
| 1390 | + for k in range(1, len(thresholds)): |
| 1391 | + uk = thresholds[k] |
| 1392 | + uk_prev = thresholds[k - 1] |
| 1393 | + if debug: |
| 1394 | + print(f"\n[Stage {k}] Threshold u_k = {uk:.2f}") |
| 1395 | + |
| 1396 | + # Construct logpdf for current level |
| 1397 | + logpdf_k = log_indicator_density(f, uk, log_px, tail=tail) |
| 1398 | + smc.particles.set_logpdf(logpdf_k) |
| 1399 | + |
| 1400 | + # Reweight |
| 1401 | + smc.particles.reweight() |
| 1402 | + |
| 1403 | + # Compute conditional probability p_{u_k | u_{k-1}} |
| 1404 | + w_sum = gnp.sum(smc.particles.w) |
| 1405 | + stage_probs[k-1] = w_sum |
| 1406 | + |
| 1407 | + if debug: |
| 1408 | + print(f" p_{{{uk:.2f} | {uk_prev:.2f}}} - {w_sum:.2f}") |
| 1409 | + |
| 1410 | + # Normalize weights |
| 1411 | + smc.particles.w = smc.particles.w / w_sum |
| 1412 | + |
| 1413 | + # Resample and MH move |
| 1414 | + smc.particles.resample(debug=debug) |
| 1415 | + smc.move_with_controlled_acceptation_rate(debug=debug) |
| 1416 | + for _ in range(mh_steps - 1): |
| 1417 | + smc.particles.move() |
| 1418 | + |
| 1419 | + smc.stage += 1 |
| 1420 | + smc.log_snapshot() |
| 1421 | + |
| 1422 | + # Final estimate of tail probability |
| 1423 | + p_estimate = float(gnp.prod(stage_probs)) |
| 1424 | + return p_estimate, stage_probs, smc |
| 1425 | + |
| 1426 | + |
1297 | 1427 | def test_run_smc_sampling_gaussian_mixture(): |
1298 | 1428 | import matplotlib.pyplot as plt |
1299 | 1429 | from scipy import stats |
@@ -1358,5 +1488,128 @@ def logpdf_mixture(x, beta): |
1358 | 1488 | smc_instance.plot_state() |
1359 | 1489 |
|
1360 | 1490 |
|
| 1491 | +def test_subset_sampling_gaussian_icdf(): |
| 1492 | + from scipy.stats import norm |
| 1493 | + import matplotlib.pyplot as plt |
| 1494 | + import numpy as np |
| 1495 | + |
| 1496 | + # Define f(x) = inverse CDF of standard normal |
| 1497 | + def f(x): |
| 1498 | + return norm.ppf(x[:, 0]) |
| 1499 | + |
| 1500 | + # Quantile levels and corresponding thresholds |
| 1501 | + q_levels = [0.0, 0.5, 0.9, 0.97, 0.99, 1 - 1e-3, 1 - 1e-4, 1 - 1e-5, 1 - 1e-6] |
| 1502 | + thresholds = [float("-inf")] + list(norm.ppf(q_levels[1:])) |
| 1503 | + |
| 1504 | + # Domain: Uniform[0, 1] |
| 1505 | + box = [[0.0], [1.0]] |
| 1506 | + |
| 1507 | + def log_px(x): |
| 1508 | + inside = gnp.all((x >= 0.0) & (x <= 1.0), axis=1) |
| 1509 | + return gnp.where(inside, 0.0, -1e100) |
| 1510 | + |
| 1511 | + # Run subset simulation |
| 1512 | + p_hat, stage_probs, smc = run_subset_simulation( |
| 1513 | + f=f, |
| 1514 | + thresholds=thresholds, |
| 1515 | + init_box=box, |
| 1516 | + log_px=log_px, |
| 1517 | + tail="upper", |
| 1518 | + n_particles=10000, |
| 1519 | + debug=True, |
| 1520 | + ) |
| 1521 | + |
| 1522 | + # Exact conditional and sequential probabilities |
| 1523 | + exact_conditional_probs = [] |
| 1524 | + exact_sequential_probs = [] |
| 1525 | + for i in range(1, len(q_levels)): |
| 1526 | + q_i = q_levels[i] |
| 1527 | + q_prev = q_levels[i - 1] |
| 1528 | + exact_conditional_probs.append((1 - q_i) / (1 - q_prev)) |
| 1529 | + exact_sequential_probs.append(1 - q_i) |
| 1530 | + |
| 1531 | + # Estimated sequential probabilities |
| 1532 | + estimated_sequential_probs = np.cumprod(stage_probs) |
| 1533 | + |
| 1534 | + # Print results |
| 1535 | + print("\nEstimated conditional probs:", [f"{p:.2e}" for p in stage_probs]) |
| 1536 | + print("Exact conditional probs :", [f"{p:.2e}" for p in exact_conditional_probs]) |
| 1537 | + print( |
| 1538 | + "Estimated sequential probs:", [f"{p:.2e}" for p in estimated_sequential_probs] |
| 1539 | + ) |
| 1540 | + print("Exact sequential probs :", [f"{p:.2e}" for p in exact_sequential_probs]) |
| 1541 | + print(f"Estimated final probability: {p_hat:.3e}") |
| 1542 | + print(f"Exact final probability: {exact_sequential_probs[-1]:.3e}") |
| 1543 | + |
| 1544 | + # Plot conditional probabilities |
| 1545 | + stages = list(range(1, len(thresholds))) |
| 1546 | + plt.figure(figsize=(6, 4)) |
| 1547 | + plt.plot(stages, stage_probs, marker="o", label="Estimated p_{u_i | u_{i-1}}") |
| 1548 | + plt.plot( |
| 1549 | + stages, |
| 1550 | + exact_conditional_probs, |
| 1551 | + marker="x", |
| 1552 | + linestyle="--", |
| 1553 | + label="Exact p_{u_i | u_{i-1}}", |
| 1554 | + ) |
| 1555 | + plt.xlabel("Stage (i)") |
| 1556 | + plt.ylabel("p") |
| 1557 | + plt.title("Conditional probabilities") |
| 1558 | + plt.legend() |
| 1559 | + plt.grid(True) |
| 1560 | + plt.tight_layout() |
| 1561 | + plt.show() |
| 1562 | + |
| 1563 | + # Plot sequential tail probabilities |
| 1564 | + plt.figure(figsize=(6, 4)) |
| 1565 | + plt.plot( |
| 1566 | + stages, estimated_sequential_probs, marker="o", label="Estimated P(f(X) > u_i)" |
| 1567 | + ) |
| 1568 | + plt.plot( |
| 1569 | + stages, |
| 1570 | + exact_sequential_probs, |
| 1571 | + marker="x", |
| 1572 | + linestyle="--", |
| 1573 | + label="Exact P(f(X) > u_i)", |
| 1574 | + ) |
| 1575 | + plt.xlabel("Stage (i)") |
| 1576 | + plt.ylabel("P(f(X) > u_i)") |
| 1577 | + plt.yscale("log") |
| 1578 | + plt.title("Sequential tail probabilities") |
| 1579 | + plt.legend() |
| 1580 | + plt.grid(True, which="both") |
| 1581 | + plt.tight_layout() |
| 1582 | + plt.show() |
| 1583 | + |
| 1584 | + # Plot the inverse CDF: zoom on upper tail using 1 - x (log scale) |
| 1585 | + x_vals = np.linspace(1e-8, 1 - 1e-8, 1000) |
| 1586 | + y_vals = norm.ppf(x_vals) |
| 1587 | + x_tail = 1 - x_vals # Tail probability |
| 1588 | + |
| 1589 | + plt.figure(figsize=(6, 4)) |
| 1590 | + plt.semilogx(x_tail, y_vals, label=r"$f(x) = \Phi^{-1}(x)$") |
| 1591 | + |
| 1592 | + for i, q in enumerate(q_levels[1:], start=1): |
| 1593 | + x_q = q |
| 1594 | + y_q = thresholds[i] |
| 1595 | + x_tail_q = 1 - x_q |
| 1596 | + plt.axhline(y=y_q, color="gray", linestyle="dotted") |
| 1597 | + plt.plot([x_tail_q], [y_q], "ro") |
| 1598 | + plt.text( |
| 1599 | + x_tail_q, y_q, f"$1-q={1 - q:.3g}$", fontsize=8, va="bottom", ha="right" |
| 1600 | + ) |
| 1601 | + |
| 1602 | + plt.xlabel(r"$1 - x$ (tail probability)") |
| 1603 | + plt.ylabel(r"$\Phi^{-1}(x)$") |
| 1604 | + plt.title("Zoom on right tail of standard gaussian inverse cdf") |
| 1605 | + plt.grid(True, which="both") |
| 1606 | + plt.legend() |
| 1607 | + plt.tight_layout() |
| 1608 | + plt.show() |
| 1609 | + |
| 1610 | + |
1361 | 1611 | if __name__ == "__main__": |
| 1612 | + print("Sample a Gaussian mixture") |
1362 | 1613 | test_run_smc_sampling_gaussian_mixture() |
| 1614 | + print("Subset sampling") |
| 1615 | + test_subset_sampling_gaussian_icdf() |
0 commit comments