Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ext/TensorKitMooncakeExt/tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ _field_symbol(t, ::Val{F}) where {F} = _field_symbol(t, F)

# frules
_frule_getfield_common(t_dt::Dual{<:DiagOrTensorMap}, field_sym::Symbol) =
Dual(getfield(primal(t), field_sym), field_sym === :data ? tangent(t).data : NoFData())
Dual(getfield(primal(t_dt), field_sym), field_sym === :data ? tangent(t_dt).data : NoTangent())

Mooncake.frule!!(::Dual{typeof(Mooncake.lgetfield)}, t_dt::Dual{<:DiagOrTensorMap}, f_df::Dual) =
_frule_getfield_common(t_dt, _field_symbol(primal(t_dt), primal(f_df)))
Expand Down
1 change: 1 addition & 0 deletions ext/TensorKitMooncakeExt/utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Mooncake.tangent_type(::Type{<:HomSpace}) = Mooncake.NoTangent

@zero_derivative DefaultCtx Tuple{typeof(TensorKit.select), HomSpace, Index2Tuple}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.flip), HomSpace, Any}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.adjoint), HomSpace}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.permute), HomSpace, Index2Tuple}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.braid), HomSpace, Index2Tuple, IndexTuple}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.compose), HomSpace, HomSpace}
Expand Down
93 changes: 4 additions & 89 deletions ext/TensorKitMooncakeExt/vectorinterface.jl

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if something like this works:

@static if pkgversion(VectorInterface) < v"0.6.0"
    # previous implementation
end

This would allow us to continue supporting old + new version of VectorInterface.

For the current implementation, I guess this might break the rrule implementations that were already there if the correct version of VectorInterface is not loaded. In that case, I would maybe just bite the bullet and actually restrict to VectorInterface v0.6 and just deal with that?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, I can try it tomorrow!

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having to support the old versions again leads to having duplicated rules all over the place, though, one advantage of doing it the current way is being able to cut quite a few lines from the extension.

Original file line number Diff line number Diff line change
@@ -1,89 +1,4 @@
@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, Number}

function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
α = primal(α_Δα)

# primal call
C_cache = copy(C)
scale!(C, α)

function scale_pullback(::NoRData)
copy!(C, C_cache)
Δαr = _needs_tangent(α) ? project_scalar(α, inner(C, ΔC)) : NoRData()
scale!(ΔC, conj(α))
return NoRData(), NoRData(), Δαr
end

return C_ΔC, scale_pullback
end

@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, AbstractTensorMap, Number}

function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
A, ΔA = arrayify(A_ΔA)
α = primal(α_Δα)

# primal call
C_cache = copy(C)
scale!(C, A, α)

function scale_pullback(::NoRData)
copy!(C, C_cache)
add!(ΔA, ΔC, conj(α))
Δαr = _needs_tangent(α) ? project_scalar(α, inner(A, ΔC)) : NoRData()
zerovector!(ΔC)
return NoRData(), NoRData(), NoRData(), Δαr
end

return C_ΔC, scale_pullback
end

@is_primitive DefaultCtx ReverseMode Tuple{typeof(add!), AbstractTensorMap, AbstractTensorMap, Number, Number}

function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
A, ΔA = arrayify(A_ΔA)
α = primal(α_Δα)
β = primal(β_Δβ)

# primal call
C_cache = copy(C)
add!(C, A, α, β)

function add_pullback(::NoRData)
copy!(C, C_cache)

Δαr = _needs_tangent(α) ? project_scalar(α, inner(A, ΔC)) : NoRData()
Δβr = _needs_tangent(β) ? project_scalar(β, inner(C, ΔC)) : NoRData()
add!(ΔA, ΔC, conj(α))
scale!(ΔC, conj(β))

return NoRData(), NoRData(), NoRData(), Δαr, Δβr
end

return C_ΔC, add_pullback
end

@is_primitive DefaultCtx ReverseMode Tuple{typeof(inner), AbstractTensorMap, AbstractTensorMap}

function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractTensorMap}, B_ΔB::CoDual{<:AbstractTensorMap})
# prepare arguments
A, ΔA = arrayify(A_ΔA)
B, ΔB = arrayify(B_ΔB)

# primal call
s = inner(A, B)

function inner_pullback(Δs)
add!(ΔA, B, conj(Δs))
add!(ΔB, A, Δs)
return NoRData(), NoRData(), NoRData()
end

return CoDual(s, NoFData()), inner_pullback
end
@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractTensorMap, Number}
@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractTensorMap, AbstractTensorMap, Number}
@is_primitive DefaultCtx Tuple{typeof(add!), AbstractTensorMap, AbstractTensorMap, Number, Number}
@is_primitive DefaultCtx Tuple{typeof(inner), AbstractTensorMap, AbstractTensorMap}
38 changes: 19 additions & 19 deletions test/mooncake/vectorinterface.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
using Test, TestExtras
using TensorKit
using TensorOperations
using Mooncake
using VectorInterface, Mooncake
using Random


mode = Mooncake.ReverseMode
rng = Random.default_rng()

spacelist = ad_spacelist(fast_tests)
Expand All @@ -17,20 +15,22 @@ eltypes = (Float64, ComplexF64)

C = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')
A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')
α = randn(T)
β = randn(T)

Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol, mode)
Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol, mode)
Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol, mode)
Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol, mode)
Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol, mode)
Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol, mode)

Mooncake.TestUtils.test_rule(rng, add!, C, A; atol, rtol, mode, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, add!, C, A, α; atol, rtol, mode, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol, mode)

Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol, mode)
Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol, mode)
for α in (randn(T), One(), Zero()), β in (randn(T), One(), Zero())
Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol)
Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol)
Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol)
Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol)
Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol)
Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol)

Mooncake.TestUtils.test_rule(rng, add!, C, A; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, add!, C, A, α; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol)
Comment thread
kshyatt marked this conversation as resolved.
Mooncake.TestUtils.test_rule(rng, add!, C', A'; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, add!, copy(C'), A', α; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, add!, C', copy(A'), α, β; atol, rtol)

Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol)
Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol)
end
end
Loading