Skip to content

Commit 8054282

Browse files
authored
Merge pull request #107 from GWmodel-Lab/fix/gwpca
Fix/gwpca
2 parents 1d1d56b + 2c14e5b commit 8054282

5 files changed

Lines changed: 264 additions & 38 deletions

File tree

include/gwmodel.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,6 @@
4040
#include "gwmodelpp/GWAverage.h"
4141
#include "gwmodelpp/GWCorrelation.h"
4242
#include "gwmodelpp/GWPCA.h"
43+
#include "gwmodelpp/GTWR.h"
4344

44-
#endif // GWMODEL_H
45+
#endif // GWMODEL_H

include/gwmodelpp/GWPCA.h

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace gwm
1919
class GWPCA: public SpatialMonoscaleAlgorithm, public IMultivariableAnalysis
2020
{
2121
private:
22-
typedef arma::mat (GWPCA::*Solver)(const arma::mat&, arma::cube&, arma::mat&); //!< \~english Calculator to solve \~chinese 模型求解函数
22+
typedef arma::mat (GWPCA::*Solver)(const arma::mat&, arma::cube&, arma::cube&, arma::mat&); //!< \~english Calculator to solve \~chinese 模型求解函数
2323

2424
public: // Constructors and Deconstructors
2525

@@ -62,6 +62,20 @@ class GWPCA: public SpatialMonoscaleAlgorithm, public IMultivariableAnalysis
6262
*/
6363
void setKeepComponents(int k) { mK = k; }
6464

65+
/**
66+
* @brief \~english Get the Robust flag. \~chinese 获取是否使用鲁棒模式。
67+
*
68+
* @return bool \~english Robust flag \~chinese 是否使用鲁棒模式
69+
*/
70+
bool robust() { return mRobust; }
71+
72+
/**
73+
* @brief \~english Set the Robust flag. \~chinese 设置是否使用鲁棒模式。
74+
*
75+
* @param robust \~english Robust flag \~chinese 是否使用鲁棒模式
76+
*/
77+
void setRobust(bool robust) { mRobust = robust; mSolver = robust ? &GWPCA::solveRobustSerial : &GWPCA::solveSerial; }
78+
6579
/**
6680
* @brief \~english Get the Local Principle Values matrix. \~chinese 获取局部主成分值。
6781
*
@@ -86,7 +100,7 @@ class GWPCA: public SpatialMonoscaleAlgorithm, public IMultivariableAnalysis
86100
/**
87101
* @brief \~english Get the Scores matrix. \~chinese 获取得分矩阵。
88102
*
89-
* @return arma::mat \~english Scores matrix \~chinese 得分矩阵
103+
* @return arma::mat \~english Scores matrix \~chinese 得分矩阵1
90104
*/
91105
const arma::cube& scores() { return mScores; }
92106

@@ -105,52 +119,77 @@ class GWPCA: public SpatialMonoscaleAlgorithm, public IMultivariableAnalysis
105119
*
106120
* @param x \~english Symmetric data matrix \~chinese 对称数据矩阵
107121
* @param loadings [out] \~english Out reference to loadings matrix \~chinese 载荷矩阵
122+
* @param scores [out] \~english Out reference to scores matrix \~chinese 得分矩阵
108123
* @param sdev [out] \~english Out reference to standard deviation matrix \~chinese 标准差
109124
* @return arma::mat \~english Principle values matrix \~chinese 主成分值矩阵
110125
*/
111-
arma::mat pca(const arma::mat& x, arma::cube& loadings, arma::mat& sdev)
126+
arma::mat pca(const arma::mat& x, arma::cube& loadings, arma::cube& scores, arma::mat& sdev)
112127
{
113-
return (this->*mSolver)(x, loadings, sdev);
128+
return (this->*mSolver)(x, loadings, scores, sdev);
114129
}
115130

116131
/**
117-
* @brief \~english Serial version of PCA funtion. \~chinese 单线程 PCA 函数。
132+
* @brief \~english Serial version of PCA funtion. \~chinese 单线程 PCA 函数。1
133+
*
134+
* @param x \~english Symmetric data matrix \~chinese 对称数据矩阵1
135+
* @param loadings [out] \~english Out reference to loadings matrix \~chinese 载荷矩阵1
136+
* @param scores [out] \~english Out reference to scores matrix \~chinese 得分矩阵1
137+
* @param sdev [out] \~english Out reference to standard deviation matrix \~chinese 标准差1
138+
* @return arma::mat \~english Principle values matrix \~chinese 主成分值矩阵1
139+
*/
140+
arma::mat solveSerial(const arma::mat& x, arma::cube& loadings, arma::cube& scores, arma::mat& sdev);
141+
142+
/**
143+
* @brief \~english Robust serial version of PCA function. \~chinese 鲁棒单线程 PCA 函数。
118144
*
119145
* @param x \~english Symmetric data matrix \~chinese 对称数据矩阵
120146
* @param loadings [out] \~english Out reference to loadings matrix \~chinese 载荷矩阵
147+
* @param scores [out] \~english Out reference to scores matrix \~chinese 得分矩阵
121148
* @param sdev [out] \~english Out reference to standard deviation matrix \~chinese 标准差
122149
* @return arma::mat \~english Principle values matrix \~chinese 主成分值矩阵
123150
*/
124-
arma::mat solveSerial(const arma::mat& x, arma::cube& loadings, arma::mat& sdev);
151+
arma::mat solveRobustSerial(const arma::mat& x, arma::cube& loadings, arma::cube& scores, arma::mat& sdev);
125152

126153
/**
127154
* @brief \~english Function to carry out weighted PCA. \~chinese 执行加权PCA的函数。
128155
*
129156
* @param x \~english Symmetric data matrix \~chinese 对称数据矩阵
130157
* @param w \~english Weight vector \~chinese 权重向量
131-
* @param V [out] \~english Right orthogonal matrix \~chinese 右边的正交矩阵
132-
* @param d [out] \~english Rectangular diagonal matri \~chinese 矩形对角阵
158+
* @param U [out] \~english Left orthogonal matrix (scores) \~chinese 左正交矩阵(得分)
159+
* @param V [out] \~english Right orthogonal matrix (loadings) \~chinese 右正交矩阵(载荷)
160+
* @param d [out] \~english Rectangular diagonal matrix \~chinese 矩形对角阵
161+
*/
162+
void wpca(const arma::mat& x, const arma::vec& w, arma::mat& U, arma::mat& V, arma::vec & d);
163+
164+
/**
165+
* @brief \~english Function to carry out robust weighted PCA. \~chinese 执行鲁棒加权PCA的函数。
166+
*
167+
* @param x \~english Symmetric data matrix \~chinese 对称数据矩阵
168+
* @param w \~english Weight vector \~chinese 权重向量
169+
* @param U [out] \~english Left orthogonal matrix (scores) \~chinese 左正交矩阵(得分)
170+
* @param V [out] \~english Right orthogonal matrix (loadings) \~chinese 右正交矩阵(载荷)
171+
* @param d [out] \~english Rectangular diagonal matrix \~chinese 矩形对角阵
133172
*/
134-
void wpca(const arma::mat& x, const arma::vec& w, arma::mat& V, arma::vec & d);
173+
void rwpca(const arma::mat& x, const arma::vec& w, arma::mat& U, arma::mat& V, arma::vec & d);
135174

136175
private: // Algorithm Parameters
137176
int mK = 2; //!< \~english Number of components to be kept \~chinese 要保留的主成分数量
138-
// bool mRobust = false;
177+
bool mRobust = false; //!< \~english Robust mode flag \~chinese 鲁棒模式标志
139178

140179
private: // Algorithm Results
141-
arma::mat mLocalPV; //!< \~english Local principle component values \~chinese 局部主成分值
142-
arma::cube mLoadings; //!< \~english Loadings for each component \~chinese 局部载荷矩阵
143-
arma::mat mSDev; //!< \~english Standard Deviation \~chinese 标准差矩阵
144-
arma::cube mScores; //!< \~english Scores for each variable \~chinese 得分矩阵
145-
arma::uvec mWinner; //!< \~english Winner variable at each sample \~chinese 优胜变量索引值
180+
arma::mat mLocalPV; //!< \~english Local principle component values \~chinese 局部主成分值1
181+
arma::cube mLoadings; //!< \~english Loadings for each component \~chinese 局部载荷矩阵1
182+
arma::mat mSDev; //!< \~english Standard Deviation \~chinese 标准差矩阵1
183+
arma::cube mScores; //!< \~english Scores for each variable \~chinese 得分矩阵1
184+
arma::uvec mWinner; //!< \~english Winner variable at each sample \~chinese 优胜变量索引值1
146185

147186
private: // Algorithm Runtime Variables
148-
arma::mat mX; //!< \~english Variable matrix \~chinese 变量矩阵
149-
arma::vec mLatestWt; //!< \~english Latest weigths \~chinese 最新的权重
187+
arma::mat mX; //!< \~english Variable matrix \~chinese 变量矩阵1
188+
arma::vec mLatestWt; //!< \~english Latest weigths \~chinese 最新的权重1
150189

151-
Solver mSolver = &GWPCA::solveSerial; //!< \~english Calculator to solve \~chinese 模型求解函数
190+
Solver mSolver = &GWPCA::solveSerial; //!< \~english Calculator to solve \~chinese 模型求解函数1
152191
};
153192

