Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions backoff.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package resilientws

import (
"context"
"math/rand/v2"
"time"
)

func (r *Resws) backoff(attempt int) time.Duration {
min := r.RecBackoffMin
max := r.RecBackoffMax

if min >= max {
return max
}

backoffFactor := r.RecBackoffFactor
if backoffFactor == 0 {
backoffFactor = 1.5
}

if attempt > 30 {
attempt = 30
}

backoff := r.calculateBackoff(min, max, attempt)
backoff = time.Duration(float64(backoff) * backoffFactor)
backoff = r.applyJitter(backoff)

return r.clampBackoff(backoff, min, max)
}

func (r *Resws) calculateBackoff(min, max time.Duration, attempt int) time.Duration {
backoff := min
for i := 0; i < attempt; i++ {
backoff *= 2
if backoff > max || backoff < 0 {
return max
}
}
return backoff
}

func (r *Resws) applyJitter(backoff time.Duration) time.Duration {
if r.BackoffType == BackoffTypeJitter {
backoff = time.Duration(float64(backoff) * (1 + 0.1*rand.Float64()))
}
return backoff
}

func (r *Resws) clampBackoff(backoff, min, max time.Duration) time.Duration {
if backoff < min {
return min
}
if backoff > max {
return max
}
return backoff.Round(100 * time.Millisecond)
}

// getReconnectBackoff calculates backoff duration for reconnection attempts
func (r *Resws) getReconnectBackoff() time.Duration {
r.backoffMu.RLock()
attempts := r.reconnectAttempts
r.backoffMu.RUnlock()

if attempts <= 1 {
return 0 // First connection attempt, no backoff
}

return r.backoff(attempts - 1)
}

// incrementReconnectAttempts increments the reconnection attempt counter
func (r *Resws) incrementReconnectAttempts() {
r.backoffMu.Lock()
r.reconnectAttempts++
r.lastReconnectTime = time.Now()
r.backoffMu.Unlock()
}

// getReconnectAttempts returns the current reconnection attempt count
func (r *Resws) getReconnectAttempts() int {
r.backoffMu.RLock()
defer r.backoffMu.RUnlock()
return r.reconnectAttempts
}

// resetReconnectAttempts resets the reconnection attempt counter
func (r *Resws) resetReconnectAttempts() {
r.backoffMu.Lock()
r.reconnectAttempts = 0
r.backoffMu.Unlock()
}

// monitorConnectionStability monitors connection stability and resets backoff after stable period
func (r *Resws) monitorConnectionStability(ctx context.Context) {
timer := time.NewTimer(r.StableConnectionDuration)
defer timer.Stop()

select {
case <-ctx.Done():
return
case <-timer.C:
if r.IsConnected() {
r.resetReconnectAttempts()
if !r.NonVerbose {
r.Logger.Debug("Connection stable for %v, reset backoff counter", r.StableConnectionDuration)
}
}
}
}
196 changes: 196 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
package resilientws

import (
"context"
"fmt"
"net/http"
"time"

"github.com/gorilla/websocket"
)

// connect establishes a connection to the WebSocket server
func (r *Resws) connect() {
r.cancelExistingConn()
r.connWg.Wait()
r.setupConnContext()

recBackoff := r.getRecBackoffMin()
attempt := 0

for {
select {
case <-r.ctx.Done():
return
default:
conn, resp, err := r.dialer.Dial(r.url, r.Headers)
if err != nil {
r.handleDialFailure(resp, err, recBackoff)
recBackoff = r.backoff(attempt)
attempt++
continue
}

if r.handleConnectionSuccess(conn) {
return
}
}
}
}

func (r *Resws) cancelExistingConn() {
r.mu.Lock()
if r.connCancel != nil {
r.connCancel()
}
r.mu.Unlock()
}

func (r *Resws) setupConnContext() {
r.mu.Lock()
r.connCtx, r.connCancel = context.WithCancel(r.ctx)
r.mu.Unlock()
}

func (r *Resws) getRecBackoffMin() time.Duration {
r.mu.RLock()
defer r.mu.RUnlock()
return r.RecBackoffMin
}

