Skip to content

Commit 6afbe5d

Browse files
committed
ba-proxy-agent: close idle connections to mitigate memory leaks
The ba-proxy-agent currently experiences increasing memory consumption, leading to daily or weekly restarts. Previous mitigation attempts were insufficient, and the issue has been isolated to high memory usage on the agent connection side. This CL mitigates the issue by enforcing a timeout on idle connections. Since the root cause remains elusive, forcefully closing unused connections prevents memory accumulation from persistent links. Implementation details: * Introduced `lastActivityTime` property to agent connections, which updates upon usage. * Added a background routine to monitor connection activity. * Configured the routine to explicitly close connections that remain idle for more than 30 seconds.
1 parent d9a3e67 commit 6afbe5d

5 files changed

Lines changed: 181 additions & 26 deletions

File tree

agent/agent.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ var (
7373
injectBanner = flag.String("inject-banner", "", "HTML snippet to inject in served webpages")
7474
bannerHeight = flag.String("banner-height", "40px", "Height of the injected banner. This is ignored if no banner is set.")
7575
shimWebsockets = flag.Bool("shim-websockets", false, "Whether or not to replace websockets with a shim")
76+
websocketShimTimeout = flag.Duration("websocket-shim-timeout", 60*time.Minute, "Timeout for websocket shim connections to expire due to inactivity.")
7677
shimPath = flag.String("shim-path", "", "Path under which to handle websocket shim requests")
7778
healthCheckPath = flag.String("health-check-path", "/", "Path on backend host to issue health checks against. Defaults to the root.")
7879
healthCheckFreq = flag.Int("health-check-interval-seconds", 0, "Wait time in seconds between health checks. Set to zero to disable health checks. Checks disabled by default.")
@@ -126,7 +127,8 @@ func hostProxy(ctx context.Context, host, shimPath string, injectShimCode, force
126127
// restricted to a path prefix not equal to "/" will fail for websocket open requests. Passing in the
127128
// sessionHandler twice allows the websocket handler to ensure that cookies are applied based on the
128129
// correct, restored path.
129-
h, err = websockets.Proxy(ctx, h, host, shimPath, *rewriteWebsocketHost, *enableWebsocketsInjection, sessionLRU.SessionHandler, metricHandler)
130+
h, err = websockets.Proxy(ctx, h, host, shimPath, *rewriteWebsocketHost, *enableWebsocketsInjection, sessionLRU.SessionHandler,
131+
metricHandler, *websocketShimTimeout)
130132
if injectShimCode {
131133
shimFunc, err := websockets.ShimBody(shimPath)
132134
if err != nil {

agent/websockets/connection.go

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@ limitations under the License.
1717
package websockets
1818

1919
import (
20+
"context"
2021
"encoding/base64"
2122
"encoding/json"
2223
"errors"
2324
"fmt"
2425
"log"
2526
"net/http"
27+
"sync"
2628
"time"
2729

28-
"context"
29-
3030
"github.com/gorilla/websocket"
3131
)
3232

@@ -57,12 +57,14 @@ func (m *message) Serialize(version int) interface{} {
5757
// and encapsulates it in an API that is a little more amenable to how the server side
5858
// of our websocket shim is implemented.
5959
type Connection struct {
60-
done func() <-chan struct{}
61-
cancel context.CancelFunc
62-
clientMessages chan *message
63-
serverMessages chan *message
64-
protocolVersion int
65-
subprotocol string
60+
done func() <-chan struct{}
61+
cancel context.CancelFunc
62+
clientMessages chan *message
63+
serverMessages chan *message
64+
protocolVersion int
65+
subprotocol string
66+
mu sync.Mutex
67+
lastActivityTime time.Time
6668
}
6769

6870
// This map defines the set of headers that should be stripped from the WS request, as they
@@ -87,6 +89,20 @@ func stripWSHeader(header http.Header) http.Header {
8789
return result
8890
}
8991

92+
// updateActivity updates the last activity timestamp.
93+
func (conn *Connection) updateActivity() {
94+
conn.mu.Lock()
95+
defer conn.mu.Unlock()
96+
conn.lastActivityTime = time.Now()
97+
}
98+
99+
// lastActivity returns the last activity timestamp.
100+
func (conn *Connection) lastActivity() time.Time {
101+
conn.mu.Lock()
102+
defer conn.mu.Unlock()
103+
return conn.lastActivityTime
104+
}
105+
90106
// NewConnection creates and returns a new Connection.
91107
func NewConnection(ctx context.Context, targetURL string, header http.Header, errCallback func(err error)) (*Connection, error) {
92108
ctx, cancel := context.WithCancel(ctx)
@@ -162,11 +178,12 @@ func NewConnection(ctx context.Context, targetURL string, header http.Header, er
162178
}
163179
}()
164180
return &Connection{
165-
done: ctx.Done,
166-
cancel: cancel,
167-
clientMessages: clientMessages,
168-
serverMessages: serverMessages,
169-
subprotocol: serverConn.Subprotocol(),
181+
done: ctx.Done,
182+
cancel: cancel,
183+
clientMessages: clientMessages,
184+
serverMessages: serverMessages,
185+
subprotocol: serverConn.Subprotocol(),
186+
lastActivityTime: time.Now(),
170187
}, nil
171188
}
172189

@@ -184,6 +201,7 @@ func (conn *Connection) Close() {
184201
//
185202
// The returned error value is non-nill if the connection has been closed.
186203
func (conn *Connection) SendClientMessage(msg interface{}, injectionEnabled bool, injectedHeaders map[string]string) error {
204+
conn.updateActivity()
187205
var clientMessage *message
188206
if textMsg, ok := msg.(string); ok {
189207
clientMessage = &message{
@@ -236,7 +254,9 @@ func (conn *Connection) SendClientMessage(msg interface{}, injectionEnabled bool
236254
//
237255
// The returned []string value is nil if the error is non-nil, or if the method
238256
// times out while waiting for a server message.
239-
func (conn *Connection) ReadServerMessages() ([]interface{}, error) {
257+
func (conn *Connection) ReadServerMessages(readTimeout time.Duration) ([]interface{}, error) {
258+
conn.updateActivity()
259+
defer conn.updateActivity()
240260
var msgs []interface{}
241261
select {
242262
case serverMsg, ok := <-conn.serverMessages:
@@ -257,7 +277,7 @@ func (conn *Connection) ReadServerMessages() ([]interface{}, error) {
257277
return msgs, nil
258278
}
259279
}
260-
case <-time.After(time.Second * 20):
280+
case <-time.After(readTimeout):
261281
return nil, nil
262282
}
263283
}

agent/websockets/shim.go

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package websockets
1818

1919
import (
2020
"bytes"
21+
"context"
2122
"encoding/json"
2223
"fmt"
2324
"io"
@@ -31,8 +32,8 @@ import (
3132
"sync"
3233
"sync/atomic"
3334
"text/template"
35+
"time"
3436

35-
"context"
3637
"github.com/google/inverting-proxy/agent/metrics"
3738
)
3839

@@ -320,9 +321,33 @@ func (c *connectionErrorHandler) ReportError(err error) {
320321
}
321322
}
322323

323-
func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost bool, openWebsocketWrapper func(http.Handler, *metrics.MetricHandler) http.Handler, enableWebsocketInjection bool, metricHandler *metrics.MetricHandler) http.Handler {
324+
func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost bool, openWebsocketWrapper func(http.Handler, *metrics.MetricHandler) http.Handler, enableWebsocketInjection bool, metricHandler *metrics.MetricHandler, timeout time.Duration) http.Handler {
324325
var connections sync.Map
325326
var sessionCount uint64
327+
328+
// Background goroutine to clean up inactive websocket shim connections.
329+
go func() {
330+
ticker := time.NewTicker(min(timeout, 30*time.Second))
331+
defer ticker.Stop()
332+
for {
333+
select {
334+
case <-ctx.Done():
335+
return
336+
case <-ticker.C:
337+
connections.Range(func(key, value any) bool {
338+
sessionID := key.(string)
339+
conn := value.(*Connection)
340+
if time.Since(conn.lastActivity()) > timeout {
341+
log.Printf("Closing inactive websocket shim session %q after timeout", sessionID)
342+
conn.Close()
343+
connections.Delete(sessionID)
344+
}
345+
return true // Continue iteration
346+
})
347+
}
348+
}
349+
}()
350+
326351
mux := http.NewServeMux()
327352
errorHandler := &connectionErrorHandler{}
328353
openWebsocketHandler := openWebsocketWrapper(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -351,9 +376,9 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b
351376
}
352377
}
353378
resp := &sessionMessage{
354-
ID: sessionID,
355-
Message: targetURL.String(),
356-
Version: conn.protocolVersion,
379+
ID: sessionID,
380+
Message: targetURL.String(),
381+
Version: conn.protocolVersion,
357382
Subprotocol: conn.Subprotocol(),
358383
}
359384
respBytes, err := json.Marshal(resp)
@@ -512,7 +537,7 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b
512537
metricHandler.WriteResponseCodeMetric(statusCode)
513538
return
514539
}
515-
serverMsgs, err := conn.ReadServerMessages()
540+
serverMsgs, err := conn.ReadServerMessages(min(20*time.Second, timeout/2))
516541
if err != nil {
517542
statusCode := http.StatusBadRequest
518543
errorMessage := fmt.Sprintf("attempt to read data from a closed session: %q", msg.ID)
@@ -548,11 +573,11 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b
548573
// openWebsocketWrapper is a http.Handler wrapper function that is invoked on websocket open requests after the original
549574
// targetURL of the request is restored. It must call the wrapped http.Handler with which it is created after it
550575
// is finished processing the request.
551-
func Proxy(ctx context.Context, wrapped http.Handler, host, shimPath string, rewriteHost, enableWebsocketInjection bool, openWebsocketWrapper func(wrapped http.Handler, metricHandler *metrics.MetricHandler) http.Handler, metricHandler *metrics.MetricHandler) (http.Handler, error) {
576+
func Proxy(ctx context.Context, wrapped http.Handler, host, shimPath string, rewriteHost, enableWebsocketInjection bool, openWebsocketWrapper func(wrapped http.Handler, metricHandler *metrics.MetricHandler) http.Handler, metricHandler *metrics.MetricHandler, timeout time.Duration) (http.Handler, error) {
552577
mux := http.NewServeMux()
553578
if shimPath != "" {
554579
shimPath = path.Clean("/"+shimPath) + "/"
555-
shimServer := createShimChannel(ctx, host, shimPath, rewriteHost, openWebsocketWrapper, enableWebsocketInjection, metricHandler)
580+
shimServer := createShimChannel(ctx, host, shimPath, rewriteHost, openWebsocketWrapper, enableWebsocketInjection, metricHandler, timeout)
556581
mux.Handle(shimPath, shimServer)
557582
}
558583
mux.Handle("/", wrapped)

agent/websockets/websockets_test.go

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@ import (
2323
"errors"
2424
"fmt"
2525
"io"
26+
"io/ioutil"
2627
"net/http"
2728
"net/http/httptest"
2829
"net/url"
2930
"path"
3031
"strings"
3132
"sync"
3233
"testing"
34+
"time"
3335

3436
"github.com/google/go-cmp/cmp"
3537
"github.com/google/go-cmp/cmp/cmpopts"
@@ -239,7 +241,7 @@ func TestShimHandlers(t *testing.T) {
239241
openWrapper := func(h http.Handler, metricHandler *metrics.MetricHandler) http.Handler {
240242
return h
241243
}
242-
p, err := Proxy(context.Background(), h, serverURL.Host, testShimPath, false, false, openWrapper, nil)
244+
p, err := Proxy(context.Background(), h, serverURL.Host, testShimPath, false, false, openWrapper, nil, 60*time.Second)
243245
if err != nil {
244246
t.Fatalf("Failure creating the websocket shim proxy: %+v", err)
245247
}
@@ -354,3 +356,107 @@ func TestShimHandlers(t *testing.T) {
354356
}
355357
}
356358
}
359+
360+
func TestShimPolling(t *testing.T) {
361+
ctx, cancel := context.WithCancel(context.Background())
362+
defer cancel()
363+
// Setup a fake backend that accepts websocket connections but sends no messages.
364+
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
365+
upgrader := websocket.Upgrader{}
366+
conn, err := upgrader.Upgrade(w, r, nil)
367+
if err != nil {
368+
t.Logf("Failed to upgrade websocket: %v", err)
369+
return
370+
}
371+
defer conn.Close()
372+
// Keep connection open until client closes it.
373+
for {
374+
if _, _, err := conn.NextReader(); err != nil {
375+
break
376+
}
377+
}
378+
}))
379+
defer backend.Close()
380+
381+
backendURL, err := url.Parse(backend.URL)
382+
if err != nil {
383+
t.Fatalf("Failed to parse backend URL: %v", err)
384+
}
385+
386+
shimPath := "/shim/"
387+
idleTimeout := 3 * time.Second
388+
shim := createShimChannel(
389+
ctx,
390+
backendURL.Host,
391+
shimPath,
392+
false,
393+
func(h http.Handler, m *metrics.MetricHandler) http.Handler { return h },
394+
false,
395+
nil,
396+
idleTimeout,
397+
)
398+
shimServer := httptest.NewServer(shim)
399+
defer shimServer.Close()
400+
401+
// 1. Open a websocket connection via the shim.
402+
openURL := shimServer.URL + shimPath + "open"
403+
resp, err := http.Post(openURL, "text/plain", strings.NewReader(backendURL.String()))
404+
if err != nil {
405+
t.Fatalf("Failed to open shim connection: %v", err)
406+
}
407+
if resp.StatusCode != http.StatusOK {
408+
t.Fatalf("Failed to open shim connection, status: %d", resp.StatusCode)
409+
}
410+
body, err := ioutil.ReadAll(resp.Body)
411+
if err != nil {
412+
t.Fatalf("Failed to read open response body: %v", err)
413+
}
414+
resp.Body.Close()
415+
var openResp sessionMessage
416+
if err := json.Unmarshal(body, &openResp); err != nil {
417+
t.Fatalf("Failed to unmarshal open response: %v", err)
418+
}
419+
sessionID := openResp.ID
420+
if sessionID == "" {
421+
t.Fatal("No sessionID in open response")
422+
}
423+
424+
// 2. Poll repeatedly without any messages being sent.
425+
pollURL := shimServer.URL + shimPath + "poll"
426+
pollReq := fmt.Sprintf(`{"id": %q}`, sessionID)
427+
timeout := time.Now().Add(idleTimeout + 1*time.Second)
428+
for time.Now().Before(timeout) {
429+
resp, err := http.Post(pollURL, "application/json", strings.NewReader(pollReq))
430+
if err != nil {
431+
t.Fatalf("Failed to poll shim connection: %v", err)
432+
}
433+
resp.Body.Close()
434+
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusRequestTimeout {
435+
t.Fatalf("Unexpected status code during polling: %d", resp.StatusCode)
436+
}
437+
// Sleep for half of read timeout to simulate polling faster than idle timeout.
438+
time.Sleep(500 * time.Millisecond)
439+
}
440+
441+
// 3. After idleTimeout + 1s, one more poll should succeed.
442+
resp, err = http.Post(pollURL, "application/json", strings.NewReader(pollReq))
443+
if err != nil {
444+
t.Fatalf("Failed to poll shim connection: %v", err)
445+
}
446+
defer resp.Body.Close()
447+
if resp.StatusCode != http.StatusRequestTimeout {
448+
t.Errorf("Polling after idle timeout got status %d, want %d", resp.StatusCode, http.StatusRequestTimeout)
449+
}
450+
451+
// 4. Close connection.
452+
closeURL := shimServer.URL + shimPath + "close"
453+
closeReq := fmt.Sprintf(`{"id": %q}`, sessionID)
454+
resp, err = http.Post(closeURL, "application/json", strings.NewReader(closeReq))
455+
if err != nil {
456+
t.Fatalf("Failed to close shim connection: %v", err)
457+
}
458+
defer resp.Body.Close()
459+
if resp.StatusCode != http.StatusOK {
460+
t.Errorf("Close shim connection got status %d, want %d", resp.StatusCode, http.StatusOK)
461+
}
462+
}

testing/websockets/main.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"net/http"
3232
"net/http/httputil"
3333
"net/url"
34+
"time"
3435

3536
"github.com/google/inverting-proxy/agent/metrics"
3637
"github.com/google/inverting-proxy/agent/websockets"
@@ -48,6 +49,7 @@ var (
4849
monitoringEndpoint = flag.String("monitoring-endpoint", "staging-monitoring.sandbox.googleapis.com:443", "The endpoint to which to write metrics. Eg: monitoring.googleapis.com corresponds to Cloud Monarch.")
4950
monitoringResourceType = flag.String("monitoring-resource-type", "gce_instance", "The monitoring resource type. Eg: gce_instance")
5051
monitoringResourceLabels = flag.String("monitoring-resource-labels", "instance-id=fake-instance-id,instance-zone=us-west1-a", "Comma separated key value pairs for the purpose of monitoring configuration. Eg: 'instance-id=my-instance-id,instance-zone=us-west1-a")
52+
websocketShimTimeout = flag.Duration("websocket-shim-timeout", 60*time.Minute, "Timeout for websocket shim connections to expire due to inactivity.")
5153
)
5254

5355
func main() {
@@ -69,7 +71,7 @@ func main() {
6971
}
7072

7173
backendProxy := httputil.NewSingleHostReverseProxy(backendURL)
72-
shimmingProxy, err := websockets.Proxy(context.Background(), backendProxy, backendURL.Host, *shimPath, true, *enableWebsocketInjection, func(h http.Handler, metricHandler *metrics.MetricHandler) http.Handler { return h }, metricHandler)
74+
shimmingProxy, err := websockets.Proxy(context.Background(), backendProxy, backendURL.Host, *shimPath, true, *enableWebsocketInjection, func(h http.Handler, metricHandler *metrics.MetricHandler) http.Handler { return h }, metricHandler, *websocketShimTimeout)
7375
if err != nil {
7476
log.Fatalf("Failure starting the websocket-shimming proxy: %v", err)
7577
}

0 commit comments

Comments
 (0)