Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions conn/bind_std.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"sync"
"syscall"

"github.com/tailscale/wireguard-go/iobuf"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
Expand Down Expand Up @@ -233,13 +234,14 @@ func (s *StdNetBind) receiveIP(
br batchReader,
conn *net.UDPConn,
rxOffload bool,
bufs [][]byte,
sizes []int,
bufs []iobuf.View,
eps []Endpoint,
) (n int, err error) {
msgs := s.getMessages()
// TODO: placeholder until bind implements right-sized buffers.
iobuf.EnsureAllocated(bufs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].Buffers[0] = bufs[i].Bytes
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer s.putMessages(msgs)
Expand Down Expand Up @@ -271,8 +273,8 @@ func (s *StdNetBind) receiveIP(
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
if sizes[i] == 0 {
bufs[i].Bytes = bufs[i].Bytes[:msg.N]
if len(bufs[i].Bytes) == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
Expand All @@ -284,14 +286,14 @@ func (s *StdNetBind) receiveIP(
}

func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
return func(bufs []iobuf.View, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, eps)
}
}

func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
return func(bufs []iobuf.View, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, eps)
}
}

Expand Down
8 changes: 4 additions & 4 deletions conn/bind_std_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net"
"testing"

"github.com/tailscale/wireguard-go/iobuf"
"golang.org/x/net/ipv6"
)

Expand All @@ -15,15 +16,14 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
t.Fatal(err)
}
bind.Close()
bufs := make([][]byte, 1)
bufs[0] = make([]byte, 1)
sizes := make([]int, 1)
bufs := make([]iobuf.View, 1)
bufs[0] = iobuf.View{Bytes: make([]byte, 1)}
eps := make([]Endpoint, 1)
for _, fn := range fns {
// The ReceiveFuncs must not access conn-related fields on StdNetBind
// unguarded. Close() nils the conn-related fields resulting in a panic
// if they violate the mutex.
fn(bufs, sizes, eps)
fn(bufs, eps)
}
}

Expand Down
15 changes: 9 additions & 6 deletions conn/bind_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"golang.org/x/sys/windows"

"github.com/tailscale/wireguard-go/conn/winrio"
"github.com/tailscale/wireguard-go/iobuf"
)

const (
Expand Down Expand Up @@ -416,20 +417,22 @@ retry:
return n, &ep, nil
}

func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
func (bind *WinRingBind) receiveIPv4(bufs []iobuf.View, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
sizes[0] = n
iobuf.EnsureAllocated(bufs[:1])
n, ep, err := bind.v4.Receive(bufs[0].Bytes, &bind.isOpen)
bufs[0].Bytes = bufs[0].Bytes[:n]
eps[0] = ep
return 1, err
}

func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
func (bind *WinRingBind) receiveIPv6(bufs []iobuf.View, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
sizes[0] = n
iobuf.EnsureAllocated(bufs[:1])
n, ep, err := bind.v6.Receive(bufs[0].Bytes, &bind.isOpen)
bufs[0].Bytes = bufs[0].Bytes[:n]
eps[0] = ep
return 1, err
}
Expand Down
8 changes: 5 additions & 3 deletions conn/bindtest/bindtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"os"

"github.com/tailscale/wireguard-go/conn"
"github.com/tailscale/wireguard-go/iobuf"
)

