From a30525468f6c83bdc600d7afe83b9418918da5b3 Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Fri, 17 Apr 2026 16:23:35 +0530 Subject: [PATCH 1/2] drpcstream: introduce shared BufferPool for ring buffer Add a BufferPool backed by sync.Pool that is shared across all streams within a Manager. The ring buffer now obtains buffers from the pool on Enqueue and transfers ownership to the caller on Dequeue, which advances the tail immediately. This removes the two-step Dequeue/Done protocol and simplifies Close (no longer needs to wait for held buffers). The pool is a required parameter in the Stream constructor, created once per Manager and passed to all streams it creates. Co-Authored-By: Claude Opus 4.6 --- drpcmanager/active_streams_test.go | 2 +- drpcmanager/manager.go | 6 +- drpcstream/buffer_pool.go | 42 ++++++++++++++ drpcstream/ring_buffer.go | 71 ++++++++++------------- drpcstream/ring_buffer_test.go | 90 +++++++++++++----------------- drpcstream/stream.go | 27 ++++----- drpcstream/stream_test.go | 20 +++---- 7 files changed, 139 insertions(+), 119 deletions(-) create mode 100644 drpcstream/buffer_pool.go diff --git a/drpcmanager/active_streams_test.go b/drpcmanager/active_streams_test.go index f463b18..d2e5942 100644 --- a/drpcmanager/active_streams_test.go +++ b/drpcmanager/active_streams_test.go @@ -22,7 +22,7 @@ func testMuxWriter(t *testing.T) *drpcwire.MuxWriter { } func testStream(t *testing.T, id uint64) *drpcstream.Stream { - return drpcstream.New(context.Background(), id, testMuxWriter(t)) + return drpcstream.New(context.Background(), id, testMuxWriter(t), drpcstream.NewBufferPool()) } func TestActiveStreams_AddAndGet(t *testing.T) { diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index 6c77e17..c319ff6 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -61,7 +61,8 @@ type Manager struct { wg sync.WaitGroup // tracks active manageStream goroutines // streams tracks active streams. - streams *activeStreams + streams *activeStreams + recvPool *drpcstream.BufferPool pdone drpcsignal.Chan // signals when NewServerStream has registered the new stream invokes chan invokeInfo // completed invoke info from manageReader to NewServerStream @@ -130,6 +131,7 @@ func NewWithOptions(tr drpc.Transport, kind ManagerKind, opts Options) *Manager m.pendingStreams = make(map[uint64]*pendingStream) m.streams = newActiveStreams() + m.recvPool = drpcstream.NewBufferPool() // set the internal stream options drpcopts.SetStreamTransport(&m.opts.Stream.Internal, m.tr) @@ -268,7 +270,7 @@ func (m *Manager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKin drpcopts.SetStreamStats(&opts.Internal, cb(rpc)) } - stream := drpcstream.NewWithOptions(ctx, sid, m.wr, opts) + stream := drpcstream.NewWithOptions(ctx, sid, m.wr, m.recvPool, opts) if err := m.streams.Add(sid, stream); err != nil { return nil, err diff --git a/drpcstream/buffer_pool.go b/drpcstream/buffer_pool.go new file mode 100644 index 0000000..0785e51 --- /dev/null +++ b/drpcstream/buffer_pool.go @@ -0,0 +1,42 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcstream + +import "sync" + +// BufferPool wraps sync.Pool to provide reusable byte slices for the +// stream receive path. Buffers obtained via Get should be returned via +// Put when no longer needed. Forgetting to Put is safe (GC reclaims) +// but reduces reuse. +type BufferPool struct { + pool sync.Pool +} + +// NewBufferPool returns a new buffer pool. +func NewBufferPool() *BufferPool { + return &BufferPool{ + pool: sync.Pool{ + New: func() interface{} { + b := make([]byte, 0, 4096) + return &b + }, + }, + } +} + +// Get returns a zero-length byte slice from the pool, retaining its +// backing array for reuse. +func (bp *BufferPool) Get() *[]byte { + p := bp.pool.Get().(*[]byte) + *p = (*p)[:0] + return p +} + +// Put returns a buffer to the pool. Nil is safe to pass. +func (bp *BufferPool) Put(b *[]byte) { + if b == nil { + return + } + bp.pool.Put(b) +} diff --git a/drpcstream/ring_buffer.go b/drpcstream/ring_buffer.go index 5cb620a..3056617 100644 --- a/drpcstream/ring_buffer.go +++ b/drpcstream/ring_buffer.go @@ -15,11 +15,12 @@ const defaultRingBufferCapacity = 256 // ringBuffer is a bounded single-producer / single-consumer FIFO queue for // assembled packet data. It sits between manageReader (producer, calls -// Enqueue) and the application goroutine (consumer, calls Dequeue/Done). +// Enqueue) and the application goroutine (consumer, calls Dequeue). // -// Slots are pre-allocated and reused: each slot's backing array grows via -// append to fit incoming data, then stays at its high-water mark, avoiding -// per-message allocation in steady state. +// Buffers are obtained from a shared BufferPool. Enqueue copies data into a +// pooled buffer; Dequeue returns ownership of that buffer to the caller and +// advances the tail immediately. The caller is responsible for returning the +// buffer to the pool via BufferPool.Put. // // After Close, Dequeue drains any queued messages before returning the close // error. This ensures graceful shutdown (KindClose/KindCloseSend) delivers @@ -28,23 +29,24 @@ type ringBuffer struct { mu sync.Mutex cond sync.Cond - buf [][]byte // ring of byte slices - head int // next write position (producer) - tail int // next read position (consumer) - count int // number of occupied slots + pool *BufferPool // shared pool; nil means allocate fresh each time + buf []*[]byte // ring of pooled buffer pointers + head int // next write position (producer) + tail int // next read position (consumer) + count int // number of occupied slots - held bool // true between Dequeue and Done - err error // terminal error, set by Close + err error // terminal error, set by Close } -func (rb *ringBuffer) init() { +func (rb *ringBuffer) init(pool *BufferPool) { rb.cond.L = &rb.mu - rb.buf = make([][]byte, defaultRingBufferCapacity) + rb.pool = pool + rb.buf = make([]*[]byte, defaultRingBufferCapacity) } -// Enqueue copies data into the next write slot. If the buffer is full, it -// blocks until a slot is freed or the buffer is closed. If the buffer is -// closed, Enqueue returns silently without enqueuing. +// Enqueue copies data into a pooled buffer and places it in the next write +// slot. If the buffer is full, it blocks until a slot is freed or the buffer +// is closed. If the buffer is closed, Enqueue returns silently. func (rb *ringBuffer) Enqueue(data []byte) { rb.mu.Lock() defer rb.mu.Unlock() @@ -56,16 +58,19 @@ func (rb *ringBuffer) Enqueue(data []byte) { return } - rb.buf[rb.head] = append(rb.buf[rb.head][:0], data...) + b := rb.pool.Get() + *b = append(*b, data...) + + rb.buf[rb.head] = b rb.head = (rb.head + 1) % len(rb.buf) rb.count++ rb.cond.Broadcast() } -// Dequeue returns the data from the next read slot. If the buffer is empty, -// it blocks until data is available or the buffer is closed. The returned -// slice is valid until Done is called. -func (rb *ringBuffer) Dequeue() ([]byte, error) { +// Dequeue returns the next buffered message. The returned *[]byte is owned +// by the caller; the tail is advanced immediately. If the ring buffer has a +// pool, the caller should return the buffer via BufferPool.Put when done. +func (rb *ringBuffer) Dequeue() (*[]byte, error) { rb.mu.Lock() defer rb.mu.Unlock() @@ -76,37 +81,21 @@ func (rb *ringBuffer) Dequeue() ([]byte, error) { return nil, rb.err } - rb.held = true - return rb.buf[rb.tail], nil -} - -// Done advances the read pointer, making the slot available for reuse. -// It must be called exactly once after each successful Dequeue. -// -// TODO(shubham): remove this method once a shared buffer pool is introduced. -// With a pool, Dequeue will advance the tail immediately and the caller will -// return the buffer to the pool directly. -func (rb *ringBuffer) Done() { - rb.mu.Lock() - defer rb.mu.Unlock() - + b := rb.buf[rb.tail] + rb.buf[rb.tail] = nil rb.tail = (rb.tail + 1) % len(rb.buf) rb.count-- - rb.held = false rb.cond.Broadcast() + + return b, nil } // Close marks the buffer as closed with the given error. All blocked Enqueue -// and Dequeue calls are woken and will return. Close waits for any in-progress -// Dequeue/Done pair to complete before setting the error. Subsequent calls are -// no-ops. +// and Dequeue calls are woken and will return. Subsequent calls are no-ops. func (rb *ringBuffer) Close(err error) { rb.mu.Lock() defer rb.mu.Unlock() - for rb.held { - rb.cond.Wait() - } if rb.err != nil { return } diff --git a/drpcstream/ring_buffer_test.go b/drpcstream/ring_buffer_test.go index 8be9c58..d62d633 100644 --- a/drpcstream/ring_buffer_test.go +++ b/drpcstream/ring_buffer_test.go @@ -13,19 +13,18 @@ import ( func TestRingBuffer_EnqueueDequeue(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(NewBufferPool()) rb.Enqueue([]byte("hello")) data, err := rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("hello")) - rb.Done() + assert.DeepEqual(t, *data, []byte("hello")) } func TestRingBuffer_FIFO(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(NewBufferPool()) rb.Enqueue([]byte("first")) rb.Enqueue([]byte("second")) @@ -34,31 +33,30 @@ func TestRingBuffer_FIFO(t *testing.T) { for _, want := range []string{"first", "second", "third"} { data, err := rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte(want)) - rb.Done() + assert.DeepEqual(t, *data, []byte(want)) } } func TestRingBuffer_DequeueBlocksUntilEnqueue(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(NewBufferPool()) got := make(chan []byte, 1) go func() { data, err := rb.Dequeue() assert.NoError(t, err) - got <- data + got <- *data }() rb.Enqueue([]byte("delayed")) assert.DeepEqual(t, <-got, []byte("delayed")) - rb.Done() } func TestRingBuffer_EnqueueBlocksWhenFull(t *testing.T) { var rb ringBuffer rb.cond.L = &rb.mu - rb.buf = make([][]byte, 2) // capacity 2 + rb.pool = NewBufferPool() + rb.buf = make([]*[]byte, 2) // capacity 2 rb.Enqueue([]byte("a")) rb.Enqueue([]byte("b")) @@ -73,8 +71,7 @@ func TestRingBuffer_EnqueueBlocksWhenFull(t *testing.T) { // Drain one slot. data, err := rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("a")) - rb.Done() + assert.DeepEqual(t, *data, []byte("a")) // Now the blocked Enqueue should complete. <-done @@ -82,18 +79,16 @@ func TestRingBuffer_EnqueueBlocksWhenFull(t *testing.T) { // Verify remaining items. data, err = rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("b")) - rb.Done() + assert.DeepEqual(t, *data, []byte("b")) data, err = rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("c")) - rb.Done() + assert.DeepEqual(t, *data, []byte("c")) } func TestRingBuffer_CloseUnblocksDequeue(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(NewBufferPool()) errch := make(chan error, 1) go func() { @@ -108,7 +103,8 @@ func TestRingBuffer_CloseUnblocksDequeue(t *testing.T) { func TestRingBuffer_CloseUnblocksEnqueue(t *testing.T) { var rb ringBuffer rb.cond.L = &rb.mu - rb.buf = make([][]byte, 1) // capacity 1 + rb.pool = NewBufferPool() + rb.buf = make([]*[]byte, 1) // capacity 1 rb.Enqueue([]byte("fill")) @@ -124,7 +120,7 @@ func TestRingBuffer_CloseUnblocksEnqueue(t *testing.T) { func TestRingBuffer_CloseDrainsQueued(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(NewBufferPool()) rb.Enqueue([]byte("queued")) rb.Close(io.EOF) @@ -132,8 +128,7 @@ func TestRingBuffer_CloseDrainsQueued(t *testing.T) { // Dequeue returns the queued data first. data, err := rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("queued")) - rb.Done() + assert.DeepEqual(t, *data, []byte("queued")) // Next Dequeue returns the close error. data, err = rb.Dequeue() @@ -143,7 +138,7 @@ func TestRingBuffer_CloseDrainsQueued(t *testing.T) { func TestRingBuffer_CloseIdempotent(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(NewBufferPool()) rb.Close(io.EOF) rb.Close(io.ErrUnexpectedEOF) // should not overwrite @@ -154,7 +149,7 @@ func TestRingBuffer_CloseIdempotent(t *testing.T) { func TestRingBuffer_EnqueueAfterClose(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(NewBufferPool()) rb.Close(io.EOF) rb.Enqueue([]byte("dropped")) // should not panic or block @@ -163,44 +158,21 @@ func TestRingBuffer_EnqueueAfterClose(t *testing.T) { func TestRingBuffer_SlotReuse(t *testing.T) { var rb ringBuffer rb.cond.L = &rb.mu - rb.buf = make([][]byte, 2) + rb.pool = NewBufferPool() + rb.buf = make([]*[]byte, 2) // Fill and drain a few rounds to exercise slot reuse. for round := 0; round < 5; round++ { rb.Enqueue([]byte("data")) data, err := rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("data")) - rb.Done() + assert.DeepEqual(t, *data, []byte("data")) } } -func TestRingBuffer_CloseWaitsForHeld(t *testing.T) { - var rb ringBuffer - rb.init() - - rb.Enqueue([]byte("msg")) - - // Dequeue the data but don't call Done yet. - data, err := rb.Dequeue() - assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("msg")) - - closed := make(chan struct{}) - go func() { - rb.Close(io.EOF) - close(closed) - }() - - // Close should be blocked because held is true. - // Call Done to release it. - rb.Done() - <-closed -} - func TestRingBuffer_ConcurrentProducerConsumer(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(NewBufferPool()) const n = 1000 var wg sync.WaitGroup @@ -218,11 +190,25 @@ func TestRingBuffer_ConcurrentProducerConsumer(t *testing.T) { for i := 0; i < n; i++ { data, err := rb.Dequeue() assert.NoError(t, err) - assert.Equal(t, data[0], byte(i)) - rb.Done() + assert.Equal(t, (*data)[0], byte(i)) } }() wg.Wait() rb.Close(io.EOF) } + +func TestRingBuffer_WithPool(t *testing.T) { + pool := NewBufferPool() + var rb ringBuffer + rb.init(pool) + + rb.Enqueue([]byte("pooled")) + + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, *data, []byte("pooled")) + pool.Put(data) + + rb.Close(io.EOF) +} diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 1fa0460..d8545cf 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -53,6 +53,7 @@ type Stream struct { id drpcwire.ID wr *drpcwire.MuxWriter + pool *BufferPool recvQueue ringBuffer wbuf []byte @@ -78,15 +79,15 @@ var _ drpc.Stream = (*Stream)(nil) // New returns a new stream bound to the context with the given stream id and // will use the writer to write messages on. It is important use monotonically // increasing stream ids within a single transport. -func New(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter) *Stream { - return NewWithOptions(ctx, sid, wr, Options{}) +func New(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, pool *BufferPool) *Stream { + return NewWithOptions(ctx, sid, wr, pool, Options{}) } // NewWithOptions returns a new stream bound to the context with the given // stream id and will use the writer to write messages on. It is important use // monotonically increasing stream ids within a single transport. The options // are used to control details of how the Stream operates. -func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, opts Options) *Stream { +func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, pool *BufferPool, opts Options) *Stream { var task *trace.Task if trace.IsEnabled() { kind, rpc := drpcopts.GetStreamKind(&opts.Internal), drpcopts.GetStreamRPC(&opts.Internal) @@ -108,12 +109,12 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, opt pa: pa, - id: drpcwire.ID{Stream: sid}, - wr: wr, + id: drpcwire.ID{Stream: sid}, + wr: wr, + pool: pool, } - // initialize the packet buffer - s.recvQueue.init() + s.recvQueue.init(pool) return s } @@ -414,12 +415,12 @@ func (s *Stream) RawRecv() (data []byte, err error) { s.read.Lock() defer s.read.Unlock() - data, err = s.recvQueue.Dequeue() + b, err := s.recvQueue.Dequeue() if err != nil { return nil, err } - data = append([]byte(nil), data...) - s.recvQueue.Done() + data = append([]byte(nil), *b...) + s.pool.Put(b) return data, nil } @@ -456,12 +457,12 @@ func (s *Stream) MsgRecv(msg drpc.Message, enc drpc.Encoding) (err error) { s.read.Lock() defer s.read.Unlock() - data, err := s.recvQueue.Dequeue() + b, err := s.recvQueue.Dequeue() if err != nil { return err } - err = enc.Unmarshal(data, msg) - s.recvQueue.Done() + err = enc.Unmarshal(*b, msg) + s.pool.Put(b) return err } diff --git a/drpcstream/stream_test.go b/drpcstream/stream_test.go index 9cee8dd..f54b509 100644 --- a/drpcstream/stream_test.go +++ b/drpcstream/stream_test.go @@ -117,7 +117,7 @@ func TestStream_StateTransitions(t *testing.T) { } for _, test := range cases { - st := New(ctx, 1, mw) + st := New(ctx, 1, mw, NewBufferPool()) assert.NoError(t, test.Op(st)) checkErrs(t, test.Send, st.RawWrite(drpcwire.KindMessage, nil)) @@ -169,7 +169,7 @@ func TestStream_Unblocks(t *testing.T) { } for _, test := range cases { - st := New(ctx, 1, mw) + st := New(ctx, 1, mw, NewBufferPool()) ctx.Run(func(ctx context.Context) { _, _ = st.RawRecv() }) assert.NoError(t, test.Op(st)) @@ -180,7 +180,7 @@ func TestStream_Unblocks(t *testing.T) { func TestStream_ContextCancel(t *testing.T) { ctx := context.Background() mw := testMuxWriter(t) - st := New(ctx, 0, mw) + st := New(ctx, 0, mw, NewBufferPool()) child, cancel := context.WithCancel(st.Context()) defer cancel() @@ -195,7 +195,7 @@ func TestStream_ConcurrentCloseCancel(t *testing.T) { defer ctx.Close() mw := testMuxWriter(t) - st := New(ctx, 0, mw) + st := New(ctx, 0, mw, NewBufferPool()) // Close and Cancel concurrently should not panic or deadlock. errch := make(chan error, 1) @@ -219,7 +219,7 @@ func TestStream_PacketBufferReuse(t *testing.T) { mw := testMuxWriter(t) data := make([]byte, 20) mid := uint64(1) - st := New(ctx, 1, mw) + st := New(ctx, 1, mw, NewBufferPool()) ctx.Run(func(ctx context.Context) { for !st.IsTerminated() { @@ -265,7 +265,7 @@ func TestStream_PacketBufferReuse(t *testing.T) { func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { mw := testMuxWriter(t) for _, messageID := range []uint64{1, 2} { - st := New(context.Background(), 1, mw) + st := New(context.Background(), 1, mw, NewBufferPool()) // Close the ring buffer so KindMessage Enqueue doesn't block. st.recvQueue.Close(io.EOF) err := st.HandleFrame(drpcwire.Frame{ @@ -278,7 +278,7 @@ func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { // Invoke and InvokeMetadata frames are rejected on an already-created stream. func TestHandleFrame_InvokeOnExistingStream(t *testing.T) { mw := testMuxWriter(t) - st := New(context.Background(), 1, mw) + st := New(context.Background(), 1, mw, NewBufferPool()) err := handleFrame(st, drpcwire.KindInvoke, 1) assert.Error(t, err) @@ -288,7 +288,7 @@ func TestHandleFrame_InvokeOnExistingStream(t *testing.T) { func TestHandleFrame_InvokeMetadataOnExistingStream(t *testing.T) { mw := testMuxWriter(t) - st := New(context.Background(), 1, mw) + st := New(context.Background(), 1, mw, NewBufferPool()) err := handleFrame(st, drpcwire.KindInvokeMetadata, 1) assert.Error(t, err) @@ -299,7 +299,7 @@ func TestHandleFrame_InvokeMetadataOnExistingStream(t *testing.T) { // Frames arriving after the stream is terminated are silently ignored. func TestHandleFrame_AfterTerminated(t *testing.T) { mw := testMuxWriter(t) - st := New(context.Background(), 1, mw) + st := New(context.Background(), 1, mw, NewBufferPool()) // Terminate the stream via cancel. st.Cancel(context.Canceled) @@ -317,7 +317,7 @@ func TestHandleFrame_MessageDeliveredViaRecv(t *testing.T) { defer ctx.Close() mw := testMuxWriter(t) - st := New(ctx, 1, mw) + st := New(ctx, 1, mw, NewBufferPool()) // Launch receiver before sending to avoid Put blocking. recv := make(chan []byte, 1) From 652af36b914dd18de7bad12f5762266a4cf40ecb Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Mon, 8 Jun 2026 13:19:55 +0000 Subject: [PATCH 2/2] drpcwire: assemble packets into pooled buffers to drop a receive-path copy Receiving a message used to copy its payload twice: PacketAssembler appended each frame into its own backing array, and the ring buffer's Enqueue then copied that array into a pooled buffer. The second copy existed only to move the data into a buffer the consumer could own. The assembler now takes its buffer from the shared BufferPool and assembles frames directly into it, then hands ownership of that buffer to the completed packet. Enqueue stores the pooled pointer as is, and RawRecv/MsgRecv return the buffer to the pool once the message is consumed. This removes one full copy per received message. To let ownership flow through the ring buffer, Packet.Data is now a *[]byte instead of []byte. BufferPool moves from drpcstream to drpcwire because the assembler lives in drpcwire and drpcwire cannot import drpcstream. NewPacketAssembler now requires a pool, and the manager passes its shared recvPool to both the stream assembler and the pending invoke-stream assembler. Every completed packet now holds a pooled buffer, so each path returns it. handlePacket enqueues message buffers and hands ownership to the consumer, and defers Put for control and error packets after their data is read. The manager Puts the buffer after decoding metadata, and after the invoke is consumed; the latter is safe because NewServerStream copies the rpc name out before pdone fires. Also drop SplitN, which had no remaining users. Stream writes split frames through SplitData directly. Co-Authored-By: Claude Opus 4.8 --- drpcmanager/active_streams_test.go | 2 +- drpcmanager/manager.go | 25 ++++++---- drpcstream/ring_buffer.go | 29 ++++++----- drpcstream/ring_buffer_test.go | 64 ++++++++++++++----------- drpcstream/stream.go | 28 ++++++++--- drpcstream/stream_test.go | 20 ++++---- {drpcstream => drpcwire}/buffer_pool.go | 2 +- drpcwire/packet.go | 5 +- drpcwire/packet_assembler.go | 38 +++++++++------ drpcwire/packet_assembler_test.go | 30 ++++++------ drpcwire/rand_test.go | 9 ---- drpcwire/split.go | 26 ---------- drpcwire/split_test.go | 36 -------------- 13 files changed, 143 insertions(+), 171 deletions(-) rename {drpcstream => drpcwire}/buffer_pool.go (97%) delete mode 100644 drpcwire/split_test.go diff --git a/drpcmanager/active_streams_test.go b/drpcmanager/active_streams_test.go index d2e5942..c224037 100644 --- a/drpcmanager/active_streams_test.go +++ b/drpcmanager/active_streams_test.go @@ -22,7 +22,7 @@ func testMuxWriter(t *testing.T) *drpcwire.MuxWriter { } func testStream(t *testing.T, id uint64) *drpcstream.Stream { - return drpcstream.New(context.Background(), id, testMuxWriter(t), drpcstream.NewBufferPool()) + return drpcstream.New(context.Background(), id, testMuxWriter(t), drpcwire.NewBufferPool()) } func TestActiveStreams_AddAndGet(t *testing.T) { diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index c319ff6..a6a88b0 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -16,7 +16,6 @@ import ( "github.com/zeebo/errs" grpcmetadata "google.golang.org/grpc/metadata" - "storj.io/drpc" "storj.io/drpc/drpcdebug" "storj.io/drpc/drpcmetadata" @@ -62,7 +61,7 @@ type Manager struct { // streams tracks active streams. streams *activeStreams - recvPool *drpcstream.BufferPool + recvPool *drpcwire.BufferPool pdone drpcsignal.Chan // signals when NewServerStream has registered the new stream invokes chan invokeInfo // completed invoke info from manageReader to NewServerStream @@ -131,7 +130,7 @@ func NewWithOptions(tr drpc.Transport, kind ManagerKind, opts Options) *Manager m.pendingStreams = make(map[uint64]*pendingStream) m.streams = newActiveStreams() - m.recvPool = drpcstream.NewBufferPool() + m.recvPool = drpcwire.NewBufferPool() // set the internal stream options drpcopts.SetStreamTransport(&m.opts.Stream.Internal, m.tr) @@ -222,7 +221,7 @@ func (m *Manager) manageReader() { func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error { ps, ok := m.pendingStreams[fr.ID.Stream] if !ok { - ps = &pendingStream{pa: drpcwire.NewPacketAssembler()} + ps = &pendingStream{pa: drpcwire.NewPacketAssembler(m.recvPool)} m.pendingStreams[fr.ID.Stream] = ps } pkt, packetReady, err := ps.pa.AppendFrame(fr) @@ -235,7 +234,8 @@ func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error { // Metadata arrives before invoke; accumulate it and wait for the invoke. if pkt.Kind == drpcwire.KindInvokeMetadata { - meta, err := drpcmetadata.Decode(pkt.Data) + meta, err := drpcmetadata.Decode(*pkt.Data) + m.recvPool.Put(pkt.Data) if err != nil { return err } @@ -245,11 +245,12 @@ func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error { // Invoke packet completes the sequence. Send to NewServerStream. select { - case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: pkt.Data, metadata: ps.metadata}: + case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: *pkt.Data, metadata: ps.metadata}: // Wait for NewServerStream to finish stream creation before reading the // next frame. This guarantees curr is set for subsequent non-invoke // packets. m.pdone.Recv() + m.recvPool.Put(pkt.Data) // TODO: reuse pending stream delete(m.pendingStreams, fr.ID.Stream) case <-m.sigs.term.Signal(): @@ -262,7 +263,9 @@ func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error { // // newStream creates a stream value with the appropriate configuration for this manager. -func (m *Manager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKind, rpc string) (*drpcstream.Stream, error) { +func (m *Manager) newStream( + ctx context.Context, sid uint64, kind drpc.StreamKind, rpc string, +) (*drpcstream.Stream, error) { opts := m.opts.Stream drpcopts.SetStreamKind(&opts.Internal, kind) drpcopts.SetStreamRPC(&opts.Internal, rpc) @@ -336,7 +339,9 @@ func (m *Manager) Close() error { } // NewClientStream starts a stream on the managed transport for use by a client. -func (m *Manager) NewClientStream(ctx context.Context, rpc string) (stream *drpcstream.Stream, err error) { +func (m *Manager) NewClientStream( + ctx context.Context, rpc string, +) (stream *drpcstream.Stream, err error) { if err := ctx.Err(); err != nil { return nil, err } @@ -346,7 +351,9 @@ func (m *Manager) NewClientStream(ctx context.Context, rpc string) (stream *drpc // NewServerStream starts a stream on the managed transport for use by a server. // It does this by waiting for the client to issue an invoke message and // returning the details. -func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Stream, rpc string, err error) { +func (m *Manager) NewServerStream( + ctx context.Context, +) (stream *drpcstream.Stream, rpc string, err error) { select { case <-ctx.Done(): return nil, "", ctx.Err() diff --git a/drpcstream/ring_buffer.go b/drpcstream/ring_buffer.go index 3056617..0eb3e48 100644 --- a/drpcstream/ring_buffer.go +++ b/drpcstream/ring_buffer.go @@ -3,7 +3,11 @@ package drpcstream -import "sync" +import ( + "sync" + + "storj.io/drpc/drpcwire" +) // defaultRingBufferCapacity is the number of messages the ring buffer can // hold before the producer blocks. This decouples the transport reader @@ -18,7 +22,8 @@ const defaultRingBufferCapacity = 256 // Enqueue) and the application goroutine (consumer, calls Dequeue). // // Buffers are obtained from a shared BufferPool. Enqueue copies data into a -// pooled buffer; Dequeue returns ownership of that buffer to the caller and +// pooled buffer, while EnqueueOwned takes ownership of an already-pooled buffer +// without copying; Dequeue returns ownership of that buffer to the caller and // advances the tail immediately. The caller is responsible for returning the // buffer to the pool via BufferPool.Put. // @@ -29,16 +34,16 @@ type ringBuffer struct { mu sync.Mutex cond sync.Cond - pool *BufferPool // shared pool; nil means allocate fresh each time - buf []*[]byte // ring of pooled buffer pointers - head int // next write position (producer) - tail int // next read position (consumer) - count int // number of occupied slots + pool *drpcwire.BufferPool // shared pool; nil means allocate fresh each time + buf []*[]byte // ring of pooled buffer pointers + head int // next write position (producer) + tail int // next read position (consumer) + count int // number of occupied slots err error // terminal error, set by Close } -func (rb *ringBuffer) init(pool *BufferPool) { +func (rb *ringBuffer) init(pool *drpcwire.BufferPool) { rb.cond.L = &rb.mu rb.pool = pool rb.buf = make([]*[]byte, defaultRingBufferCapacity) @@ -47,7 +52,7 @@ func (rb *ringBuffer) init(pool *BufferPool) { // Enqueue copies data into a pooled buffer and places it in the next write // slot. If the buffer is full, it blocks until a slot is freed or the buffer // is closed. If the buffer is closed, Enqueue returns silently. -func (rb *ringBuffer) Enqueue(data []byte) { +func (rb *ringBuffer) Enqueue(data *[]byte) { rb.mu.Lock() defer rb.mu.Unlock() @@ -55,13 +60,11 @@ func (rb *ringBuffer) Enqueue(data []byte) { rb.cond.Wait() } if rb.err != nil { + rb.pool.Put(data) return } - b := rb.pool.Get() - *b = append(*b, data...) - - rb.buf[rb.head] = b + rb.buf[rb.head] = data rb.head = (rb.head + 1) % len(rb.buf) rb.count++ rb.cond.Broadcast() diff --git a/drpcstream/ring_buffer_test.go b/drpcstream/ring_buffer_test.go index d62d633..e1a7b03 100644 --- a/drpcstream/ring_buffer_test.go +++ b/drpcstream/ring_buffer_test.go @@ -9,13 +9,23 @@ import ( "testing" "github.com/zeebo/assert" + + "storj.io/drpc/drpcwire" ) +// enqueue mimics the producer: it takes a pooled buffer, fills it, and hands +// ownership to the ring buffer, matching how handlePacket feeds the queue. +func enqueue(rb *ringBuffer, data []byte) { + b := rb.pool.Get() + *b = append(*b, data...) + rb.Enqueue(b) +} + func TestRingBuffer_EnqueueDequeue(t *testing.T) { var rb ringBuffer - rb.init(NewBufferPool()) + rb.init(drpcwire.NewBufferPool()) - rb.Enqueue([]byte("hello")) + enqueue(&rb, []byte("hello")) data, err := rb.Dequeue() assert.NoError(t, err) @@ -24,11 +34,11 @@ func TestRingBuffer_EnqueueDequeue(t *testing.T) { func TestRingBuffer_FIFO(t *testing.T) { var rb ringBuffer - rb.init(NewBufferPool()) + rb.init(drpcwire.NewBufferPool()) - rb.Enqueue([]byte("first")) - rb.Enqueue([]byte("second")) - rb.Enqueue([]byte("third")) + enqueue(&rb, []byte("first")) + enqueue(&rb, []byte("second")) + enqueue(&rb, []byte("third")) for _, want := range []string{"first", "second", "third"} { data, err := rb.Dequeue() @@ -39,7 +49,7 @@ func TestRingBuffer_FIFO(t *testing.T) { func TestRingBuffer_DequeueBlocksUntilEnqueue(t *testing.T) { var rb ringBuffer - rb.init(NewBufferPool()) + rb.init(drpcwire.NewBufferPool()) got := make(chan []byte, 1) go func() { @@ -48,23 +58,23 @@ func TestRingBuffer_DequeueBlocksUntilEnqueue(t *testing.T) { got <- *data }() - rb.Enqueue([]byte("delayed")) + enqueue(&rb, []byte("delayed")) assert.DeepEqual(t, <-got, []byte("delayed")) } func TestRingBuffer_EnqueueBlocksWhenFull(t *testing.T) { var rb ringBuffer rb.cond.L = &rb.mu - rb.pool = NewBufferPool() + rb.pool = drpcwire.NewBufferPool() rb.buf = make([]*[]byte, 2) // capacity 2 - rb.Enqueue([]byte("a")) - rb.Enqueue([]byte("b")) + enqueue(&rb, []byte("a")) + enqueue(&rb, []byte("b")) // Third enqueue should block until we drain one. done := make(chan struct{}) go func() { - rb.Enqueue([]byte("c")) + enqueue(&rb, []byte("c")) close(done) }() @@ -88,7 +98,7 @@ func TestRingBuffer_EnqueueBlocksWhenFull(t *testing.T) { func TestRingBuffer_CloseUnblocksDequeue(t *testing.T) { var rb ringBuffer - rb.init(NewBufferPool()) + rb.init(drpcwire.NewBufferPool()) errch := make(chan error, 1) go func() { @@ -103,14 +113,14 @@ func TestRingBuffer_CloseUnblocksDequeue(t *testing.T) { func TestRingBuffer_CloseUnblocksEnqueue(t *testing.T) { var rb ringBuffer rb.cond.L = &rb.mu - rb.pool = NewBufferPool() + rb.pool = drpcwire.NewBufferPool() rb.buf = make([]*[]byte, 1) // capacity 1 - rb.Enqueue([]byte("fill")) + enqueue(&rb, []byte("fill")) done := make(chan struct{}) go func() { - rb.Enqueue([]byte("blocked")) + enqueue(&rb, []byte("blocked")) close(done) }() @@ -120,9 +130,9 @@ func TestRingBuffer_CloseUnblocksEnqueue(t *testing.T) { func TestRingBuffer_CloseDrainsQueued(t *testing.T) { var rb ringBuffer - rb.init(NewBufferPool()) + rb.init(drpcwire.NewBufferPool()) - rb.Enqueue([]byte("queued")) + enqueue(&rb, []byte("queued")) rb.Close(io.EOF) // Dequeue returns the queued data first. @@ -138,7 +148,7 @@ func TestRingBuffer_CloseDrainsQueued(t *testing.T) { func TestRingBuffer_CloseIdempotent(t *testing.T) { var rb ringBuffer - rb.init(NewBufferPool()) + rb.init(drpcwire.NewBufferPool()) rb.Close(io.EOF) rb.Close(io.ErrUnexpectedEOF) // should not overwrite @@ -149,21 +159,21 @@ func TestRingBuffer_CloseIdempotent(t *testing.T) { func TestRingBuffer_EnqueueAfterClose(t *testing.T) { var rb ringBuffer - rb.init(NewBufferPool()) + rb.init(drpcwire.NewBufferPool()) rb.Close(io.EOF) - rb.Enqueue([]byte("dropped")) // should not panic or block + enqueue(&rb, []byte("dropped")) // should not panic or block } func TestRingBuffer_SlotReuse(t *testing.T) { var rb ringBuffer rb.cond.L = &rb.mu - rb.pool = NewBufferPool() + rb.pool = drpcwire.NewBufferPool() rb.buf = make([]*[]byte, 2) // Fill and drain a few rounds to exercise slot reuse. for round := 0; round < 5; round++ { - rb.Enqueue([]byte("data")) + enqueue(&rb, []byte("data")) data, err := rb.Dequeue() assert.NoError(t, err) assert.DeepEqual(t, *data, []byte("data")) @@ -172,7 +182,7 @@ func TestRingBuffer_SlotReuse(t *testing.T) { func TestRingBuffer_ConcurrentProducerConsumer(t *testing.T) { var rb ringBuffer - rb.init(NewBufferPool()) + rb.init(drpcwire.NewBufferPool()) const n = 1000 var wg sync.WaitGroup @@ -181,7 +191,7 @@ func TestRingBuffer_ConcurrentProducerConsumer(t *testing.T) { go func() { defer wg.Done() for i := 0; i < n; i++ { - rb.Enqueue([]byte{byte(i)}) + enqueue(&rb, []byte{byte(i)}) } }() @@ -199,11 +209,11 @@ func TestRingBuffer_ConcurrentProducerConsumer(t *testing.T) { } func TestRingBuffer_WithPool(t *testing.T) { - pool := NewBufferPool() + pool := drpcwire.NewBufferPool() var rb ringBuffer rb.init(pool) - rb.Enqueue([]byte("pooled")) + enqueue(&rb, []byte("pooled")) data, err := rb.Dequeue() assert.NoError(t, err) diff --git a/drpcstream/stream.go b/drpcstream/stream.go index d8545cf..836795c 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -11,7 +11,6 @@ import ( "sync" "github.com/zeebo/errs" - "storj.io/drpc" "storj.io/drpc/drpcctx" "storj.io/drpc/drpcdebug" @@ -53,7 +52,7 @@ type Stream struct { id drpcwire.ID wr *drpcwire.MuxWriter - pool *BufferPool + pool *drpcwire.BufferPool recvQueue ringBuffer wbuf []byte @@ -79,7 +78,9 @@ var _ drpc.Stream = (*Stream)(nil) // New returns a new stream bound to the context with the given stream id and // will use the writer to write messages on. It is important use monotonically // increasing stream ids within a single transport. -func New(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, pool *BufferPool) *Stream { +func New( + ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, pool *drpcwire.BufferPool, +) *Stream { return NewWithOptions(ctx, sid, wr, pool, Options{}) } @@ -87,7 +88,9 @@ func New(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, pool *BufferPo // stream id and will use the writer to write messages on. It is important use // monotonically increasing stream ids within a single transport. The options // are used to control details of how the Stream operates. -func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, pool *BufferPool, opts Options) *Stream { +func NewWithOptions( + ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, pool *drpcwire.BufferPool, opts Options, +) *Stream { var task *trace.Task if trace.IsEnabled() { kind, rpc := drpcopts.GetStreamKind(&opts.Internal), drpcopts.GetStreamRPC(&opts.Internal) @@ -96,7 +99,9 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, poo } } - pa := drpcwire.NewPacketAssembler() + // When a pool is available, assemble directly into pooled buffers so the + // completed packet can be handed off to the recv queue without a copy. + pa := drpcwire.NewPacketAssembler(pool) pa.SetStreamID(sid) s := &Stream{ @@ -222,15 +227,24 @@ func (s *Stream) HandleFrame(fr drpcwire.Frame) (err error) { // returns any major errors that should terminate the transport the stream is // operating on. func (s *Stream) handlePacket(pkt drpcwire.Packet) (err error) { - drpcopts.GetStreamStats(&s.opts.Internal).AddRead(uint64(len(pkt.Data))) + drpcopts.GetStreamStats(&s.opts.Internal).AddRead(uint64(len(*pkt.Data))) s.log("HANDLE", pkt.String) if pkt.Kind == drpcwire.KindMessage { + // The assembler handed us ownership of the pooled buffer; enqueue + // it directly without copying. s.recvQueue.Enqueue(pkt.Data) return nil } + // Control and error packets are consumed here and not handed to the recv + // queue; return any pooled buffer once we're done reading pkt.Data. The + // defer runs after the switch below, so data stays valid while in use. + if pkt.Data != nil { + defer s.pool.Put(pkt.Data) + } + s.mu.Lock() defer s.mu.Unlock() @@ -241,7 +255,7 @@ func (s *Stream) handlePacket(pkt drpcwire.Packet) (err error) { return err case drpcwire.KindError: - err := drpcwire.UnmarshalError(pkt.Data) + err := drpcwire.UnmarshalError(*pkt.Data) s.sigs.send.Set(io.EOF) // in this state, gRPC returns io.EOF on send. s.terminate(err) return nil diff --git a/drpcstream/stream_test.go b/drpcstream/stream_test.go index f54b509..724fab6 100644 --- a/drpcstream/stream_test.go +++ b/drpcstream/stream_test.go @@ -117,7 +117,7 @@ func TestStream_StateTransitions(t *testing.T) { } for _, test := range cases { - st := New(ctx, 1, mw, NewBufferPool()) + st := New(ctx, 1, mw, drpcwire.NewBufferPool()) assert.NoError(t, test.Op(st)) checkErrs(t, test.Send, st.RawWrite(drpcwire.KindMessage, nil)) @@ -169,7 +169,7 @@ func TestStream_Unblocks(t *testing.T) { } for _, test := range cases { - st := New(ctx, 1, mw, NewBufferPool()) + st := New(ctx, 1, mw, drpcwire.NewBufferPool()) ctx.Run(func(ctx context.Context) { _, _ = st.RawRecv() }) assert.NoError(t, test.Op(st)) @@ -180,7 +180,7 @@ func TestStream_Unblocks(t *testing.T) { func TestStream_ContextCancel(t *testing.T) { ctx := context.Background() mw := testMuxWriter(t) - st := New(ctx, 0, mw, NewBufferPool()) + st := New(ctx, 0, mw, drpcwire.NewBufferPool()) child, cancel := context.WithCancel(st.Context()) defer cancel() @@ -195,7 +195,7 @@ func TestStream_ConcurrentCloseCancel(t *testing.T) { defer ctx.Close() mw := testMuxWriter(t) - st := New(ctx, 0, mw, NewBufferPool()) + st := New(ctx, 0, mw, drpcwire.NewBufferPool()) // Close and Cancel concurrently should not panic or deadlock. errch := make(chan error, 1) @@ -219,7 +219,7 @@ func TestStream_PacketBufferReuse(t *testing.T) { mw := testMuxWriter(t) data := make([]byte, 20) mid := uint64(1) - st := New(ctx, 1, mw, NewBufferPool()) + st := New(ctx, 1, mw, drpcwire.NewBufferPool()) ctx.Run(func(ctx context.Context) { for !st.IsTerminated() { @@ -265,7 +265,7 @@ func TestStream_PacketBufferReuse(t *testing.T) { func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { mw := testMuxWriter(t) for _, messageID := range []uint64{1, 2} { - st := New(context.Background(), 1, mw, NewBufferPool()) + st := New(context.Background(), 1, mw, drpcwire.NewBufferPool()) // Close the ring buffer so KindMessage Enqueue doesn't block. st.recvQueue.Close(io.EOF) err := st.HandleFrame(drpcwire.Frame{ @@ -278,7 +278,7 @@ func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { // Invoke and InvokeMetadata frames are rejected on an already-created stream. func TestHandleFrame_InvokeOnExistingStream(t *testing.T) { mw := testMuxWriter(t) - st := New(context.Background(), 1, mw, NewBufferPool()) + st := New(context.Background(), 1, mw, drpcwire.NewBufferPool()) err := handleFrame(st, drpcwire.KindInvoke, 1) assert.Error(t, err) @@ -288,7 +288,7 @@ func TestHandleFrame_InvokeOnExistingStream(t *testing.T) { func TestHandleFrame_InvokeMetadataOnExistingStream(t *testing.T) { mw := testMuxWriter(t) - st := New(context.Background(), 1, mw, NewBufferPool()) + st := New(context.Background(), 1, mw, drpcwire.NewBufferPool()) err := handleFrame(st, drpcwire.KindInvokeMetadata, 1) assert.Error(t, err) @@ -299,7 +299,7 @@ func TestHandleFrame_InvokeMetadataOnExistingStream(t *testing.T) { // Frames arriving after the stream is terminated are silently ignored. func TestHandleFrame_AfterTerminated(t *testing.T) { mw := testMuxWriter(t) - st := New(context.Background(), 1, mw, NewBufferPool()) + st := New(context.Background(), 1, mw, drpcwire.NewBufferPool()) // Terminate the stream via cancel. st.Cancel(context.Canceled) @@ -317,7 +317,7 @@ func TestHandleFrame_MessageDeliveredViaRecv(t *testing.T) { defer ctx.Close() mw := testMuxWriter(t) - st := New(ctx, 1, mw, NewBufferPool()) + st := New(ctx, 1, mw, drpcwire.NewBufferPool()) // Launch receiver before sending to avoid Put blocking. recv := make(chan []byte, 1) diff --git a/drpcstream/buffer_pool.go b/drpcwire/buffer_pool.go similarity index 97% rename from drpcstream/buffer_pool.go rename to drpcwire/buffer_pool.go index 0785e51..59a8196 100644 --- a/drpcstream/buffer_pool.go +++ b/drpcwire/buffer_pool.go @@ -1,7 +1,7 @@ // Copyright (C) 2026 Cockroach Labs. // See LICENSE for copying information. -package drpcstream +package drpcwire import "sync" diff --git a/drpcwire/packet.go b/drpcwire/packet.go index 447eb40..fa2c8af 100644 --- a/drpcwire/packet.go +++ b/drpcwire/packet.go @@ -149,8 +149,7 @@ func AppendFrame(buf []byte, fr Frame) []byte { // Packet is a single message sent by drpc. type Packet struct { - // Data is the payload of the packet. - Data []byte + Data *[]byte // ID is the identifier for the packet. ID ID @@ -168,5 +167,5 @@ type Packet struct { // String returns a human readable form of the packet. func (p Packet) String() string { return fmt.Sprintf("", - p.ID.Stream, p.ID.Message, len(p.Data), p.Kind) + p.ID.Stream, p.ID.Message, len(*p.Data), p.Kind) } diff --git a/drpcwire/packet_assembler.go b/drpcwire/packet_assembler.go index 2cf7fae..b958360 100644 --- a/drpcwire/packet_assembler.go +++ b/drpcwire/packet_assembler.go @@ -1,8 +1,6 @@ package drpcwire -import ( - "storj.io/drpc" -) +import "storj.io/drpc" // PacketAssembler assembles frames into complete packets, enforcing wire // protocol invariants: @@ -11,19 +9,27 @@ import ( // - Message IDs must be monotonically increasing. // - Frame kind must not change within a single packet (multi-frame). // +// When constructed with a BufferPool, the assembler assembles directly into a +// pooled buffer and transfers its ownership to the returned packet (via +// Packet.Buf), removing a copy on the receive path. Without a pool it reuses +// its own backing array, and the caller must consume packet.Data before the +// next AppendFrame call. +// // It is not safe for concurrent use. type PacketAssembler struct { + pool *BufferPool pk Packet assembling bool streamInitialized bool } // NewPacketAssembler returns a new PacketAssembler ready to assemble frames. -func NewPacketAssembler() PacketAssembler { +func NewPacketAssembler(pool *BufferPool) PacketAssembler { return PacketAssembler{ pk: Packet{ ID: ID{Stream: 0, Message: 1}, }, + pool: pool, } } @@ -36,6 +42,9 @@ func (pa *PacketAssembler) SetStreamID(streamID uint64) { // Reset clears all assembly state, preparing the assembler for a new stream. func (pa *PacketAssembler) Reset() { + if pa.pk.Data != nil { + pa.pool.Put(pa.pk.Data) + } pa.pk = Packet{ ID: ID{Stream: 0, Message: 1}, } @@ -46,7 +55,7 @@ func (pa *PacketAssembler) Reset() { // AppendFrame adds a frame to the in-progress packet. It returns the completed // packet and true when a frame with Done=true is received. It returns false // when more frames are needed to complete the packet. -func (pa *PacketAssembler) AppendFrame(fr Frame) (packet Packet, packetReady bool, err error) { +func (pa *PacketAssembler) AppendFrame(fr Frame) (Packet, bool, error) { // Enforce stream ID consistency: infer from first frame or reject mismatches. if !pa.streamInitialized { pa.pk.ID.Stream = fr.ID.Stream @@ -60,8 +69,11 @@ func (pa *PacketAssembler) AppendFrame(fr Frame) (packet Packet, packetReady boo return Packet{}, false, drpc.ProtocolError.New( "message id monotonicity violation: got %v, expected >= %v", fr.ID.Message, pa.pk.ID.Message) } else if fr.ID.Message > pa.pk.ID.Message || !pa.assembling { - // New message: reset the buffer and start assembling. - pa.pk.Data = pa.pk.Data[:0] + if pa.pk.Data == nil { + pa.pk.Data = pa.pool.Get() + } else { + *pa.pk.Data = (*pa.pk.Data)[:0] + } pa.assembling = true pa.pk.ID.Message = fr.ID.Message } else if fr.Kind != pa.pk.Kind { @@ -69,8 +81,9 @@ func (pa *PacketAssembler) AppendFrame(fr Frame) (packet Packet, packetReady boo "frame kind changed mid-packet: got %v, expected %v", fr.Kind, pa.pk.Kind) } - // TODO(shubham): add buf reuse - pa.pk.Data = append(pa.pk.Data, fr.Data...) + // Assemble directly into the pooled buffer so the completed packet can + // be handed off down the receive path without another copy. + *pa.pk.Data = append(*pa.pk.Data, fr.Data...) pa.pk.Kind = fr.Kind pa.pk.Control = fr.Control @@ -78,12 +91,9 @@ func (pa *PacketAssembler) AppendFrame(fr Frame) (packet Packet, packetReady boo return Packet{}, false, nil } - packet = pa.pk - + packet := pa.pk pa.assembling = false pa.pk.ID.Message = fr.ID.Message + 1 - // Reuse the backing array: the caller must consume packet.Data before the - // next AppendFrame call, as it will be overwritten. - pa.pk.Data = pa.pk.Data[:0] + pa.pk.Data = nil return packet, true, nil } diff --git a/drpcwire/packet_assembler_test.go b/drpcwire/packet_assembler_test.go index 41cf70d..4f555bb 100644 --- a/drpcwire/packet_assembler_test.go +++ b/drpcwire/packet_assembler_test.go @@ -13,7 +13,7 @@ import ( ) func TestPacketAssembler_WrongStreamID(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) _, _, err := pa.AppendFrame(Frame{ @@ -27,7 +27,7 @@ func TestPacketAssembler_WrongStreamID(t *testing.T) { } func TestPacketAssembler_StreamIDInferredFromFirstFrame(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) // First frame sets the stream ID implicitly. _, _, err := pa.AppendFrame(Frame{ @@ -49,7 +49,7 @@ func TestPacketAssembler_StreamIDInferredFromFirstFrame(t *testing.T) { // A frame with a message ID lower than a previously completed message is rejected. func TestPacketAssembler_MessageMonotonicity(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) // m3 completes, next expected becomes 4. @@ -70,7 +70,7 @@ func TestPacketAssembler_MessageMonotonicity(t *testing.T) { // When a higher message ID arrives mid-assembly, the in-progress data is // silently discarded and a new packet begins. func TestPacketAssembler_HigherMsgDiscardsInProgress(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) // Start accumulating m1. @@ -86,13 +86,13 @@ func TestPacketAssembler_HigherMsgDiscardsInProgress(t *testing.T) { }) assert.NoError(t, err) assert.That(t, ready) - assert.DeepEqual(t, pkt.Data, []byte("kept")) + assert.DeepEqual(t, *pkt.Data, []byte("kept")) } // Continuation frames (same message ID, mid-assembly) must carry the same // kind as the first frame. A kind change mid-packet is a protocol error. func TestPacketAssembler_KindChangeWithinPacket(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) _, _, err := pa.AppendFrame(Frame{ @@ -110,7 +110,7 @@ func TestPacketAssembler_KindChangeWithinPacket(t *testing.T) { // Multiple continuation frames for the same message accumulate data correctly. func TestPacketAssembler_MultiFrameDataAccumulation(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) _, ready, err := pa.AppendFrame(Frame{ @@ -130,14 +130,14 @@ func TestPacketAssembler_MultiFrameDataAccumulation(t *testing.T) { }) assert.NoError(t, err) assert.That(t, ready) - assert.DeepEqual(t, pkt.Data, []byte("hello world")) + assert.DeepEqual(t, *pkt.Data, []byte("hello world")) } // Multi-frame assembly works when the message ID is greater than the initial // expected ID (e.g., on the server side where invoke consumed earlier message // IDs). Continuation frames must accumulate data, not reset on each frame. func TestPacketAssembler_MultiFrameWithSkippedMessageID(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) // msg=3 is greater than initial expected message ID=1. @@ -158,12 +158,12 @@ func TestPacketAssembler_MultiFrameWithSkippedMessageID(t *testing.T) { }) assert.NoError(t, err) assert.That(t, ready) - assert.DeepEqual(t, pkt.Data, []byte("hello world")) + assert.DeepEqual(t, *pkt.Data, []byte("hello world")) } // Once a message completes (done=true), the same message ID is rejected. func TestPacketAssembler_DonePreventsReplay(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) // m1 completes → next expected becomes 2. @@ -185,7 +185,7 @@ func TestPacketAssembler_DonePreventsReplay(t *testing.T) { // across messages. A KindMessage followed by a KindClose for the next message // should be accepted without error. func TestPacketAssembler_KindChangeAcrossMessages(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) // Multi-frame message 1 with KindMessage. @@ -199,7 +199,7 @@ func TestPacketAssembler_KindChangeAcrossMessages(t *testing.T) { }) assert.NoError(t, err) assert.That(t, ready) - assert.DeepEqual(t, pkt.Data, []byte("abcd")) + assert.DeepEqual(t, *pkt.Data, []byte("abcd")) // Message 2 with a different kind — should not trigger kind check. pkt, ready, err = pa.AppendFrame(Frame{ @@ -212,7 +212,7 @@ func TestPacketAssembler_KindChangeAcrossMessages(t *testing.T) { // Reset clears all state so the assembler can be reused for a new stream. func TestPacketAssembler_Reset(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) // Complete a packet on stream 1. @@ -230,6 +230,6 @@ func TestPacketAssembler_Reset(t *testing.T) { }) assert.NoError(t, err) assert.That(t, ready) - assert.DeepEqual(t, pkt.Data, []byte("new")) + assert.DeepEqual(t, *pkt.Data, []byte("new")) assert.Equal(t, pkt.ID.Stream, uint64(2)) } diff --git a/drpcwire/rand_test.go b/drpcwire/rand_test.go index 1eeb4b4..6d0f330 100644 --- a/drpcwire/rand_test.go +++ b/drpcwire/rand_test.go @@ -75,12 +75,3 @@ func RandFrame() Frame { Done: RandBool(), } } - -func RandPacket() Packet { - kind := RandKind() - return Packet{ - Data: RandBytes(10 * payloadSize[kind]()), - ID: RandID(), - Kind: kind, - } -} diff --git a/drpcwire/split.go b/drpcwire/split.go index 1189615..3b3c362 100644 --- a/drpcwire/split.go +++ b/drpcwire/split.go @@ -3,32 +3,6 @@ package drpcwire -// SplitN splits the marshaled form of the Packet into a number of -// frames such that each frame is at most n bytes. It calls -// the callback with every such frame. If n is zero, a reasonable -// default is used. -func SplitN(pkt Packet, n int, cb func(fr Frame) error) error { - for { - fr := Frame{ - Data: pkt.Data, - ID: pkt.ID, - Kind: pkt.Kind, - Control: pkt.Control, - Done: true, - } - - fr.Data, pkt.Data = SplitData(pkt.Data, n) - fr.Done = len(pkt.Data) == 0 - - if err := cb(fr); err != nil { - return err - } - if fr.Done { - return nil - } - } -} - // SplitData is used to split a buffer if it is larger than n bytes. // If n is zero, a reasonable default is used. If n is less than zero // then it does not split. diff --git a/drpcwire/split_test.go b/drpcwire/split_test.go deleted file mode 100644 index 29daf4c..0000000 --- a/drpcwire/split_test.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpcwire - -import ( - "bytes" - "math/rand" - "testing" - - "github.com/zeebo/assert" -) - -func TestSplit(t *testing.T) { - for i := 0; i < 1000; i++ { - pkt, done, n := RandPacket(), false, rand.Intn(10)-1 - if size := rand.Intn(100); size < len(pkt.Data) { - pkt.Data = pkt.Data[:size] - } - - var buf []byte - assert.NoError(t, SplitN(pkt, n, func(fr Frame) error { - assert.That(t, !done) - assert.That(t, len(fr.Data) <= n || - (n == -1 && len(pkt.Data) == len(fr.Data)) || - (n == 0 && len(fr.Data) <= 1024)) - assert.Equal(t, pkt.Kind, fr.Kind) - done = fr.Done - buf = append(buf, fr.Data...) - return nil - })) - - assert.That(t, done) - assert.That(t, bytes.Equal(pkt.Data, buf)) - } -}