Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
Expand All @@ -31,6 +33,8 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
TensorKitAMDGPUExt = "AMDGPU"
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
TensorKitChainRulesCoreExt = "ChainRulesCore"
TensorKitEnzymeExt = "Enzyme"
TensorKitEnzymeTestUtilsExt = "EnzymeTestUtils"
TensorKitFiniteDifferencesExt = "FiniteDifferences"
TensorKitMooncakeExt = "Mooncake"

Expand All @@ -43,19 +47,21 @@ AMDGPU = "2"
CUDA = "6"
ChainRulesCore = "1"
Dictionaries = "0.4"
Enzyme = "0.13.157"
EnzymeTestUtils = "0.2.8"
FiniteDifferences = "0.12"
LRUCache = "1.0.2"
LinearAlgebra = "1"
MatrixAlgebraKit = "0.6.7"
MatrixAlgebraKit = "0.6.8"
Mooncake = "0.5.27"
OhMyThreads = "0.8.0"
Printf = "1"
Random = "1"
ScopedValues = "1.3.0"
Strided = "2"
TensorKitSectors = "0.3.7"
TensorOperations = "5.5"
TensorOperations = "5.5.2"
TupleTools = "1.5"
VectorInterface = "0.4.8, 0.5, 0.6"
VectorInterface = "0.4.8, 0.5"
cuTENSOR = "6"
julia = "1.10"
16 changes: 8 additions & 8 deletions ext/TensorKitChainRulesCoreExt/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function ChainRulesCore.rrule(
dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk let
ipA = invperm(linearize(pA))
pdA = _repartition(ipA, A)
pdA = TO.repartition(ipA, numout(A))
TA = promote_add(ΔC, α)
# TODO: allocator
_dA = tensoralloc_add(TA, ΔC, pdA, conjA, Val(false))
Expand All @@ -40,7 +40,7 @@ function ChainRulesCore.rrule(
_dα = tensorscalar(
tensorcontract(
A, ((), linearize(pA)), !conjA,
tΔC, (trivtuple(TO.numind(pA)), ()), false,
tΔC, (TO.trivialpermutation(TO.numind(pA)), ()), false,
((), ()), One(), ba...
)
)
Expand Down Expand Up @@ -76,11 +76,11 @@ function ChainRulesCore.rrule(
function pullback(ΔC′)
ΔC = unthunk(ΔC′)
ipAB = invperm(linearize(pAB))
pΔC = _repartition(ipAB, TO.numout(pA))
pΔC = TO.repartition(ipAB, pA)

dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk let
ipA = _repartition(invperm(linearize(pA)), A)
ipA = TO.repartition(invperm(linearize(pA)), numout(A))
conjΔC = conjA
conjB′ = conjA ? conjB : !conjB
TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))
Expand All @@ -105,7 +105,7 @@ function ChainRulesCore.rrule(
projectA(_dA)
end
dB = @thunk let
ipB = _repartition(invperm(linearize(pB)), B)
ipB = TO.repartition(invperm(linearize(pB)), numout(B))
conjΔC = conjB
conjA′ = conjB ? conjA : !conjA
TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))
Expand Down Expand Up @@ -167,11 +167,11 @@ function ChainRulesCore.rrule(
dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk let
ip = invperm((linearize(p)..., q[1]..., q[2]...))
pdA = _repartition(ip, A)
pdA = TO.repartition(ip, numout(A))
E = one!(TO.tensoralloc_add(scalartype(A), A, q, conjA))
twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E)))
pE = ((), trivtuple(TO.numind(q)))
pΔC = (trivtuple(TO.numind(p)), ())
pE = ((), TO.trivialpermutation(TO.numind(q)))
pΔC = (TO.trivialpermutation(TO.numind(p)), ())
TA = promote_scale(ΔC, α)
# TODO: allocator
_dA = tensoralloc_contract(TA, ΔC, pΔC, conjA, E, pE, conjA, pdA, Val(false))
Expand Down
20 changes: 0 additions & 20 deletions ext/TensorKitChainRulesCoreExt/utility.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,3 @@
# Utility
# -------
trivtuple(N) = ntuple(identity, N)

Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int)
length(p) >= N₁ ||
throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)"))
return TupleTools.getindices(p, trivtuple(N₁)),
TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁)
end
Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int)
return _repartition(linearize(p), N₁)
end
function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁}
return _repartition(p, N₁)
end
function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap)
return _repartition(p, TensorKit.numout(t))
end

TensorKit.block(t::ZeroTangent, c::Sector) = t

ChainRulesCore.ProjectTo(::T) where {T <: AbstractTensorMap} = ProjectTo{T}()
Expand Down
16 changes: 16 additions & 0 deletions ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module TensorKitEnzymeExt

using Enzyme
using TensorKit
import TensorKit as TK
using VectorInterface
using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize
import TensorOperations as TO
using MatrixAlgebraKit
using TupleTools
using Random: AbstractRNG

include("utility.jl")
include("linalg.jl")

end
Loading
Loading