type ChannelBind struct {
Expand Down Expand Up @@ -94,13 +95,14 @@ func (c *ChannelBind) BatchSize() int { return 1 }
func (c *ChannelBind) SetMark(mark uint32) error { return nil }

func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
return func(bufs []iobuf.View, eps []conn.Endpoint) (n int, err error) {
select {
case <-c.closeSignal:
return 0, net.ErrClosed
case rx := <-ch:
copied := copy(bufs[0], rx)
sizes[0] = copied
iobuf.EnsureAllocated(bufs[:1])
n := copy(bufs[0].Bytes, rx)
bufs[0].Bytes = bufs[0].Bytes[:n]
eps[0] = c.target6
return 1, nil
}
Expand Down
15 changes: 8 additions & 7 deletions conn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@ import (
"reflect"
"runtime"
"strings"

"github.com/tailscale/wireguard-go/iobuf"
)

const (
IdealBatchSize = 128 // maximum number of packets handled per read and write
)

// A ReceiveFunc receives at least one packet from the network and writes them
// into packets. On a successful read it returns the number of elements of
// sizes, packets, and endpoints that should be evaluated. Some elements of
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
// and eps slice with a length greater than or equal to the length of packets.
// These lengths must not exceed the length of the associated Bind.BatchSize().
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
// A ReceiveFunc receives at least one packet from the network into bufs.
// On a successful read it returns the number of elements of bufs and eps
// that should be evaluated. Callers must pass an eps slice with a length
// greater than or equal to the length of bufs. These lengths must not
// exceed the length of the associated Bind.BatchSize().
type ReceiveFunc func(bufs []iobuf.View, eps []Endpoint) (n int, err error)

// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
//
Expand Down
4 changes: 3 additions & 1 deletion conn/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ package conn

import (
"testing"

"github.com/tailscale/wireguard-go/iobuf"
)

func TestPrettyName(t *testing.T) {
var (
recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
recvFunc ReceiveFunc = func(bufs []iobuf.View, eps []Endpoint) (n int, err error) { return }
)

const want = "TestPrettyName"
Expand Down
10 changes: 6 additions & 4 deletions device/channels.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ package device
import (
"runtime"
"sync"

"github.com/tailscale/wireguard-go/iobuf"
)

// An outboundQueue is a channel of QueueOutboundElements awaiting encryption.
Expand Down Expand Up @@ -90,7 +92,7 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
}

func (device *Device) needsInboundQueueFinalizer() bool {
return device.pool.messageBuffers.HasAccounting() ||
return iobuf.HasAccounting() ||
device.pool.inboundElements.HasAccounting() ||
device.pool.inboundElementsContainer.HasAccounting()
}
Expand All @@ -101,7 +103,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
case elemsContainer := <-q.c:
elemsContainer.filling.Wait()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
elem.buffer.Release()
device.PutInboundElement(elem)
}
device.PutInboundElementsContainer(elemsContainer)
Expand Down Expand Up @@ -131,7 +133,7 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
}

func (device *Device) needsOutboundQueueFinalizer() bool {
return device.pool.messageBuffers.HasAccounting() ||
return iobuf.HasAccounting() ||
device.pool.outboundElements.HasAccounting() ||
device.pool.outboundElementsContainer.HasAccounting()
}
Expand All @@ -142,7 +144,7 @@ func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
case elemsContainer := <-q.c:
elemsContainer.filling.Wait()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
elem.buffer.Release()
device.PutOutboundElement(elem)
}
device.PutOutboundElementsContainer(elemsContainer)
Expand Down
13 changes: 9 additions & 4 deletions device/channels_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@ package device
import (
"testing"

"github.com/tailscale/wireguard-go/iobuf"
"github.com/tailscale/wireguard-go/waitpool"
)

func TestAutodrainingQueueFinalizerNeedTracksPoolAccounting(t *testing.T) {
unbounded := func() *waitpool.WaitPool { return waitpool.New(0, func() any { return nil }) }
bounded := func() *waitpool.WaitPool { return waitpool.New(1, func() any { return nil }) }

// Force the default raw pool unbounded for the bulk of the test.
origPool := iobuf.DefaultRawPool
iobuf.DefaultRawPool = iobuf.NewRawPool(0)
t.Cleanup(func() { iobuf.DefaultRawPool = origPool })

device := &Device{}
device.pool.messageBuffers = unbounded()
device.pool.inboundElements = unbounded()
device.pool.inboundElementsContainer = unbounded()
device.pool.outboundElements = unbounded()
Expand Down Expand Up @@ -47,11 +52,11 @@ func TestAutodrainingQueueFinalizerNeedTracksPoolAccounting(t *testing.T) {
}

device.pool.outboundElementsContainer = unbounded()
device.pool.messageBuffers = bounded()
iobuf.DefaultRawPool = iobuf.NewRawPool(1)
if !device.needsInboundQueueFinalizer() {
t.Fatal("bounded message buffer pool should need inbound queue finalizer")
t.Fatal("bounded raw buffer pool should need inbound queue finalizer")
}
if !device.needsOutboundQueueFinalizer() {
t.Fatal("bounded message buffer pool should need outbound queue finalizer")
t.Fatal("bounded raw buffer pool should need outbound queue finalizer")
}
}
8 changes: 5 additions & 3 deletions device/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ package device