func (r *Resws) handleConnectionSuccess(conn *websocket.Conn) bool {
r.mu.Lock()
r.Conn = conn
r.lastConnect = time.Now()
r.mu.Unlock()

if !r.NonVerbose {
r.Logger.Info("Connection was successfully established: %s", r.url)
}

r.signalConnected()
r.setIsConnected(true)
r.startConnectionWorkers()
r.emitEvent(Event{Type: EventConnected})

return r.retrySubscribeIfNeeded()
}

func (r *Resws) startConnectionWorkers() {
r.connWg.Add(1)
go func() {
defer r.connWg.Done()
r.monitorConnectionStability(r.connCtx)
}()

r.connWg.Add(1)
go func() {
defer r.connWg.Done()
r.processMessageQueue(r.connCtx)
}()

if r.PingHandler != nil {
r.connWg.Add(1)
go func() {
defer r.connWg.Done()
r.heartbeat(r.connCtx)
}()
}

if r.MessageHandler != nil {
r.connWg.Add(1)
go func() {
defer r.connWg.Done()
r.reader(r.connCtx)
}()
}
}

func (r *Resws) retrySubscribeIfNeeded() bool {
if r.SubscribeHandler == nil {
return true
}

if err := r.retrySubscribeHandler(); err != nil {
r.CloseAndReconnect()
return false
}

return true
}

func (r *Resws) handleDialFailure(resp *http.Response, err error, backoff time.Duration) {
r.mu.Lock()
r.Conn = nil
r.httpResp = resp
r.lastErr = err
r.shouldReconnect = true
r.mu.Unlock()

r.setIsConnected(false)

if r.onErrorFn != nil {
r.emitEvent(Event{Type: EventError, Error: err})
}
if !r.NonVerbose {
r.Logger.Info("Will reconnect in %v", backoff)
}
if r.onReconnectingFn != nil {
r.emitEvent(Event{Type: EventReconnecting, Data: backoff})
}

time.Sleep(backoff)
}

func (r *Resws) retrySubscribeHandler() error {
_, max := r.getSubscribeBackoffConfig()
attempt := 0

for {
select {
case <-r.connCtx.Done():
return r.connCtx.Err()
default:
}

if err := r.SubscribeHandler(); err != nil {
nextBackoff := r.backoff(attempt)
if retryErr := r.handleSubscribeError(err, nextBackoff, max, attempt); retryErr != nil {
return retryErr
}
attempt++
continue
}

if !r.NonVerbose {
r.Logger.Info("Subscribe handler executed successfully")
}
return nil
}
}

func (r *Resws) getSubscribeBackoffConfig() (min, max time.Duration) {
r.mu.RLock()
defer r.mu.RUnlock()
return r.RecBackoffMin, r.RecBackoffMax
}

func (r *Resws) handleSubscribeError(err error, backoff, max time.Duration, attempt int) error {
r.Logger.Error("Subscribe handler failed: %v", err)
r.emitEvent(Event{Type: EventError, Error: err})

if attempt > 0 && backoff >= max {
return fmt.Errorf("subscribe handler failed after max retries: %w", err)
}

if !r.NonVerbose {
r.Logger.Info("Retrying subscribe handler in %v", backoff)
}

select {
case <-r.connCtx.Done():
return r.connCtx.Err()
case <-time.After(backoff):
return nil
}
}
51 changes: 51 additions & 0 deletions heartbeat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package resilientws

import (
"context"
"time"

"github.com/gorilla/websocket"
)

// heartbeat sends ping messages to the server to keep the connection alive
func (r *Resws) heartbeat(ctx context.Context) {
ticker := time.NewTicker(r.PingInterval)
defer ticker.Stop()

r.setupPongHandler()

for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if !r.sendPing() {
return
}
r.PingHandler()
}
}
}

func (r *Resws) setupPongHandler() {
conn := r.getConn()
if conn == nil || r.PongTimeout <= 0 {
return
}

_ = conn.SetReadDeadline(time.Now().Add(r.PongTimeout))
conn.SetPongHandler(func(appData string) error {
_ = conn.SetReadDeadline(time.Now().Add(r.PongTimeout))
return nil
})
}

func (r *Resws) sendPing() bool {
conn := r.getConn()
if conn == nil {
return false
}

err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(r.PingInterval))
return err == nil
}
Loading