Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c711f58
add complex lstsq and its test passes
yuiyuiui Feb 23, 2025
3e90afe
add some rules
yuiyuiui Feb 24, 2025
f136adc
change lstsq to arg_lstsq and add lstsq: A,b -> min ||Ax-b||
yuiyuiui Feb 25, 2025
c6a65fe
add lu decomposition
yuiyuiui Feb 25, 2025
8ecb361
add analytic function for normal matrix and some other improvement
yuiyuiui Feb 26, 2025
04f398f
svd and rsvd is right but does the test of rsvd really make sence?
yuiyuiui Feb 26, 2025
eff8f82
standard LP test passes
yuiyuiui Feb 26, 2025
67a1359
add GMRES
yuiyuiui Mar 2, 2025
d3d8c89
add complex GMRES
yuiyuiui Mar 2, 2025
470c88b
add theory statement of GMRES adjoint
yuiyuiui Mar 3, 2025
5cd7e14
approximate GMRES_BACK plays better, but of coures it falls for less…
yuiyuiui Mar 4, 2025
73e5fcf
add pffaffian
yuiyuiui Mar 4, 2025
aee0191
I will give proof of symeigen/normeigen
yuiyuiui Mar 5, 2025
7b7165b
add normeigen
yuiyuiui Mar 6, 2025
5459d4d
add proof of symeigen and normal eigen
yuiyuiui Mar 6, 2025
3a51e51
delete rule document
yuiyuiui Mar 6, 2025
154656e
add sdp without test
yuiyuiui Mar 9, 2025
1ff0996
realize ad rule for real sdp and tests pass
yuiyuiui Mar 10, 2025
ddefac4
sdp test for A can pass in 1e-1 rtol but falls in 1e-2 rtol
yuiyuiui Mar 10, 2025
5c7083c
add ad rule back for fft
yuiyuiui Mar 11, 2025
7f316a6
add ad for unfft type2, but finite diffinite has a too low accuracy
yuiyuiui Mar 13, 2025
35686c4
save
yuiyuiui Mar 27, 2025
4688e66
add gradient for general analytic matrix function
yuiyuiui Apr 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,33 @@ version = "0.2.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FINUFFT = "d8beea63-0952-562e-9c6a-8e8ef7364055"
GLPK = "60bf3e95-4087-53dc-ae20-288a0d20c6a6"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NFFT = "efe261a4-0d2b-5849-be55-fc731d526b0d"
NFFTTools = "7424e34d-94f7-41d6-98a0-85abaf1b6c91"
SCS = "c946c3f1-0d1f-5ce8-9dea-7daa1f7e2d13"
SkewLinearAlgebra = "5c889d49-8c60-4500-9d10-5d3a22e2f4b9"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ChainRulesCore = "1.25.1"
FFTW = "1.8.1"
FINUFFT = "3.3.1"
IterativeSolvers = "0.9.4"
LinearAlgebra = "1"
NFFT = "0.13.6"
NFFTTools = "0.2.6"
julia = "1.10"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
Expand Down
48 changes: 48 additions & 0 deletions docs/rule_list.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@

real/complex test
matrix multiplication | complex done

tensor network |

least sq / arg least sq | complex done

qr (all size) | complex done

symeigen | complex done

nromal eigen | complex done

svd | complex done

rsvd | complex done

schatten norm | complex done

matrix inversion | complex done

det | complex done

lu | complex done

linear equations | compelex done

expmv |

norm matrix analytic function | complex done

Cholesky decomposition | complex done

LP | real done

SDP | real done

GMRES | complex done

Pfaffain | real done

FFT | complex done

UNFFT(Type 1) | complex done

Inverse UNFFT(Type 1)

58 changes: 58 additions & 0 deletions examples/UNFFT.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
using NFFT, LinearAlgebra, Random, BackwardsLinalg, IterativeSolvers, NFFTTools