import (
"time"

"github.com/tailscale/wireguard-go/iobuf"
)

/* Specification constants */
Expand All @@ -27,9 +29,9 @@ const (
)

const (
MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive)
MaxMessageSize = MaxSegmentSize // maximum size of transport message
MaxContentSize = MaxSegmentSize - MessageTransportSize - MessageEncapsulatingTransportSize // maximum size of transport message content
MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive)
MaxMessageSize = iobuf.MaxSegmentSize // maximum size of transport message
MaxContentSize = iobuf.MaxSegmentSize - MessageTransportSize - MessageEncapsulatingTransportSize // maximum size of transport message content
)

/* Implementation constants */
Expand Down
1 change: 0 additions & 1 deletion device/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ type Device struct {
pool struct {
inboundElementsContainer *waitpool.WaitPool
outboundElementsContainer *waitpool.WaitPool
messageBuffers *waitpool.WaitPool
inboundElements *waitpool.WaitPool
outboundElements *waitpool.WaitPool
}
Expand Down
3 changes: 2 additions & 1 deletion device/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/tailscale/wireguard-go/conn"
"github.com/tailscale/wireguard-go/conn/bindtest"
"github.com/tailscale/wireguard-go/iobuf"
"github.com/tailscale/wireguard-go/tun"
"github.com/tailscale/wireguard-go/tun/tuntest"
)
Expand Down Expand Up @@ -437,7 +438,7 @@ type fakeTUNDeviceSized struct {
}

func (t *fakeTUNDeviceSized) File() *os.File { return nil }
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
func (t *fakeTUNDeviceSized) Read(bufs []iobuf.View, offset int) (n int, err error) {
return 0, nil
}
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
Expand Down
21 changes: 6 additions & 15 deletions device/pools.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,23 @@
package device

import (
"github.com/tailscale/wireguard-go/iobuf"
"github.com/tailscale/wireguard-go/waitpool"
)

func (device *Device) PopulatePools() {
device.pool.inboundElementsContainer = waitpool.New(PreallocatedBuffersPerPool, func() any {
device.pool.inboundElementsContainer = waitpool.New(iobuf.MaxPooledBuffers, func() any {
s := make([]*QueueInboundElement, 0, device.BatchSize())
return &QueueInboundElementsContainer{elems: s}
})
device.pool.outboundElementsContainer = waitpool.New(PreallocatedBuffersPerPool, func() any {
device.pool.outboundElementsContainer = waitpool.New(iobuf.MaxPooledBuffers, func() any {
s := make([]*QueueOutboundElement, 0, device.BatchSize())
return &QueueOutboundElementsContainer{elems: s}
})
device.pool.messageBuffers = waitpool.New(PreallocatedBuffersPerPool, func() any {
return new([MaxMessageSize]byte)
})
device.pool.inboundElements = waitpool.New(PreallocatedBuffersPerPool, func() any {
device.pool.inboundElements = waitpool.New(iobuf.MaxPooledBuffers, func() any {
return new(QueueInboundElement)
})
device.pool.outboundElements = waitpool.New(PreallocatedBuffersPerPool, func() any {
device.pool.outboundElements = waitpool.New(iobuf.MaxPooledBuffers, func() any {
return new(QueueOutboundElement)
})
}
Expand Down Expand Up @@ -55,14 +53,6 @@ func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsConta
device.pool.outboundElementsContainer.Put(c)
}

func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
}

func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
device.pool.messageBuffers.Put(msg)
}

func (device *Device) GetInboundElement() *QueueInboundElement {
return device.pool.inboundElements.Get().(*QueueInboundElement)
}
Expand All @@ -78,5 +68,6 @@ func (device *Device) GetOutboundElement() *QueueOutboundElement {

func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
elem.clearPointers()
elem.nonce = 0
device.pool.outboundElements.Put(elem)
}
10 changes: 4 additions & 6 deletions device/queueconstants_android.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ import "github.com/tailscale/wireguard-go/conn"
/* Reduce memory consumption for Android */

const (
QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
MaxSegmentSize = 2200
PreallocatedBuffersPerPool = 4096
QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
)
Loading