From 58023a9035c02e339aa86cf1a879fea67df07cba Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 10 Jun 2026 15:29:10 +0200 Subject: [PATCH 01/11] Forward and reverse Enzyme tests and rules for linalg --- Project.toml | 12 +- ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl | 16 ++ ext/TensorKitEnzymeExt/linalg.jl | 262 +++++++++++++++++++ ext/TensorKitEnzymeExt/utility.jl | 80 ++++++ ext/TensorKitEnzymeTestUtilsExt.jl | 66 +++++ test/Project.toml | 2 + test/enzyme-linalg/inv.jl | 25 ++ test/enzyme-linalg/mul.jl | 30 +++ test/enzyme-linalg/norm.jl | 23 ++ test/enzyme-linalg/tr.jl | 33 +++ 10 files changed, 546 insertions(+), 3 deletions(-) create mode 100644 ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl create mode 100644 ext/TensorKitEnzymeExt/linalg.jl create mode 100644 ext/TensorKitEnzymeExt/utility.jl create mode 100644 ext/TensorKitEnzymeTestUtilsExt.jl create mode 100644 test/enzyme-linalg/inv.jl create mode 100644 test/enzyme-linalg/mul.jl create mode 100644 test/enzyme-linalg/norm.jl create mode 100644 test/enzyme-linalg/tr.jl diff --git a/Project.toml b/Project.toml index 3d0abc9b1..f8a50d723 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,8 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" @@ -31,6 +33,8 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" TensorKitAMDGPUExt = "AMDGPU" TensorKitCUDAExt = ["CUDA", "cuTENSOR"] TensorKitChainRulesCoreExt = "ChainRulesCore" +TensorKitEnzymeExt = "Enzyme" +TensorKitEnzymeTestUtilsExt = "EnzymeTestUtils" TensorKitFiniteDifferencesExt = "FiniteDifferences" TensorKitMooncakeExt = "Mooncake" @@ -43,10 +47,12 @@ AMDGPU = "2" CUDA = "6" ChainRulesCore = "1" Dictionaries = "0.4" +Enzyme = "0.13.146" +EnzymeTestUtils = "0.2.7" FiniteDifferences = "0.12" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.6.7" +MatrixAlgebraKit = "0.6.8" Mooncake = "0.5.27" OhMyThreads = "0.8.0" Printf = "1" @@ -54,8 +60,8 @@ Random = "1" ScopedValues = "1.3.0" Strided = "2" TensorKitSectors = "0.3.7" -TensorOperations = "5.5" +TensorOperations = "5.5.2" TupleTools = "1.5" -VectorInterface = "0.4.8, 0.5, 0.6" +VectorInterface = "0.4.8, 0.5" cuTENSOR = "6" julia = "1.10" diff --git a/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl new file mode 100644 index 000000000..7f448f9e3 --- /dev/null +++ b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl @@ -0,0 +1,16 @@ +module TensorKitEnzymeExt + +using Enzyme +using TensorKit +import TensorKit as TK +using VectorInterface +using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize +import TensorOperations as TO +using MatrixAlgebraKit +using TupleTools +using Random: AbstractRNG + +include("utility.jl") +include("linalg.jl") + +end diff --git a/ext/TensorKitEnzymeExt/linalg.jl b/ext/TensorKitEnzymeExt/linalg.jl new file mode 100644 index 000000000..2e61c3bca --- /dev/null +++ b/ext/TensorKitEnzymeExt/linalg.jl @@ -0,0 +1,262 @@ +# Shared +# ------ +# Can Enzyme do this itself? Apparently not... +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(mul!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + B::Annotation{<:AbstractTensorMap}, + α::Annotation, + β::Annotation, + ) where {RT} + cacheC = !isa(β, Const) && copy(C.val) + cacheA = !isa(B, Const) && EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing + cacheB = !isa(A, Const) && EnzymeRules.overwritten(config)[4] ? copy(B.val) : nothing + AB = if !isa(α, Const) + AB = A.val * B.val + add!(C.val, AB, α.val, β.val) + AB + else + mul!(C.val, A.val, B.val, α.val, β.val) + nothing + end + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + cache = (cacheC, cacheA, cacheB, AB) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(mul!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + B::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ) where {RT} + if RT <: Const + Δα = isa(α, Const) ? nothing : zero(α.val) + Δβ = isa(β, Const) ? nothing : zero(β.val) + return (nothing, nothing, nothing, Δα, Δβ) + end + cacheC, cacheA, cacheB, AB = cache + Cval = something(cacheC, C.val) + Aval = something(cacheA, A.val) + Bval = something(cacheB, B.val) + + !isa(A, Const) && !isa(C, Const) && project_mul!(A.dval, C.dval, Bval', conj(α.val)) + !isa(B, Const) && !isa(C, Const) && project_mul!(B.dval, Aval', C.dval, conj(α.val)) + Δαr = pullback_dα(α, C, AB) + Δβr = pullback_dβ(β, C, Cval) + !isa(C, Const) && pullback_dC!(C.dval, β.val) + + return (nothing, nothing, nothing, Δαr, Δβr) +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(mul!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + B::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ) where {RT} + # ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α + if !isa(C, Const) + scale!(C.dval, β.val) + !isa(β, Const) && add!(C.dval, C.val, β.dval) + !isa(α, Const) && project_mul!(C.dval, A.val, B.val, α.dval) + !isa(A, Const) && project_mul!(C.dval, A.dval, B.val, α.val) + !isa(B, Const) && project_mul!(C.dval, A.val, B.dval, α.val) + end + mul!(C.val, A.val, B.val, α.val, β.val) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return C + elseif EnzymeRules.needs_primal(config) + return C.val + elseif EnzymeRules.needs_shadow(config) + return C.dval + else + return nothing + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(tr)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + ) where {RT} + ret = func.val(A.val) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? zero(ret) : nothing + cache = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(tr)}, + dret::Active, + cache, + A::Annotation{<:AbstractTensorMap}, + ) + Aval = something(cache, A.val) + Δtrace = dret.val + if !isa(A, Const) + for (_, b) in blocks(A.dval) + TensorKit.diagview(b) .+= Δtrace + end + end + return (nothing,) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(tr)}, + ::Type{<:Const}, + cache, + A::Annotation{<:AbstractTensorMap}, + ) + return (nothing,) +end +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + ::Type{RT}, + func::Const{typeof(tr)}, + A::Annotation{<:AbstractTensorMap}, + ) where {RT} + y = EnzymeRules.needs_primal(config) ? tr(A.val) : nothing + Δy = if EnzymeRules.needs_shadow(config) && !isa(A, Const) + tr(A.dval) + elseif EnzymeRules.needs_shadow(config) + zero(eltype(A.dval)) + else + nothing + end + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(y, Δy) + elseif EnzymeRules.needs_primal(config) + return y + elseif EnzymeRules.needs_shadow(config) + return Δy + else + return nothing + end +end +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(norm)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Real}, + ) where {RT} + p.val == 2 || error("currently only implemented for p = 2") + ret = func.val(A.val, p.val) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? zero(ret) : nothing + cacheA = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing + cache = (ret, cacheA) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(norm)}, + dret::Active, + cache, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Real}, + ) + n, cacheA = cache + Δn = dret.val + p.val == 2 || error("currently only implemented for p = 2") + Aval = something(cacheA, A.val) + if !isa(A, Const) + x = (Δn' + Δn) / 2 / hypot(n, eps(one(n))) + add!(A.dval, A.val, x) + end + return (nothing, nothing) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(norm)}, + ::Type{<:Const}, + cache, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Real}, + ) + return (nothing, nothing) +end +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(norm)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Real}, + ) where {RT} + y = norm(A.val, p.val) + Δy = if EnzymeRules.needs_shadow(config) && !isa(A, Const) + real(dot(A.val, A.dval)) * pinv(y) + elseif EnzymeRules.needs_shadow(config) + zero(eltype(A.dval)) + else + nothing + end + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(y, Δy) + elseif EnzymeRules.needs_primal(config) + return y + elseif EnzymeRules.needs_shadow(config) + return Δy + else + return nothing + end +end +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(inv)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + ) where {RT} + ret = inv(A.val) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? make_zero(ret) : nothing + cache = (ret, shadow) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(inv)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + ) where {RT} + Ainv, ΔAinv = cache + !isa(A, Const) && mul!(A.dval, Ainv' * ΔAinv, Ainv', -1, One()) + return (nothing,) +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(inv)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + ) where {RT} + Ainv = inv(A.val) + ΔAinv = !isa(A, Const) ? scale!(Ainv * A.dval * Ainv, -1) : make_zero(Ainv) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(Ainv, ΔAinv) + elseif EnzymeRules.needs_primal(config) + return Ainv + elseif EnzymeRules.needs_shadow(config) + return ΔAinv + else + return nothing + end +end diff --git a/ext/TensorKitEnzymeExt/utility.jl b/ext/TensorKitEnzymeExt/utility.jl new file mode 100644 index 000000000..03ade424a --- /dev/null +++ b/ext/TensorKitEnzymeExt/utility.jl @@ -0,0 +1,80 @@ +# Projection +# ---------- +pullback_dα(α::Const, C::Const, A) = nothing +pullback_dα(α::Const, C::Annotation, A) = nothing +pullback_dα(α::Annotation, C::Const, A) = zero(α.val) +pullback_dα(α::Annotation, C::Annotation, A) = project_scalar(α.val, inner(A, C.dval)) + +pullback_dβ(β::Const, C::Const, Ccache) = nothing +pullback_dβ(β::Const, C::Annotation, Ccache) = nothing +pullback_dβ(β::Annotation, C::Const, Ccache) = zero(β.val) +pullback_dβ(β::Annotation, C::Annotation, Ccache) = project_scalar(β.val, inner(Ccache, C.dval)) + +pullback_dC!(ΔC, β::Number) = scale!(ΔC, conj(β)) + +""" + project_scalar(x::Number, dx::Number) + +Project a computed tangent `dx` onto the correct tangent type for `x`. +For example, we might compute a complex `dx` but only require the real part. +""" +project_scalar(x::Number, dx::Number) = oftype(x, dx) +project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) + +# in-place multiplication and accumulation which might project to (real) +# TODO: this could probably be done without allocating +function project_mul!(C, A, B, α) + TC = TO.promote_contract(scalartype(A), scalartype(B), scalartype(α)) + return if !(TC <: Real) && scalartype(C) <: Real + add!(C, real(mul!(zerovector(C, TC), A, B, α))) + else + mul!(C, A, B, α, One()) + end +end +function project_contract!(C, A, pA, conjA, B, pB, conjB, pAB, α, backend, allocator) + TA = TensorKit.promote_permute(A) + TB = TensorKit.promote_permute(B) + TC = TO.promote_contract(TA, TB, scalartype(α)) + + return if scalartype(C) <: Real && !(TC <: Real) + add!(C, real(TO.tensorcontract!(zerovector(C, TC), A, pA, conjA, B, pB, conjB, pAB, α, Zero(), backend, allocator))) + else + TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, One(), backend, allocator) + end +end + +# IndexTuple utility +# ------------------ +trivtuple(N) = ntuple(identity, N) + +Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) + length(p) >= N₁ || + throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) + return TupleTools.getindices(p, trivtuple(N₁)), + TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) +end +Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) + return _repartition(linearize(p), N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} + return _repartition(p, N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap) + return _repartition(p, TensorKit.numout(t)) +end + +# Ignore derivatives +# ------------------ + +@inline EnzymeRules.inactive_type(::Type{<:TensorKit.FusionTree}) = true +@inline EnzymeRules.inactive_type(::Type{<:TensorKit.GenericTreeTransformer}) = true +@inline EnzymeRules.inactive_type(::Type{<:TensorKit.VectorSpace}) = true + +@inline EnzymeRules.inactive(::typeof(TensorKit.sectorstructure), ::Any) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.degeneracystructure), ::Any) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.select), s::HomSpace, i::Index2Tuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.flip), s::HomSpace, i::Any) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.permute), s::HomSpace, i::Index2Tuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.braid), s::HomSpace, i::Index2Tuple, ::IndexTuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.compose), s1::HomSpace, s2::HomSpace) = nothing +@inline EnzymeRules.inactive(::typeof(TensorOperations.tensorcontract), c::HomSpace, p::Index2Tuple, α::Bool, b::HomSpace, q::Index2Tuple, β::Bool, pq::Index2Tuple) = nothing diff --git a/ext/TensorKitEnzymeTestUtilsExt.jl b/ext/TensorKitEnzymeTestUtilsExt.jl new file mode 100644 index 000000000..4a1f393b1 --- /dev/null +++ b/ext/TensorKitEnzymeTestUtilsExt.jl @@ -0,0 +1,66 @@ +module TensorKitEnzymeTestUtilsExt + +using TensorKit +using EnzymeTestUtils +using EnzymeTestUtils: Enzyme +import EnzymeTestUtils: to_vec, from_vec, rand_tangent + +function EnzymeTestUtils.to_vec(x::TensorMap, seen_vecs::EnzymeTestUtils.AliasDict) + has_seen = haskey(seen_vecs, x) + is_const = Enzyme.Compiler.guaranteed_const(Core.Typeof(x)) + if has_seen || is_const + x_vec = Float32[] + else + vec_of_vecs = [b * TensorKit.sqrtdim(c) for (c, b) in blocks(x)] + x_vec, back = to_vec(vec_of_vecs) + seen_vecs[x] = x_vec + end + function TensorMap_from_vec(x_vec_new::AbstractVector, seen_xs::EnzymeTestUtils.AliasDict) + if xor(has_seen, haskey(seen_xs, x)) + throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized.")) + end + has_seen && return seen_xs[x] + is_const && return x + + x_new = similar(x) + xvec_of_vecs = back(x_vec_new) + for (i, (c, b)) in enumerate(blocks(x_new)) + scale!(b, xvec_of_vecs[i], TensorKit.invsqrtdim(c)) + end + if Core.Typeof(x_new) != Core.Typeof(x) + x_new = Core.Typeof(x)(x_new) + end + seen_xs[x] = x_new + return x_new + end + return x_vec, TensorMap_from_vec +end +function EnzymeTestUtils.to_vec(t::TensorKit.AdjointTensorMap, seen_vecs::EnzymeTestUtils.AliasDict) + parent_vec, parent_t = to_vec(parent(t), seen_vecs) + return parent_vec, adjoint ∘ parent_t +end +function EnzymeTestUtils.to_vec(t::TensorKit.DiagonalTensorMap, seen_vecs::EnzymeTestUtils.AliasDict) + parent_vec, parent_t = to_vec(TensorMap(t), seen_vecs) + return parent_vec, TensorKit.DiagonalTensorMap ∘ parent_t +end + +# generate random tangents for testing +function EnzymeTestUtils.rand_tangent(rng, t::TensorMap) + return TensorMap(EnzymeTestUtils.rand_tangent(rng, t.data), space(t)) +end + +function EnzymeTestUtils.rand_tangent(rng, t::TensorKit.AdjointTensorMap) + return adjoint(rand_tangent(rng, parent(t))) +end + +function EnzymeTestUtils.rand_tangent(rng, t::DiagonalTensorMap) + return DiagonalTensorMap(EnzymeTestUtils.rand_tangent(rng, t.data), space(t, 1)) +end + +function EnzymeTestUtils.map_fields_recursive(f::typeof(Base.copyto!), y::TensorKit.SortedVectorDict{K, V}, x::TensorKit.SortedVectorDict{K, V}) where {K, V} + copyto!(y.keys, x.keys) + copyto!(y.values, x.values) + return y +end + +end diff --git a/test/Project.toml b/test/Project.toml index 18af8af80..5252ff1f4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,6 +9,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" diff --git a/test/enzyme-linalg/inv.jl b/test/enzyme-linalg/inv.jl new file mode 100644 index 000000000..68aa37adb --- /dev/null +++ b/test/enzyme-linalg/inv.jl @@ -0,0 +1,25 @@ +using Test, TestExtras +using TensorKit +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - LinearAlgebra (inv):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + D1 = randn(T, V[1] ← V[1]) + D2 = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + D3 = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[1] ⊗ V[2] ⊗ V[3]) + @testset "inv: TD $TD" for TD in (Const, Duplicated) + EnzymeTestUtils.test_reverse(inv, TD, (D1, TD); atol, rtol) + EnzymeTestUtils.test_reverse(inv, TD, (D2, TD); atol, rtol) + EnzymeTestUtils.test_reverse(inv, TD, (D3, TD); atol, rtol) + EnzymeTestUtils.test_forward(inv, TD, (D1, TD); atol, rtol) + EnzymeTestUtils.test_forward(inv, TD, (D2, TD); atol, rtol) + EnzymeTestUtils.test_forward(inv, TD, (D3, TD); atol, rtol) + end + end +end diff --git a/test/enzyme-linalg/mul.jl b/test/enzyme-linalg/mul.jl new file mode 100644 index 000000000..d2882dceb --- /dev/null +++ b/test/enzyme-linalg/mul.jl @@ -0,0 +1,30 @@ +using Test, TestExtras +using TensorKit +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@timedtestset verbose = true "Enzyme - LinearAlgebra (mul):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + + C = randn(T, V[1] ⊗ V[2] ← V[5]) + A = randn(T, codomain(C) ← V[3] ⊗ V[4]) + B = randn(T, domain(A) ← domain(C)) + α = randn(T) + β = randn(T) + @testset "mul: TC $TC, TA $TA, TB $TB" for TC in (Duplicated,), TA in (Duplicated,), TB in (Duplicated,) + @testset "Tα $Tα, Tβ $Tβ" for Tα in (Active, Const), Tβ in (Active, Const) + EnzymeTestUtils.test_reverse(mul!, TC, (C, TC), (A, TA), (B, TB), (α, Tα), (β, Tβ); atol, rtol) + end + @testset "Tα $Tα, Tβ $Tβ" for Tα in (Duplicated, Const), Tβ in (Duplicated, Const) + EnzymeTestUtils.test_forward(mul!, TC, (C, TC), (A, TA), (B, TB), (α, Tα), (β, Tβ); atol, rtol) + end + EnzymeTestUtils.test_reverse(mul!, TC, (C, TC), (A, TA), (B, TB); atol, rtol) + EnzymeTestUtils.test_forward(mul!, TC, (C, TC), (A, TA), (B, TB); atol, rtol) + end + end +end diff --git a/test/enzyme-linalg/norm.jl b/test/enzyme-linalg/norm.jl new file mode 100644 index 000000000..8a288035a --- /dev/null +++ b/test/enzyme-linalg/norm.jl @@ -0,0 +1,23 @@ +using Test, TestExtras +using TensorKit +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@timedtestset verbose = true "Enzyme - LinearAlgebra (norm):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T), TC $TC" for V in spacelist, T in eltypes, TC in (Const, Duplicated) + atol = default_tol(T) + rtol = default_tol(T) + C = randn(T, V[1] ⊗ V[2] ← V[5]) + for RT in (Const, Active) + EnzymeTestUtils.test_reverse(norm, RT, (C, TC), (2, Const); atol, rtol) + EnzymeTestUtils.test_reverse(norm, RT, (C', TC), (2, Const); atol, rtol) + end + for RT in (Const, Duplicated) + EnzymeTestUtils.test_forward(norm, RT, (C, TC), (2, Const); atol, rtol) + EnzymeTestUtils.test_forward(norm, RT, (C', TC), (2, Const); atol, rtol) + end + end +end diff --git a/test/enzyme-linalg/tr.jl b/test/enzyme-linalg/tr.jl new file mode 100644 index 000000000..1ba0c2df7 --- /dev/null +++ b/test/enzyme-linalg/tr.jl @@ -0,0 +1,33 @@ +using Test, TestExtras +using TensorKit +using Enzyme, EnzymeTestUtils +using Random + +is_ci = get(ENV, "CI", "false") == "true" + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +rRTs = is_ci ? (Active,) : (Const, Active) +fRTs = is_ci ? (Duplicated,) : (Const, Duplicated) +TDs = is_ci ? (Duplicated,) : (Const, Duplicated) + +@timedtestset verbose = true "Enzyme - LinearAlgebra (tr):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + D1 = randn(T, V[1] ← V[1]) + D2 = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + D3 = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[1] ⊗ V[2] ⊗ V[3]) + @testset "tr reverse: RT $RT, TD $TD" for RT in rRTs, TD in TDs + EnzymeTestUtils.test_reverse(tr, RT, (D1, TD); atol, rtol) + EnzymeTestUtils.test_reverse(tr, RT, (D2, TD); atol, rtol) + EnzymeTestUtils.test_reverse(tr, RT, (D3, TD); atol, rtol) + end + @testset "tr forward: RT $RT, TD $TD" for RT in fRTs, TD in TDs + EnzymeTestUtils.test_forward(tr, RT, (D1, TD); atol, rtol) + EnzymeTestUtils.test_forward(tr, RT, (D2, TD); atol, rtol) + EnzymeTestUtils.test_forward(tr, RT, (D3, TD); atol, rtol) + end + end +end From 2554d8d62d340bf421b37ae88825a1b03cafe41e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Jun 2026 08:12:57 +0200 Subject: [PATCH 02/11] Try to cut down on ci times --- test/enzyme-linalg/inv.jl | 5 ++++- test/enzyme-linalg/mul.jl | 8 ++++++-- test/enzyme-linalg/norm.jl | 8 ++++++-- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/test/enzyme-linalg/inv.jl b/test/enzyme-linalg/inv.jl index 68aa37adb..73807ec77 100644 --- a/test/enzyme-linalg/inv.jl +++ b/test/enzyme-linalg/inv.jl @@ -6,6 +6,9 @@ using Random spacelist = ad_spacelist(fast_tests) eltypes = (Float64, ComplexF64) +is_ci = get(ENV, "CI", "false") == "true" +TDs = is_ci ? (Duplicated,) : (Const, Duplicated) + @timedtestset "Enzyme - LinearAlgebra (inv):" begin @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes atol = default_tol(T) @@ -13,7 +16,7 @@ eltypes = (Float64, ComplexF64) D1 = randn(T, V[1] ← V[1]) D2 = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) D3 = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[1] ⊗ V[2] ⊗ V[3]) - @testset "inv: TD $TD" for TD in (Const, Duplicated) + @testset "inv: TD $TD" for TD in TDs EnzymeTestUtils.test_reverse(inv, TD, (D1, TD); atol, rtol) EnzymeTestUtils.test_reverse(inv, TD, (D2, TD); atol, rtol) EnzymeTestUtils.test_reverse(inv, TD, (D3, TD); atol, rtol) diff --git a/test/enzyme-linalg/mul.jl b/test/enzyme-linalg/mul.jl index d2882dceb..c4918d8b9 100644 --- a/test/enzyme-linalg/mul.jl +++ b/test/enzyme-linalg/mul.jl @@ -6,6 +6,10 @@ using Random spacelist = ad_spacelist(fast_tests) eltypes = (Float64, ComplexF64) +is_ci = get(ENV, "CI", "false") == "true" +rTs = is_ci ? (Active,) : (Const, Active) +fTs = is_ci ? (Duplicated,) : (Const, Duplicated) + @timedtestset verbose = true "Enzyme - LinearAlgebra (mul):" begin @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes atol = default_tol(T) @@ -17,10 +21,10 @@ eltypes = (Float64, ComplexF64) α = randn(T) β = randn(T) @testset "mul: TC $TC, TA $TA, TB $TB" for TC in (Duplicated,), TA in (Duplicated,), TB in (Duplicated,) - @testset "Tα $Tα, Tβ $Tβ" for Tα in (Active, Const), Tβ in (Active, Const) + @testset "Tα $Tα, Tβ $Tβ" for Tα in rTs, Tβ in rTs EnzymeTestUtils.test_reverse(mul!, TC, (C, TC), (A, TA), (B, TB), (α, Tα), (β, Tβ); atol, rtol) end - @testset "Tα $Tα, Tβ $Tβ" for Tα in (Duplicated, Const), Tβ in (Duplicated, Const) + @testset "Tα $Tα, Tβ $Tβ" for Tα in fTs, Tβ in fTs EnzymeTestUtils.test_forward(mul!, TC, (C, TC), (A, TA), (B, TB), (α, Tα), (β, Tβ); atol, rtol) end EnzymeTestUtils.test_reverse(mul!, TC, (C, TC), (A, TA), (B, TB); atol, rtol) diff --git a/test/enzyme-linalg/norm.jl b/test/enzyme-linalg/norm.jl index 8a288035a..21f52a416 100644 --- a/test/enzyme-linalg/norm.jl +++ b/test/enzyme-linalg/norm.jl @@ -6,16 +6,20 @@ using Random spacelist = ad_spacelist(fast_tests) eltypes = (Float64, ComplexF64) +is_ci = get(ENV, "CI", "false") == "true" +rRTs = is_ci ? (Active,) : (Const, Active) +fRTs = is_ci ? (Duplicated,) : (Const, Duplicated) + @timedtestset verbose = true "Enzyme - LinearAlgebra (norm):" begin @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T), TC $TC" for V in spacelist, T in eltypes, TC in (Const, Duplicated) atol = default_tol(T) rtol = default_tol(T) C = randn(T, V[1] ⊗ V[2] ← V[5]) - for RT in (Const, Active) + for RT in rRTs EnzymeTestUtils.test_reverse(norm, RT, (C, TC), (2, Const); atol, rtol) EnzymeTestUtils.test_reverse(norm, RT, (C', TC), (2, Const); atol, rtol) end - for RT in (Const, Duplicated) + for RT in fRTs EnzymeTestUtils.test_forward(norm, RT, (C, TC), (2, Const); atol, rtol) EnzymeTestUtils.test_forward(norm, RT, (C', TC), (2, Const); atol, rtol) end From 3c2139d9f117cbf596fd49d382b43a385cbebd7d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Jun 2026 08:18:26 +0200 Subject: [PATCH 03/11] Formatter --- test/enzyme-linalg/inv.jl | 2 +- test/enzyme-linalg/norm.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/enzyme-linalg/inv.jl b/test/enzyme-linalg/inv.jl index 73807ec77..8e8920f46 100644 --- a/test/enzyme-linalg/inv.jl +++ b/test/enzyme-linalg/inv.jl @@ -16,7 +16,7 @@ TDs = is_ci ? (Duplicated,) : (Const, Duplicated) D1 = randn(T, V[1] ← V[1]) D2 = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) D3 = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[1] ⊗ V[2] ⊗ V[3]) - @testset "inv: TD $TD" for TD in TDs + @testset "inv: TD $TD" for TD in TDs EnzymeTestUtils.test_reverse(inv, TD, (D1, TD); atol, rtol) EnzymeTestUtils.test_reverse(inv, TD, (D2, TD); atol, rtol) EnzymeTestUtils.test_reverse(inv, TD, (D3, TD); atol, rtol) diff --git a/test/enzyme-linalg/norm.jl b/test/enzyme-linalg/norm.jl index 21f52a416..634d1ca3e 100644 --- a/test/enzyme-linalg/norm.jl +++ b/test/enzyme-linalg/norm.jl @@ -15,7 +15,7 @@ fRTs = is_ci ? (Duplicated,) : (Const, Duplicated) atol = default_tol(T) rtol = default_tol(T) C = randn(T, V[1] ⊗ V[2] ← V[5]) - for RT in rRTs + for RT in rRTs EnzymeTestUtils.test_reverse(norm, RT, (C, TC), (2, Const); atol, rtol) EnzymeTestUtils.test_reverse(norm, RT, (C', TC), (2, Const); atol, rtol) end From 5357ee2614bcd082813f556b49d52cc0aee6d0f3 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Jun 2026 09:17:50 +0200 Subject: [PATCH 04/11] Fix tensor for norm --- test/enzyme-linalg/norm.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/enzyme-linalg/norm.jl b/test/enzyme-linalg/norm.jl index 634d1ca3e..d12dccc1a 100644 --- a/test/enzyme-linalg/norm.jl +++ b/test/enzyme-linalg/norm.jl @@ -14,7 +14,7 @@ fRTs = is_ci ? (Duplicated,) : (Const, Duplicated) @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T), TC $TC" for V in spacelist, T in eltypes, TC in (Const, Duplicated) atol = default_tol(T) rtol = default_tol(T) - C = randn(T, V[1] ⊗ V[2] ← V[5]) + C = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') for RT in rRTs EnzymeTestUtils.test_reverse(norm, RT, (C, TC), (2, Const); atol, rtol) EnzymeTestUtils.test_reverse(norm, RT, (C', TC), (2, Const); atol, rtol) From b6b7e4ee2e7285398d1a15ac60c958da86b9df8b Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Jun 2026 08:38:44 -0400 Subject: [PATCH 05/11] Fix space for mul also --- test/enzyme-linalg/mul.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/enzyme-linalg/mul.jl b/test/enzyme-linalg/mul.jl index c4918d8b9..0d4043791 100644 --- a/test/enzyme-linalg/mul.jl +++ b/test/enzyme-linalg/mul.jl @@ -15,8 +15,8 @@ fTs = is_ci ? (Duplicated,) : (Const, Duplicated) atol = default_tol(T) rtol = default_tol(T) - C = randn(T, V[1] ⊗ V[2] ← V[5]) - A = randn(T, codomain(C) ← V[3] ⊗ V[4]) + C = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') + A = randn(T, codomain(C) ← V[5]' ⊗ V[4]') B = randn(T, domain(A) ← domain(C)) α = randn(T) β = randn(T) From 14afb58953575608bbcaedd10752dee83138365c Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 12 Jun 2026 09:37:52 +0200 Subject: [PATCH 06/11] Try to cut down on CI burden some more --- test/enzyme-linalg/inv.jl | 16 +++++++++------- test/enzyme-linalg/mul.jl | 6 ++++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/test/enzyme-linalg/inv.jl b/test/enzyme-linalg/inv.jl index 8e8920f46..455776d24 100644 --- a/test/enzyme-linalg/inv.jl +++ b/test/enzyme-linalg/inv.jl @@ -13,15 +13,17 @@ TDs = is_ci ? (Duplicated,) : (Const, Duplicated) @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes atol = default_tol(T) rtol = default_tol(T) - D1 = randn(T, V[1] ← V[1]) - D2 = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) - D3 = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[1] ⊗ V[2] ⊗ V[3]) @testset "inv: TD $TD" for TD in TDs - EnzymeTestUtils.test_reverse(inv, TD, (D1, TD); atol, rtol) - EnzymeTestUtils.test_reverse(inv, TD, (D2, TD); atol, rtol) + if !is_ci + D1 = randn(T, V[1] ← V[1]) + D2 = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + EnzymeTestUtils.test_reverse(inv, TD, (D1, TD); atol, rtol) + EnzymeTestUtils.test_reverse(inv, TD, (D2, TD); atol, rtol) + EnzymeTestUtils.test_forward(inv, TD, (D1, TD); atol, rtol) + EnzymeTestUtils.test_forward(inv, TD, (D2, TD); atol, rtol) + end + D3 = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[1] ⊗ V[2] ⊗ V[3]) EnzymeTestUtils.test_reverse(inv, TD, (D3, TD); atol, rtol) - EnzymeTestUtils.test_forward(inv, TD, (D1, TD); atol, rtol) - EnzymeTestUtils.test_forward(inv, TD, (D2, TD); atol, rtol) EnzymeTestUtils.test_forward(inv, TD, (D3, TD); atol, rtol) end end diff --git a/test/enzyme-linalg/mul.jl b/test/enzyme-linalg/mul.jl index 0d4043791..295a06334 100644 --- a/test/enzyme-linalg/mul.jl +++ b/test/enzyme-linalg/mul.jl @@ -27,8 +27,10 @@ fTs = is_ci ? (Duplicated,) : (Const, Duplicated) @testset "Tα $Tα, Tβ $Tβ" for Tα in fTs, Tβ in fTs EnzymeTestUtils.test_forward(mul!, TC, (C, TC), (A, TA), (B, TB), (α, Tα), (β, Tβ); atol, rtol) end - EnzymeTestUtils.test_reverse(mul!, TC, (C, TC), (A, TA), (B, TB); atol, rtol) - EnzymeTestUtils.test_forward(mul!, TC, (C, TC), (A, TA), (B, TB); atol, rtol) + if !is_ci + EnzymeTestUtils.test_reverse(mul!, TC, (C, TC), (A, TA), (B, TB); atol, rtol) + EnzymeTestUtils.test_forward(mul!, TC, (C, TC), (A, TA), (B, TB); atol, rtol) + end end end end From 9420158bf4cef99a5f9b5be1d2279fbabbfb514d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 13 Jun 2026 11:50:50 +0200 Subject: [PATCH 07/11] Use some common logic now in TO --- .../tensoroperations.jl | 16 +++--- ext/TensorKitChainRulesCoreExt/utility.jl | 20 ------- ext/TensorKitEnzymeExt/utility.jl | 51 ------------------ .../indexmanipulations.jl | 8 +-- ext/TensorKitMooncakeExt/linalg.jl | 14 ++--- ext/TensorKitMooncakeExt/planaroperations.jl | 8 +-- ext/TensorKitMooncakeExt/tensoroperations.jl | 22 ++++---- ext/TensorKitMooncakeExt/utility.jl | 53 ------------------- ext/TensorKitMooncakeExt/vectorinterface.jl | 8 +-- src/TensorKit.jl | 5 ++ src/auxiliary/ad.jl | 21 ++++++++ test/mooncake/tangent.jl | 2 +- 12 files changed, 65 insertions(+), 163 deletions(-) create mode 100644 src/auxiliary/ad.jl diff --git a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl index dfa8eb72c..82bd9b578 100644 --- a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl +++ b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl @@ -23,7 +23,7 @@ function ChainRulesCore.rrule( dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β))) dA = @thunk let ipA = invperm(linearize(pA)) - pdA = _repartition(ipA, A) + pdA = TO.repartition(ipA, numout(A)) TA = promote_add(ΔC, α) # TODO: allocator _dA = tensoralloc_add(TA, ΔC, pdA, conjA, Val(false)) @@ -40,7 +40,7 @@ function ChainRulesCore.rrule( _dα = tensorscalar( tensorcontract( A, ((), linearize(pA)), !conjA, - tΔC, (trivtuple(TO.numind(pA)), ()), false, + tΔC, (TO.trivialpermutation(TO.numind(pA)), ()), false, ((), ()), One(), ba... ) ) @@ -76,11 +76,11 @@ function ChainRulesCore.rrule( function pullback(ΔC′) ΔC = unthunk(ΔC′) ipAB = invperm(linearize(pAB)) - pΔC = _repartition(ipAB, TO.numout(pA)) + pΔC = TO.repartition(ipAB, pA) dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β))) dA = @thunk let - ipA = _repartition(invperm(linearize(pA)), A) + ipA = TO.repartition(invperm(linearize(pA)), numout(A)) conjΔC = conjA conjB′ = conjA ? conjB : !conjB TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α)) @@ -105,7 +105,7 @@ function ChainRulesCore.rrule( projectA(_dA) end dB = @thunk let - ipB = _repartition(invperm(linearize(pB)), B) + ipB = TO.repartition(invperm(linearize(pB)), numout(B)) conjΔC = conjB conjA′ = conjB ? conjA : !conjA TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α)) @@ -167,11 +167,11 @@ function ChainRulesCore.rrule( dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β))) dA = @thunk let ip = invperm((linearize(p)..., q[1]..., q[2]...)) - pdA = _repartition(ip, A) + pdA = TO.repartition(ip, numout(A)) E = one!(TO.tensoralloc_add(scalartype(A), A, q, conjA)) twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) - pE = ((), trivtuple(TO.numind(q))) - pΔC = (trivtuple(TO.numind(p)), ()) + pE = ((), TO.trivialpermutation(TO.numind(q))) + pΔC = (TO.trivialpermutation(TO.numind(p)), ()) TA = promote_scale(ΔC, α) # TODO: allocator _dA = tensoralloc_contract(TA, ΔC, pΔC, conjA, E, pE, conjA, pdA, Val(false)) diff --git a/ext/TensorKitChainRulesCoreExt/utility.jl b/ext/TensorKitChainRulesCoreExt/utility.jl index 85a444422..6896ff894 100644 --- a/ext/TensorKitChainRulesCoreExt/utility.jl +++ b/ext/TensorKitChainRulesCoreExt/utility.jl @@ -1,23 +1,3 @@ -# Utility -# ------- -trivtuple(N) = ntuple(identity, N) - -Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) - length(p) >= N₁ || - throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) - return TupleTools.getindices(p, trivtuple(N₁)), - TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) -end -Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) - return _repartition(linearize(p), N₁) -end -function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} - return _repartition(p, N₁) -end -function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap) - return _repartition(p, TensorKit.numout(t)) -end - TensorKit.block(t::ZeroTangent, c::Sector) = t ChainRulesCore.ProjectTo(::T) where {T <: AbstractTensorMap} = ProjectTo{T}() diff --git a/ext/TensorKitEnzymeExt/utility.jl b/ext/TensorKitEnzymeExt/utility.jl index 03ade424a..cfaf42751 100644 --- a/ext/TensorKitEnzymeExt/utility.jl +++ b/ext/TensorKitEnzymeExt/utility.jl @@ -12,57 +12,6 @@ pullback_dβ(β::Annotation, C::Annotation, Ccache) = project_scalar(β.val, inn pullback_dC!(ΔC, β::Number) = scale!(ΔC, conj(β)) -""" - project_scalar(x::Number, dx::Number) - -Project a computed tangent `dx` onto the correct tangent type for `x`. -For example, we might compute a complex `dx` but only require the real part. -""" -project_scalar(x::Number, dx::Number) = oftype(x, dx) -project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) - -# in-place multiplication and accumulation which might project to (real) -# TODO: this could probably be done without allocating -function project_mul!(C, A, B, α) - TC = TO.promote_contract(scalartype(A), scalartype(B), scalartype(α)) - return if !(TC <: Real) && scalartype(C) <: Real - add!(C, real(mul!(zerovector(C, TC), A, B, α))) - else - mul!(C, A, B, α, One()) - end -end -function project_contract!(C, A, pA, conjA, B, pB, conjB, pAB, α, backend, allocator) - TA = TensorKit.promote_permute(A) - TB = TensorKit.promote_permute(B) - TC = TO.promote_contract(TA, TB, scalartype(α)) - - return if scalartype(C) <: Real && !(TC <: Real) - add!(C, real(TO.tensorcontract!(zerovector(C, TC), A, pA, conjA, B, pB, conjB, pAB, α, Zero(), backend, allocator))) - else - TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, One(), backend, allocator) - end -end - -# IndexTuple utility -# ------------------ -trivtuple(N) = ntuple(identity, N) - -Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) - length(p) >= N₁ || - throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) - return TupleTools.getindices(p, trivtuple(N₁)), - TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) -end -Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) - return _repartition(linearize(p), N₁) -end -function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} - return _repartition(p, N₁) -end -function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap) - return _repartition(p, TensorKit.numout(t)) -end - # Ignore derivatives # ------------------ diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl index 1fccdd9e6..c0c9aecdc 100644 --- a/ext/TensorKitMooncakeExt/indexmanipulations.jl +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -45,7 +45,7 @@ for transform in (:permute, :transpose) # ΔA ip = invperm(linearize(p)) - pΔA = _repartition(ip, A) + pΔA = TO.repartition(ip, numout(A)) TC = VectorInterface.promote_scale(ΔC, α) if scalartype(ΔA) <: Real && !(TC <: Real) @@ -57,7 +57,7 @@ for transform in (:permute, :transpose) end ΔAr = NoRData() - Δαr = isnothing(Ap) ? NoRData() : project_scalar(α, inner(Ap, ΔC)) + Δαr = isnothing(Ap) ? NoRData() : TO.project_scalar(α, inner(Ap, ΔC)) Δβr = pullback_dβ(ΔC, C, β) ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() @@ -113,7 +113,7 @@ function Mooncake.rrule!!( # ΔA ip = invperm(linearize(p)) - pΔA = _repartition(ip, A) + pΔA = TO.repartition(ip, numout(A)) ilevels = TupleTools.permute(levels, linearize(p)) TC = VectorInterface.promote_scale(ΔC, α) if scalartype(ΔA) <: Real && !(TC <: Real) @@ -125,7 +125,7 @@ function Mooncake.rrule!!( end ΔAr = NoRData() - Δαr = isnothing(Ap) ? NoRData() : project_scalar(α, inner(Ap, ΔC)) + Δαr = isnothing(Ap) ? NoRData() : TO.project_scalar(α, inner(Ap, ΔC)) Δβr = pullback_dβ(ΔC, C, β) ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() diff --git a/ext/TensorKitMooncakeExt/linalg.jl b/ext/TensorKitMooncakeExt/linalg.jl index 943d1ea46..96d72a7bc 100644 --- a/ext/TensorKitMooncakeExt/linalg.jl +++ b/ext/TensorKitMooncakeExt/linalg.jl @@ -1,7 +1,7 @@ # Shared # ------ pullback_dC!(ΔC, β) = (scale!(ΔC, conj(β)); return NoRData()) -pullback_dβ(ΔC, C, β) = _needs_tangent(β) ? project_scalar(β, inner(C, ΔC)) : NoRData() +pullback_dβ(ΔC, C, β) = _needs_tangent(β) ? TO.project_scalar(β, inner(C, ΔC)) : NoRData() @is_primitive DefaultCtx Tuple{typeof(mul!), AbstractTensorMap, AbstractTensorMap, AbstractTensorMap, Number, Number} @@ -27,11 +27,11 @@ function Mooncake.rrule!!( function mul_pullback(::NoRData) copy!(C, C_cache) - project_mul!(ΔA, ΔC, B', conj(α)) - project_mul!(ΔB, A', ΔC, conj(α)) + TK.project_mul!(ΔA, ΔC, B', conj(α)) + TK.project_mul!(ΔB, A', ΔC, conj(α)) ΔAr = NoRData() ΔBr = NoRData() - Δαr = isnothing(AB) ? NoRData() : project_scalar(α, inner(AB, ΔC)) + Δαr = isnothing(AB) ? NoRData() : TO.project_scalar(α, inner(AB, ΔC)) Δβr = pullback_dβ(ΔC, C, β) ΔCr = pullback_dC!(ΔC, β) @@ -54,10 +54,10 @@ function Mooncake.frule!!( add!(ΔC, C, Δβ) end if !isa(Δα, Mooncake.NoTangent) - project_mul!(ΔC, A, B, Δα) + TK.project_mul!(ΔC, A, B, Δα) end - project_mul!(ΔC, ΔA, B, α) - project_mul!(ΔC, A, ΔB, α) + TK.project_mul!(ΔC, ΔA, B, α) + TK.project_mul!(ΔC, A, ΔB, α) mul!(C, A, B, α, β) return C_ΔC end diff --git a/ext/TensorKitMooncakeExt/planaroperations.jl b/ext/TensorKitMooncakeExt/planaroperations.jl index abbef5004..48d21a78f 100644 --- a/ext/TensorKitMooncakeExt/planaroperations.jl +++ b/ext/TensorKitMooncakeExt/planaroperations.jl @@ -59,17 +59,17 @@ # ) # if length(q[1]) == 0 # ip = invperm(linearize(p)) -# pΔA = _repartition(ip, A) +# pΔA = TK._repartition(ip, A) # TK.transpose!(ΔA, ΔC, pΔA, conj(α), One(), backend, allocator) # return NoRData() # end # # if length(q[1]) == 1 # # ip = invperm((p[1]..., q[2]..., p[2]..., q[1]...)) -# # pdA = _repartition(ip, A) +# # pdA = TK._repartition(ip, A) # # E = one!(TO.tensoralloc_add(scalartype(A), A, q, false)) # # twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) -# # # pE = ((), trivtuple(TO.numind(q))) -# # # pΔC = (trivtuple(TO.numind(p)), ()) +# # # pE = ((), TK.trivtuple(TO.numind(q))) +# # # pΔC = (TK.trivtuple(TO.numind(p)), ()) # # TensorKit.planaradd!(ΔA, ΔC ⊗ E, pdA, conj(α), One(), backend, allocator) # # return NoRData() # # end diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 5f47a5260..affde47cf 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -54,7 +54,7 @@ function Mooncake.rrule!!( ΔBr = blas_contract_pullback_ΔB!( ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) # this typically returns NoRData() - Δαr = isnothing(AB) ? NoRData() : project_scalar(α, inner(AB, ΔC)) + Δαr = isnothing(AB) ? NoRData() : TO.project_scalar(α, inner(AB, ΔC)) Δβr = pullback_dβ(ΔC, C, β) ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() @@ -103,8 +103,8 @@ function blas_contract_pullback_ΔA!( ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) ipAB = invperm(linearize(pAB)) - pΔC = _repartition(ipAB, TO.numout(pA)) - ipA = _repartition(invperm(linearize(pA)), A) + pΔC = TO.repartition(ipAB, pA) + ipA = TO.repartition(invperm(linearize(pA)), numout(A)) tB = twist( B, @@ -114,7 +114,7 @@ function blas_contract_pullback_ΔA!( ); copy = false ) - project_contract!( + TK.project_contract!( ΔA, ΔC, pΔC, false, tB, reverse(pB), true, @@ -128,8 +128,8 @@ function blas_contract_pullback_ΔB!( ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) ipAB = invperm(linearize(pAB)) - pΔC = _repartition(ipAB, TO.numout(pA)) - ipB = _repartition(invperm(linearize(pB)), B) + pΔC = TO.repartition(ipAB, pA) + ipB = TO.repartition(invperm(linearize(pB)), numout(B)) tA = twist( A, @@ -139,7 +139,7 @@ function blas_contract_pullback_ΔB!( ); copy = false ) - project_contract!( + TK.project_contract!( ΔB, tA, reverse(pA), true, ΔC, pΔC, false, @@ -193,7 +193,7 @@ function Mooncake.rrule!!( ΔAr = trace_permute_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend) # this typically returns NoRData() - Δαr = isnothing(At) ? NoRData() : project_scalar(α, inner(At, ΔC)) + Δαr = isnothing(At) ? NoRData() : TO.project_scalar(α, inner(At, ΔC)) Δβr = pullback_dβ(ΔC, C, β) ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() @@ -240,11 +240,11 @@ function trace_permute_pullback_ΔA!( ΔA, ΔC, A, p, q, α, backend ) ip = invperm((linearize(p)..., q[1]..., q[2]...)) - pdA = _repartition(ip, A) + pdA = TO.repartition(ip, numout(A)) E = one!(TO.tensoralloc_add(scalartype(A), A, q, false)) twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) - pE = ((), trivtuple(TO.numind(q))) - pΔC = (trivtuple(TO.numind(p)), ()) + pE = ((), TO.trivialpermutation(TO.numind(q))) + pΔC = (TO.trivialpermutation(TO.numind(p)), ()) TO.tensorproduct!( ΔA, ΔC, pΔC, false, E, pE, false, pdA, conj(α), One(), backend ) diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl index ceb32d867..ddd6df6dc 100644 --- a/ext/TensorKitMooncakeExt/utility.jl +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -2,59 +2,6 @@ _needs_tangent(x) = _needs_tangent(typeof(x)) _needs_tangent(::Type{T}) where {T <: Number} = Mooncake.rdata_type(Mooncake.tangent_type(T)) !== NoRData -# Projection -# ---------- -""" - project_scalar(x::Number, dx::Number) - -Project a computed tangent `dx` onto the correct tangent type for `x`. -For example, we might compute a complex `dx` but only require the real part. -""" -project_scalar(x::Number, dx::Number) = oftype(x, dx) -project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) - -# in-place multiplication and accumulation which might project to (real) -# TODO: this could probably be done without allocating -function project_mul!(C, A, B, α) - TC = TO.promote_contract(scalartype(A), scalartype(B), scalartype(α)) - return if !(TC <: Real) && scalartype(C) <: Real - add!(C, real(mul!(zerovector(C, TC), A, B, α))) - else - mul!(C, A, B, α, One()) - end -end -function project_contract!(C, A, pA, conjA, B, pB, conjB, pAB, α, backend, allocator) - TA = TensorKit.promote_permute(A) - TB = TensorKit.promote_permute(B) - TC = TO.promote_contract(TA, TB, scalartype(α)) - - return if scalartype(C) <: Real && !(TC <: Real) - add!(C, real(TO.tensorcontract!(zerovector(C, TC), A, pA, conjA, B, pB, conjB, pAB, α, Zero(), backend, allocator))) - else - TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, One(), backend, allocator) - end -end - -# IndexTuple utility -# ------------------ -trivtuple(N) = ntuple(identity, N) - -Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) - length(p) >= N₁ || - throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) - return TupleTools.getindices(p, trivtuple(N₁)), - TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) -end -Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) - return _repartition(linearize(p), N₁) -end -function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} - return _repartition(p, N₁) -end -function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap) - return _repartition(p, TensorKit.numout(t)) -end - # Ignore derivatives # ------------------ diff --git a/ext/TensorKitMooncakeExt/vectorinterface.jl b/ext/TensorKitMooncakeExt/vectorinterface.jl index a6f2db85f..a37b28f3b 100644 --- a/ext/TensorKitMooncakeExt/vectorinterface.jl +++ b/ext/TensorKitMooncakeExt/vectorinterface.jl @@ -11,7 +11,7 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTens function scale_pullback(::NoRData) copy!(C, C_cache) - Δαr = _needs_tangent(α) ? project_scalar(α, inner(C, ΔC)) : NoRData() + Δαr = _needs_tangent(α) ? TO.project_scalar(α, inner(C, ΔC)) : NoRData() scale!(ΔC, conj(α)) return NoRData(), NoRData(), Δαr end @@ -34,7 +34,7 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTens function scale_pullback(::NoRData) copy!(C, C_cache) add!(ΔA, ΔC, conj(α)) - Δαr = _needs_tangent(α) ? project_scalar(α, inner(A, ΔC)) : NoRData() + Δαr = _needs_tangent(α) ? TO.project_scalar(α, inner(A, ΔC)) : NoRData() zerovector!(ΔC) return NoRData(), NoRData(), NoRData(), Δαr end @@ -58,8 +58,8 @@ function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractTensor function add_pullback(::NoRData) copy!(C, C_cache) - Δαr = _needs_tangent(α) ? project_scalar(α, inner(A, ΔC)) : NoRData() - Δβr = _needs_tangent(β) ? project_scalar(β, inner(C, ΔC)) : NoRData() + Δαr = _needs_tangent(α) ? TO.project_scalar(α, inner(A, ΔC)) : NoRData() + Δβr = _needs_tangent(β) ? TO.project_scalar(β, inner(C, ΔC)) : NoRData() add!(ΔA, ΔC, conj(α)) scale!(ΔC, conj(β)) diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 6a2828588..0622d6e6a 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -280,4 +280,9 @@ include("planar/macros.jl") @specialize include("planar/planaroperations.jl") +# include some AD specific things at the end +# once all types have been declared +# ------------------------ +include("auxiliary/ad.jl") + end diff --git a/src/auxiliary/ad.jl b/src/auxiliary/ad.jl new file mode 100644 index 000000000..df6556fc2 --- /dev/null +++ b/src/auxiliary/ad.jl @@ -0,0 +1,21 @@ +# in-place multiplication and accumulation which might project to (real) +# TODO: this could probably be done without allocating +function project_mul!(C, A, B, α) + TC = TO.promote_contract(scalartype(A), scalartype(B), scalartype(α)) + return if !(TC <: Real) && scalartype(C) <: Real + add!(C, real(mul!(zerovector(C, TC), A, B, α))) + else + mul!(C, A, B, α, One()) + end +end +function project_contract!(C, A, pA, conjA, B, pB, conjB, pAB, α, backend, allocator) + TA = TensorKit.promote_permute(A) + TB = TensorKit.promote_permute(B) + TC = TO.promote_contract(TA, TB, scalartype(α)) + + return if scalartype(C) <: Real && !(TC <: Real) + add!(C, real(TO.tensorcontract!(zerovector(C, TC), A, pA, conjA, B, pB, conjB, pAB, α, Zero(), backend, allocator))) + else + TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, One(), backend, allocator) + end +end diff --git a/test/mooncake/tangent.jl b/test/mooncake/tangent.jl index 7b7c9ab4d..7f9088a12 100644 --- a/test/mooncake/tangent.jl +++ b/test/mooncake/tangent.jl @@ -2,7 +2,7 @@ using Test, TestExtras using TensorKit using Mooncake using Random -using JET, AllocCheck +using AllocCheck mode = Mooncake.ReverseMode rng = Random.default_rng() From 30694496c7695c77b7a7adea810645a13896cb73 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 13 Jun 2026 12:29:53 +0200 Subject: [PATCH 08/11] Touch ups --- ext/TensorKitEnzymeExt/linalg.jl | 10 +++++----- ext/TensorKitMooncakeExt/linalg.jl | 10 +++++----- src/auxiliary/ad.jl | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/ext/TensorKitEnzymeExt/linalg.jl b/ext/TensorKitEnzymeExt/linalg.jl index 2e61c3bca..c475905ab 100644 --- a/ext/TensorKitEnzymeExt/linalg.jl +++ b/ext/TensorKitEnzymeExt/linalg.jl @@ -49,8 +49,8 @@ function EnzymeRules.reverse( Aval = something(cacheA, A.val) Bval = something(cacheB, B.val) - !isa(A, Const) && !isa(C, Const) && project_mul!(A.dval, C.dval, Bval', conj(α.val)) - !isa(B, Const) && !isa(C, Const) && project_mul!(B.dval, Aval', C.dval, conj(α.val)) + !isa(A, Const) && !isa(C, Const) && TK.project_mul!(A.dval, C.dval, Bval', conj(α.val), One()) + !isa(B, Const) && !isa(C, Const) && TK.project_mul!(B.dval, Aval', C.dval, conj(α.val), One()) Δαr = pullback_dα(α, C, AB) Δβr = pullback_dβ(β, C, Cval) !isa(C, Const) && pullback_dC!(C.dval, β.val) @@ -72,9 +72,9 @@ function EnzymeRules.forward( if !isa(C, Const) scale!(C.dval, β.val) !isa(β, Const) && add!(C.dval, C.val, β.dval) - !isa(α, Const) && project_mul!(C.dval, A.val, B.val, α.dval) - !isa(A, Const) && project_mul!(C.dval, A.dval, B.val, α.val) - !isa(B, Const) && project_mul!(C.dval, A.val, B.dval, α.val) + !isa(α, Const) && TK.project_mul!(C.dval, A.val, B.val, α.dval, One()) + !isa(A, Const) && TK.project_mul!(C.dval, A.dval, B.val, α.val, One()) + !isa(B, Const) && TK.project_mul!(C.dval, A.val, B.dval, α.val, One()) end mul!(C.val, A.val, B.val, α.val, β.val) if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) diff --git a/ext/TensorKitMooncakeExt/linalg.jl b/ext/TensorKitMooncakeExt/linalg.jl index 96d72a7bc..ebcf922bd 100644 --- a/ext/TensorKitMooncakeExt/linalg.jl +++ b/ext/TensorKitMooncakeExt/linalg.jl @@ -27,8 +27,8 @@ function Mooncake.rrule!!( function mul_pullback(::NoRData) copy!(C, C_cache) - TK.project_mul!(ΔA, ΔC, B', conj(α)) - TK.project_mul!(ΔB, A', ΔC, conj(α)) + TK.project_mul!(ΔA, ΔC, B', conj(α), One()) + TK.project_mul!(ΔB, A', ΔC, conj(α), One()) ΔAr = NoRData() ΔBr = NoRData() Δαr = isnothing(AB) ? NoRData() : TO.project_scalar(α, inner(AB, ΔC)) @@ -54,10 +54,10 @@ function Mooncake.frule!!( add!(ΔC, C, Δβ) end if !isa(Δα, Mooncake.NoTangent) - TK.project_mul!(ΔC, A, B, Δα) + TK.project_mul!(ΔC, A, B, Δα, One()) end - TK.project_mul!(ΔC, ΔA, B, α) - TK.project_mul!(ΔC, A, ΔB, α) + TK.project_mul!(ΔC, ΔA, B, α, One()) + TK.project_mul!(ΔC, A, ΔB, α, One()) mul!(C, A, B, α, β) return C_ΔC end diff --git a/src/auxiliary/ad.jl b/src/auxiliary/ad.jl index df6556fc2..c4b849f9e 100644 --- a/src/auxiliary/ad.jl +++ b/src/auxiliary/ad.jl @@ -1,11 +1,11 @@ # in-place multiplication and accumulation which might project to (real) # TODO: this could probably be done without allocating -function project_mul!(C, A, B, α) +function project_mul!(C, A, B, α, β = One()) TC = TO.promote_contract(scalartype(A), scalartype(B), scalartype(α)) return if !(TC <: Real) && scalartype(C) <: Real add!(C, real(mul!(zerovector(C, TC), A, B, α))) else - mul!(C, A, B, α, One()) + mul!(C, A, B, α, β) end end function project_contract!(C, A, pA, conjA, B, pB, conjB, pAB, α, backend, allocator) From 6a66e4ccccc8342b3ac854c754f99eb1ca04f98c Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 13 Jun 2026 13:12:38 +0200 Subject: [PATCH 09/11] Restore JET --- test/mooncake/tangent.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mooncake/tangent.jl b/test/mooncake/tangent.jl index 7f9088a12..7b7c9ab4d 100644 --- a/test/mooncake/tangent.jl +++ b/test/mooncake/tangent.jl @@ -2,7 +2,7 @@ using Test, TestExtras using TensorKit using Mooncake using Random -using AllocCheck +using JET, AllocCheck mode = Mooncake.ReverseMode rng = Random.default_rng() From a670e110cd7974da534f1f41709697f428036bfe Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 13 Jun 2026 16:18:43 +0200 Subject: [PATCH 10/11] Missed some project scalars --- ext/TensorKitEnzymeExt/utility.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/TensorKitEnzymeExt/utility.jl b/ext/TensorKitEnzymeExt/utility.jl index cfaf42751..46a8c8304 100644 --- a/ext/TensorKitEnzymeExt/utility.jl +++ b/ext/TensorKitEnzymeExt/utility.jl @@ -3,12 +3,12 @@ pullback_dα(α::Const, C::Const, A) = nothing pullback_dα(α::Const, C::Annotation, A) = nothing pullback_dα(α::Annotation, C::Const, A) = zero(α.val) -pullback_dα(α::Annotation, C::Annotation, A) = project_scalar(α.val, inner(A, C.dval)) +pullback_dα(α::Annotation, C::Annotation, A) = TO.project_scalar(α.val, inner(A, C.dval)) pullback_dβ(β::Const, C::Const, Ccache) = nothing pullback_dβ(β::Const, C::Annotation, Ccache) = nothing pullback_dβ(β::Annotation, C::Const, Ccache) = zero(β.val) -pullback_dβ(β::Annotation, C::Annotation, Ccache) = project_scalar(β.val, inner(Ccache, C.dval)) +pullback_dβ(β::Annotation, C::Annotation, Ccache) = TO.project_scalar(β.val, inner(Ccache, C.dval)) pullback_dC!(ΔC, β::Number) = scale!(ΔC, conj(β)) From d76d135ac6a9842cea2736623440db99c10e8bc9 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 13 Jun 2026 16:19:13 +0200 Subject: [PATCH 11/11] Up Enzyme compat --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index f8a50d723..530d76d71 100644 --- a/Project.toml +++ b/Project.toml @@ -47,8 +47,8 @@ AMDGPU = "2" CUDA = "6" ChainRulesCore = "1" Dictionaries = "0.4" -Enzyme = "0.13.146" -EnzymeTestUtils = "0.2.7" +Enzyme = "0.13.157" +EnzymeTestUtils = "0.2.8" FiniteDifferences = "0.12" LRUCache = "1.0.2" LinearAlgebra = "1"