Skip to content

Commit 726b3dc

Browse files
committed
smc.py: knn-based covariance estimation
1 parent 8ad41a7 commit 726b3dc

File tree

2 files changed

+110
-7
lines changed

2 files changed

+110
-7
lines changed

gpmp/misc/knn_cov.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import gpmp.num as gnp
2+
3+
try:
4+
import hnswlib
5+
HNSWLIB_AVAILABLE = True
6+
except ImportError:
7+
HNSWLIB_AVAILABLE = False
8+
print(
9+
"hnswlib is not installed. estimate_cov_matrix_knn will fall back to classical covariance estimation."
10+
)
11+
12+
def estimate_cov_matrix(x):
13+
"""
14+
Default classical covariance estimation using the sample covariance.
15+
x: shape = (N, d)
16+
Returns a (d, d) covariance matrix
17+
"""
18+
# gpmp.num.cov expects shape (d, N), so we transpose x
19+
return gnp.cov(x.T)
20+
21+
def estimate_cov_matrix_knn(
22+
x,
23+
n_random=50,
24+
n_neighbors=50,
25+
ef=50,
26+
max_ef_construction=200,
27+
M=16,
28+
):
29+
"""
30+
Estimate the covariance matrix of x by:
31+
1) Randomly sampling 'n_random' points from x.
32+
2) Building an HNSW index (if hnswlib is available).
33+
3) For each sampled point, finding its n_neighbors nearest neighbors.
34+
4) Computing each local k-NN covariance.
35+
5) Averaging the results into a single covariance matrix.
36+
37+
Fallback: if HNSWLIB_AVAILABLE is False, returns classical covariance.
38+
39+
Parameters
40+
----------
41+
x : ndarray, shape (N, d)
42+
The dataset.
43+
n_random : int, optional
44+
Number of random points from x to sample for local covariance.
45+
n_neighbors : int, optional
46+
Number of neighbors for each point to compute local covariances.
47+
ef : int, optional
48+
'ef' parameter for HNSW search (larger means more accurate at some cost).
49+
max_ef_construction : int, optional
50+
'ef_construction' parameter for HNSW index building.
51+
M : int, optional
52+
HNSW parameter controlling the connectivity of the graph.
53+
54+
Returns
55+
-------
56+
C_avg : ndarray, shape (d, d)
57+
The averaged covariance matrix (or classical covariance if fallback).
58+
"""
59+
if not HNSWLIB_AVAILABLE:
60+
# If import hnswlib failed at module load, fallback to classical
61+
return estimate_cov_matrix(x)
62+
63+
N, d = x.shape
64+
if n_random > N:
65+
n_random = N
66+
67+
# 1) Build the HNSW index
68+
p = hnswlib.Index(space='l2', dim=d)
69+
p.init_index(max_elements=N, ef_construction=max_ef_construction, M=M)
70+
p.add_items(x)
71+
p.set_ef(ef)
72+
73+
# 2) Randomly select points and compute local covariances
74+
random_indices = gnp.choice(N, size=n_random, replace=False)
75+
local_covs = []
76+
77+
for idx in random_indices:
78+
query_point = x[idx]
79+
# knn_query returns two arrays: (labels, distances) each of shape (1, k)
80+
labels, distances = p.knn_query(query_point, k=n_neighbors)
81+
neighbors_idx = labels[0] # shape (n_neighbors,)
82+
neighbors_x = x[neighbors_idx] # shape (n_neighbors, d)
83+
84+
# local covariance (gpmp.num.cov => shape (d, n_neighbors))
85+
C_local = gnp.cov(neighbors_x.T)
86+
local_covs.append(C_local)
87+
88+
# 3) Average all local covariance matrices
89+
C_avg = gnp.mean(local_covs, axis=0) # shape (d, d)
90+
return gnp.asarray(C_avg)

gpmp/misc/smc.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from scipy.stats import qmc
2222
from scipy.optimize import brentq
2323
import gpmp.num as gnp
24+
import gpmp.misc.knn_cov
2425
import matplotlib.pyplot as plt
2526
import matplotlib.gridspec as gridspec
2627

