Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
MPSKit = "bb1c41ca-d63c-52ed-829e-0820dda26502"
MPSKitModels = "ca635005-6f8c-4cd1-b51d-8491250ef2ab"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
OptimKit = "77e91f04-9b3b-57a6-a776-40b61faaebe0"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand All @@ -27,7 +28,9 @@ TensorKit = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
PEPSKitMooncakeExt = "Mooncake"

[compat]
Accessors = "0.1"
Expand All @@ -41,14 +44,14 @@ LoggingExtras = "1"
MPSKit = "0.13.9"
MPSKitModels = "0.4"
MatrixAlgebraKit = "0.6.5"
Mooncake = "0.5.27"
OhMyThreads = "0.7, 0.8"
OptimKit = "0.4"
Printf = "1"
Random = "1"
Statistics = "1"
TensorKit = "0.16.5"
TensorKit = "0.16.5, 0.17"
TensorOperations = "5"
TupleTools = "1.6.0"
VectorInterface = "0.4, 0.5, 0.6"
Zygote = "0.6, 0.7"
VectorInterface = "0.6"
julia = "1.10"
146 changes: 146 additions & 0 deletions ext/PEPSKitMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
module PEPSKitMooncakeExt

using PEPSKit, MPSKit, TensorKit, Mooncake, MatrixAlgebraKit
using PEPSKit: SVDAdjoint, EighAdjoint, QRAdjoint, CTMRGAlgorithm, FixedPointGradient, sdiag_pow
import PEPSKit: real_inner
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, primal, tangent, rrule!!, arrayify, @is_primitive

function Mooncake.arrayify(ψ::PEPSKit.InfinitePEPS{T}, dψ) where {T}
Δψmat = map((a, da) -> Mooncake.arrayify(a, da)[2], ψ.A, dψ.fields.A)
Δψ = PEPSKit.InfinitePEPS{T}(Δψmat)
return ψ, Δψ
end

_warn_pullback_truncerror(dϵ::Real; tol = MatrixAlgebraKit.defaulttol(dϵ)) =
abs(dϵ) ≤ tol || @warn "Pullback ignores non-zero tangents for truncation error"

Mooncake.tangent_type(::Type{<:PEPSKit.SVDAdjoint}) = Mooncake.NoTangent
Mooncake.tangent_type(::Type{<:PEPSKit.EighAdjoint}) = Mooncake.NoTangent
Mooncake.tangent_type(::Type{<:PEPSKit.QRAdjoint}) = Mooncake.NoTangent
Mooncake.tangent_type(::Type{<:PEPSKit.CTMRGAlgorithm}) = Mooncake.NoTangent
Mooncake.tangent_type(::Type{<:PEPSKit.FixedPointGradient}) = Mooncake.NoTangent

Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(PEPSKit.eachcoordinate), Any, Any}
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(PEPSKit._next_coordinate), Int, Int}
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(PEPSKit._set_decomposition_truncation), Any, Any}
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(PEPSKit.CTMRGEnv), Union{PEPSKit.InfinitePartitionFunction, PEPSKit.InfinitePEPS}, Vararg}

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), TensorKit.AbstractTensorMap, SVDAdjoint}
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_trunc)}, t_dt::CoDual{<:TensorKit.AbstractTensorMap}, alg_dalg::CoDual{SVDAdjoint{F, R}}) where {F, R <: PEPSKit.FullPullback}
# TODO: filter out any decomposition algorithm that doesn't give access to the full spectrum
t, Δt = arrayify(t_dt)
alg = primal(alg_dalg)
# requires access to the full decomposition
U, S, V⁺ = svd_compact!(t, alg.fwd_alg.alg)
(Ũ, S̃, Ṽ⁺), inds = MatrixAlgebraKit.truncate(svd_trunc!, (U, S, V⁺), alg.fwd_alg.trunc)
truncerror = MatrixAlgebraKit.truncation_error(diagview(S), inds)

