When using Zygote version 0.6.40 the gradient of the following function can be evaluated correctly:
'''
function nrmsqrt(A)
return @tensor res = A[1,2] * A'[2,1]
end
'''
However, if one uses the version 0.6.49 (latest) this does no longer work, producing the following error:
'''
Need an adjoint for constructor TensorKit.AdjointTensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing}. Gradient is of type TensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] (::Zygote.Jnew{TensorKit.AdjointTensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing}, Nothing, false})(Δ::TensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/lib/lib.jl:327
[3] (::Zygote.var"#2100#back#224"{Zygote.Jnew{TensorKit.AdjointTensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing}, Nothing, false}})(Δ::TensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[4] Pullback
@ ~/.julia/packages/TensorKit/KbNYI/src/tensors/adjoint.jl:11 [inlined]
[5] (::typeof(∂(TensorKit.AdjointTensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing})))(Δ::TensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[6] Pullback
@ ~/.julia/packages/TensorKit/KbNYI/src/tensors/adjoint.jl:11 [inlined]
[7] (::typeof(∂(TensorKit.AdjointTensorMap)))(Δ::TensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[8] Pullback
@ ~/.julia/packages/TensorKit/KbNYI/src/tensors/adjoint.jl:18 [inlined]
[9] (::typeof(∂(adjoint)))(Δ::TensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
...
'''
Replacing A'[2,1] with conj(A)[1,2] suprisingly works also with the latest version of Zygote.
Best wishes,
Erik
When using Zygote version 0.6.40 the gradient of the following function can be evaluated correctly:
'''
function nrmsqrt(A)
return @tensor res = A[1,2] * A'[2,1]
end
'''
However, if one uses the version 0.6.49 (latest) this does no longer work, producing the following error:
'''
Need an adjoint for constructor TensorKit.AdjointTensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing}. Gradient is of type TensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] (::Zygote.Jnew{TensorKit.AdjointTensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing}, Nothing, false})(Δ::TensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/lib/lib.jl:327
[3] (::Zygote.var"#2100#back#224"{Zygote.Jnew{TensorKit.AdjointTensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing}, Nothing, false}})(Δ::TensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[4] Pullback
@ ~/.julia/packages/TensorKit/KbNYI/src/tensors/adjoint.jl:11 [inlined]
[5] (::typeof(∂(TensorKit.AdjointTensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing})))(Δ::TensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[6] Pullback
@ ~/.julia/packages/TensorKit/KbNYI/src/tensors/adjoint.jl:11 [inlined]
[7] (::typeof(∂(TensorKit.AdjointTensorMap)))(Δ::TensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[8] Pullback
@ ~/.julia/packages/TensorKit/KbNYI/src/tensors/adjoint.jl:18 [inlined]
[9] (::typeof(∂(adjoint)))(Δ::TensorMap{ComplexSpace, 3, 1, Trivial, Matrix{ComplexF64}, Nothing, Nothing})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
...
'''
Replacing A'[2,1] with conj(A)[1,2] suprisingly works also with the latest version of Zygote.
Best wishes,
Erik