From 7ddfa0c9d7faad4edc20e00c1ada54d08ebaf95e Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Wed, 17 Jun 2026 11:53:56 -0400 Subject: [PATCH 1/3] Lower linear broadcasting through bipermutedimsopadd! copyto! and copy now run the denamed broadcast through tryflattenlinear, and the linear case lowers to permutedimsopadd! on the name-aligned operands, where each operand's PermutedDimsArray alignment is unwrapped down to the backend's own bipermutedimsopadd!. Nonlinear broadcasts fall back to the existing denamed Base broadcast. aligneddims now aligns through the overloadable TensorAlgebra.permuteddims hook, and the FunctionImplementations dependency is dropped (the dead ImplementationStyle trait that routed the old lazy permute is removed). Co-Authored-By: Claude Opus 4.8 --- Project.toml | 10 ++++++---- src/abstractnameddimsarray.jl | 21 +++------------------ src/broadcast.jl | 24 ++++++++++++++++++++++-- 3 files changed, 31 insertions(+), 24 deletions(-) diff --git a/Project.toml b/Project.toml index 0eb0784b..05c0e345 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "NamedDimsArrays" uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" -version = "0.15.8" +version = "0.15.9" authors = ["ITensor developers and contributors"] [workspace] @@ -10,7 +10,6 @@ projects = ["benchmark", "dev", "docs", "examples", "test"] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -26,6 +25,10 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +[sources.TensorAlgebra] +rev = "mf/permuteddims-unwrap" +url = "https://github.com/ITensor/TensorAlgebra.jl" + [extensions] NamedDimsArraysAbstractTreesExt = "AbstractTrees" NamedDimsArraysAdaptExt = "Adapt" @@ -39,13 +42,12 @@ ArrayLayouts = "1.11" BlockArrays = "1.3" Compat = "4.16" FillArrays = "1.13" -FunctionImplementations = "0.4" LinearAlgebra = "1.10" Mooncake = "0.4.202, 0.5" OrderedCollections = "1.6" Random = "1.10" SimpleTraits = "0.9.4" -TensorAlgebra = "0.9.5" +TensorAlgebra = "0.9.6" TupleTools = "1.6" TypeParameterAccessors = "0.4" VectorInterface = "0.5, 0.6" diff --git a/src/abstractnameddimsarray.jl b/src/abstractnameddimsarray.jl index e2b96118..7a3e9395 100644 --- a/src/abstractnameddimsarray.jl +++ b/src/abstractnameddimsarray.jl @@ -1,6 +1,6 @@ -using FunctionImplementations: FunctionImplementations as FI using LinearAlgebra: LinearAlgebra using Random: Random +using TensorAlgebra: permuteddims using TypeParameterAccessors: unspecify_type_parameters # Some of the interface is inspired by: @@ -9,20 +9,11 @@ using TypeParameterAccessors: unspecify_type_parameters # https://github.com/mcabbott/NamedPlus.jl # https://pytorch.org/docs/stable/named_tensor.html -abstract type AbstractNamedDimsArrayImplementationStyle <: -FI.AbstractArrayImplementationStyle end - -struct NamedDimsArrayImplementationStyle <: AbstractNamedDimsArrayImplementationStyle end - abstract type AbstractNamedDimsArray{T, N} <: AbstractArray{T, N} end const AbstractNamedDimsVector{T} = AbstractNamedDimsArray{T, 1} const AbstractNamedDimsMatrix{T} = AbstractNamedDimsArray{T, 2} -function FI.ImplementationStyle(type::Type{<:AbstractNamedDimsArray}) - return NamedDimsArrayImplementationStyle() -end - dimnames(a::AbstractNamedDimsArray) = throw(MethodError(dimnames, a)) function dimnames(a::AbstractNamedDimsArray, dim::Int) return dimnames(a)[dim] @@ -54,13 +45,7 @@ dimnametype(type::Type{<:AbstractNamedDimsArray}) = throw(MethodError(dimnametyp # Unwrapping the names (`NamedDimsArrays.jl` interface). # TODO: Use `IsNamed` trait? denamed(a::AbstractNamedDimsArray) = throw(MethodError(denamed, a)) -function denamed(a::AbstractNamedDimsArray, inds) - # Skip the lazy `PermutedDimsArray` wrap when the requested order already - # matches `a`'s; compare via `Tuple` because `LittleSet ==` is - # set-equality and would mask a permutation. - Tuple(name.(inds)) == Tuple(dimnames(a)) && return denamed(a) - return denamed(aligneddims(a, inds)) -end +denamed(a::AbstractNamedDimsArray, inds) = denamed(aligneddims(a, inds)) dename(a::AbstractNamedDimsArray, inds) = denamed(aligndims(a, inds)) # Output the named axes/indices of the named dims array. @@ -706,7 +691,7 @@ function aligneddims(a::AbstractArray, dims) ) ) return nameddimsconstructorof(a)( - FI.permuteddims(denamed(a), perm), new_dimnames + permuteddims(denamed(a), perm), new_dimnames ) end diff --git a/src/broadcast.jl b/src/broadcast.jl index 05e4bdee..01391446 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -100,6 +100,14 @@ function broadcasted_denamed(bc::Broadcasted, inds) return broadcasted(bc.f, Base.Fix2(broadcasted_denamed, inds).(bc.args)...) end +# A bare (unnamed) array operand, used as an allocation prototype so a broadcast +# result inherits the operands' backend (e.g. graded) rather than a lazy permuted +# wrapper's `similar` (which can drop the backend). +denamed_prototype(bc::Broadcasted) = denamed_prototype(bc.args...) +denamed_prototype(arg::AbstractNamedDimsArray, args...) = denamed(arg) +denamed_prototype(arg::Broadcasted, args...) = denamed_prototype(arg.args..., args...) +denamed_prototype(arg, args...) = denamed_prototype(args...) + function Base.similar(bc::Broadcasted{<:AbstractNamedDimsArrayStyle}, elt::Type, ax) inds_a = name.(ax) bc_denamed = broadcasted_denamed(bc, inds_a) @@ -120,7 +128,14 @@ function Base.copy(bc::Broadcasted{<:AbstractNamedDimsArrayStyle}) # the output element type at runtime with widening. inds_dest = inds(bc) bc_denamed = broadcasted_denamed(bc, inds_dest) - dest_denamed = copy(bc_denamed) + lb = TA.tryflattenlinear(bc_denamed) + if isnothing(lb) + dest_denamed = copy(bc_denamed) + else + dest_axes = denamed.(Tuple(axes(bc))) + dest_denamed = similar(denamed_prototype(bc), eltype(lb), dest_axes) + copyto!(dest_denamed, lb) + end return nameddimstype(bc.style)(dest_denamed, inds_dest) end @@ -128,6 +143,11 @@ function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractNamedDimsAr dest_denamed = denamed(dest) inds_dest = inds(dest) bc_denamed = broadcasted_denamed(bc, inds_dest) - copyto!(dest_denamed, bc_denamed) + lb = TA.tryflattenlinear(bc_denamed) + if isnothing(lb) + copyto!(dest_denamed, bc_denamed) + else + copyto!(dest_denamed, lb) + end return dest end From 93b6822180d866fd66bf7bc873c7d685d3f13faa Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Wed, 17 Jun 2026 18:03:34 -0400 Subject: [PATCH 2/3] Comment the linear-broadcast lowering branches Co-Authored-By: Claude Opus 4.8 --- src/broadcast.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/broadcast.jl b/src/broadcast.jl index 01391446..e4030343 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -130,8 +130,11 @@ function Base.copy(bc::Broadcasted{<:AbstractNamedDimsArrayStyle}) bc_denamed = broadcasted_denamed(bc, inds_dest) lb = TA.tryflattenlinear(bc_denamed) if isnothing(lb) + # Not a linear combination: ordinary fused broadcast. dest_denamed = copy(bc_denamed) else + # Linear: lower to bipermutedimsopadd!. Allocate from an operand so the + # result keeps the backend, using the backend's result axes (not `lb`'s). dest_axes = denamed.(Tuple(axes(bc))) dest_denamed = similar(denamed_prototype(bc), eltype(lb), dest_axes) copyto!(dest_denamed, lb) @@ -145,8 +148,10 @@ function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractNamedDimsAr bc_denamed = broadcasted_denamed(bc, inds_dest) lb = TA.tryflattenlinear(bc_denamed) if isnothing(lb) + # Not a linear combination: ordinary fused broadcast. copyto!(dest_denamed, bc_denamed) else + # Linear: lower to bipermutedimsopadd! into the existing dest. copyto!(dest_denamed, lb) end return dest From aeb0e5fc7e6848e202152743336e20276c3b0c77 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Wed, 17 Jun 2026 18:52:57 -0400 Subject: [PATCH 3/3] Remove TensorAlgebra [sources] pin Co-Authored-By: Claude Opus 4.8 --- Project.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/Project.toml b/Project.toml index 05c0e345..d8ae3852 100644 --- a/Project.toml +++ b/Project.toml @@ -25,10 +25,6 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" -[sources.TensorAlgebra] -rev = "mf/permuteddims-unwrap" -url = "https://github.com/ITensor/TensorAlgebra.jl" - [extensions] NamedDimsArraysAbstractTreesExt = "AbstractTrees" NamedDimsArraysAdaptExt = "Adapt"