Random.seed!(3)
J = N = 128;
k = range(-0.4, stop=0.4, length=J);
f = randn(ComplexF64, J);
p = plan_nfft(k, N, reltol=1e-9);
A = BackwardsLinalg.A_construct_t2(k);
fhat = p*f;
f1 = A^(-1)*fhat;
norm(A*f1 - fhat)
W = sdc(p, iters = 10);
B = A'*diagm(W)*A;
b = A'*diagm(W)*fhat;
f2 = B\b;
norm(A*f2 - fhat)
B

A*diagm(W)*A'

g1 = A'^(-1)*f
norm(A*g1 - f)
C = A*diagm(W)*A'
c = A*diagm(W)*fhat
g2 = C\c
norm(A*g2 - fhat)

f3 = gmres(B,b; reltol=1e-8, abstol=1e-8, verbose=true)
norm(A*f3 - fhat)

sdc(p,iters = 10)


###########
Random.seed!(3)
N = 128
k = rand(N) .- 0.5
A = BackwardsLinalg.A_construct_t2(k)
f = randn(ComplexF64, N)
fhat = NFFT.nfft(k,f)


f1 = A^(-1)*fhat
error1 = norm(A*f1 - fhat)

f2 = gmres(A'*A, A'*fhat;reltol=1e-8, abstol=1e-8, verbose=true)

error2 = norm(A*f2 - fhat)



A = rand(128,128) + 128*I
b = rand(128)

x = gmres(A, b; reltol=1e-8, abstol=1e-8, verbose=true)


norm(A*x - b)
78 changes: 78 additions & 0 deletions examples/gmres.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
using Zygote, LinearAlgebra, BackwardsLinalg, Random


T = Float64
Random.seed!(3)
n = 200
A = rand(T, n, n) + n/64 * I
b = rand(T, n)

x= BackwardsLinalg.gmres(A, b)
norm(A*x-b)
x̄ = rand(T ,n)
BackwardsLinalg.my_gmres(A,b)[2]

# ========


e1 = zeros(T, k + 1)
e1[1] = 1.0
m, n = size(A)
mask = ones(T, k + 1, k)
for j ∈ 1:k
for i ∈ j+2:k+1
mask[i, j] = 0.0
end
end


x0 = zeros(n)
r0 = b - A * x0
W = hcat([A^(i - 1) * r0 for i in 1:k+1]...)
Q,R = BackwardsLinalg.qr(W)

H0 = Q' * A * Q[:, 1:k]
H1 = H0 .* mask
r0e = R[1,1] * e1
y = BackwardsLinalg.arg_lstsq(H1, r0e)
x1 = x0 + Q[:, 1:k] * y



norm(A*x1-b)
norm(x1-x)/norm(x)



# --------------------

x0 = zeros(n)
r0 = b - A * x0
W = hcat([A^(i - 1) * r0 for i in 1:k+1]...)
Q, R = BackwardsLinalg.qr(W)
Q = Q[:,1:k+1]
β = R[1,1]
H0 = R[1:k+1,2:k+1]

r0e = β * e1
y = BackwardsLinalg.arg_lstsq(H0, r0e)
x1 = x0 + Q[:, 1:k] * y
x

norm(A*x1-b)
norm(x1-x)/norm(x)

function tf(A)
B = copy(A)
B = 2*B
return sum(abs2.(B))
end

A =rand(3,3)

tf(A)
gradient(tf,A)

A = rand(100,5)
res = LinearAlgebra.qr(A)
Matrix(res.Q)
42 changes: 42 additions & 0 deletions examples/illness of iunftt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using BackwardsLinalg, NFFT, LinearAlgebra,Random
using Plots,Zygote

Random.seed!(3)
N = 16
k = rand(N) .- 0.5
f = randn(ComplexF64, N)
A = BackwardsLinalg.A_construct_t2(k)
cond(A)
loss(x)=sum(abs2.(A*x))
η = 1e-5
step = 1000

J = zeros(step)
for i in 1:step
J[i] = norm(gradient(loss,f)[1])
f = f - η*gradient(loss,f)[1]
end
plot(J)


