-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathclient_interceptor.go
More file actions
351 lines (320 loc) · 11.3 KB
/
client_interceptor.go
File metadata and controls
351 lines (320 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
package gorums
import (
"context"
"slices"
"sync"
"github.com/relab/gorums/internal/stream"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/emptypb"
)
// QuorumInterceptor intercepts and processes quorum calls, allowing modification of
// requests, responses, and aggregation logic. Interceptors can be chained together.
//
// Type parameters:
// - Req: The request message type sent to nodes
// - Resp: The response message type from individual nodes
//
// The interceptor receives the ClientCtx for metadata access, the current response
// iterator (next), and returns a new response iterator. This pattern allows
// interceptors to wrap the response stream with custom logic.
//
// Custom interceptors can be created like this:
//
// func LoggingInterceptor[Req, Resp proto.Message](
// ctx *gorums.ClientCtx[Req, Resp],
// next gorums.ResponseSeq[Resp],
// ) gorums.ResponseSeq[Resp] {
// return func(yield func(gorums.NodeResponse[Resp]) bool) {
// for resp := range next {
// log.Printf("Response from node %d", resp.NodeID)
// if !yield(resp) { return }
// }
// }
// }
type QuorumInterceptor[Req, Resp msg] func(ctx *ClientCtx[Req, Resp], next ResponseSeq[Resp]) ResponseSeq[Resp]
// ClientCtx provides context and access to the quorum call state for interceptors.
// It exposes the request, configuration, metadata about the call, and the response iterator.
type ClientCtx[Req, Resp msg] struct {
context.Context
config Configuration
request Req
method string
msgID uint64
replyChan chan NodeResponse[*stream.Message]
// reqTransforms holds request transformation functions registered by interceptors.
reqTransforms []func(Req, *Node) Req
// responseSeq is the iterator that yields node responses.
// Interceptors can wrap this iterator to modify responses.
responseSeq ResponseSeq[Resp]
// streaming indicates whether this is a streaming call (for correctable streams).
streaming bool
// oneway indicates whether this is a one-way call (for multicast).
oneway bool
// sendOnce ensures messages are sent exactly once, on the first
// call to Responses(). This deferred sending allows interceptors
// to register request transformations before dispatch.
sendOnce sync.Once
}
// sendNow triggers request dispatch exactly once.
func (c *ClientCtx[Req, Resp]) sendNow() {
c.sendOnce.Do(c.send)
}
// newQuorumCallClientCtx constructs a ClientCtx for quorum calls (two-way, always returns responses).
// A reply channel is always created; streaming controls both its buffer size and the response iterator type.
func newQuorumCallClientCtx[Req, Resp msg](
ctx *ConfigContext,
req Req,
method string,
streaming bool,
interceptors []any,
) *ClientCtx[Req, Resp] {
config := ctx.Configuration()
n := config.Size()
if streaming {
n *= 10
}
clientCtx := &ClientCtx[Req, Resp]{
Context: ctx,
config: config,
request: req,
method: method,
msgID: config.nextMsgID(),
streaming: streaming,
replyChan: make(chan NodeResponse[*stream.Message], n),
}
if streaming {
clientCtx.responseSeq = clientCtx.streamingResponseSeq()
} else {
clientCtx.responseSeq = clientCtx.defaultResponseSeq()
}
clientCtx.applyInterceptors(interceptors)
return clientCtx
}
// newMulticastClientCtx constructs a ClientCtx for multicast (one-way, no responses).
// A reply channel is created only when waitForSend=true (blocking send); fire-and-forget
// calls receive a nil channel, meaning no router entry is registered.
func newMulticastClientCtx[Req msg](
ctx *ConfigContext,
req Req,
method string,
waitForSend bool,
interceptors []any,
) *ClientCtx[Req, *emptypb.Empty] {
config := ctx.Configuration()
var replyChan chan NodeResponse[*stream.Message]
if waitForSend {
replyChan = make(chan NodeResponse[*stream.Message], config.Size())
}
clientCtx := &ClientCtx[Req, *emptypb.Empty]{
Context: ctx,
config: config,
request: req,
method: method,
msgID: config.nextMsgID(),
oneway: true,
replyChan: replyChan,
}
clientCtx.responseSeq = clientCtx.defaultResponseSeq()
clientCtx.applyInterceptors(interceptors)
return clientCtx
}
// -------------------------------------------------------------------------
// ClientCtx Methods
// -------------------------------------------------------------------------
// Request returns the original request message for this quorum call.
func (c *ClientCtx[Req, Resp]) Request() Req {
return c.request
}
// Config returns the configuration (set of nodes) for this quorum call.
func (c *ClientCtx[Req, Resp]) Config() Configuration {
return c.config
}
// Method returns the name of the RPC method being called.
func (c *ClientCtx[Req, Resp]) Method() string {
return c.method
}
// Nodes returns the slice of nodes in this configuration.
func (c *ClientCtx[Req, Resp]) Nodes() []*Node {
return c.config.Nodes()
}
// Node returns the node with the given ID.
func (c *ClientCtx[Req, Resp]) Node(id uint32) *Node {
nodes := c.config.Nodes()
index := slices.IndexFunc(nodes, func(n *Node) bool {
return n.ID() == id
})
if index != -1 {
return nodes[index]
}
return nil
}
// Size returns the number of nodes in this configuration.
func (c *ClientCtx[Req, Resp]) Size() int {
return c.config.Size()
}
// reportNodeError sends an error response for the given node to replyChan.
// It is a no-op for fire-and-forget calls where replyChan is nil.
func (c *ClientCtx[Req, Resp]) reportNodeError(nodeID uint32, err error) {
if c.replyChan != nil {
c.replyChan <- NodeResponse[*stream.Message]{NodeID: nodeID, Err: err}
}
}
// enqueue sends a stream.Request to the given node, populating the shared
// fields from ClientCtx so call sites only need to supply the message.
func (c *ClientCtx[Req, Resp]) enqueue(n *Node, msg *stream.Message) {
n.Enqueue(stream.Request{
Ctx: c.Context,
Msg: msg,
Streaming: c.streaming,
Oneway: c.oneway,
ResponseChan: c.replyChan,
})
}
// applyInterceptors chains the given interceptors, wrapping the response sequence.
// Each interceptor receives the current response sequence and returns a new one.
// Interceptors are applied in order, with each wrapping the previous result.
func (c *ClientCtx[Req, Resp]) applyInterceptors(interceptors []any) {
responseSeq := c.responseSeq
for _, ic := range interceptors {
interceptor := ic.(QuorumInterceptor[Req, Resp])
responseSeq = interceptor(c, responseSeq)
}
c.responseSeq = responseSeq
}
// send dispatches requests to all nodes. It delegates to sendWithPerNodeTransformation
// if any per-node request transformations are registered. Otherwise, it uses sendShared
// to marshal the request once and send the same message to all nodes.
func (c *ClientCtx[Req, Resp]) send() {
if len(c.reqTransforms) == 0 {
c.sendShared()
} else {
c.sendWithPerNodeTransformation()
}
}
// sendShared marshals the request once and enqueues the shared message to all nodes.
// On marshal error, it reports the error to every node and returns early.
func (c *ClientCtx[Req, Resp]) sendShared() {
sharedMsg, err := stream.NewMessage(c.Context, c.msgID, c.method, c.request)
if err != nil {
// Marshaling fails identically for all nodes; report and return.
for _, n := range c.config {
c.reportNodeError(n.ID(), err)
}
return
}
for _, n := range c.config {
c.enqueue(n, sharedMsg)
}
}
// sendWithPerNodeTransformation applies per-node request transformations before
// marshaling and enqueues each individually transformed message to its node.
func (c *ClientCtx[Req, Resp]) sendWithPerNodeTransformation() {
for _, n := range c.config {
streamMsg := c.transformAndMarshal(n)
if streamMsg == nil {
continue // Skip node: transformAndMarshal already sent ErrSkipNode
}
c.enqueue(n, streamMsg)
}
}
// transformAndMarshal applies transformations to the request for the given node,
// then marshals it into a stream.Message. Returns nil if transformation fails
// or marshaling fails (in which case the error is reported via reportNodeError).
func (c *ClientCtx[Req, Resp]) transformAndMarshal(n *Node) *stream.Message {
transformedRequest := c.request
for _, transform := range c.reqTransforms {
transformedRequest = transform(transformedRequest, n)
}
// Check if the result is valid
if protoReq, ok := any(transformedRequest).(proto.Message); !ok || protoReq == nil || !protoReq.ProtoReflect().IsValid() {
c.reportNodeError(n.ID(), ErrSkipNode)
return nil
}
streamMsg, err := stream.NewMessage(c.Context, c.msgID, c.method, transformedRequest)
if err != nil {
c.reportNodeError(n.ID(), err)
return nil
}
return streamMsg
}
// defaultResponseSeq returns an iterator that yields at most c.expectedReplies responses
// from nodes until the context is canceled or all expected responses are received.
func (c *ClientCtx[Req, Resp]) defaultResponseSeq() ResponseSeq[Resp] {
return func(yield func(NodeResponse[Resp]) bool) {
// Trigger sending on first iteration
c.sendNow()
for range c.Size() {
select {
case r := <-c.replyChan:
res := mapToCallResponse[Resp](r)
if !yield(res) {
return // Consumer stopped iteration
}
case <-c.Done():
return // Context canceled
}
}
}
}
// streamingResponseSeq returns an iterator that yields responses as they arrive
// from nodes until the context is canceled or breaking from the range loop.
func (c *ClientCtx[Req, Resp]) streamingResponseSeq() ResponseSeq[Resp] {
return func(yield func(NodeResponse[Resp]) bool) {
// Trigger sending on first iteration
c.sendNow()
for {
select {
case r := <-c.replyChan:
res := mapToCallResponse[Resp](r)
if !yield(res) {
return // Consumer stopped iteration
}
case <-c.Done():
return // Context canceled
}
}
}
}
// -------------------------------------------------------------------------
// Interceptors (Middleware)
// -------------------------------------------------------------------------
// MapRequest returns an interceptor that applies per-node request transformations.
// Multiple interceptors can be chained together, with transforms applied in order.
//
// The fn receives the original request and a node, and returns the transformed
// request to send to that node. If the function returns an invalid message or nil,
// an ErrSkipNode error is sent for that node, indicating it was skipped.
func MapRequest[Req, Resp msg](fn func(Req, *Node) Req) QuorumInterceptor[Req, Resp] {
return func(ctx *ClientCtx[Req, Resp], next ResponseSeq[Resp]) ResponseSeq[Resp] {
if fn != nil {
ctx.reqTransforms = append(ctx.reqTransforms, fn)
}
return next
}
}
// MapResponse returns an interceptor that applies per-node response transformations.
//
// The fn receives the response from a node and the node itself, and returns the
// transformed response.
func MapResponse[Req, Resp msg](fn func(Resp, *Node) Resp) QuorumInterceptor[Req, Resp] {
return func(ctx *ClientCtx[Req, Resp], next ResponseSeq[Resp]) ResponseSeq[Resp] {
if fn == nil {
return next
}
// Wrap the response iterator with the transformation logic.
return func(yield func(NodeResponse[Resp]) bool) {
for resp := range next {
// We only apply the transformation if there is no error.
// Errors are passed through as-is.
if resp.Err == nil {
if node := ctx.Node(resp.NodeID); node != nil {
resp.Value = fn(resp.Value, node)
}
}
if !yield(resp) {
return
}
}
}
}
}