154193
}
155194

156-
#endif // GWPCA_H
195+
#endif // GWPCA_H

src/gwmodelpp/GWPCA.cpp

Lines changed: 84 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,47 +10,120 @@ void GWPCA::run()
1010
GWM_LOG_STOP_RETURN(mStatus, void());
1111

1212
GWM_LOG_STAGE("Solving");
13-
mLocalPV = pca(mX, mLoadings, mSDev);
13+
mLocalPV = pca(mX, mLoadings, mScores, mSDev);
1414
GWM_LOG_STOP_RETURN(mStatus, void());
15-
15+
1616
mWinner = index_max(mLoadings.slice(0), 1);
1717
}
1818

19-
mat GWPCA::solveSerial(const mat& x, cube& loadings, mat& sdev)
19+
mat GWPCA::solveSerial(const mat& x, cube& loadings, cube& scores, mat& sdev)
2020
{
2121
uword nDp = mCoords.n_rows, nVar = mX.n_cols;
2222
mat d_all(nVar, nDp, arma::fill::zeros);
23-
vec w0;
23+
2424
loadings = cube(nDp, nVar, mK, arma::fill::zeros);
25+
scores = cube(nDp, mK, nDp, arma::fill::zeros);
26+
2527
for (uword i = 0; i < nDp; i++)
2628
{
2729
GWM_LOG_STOP_BREAK(mStatus);
30+
2831
vec w = mSpatialWeight.weightVector(i);
29-
mat V;
32+
uvec positive = find(w > 0);
33+
vec newWt = w.elem(positive);
34+
mat newX = x.rows(positive);
35+
if (newWt.n_rows <= 5)
36+
{
37+
break;
38+
}
39+
40+
mat U, V;
3041
vec d;
31-
wpca(x, w, V, d);
32-
w0 = w;
42+
wpca(newX, newWt, U, V, d);
43+
44+
mLatestWt = newWt;
3345
d_all.col(i) = d;
46+
3447
for (int j = 0; j < mK; j++)
3548
{
3649
loadings.slice(j).row(i) = arma::trans(V.col(j));
3750
}
51+
52+
mat scorei(nDp, mK, arma::fill::zeros);
53+
for (int j = 0; j < mK; j++)
54+
{
55+
mat score = newX.each_row() % arma::trans(V.col(j));
56+
scorei.col(j) = sum(score, 1);
57+
}
58+
scores.slice(i) = scorei;
59+
3860
GWM_LOG_PROGRESS(i + 1, nDp);
3961
}
62+
4063
d_all = trans(d_all);
41-
mat variance = (d_all / sqrt(sum(w0))) % (d_all / sqrt(sum(w0)));
42-
sdev = sqrt(variance);
64+
mat variance = (d_all / pow(sum(mLatestWt), 0.5)) % (d_all / pow(sum(mLatestWt), 0.5));
65+
sdev = arma::sqrt(variance);
4366
mat pv = variance.cols(0, mK - 1).each_col() % (1.0 / sum(variance, 1)) * 100.0;
4467
return pv;
4568
}
4669

