Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "NamedDimsArrays"
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
version = "0.15.8"
version = "0.15.9"
authors = ["ITensor developers <support@itensor.org> and contributors"]

[workspace]
Expand All @@ -10,7 +10,6 @@ projects = ["benchmark", "dev", "docs", "examples", "test"]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -39,13 +38,12 @@ ArrayLayouts = "1.11"
BlockArrays = "1.3"
Compat = "4.16"
FillArrays = "1.13"
FunctionImplementations = "0.4"
LinearAlgebra = "1.10"
Mooncake = "0.4.202, 0.5"
OrderedCollections = "1.6"
Random = "1.10"
SimpleTraits = "0.9.4"
TensorAlgebra = "0.9.5"
TensorAlgebra = "0.9.6"
TupleTools = "1.6"
TypeParameterAccessors = "0.4"
VectorInterface = "0.5, 0.6"
Expand Down
21 changes: 3 additions & 18 deletions src/abstractnameddimsarray.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using FunctionImplementations: FunctionImplementations as FI
using LinearAlgebra: LinearAlgebra
using Random: Random
using TensorAlgebra: permuteddims
using TypeParameterAccessors: unspecify_type_parameters

# Some of the interface is inspired by:
Expand All @@ -9,20 +9,11 @@ using TypeParameterAccessors: unspecify_type_parameters
# https://github.com/mcabbott/NamedPlus.jl
# https://pytorch.org/docs/stable/named_tensor.html

abstract type AbstractNamedDimsArrayImplementationStyle <:
FI.AbstractArrayImplementationStyle end

struct NamedDimsArrayImplementationStyle <: AbstractNamedDimsArrayImplementationStyle end

abstract type AbstractNamedDimsArray{T, N} <: AbstractArray{T, N} end

const AbstractNamedDimsVector{T} = AbstractNamedDimsArray{T, 1}
const AbstractNamedDimsMatrix{T} = AbstractNamedDimsArray{T, 2}

function FI.ImplementationStyle(type::Type{<:AbstractNamedDimsArray})
return NamedDimsArrayImplementationStyle()
end

dimnames(a::AbstractNamedDimsArray) = throw(MethodError(dimnames, a))
function dimnames(a::AbstractNamedDimsArray, dim::Int)
return dimnames(a)[dim]
Expand Down Expand Up @@ -54,13 +45,7 @@ dimnametype(type::Type{<:AbstractNamedDimsArray}) = throw(MethodError(dimnametyp
# Unwrapping the names (`NamedDimsArrays.jl` interface).
# TODO: Use `IsNamed` trait?
denamed(a::AbstractNamedDimsArray) = throw(MethodError(denamed, a))
function denamed(a::AbstractNamedDimsArray, inds)
# Skip the lazy `PermutedDimsArray` wrap when the requested order already
# matches `a`'s; compare via `Tuple` because `LittleSet ==` is
# set-equality and would mask a permutation.
Tuple(name.(inds)) == Tuple(dimnames(a)) && return denamed(a)
return denamed(aligneddims(a, inds))
end
denamed(a::AbstractNamedDimsArray, inds) = denamed(aligneddims(a, inds))
dename(a::AbstractNamedDimsArray, inds) = denamed(aligndims(a, inds))

# Output the named axes/indices of the named dims array.
Expand Down Expand Up @@ -706,7 +691,7 @@ function aligneddims(a::AbstractArray, dims)
)
)
return nameddimsconstructorof(a)(
FI.permuteddims(denamed(a), perm), new_dimnames
permuteddims(denamed(a), perm), new_dimnames
)
end

Expand Down
29 changes: 27 additions & 2 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ function broadcasted_denamed(bc::Broadcasted, inds)
return broadcasted(bc.f, Base.Fix2(broadcasted_denamed, inds).(bc.args)...)
end

# A bare (unnamed) array operand, used as an allocation prototype so a broadcast
# result inherits the operands' backend (e.g. graded) rather than a lazy permuted
# wrapper's `similar` (which can drop the backend).
denamed_prototype(bc::Broadcasted) = denamed_prototype(bc.args...)
denamed_prototype(arg::AbstractNamedDimsArray, args...) = denamed(arg)
denamed_prototype(arg::Broadcasted, args...) = denamed_prototype(arg.args..., args...)
denamed_prototype(arg, args...) = denamed_prototype(args...)

function Base.similar(bc::Broadcasted{<:AbstractNamedDimsArrayStyle}, elt::Type, ax)
inds_a = name.(ax)
bc_denamed = broadcasted_denamed(bc, inds_a)
Expand All @@ -120,14 +128,31 @@ function Base.copy(bc::Broadcasted{<:AbstractNamedDimsArrayStyle})
# the output element type at runtime with widening.
inds_dest = inds(bc)
bc_denamed = broadcasted_denamed(bc, inds_dest)
dest_denamed = copy(bc_denamed)
lb = TA.tryflattenlinear(bc_denamed)
if isnothing(lb)
# Not a linear combination: ordinary fused broadcast.
dest_denamed = copy(bc_denamed)
else
# Linear: lower to bipermutedimsopadd!. Allocate from an operand so the
# result keeps the backend, using the backend's result axes (not `lb`'s).
dest_axes = denamed.(Tuple(axes(bc)))
dest_denamed = similar(denamed_prototype(bc), eltype(lb), dest_axes)
copyto!(dest_denamed, lb)
end
return nameddimstype(bc.style)(dest_denamed, inds_dest)
end

function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractNamedDimsArrayStyle})
dest_denamed = denamed(dest)
inds_dest = inds(dest)
bc_denamed = broadcasted_denamed(bc, inds_dest)
copyto!(dest_denamed, bc_denamed)
lb = TA.tryflattenlinear(bc_denamed)
if isnothing(lb)
# Not a linear combination: ordinary fused broadcast.
copyto!(dest_denamed, bc_denamed)
else
# Linear: lower to bipermutedimsopadd! into the existing dest.
copyto!(dest_denamed, lb)
end
return dest
end
Loading