Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module ITensorBaseAbstractTreesExt

using AbstractTrees: AbstractTrees
using ITensorBase: AbstractNamedDimsArray, dimnames
using ITensorBase: AbstractITensor, dimnames

# Only print the dimension names when printing with `AbstractTrees.print_tree`.
function AbstractTrees.printnode(io::IO, a::AbstractNamedDimsArray)
function AbstractTrees.printnode(io::IO, a::AbstractITensor)
dimnames_a = "{" * join(map(s -> "\"$s\"", dimnames(a)), ", ") * "}"
print(io, dimnames_a)
return nothing
Expand Down
4 changes: 2 additions & 2 deletions ext/ITensorBaseAdaptExt/ITensorBaseAdaptExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module ITensorBaseAdaptExt

using Adapt: Adapt, adapt
using ITensorBase: AbstractNamedDimsArray, denamed, dimnames, nameddims
using ITensorBase: AbstractITensor, denamed, dimnames, nameddims

function Adapt.adapt_structure(to, a::AbstractNamedDimsArray)
function Adapt.adapt_structure(to, a::AbstractITensor)
return nameddims(adapt(to, denamed(a)), dimnames(a))
end

Expand Down
25 changes: 12 additions & 13 deletions ext/ITensorBaseBlockArraysExt/ITensorBaseBlockArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module ITensorBaseBlockArraysExt
using ArrayLayouts: ArrayLayouts
using BlockArrays: Block, BlockRange
using ITensorBase: AbstractNamedDimsArray, AbstractNamedDimsMatrix, AbstractNamedUnitRange,
getindex_named, view_nameddims
using ITensorBase: AbstractITensor, AbstractNamedUnitRange, getindex_named, view_nameddims

function Base.getindex(r::AbstractNamedUnitRange{<:Integer}, I::Block{1})
# TODO: Use `Derive.@interface NamedArrayInterface() r[I]` instead.
Expand All @@ -16,39 +15,39 @@ end

const BlockIndex{N} = Union{Block{N}, BlockRange{N}, AbstractVector{<:Block{N}}}

function Base.view(a::AbstractNamedDimsArray, I1::Block{1}, Irest::BlockIndex{1}...)
# TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead.
function Base.view(a::AbstractITensor, I1::Block{1}, Irest::BlockIndex{1}...)
# TODO: Use `Derive.@interface ITensorInterface() r[I]` instead.
return view_nameddims(a, I1, Irest...)
end

function Base.view(a::AbstractNamedDimsArray, I::Block)
# TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead.
function Base.view(a::AbstractITensor, I::Block)
# TODO: Use `Derive.@interface ITensorInterface() r[I]` instead.
return view_nameddims(a, Tuple(I)...)
end

function Base.view(a::AbstractNamedDimsArray, I1::BlockIndex{1}, Irest::BlockIndex{1}...)
# TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead.
function Base.view(a::AbstractITensor, I1::BlockIndex{1}, Irest::BlockIndex{1}...)
# TODO: Use `Derive.@interface ITensorInterface() r[I]` instead.
return view_nameddims(a, I1, Irest...)
end

# Fix ambiguity error.
function Base.getindex(
a::AbstractNamedDimsArray, I1::BlockRange{1}, Irest::BlockRange{1}...
a::AbstractITensor, I1::BlockRange{1}, Irest::BlockRange{1}...
)
return ArrayLayouts.layout_getindex(a, I1, Irest...)
end

# Fix ambiguity errors.
function Base.getindex(a::AbstractNamedDimsArray, I1::Block{1}, Irest...)
function Base.getindex(a::AbstractITensor, I1::Block{1}, Irest...)
return copy(view(a, I1, Irest...))
end
function Base.getindex(a::AbstractNamedDimsMatrix, I1::AbstractVector, I2::Block{1})
function Base.getindex(a::AbstractITensor, I1::AbstractVector, I2::Block{1})
return copy(view(a, I1, I2))
end
function Base.getindex(a::AbstractNamedDimsMatrix, I1::Block{1}, I2::AbstractVector)
function Base.getindex(a::AbstractITensor, I1::Block{1}, I2::AbstractVector)
return copy(view(a, I1, I2))
end
function Base.getindex(a::AbstractNamedDimsArray{<:Any, N}, I::Block{N}) where {N}
function Base.getindex(a::AbstractITensor, I::Block{N}) where {N}
return copy(view(a, I))
end

Expand Down
8 changes: 4 additions & 4 deletions ext/ITensorBaseMooncakeExt/ITensorBaseMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
module ITensorBaseMooncakeExt

using ITensorBase: AbstractNamedDimsArray, AbstractNamedUnitRange, blockedperm_nameddims,
using ITensorBase: AbstractITensor, AbstractNamedUnitRange, blockedperm_nameddims,
combine_nameddimsconstructors, dimnames, dimnames_setdiff, inds, name,
nameddimsconstructorof, randname, to_inds
using Mooncake: Mooncake, @zero_derivative, DefaultCtx
using TensorAlgebra: blockedperm

Mooncake.tangent_type(::Type{<:AbstractNamedUnitRange}) = Mooncake.NoTangent

@zero_derivative DefaultCtx Tuple{typeof(blockedperm), AbstractNamedDimsArray, Any, Any}
@zero_derivative DefaultCtx Tuple{typeof(blockedperm), AbstractITensor, Any, Any}
@zero_derivative DefaultCtx Tuple{typeof(blockedperm_nameddims), Any, Any, Any}
@zero_derivative DefaultCtx Tuple{typeof(combine_nameddimsconstructors), Any, Any}
@zero_derivative DefaultCtx Tuple{typeof(dimnames), Any}
Expand All @@ -22,9 +22,9 @@ Mooncake.tangent_type(::Type{<:AbstractNamedUnitRange}) = Mooncake.NoTangent
@zero_derivative DefaultCtx Tuple{typeof(randname), Any, Any}
@zero_derivative DefaultCtx Tuple{typeof(to_inds), Any, Any}

using ITensorBase: AbstractNamedDimsArray, NamedDimsArray, denamed
using ITensorBase: AbstractITensor, ITensor, denamed
using Mooncake: Tangent
function Base.copyto!(dest::NamedDimsArray, src::Tangent)
function Base.copyto!(dest::ITensor, src::Tangent)
# TODO: Account for the `inds` of the Tangent? In other words, is the tangent data
# aligned with the `dest` data?
copyto!(denamed(dest), src.fields.parent)
Expand Down
11 changes: 5 additions & 6 deletions src/ITensorBase.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module ITensorBase

export ITensor, Index, NamedDimsArray, aligndims, dimnametype, named, nameddims,
export ITensor, Index, aligndims, dimnametype, named, nameddims,
operator, similar_operator
using Compat: @compat
@compat public to_inds
Expand All @@ -16,16 +16,15 @@ include("abstractnamedarray.jl")
include("namedarray.jl")
include("abstractnamedunitrange.jl")
include("namedunitrange.jl")
include("abstractnameddimsarray.jl")
include("abstractitensor.jl")
include("broadcast.jl")
include("tensoralgebra.jl")
include("linearalgebra.jl")
include("nameddimsarray.jl")
include("nameddimsoperator.jl")
include("itensor.jl")
include("itensoroperator.jl")

# ITensor layer built on the named-array machinery.
# `IndexName` dimname flavor and the `Index` named unit range.
include("index.jl")
include("abstractitensor.jl")
include("quirks.jl")

end
Loading