using LinearAlgebra,Plots,Random
Random.seed!(3)
M = 20
cond_num = zeros(M)
T = 1000000
for i in 1: M
cond0 = 0.0
for t in 1:T
if t%1000 == 0
println(i)
end
A = rand(i,i)
cond0 += cond(A)
end
cond_num[i] = cond0/T
end
plot(1:M,cond_num)



cond_num[2]
7 changes: 7 additions & 0 deletions examples/normeigen.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
using BackwardsLinalg

A = rand(ComplexF64,3,3)

A += A'

BackwardsLinalg.normeigen(A)[1]
64 changes: 64 additions & 0 deletions examples/sdp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using JuMP, SCS, LinearAlgebra, Random, Test, BackwardsLinalg

# 定义数据
Random.seed!(123) # 设置随机种子
n = 4 # 矩阵的维度

# 生成对称的目标矩阵 C
C = rand(n, n)
C[2, 3] += 0.1
C = (C + C') / 2 # 确保对称性

# 生成对称的约束矩阵 A1 和 A2
A1 = rand(n, n)
A1 = (A1 + A1') / 2 # 确保对称性

A2 = rand(n, n)
A2 = (A2 + A2') / 2 # 确保对称性

# 生成约束的右侧值 b1 和 b2
b1 = tr(A1 * I(n)) # 约束 1 的右侧值 b1,确保可行
b2 = tr(A2 * I(n)) # 约束 2 的右侧值 b2,确保可行

# 使用 JuMP + SCS 求解
model = Model(SCS.Optimizer);
@variable(model, X[1:n, 1:n], PSD); # 定义半正定矩阵 X
@objective(model, Min, tr(C * X)) ; # 目标是最小化 tr(C * X)
@constraint(model, tr(A1 * X) == b1) ; # 约束 1: tr(A1 * X) = b1
@constraint(model, tr(A2 * X) == b2) ; # 约束 2: tr(A2 * X) = b2
set_silent(model)
optimize!(model); # 求解问题

# 检查求解状态并输出结果
if termination_status(model) == MOI.OPTIMAL
println("JuMP + SCS 结果:")
println("目标函数值: ", objective_value(model))
println("最优解 X:")
println(value.(X))
else
println("JuMP + SCS 求解失败")
println("求解状态: ", termination_status(model))
end

# 计算最优解 X 的特征值和特征向量
E, U = eigen(value.(X))
println("最优解 X 的特征值:")
println(E)
println("最优解 X 的特征向量:")
println(U)


#-------------------
A = [A1,A2]
b = [b1,b2]

X = BackwardsLinalg.sdp(C,A,b)
X = Matrix(X)
tr(C*X)
X̄ = rand(4, 4)
X̄ = (X̄ + X̄') / 2
Ā,b̄ = BackwardsLinalg.sdp_backward(C,A,b,X,X̄)




21 changes: 21 additions & 0 deletions src/BackwardsLinalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module BackwardsLinalg

using ChainRulesCore; import ChainRulesCore: rrule
using LinearAlgebra; import LinearAlgebra: ldiv!
using JuMP, GLPK, Zygote, SkewLinearAlgebra, SCS, FFTW, NFFT, NFFTTools

struct ZeroAdder end
Base.:+(a, zero::ZeroAdder) = a
Expand All @@ -10,11 +11,31 @@ Base.:-(a, zero::ZeroAdder) = a
Base.:-(zero::ZeroAdder, a) = -a
Base.:-(zero::ZeroAdder) = zero



include("qr.jl")
include("svd.jl")
include("lstsq.jl")
include("rsvd.jl")
include("symeigen.jl")
include("norm_anlfunc.jl")
include("cls.jl")
include("det.jl")
include("inv.jl")
include("lneq.jl")
include("lp.jl")
include("sdp.jl")
include("lu.jl")
include("mxmul.jl")
include("scha_norm.jl")
include("gmres.jl")
include("pf.jl")
include("normeigen.jl")
include("fft.jl")
include("unfft.jl")
include("matrix_func.jl")

include("chainrules.jl")


end
Loading