@@ -34,6 +35,9 @@ class ParticlesSetConfig:
3435
param_s_lower_bound: float = 1e-3
3536
jitter_initial_value: float = 1e-16
3637
jitter_max_iterations: int = 10
38+
covariance_method: str = "knn"
39+
covariance_knn_n_random: int = 50
40+
covariance_knn_n_neighbors: int = 100
3741

3842

3943
@dataclass
@@ -364,7 +368,15 @@ def perturb(self):
364368
raise ParticlesSetError(self.param_s, param_s_lower, param_s_upper)
365369

366370
# Covariance matrix of the pertubation noise
367-
C = self.param_s * gnp.cov(self.x.reshape(self.n, -1).T)
371+
if self.config.covariance_method == "knn":
372+
base_cov = gpmp.misc.knn_cov.estimate_cov_matrix_knn(
373+
self.x,
374+
n_random=self.config.covariance_knn_n_random,
375+
n_neighbors=self.config.covariance_knn_n_neighbors,
376+
) # shape (dim, dim)
377+
elif self.config.covariance_method == "normal":
378+
base_cov = gpmp.misc.knn_cov.estimate_cov_matrix(self.x)
379+
C = self.param_s * base_cov
368380

369381
# Call ParticlesSet.multivariate_normal_rvs(C, self.n, self.rng)
370382
# with control on the possible degeneracy of C
@@ -424,7 +436,7 @@ def move(self):
424436
self.x = gnp.set_row_2d(self.x, accept_mask, y[accept_mask, :])
425437
# self.logpx[accept_mask] = logpy[accept_mask]
426438
self.logpx = gnp.set_elem_1d(self.logpx, accept_mask, logpy[accept_mask])
427-
439+
428440
# Compute the acceptance rate
429441
acceptance_rate = gnp.sum(accept_mask) / self.n
430442

@@ -1287,7 +1299,7 @@ def test_run_smc_sampling_gaussian_mixture():
12871299
from scipy import stats
12881300

12891301
# Gaussian mixture parameters
1290-
m1, s1, w1 = 0.0, 0.05, 0.3
1302+
m1, s1, w1 = 0.0, 0.04, 0.3
12911303
m2, s2, w2 = 1.0, 0.1, 0.7
12921304

12931305
# Tempered log pdf: log p_T(x) = beta * log(p(x))
@@ -1301,7 +1313,7 @@ def logpdf_mixture(x, beta):
13011313

13021314
# Domain: 1D in [-1, 2]
13031315
init_box = [[-1], [2]]
1304-
initial_logpdf_param = 0.1 # initial beta
1316+
initial_logpdf_param = 0.01 # initial beta
13051317
target_logpdf_param = 1.0 # target beta
13061318

13071319
# SMC settings
@@ -1329,13 +1341,14 @@ def logpdf_mixture(x, beta):
13291341
print("Sample variance:", gnp.var(particles))
13301342

13311343
# Plot target density and histogram of particles
1332-
x_vals = gnp.linspace(-1, 2, 300)
1333-
target_density = w1 * stats.norm.pdf(
1344+
x_vals = gnp.linspace(-0.5, 1.5, 600)
1345+
target_density = lambda x_vals: w1 * stats.norm.pdf(
13341346
x_vals, loc=m1, scale=s1
13351347
) + w2 * stats.norm.pdf(x_vals, loc=m2, scale=s2)
13361348
plt.figure(figsize=(8, 3))
13371349
plt.hist(particles, bins=100, density=True, histtype="step", label="SMC particles")
1338-
plt.plot(x_vals, target_density, "r--", label="Target density")
1350+
plt.plot(x_vals, target_density(x_vals), "r--", label="Target density")
1351+
# plt.plot(particles, target_density(particles), 'b.')
13391352
plt.xlabel("x")
13401353
plt.ylabel("Density")
13411354
plt.legend()

0 commit comments

Comments
 (0)