Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 33 additions & 19 deletions src/factorizations/blocklanczos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ function initialize(
X₁ = Block(map(Base.Fix2(scale, one(α)), X₀.vec))

# Orthogonalization of the initial block
_, good_idx = block_qr!(X₁, iter.qr_tol)
_, good_idx, _ = block_qr!(X₁, iter.qr_tol)
X₁ = X₁[good_idx]
V = OrthonormalBasis(X₁.vec)
bs = length(X₁) # block size of the first block
Expand Down Expand Up @@ -205,9 +205,16 @@ function expand!(
R = state.R[1:(state.R_size)]
bs = length(R)
V = state.V
Rcopy = copy(R)

# Calculate the new basis and B
B, good_idx = block_qr!(R, iter.qr_tol)
B, good_idx, is_drift = block_qr!(R, iter.qr_tol)
if is_drift # Prevent column subspace of R from drifting caused by an excessively small β in block_qr!
block_reorthogonalize!(R, V)
_, good_idx, is_drift = block_qr!(R, iter.qr_tol)
Comment thread
yuiyuiui marked this conversation as resolved.
B = block_inner(R[good_idx], Rcopy) # Make sure R = XB
end

bs_next = length(good_idx)
push!(V, R[good_idx])
state.H[(k + 1):(k + bs_next), (k - bs + 1):k] = view(B, 1:bs_next, 1:bs)
Expand All @@ -221,7 +228,7 @@ function expand!(
Mnext, 1:bs_next,
1:bs_next
)
state.R.vec[1:bs_next] .= Rnext.vec
state.R.vec[1:bs_next] = Rnext.vec
state.norm_R = norm(Rnext)
state.k += bs_next
state.R_size = bs_next
Expand All @@ -237,7 +244,6 @@ function block_lanczosrecurrence(
)
# Apply the operator and calculate the M. Get Xnext and Mnext.
bs, bs_prev = size(B)
S = eltype(B)
k = length(V)
X = Block(V[(k - bs + 1):k])
AX = apply(operator, X)
Expand Down Expand Up @@ -295,14 +301,17 @@ This function performs a QR factorization of a block of abstract vectors using t
It takes as input a block of abstract vectors and a tolerance parameter, which is used to determine whether a vector is considered numerically zero.
The operation is performed in-place, transforming the input block into a block of orthonormal vectors.

The function returns a matrix of size `(r, p)` and a vector of indices goodidx. Here, `p` denotes the number of input vectors,
The function returns a matrix of size `(r, p)`, a vector of indices `goodidx` and a boolean flag `is_drift`. Here, `p` denotes the number of input vectors,
and `r` is the numerical rank of the input block. The matrix represents the upper-triangular factor of the QR decomposition,
restricted to the `r` linearly independent components. The vector `goodidx` contains the indices of the non-zero
(i.e., numerically independent) vectors in the orthonormalized block.
If a small value of β (the norm of a vector after first orthogonalization) is detected, the function will carry out an additional
reorthogonalization step to further ensure the input block vectors are orthonormalized.
In such cases, is_drift is set to true to indicate potential numerical instability.
"""
function block_qr!(block::Block, tol::Real)
n = length(block)
rank_shrink = false
is_drift = false
idx = trues(n)
r₁₁ = inner(block[1], block[1])
R = zeros(typeof(r₁₁), n, n)
Expand All @@ -312,28 +321,33 @@ function block_qr!(block::Block, tol::Real)
block[1] = scale!!(block[1], 1 / β)
else
block[1] = zerovector!!(block[1])
rank_shrink = true
idx[1] = false
end
@inbounds for j in 2:n
for j in 2:n
# first MGS
for i in 1:(j - 1)
R[i, j] = inner(block[i], block[j])
block[j] = add!!(block[j], block[i], -R[i, j])
end
β = norm(block[j])
if β > tol
R[j, j] = β
block[j] = scale!!(block[j], 1 / β)
else

if tol < β < 100 * tol # DGKS reorthogonalization
is_drift = true
for i in 1:(j - 1)
δ = inner(block[i], block[j])
R[i, j] += δ
block[j] = add!!(block[j], block[i], -δ)
end
β = norm(block[j])
end
if β < tol
block[j] = zerovector!!(block[j])
rank_shrink = true
idx[j] = false
else
R[j, j] = β
block[j] = scale!!(block[j], 1 / β)
end
end
if rank_shrink
good_idx = findall(idx)
return R[good_idx, :], good_idx
else
return R, collect(Int, 1:n)
end
good_idx = findall(idx)
return R[good_idx, :], good_idx, is_drift
end
Loading
Loading