gtol = PEPSKit._get_pullback_gauge_tol(alg.rrule_alg.verbosity)
output = (Ũ, S̃, Ṽ⁺, truncerror)
USVᴴtrunc = (Ũ, S̃, Ṽ⁺)
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
ΔUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(output_codual))))
function svd_trunc!_full_pullback((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real})
_warn_pullback_truncerror(dϵ)
Δt = MatrixAlgebraKit.svd_pullback!(
Δt, t, (U, S, V⁺), ΔUSVᴴtrunc, inds;
gauge_atol = gtol(ΔUSVᴴtrunc), degeneracy_atol = alg.rrule_alg.degeneracy_atol,
)
return NoRData(), NoRData(), NoRData(), zero(dϵ)
end
return output_codual, svd_trunc!_full_pullback
end

function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_trunc)}, t_dt::CoDual{<:TensorKit.AbstractTensorMap}, alg_dalg::CoDual{SVDAdjoint{F, R}}) where {F, R <: PEPSKit.TruncPullback}
t, Δt = arrayify(t_dt)
alg = primal(alg_dalg)
gtol = PEPSKit._get_pullback_gauge_tol(alg.rrule_alg.verbosity)
output = svd_trunc(t, alg)

output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
function svd_trunc!_trunc_pullback((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real})
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
dUtrunc_, dStrunc_, dVᴴtrunc_, _ = Mooncake.tangent(output_codual)
_warn_pullback_truncerror(dϵ)
U, dU = arrayify(Utrunc, dUtrunc_)
S, dS = arrayify(Strunc, dStrunc_)
Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_)
MatrixAlgebraKit.svd_trunc_pullback!(Δt, t, (U, S, Vᴴ), (dU, dS, dVᴴ))
MatrixAlgebraKit.zero!(dU)
MatrixAlgebraKit.zero!(dS)
MatrixAlgebraKit.zero!(dVᴴ)
return NoRData(), NoRData(), NoRData()
end
return output_codual, svd_trunc!_trunc_pullback
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(eigh_trunc), TensorKit.AbstractTensorMap, EighAdjoint}
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.eigh_trunc)}, t_dt::CoDual{<:TensorKit.AbstractTensorMap}, alg_dalg::CoDual{EighAdjoint{F, R}}) where {F, R <: PEPSKit.FullPullback}
t, dt = arrayify(t_dt)
alg = primal(alg_dalg)

D, V = eigh_full!(t; alg.fwd_alg.alg)
(D̃, Ṽ), inds = MatrixAlgebraKit.truncate(eigh_trunc!, (D, V), alg.fwd_alg.trunc)
ϵ = MatrixAlgebraKit.truncation_error(diagview(D), inds)

DVtrunc = (D̃, Ṽ)
# pack output
DVtrunc_dDVtrunc = Mooncake.zero_fcodual((DVtrunc..., ϵ))

# define pullback
dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc))))

gtol = PEPSKit._get_pullback_gauge_tol(alg.rrule_alg.verbosity)
function eigh_trunc!_full_pullback((_, _, dϵ)::Tuple{NoRData, NoRData, Real})
_warn_pullback_truncerror(dϵ)
MatrixAlgebraKit.eigh_pullback!(dt, t, (D, V), dDVtrunc, inds; gauge_atol = gtol(dDVtrunc), degeneracy_atol = alg.rrule_alg.degeneracy_atol)
MatrixAlgebraKit.zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
return ntuple(Returns(NoRData()), 3)
end
return DVtrunc_dDVtrunc, eigh_trunc!_full_pullback
end

function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.eigh_trunc)}, t_dt::CoDual{<:TensorKit.AbstractTensorMap}, alg_dalg::CoDual{EighAdjoint{F, R}}) where {F, R <: PEPSKit.TruncPullback}
t, dt = arrayify(t_dt)
alg = primal(alg_dalg)

D, V, truncerror = eigh_trunc(t, alg)
gtol = PEPSKit._get_pullback_gauge_tol(alg.rrule_alg.verbosity)
output = (D, V, truncerror)
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))

