Skip to content

Commit a50858d

Browse files
authored
Trivial tensors fast path into TensorOperations machinery (#463)
1 parent 69d3c3b commit a50858d

2 files changed

Lines changed: 15 additions & 5 deletions

File tree

src/tensors/indexmanipulations.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,7 @@ Base.@deprecate(
553553
if p[1] === codomainind(tsrc) && p[2] === domainind(tsrc)
554554
add!(tdst, tsrc, α, β)
555555
else
556-
I = sectortype(tdst)
557-
if I === Trivial
556+
if has_array_view(tdst) && has_array_view(tsrc)
558557
TO.tensoradd!(tdst[], tsrc[], p, false, α, β, backend, allocator)
559558
else
560559
ntasks = use_threaded_transform(tdst, transformer) ? get_num_transformer_threads() : 1

src/tensors/tensoroperations.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ function _canonicalize(p::IndexTuple, t::AbstractTensorMap)
3333
return (p₁, p₂)
3434
end
3535

36+
# Whether a tensor can be viewed as a single contiguous array, such that
37+
# the fusiontree machinery and act directly on the `t[]` view.
38+
has_array_view(t) = has_array_view(typeof(t))
39+
has_array_view(::Type) = false
40+
has_array_view(::Type{T}) where {T <: TensorMap} = sectortype(T) === Trivial
41+
has_array_view(::Type{T}) where {T <: AdjointTensorMap} = has_array_view(parenttype(T))
42+
3643
# tensoradd!
3744
function TO.tensoradd!(
3845
C::AbstractTensorMap,
@@ -43,9 +50,9 @@ function TO.tensoradd!(
4350
if conjA
4451
A′ = adjoint(A)
4552
pA′ = adjointtensorindices(A, _canonicalize(pA, C))
46-
permute!(C, A′, pA′, α, β, backend)
53+
permute!(C, A′, pA′, α, β, backend, allocator)
4754
else
48-
permute!(C, A, _canonicalize(pA, C), α, β, backend)
55+
permute!(C, A, _canonicalize(pA, C), α, β, backend, allocator)
4956
end
5057
return C
5158
end
@@ -125,6 +132,10 @@ function TO.tensorcontract!(
125132
)
126133
pAB′ = _canonicalize(pAB, C)
127134
@boundscheck spacecheck_contract(C, A, pA, conjA, B, pB, conjB, pAB′)
135+
if has_array_view(C) && has_array_view(A) && has_array_view(B)
136+
TO.tensorcontract!(C[], A[], pA, conjA, B[], pB, conjB, pAB′, α, β, backend, allocator)
137+
return C
138+
end
128139
if conjA && conjB
129140
A′ = A'
130141
pA′ = adjointtensorindices(A, pA)
@@ -219,7 +230,7 @@ function trace_permute!(
219230
q₁ = $(q₁), q₂ = $(q₂)"))
220231
end
221232

222-
if I === Trivial
233+
if has_array_view(tdst) && has_array_view(tsrc)
223234
TO.tensortrace!(tdst[], tsrc[], (p₁, p₂), (q₁, q₂), false, α, β, backend)
224235
else
225236
_trace_permute!(FusionStyle(I), tdst, tsrc, (p₁, p₂), (q₁, q₂), α, β, backend)

0 commit comments

Comments
 (0)