2121from scipy .stats import qmc
2222from scipy .optimize import brentq
2323import gpmp .num as gnp
24+ import gpmp .misc .knn_cov
2425import matplotlib .pyplot as plt
2526import matplotlib .gridspec as gridspec
2627
@@ -34,6 +35,9 @@ class ParticlesSetConfig:
3435 param_s_lower_bound : float = 1e-3
3536 jitter_initial_value : float = 1e-16
3637 jitter_max_iterations : int = 10
38+ covariance_method : str = "knn"
39+ covariance_knn_n_random : int = 50
40+ covariance_knn_n_neighbors : int = 100
3741
3842
3943@dataclass
@@ -364,7 +368,15 @@ def perturb(self):
364368 raise ParticlesSetError (self .param_s , param_s_lower , param_s_upper )
365369
366370 # Covariance matrix of the pertubation noise
367- C = self .param_s * gnp .cov (self .x .reshape (self .n , - 1 ).T )
371+ if self .config .covariance_method == "knn" :
372+ base_cov = gpmp .misc .knn_cov .estimate_cov_matrix_knn (
373+ self .x ,
374+ n_random = self .config .covariance_knn_n_random ,
375+ n_neighbors = self .config .covariance_knn_n_neighbors ,
376+ ) # shape (dim, dim)
377+ elif self .config .covariance_method == "normal" :
378+ base_cov = gpmp .misc .knn_cov .estimate_cov_matrix (self .x )
379+ C = self .param_s * base_cov
368380
369381 # Call ParticlesSet.multivariate_normal_rvs(C, self.n, self.rng)
370382 # with control on the possible degeneracy of C
@@ -424,7 +436,7 @@ def move(self):
424436 self .x = gnp .set_row_2d (self .x , accept_mask , y [accept_mask , :])
425437 # self.logpx[accept_mask] = logpy[accept_mask]
426438 self .logpx = gnp .set_elem_1d (self .logpx , accept_mask , logpy [accept_mask ])
427-
439+
428440 # Compute the acceptance rate
429441 acceptance_rate = gnp .sum (accept_mask ) / self .n
430442
@@ -1287,7 +1299,7 @@ def test_run_smc_sampling_gaussian_mixture():
12871299 from scipy import stats
12881300
12891301 # Gaussian mixture parameters
1290- m1 , s1 , w1 = 0.0 , 0.05 , 0.3
1302+ m1 , s1 , w1 = 0.0 , 0.04 , 0.3
12911303 m2 , s2 , w2 = 1.0 , 0.1 , 0.7
12921304
12931305 # Tempered log pdf: log p_T(x) = beta * log(p(x))
@@ -1301,7 +1313,7 @@ def logpdf_mixture(x, beta):
13011313
13021314 # Domain: 1D in [-1, 2]
13031315 init_box = [[- 1 ], [2 ]]
1304- initial_logpdf_param = 0.1 # initial beta
1316+ initial_logpdf_param = 0.01 # initial beta
13051317 target_logpdf_param = 1.0 # target beta
13061318
13071319 # SMC settings
@@ -1329,13 +1341,14 @@ def logpdf_mixture(x, beta):
13291341 print ("Sample variance:" , gnp .var (particles ))
13301342
13311343 # Plot target density and histogram of particles
1332- x_vals = gnp .linspace (- 1 , 2 , 300 )
1333- target_density = w1 * stats .norm .pdf (
1344+ x_vals = gnp .linspace (- 0.5 , 1.5 , 600 )
1345+ target_density = lambda x_vals : w1 * stats .norm .pdf (
13341346 x_vals , loc = m1 , scale = s1
13351347 ) + w2 * stats .norm .pdf (x_vals , loc = m2 , scale = s2 )
13361348 plt .figure (figsize = (8 , 3 ))
13371349 plt .hist (particles , bins = 100 , density = True , histtype = "step" , label = "SMC particles" )
1338- plt .plot (x_vals , target_density , "r--" , label = "Target density" )
1350+ plt .plot (x_vals , target_density (x_vals ), "r--" , label = "Target density" )
1351+ # plt.plot(particles, target_density(particles), 'b.')
13391352 plt .xlabel ("x" )
13401353 plt .ylabel ("Density" )
13411354 plt .legend ()
0 commit comments