gtol = PEPSKit._get_pullback_gauge_tol(alg.rrule_alg.verbosity)
function eigh_trunc!_trunc_pullback((_, _, dϵ)::Tuple{NoRData, NoRData, Real})
_warn_pullback_truncerror(dϵ)
Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual)
dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual)
D, dD = arrayify(Dtrunc, dDtrunc_)
V, dV = arrayify(Vtrunc, dVtrunc_)
MatrixAlgebraKit.eigh_trunc_pullback!(dt, t, (D, V), (dD, dV); gauge_atol = gtol((dD, dV)), degeneracy_atol = alg.rrule_alg.degeneracy_atol)
MatrixAlgebraKit.zero!(dD) # since this is allocated in this function this is probably not required
MatrixAlgebraKit.zero!(dV) # since this is allocated in this function this is probably not required
return ntuple(Returns(NoRData()), 3)
end
return output_codual, eigh_trunc!_trunc_pullback
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(left_orth), TensorKit.AbstractTensorMap, QRAdjoint}
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.left_orth)}, t_dt::CoDual{<:TensorKit.AbstractTensorMap}, alg_dalg::CoDual{QRAdjoint})
t, dt = arrayify(t_dt)
alg = primal(alg_dalg)

QR = left_orth(t, alg)
gtol = PEPSKit._get_pullback_gauge_tol(alg.rrule_alg.verbosity)

output_codual = Mooncake.zero_fcodual(QR)
dQ_, dR_ = Mooncake.tangent(output_codual)
Q, dQ = arrayify(Q, dQ_)
R, dR = arrayify(R, dR_)
function left_orth_pullback(::NoRData)
MatrixAlgebraKit.qr_pullback!(dt, t, QR, (dQ, dR); gauge_atol = gtol(dQR))
return ntuple(Returns(NoRData()), 3)
end
return output_codual, left_orth_pullback
end

end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MPSKit = "bb1c41ca-d63c-52ed-829e-0820dda26502"
MPSKitModels = "ca635005-6f8c-4cd1-b51d-8491250ef2ab"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
OptimKit = "77e91f04-9b3b-57a6-a776-40b61faaebe0"
PEPSKit = "52969e89-939e-4361-9b68-9bc7cde4bdeb"
ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc"
Expand Down
180 changes: 180 additions & 0 deletions test/mooncake/eigh_wrapper.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
using Test
using Random
using LinearAlgebra
using TensorKit
using Mooncake
using Accessors
using PEPSKit

using MatrixAlgebraKit: TruncatedAlgorithm, diagview