47-
void GWPCA::wpca(const mat& x, const vec& w, mat& V, vec & d)
70+
void GWPCA::wpca(const mat& x, const vec& w, mat& U, mat& V, vec & d)
4871
{
49-
mat xw = x.each_col() % w, U;
72+
mat xw = x.each_col() % w;
5073
mat centerized = (x.each_row() - sum(xw) / sum(w)).each_col() % sqrt(w);
5174
svd(U, d, V, centerized);
5275
}
5376

77+
void GWPCA::rwpca(const mat& x, const vec& w, mat& U, mat& V, vec & d)
78+
{
79+
mat mids = x;
80+
uword medianIdx = (abs(w - 0.5)).index_min();
81+
mids = mids.each_row() - x.row(medianIdx);
82+
mat weighted = mids.each_col() % w;
83+
mat score;
84+
vec tsquared;
85+
princomp(V, score, d, tsquared, weighted);
86+
U = score;
87+
}
88+
89+
mat GWPCA::solveRobustSerial(const mat& x, cube& loadings, cube& scores, mat& sdev)
90+
{
91+
uword nDp = mCoords.n_rows, nVar = mX.n_cols;
92+
mat d_all(nVar, nDp, arma::fill::zeros);
93+
vec w0;
94+
loadings = cube(nDp, nVar, mK, arma::fill::zeros);
95+
scores = cube(nDp, nDp, mK, arma::fill::zeros);
96+
for (uword i = 0; i < nDp; i++)
97+
{
98+
GWM_LOG_STOP_BREAK(mStatus);
99+
vec w = mSpatialWeight.weightVector(i);
100+
uvec positive = find(w > 0);
101+
vec newWt = w.elem(positive);
102+
mat newX = x.rows(positive);
103+
if (newWt.n_rows <= 5)
104+
{
105+
continue;
106+
}
107+
mat U, V;
108+
vec d;
109+
rwpca(newX, newWt, U, V, d);
110+
w0 = newWt;
111+
d_all.col(i) = d;
112+
for (int j = 0; j < mK; j++)
113+
{
114+
loadings.slice(j).row(i) = arma::trans(V.col(j));
115+
mat scoreAll = x.each_row() % arma::trans(V.col(j));
116+
scores.slice(j).col(i) = sum(scoreAll, 1);
117+
}
118+
GWM_LOG_PROGRESS(i + 1, nDp);
119+
}
120+
d_all = trans(d_all);
121+
mat variance = (d_all / pow(sum(w0), 0.5)) % (d_all / pow(sum(w0), 0.5));
122+
sdev = sqrt(variance);
123+
mat pv = variance.cols(0, mK - 1).each_col() % (1.0 / sum(variance, 1)) * 100.0;
124+
return pv;
125+
}
126+
54127
bool GWPCA::isValid()
55128
{
56129
if (SpatialAlgorithm::isValid())

test/CMakeFind/main.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,12 @@ using namespace arma;
66
int main()
77
{
88
mat coords(100, 2, fill::randu);
9-
mat x = join_rows(vec(100, fill::ones), mat(100, 2, fill::randu));
10-
mat betas = mat(100, 3, fill::randu);
11-
vec eps(100, fill::randu);
12-
vec y = sum(x % betas, 1) + eps;
9+
mat x = mat(100, 3, fill::randu);
1310
BandwidthWeight bw(36.0, true, BandwidthWeight::Gaussian);
1411
CRSDistance dist(false);
1512
SpatialWeight sw(&bw, &dist);
16-
GWRBasic algorithm(x, y, coords, sw);
17-
algorithm.fit();
13+
GWPCA algorithm(x, coords, sw);
14+
algorithm.setKeepComponents(2);
15+
algorithm.run();
1816
return 0;
1917
}

0 commit comments

Comments
 (0)