diff --git a/backoff.go b/backoff.go new file mode 100644 index 0000000..eae0c72 --- /dev/null +++ b/backoff.go @@ -0,0 +1,112 @@ +package resilientws + +import ( + "context" + "math/rand/v2" + "time" +) + +func (r *Resws) backoff(attempt int) time.Duration { + min := r.RecBackoffMin + max := r.RecBackoffMax + + if min >= max { + return max + } + + backoffFactor := r.RecBackoffFactor + if backoffFactor == 0 { + backoffFactor = 1.5 + } + + if attempt > 30 { + attempt = 30 + } + + backoff := r.calculateBackoff(min, max, attempt) + backoff = time.Duration(float64(backoff) * backoffFactor) + backoff = r.applyJitter(backoff) + + return r.clampBackoff(backoff, min, max) +} + +func (r *Resws) calculateBackoff(min, max time.Duration, attempt int) time.Duration { + backoff := min + for i := 0; i < attempt; i++ { + backoff *= 2 + if backoff > max || backoff < 0 { + return max + } + } + return backoff +} + +func (r *Resws) applyJitter(backoff time.Duration) time.Duration { + if r.BackoffType == BackoffTypeJitter { + backoff = time.Duration(float64(backoff) * (1 + 0.1*rand.Float64())) + } + return backoff +} + +func (r *Resws) clampBackoff(backoff, min, max time.Duration) time.Duration { + if backoff < min { + return min + } + if backoff > max { + return max + } + return backoff.Round(100 * time.Millisecond) +} + +// getReconnectBackoff calculates backoff duration for reconnection attempts +func (r *Resws) getReconnectBackoff() time.Duration { + r.backoffMu.RLock() + attempts := r.reconnectAttempts + r.backoffMu.RUnlock() + + if attempts <= 1 { + return 0 // First connection attempt, no backoff + } + + return r.backoff(attempts - 1) +} + +// incrementReconnectAttempts increments the reconnection attempt counter +func (r *Resws) incrementReconnectAttempts() { + r.backoffMu.Lock() + r.reconnectAttempts++ + r.lastReconnectTime = time.Now() + r.backoffMu.Unlock() +} + +// getReconnectAttempts returns the current reconnection attempt count +func (r *Resws) getReconnectAttempts() int { + r.backoffMu.RLock() + defer r.backoffMu.RUnlock() + return r.reconnectAttempts +} + +// resetReconnectAttempts resets the reconnection attempt counter +func (r *Resws) resetReconnectAttempts() { + r.backoffMu.Lock() + r.reconnectAttempts = 0 + r.backoffMu.Unlock() +} + +// monitorConnectionStability monitors connection stability and resets backoff after stable period +func (r *Resws) monitorConnectionStability(ctx context.Context) { + timer := time.NewTimer(r.StableConnectionDuration) + defer timer.Stop() + + select { + case <-ctx.Done(): + return + case <-timer.C: + if r.IsConnected() { + r.resetReconnectAttempts() + if !r.NonVerbose { + r.Logger.Debug("Connection stable for %v, reset backoff counter", r.StableConnectionDuration) + } + } + } +} diff --git a/connection.go b/connection.go new file mode 100644 index 0000000..e8f933a --- /dev/null +++ b/connection.go @@ -0,0 +1,196 @@ +package resilientws + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/gorilla/websocket" +) + +// connect establishes a connection to the WebSocket server +func (r *Resws) connect() { + r.cancelExistingConn() + r.connWg.Wait() + r.setupConnContext() + + recBackoff := r.getRecBackoffMin() + attempt := 0 + + for { + select { + case <-r.ctx.Done(): + return + default: + conn, resp, err := r.dialer.Dial(r.url, r.Headers) + if err != nil { + r.handleDialFailure(resp, err, recBackoff) + recBackoff = r.backoff(attempt) + attempt++ + continue + } + + if r.handleConnectionSuccess(conn) { + return + } + } + } +} + +func (r *Resws) cancelExistingConn() { + r.mu.Lock() + if r.connCancel != nil { + r.connCancel() + } + r.mu.Unlock() +} + +func (r *Resws) setupConnContext() { + r.mu.Lock() + r.connCtx, r.connCancel = context.WithCancel(r.ctx) + r.mu.Unlock() +} + +func (r *Resws) getRecBackoffMin() time.Duration { + r.mu.RLock() + defer r.mu.RUnlock() + return r.RecBackoffMin +} + +func (r *Resws) handleConnectionSuccess(conn *websocket.Conn) bool { + r.mu.Lock() + r.Conn = conn + r.lastConnect = time.Now() + r.mu.Unlock() + + if !r.NonVerbose { + r.Logger.Info("Connection was successfully established: %s", r.url) + } + + r.signalConnected() + r.setIsConnected(true) + r.startConnectionWorkers() + r.emitEvent(Event{Type: EventConnected}) + + return r.retrySubscribeIfNeeded() +} + +func (r *Resws) startConnectionWorkers() { + r.connWg.Add(1) + go func() { + defer r.connWg.Done() + r.monitorConnectionStability(r.connCtx) + }() + + r.connWg.Add(1) + go func() { + defer r.connWg.Done() + r.processMessageQueue(r.connCtx) + }() + + if r.PingHandler != nil { + r.connWg.Add(1) + go func() { + defer r.connWg.Done() + r.heartbeat(r.connCtx) + }() + } + + if r.MessageHandler != nil { + r.connWg.Add(1) + go func() { + defer r.connWg.Done() + r.reader(r.connCtx) + }() + } +} + +func (r *Resws) retrySubscribeIfNeeded() bool { + if r.SubscribeHandler == nil { + return true + } + + if err := r.retrySubscribeHandler(); err != nil { + r.CloseAndReconnect() + return false + } + + return true +} + +func (r *Resws) handleDialFailure(resp *http.Response, err error, backoff time.Duration) { + r.mu.Lock() + r.Conn = nil + r.httpResp = resp + r.lastErr = err + r.shouldReconnect = true + r.mu.Unlock() + + r.setIsConnected(false) + + if r.onErrorFn != nil { + r.emitEvent(Event{Type: EventError, Error: err}) + } + if !r.NonVerbose { + r.Logger.Info("Will reconnect in %v", backoff) + } + if r.onReconnectingFn != nil { + r.emitEvent(Event{Type: EventReconnecting, Data: backoff}) + } + + time.Sleep(backoff) +} + +func (r *Resws) retrySubscribeHandler() error { + _, max := r.getSubscribeBackoffConfig() + attempt := 0 + + for { + select { + case <-r.connCtx.Done(): + return r.connCtx.Err() + default: + } + + if err := r.SubscribeHandler(); err != nil { + nextBackoff := r.backoff(attempt) + if retryErr := r.handleSubscribeError(err, nextBackoff, max, attempt); retryErr != nil { + return retryErr + } + attempt++ + continue + } + + if !r.NonVerbose { + r.Logger.Info("Subscribe handler executed successfully") + } + return nil + } +} + +func (r *Resws) getSubscribeBackoffConfig() (min, max time.Duration) { + r.mu.RLock() + defer r.mu.RUnlock() + return r.RecBackoffMin, r.RecBackoffMax +} + +func (r *Resws) handleSubscribeError(err error, backoff, max time.Duration, attempt int) error { + r.Logger.Error("Subscribe handler failed: %v", err) + r.emitEvent(Event{Type: EventError, Error: err}) + + if attempt > 0 && backoff >= max { + return fmt.Errorf("subscribe handler failed after max retries: %w", err) + } + + if !r.NonVerbose { + r.Logger.Info("Retrying subscribe handler in %v", backoff) + } + + select { + case <-r.connCtx.Done(): + return r.connCtx.Err() + case <-time.After(backoff): + return nil + } +} diff --git a/heartbeat.go b/heartbeat.go new file mode 100644 index 0000000..ee3896f --- /dev/null +++ b/heartbeat.go @@ -0,0 +1,51 @@ +package resilientws + +import ( + "context" + "time" + + "github.com/gorilla/websocket" +) + +// heartbeat sends ping messages to the server to keep the connection alive +func (r *Resws) heartbeat(ctx context.Context) { + ticker := time.NewTicker(r.PingInterval) + defer ticker.Stop() + + r.setupPongHandler() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if !r.sendPing() { + return + } + r.PingHandler() + } + } +} + +func (r *Resws) setupPongHandler() { + conn := r.getConn() + if conn == nil || r.PongTimeout <= 0 { + return + } + + _ = conn.SetReadDeadline(time.Now().Add(r.PongTimeout)) + conn.SetPongHandler(func(appData string) error { + _ = conn.SetReadDeadline(time.Now().Add(r.PongTimeout)) + return nil + }) +} + +func (r *Resws) sendPing() bool { + conn := r.getConn() + if conn == nil { + return false + } + + err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(r.PingInterval)) + return err == nil +} diff --git a/io.go b/io.go new file mode 100644 index 0000000..8b2519c --- /dev/null +++ b/io.go @@ -0,0 +1,323 @@ +package resilientws + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/gorilla/websocket" +) + +// lockedWrite serialises writes through writeMu and applies WriteDeadline when configured. +// Callers must not hold r.mu when calling this. +func (r *Resws) lockedWrite(conn *websocket.Conn, fn func() error) error { + r.mu.RLock() + deadline := r.WriteDeadline + r.mu.RUnlock() + + r.writeMu.Lock() + defer r.writeMu.Unlock() + if deadline > 0 { + conn.SetWriteDeadline(time.Now().Add(deadline)) + } + return fn() +} + +// Send sends a message to the WebSocket server with a fallback queue +func (r *Resws) Send(msg []byte) error { + conn := r.getConn() + if conn != nil { + err := r.lockedWrite(conn, func() error { + return conn.WriteMessage(websocket.TextMessage, msg) + }) + if err == nil { + return nil + } + } + + return r.queueMessage(msg) +} + +func (r *Resws) queueMessage(msg []byte) error { + r.messageQueueMu.Lock() + defer r.messageQueueMu.Unlock() + + if len(r.messageQueue) >= r.MessageQueueSize { + return fmt.Errorf("message queue is full") + } + + r.messageQueue = append(r.messageQueue, msg) + + if r.Logger == nil { + r.Logger = &defaultLogger{} + } + r.Logger.Debug("Message queued for later delivery") + return nil +} + +// SendJSON sends a JSON message to the WebSocket server with a fallback queue +func (r *Resws) SendJSON(v any) (err error) { + conn := r.getConn() + if conn != nil { + err = r.lockedWrite(conn, func() error { + return conn.WriteJSON(v) + }) + if err == nil { + return nil + } + } + + return r.queueMessageJSON(v) +} + +func (r *Resws) queueMessageJSON(v any) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + r.messageQueueMu.Lock() + defer r.messageQueueMu.Unlock() + + if len(r.messageQueue) >= r.MessageQueueSize { + return fmt.Errorf("message queue is full") + } + + r.messageQueue = append(r.messageQueue, b) + + if r.Logger == nil { + r.Logger = &defaultLogger{} + } + r.Logger.Debug("Message queued for later delivery") + return nil +} + +// processMessageQueue processes messages from the queue and sends them to the WebSocket server +func (r *Resws) processMessageQueue(ctx context.Context) { + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + conn, ok := r.getConnForQueue() + if !ok { + continue + } + + msg, ok := r.dequeueMessage() + if !ok { + continue + } + + r.sendQueuedMessage(conn, msg) + } + } +} + +func (r *Resws) getConnForQueue() (*websocket.Conn, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.Conn, r.isConnected && r.Conn != nil +} + +func (r *Resws) dequeueMessage() ([]byte, bool) { + r.messageQueueMu.Lock() + defer r.messageQueueMu.Unlock() + + if len(r.messageQueue) == 0 { + return nil, false + } + + msg := r.messageQueue[0] + r.messageQueue = r.messageQueue[1:] + return msg, true +} + +func (r *Resws) sendQueuedMessage(conn *websocket.Conn, msg []byte) { + err := r.lockedWrite(conn, func() error { + return conn.WriteMessage(websocket.TextMessage, msg) + }) + + if err != nil { + r.Logger.Error("Failed to send queued message: %v", err) + r.requeueMessage(msg) + return + } + + r.Logger.Debug("Successfully sent queued message") +} + +func (r *Resws) requeueMessage(msg []byte) { + r.messageQueueMu.Lock() + defer r.messageQueueMu.Unlock() + + if len(r.messageQueue) < r.MessageQueueSize { + r.messageQueue = append(r.messageQueue, msg) + r.Logger.Debug("Requeued failed message") + return + } + + r.Logger.Error("Failed to requeue message: queue full") +} + +// reader loops and reads messages from the WebSocket connection and emits them as events +func (r *Resws) reader(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + default: + conn := r.getConn() + if conn == nil { + return + } + + if r.ReadDeadline > 0 { + conn.SetReadDeadline(time.Now().Add(r.ReadDeadline)) + } + + msgType, msg, err := conn.ReadMessage() + if r.handleReadError(err, conn) { + return + } + + if !r.handleMessage(msgType, msg) { + return + } + } + } +} + +func (r *Resws) handleReadError(err error, conn *websocket.Conn) bool { + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.ClosePolicyViolation) { + return true + } + + if err == nil { + return false + } + + r.mu.Lock() + reconnect := r.shouldReconnect + if r.Conn == conn { + r.Conn = nil + } + r.mu.Unlock() + + r.emitEvent(Event{Type: EventError, Error: err}) + + if reconnect { + time.Sleep(100 * time.Millisecond) + r.CloseAndReconnect() + } + + return true +} + +// handleMessage dispatches the message and returns false if the reader loop should stop. +func (r *Resws) handleMessage(msgType int, msg []byte) bool { + switch msgType { + case websocket.TextMessage, websocket.BinaryMessage: + r.MessageHandler(msgType, msg) + case websocket.CloseMessage: + return false + } + return true +} + +// ReadMessage manually reads a message from the websocket connection +func (r *Resws) ReadMessage() (msgType int, msg []byte, err error) { + conn, ok := r.checkConnection() + if !ok { + return 0, nil, errNotConnected + } + + msgType, msg, err = conn.ReadMessage() + if r.handleReadCloseError(err) { + return msgType, msg, nil + } + + r.handleReadReconnect(err) + + return +} + +func (r *Resws) handleReadCloseError(err error) bool { + if websocket.IsCloseError(err, websocket.CloseNormalClosure) { + r.Close() + return true + } + return false +} + +func (r *Resws) handleReadReconnect(err error) { + if err == nil { + return + } + + r.mu.Lock() + reconnect := r.shouldReconnect + r.mu.Unlock() + + if reconnect { + r.CloseAndReconnect() + } +} + +// ReadJSON manually reads a JSON message from the websocket connection +func (r *Resws) ReadJSON(v any) (err error) { + conn, ok := r.checkConnection() + if !ok { + return errNotConnected + } + + err = conn.ReadJSON(v) + if r.handleReadCloseError(err) { + return + } + + r.handleReadReconnect(err) + + return +} + +// WriteMessage manually writes a message to the websocket connection +func (r *Resws) WriteMessage(msgType int, msg []byte) (err error) { + conn, ok := r.checkConnection() + if !ok { + return errNotConnected + } + + err = r.lockedWrite(conn, func() error { + return conn.WriteMessage(msgType, msg) + }) + + if websocket.IsCloseError(err, websocket.CloseNormalClosure) { + r.Close() + } + + return +} + +// WriteJSON manually writes a JSON message to the websocket connection +func (r *Resws) WriteJSON(v any) (err error) { + conn, ok := r.checkConnection() + if !ok { + return errNotConnected + } + + err = r.lockedWrite(conn, func() error { + return conn.WriteJSON(v) + }) + + if websocket.IsCloseError(err, websocket.CloseNormalClosure) { + r.Close() + } + + return +} diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..3149e3f --- /dev/null +++ b/logger.go @@ -0,0 +1,23 @@ +package resilientws + +import "log" + +type Logger interface { + Debug(msg string, args ...any) + Info(msg string, args ...any) + Error(msg string, args ...any) +} + +type defaultLogger struct{} + +func (l *defaultLogger) Debug(msg string, args ...any) { + log.Printf("DEBUG: "+msg, args...) +} + +func (l *defaultLogger) Info(msg string, args ...any) { + log.Printf("INFO: "+msg, args...) +} + +func (l *defaultLogger) Error(msg string, args ...any) { + log.Printf("ERROR: "+msg, args...) +} diff --git a/resilientws.go b/resilientws.go index 95d3910..327f6d6 100644 --- a/resilientws.go +++ b/resilientws.go @@ -27,12 +27,7 @@ package resilientws import ( "context" - "crypto/tls" - "encoding/json" - "errors" "fmt" - "log" - "math/rand/v2" "net/http" "net/url" "strings" @@ -42,226 +37,40 @@ import ( "github.com/gorilla/websocket" ) -type Resws struct { - // RecBackoffMin is the minimum backoff duration between reconnection attempts - RecBackoffMin time.Duration - - // RecBackoffMax is the maximum backoff duration between reconnection attempts - RecBackoffMax time.Duration - - // RecBackoffFactor is the factor by which the backoff duration is multiplied - RecBackoffFactor float64 - - // BackoffType is the type of backoff to use - BackoffType BackoffType - - // StableConnectionDuration is the duration a connection must be stable before resetting backoff - StableConnectionDuration time.Duration - - // Handshake timeout - HandshakeTimeout time.Duration - - // Headers to be sent with the connection - Headers http.Header - - // Ping interval - PingInterval time.Duration - - // Pong timeout - PongTimeout time.Duration - - // Read deadline - ReadDeadline time.Duration - - // Write deadline - WriteDeadline time.Duration - - // Message queue size - MessageQueueSize int - - // TLS configuration - TLSConfig *tls.Config - - // Proxy configuration - Proxy func(*http.Request) (*url.URL, error) - - // Logger - Logger Logger - - // Non-verbose mode - NonVerbose bool - - // Subscribe handler - SubscribeHandler func() error - - // Message handler - MessageHandler func(int, []byte) - - // Ping handler - PingHandler func() - - url string - dialer *websocket.Dialer - httpResp *http.Response - mu sync.RWMutex - messageQueue [][]byte - messageQueueMu sync.Mutex - isConnected bool - lastConnect time.Time - lastErr error - shouldReconnect bool - connectedCh chan struct{} - connOnce *sync.Once - closeOnce sync.Once - - // Backoff state that persists across reconnections - reconnectAttempts int - lastReconnectTime time.Time - backoffMu sync.RWMutex - - // Context for connection management - ctx context.Context - cancel context.CancelFunc - connCtx context.Context - connCancel context.CancelFunc - - connWg sync.WaitGroup - - // Event handlers - onReconnectingFn func(time.Duration) - onConnectedFn func(string) - onErrorFn func(error) - - *websocket.Conn -} - -type Logger interface { - Debug(msg string, args ...any) - Info(msg string, args ...any) - Error(msg string, args ...any) -} - -type defaultLogger struct{} - -func (l *defaultLogger) Debug(msg string, args ...any) { - log.Printf("DEBUG: "+msg, args...) -} - -func (l *defaultLogger) Info(msg string, args ...any) { - log.Printf("INFO: "+msg, args...) -} - -func (l *defaultLogger) Error(msg string, args ...any) { - log.Printf("ERROR: "+msg, args...) -} - -type BackoffType int - -const ( - BackoffTypeJitter BackoffType = iota - BackoffTypeFixed -) - -type Event struct { - Type EventType - Message []byte - MessageType int - Data any - Error error -} - -type EventType int - -const ( - EventMessage EventType = iota - EventConnected - EventReconnecting - EventError - EventClose -) - -var errNotConnected = errors.New("websocket: not connected") - -// setDefaultConfig sets the default configuration for the WebSocket client -func (r *Resws) setDefaultConfig() { - // shouldReconnect is a flag to prevent reconnecting when Close() is called - r.shouldReconnect = true - - if r.RecBackoffMin == 0 { - r.RecBackoffMin = 1000 * time.Millisecond - } - if r.RecBackoffMax == 0 { - r.RecBackoffMax = 30 * time.Second - } - if r.RecBackoffFactor == 0 { - r.RecBackoffFactor = 1.5 - } - if r.HandshakeTimeout == 0 { - r.HandshakeTimeout = 2 * time.Second - } - if r.StableConnectionDuration == 0 { - r.StableConnectionDuration = 30 * time.Second - } - if r.Logger == nil { - r.Logger = &defaultLogger{} - } - if r.PingInterval == 0 { - r.PingInterval = 15 * time.Second - } - r.dialer = &websocket.Dialer{ - TLSClientConfig: r.TLSConfig, - Proxy: r.Proxy, - HandshakeTimeout: r.getHandshakeTimeout(), - } -} - -// setURL sets the URL of the WebSocket server -func (r *Resws) setURL(urlStr string) { - r.url = urlStr -} - -// parseURL parses the URL of the WebSocket server -func (r *Resws) parseURL(urlStr string) (string, error) { - if strings.TrimSpace(urlStr) == "" { - return "", fmt.Errorf("url cannot be empty") - } +// Dial establishes a connection to the WebSocket server +func (r *Resws) Dial(urlStr string) { + r.ctx, r.cancel = context.WithCancel(context.Background()) - if len(urlStr) < 5 { - return "", fmt.Errorf("url too short") + parsed, err := r.parseURL(urlStr) + if err != nil { + r.lastErr = err + r.emitEvent(Event{Type: EventError, Error: err}) + return } - u, err := url.Parse(urlStr) + r.mu.Lock() + r.connectedCh = make(chan struct{}, 1) + r.connOnce = new(sync.Once) + r.closeOnce = sync.Once{} + r.mu.Unlock() - if err != nil { - return "", fmt.Errorf("url: %s", err.Error()) - } + r.setURL(parsed) + r.setDefaultConfig() - if u.Scheme != "ws" && u.Scheme != "wss" { - return "", fmt.Errorf("url: websocket uris must start with ws or wss scheme") - } + go r.connect() - if u.User != nil { - return "", fmt.Errorf("url: user name and password are not allowed in websocket URIs") - } + timer := time.NewTimer(r.getHandshakeTimeout()) + defer timer.Stop() - return urlStr, nil -} + r.mu.RLock() + connectedCh := r.connectedCh + r.mu.RUnlock() -// emitEvent emits an event to the event handlers -func (r *Resws) emitEvent(event Event) { - switch event.Type { - case EventConnected: - if r.onConnectedFn != nil { - r.onConnectedFn(r.url) - } - case EventReconnecting: - if r.onReconnectingFn != nil { - r.onReconnectingFn(event.Data.(time.Duration)) - } - case EventError: - if r.onErrorFn != nil { - r.onErrorFn(event.Error) - } + select { + case <-timer.C: + return + case <-connectedCh: + return } } @@ -289,59 +98,6 @@ func (r *Resws) OnReconnecting(fn func(time.Duration)) { r.onReconnectingFn = fn } -// setIsConnected sets the connection state -func (r *Resws) setIsConnected(isConnected bool) { - r.mu.Lock() - defer r.mu.Unlock() - - r.isConnected = isConnected -} - -// signalConnected signals the connection state -func (r *Resws) signalConnected() { - r.mu.Lock() - defer r.mu.Unlock() - - r.connOnce.Do(func() { - select { - case r.connectedCh <- struct{}{}: - default: - } - }) -} - -// IsConnected returns the connection state -func (r *Resws) IsConnected() bool { - r.mu.RLock() - defer r.mu.RUnlock() - - return r.isConnected -} - -// LastConnectTime returns the last connection time -func (r *Resws) LastConnectTime() time.Time { - r.mu.RLock() - defer r.mu.RUnlock() - - return r.lastConnect -} - -// LastError returns the last error -func (r *Resws) LastError() error { - r.mu.RLock() - defer r.mu.RUnlock() - - return r.lastErr -} - -// GetHTTPResponse returns the HTTP response -func (r *Resws) GetHTTPResponse() *http.Response { - r.mu.RLock() - defer r.mu.RUnlock() - - return r.httpResp -} - // Close closes the connection func (r *Resws) Close() { r.closeOnce.Do(func() { @@ -382,22 +138,22 @@ func (r *Resws) Close() { func (r *Resws) CloseAndReconnect() { r.mu.Lock() conn := r.Conn - if conn != nil { - // Attempt graceful close - deadline := time.Now().Add(250 * time.Millisecond) - _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), deadline) - time.Sleep(100 * time.Millisecond) - _ = conn.Close() - r.Conn = nil - } + r.Conn = nil if r.connCancel != nil { r.connCancel() } - r.connectedCh = make(chan struct{}, 1) r.connOnce = new(sync.Once) r.mu.Unlock() + if conn != nil { + // Attempt graceful close outside the lock — WriteControl is concurrent-safe + deadline := time.Now().Add(250 * time.Millisecond) + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), deadline) + time.Sleep(100 * time.Millisecond) + _ = conn.Close() + } + r.setIsConnected(false) // Wait for the connection to close @@ -421,582 +177,148 @@ func (r *Resws) CloseAndReconnect() { go r.connect() } -// getHandshakeTimeout returns the handshake timeout -func (r *Resws) getHandshakeTimeout() time.Duration { +// IsConnected returns the connection state +func (r *Resws) IsConnected() bool { r.mu.RLock() defer r.mu.RUnlock() - return r.HandshakeTimeout + return r.isConnected } -// Dial establishes a connection to the WebSocket server -func (r *Resws) Dial(url string) { - r.ctx, r.cancel = context.WithCancel(context.Background()) - - urlStr, err := r.parseURL(url) - if err != nil { - r.lastErr = err - r.emitEvent(Event{Type: EventError, Error: err}) - return - } - - r.mu.Lock() - r.connectedCh = make(chan struct{}, 1) - r.connOnce = new(sync.Once) - r.closeOnce = sync.Once{} - r.mu.Unlock() - - r.setURL(urlStr) - r.setDefaultConfig() - - go r.connect() - - timer := time.NewTimer(r.getHandshakeTimeout()) - defer timer.Stop() - +// LastConnectTime returns the last connection time +func (r *Resws) LastConnectTime() time.Time { r.mu.RLock() - connectedCh := r.connectedCh - r.mu.RUnlock() + defer r.mu.RUnlock() - select { - case <-timer.C: - return - case <-connectedCh: - return - } + return r.lastConnect } -// connect establishes a connection to the WebSocket server -func (r *Resws) connect() { - r.mu.Lock() - if r.connCancel != nil { - r.connCancel() - } - r.mu.Unlock() - - r.connWg.Wait() - - r.mu.Lock() - r.connCtx, r.connCancel = context.WithCancel(r.ctx) - r.mu.Unlock() - +// LastError returns the last error +func (r *Resws) LastError() error { r.mu.RLock() - recBackoff := r.RecBackoffMin - r.mu.RUnlock() - - attempt := 0 - - for { - select { - case <-r.ctx.Done(): - return - default: - conn, resp, err := r.dialer.Dial(r.url, r.Headers) - if err != nil { - r.handleDialFailure(resp, err, recBackoff) - recBackoff = r.backoff(attempt) - attempt++ - continue - } - - r.mu.Lock() - r.Conn = conn - r.lastConnect = time.Now() - r.mu.Unlock() - - if !r.NonVerbose { - r.Logger.Info("Connection was successfully established: %s", r.url) - } - - r.signalConnected() - r.setIsConnected(true) - - // Start connection stability monitor - r.connWg.Add(1) - go func() { - defer r.connWg.Done() - r.monitorConnectionStability(r.connCtx) - }() - - // Start message queue processor - r.connWg.Add(1) - go func() { - defer r.connWg.Done() - r.processMessageQueue(r.connCtx) - }() - - if r.PingHandler != nil { - r.connWg.Add(1) - go func() { - defer r.connWg.Done() - r.heartbeat(r.connCtx) - }() - } - if r.MessageHandler != nil { - r.connWg.Add(1) - go func() { - defer r.connWg.Done() - r.reader(r.connCtx) - }() - } + defer r.mu.RUnlock() - r.emitEvent(Event{Type: EventConnected}) + return r.lastErr +} - // Retry subscribe handler with backoff after connection is fully established - if r.SubscribeHandler != nil { - if err := r.retrySubscribeHandler(); err != nil { - r.CloseAndReconnect() - return - } - } +// GetHTTPResponse returns the HTTP response +func (r *Resws) GetHTTPResponse() *http.Response { + r.mu.RLock() + defer r.mu.RUnlock() - return - } - } + return r.httpResp } -func (r *Resws) handleDialFailure(resp *http.Response, err error, backoff time.Duration) { - r.mu.Lock() - r.Conn = nil - r.httpResp = resp - r.lastErr = err +// setDefaultConfig sets the default configuration for the WebSocket client +func (r *Resws) setDefaultConfig() { + // shouldReconnect is a flag to prevent reconnecting when Close() is called r.shouldReconnect = true - r.mu.Unlock() - - r.setIsConnected(false) - if r.onErrorFn != nil { - r.emitEvent(Event{Type: EventError, Error: err}) - } - if !r.NonVerbose { - r.Logger.Info("Will reconnect in %v", backoff) + if r.RecBackoffMin == 0 { + r.RecBackoffMin = 1000 * time.Millisecond } - if r.onReconnectingFn != nil { - r.emitEvent(Event{Type: EventReconnecting, Data: backoff}) + if r.RecBackoffMax == 0 { + r.RecBackoffMax = 30 * time.Second } - - time.Sleep(backoff) -} - -func (r *Resws) retrySubscribeHandler() error { - r.mu.RLock() - backoff := r.RecBackoffMin - max := r.RecBackoffMax - r.mu.RUnlock() - - attempt := 0 - - for { - select { - case <-r.connCtx.Done(): - return r.connCtx.Err() - default: - } - - if err := r.SubscribeHandler(); err != nil { - r.Logger.Error("Subscribe handler failed: %v", err) - r.emitEvent(Event{Type: EventError, Error: err}) - - if backoff >= max { - return fmt.Errorf("subscribe handler failed after max retries: %w", err) - } - - if !r.NonVerbose { - r.Logger.Info("Retrying subscribe handler in %v", backoff) - } - - select { - case <-r.connCtx.Done(): - return r.connCtx.Err() - case <-time.After(backoff): - } - - backoff = r.backoff(attempt) - attempt++ - } else { - if !r.NonVerbose { - r.Logger.Info("Subscribe handler executed successfully") - } - return nil - } + if r.RecBackoffFactor == 0 { + r.RecBackoffFactor = 1.5 } -} - -// Send sends a message to the WebSocket server with a fallback queue -func (r *Resws) Send(msg []byte) error { - r.mu.RLock() - conn := r.Conn - r.mu.RUnlock() - - // If we have a connection, try to send directly - if conn != nil { - if r.WriteDeadline > time.Duration(0) { - conn.SetWriteDeadline(time.Now().Add(r.WriteDeadline)) - } - r.mu.Lock() - err := conn.WriteMessage(websocket.TextMessage, msg) - r.mu.Unlock() - if err == nil { - return nil - } + if r.HandshakeTimeout == 0 { + r.HandshakeTimeout = 2 * time.Second } - - // If no connection or send failed, queue the message - r.messageQueueMu.Lock() - if len(r.messageQueue) >= r.MessageQueueSize { - r.messageQueueMu.Unlock() - return fmt.Errorf("message queue is full") + if r.StableConnectionDuration == 0 { + r.StableConnectionDuration = 30 * time.Second } - r.messageQueue = append(r.messageQueue, msg) - r.messageQueueMu.Unlock() - if r.Logger == nil { r.Logger = &defaultLogger{} } - r.Logger.Debug("Message queued for later delivery") - return nil -} - -func (r *Resws) SendJSON(v any) (err error) { - r.mu.RLock() - conn := r.Conn - r.mu.RUnlock() - - // If we have a connection, try to send directly - if conn != nil { - if r.WriteDeadline > time.Duration(0) { - conn.SetWriteDeadline(time.Now().Add(r.WriteDeadline)) - } - r.mu.Lock() - err = conn.WriteJSON(v) - r.mu.Unlock() - if err == nil { - return nil - } - } - - // If no connection or send failed, queue the message - r.messageQueueMu.Lock() - if len(r.messageQueue) >= r.MessageQueueSize { - r.messageQueueMu.Unlock() - return fmt.Errorf("message queue is full") - } - b, err := json.Marshal(v) - if err != nil { - r.messageQueueMu.Unlock() - return err + if r.PingInterval == 0 { + r.PingInterval = 15 * time.Second } - r.messageQueue = append(r.messageQueue, b) - r.messageQueueMu.Unlock() - - if r.Logger == nil { - r.Logger = &defaultLogger{} + r.dialer = &websocket.Dialer{ + TLSClientConfig: r.TLSConfig, + Proxy: r.Proxy, + HandshakeTimeout: r.getHandshakeTimeout(), } - r.Logger.Debug("Message queued for later delivery") - return nil } -// processMessageQueue processes messages from the queue and sends them to the WebSocket server -func (r *Resws) processMessageQueue(ctx context.Context) { - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - r.mu.RLock() - conn := r.Conn - isConnected := r.isConnected - if !isConnected || conn == nil { - r.mu.RUnlock() - continue - } - r.mu.RUnlock() - - r.messageQueueMu.Lock() - queueLen := len(r.messageQueue) - if queueLen == 0 { - r.messageQueueMu.Unlock() - continue - } - - // Get the first message - msg := r.messageQueue[0] - // Remove it from the queue - r.messageQueue = r.messageQueue[1:] - r.messageQueueMu.Unlock() - - // Try to send the message - r.mu.RLock() - if r.WriteDeadline > time.Duration(0) { - conn.SetWriteDeadline(time.Now().Add(r.WriteDeadline)) - } - r.mu.RUnlock() - err := conn.WriteMessage(websocket.TextMessage, msg) - if err != nil { - r.Logger.Error("Failed to send queued message: %v", err) - // If we failed to send, try to requeue the message - r.messageQueueMu.Lock() - if len(r.messageQueue) < r.MessageQueueSize { - r.messageQueue = append(r.messageQueue, msg) - r.Logger.Debug("Requeued failed message") - } else { - r.Logger.Error("Failed to requeue message: queue full") - } - r.messageQueueMu.Unlock() - } else { - r.Logger.Debug("Successfully sent queued message") - } - } - } +// setURL sets the URL of the WebSocket server +func (r *Resws) setURL(urlStr string) { + r.url = urlStr } -// reader loops and reads messages from the WebSocket connection and emits them as events -func (r *Resws) reader(ctx context.Context) { - for { - select { - case <-ctx.Done(): - return - default: - r.mu.RLock() - conn := r.Conn - r.mu.RUnlock() - - if conn == nil { - return - } - if r.ReadDeadline > time.Duration(0) { - conn.SetReadDeadline(time.Now().Add(r.ReadDeadline)) - } - msgType, msg, err := conn.ReadMessage() - if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.ClosePolicyViolation) { - return - } - if err != nil { - r.mu.Lock() - reconnect := r.shouldReconnect - if r.Conn == conn { - r.Conn = nil - } - r.mu.Unlock() - r.emitEvent(Event{Type: EventError, Error: err}) - if reconnect { - // Wait for a momment before reconnecting - time.Sleep(100 * time.Millisecond) - r.CloseAndReconnect() - } - return - } - switch msgType { - case websocket.TextMessage, websocket.BinaryMessage: - r.MessageHandler(msgType, msg) - case websocket.CloseMessage: - return - } - } +// parseURL parses the URL of the WebSocket server +func (r *Resws) parseURL(urlStr string) (string, error) { + if strings.TrimSpace(urlStr) == "" { + return "", fmt.Errorf("url cannot be empty") } -} -// ReadMessage manually reads a message from the websocket connection -func (r *Resws) ReadMessage() (msgType int, msg []byte, err error) { - err = errNotConnected - if !r.IsConnected() { - return - } - msgType, msg, err = r.Conn.ReadMessage() - if websocket.IsCloseError(err, websocket.CloseNormalClosure) { - r.Close() - return msgType, msg, nil - } - if err != nil { - r.mu.Lock() - reconnect := r.shouldReconnect - r.mu.Unlock() - if reconnect { - r.CloseAndReconnect() - } + if len(urlStr) < 5 { + return "", fmt.Errorf("url too short") } - return -} + u, err := url.Parse(urlStr) -// ReadJSON manually reads a JSON message from the websocket connection -func (r *Resws) ReadJSON(v any) (err error) { - err = errNotConnected - if !r.IsConnected() { - return - } - err = r.Conn.ReadJSON(v) - if websocket.IsCloseError(err, websocket.CloseNormalClosure) { - r.Close() - return - } if err != nil { - r.mu.Lock() - reconnect := r.shouldReconnect - r.mu.Unlock() - if reconnect { - r.CloseAndReconnect() - } + return "", fmt.Errorf("url: %s", err.Error()) } - return -} - -// WriteMessage manually writes a message to the websocket connection -func (r *Resws) WriteMessage(msgType int, msg []byte) (err error) { - err = errNotConnected - if !r.IsConnected() { - return - } - r.mu.Lock() - err = r.Conn.WriteMessage(msgType, msg) - r.mu.Unlock() - if websocket.IsCloseError(err, websocket.CloseNormalClosure) { - r.Close() - return + if u.Scheme != "ws" && u.Scheme != "wss" { + return "", fmt.Errorf("url: websocket uris must start with ws or wss scheme") } - return -} - -// WriteJSON manually writes a JSON message to the websocket connection -func (r *Resws) WriteJSON(v any) (err error) { - err = errNotConnected - if !r.IsConnected() { - return - } - r.mu.Lock() - err = r.Conn.WriteJSON(v) - r.mu.Unlock() - if websocket.IsCloseError(err, websocket.CloseNormalClosure) { - r.Close() - return + if u.User != nil { + return "", fmt.Errorf("url: user name and password are not allowed in websocket URIs") } - return + return urlStr, nil } -// heartbeat sends ping messages to the server to keep the connection alive -func (r *Resws) heartbeat(ctx context.Context) { - ticker := time.NewTicker(r.PingInterval) - defer ticker.Stop() - - r.mu.RLock() - conn := r.Conn - r.mu.RUnlock() - - if conn != nil && r.PongTimeout > 0 { - _ = conn.SetReadDeadline(time.Now().Add(r.PongTimeout)) - conn.SetPongHandler(func(appData string) error { - _ = conn.SetReadDeadline(time.Now().Add(r.PongTimeout)) - return nil - }) - } - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - r.mu.RLock() - currentConn := r.Conn - r.mu.RUnlock() - - if currentConn == nil { - return - } - - if err := currentConn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(r.PingInterval)); err != nil { - return +// emitEvent emits an event to the event handlers +func (r *Resws) emitEvent(event Event) { + switch event.Type { + case EventConnected: + if r.onConnectedFn != nil { + r.onConnectedFn(r.url) + } + case EventReconnecting: + if r.onReconnectingFn != nil { + if d, ok := event.Data.(time.Duration); ok { + r.onReconnectingFn(d) } - r.PingHandler() + } + case EventError: + if r.onErrorFn != nil { + r.onErrorFn(event.Error) } } } -func (r *Resws) backoff(attempt int) time.Duration { - min := r.RecBackoffMin - max := r.RecBackoffMax - if min >= max { - return max - } - backoffFactor := r.RecBackoffFactor - if backoffFactor == 0 { - backoffFactor = 1.5 - } - - if attempt > 30 { - attempt = 30 - } - backoff := min * time.Duration(1< max { - return max - } - - return backoff.Round(100 * time.Millisecond) -} - -// getReconnectBackoff calculates backoff duration for reconnection attempts -func (r *Resws) getReconnectBackoff() time.Duration { - r.backoffMu.RLock() - attempts := r.reconnectAttempts - r.backoffMu.RUnlock() - - if attempts <= 1 { - return 0 // First connection attempt, no backoff - } - - return r.backoff(attempts - 1) -} +// setIsConnected sets the connection state +func (r *Resws) setIsConnected(isConnected bool) { + r.mu.Lock() + defer r.mu.Unlock() -// incrementReconnectAttempts increments the reconnection attempt counter -func (r *Resws) incrementReconnectAttempts() { - r.backoffMu.Lock() - r.reconnectAttempts++ - r.lastReconnectTime = time.Now() - r.backoffMu.Unlock() + r.isConnected = isConnected } -// getReconnectAttempts returns the current reconnection attempt count -func (r *Resws) getReconnectAttempts() int { - r.backoffMu.RLock() - defer r.backoffMu.RUnlock() - return r.reconnectAttempts -} +// signalConnected signals the connection state +func (r *Resws) signalConnected() { + r.mu.Lock() + defer r.mu.Unlock() -// resetReconnectAttempts resets the reconnection attempt counter -func (r *Resws) resetReconnectAttempts() { - r.backoffMu.Lock() - r.reconnectAttempts = 0 - r.backoffMu.Unlock() + r.connOnce.Do(func() { + select { + case r.connectedCh <- struct{}{}: + default: + } + }) } -// monitorConnectionStability monitors connection stability and resets backoff after stable period -func (r *Resws) monitorConnectionStability(ctx context.Context) { - timer := time.NewTimer(r.StableConnectionDuration) - defer timer.Stop() +// getHandshakeTimeout returns the handshake timeout +func (r *Resws) getHandshakeTimeout() time.Duration { + r.mu.RLock() + defer r.mu.RUnlock() - select { - case <-ctx.Done(): - return - case <-timer.C: - // Connection has been stable, reset reconnect attempts - if r.IsConnected() { - r.resetReconnectAttempts() - if !r.NonVerbose { - r.Logger.Debug("Connection stable for %v, reset backoff counter", r.StableConnectionDuration) - } - } - } + return r.HandshakeTimeout } diff --git a/state.go b/state.go new file mode 100644 index 0000000..fd7b8ed --- /dev/null +++ b/state.go @@ -0,0 +1,17 @@ +package resilientws + +import "github.com/gorilla/websocket" + +// getConn returns the current connection under a read lock. +func (r *Resws) getConn() *websocket.Conn { + r.mu.RLock() + defer r.mu.RUnlock() + return r.Conn +} + +// checkConnection snapshots conn and connected state atomically. +func (r *Resws) checkConnection() (*websocket.Conn, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + return r.Conn, r.isConnected && r.Conn != nil +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..93ecf56 --- /dev/null +++ b/types.go @@ -0,0 +1,134 @@ +package resilientws + +import ( + "context" + "crypto/tls" + "errors" + "net/http" + "net/url" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type Resws struct { + // RecBackoffMin is the minimum backoff duration between reconnection attempts + RecBackoffMin time.Duration + + // RecBackoffMax is the maximum backoff duration between reconnection attempts + RecBackoffMax time.Duration + + // RecBackoffFactor is the factor by which the backoff duration is multiplied + RecBackoffFactor float64 + + // BackoffType is the type of backoff to use + BackoffType BackoffType + + // StableConnectionDuration is the duration a connection must be stable before resetting backoff + StableConnectionDuration time.Duration + + // Handshake timeout + HandshakeTimeout time.Duration + + // Headers to be sent with the connection + Headers http.Header + + // Ping interval + PingInterval time.Duration + + // Pong timeout + PongTimeout time.Duration + + // Read deadline + ReadDeadline time.Duration + + // Write deadline + WriteDeadline time.Duration + + // Message queue size + MessageQueueSize int + + // TLS configuration + TLSConfig *tls.Config + + // Proxy configuration + Proxy func(*http.Request) (*url.URL, error) + + // Logger + Logger Logger + + // Non-verbose mode + NonVerbose bool + + // Subscribe handler + SubscribeHandler func() error + + // Message handler + MessageHandler func(int, []byte) + + // Ping handler + PingHandler func() + + url string + dialer *websocket.Dialer + httpResp *http.Response + mu sync.RWMutex + messageQueue [][]byte + messageQueueMu sync.Mutex + writeMu sync.Mutex + isConnected bool + lastConnect time.Time + lastErr error + shouldReconnect bool + connectedCh chan struct{} + connOnce *sync.Once + closeOnce sync.Once + + // Backoff state that persists across reconnections + reconnectAttempts int + lastReconnectTime time.Time + backoffMu sync.RWMutex + + // Context for connection management + ctx context.Context + cancel context.CancelFunc + connCtx context.Context + connCancel context.CancelFunc + + connWg sync.WaitGroup + + // Event handlers + onReconnectingFn func(time.Duration) + onConnectedFn func(string) + onErrorFn func(error) + + *websocket.Conn +} + +type BackoffType int + +const ( + BackoffTypeJitter BackoffType = iota + BackoffTypeFixed +) + +type Event struct { + Type EventType + Message []byte + MessageType int + Data any + Error error +} + +type EventType int + +const ( + EventMessage EventType = iota + EventConnected + EventReconnecting + EventError + EventClose +) + +var errNotConnected = errors.New("websocket: not connected")