diff --git a/drpcmanager/active_streams_test.go b/drpcmanager/active_streams_test.go index f463b188..c224037c 100644 --- a/drpcmanager/active_streams_test.go +++ b/drpcmanager/active_streams_test.go @@ -22,7 +22,7 @@ func testMuxWriter(t *testing.T) *drpcwire.MuxWriter { } func testStream(t *testing.T, id uint64) *drpcstream.Stream { - return drpcstream.New(context.Background(), id, testMuxWriter(t)) + return drpcstream.New(context.Background(), id, testMuxWriter(t), drpcwire.NewBufferPool()) } func TestActiveStreams_AddAndGet(t *testing.T) { diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index 6c77e174..a6a88b03 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -16,7 +16,6 @@ import ( "github.com/zeebo/errs" grpcmetadata "google.golang.org/grpc/metadata" - "storj.io/drpc" "storj.io/drpc/drpcdebug" "storj.io/drpc/drpcmetadata" @@ -61,7 +60,8 @@ type Manager struct { wg sync.WaitGroup // tracks active manageStream goroutines // streams tracks active streams. - streams *activeStreams + streams *activeStreams + recvPool *drpcwire.BufferPool pdone drpcsignal.Chan // signals when NewServerStream has registered the new stream invokes chan invokeInfo // completed invoke info from manageReader to NewServerStream @@ -130,6 +130,7 @@ func NewWithOptions(tr drpc.Transport, kind ManagerKind, opts Options) *Manager m.pendingStreams = make(map[uint64]*pendingStream) m.streams = newActiveStreams() + m.recvPool = drpcwire.NewBufferPool() // set the internal stream options drpcopts.SetStreamTransport(&m.opts.Stream.Internal, m.tr) @@ -220,7 +221,7 @@ func (m *Manager) manageReader() { func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error { ps, ok := m.pendingStreams[fr.ID.Stream] if !ok { - ps = &pendingStream{pa: drpcwire.NewPacketAssembler()} + ps = &pendingStream{pa: drpcwire.NewPacketAssembler(m.recvPool)} m.pendingStreams[fr.ID.Stream] = ps } pkt, packetReady, err := ps.pa.AppendFrame(fr) @@ -233,7 +234,8 @@ func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error { // Metadata arrives before invoke; accumulate it and wait for the invoke. if pkt.Kind == drpcwire.KindInvokeMetadata { - meta, err := drpcmetadata.Decode(pkt.Data) + meta, err := drpcmetadata.Decode(*pkt.Data) + m.recvPool.Put(pkt.Data) if err != nil { return err } @@ -243,11 +245,12 @@ func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error { // Invoke packet completes the sequence. Send to NewServerStream. select { - case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: pkt.Data, metadata: ps.metadata}: + case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: *pkt.Data, metadata: ps.metadata}: // Wait for NewServerStream to finish stream creation before reading the // next frame. This guarantees curr is set for subsequent non-invoke // packets. m.pdone.Recv() + m.recvPool.Put(pkt.Data) // TODO: reuse pending stream delete(m.pendingStreams, fr.ID.Stream) case <-m.sigs.term.Signal(): @@ -260,7 +263,9 @@ func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error { // // newStream creates a stream value with the appropriate configuration for this manager. -func (m *Manager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKind, rpc string) (*drpcstream.Stream, error) { +func (m *Manager) newStream( + ctx context.Context, sid uint64, kind drpc.StreamKind, rpc string, +) (*drpcstream.Stream, error) { opts := m.opts.Stream drpcopts.SetStreamKind(&opts.Internal, kind) drpcopts.SetStreamRPC(&opts.Internal, rpc) @@ -268,7 +273,7 @@ func (m *Manager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKin drpcopts.SetStreamStats(&opts.Internal, cb(rpc)) } - stream := drpcstream.NewWithOptions(ctx, sid, m.wr, opts) + stream := drpcstream.NewWithOptions(ctx, sid, m.wr, m.recvPool, opts) if err := m.streams.Add(sid, stream); err != nil { return nil, err @@ -334,7 +339,9 @@ func (m *Manager) Close() error { } // NewClientStream starts a stream on the managed transport for use by a client. -func (m *Manager) NewClientStream(ctx context.Context, rpc string) (stream *drpcstream.Stream, err error) { +func (m *Manager) NewClientStream( + ctx context.Context, rpc string, +) (stream *drpcstream.Stream, err error) { if err := ctx.Err(); err != nil { return nil, err } @@ -344,7 +351,9 @@ func (m *Manager) NewClientStream(ctx context.Context, rpc string) (stream *drpc // NewServerStream starts a stream on the managed transport for use by a server. // It does this by waiting for the client to issue an invoke message and // returning the details. -func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Stream, rpc string, err error) { +func (m *Manager) NewServerStream( + ctx context.Context, +) (stream *drpcstream.Stream, rpc string, err error) { select { case <-ctx.Done(): return nil, "", ctx.Err() diff --git a/drpcstream/ring_buffer.go b/drpcstream/ring_buffer.go index 5cb620ab..0eb3e481 100644 --- a/drpcstream/ring_buffer.go +++ b/drpcstream/ring_buffer.go @@ -3,7 +3,11 @@ package drpcstream -import "sync" +import ( + "sync" + + "storj.io/drpc/drpcwire" +) // defaultRingBufferCapacity is the number of messages the ring buffer can // hold before the producer blocks. This decouples the transport reader @@ -15,11 +19,13 @@ const defaultRingBufferCapacity = 256 // ringBuffer is a bounded single-producer / single-consumer FIFO queue for // assembled packet data. It sits between manageReader (producer, calls -// Enqueue) and the application goroutine (consumer, calls Dequeue/Done). +// Enqueue) and the application goroutine (consumer, calls Dequeue). // -// Slots are pre-allocated and reused: each slot's backing array grows via -// append to fit incoming data, then stays at its high-water mark, avoiding -// per-message allocation in steady state. +// Buffers are obtained from a shared BufferPool. Enqueue copies data into a +// pooled buffer, while EnqueueOwned takes ownership of an already-pooled buffer +// without copying; Dequeue returns ownership of that buffer to the caller and +// advances the tail immediately. The caller is responsible for returning the +// buffer to the pool via BufferPool.Put. // // After Close, Dequeue drains any queued messages before returning the close // error. This ensures graceful shutdown (KindClose/KindCloseSend) delivers @@ -28,24 +34,25 @@ type ringBuffer struct { mu sync.Mutex cond sync.Cond - buf [][]byte // ring of byte slices - head int // next write position (producer) - tail int // next read position (consumer) - count int // number of occupied slots + pool *drpcwire.BufferPool // shared pool; nil means allocate fresh each time + buf []*[]byte // ring of pooled buffer pointers + head int // next write position (producer) + tail int // next read position (consumer) + count int // number of occupied slots - held bool // true between Dequeue and Done - err error // terminal error, set by Close + err error // terminal error, set by Close } -func (rb *ringBuffer) init() { +func (rb *ringBuffer) init(pool *drpcwire.BufferPool) { rb.cond.L = &rb.mu - rb.buf = make([][]byte, defaultRingBufferCapacity) + rb.pool = pool + rb.buf = make([]*[]byte, defaultRingBufferCapacity) } -// Enqueue copies data into the next write slot. If the buffer is full, it -// blocks until a slot is freed or the buffer is closed. If the buffer is -// closed, Enqueue returns silently without enqueuing. -func (rb *ringBuffer) Enqueue(data []byte) { +// Enqueue copies data into a pooled buffer and places it in the next write +// slot. If the buffer is full, it blocks until a slot is freed or the buffer +// is closed. If the buffer is closed, Enqueue returns silently. +func (rb *ringBuffer) Enqueue(data *[]byte) { rb.mu.Lock() defer rb.mu.Unlock() @@ -53,19 +60,20 @@ func (rb *ringBuffer) Enqueue(data []byte) { rb.cond.Wait() } if rb.err != nil { + rb.pool.Put(data) return } - rb.buf[rb.head] = append(rb.buf[rb.head][:0], data...) + rb.buf[rb.head] = data rb.head = (rb.head + 1) % len(rb.buf) rb.count++ rb.cond.Broadcast() } -// Dequeue returns the data from the next read slot. If the buffer is empty, -// it blocks until data is available or the buffer is closed. The returned -// slice is valid until Done is called. -func (rb *ringBuffer) Dequeue() ([]byte, error) { +// Dequeue returns the next buffered message. The returned *[]byte is owned +// by the caller; the tail is advanced immediately. If the ring buffer has a +// pool, the caller should return the buffer via BufferPool.Put when done. +func (rb *ringBuffer) Dequeue() (*[]byte, error) { rb.mu.Lock() defer rb.mu.Unlock() @@ -76,37 +84,21 @@ func (rb *ringBuffer) Dequeue() ([]byte, error) { return nil, rb.err } - rb.held = true - return rb.buf[rb.tail], nil -} - -// Done advances the read pointer, making the slot available for reuse. -// It must be called exactly once after each successful Dequeue. -// -// TODO(shubham): remove this method once a shared buffer pool is introduced. -// With a pool, Dequeue will advance the tail immediately and the caller will -// return the buffer to the pool directly. -func (rb *ringBuffer) Done() { - rb.mu.Lock() - defer rb.mu.Unlock() - + b := rb.buf[rb.tail] + rb.buf[rb.tail] = nil rb.tail = (rb.tail + 1) % len(rb.buf) rb.count-- - rb.held = false rb.cond.Broadcast() + + return b, nil } // Close marks the buffer as closed with the given error. All blocked Enqueue -// and Dequeue calls are woken and will return. Close waits for any in-progress -// Dequeue/Done pair to complete before setting the error. Subsequent calls are -// no-ops. +// and Dequeue calls are woken and will return. Subsequent calls are no-ops. func (rb *ringBuffer) Close(err error) { rb.mu.Lock() defer rb.mu.Unlock() - for rb.held { - rb.cond.Wait() - } if rb.err != nil { return } diff --git a/drpcstream/ring_buffer_test.go b/drpcstream/ring_buffer_test.go index 8be9c587..e1a7b039 100644 --- a/drpcstream/ring_buffer_test.go +++ b/drpcstream/ring_buffer_test.go @@ -9,72 +9,79 @@ import ( "testing" "github.com/zeebo/assert" + + "storj.io/drpc/drpcwire" ) +// enqueue mimics the producer: it takes a pooled buffer, fills it, and hands +// ownership to the ring buffer, matching how handlePacket feeds the queue. +func enqueue(rb *ringBuffer, data []byte) { + b := rb.pool.Get() + *b = append(*b, data...) + rb.Enqueue(b) +} + func TestRingBuffer_EnqueueDequeue(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(drpcwire.NewBufferPool()) - rb.Enqueue([]byte("hello")) + enqueue(&rb, []byte("hello")) data, err := rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("hello")) - rb.Done() + assert.DeepEqual(t, *data, []byte("hello")) } func TestRingBuffer_FIFO(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(drpcwire.NewBufferPool()) - rb.Enqueue([]byte("first")) - rb.Enqueue([]byte("second")) - rb.Enqueue([]byte("third")) + enqueue(&rb, []byte("first")) + enqueue(&rb, []byte("second")) + enqueue(&rb, []byte("third")) for _, want := range []string{"first", "second", "third"} { data, err := rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte(want)) - rb.Done() + assert.DeepEqual(t, *data, []byte(want)) } } func TestRingBuffer_DequeueBlocksUntilEnqueue(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(drpcwire.NewBufferPool()) got := make(chan []byte, 1) go func() { data, err := rb.Dequeue() assert.NoError(t, err) - got <- data + got <- *data }() - rb.Enqueue([]byte("delayed")) + enqueue(&rb, []byte("delayed")) assert.DeepEqual(t, <-got, []byte("delayed")) - rb.Done() } func TestRingBuffer_EnqueueBlocksWhenFull(t *testing.T) { var rb ringBuffer rb.cond.L = &rb.mu - rb.buf = make([][]byte, 2) // capacity 2 + rb.pool = drpcwire.NewBufferPool() + rb.buf = make([]*[]byte, 2) // capacity 2 - rb.Enqueue([]byte("a")) - rb.Enqueue([]byte("b")) + enqueue(&rb, []byte("a")) + enqueue(&rb, []byte("b")) // Third enqueue should block until we drain one. done := make(chan struct{}) go func() { - rb.Enqueue([]byte("c")) + enqueue(&rb, []byte("c")) close(done) }() // Drain one slot. data, err := rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("a")) - rb.Done() + assert.DeepEqual(t, *data, []byte("a")) // Now the blocked Enqueue should complete. <-done @@ -82,18 +89,16 @@ func TestRingBuffer_EnqueueBlocksWhenFull(t *testing.T) { // Verify remaining items. data, err = rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("b")) - rb.Done() + assert.DeepEqual(t, *data, []byte("b")) data, err = rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("c")) - rb.Done() + assert.DeepEqual(t, *data, []byte("c")) } func TestRingBuffer_CloseUnblocksDequeue(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(drpcwire.NewBufferPool()) errch := make(chan error, 1) go func() { @@ -108,13 +113,14 @@ func TestRingBuffer_CloseUnblocksDequeue(t *testing.T) { func TestRingBuffer_CloseUnblocksEnqueue(t *testing.T) { var rb ringBuffer rb.cond.L = &rb.mu - rb.buf = make([][]byte, 1) // capacity 1 + rb.pool = drpcwire.NewBufferPool() + rb.buf = make([]*[]byte, 1) // capacity 1 - rb.Enqueue([]byte("fill")) + enqueue(&rb, []byte("fill")) done := make(chan struct{}) go func() { - rb.Enqueue([]byte("blocked")) + enqueue(&rb, []byte("blocked")) close(done) }() @@ -124,16 +130,15 @@ func TestRingBuffer_CloseUnblocksEnqueue(t *testing.T) { func TestRingBuffer_CloseDrainsQueued(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(drpcwire.NewBufferPool()) - rb.Enqueue([]byte("queued")) + enqueue(&rb, []byte("queued")) rb.Close(io.EOF) // Dequeue returns the queued data first. data, err := rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("queued")) - rb.Done() + assert.DeepEqual(t, *data, []byte("queued")) // Next Dequeue returns the close error. data, err = rb.Dequeue() @@ -143,7 +148,7 @@ func TestRingBuffer_CloseDrainsQueued(t *testing.T) { func TestRingBuffer_CloseIdempotent(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(drpcwire.NewBufferPool()) rb.Close(io.EOF) rb.Close(io.ErrUnexpectedEOF) // should not overwrite @@ -154,53 +159,30 @@ func TestRingBuffer_CloseIdempotent(t *testing.T) { func TestRingBuffer_EnqueueAfterClose(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(drpcwire.NewBufferPool()) rb.Close(io.EOF) - rb.Enqueue([]byte("dropped")) // should not panic or block + enqueue(&rb, []byte("dropped")) // should not panic or block } func TestRingBuffer_SlotReuse(t *testing.T) { var rb ringBuffer rb.cond.L = &rb.mu - rb.buf = make([][]byte, 2) + rb.pool = drpcwire.NewBufferPool() + rb.buf = make([]*[]byte, 2) // Fill and drain a few rounds to exercise slot reuse. for round := 0; round < 5; round++ { - rb.Enqueue([]byte("data")) + enqueue(&rb, []byte("data")) data, err := rb.Dequeue() assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("data")) - rb.Done() + assert.DeepEqual(t, *data, []byte("data")) } } -func TestRingBuffer_CloseWaitsForHeld(t *testing.T) { - var rb ringBuffer - rb.init() - - rb.Enqueue([]byte("msg")) - - // Dequeue the data but don't call Done yet. - data, err := rb.Dequeue() - assert.NoError(t, err) - assert.DeepEqual(t, data, []byte("msg")) - - closed := make(chan struct{}) - go func() { - rb.Close(io.EOF) - close(closed) - }() - - // Close should be blocked because held is true. - // Call Done to release it. - rb.Done() - <-closed -} - func TestRingBuffer_ConcurrentProducerConsumer(t *testing.T) { var rb ringBuffer - rb.init() + rb.init(drpcwire.NewBufferPool()) const n = 1000 var wg sync.WaitGroup @@ -209,7 +191,7 @@ func TestRingBuffer_ConcurrentProducerConsumer(t *testing.T) { go func() { defer wg.Done() for i := 0; i < n; i++ { - rb.Enqueue([]byte{byte(i)}) + enqueue(&rb, []byte{byte(i)}) } }() @@ -218,11 +200,25 @@ func TestRingBuffer_ConcurrentProducerConsumer(t *testing.T) { for i := 0; i < n; i++ { data, err := rb.Dequeue() assert.NoError(t, err) - assert.Equal(t, data[0], byte(i)) - rb.Done() + assert.Equal(t, (*data)[0], byte(i)) } }() wg.Wait() rb.Close(io.EOF) } + +func TestRingBuffer_WithPool(t *testing.T) { + pool := drpcwire.NewBufferPool() + var rb ringBuffer + rb.init(pool) + + enqueue(&rb, []byte("pooled")) + + data, err := rb.Dequeue() + assert.NoError(t, err) + assert.DeepEqual(t, *data, []byte("pooled")) + pool.Put(data) + + rb.Close(io.EOF) +} diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 1fa0460f..836795c6 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -11,7 +11,6 @@ import ( "sync" "github.com/zeebo/errs" - "storj.io/drpc" "storj.io/drpc/drpcctx" "storj.io/drpc/drpcdebug" @@ -53,6 +52,7 @@ type Stream struct { id drpcwire.ID wr *drpcwire.MuxWriter + pool *drpcwire.BufferPool recvQueue ringBuffer wbuf []byte @@ -78,15 +78,19 @@ var _ drpc.Stream = (*Stream)(nil) // New returns a new stream bound to the context with the given stream id and // will use the writer to write messages on. It is important use monotonically // increasing stream ids within a single transport. -func New(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter) *Stream { - return NewWithOptions(ctx, sid, wr, Options{}) +func New( + ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, pool *drpcwire.BufferPool, +) *Stream { + return NewWithOptions(ctx, sid, wr, pool, Options{}) } // NewWithOptions returns a new stream bound to the context with the given // stream id and will use the writer to write messages on. It is important use // monotonically increasing stream ids within a single transport. The options // are used to control details of how the Stream operates. -func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, opts Options) *Stream { +func NewWithOptions( + ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, pool *drpcwire.BufferPool, opts Options, +) *Stream { var task *trace.Task if trace.IsEnabled() { kind, rpc := drpcopts.GetStreamKind(&opts.Internal), drpcopts.GetStreamRPC(&opts.Internal) @@ -95,7 +99,9 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, opt } } - pa := drpcwire.NewPacketAssembler() + // When a pool is available, assemble directly into pooled buffers so the + // completed packet can be handed off to the recv queue without a copy. + pa := drpcwire.NewPacketAssembler(pool) pa.SetStreamID(sid) s := &Stream{ @@ -108,12 +114,12 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.MuxWriter, opt pa: pa, - id: drpcwire.ID{Stream: sid}, - wr: wr, + id: drpcwire.ID{Stream: sid}, + wr: wr, + pool: pool, } - // initialize the packet buffer - s.recvQueue.init() + s.recvQueue.init(pool) return s } @@ -221,15 +227,24 @@ func (s *Stream) HandleFrame(fr drpcwire.Frame) (err error) { // returns any major errors that should terminate the transport the stream is // operating on. func (s *Stream) handlePacket(pkt drpcwire.Packet) (err error) { - drpcopts.GetStreamStats(&s.opts.Internal).AddRead(uint64(len(pkt.Data))) + drpcopts.GetStreamStats(&s.opts.Internal).AddRead(uint64(len(*pkt.Data))) s.log("HANDLE", pkt.String) if pkt.Kind == drpcwire.KindMessage { + // The assembler handed us ownership of the pooled buffer; enqueue + // it directly without copying. s.recvQueue.Enqueue(pkt.Data) return nil } + // Control and error packets are consumed here and not handed to the recv + // queue; return any pooled buffer once we're done reading pkt.Data. The + // defer runs after the switch below, so data stays valid while in use. + if pkt.Data != nil { + defer s.pool.Put(pkt.Data) + } + s.mu.Lock() defer s.mu.Unlock() @@ -240,7 +255,7 @@ func (s *Stream) handlePacket(pkt drpcwire.Packet) (err error) { return err case drpcwire.KindError: - err := drpcwire.UnmarshalError(pkt.Data) + err := drpcwire.UnmarshalError(*pkt.Data) s.sigs.send.Set(io.EOF) // in this state, gRPC returns io.EOF on send. s.terminate(err) return nil @@ -414,12 +429,12 @@ func (s *Stream) RawRecv() (data []byte, err error) { s.read.Lock() defer s.read.Unlock() - data, err = s.recvQueue.Dequeue() + b, err := s.recvQueue.Dequeue() if err != nil { return nil, err } - data = append([]byte(nil), data...) - s.recvQueue.Done() + data = append([]byte(nil), *b...) + s.pool.Put(b) return data, nil } @@ -456,12 +471,12 @@ func (s *Stream) MsgRecv(msg drpc.Message, enc drpc.Encoding) (err error) { s.read.Lock() defer s.read.Unlock() - data, err := s.recvQueue.Dequeue() + b, err := s.recvQueue.Dequeue() if err != nil { return err } - err = enc.Unmarshal(data, msg) - s.recvQueue.Done() + err = enc.Unmarshal(*b, msg) + s.pool.Put(b) return err } diff --git a/drpcstream/stream_test.go b/drpcstream/stream_test.go index 9cee8ddf..724fab64 100644 --- a/drpcstream/stream_test.go +++ b/drpcstream/stream_test.go @@ -117,7 +117,7 @@ func TestStream_StateTransitions(t *testing.T) { } for _, test := range cases { - st := New(ctx, 1, mw) + st := New(ctx, 1, mw, drpcwire.NewBufferPool()) assert.NoError(t, test.Op(st)) checkErrs(t, test.Send, st.RawWrite(drpcwire.KindMessage, nil)) @@ -169,7 +169,7 @@ func TestStream_Unblocks(t *testing.T) { } for _, test := range cases { - st := New(ctx, 1, mw) + st := New(ctx, 1, mw, drpcwire.NewBufferPool()) ctx.Run(func(ctx context.Context) { _, _ = st.RawRecv() }) assert.NoError(t, test.Op(st)) @@ -180,7 +180,7 @@ func TestStream_Unblocks(t *testing.T) { func TestStream_ContextCancel(t *testing.T) { ctx := context.Background() mw := testMuxWriter(t) - st := New(ctx, 0, mw) + st := New(ctx, 0, mw, drpcwire.NewBufferPool()) child, cancel := context.WithCancel(st.Context()) defer cancel() @@ -195,7 +195,7 @@ func TestStream_ConcurrentCloseCancel(t *testing.T) { defer ctx.Close() mw := testMuxWriter(t) - st := New(ctx, 0, mw) + st := New(ctx, 0, mw, drpcwire.NewBufferPool()) // Close and Cancel concurrently should not panic or deadlock. errch := make(chan error, 1) @@ -219,7 +219,7 @@ func TestStream_PacketBufferReuse(t *testing.T) { mw := testMuxWriter(t) data := make([]byte, 20) mid := uint64(1) - st := New(ctx, 1, mw) + st := New(ctx, 1, mw, drpcwire.NewBufferPool()) ctx.Run(func(ctx context.Context) { for !st.IsTerminated() { @@ -265,7 +265,7 @@ func TestStream_PacketBufferReuse(t *testing.T) { func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { mw := testMuxWriter(t) for _, messageID := range []uint64{1, 2} { - st := New(context.Background(), 1, mw) + st := New(context.Background(), 1, mw, drpcwire.NewBufferPool()) // Close the ring buffer so KindMessage Enqueue doesn't block. st.recvQueue.Close(io.EOF) err := st.HandleFrame(drpcwire.Frame{ @@ -278,7 +278,7 @@ func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { // Invoke and InvokeMetadata frames are rejected on an already-created stream. func TestHandleFrame_InvokeOnExistingStream(t *testing.T) { mw := testMuxWriter(t) - st := New(context.Background(), 1, mw) + st := New(context.Background(), 1, mw, drpcwire.NewBufferPool()) err := handleFrame(st, drpcwire.KindInvoke, 1) assert.Error(t, err) @@ -288,7 +288,7 @@ func TestHandleFrame_InvokeOnExistingStream(t *testing.T) { func TestHandleFrame_InvokeMetadataOnExistingStream(t *testing.T) { mw := testMuxWriter(t) - st := New(context.Background(), 1, mw) + st := New(context.Background(), 1, mw, drpcwire.NewBufferPool()) err := handleFrame(st, drpcwire.KindInvokeMetadata, 1) assert.Error(t, err) @@ -299,7 +299,7 @@ func TestHandleFrame_InvokeMetadataOnExistingStream(t *testing.T) { // Frames arriving after the stream is terminated are silently ignored. func TestHandleFrame_AfterTerminated(t *testing.T) { mw := testMuxWriter(t) - st := New(context.Background(), 1, mw) + st := New(context.Background(), 1, mw, drpcwire.NewBufferPool()) // Terminate the stream via cancel. st.Cancel(context.Canceled) @@ -317,7 +317,7 @@ func TestHandleFrame_MessageDeliveredViaRecv(t *testing.T) { defer ctx.Close() mw := testMuxWriter(t) - st := New(ctx, 1, mw) + st := New(ctx, 1, mw, drpcwire.NewBufferPool()) // Launch receiver before sending to avoid Put blocking. recv := make(chan []byte, 1) diff --git a/drpcwire/buffer_pool.go b/drpcwire/buffer_pool.go new file mode 100644 index 00000000..59a81964 --- /dev/null +++ b/drpcwire/buffer_pool.go @@ -0,0 +1,42 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcwire + +import "sync" + +// BufferPool wraps sync.Pool to provide reusable byte slices for the +// stream receive path. Buffers obtained via Get should be returned via +// Put when no longer needed. Forgetting to Put is safe (GC reclaims) +// but reduces reuse. +type BufferPool struct { + pool sync.Pool +} + +// NewBufferPool returns a new buffer pool. +func NewBufferPool() *BufferPool { + return &BufferPool{ + pool: sync.Pool{ + New: func() interface{} { + b := make([]byte, 0, 4096) + return &b + }, + }, + } +} + +// Get returns a zero-length byte slice from the pool, retaining its +// backing array for reuse. +func (bp *BufferPool) Get() *[]byte { + p := bp.pool.Get().(*[]byte) + *p = (*p)[:0] + return p +} + +// Put returns a buffer to the pool. Nil is safe to pass. +func (bp *BufferPool) Put(b *[]byte) { + if b == nil { + return + } + bp.pool.Put(b) +} diff --git a/drpcwire/packet.go b/drpcwire/packet.go index 447eb40e..fa2c8afd 100644 --- a/drpcwire/packet.go +++ b/drpcwire/packet.go @@ -149,8 +149,7 @@ func AppendFrame(buf []byte, fr Frame) []byte { // Packet is a single message sent by drpc. type Packet struct { - // Data is the payload of the packet. - Data []byte + Data *[]byte // ID is the identifier for the packet. ID ID @@ -168,5 +167,5 @@ type Packet struct { // String returns a human readable form of the packet. func (p Packet) String() string { return fmt.Sprintf("", - p.ID.Stream, p.ID.Message, len(p.Data), p.Kind) + p.ID.Stream, p.ID.Message, len(*p.Data), p.Kind) } diff --git a/drpcwire/packet_assembler.go b/drpcwire/packet_assembler.go index 2cf7fae3..b9583601 100644 --- a/drpcwire/packet_assembler.go +++ b/drpcwire/packet_assembler.go @@ -1,8 +1,6 @@ package drpcwire -import ( - "storj.io/drpc" -) +import "storj.io/drpc" // PacketAssembler assembles frames into complete packets, enforcing wire // protocol invariants: @@ -11,19 +9,27 @@ import ( // - Message IDs must be monotonically increasing. // - Frame kind must not change within a single packet (multi-frame). // +// When constructed with a BufferPool, the assembler assembles directly into a +// pooled buffer and transfers its ownership to the returned packet (via +// Packet.Buf), removing a copy on the receive path. Without a pool it reuses +// its own backing array, and the caller must consume packet.Data before the +// next AppendFrame call. +// // It is not safe for concurrent use. type PacketAssembler struct { + pool *BufferPool pk Packet assembling bool streamInitialized bool } // NewPacketAssembler returns a new PacketAssembler ready to assemble frames. -func NewPacketAssembler() PacketAssembler { +func NewPacketAssembler(pool *BufferPool) PacketAssembler { return PacketAssembler{ pk: Packet{ ID: ID{Stream: 0, Message: 1}, }, + pool: pool, } } @@ -36,6 +42,9 @@ func (pa *PacketAssembler) SetStreamID(streamID uint64) { // Reset clears all assembly state, preparing the assembler for a new stream. func (pa *PacketAssembler) Reset() { + if pa.pk.Data != nil { + pa.pool.Put(pa.pk.Data) + } pa.pk = Packet{ ID: ID{Stream: 0, Message: 1}, } @@ -46,7 +55,7 @@ func (pa *PacketAssembler) Reset() { // AppendFrame adds a frame to the in-progress packet. It returns the completed // packet and true when a frame with Done=true is received. It returns false // when more frames are needed to complete the packet. -func (pa *PacketAssembler) AppendFrame(fr Frame) (packet Packet, packetReady bool, err error) { +func (pa *PacketAssembler) AppendFrame(fr Frame) (Packet, bool, error) { // Enforce stream ID consistency: infer from first frame or reject mismatches. if !pa.streamInitialized { pa.pk.ID.Stream = fr.ID.Stream @@ -60,8 +69,11 @@ func (pa *PacketAssembler) AppendFrame(fr Frame) (packet Packet, packetReady boo return Packet{}, false, drpc.ProtocolError.New( "message id monotonicity violation: got %v, expected >= %v", fr.ID.Message, pa.pk.ID.Message) } else if fr.ID.Message > pa.pk.ID.Message || !pa.assembling { - // New message: reset the buffer and start assembling. - pa.pk.Data = pa.pk.Data[:0] + if pa.pk.Data == nil { + pa.pk.Data = pa.pool.Get() + } else { + *pa.pk.Data = (*pa.pk.Data)[:0] + } pa.assembling = true pa.pk.ID.Message = fr.ID.Message } else if fr.Kind != pa.pk.Kind { @@ -69,8 +81,9 @@ func (pa *PacketAssembler) AppendFrame(fr Frame) (packet Packet, packetReady boo "frame kind changed mid-packet: got %v, expected %v", fr.Kind, pa.pk.Kind) } - // TODO(shubham): add buf reuse - pa.pk.Data = append(pa.pk.Data, fr.Data...) + // Assemble directly into the pooled buffer so the completed packet can + // be handed off down the receive path without another copy. + *pa.pk.Data = append(*pa.pk.Data, fr.Data...) pa.pk.Kind = fr.Kind pa.pk.Control = fr.Control @@ -78,12 +91,9 @@ func (pa *PacketAssembler) AppendFrame(fr Frame) (packet Packet, packetReady boo return Packet{}, false, nil } - packet = pa.pk - + packet := pa.pk pa.assembling = false pa.pk.ID.Message = fr.ID.Message + 1 - // Reuse the backing array: the caller must consume packet.Data before the - // next AppendFrame call, as it will be overwritten. - pa.pk.Data = pa.pk.Data[:0] + pa.pk.Data = nil return packet, true, nil } diff --git a/drpcwire/packet_assembler_test.go b/drpcwire/packet_assembler_test.go index 41cf70da..4f555bbd 100644 --- a/drpcwire/packet_assembler_test.go +++ b/drpcwire/packet_assembler_test.go @@ -13,7 +13,7 @@ import ( ) func TestPacketAssembler_WrongStreamID(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) _, _, err := pa.AppendFrame(Frame{ @@ -27,7 +27,7 @@ func TestPacketAssembler_WrongStreamID(t *testing.T) { } func TestPacketAssembler_StreamIDInferredFromFirstFrame(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) // First frame sets the stream ID implicitly. _, _, err := pa.AppendFrame(Frame{ @@ -49,7 +49,7 @@ func TestPacketAssembler_StreamIDInferredFromFirstFrame(t *testing.T) { // A frame with a message ID lower than a previously completed message is rejected. func TestPacketAssembler_MessageMonotonicity(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) // m3 completes, next expected becomes 4. @@ -70,7 +70,7 @@ func TestPacketAssembler_MessageMonotonicity(t *testing.T) { // When a higher message ID arrives mid-assembly, the in-progress data is // silently discarded and a new packet begins. func TestPacketAssembler_HigherMsgDiscardsInProgress(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) // Start accumulating m1. @@ -86,13 +86,13 @@ func TestPacketAssembler_HigherMsgDiscardsInProgress(t *testing.T) { }) assert.NoError(t, err) assert.That(t, ready) - assert.DeepEqual(t, pkt.Data, []byte("kept")) + assert.DeepEqual(t, *pkt.Data, []byte("kept")) } // Continuation frames (same message ID, mid-assembly) must carry the same // kind as the first frame. A kind change mid-packet is a protocol error. func TestPacketAssembler_KindChangeWithinPacket(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) _, _, err := pa.AppendFrame(Frame{ @@ -110,7 +110,7 @@ func TestPacketAssembler_KindChangeWithinPacket(t *testing.T) { // Multiple continuation frames for the same message accumulate data correctly. func TestPacketAssembler_MultiFrameDataAccumulation(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) _, ready, err := pa.AppendFrame(Frame{ @@ -130,14 +130,14 @@ func TestPacketAssembler_MultiFrameDataAccumulation(t *testing.T) { }) assert.NoError(t, err) assert.That(t, ready) - assert.DeepEqual(t, pkt.Data, []byte("hello world")) + assert.DeepEqual(t, *pkt.Data, []byte("hello world")) } // Multi-frame assembly works when the message ID is greater than the initial // expected ID (e.g., on the server side where invoke consumed earlier message // IDs). Continuation frames must accumulate data, not reset on each frame. func TestPacketAssembler_MultiFrameWithSkippedMessageID(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) // msg=3 is greater than initial expected message ID=1. @@ -158,12 +158,12 @@ func TestPacketAssembler_MultiFrameWithSkippedMessageID(t *testing.T) { }) assert.NoError(t, err) assert.That(t, ready) - assert.DeepEqual(t, pkt.Data, []byte("hello world")) + assert.DeepEqual(t, *pkt.Data, []byte("hello world")) } // Once a message completes (done=true), the same message ID is rejected. func TestPacketAssembler_DonePreventsReplay(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) // m1 completes → next expected becomes 2. @@ -185,7 +185,7 @@ func TestPacketAssembler_DonePreventsReplay(t *testing.T) { // across messages. A KindMessage followed by a KindClose for the next message // should be accepted without error. func TestPacketAssembler_KindChangeAcrossMessages(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) // Multi-frame message 1 with KindMessage. @@ -199,7 +199,7 @@ func TestPacketAssembler_KindChangeAcrossMessages(t *testing.T) { }) assert.NoError(t, err) assert.That(t, ready) - assert.DeepEqual(t, pkt.Data, []byte("abcd")) + assert.DeepEqual(t, *pkt.Data, []byte("abcd")) // Message 2 with a different kind — should not trigger kind check. pkt, ready, err = pa.AppendFrame(Frame{ @@ -212,7 +212,7 @@ func TestPacketAssembler_KindChangeAcrossMessages(t *testing.T) { // Reset clears all state so the assembler can be reused for a new stream. func TestPacketAssembler_Reset(t *testing.T) { - pa := NewPacketAssembler() + pa := NewPacketAssembler(NewBufferPool()) pa.SetStreamID(1) // Complete a packet on stream 1. @@ -230,6 +230,6 @@ func TestPacketAssembler_Reset(t *testing.T) { }) assert.NoError(t, err) assert.That(t, ready) - assert.DeepEqual(t, pkt.Data, []byte("new")) + assert.DeepEqual(t, *pkt.Data, []byte("new")) assert.Equal(t, pkt.ID.Stream, uint64(2)) } diff --git a/drpcwire/rand_test.go b/drpcwire/rand_test.go index 1eeb4b44..6d0f3304 100644 --- a/drpcwire/rand_test.go +++ b/drpcwire/rand_test.go @@ -75,12 +75,3 @@ func RandFrame() Frame { Done: RandBool(), } } - -func RandPacket() Packet { - kind := RandKind() - return Packet{ - Data: RandBytes(10 * payloadSize[kind]()), - ID: RandID(), - Kind: kind, - } -} diff --git a/drpcwire/split.go b/drpcwire/split.go index 1189615e..3b3c362a 100644 --- a/drpcwire/split.go +++ b/drpcwire/split.go @@ -3,32 +3,6 @@ package drpcwire -// SplitN splits the marshaled form of the Packet into a number of -// frames such that each frame is at most n bytes. It calls -// the callback with every such frame. If n is zero, a reasonable -// default is used. -func SplitN(pkt Packet, n int, cb func(fr Frame) error) error { - for { - fr := Frame{ - Data: pkt.Data, - ID: pkt.ID, - Kind: pkt.Kind, - Control: pkt.Control, - Done: true, - } - - fr.Data, pkt.Data = SplitData(pkt.Data, n) - fr.Done = len(pkt.Data) == 0 - - if err := cb(fr); err != nil { - return err - } - if fr.Done { - return nil - } - } -} - // SplitData is used to split a buffer if it is larger than n bytes. // If n is zero, a reasonable default is used. If n is less than zero // then it does not split. diff --git a/drpcwire/split_test.go b/drpcwire/split_test.go deleted file mode 100644 index 29daf4c4..00000000 --- a/drpcwire/split_test.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -package drpcwire - -import ( - "bytes" - "math/rand" - "testing" - - "github.com/zeebo/assert" -) - -func TestSplit(t *testing.T) { - for i := 0; i < 1000; i++ { - pkt, done, n := RandPacket(), false, rand.Intn(10)-1 - if size := rand.Intn(100); size < len(pkt.Data) { - pkt.Data = pkt.Data[:size] - } - - var buf []byte - assert.NoError(t, SplitN(pkt, n, func(fr Frame) error { - assert.That(t, !done) - assert.That(t, len(fr.Data) <= n || - (n == -1 && len(pkt.Data) == len(fr.Data)) || - (n == 0 && len(fr.Data) <= 1024)) - assert.Equal(t, pkt.Kind, fr.Kind) - done = fr.Done - buf = append(buf, fr.Data...) - return nil - })) - - assert.That(t, done) - assert.That(t, bytes.Equal(pkt.Data, buf)) - } -}