# Gauge-invariant loss function
function lossfun(A, alg, R = randn(space(A)), trunc = notrunc())
alg = @set alg.fwd_alg = TruncatedAlgorithm(alg.fwd_alg, trunc)
D, V, = eigh_trunc(project_hermitian(A), alg)
return real(dot(R, V * V')) + dot(D, D) # Overlap with random tensor R is gauge-invariant and differentiable
end

dtype = ComplexF64
n = 20
χ = 10
trunc = truncspace(ℂ^χ)
rtol = 1.0e-9
Random.seed!(123456789)
r = randn(dtype, ℂ^n, ℂ^n)
r = 0.5 * (r + r') # make r Hermitian
R = randn(space(r))
R = 0.5 * (R + R')

full_alg = EighAdjoint(; fwd_alg = (; alg = :QRIteration), rrule_alg = (; alg = :FullPullback))
trunc_alg = EighAdjoint(; fwd_alg = (; alg = :QRIteration), rrule_alg = (; alg = :TruncPullback))
iter_alg = EighAdjoint(; fwd_alg = (; alg = :Lanczos), rrule_alg = (; alg = :TruncPullback))

@testset "Non-truncated eigh" begin
full_lossfun = A -> lossfun(A, full_alg, R)
trunc_lossfun = A -> lossfun(A, trunc_alg, R)
iter_lossfun = A -> lossfun(A, iter_alg, R)

full_rrule = Mooncake.build_rrule(full_lossfun, r)
trunc_rrule = Mooncake.build_rrule(trunc_lossfun, r)
iter_rrule = Mooncake.build_rrule(iter_lossfun, r)

l_full, g_full = Mooncake.value_and_gradient!!(full_rrule, full_lossfun, r)
l_trunc, g_trunc = Mooncake.value_and_gradient!!(trunc_rrule, trunc_lossfun, r)
l_iter, g_iter = Mooncake.value_and_gradient!!(iter_rrule, iter_lossfun, r)

@test l_full ≈ l_trunc ≈ l_iter
@test g_full[2] ≈ g_trunc[2] rtol = rtol
@test g_full[2] ≈ g_iter[2] rtol = rtol
@test g_trunc[2] ≈ g_iter[2] rtol = rtol
end

@testset "Truncated eigh with χ=$χ" begin
full_lossfun = A -> lossfun(A, full_alg, R, trunc)
trunc_lossfun = A -> lossfun(A, trunc_alg, R, trunc)
iter_lossfun = A -> lossfun(A, iter_alg, R, trunc)

full_rrule = Mooncake.build_rrule(full_lossfun, r)
trunc_rrule = Mooncake.build_rrule(trunc_lossfun, r)
iter_rrule = Mooncake.build_rrule(iter_lossfun, r)

l_full, g_full = Mooncake.value_and_gradient!!(full_rrule, full_lossfun, r)
l_trunc, g_trunc = Mooncake.value_and_gradient!!(trunc_rrule, trunc_lossfun, r)
l_iter, g_iter = Mooncake.value_and_gradient!!(iter_rrule, iter_lossfun, r)

@test l_full ≈ l_trunc ≈ l_iter
@test g_full[2] ≈ g_trunc[2] rtol = rtol
@test g_full[2] ≈ g_iter[2] rtol = rtol
@test g_trunc[2] ≈ g_iter[2] rtol = rtol
end

@testset "Truncated eigh broadening for $(alg.rrule_alg)" for alg in [full_alg, trunc_alg]
d, v = eigh_full(r)
d.data[1:2:n] .= d.data[2:2:n] # make every eigenvalue two-fold degenerate
r_degen = v * d * v'

no_broadening_no_cutoff_alg = @set alg.rrule_alg.degeneracy_atol = 1.0e-30
small_broadening_alg = @set alg.rrule_alg.degeneracy_atol = 1.0e-13

only_lossfun = A -> lossfun(A, alg, R, trunc)
no_broadening_lossfun = A -> lossfun(A, no_broadening_no_cutoff_alg, R, trunc)
small_broadening_lossfun = A -> lossfun(A, small_broadening_alg, R, trunc)

only_rrule = Mooncake.build_rrule(only_lossfun, r_degen)
no_broadening_rrule = Mooncake.build_rrule(no_broadening_lossfun, r_degen)
small_broadening_rrule = Mooncake.build_rrule(small_broadening_lossfun, r_degen)

l_only_cutoff, g_only_cutoff = Mooncake.value_and_gradient!!(only_rrule, only_lossfun, r_degen) # cutoff sets degenerate difference to zero
l_no_broadening_no_cutoff, g_no_broadening_no_cutoff = Mooncake.value_and_gradient!!( # degenerate singular value differences lead to divergent contributions
no_broadening_rrule, no_broadening_lossfun, r_degen,
)
l_small_broadening, g_small_broadening = Mooncake.value_and_gradient!!( # broadening smoothens divergent contributions
small_broadening_rrule, small_broadening_lossfun, r_degen,
)

@test l_only_cutoff ≈ l_no_broadening_no_cutoff ≈ l_small_broadening
@test norm(g_no_broadening_no_cutoff[2] - g_small_broadening[2]) > 1.0e-2 # divergences mess up the gradient
@test g_only_cutoff[2] ≈ g_small_broadening[2] rtol = rtol # cutoff and broadening have similar effect
end

symm_m, symm_n = 18, 24
symm_space = Z2Space(0 => symm_m, 1 => symm_n)
symm_trspace = truncspace(Z2Space(0 => symm_m ÷ 2, 1 => symm_n ÷ 3))
symm_r = randn(dtype, symm_space, symm_space)
symm_r = 0.5 * (symm_r + symm_r')
symm_R = randn(dtype, space(symm_r))
symm_R = 0.5 * (symm_R + symm_R')

@testset "IterEig of symmetric tensors" begin
full_lossfun = A -> lossfun(A, full_alg, symm_R)
trunc_lossfun = A -> lossfun(A, trunc_alg, symm_R)
iter_lossfun = A -> lossfun(A, iter_alg, symm_R)

full_rrule = Mooncake.build_rrule(full_lossfun, symm_r)
trunc_rrule = Mooncake.build_rrule(trunc_lossfun, symm_r)
iter_rrule = Mooncake.build_rrule(iter_lossfun, symm_r)

l_full, g_full = Mooncake.value_and_gradient!!(full_rrule, full_lossfun, symm_r)
l_trunc, g_trunc = Mooncake.value_and_gradient!!(trunc_rrule, trunc_lossfun, symm_r)
l_iter, g_iter = Mooncake.value_and_gradient!!(iter_rrule, iter_lossfun, symm_r)

@test l_full ≈ l_trunc ≈ l_iter
@test g_full[2] ≈ g_trunc[2] rtol = rtol
@test g_full[2] ≈ g_iter[2] rtol = rtol
@test g_trunc[2] ≈ g_iter[2] rtol = rtol

full_lossfun = A -> lossfun(A, full_alg, symm_R, symm_trspace)
trunc_lossfun = A -> lossfun(A, trunc_alg, symm_R, symm_trspace)
iter_lossfun = A -> lossfun(A, iter_alg, symm_R, symm_trspace)

full_rrule = Mooncake.build_rrule(full_lossfun, symm_r)
trunc_rrule = Mooncake.build_rrule(trunc_lossfun, symm_r)
iter_rrule = Mooncake.build_rrule(iter_lossfun, symm_r)

l_full_tr, g_full_tr = Mooncake.value_and_gradient!!(full_rrule, full_lossfun, symm_r)
l_trunc_tr, g_trunc_tr = Mooncake.value_and_gradient!!(trunc_rrule, trunc_lossfun, symm_r)
l_iter_tr, g_iter_tr = Mooncake.value_and_gradient!!(iter_rrule, iter_lossfun, symm_r)
@test l_full_tr ≈ l_trunc_tr ≈ l_iter_tr
@test g_full_tr[2] ≈ g_trunc_tr[2] rtol = rtol
@test g_full_tr[2] ≈ g_iter_tr[2] rtol = rtol
@test g_trunc_tr[2] ≈ g_iter_tr[2] rtol = rtol

iter_alg_fallback = @set iter_alg.fwd_alg.fallback_threshold = 0.4 # Do dense decomposition in one block, sparse one in the other
fb_lossfun = A -> lossfun(A, iter_alg_fallback, symm_R, symm_trspace)
fb_rrule = Mooncake.build_rrule(fb_lossfun, symm_r)
l_iter_fb, g_iter_fb = Mooncake.value_and_gradient!!(fb_rrule, fb_lossfun, symm_r)
@test l_iter_fb ≈ l_trunc_tr ≈ l_full_tr
@test g_full_tr[2] ≈ g_iter_fb[2] rtol = rtol
@test g_trunc_tr[2] ≈ g_iter_fb[2] rtol = rtol
end
#=
@testset "Truncated symmetric eigh broadening for $(alg.rrule_alg)" for alg in [full_alg, trunc_alg]
d, v = eigh_full(symm_r)
# make every singular value in the 0-sector three-fold degenerate
b0 = diagview(block(d, Z2Irrep(0)))
b0[1:3:symm_m] .= b0[3:3:symm_m]
b0[2:3:symm_m] .= b0[3:3:symm_m]
# make every singular value in the 1-sector two-fold degenerate
b1 = diagview(block(d, Z2Irrep(1)))
b1[1:2:symm_n] .= b1[2:2:symm_n]
symm_r_degen = v * d * v'

no_broadening_no_cutoff_alg = @set alg.rrule_alg.degeneracy_atol = 1.0e-30
small_broadening_alg = @set alg.rrule_alg.degeneracy_atol = 1.0e-13

l_only_cutoff, g_only_cutoff = withgradient(
A -> lossfun(A, alg, symm_R, symm_trspace), symm_r_degen
) # cutoff sets degenerate difference to zero
l_no_broadening_no_cutoff, g_no_broadening_no_cutoff = withgradient( # degenerate singular value differences lead to divergent contributions
A -> lossfun(A, no_broadening_no_cutoff_alg, symm_R, symm_trspace),
symm_r_degen,
)
l_small_broadening, g_small_broadening = withgradient( # broadening smoothens divergent contributions
A -> lossfun(A, small_broadening_alg, symm_R, symm_trspace),
symm_r_degen,
)

@test l_only_cutoff ≈ l_no_broadening_no_cutoff ≈ l_small_broadening
@test norm(g_no_broadening_no_cutoff[1] - g_small_broadening[1]) > 1.0e-2 # divergences mess up the gradient
@test g_only_cutoff[1] ≈ g_small_broadening[1] rtol = rtol # cutoff and broadening have similar effect
end=#
Loading
Loading