diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index b9d175619d65a..95d9e00b0a119 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -52,7 +52,31 @@ import ( // Send/Recv will completely re-establish the connection (unless Close // has been called). type Client struct { - Header http.Header + Header http.Header + + // GetHeaders, if non-nil, returns a fresh set of HTTP headers to send + // on every (re)connect to the DERP server. When non-nil it takes + // precedence over Header. This is useful when a caller needs to inject + // short-lived authentication tokens (e.g. for an authenticating + // reverse proxy in front of DERP) that must be refreshed on each + // reconnect, rather than captured once at startup. The same pattern + // is already used by netcheck.Client.GetDERPHeaders. + // + // Implementations must be cheap and non-blocking: GetHeaders is invoked + // from connect() while the Client's internal mutex is held, so it must + // not call back into this Client (Send, Close, etc.) or acquire the + // lock of any caller that may, in turn, call into this Client (notably + // magicsock.Conn.mu). It should also avoid blocking I/O on the hot + // reconnect path; cache and refresh in the background where possible. + // + // Returning a nil http.Header is treated the same as a missing static + // Header: the connect request goes out without those caller-supplied + // headers (i.e. without auth). Implementations that fail to obtain + // fresh credentials should generally return the most recent known-good + // value rather than nil, so the server can return a clear 401 instead + // of a generic auth failure. + GetHeaders func() http.Header + TLSConfig *tls.Config // optional; nil means default DNSCache *dnscache.Resolver // optional; nil means no caching MeshKey string // optional; for trusted clients @@ -113,6 +137,17 @@ func (c *Client) String() string { return fmt.Sprintf("", c.serverPubKey.ShortString(), c.url) } +// headers returns the HTTP headers to send on the next DERP connection +// attempt. If GetHeaders is set, it is invoked on every call (so callers can +// refresh short-lived tokens). Otherwise the static Header field is used. +// Either may be nil. +func (c *Client) headers() http.Header { + if c.GetHeaders != nil { + return c.GetHeaders() + } + return c.Header +} + // NewRegionClient returns a new DERP-over-HTTP client. It connects lazily. // To trigger a connection, use Connect. // The netMon parameter is optional; if non-nil it's used to do faster interface lookups. @@ -430,7 +465,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien tlsConfig = c.tlsConfig(nil) } c.logf("%s: connecting websocket to %v", caller, urlStr) - conn, err := dialWebsocketFunc(ctx, urlStr, tlsConfig, c.Header) + conn, err := dialWebsocketFunc(ctx, urlStr, tlsConfig, c.headers()) if err != nil { c.logf("%s: websocket to %v error: %v", caller, urlStr, err) return nil, 0, err @@ -533,8 +568,8 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien if err != nil { return nil, 0, err } - if c.Header != nil { - req.Header = c.Header.Clone() + if h := c.headers(); h != nil { + req.Header = h.Clone() } req.Header.Set("Upgrade", "DERP") req.Header.Set("Connection", "Upgrade") diff --git a/derp/derphttp/derphttp_test.go b/derp/derphttp/derphttp_test.go index 13e8c95599e7d..55e5143f16954 100644 --- a/derp/derphttp/derphttp_test.go +++ b/derp/derphttp/derphttp_test.go @@ -324,3 +324,36 @@ func TestForceWebsockets(t *testing.T) { c.Close() } + +func TestClientHeaders(t *testing.T) { + t.Run("nil when neither set", func(t *testing.T) { + c := &Client{} + if got := c.headers(); got != nil { + t.Fatalf("expected nil headers, got %v", got) + } + }) + t.Run("returns Header when GetHeaders is nil", func(t *testing.T) { + want := http.Header{"X-Test": []string{"static"}} + c := &Client{Header: want} + if got := c.headers().Get("X-Test"); got != "static" { + t.Fatalf("expected static header, got %q", got) + } + }) + t.Run("GetHeaders takes precedence and is invoked on every call", func(t *testing.T) { + var calls int + c := &Client{ + Header: http.Header{"X-Test": []string{"static"}}, + GetHeaders: func() http.Header { + calls++ + return http.Header{"X-Test": []string{"dynamic"}} + }, + } + if got := c.headers().Get("X-Test"); got != "dynamic" { + t.Fatalf("expected dynamic header, got %q", got) + } + _ = c.headers() + if calls != 2 { + t.Fatalf("expected GetHeaders to be invoked on every call, got %d calls", calls) + } + }) +} diff --git a/wgengine/magicsock/derp.go b/wgengine/magicsock/derp.go index fdea777f6f064..1ce7166afdaa1 100644 --- a/wgengine/magicsock/derp.go +++ b/wgengine/magicsock/derp.go @@ -354,6 +354,9 @@ func (c *Conn) derpWriteChanOfAddr(addr netip.AddrPort, peer key.NodePublic) cha if header != nil { dc.Header = header.Clone() } + if getHeaders := c.derpGetHeaders.Load(); getHeaders != nil { + dc.GetHeaders = *getHeaders + } dc.ForceWebsockets = c.derpForceWebsockets.Load() dialer := c.derpRegionDialer.Load() if dialer != nil { diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 1349add573439..6b30da89e50d5 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -171,6 +171,11 @@ type Conn struct { // headers that are passed to the DERP HTTP client derpHeader atomic.Pointer[http.Header] + // derpGetHeaders, if non-nil, is called by the DERP HTTP client on every + // (re)connect to obtain a fresh set of HTTP headers. When non-nil it + // takes precedence over derpHeader. See derphttp.Client.GetHeaders. + derpGetHeaders atomic.Pointer[func() http.Header] + // whether websocket is always used by the DERP HTTP client derpForceWebsockets atomic.Bool @@ -462,6 +467,9 @@ func NewConn(opts Options) (*Conn, error) { PortMapper: c.portMapper, UseDNSCache: true, GetDERPHeaders: func() http.Header { + if getHeaders := c.derpGetHeaders.Load(); getHeaders != nil { + return (*getHeaders)() + } h := c.derpHeader.Load() if h == nil { return nil @@ -1759,6 +1767,17 @@ func (c *Conn) SetDERPHeader(header http.Header) { c.derpHeader.Store(&header) } +// SetDERPGetHeaders sets a callback invoked by the DERP HTTP client on every +// (re)connect to obtain a fresh set of HTTP headers. When non-nil it takes +// precedence over the value set with SetDERPHeader. Pass nil to clear. +func (c *Conn) SetDERPGetHeaders(getHeaders func() http.Header) { + if getHeaders == nil { + c.derpGetHeaders.Store(nil) + return + } + c.derpGetHeaders.Store(&getHeaders) +} + func (c *Conn) SetDERPForceWebsockets(v bool) { c.derpForceWebsockets.Store(v) }