diff --git a/examples/jld2_checkpoint.jl b/examples/jld2_checkpoint.jl new file mode 100644 index 0000000..9c46b84 --- /dev/null +++ b/examples/jld2_checkpoint.jl @@ -0,0 +1,106 @@ +""" +Demo: checkpointing and resuming LBFGS optimization with JLD2. + +Usage (from the repo root): + julia --project=. test/jld2_checkpoint_demo.jl + +What it shows: + 1. Run LBFGS with a `checkpoint` callback that saves state to a JLD2 file + every N iterations. + 2. Interrupt early (via `shouldstop`) to simulate a crashed job. + 3. Load the last checkpoint from disk and resume to full convergence. + 4. Verify the resumed result matches a reference run. +""" + +using OptimKit +using LinearAlgebra +using JLD2 # `import Pkg; Pkg.add("JLD2")` if not yet installed + +# --------------------------------------------------------------------------- +# Problem: minimise f(x) = ½ (x-y)ᵀ A (x-y) +# --------------------------------------------------------------------------- +function make_fg(A, y) + function fg(x) + r = x - y + g = A * r + f = dot(r, g) / 2 + return f, g + end + return fg +end + +# Reproducible random problem +import Random; Random.seed!(42) +n = 50 +y = randn(n) +A = let B = randn(n, n); B'B + 5I end # positive-definite, well-conditioned +fg = make_fg(A, y) +x₀ = randn(n) +alg = LBFGS(; gradtol=1e-12, verbosity=0) + +# --------------------------------------------------------------------------- +# Reference: run to convergence (ground truth) +# --------------------------------------------------------------------------- +x_ref, f_ref, _, _, _ = optimize(fg, x₀, alg) +println("Reference: f* = $f_ref, ‖x*-y‖ = $(norm(x_ref - y))") + +# --------------------------------------------------------------------------- +# Helper: build a checkpoint callback that saves to `filepath` every +# `save_every` completed iterations using JLD2. +# --------------------------------------------------------------------------- +function make_jld2_checkpoint(filepath::String; save_every::Int=1) + function checkpoint(state::LBFGSState) + if mod(state.numiter, save_every) == 0 + jldsave(filepath; state) + # Uncomment the line below to see checkpoint progress: + # println(" [checkpoint] saved iter $(state.numiter), f=$(state.f)") + end + end + return checkpoint +end + +# --------------------------------------------------------------------------- +# Phase 1: run for up to 10 iterations, saving a checkpoint after each one +# --------------------------------------------------------------------------- +checkpoint_file = tempname() * ".jld2" + +checkpoint_cb = make_jld2_checkpoint(checkpoint_file; save_every=1) +stop_at_10 = (x, f, g, numfg, numiter, t) -> numiter >= 10 + +x_part, f_part, _, numfg_part, history_part = + optimize(fg, x₀, alg; + checkpoint = checkpoint_cb, + shouldstop = stop_at_10, + hasconverged = (x, f, g, ng) -> ng <= 1e-12) + +println("\nPhase 1 done: $(size(history_part,1)-1) iterations, f = $f_part") +println("Checkpoint file: $checkpoint_file ($(round(filesize(checkpoint_file)/1024, digits=1)) KB)") + +# --------------------------------------------------------------------------- +# Phase 2: load checkpoint and resume to convergence +# --------------------------------------------------------------------------- +state_loaded = jldopen(checkpoint_file, "r") do file + file["state"] +end + +println("\nLoaded checkpoint: numiter=$(state_loaded.numiter), numfg=$(state_loaded.numfg)") +println(" fhistory length = $(length(state_loaded.fhistory)) (should be numiter+1)") +println(" H length = $(length(state_loaded.H)) (LBFGS memory used)") + +x_resumed, f_resumed, _, numfg_resumed, history_resumed = + optimize(fg, state_loaded, alg) + +println("\nPhase 2 done: total $(size(history_resumed,1)-1) iterations, f = $f_resumed") +println(" numfg (total) = $numfg_resumed") + +# --------------------------------------------------------------------------- +# Sanity checks +# --------------------------------------------------------------------------- +@assert x_resumed ≈ x_ref rtol=1e-8 "resumed solution differs from reference" +@assert f_resumed ≈ f_ref rtol=1e-8 "resumed f* differs from reference" +@assert history_resumed[1:size(history_part,1), :] ≈ history_part "history mismatch" + +println("\n✓ All checks passed — resumed result matches reference run.") + +# Clean up temp file +rm(checkpoint_file) diff --git a/src/OptimKit.jl b/src/OptimKit.jl index 6415200..cf657ec 100644 --- a/src/OptimKit.jl +++ b/src/OptimKit.jl @@ -129,6 +129,7 @@ const lbfgs = LBFGS() export optimize, gd, cg, lbfgs, optimtest export GradientDescent, ConjugateGradient, LBFGS +export LBFGSState export FletcherReeves, HestenesStiefel, PolakRibiere, HagerZhang, DaiYuan export HagerZhangLineSearch diff --git a/src/lbfgs.jl b/src/lbfgs.jl index 68bc51d..20bdc34 100644 --- a/src/lbfgs.jl +++ b/src/lbfgs.jl @@ -1,5 +1,5 @@ """ - LBFGS(m::Int = 8; + LBFGS(m::Int = 8; acceptfirst::Bool = true, maxiter::Int=MAXITER[], # 1_000_000 gradtol::Real=GRADTOL[], # 1e-8 @@ -53,9 +53,63 @@ function LBFGS(m::Int=8; return LBFGS(m, maxiter, gradtol, acceptfirst, verbosity, linesearch) end +""" + LBFGSState + +Captures the complete state of an LBFGS optimization, enabling checkpointing and +warm-starting. Instances are produced by the `checkpoint` callback passed to +[`optimize`](@ref), and can be passed back as the starting point to resume optimization. + +## Fields +- `x`: Current parameter values +- `f`: Current function value +- `g`: Current gradient +- `H`: Current LBFGS inverse Hessian approximation (`LBFGSInverseHessian`) +- `numfg`: Cumulative number of function/gradient evaluations so far +- `numiter`: Cumulative number of completed iterations +- `fhistory`: History of function values (one entry per iteration) +- `normgradhistory`: History of gradient norms (one entry per iteration) + +## Example + +Periodic checkpointing using `Serialization` from the standard library: + +```julia +using Serialization, OptimKit + +checkpoint_fn = state -> serialize("checkpoint.jls", state) +x, f, g, numfg, history = optimize(fg, x0, LBFGS(); checkpoint=checkpoint_fn) + +# resume from the last checkpoint +state = deserialize("checkpoint.jls") +x, f, g, numfg, history = optimize(fg, state, LBFGS()) +``` + +!!! note + The `LBFGSState` struct stores references to the arrays `x`, `g`, and the vectors + inside `H`. When using GPU arrays or other non-standard backends, ensure your + serialization method handles those array types correctly. + +!!! note + When resuming, the `shouldstop` and `hasconverged` callbacks receive the *cumulative* + `numfg` and `numiter` values from the original run. Pass a custom `shouldstop` if you + need a fixed number of *additional* iterations. +""" +struct LBFGSState{X,G,F<:Real,H} + x::X + f::F + g::G + H::H + numfg::Int + numiter::Int + fhistory::Vector{F} + normgradhistory::Vector{F} +end + function optimize(fg, x, alg::LBFGS; precondition=_precondition, (finalize!)=_finalize!, + checkpoint=nothing, shouldstop=DefaultShouldStop(alg.maxiter), hasconverged=DefaultHasConverged(alg.gradtol), retract=_retract, inner=_inner, (transport!)=_transport!, @@ -70,15 +124,66 @@ function optimize(fg, x, alg::LBFGS; normgrad = sqrt(innergg) fhistory = [f] normgradhistory = [normgrad] - t = time() - t₀ - _hasconverged = hasconverged(x, f, g, normgrad) - _shouldstop = shouldstop(x, f, g, numfg, numiter, t) TangentType = typeof(g) ScalarType = typeof(innergg) m = alg.m H = LBFGSInverseHessian(m, TangentType[], TangentType[], ScalarType[]) + return _lbfgs_loop!(fg, x, f, g, H, numfg, numiter, normgrad, fhistory, + normgradhistory, t₀, alg, + precondition, finalize!, checkpoint, + shouldstop, hasconverged, + retract, inner, transport!, scale!, add!, + isometrictransport) +end + +""" + optimize(fg, state::LBFGSState, alg::LBFGS; kwargs...) -> x, f, g, numfg, history + +Resume an LBFGS optimization from a previously saved [`LBFGSState`](@ref). All keyword +arguments are the same as for the standard `optimize` call. The `numfg`, `numiter`, +`fhistory`, and `normgradhistory` are continued from the checkpoint; the returned +`history` matrix covers the full run including prior iterations. +""" +function optimize(fg, state::LBFGSState, alg::LBFGS; + precondition=_precondition, + (finalize!)=_finalize!, + checkpoint=nothing, + shouldstop=DefaultShouldStop(alg.maxiter), + hasconverged=DefaultHasConverged(alg.gradtol), + retract=_retract, inner=_inner, (transport!)=_transport!, + (scale!)=_scale!, (add!)=_add!, + isometrictransport=(transport! == _transport! && inner == _inner)) + t₀ = time() + x = state.x + f = state.f + g = state.g + H = deepcopy(state.H) + numfg = state.numfg + numiter = state.numiter + normgrad = state.normgradhistory[end] + fhistory = copy(state.fhistory) + normgradhistory = copy(state.normgradhistory) + + return _lbfgs_loop!(fg, x, f, g, H, numfg, numiter, normgrad, fhistory, + normgradhistory, t₀, alg, + precondition, finalize!, checkpoint, + shouldstop, hasconverged, + retract, inner, transport!, scale!, add!, + isometrictransport) +end + +function _lbfgs_loop!(fg, x, f, g, H, numfg, numiter, normgrad, fhistory, normgradhistory, + t₀, alg::LBFGS, + precondition, finalize!, checkpoint, + shouldstop, hasconverged, + retract, inner, transport!, scale!, add!, isometrictransport) + verbosity = alg.verbosity + t = time() - t₀ + _hasconverged = hasconverged(x, f, g, normgrad) + _shouldstop = shouldstop(x, f, g, numfg, numiter, t) + verbosity >= 2 && @info @sprintf("LBFGS: initializing with f = %.12e, ‖∇f‖ = %.4e", f, normgrad) @@ -122,13 +227,12 @@ function optimize(fg, x, alg::LBFGS; _hasconverged = hasconverged(x, f, g, normgrad) _shouldstop = shouldstop(x, f, g, numfg, numiter, t) - # check stopping criteria and print info - if _hasconverged || _shouldstop - break + # print iteration info if continuing (preserves original verbosity behavior) + if !(_hasconverged || _shouldstop) + verbosity >= 3 && + @info @sprintf("LBFGS: iter %4d, Δt %s: f = %.12e, ‖∇f‖ = %.4e, α = %.2e, m = %d, nfg = %d", + numiter, format_time(Δt), f, normgrad, α, length(H), nfg) end - verbosity >= 3 && - @info @sprintf("LBFGS: iter %4d, Δt %s: f = %.12e, ‖∇f‖ = %.4e, α = %.2e, m = %d, nfg = %d", - numiter, format_time(Δt), f, normgrad, α, length(H), nfg) # transport gprev, ηprev and vectors in Hessian approximation to x gprev = transport!(gprev, xprev, ηprev, α, x) @@ -189,6 +293,16 @@ function optimize(fg, x, alg::LBFGS; ρ = innerss / innersy push!(H, (scale!(s, 1 / norms), scale!(y, 1 / norms), ρ)) end + + # checkpoint after H is updated; called every iteration including the last + if !isnothing(checkpoint) + checkpoint(LBFGSState(x, f, g, H, numfg, numiter, fhistory, normgradhistory)) + end + + # break after checkpoint so the final state is always captured + if _hasconverged || _shouldstop + break + end end if _hasconverged verbosity >= 2 && diff --git a/test/runtests.jl b/test/runtests.jl index 4904773..0ef74e3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -92,6 +92,60 @@ algorithms = (GradientDescent, ConjugateGradient, LBFGS) @test f < 1e-12 end +@testset "LBFGS checkpoint and resume" begin + n = 20 + y = randn(n) + A = let B = randn(n, n); B' * B + I end + fg = quadraticproblem(A, y) + x₀ = randn(n) + alg = LBFGS(; verbosity=0, gradtol=1e-12, maxiter=10_000_000) + + # Run to full convergence as ground truth + x_full, f_full, g_full, numfg_full, history_full = optimize(fg, x₀, alg) + + # Run with early stopping after 5 iterations and collect checkpoint + saved_states = LBFGSState[] + checkpoint_fn = state -> push!(saved_states, state) + stop_after_5 = (x, f, g, numfg, numiter, t) -> numiter >= 5 + converged_1e12 = (x, f, g, normgrad) -> normgrad <= 1e-12 + x_part, f_part, g_part, numfg_part, history_part = + optimize(fg, x₀, alg; checkpoint=checkpoint_fn, shouldstop=stop_after_5, + hasconverged=converged_1e12) + + # Checkpoint is called once per completed iteration + @test length(saved_states) == 5 + + # Checkpoint state at iteration 5 matches optimize's returned state + state5 = saved_states[end] + @test state5.numiter == 5 + @test state5.x ≈ x_part + @test state5.f ≈ f_part + @test state5.numfg == numfg_part + @test length(state5.fhistory) == 6 # initial + 5 iterations + @test length(state5.normgradhistory) == 6 + + # Resume from checkpoint and run to convergence; result must match full run + x_resumed, f_resumed, g_resumed, numfg_resumed, history_resumed = + optimize(fg, state5, alg) + @test x_resumed ≈ x_full rtol = 1e-10 + @test f_resumed ≈ f_full rtol = 1e-10 + + # Resumed history prepends the prior run's history + @test size(history_resumed, 1) == size(history_full, 1) + @test history_resumed[1:6, :] ≈ history_part # first 6 rows identical to partial run + + # Resume with additional checkpoint continues counting from previous numiter + extra_states = LBFGSState[] + stop_after_3_more = (x, f, g, numfg, numiter, t) -> numiter >= state5.numiter + 3 + optimize(fg, state5, alg; + checkpoint=state -> push!(extra_states, state), + shouldstop=stop_after_3_more, + hasconverged=converged_1e12) + @test length(extra_states) == 3 + @test extra_states[1].numiter == 6 + @test extra_states[end].numiter == 8 +end + @testset "Aqua" verbose = true begin using Aqua Aqua.test_all(OptimKit)