diff --git a/client/client.go b/client/client.go index 7af896a9..6536b77d 100644 --- a/client/client.go +++ b/client/client.go @@ -9,6 +9,7 @@ import ( "errors" "io" "math" + "slices" "sync" "time" @@ -21,7 +22,6 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" - "google.golang.org/protobuf/types/known/emptypb" ) // ID is the identifier for a client. @@ -31,17 +31,70 @@ type qspec struct { faulty int } -func (q *qspec) ExecCommandQF(_ *clientpb.Command, signatures map[uint32]*emptypb.Empty) (*emptypb.Empty, bool) { - if len(signatures) < q.faulty+1 { +// leastOverlapSet returns the set of uint64 values that appear in all slices. +func leastOverlapSet(slices [][]uint64) []uint64 { + if len(slices) == 0 { + return []uint64{} + } + + // Count occurrences of each value across all slices + occurrences := make(map[uint64]int) + for _, slice := range slices { + seen := make(map[uint64]bool) + for _, val := range slice { + if !seen[val] { + occurrences[val]++ + seen[val] = true + } + } + } + + // Find values that appear in all slices + result := []uint64{} + numSlices := len(slices) + for val, count := range occurrences { + if count == numSlices { + result = append(result, val) + } + } + + return result +} + +func (q *qspec) CommandStatusQF(command *clientpb.Command, replies map[uint32]*clientpb.CommandStatusResponse) (*clientpb.CommandStatusResponse, bool) { + if len(replies) < q.faulty+1 { return nil, false } - return &emptypb.Empty{}, true + successfulHighestCmds := make([]uint64, 0, len(replies)) + commandCount := make(map[uint64]int) + failedCommandIdSets := make([][]uint64, 0, len(replies)) + for _, reply := range replies { + successfulHighestCmds = append(successfulHighestCmds, reply.GetHighestSequenceNumber()) + _, ok := commandCount[reply.GetHighestSequenceNumber()] + if !ok { + commandCount[reply.GetHighestSequenceNumber()] = 1 + continue + } + commandCount[reply.GetHighestSequenceNumber()]++ + failedCommandIdSets = append(failedCommandIdSets, reply.GetFailedSequenceNumbers()) + } + leastOverlapFailedCmds := leastOverlapSet(failedCommandIdSets) + slices.Sort(successfulHighestCmds) + slices.Reverse(successfulHighestCmds) + for _, cmd := range successfulHighestCmds { + if commandCount[cmd] >= q.faulty+1 { + return &clientpb.CommandStatusResponse{ + HighestSequenceNumber: cmd, + FailedSequenceNumbers: leastOverlapFailedCmds, + }, true + } + } + return nil, false } type pendingCmd struct { sequenceNumber uint64 sendTime time.Time - promise *clientpb.AsyncEmpty cancelCtx context.CancelFunc } @@ -65,19 +118,19 @@ type Client struct { logger logging.Logger id ID - mut sync.Mutex - mgr *clientpb.Manager - gorumsConfig *clientpb.Configuration - payloadSize uint32 - highestCommitted uint64 // highest sequence number acknowledged by the replicas - pendingCmds chan pendingCmd - cancel context.CancelFunc - done chan struct{} - reader io.ReadCloser - limiter *rate.Limiter - stepUp float64 - stepUpInterval time.Duration - timeout time.Duration + mut sync.Mutex + mgr *clientpb.Manager + gorumsConfig *clientpb.Configuration + payloadSize uint32 + highestCommitted uint64 // highest sequence number acknowledged by the replicas + cancel context.CancelFunc + done chan struct{} + reader io.ReadCloser + limiter *rate.Limiter + stepUp float64 + stepUpInterval time.Duration + timeout time.Duration + failedCommandIdSet *BitSet } // New returns a new Client. @@ -92,17 +145,16 @@ func New( logger: logger, id: id, - pendingCmds: make(chan pendingCmd, conf.MaxConcurrent), - highestCommitted: 1, - done: make(chan struct{}), - reader: conf.Input, - payloadSize: conf.PayloadSize, - limiter: rate.NewLimiter(rate.Limit(conf.RateLimit), 1), - stepUp: conf.RateStep, - stepUpInterval: conf.RateStepInterval, - timeout: conf.Timeout, + highestCommitted: 1, + done: make(chan struct{}), + reader: conf.Input, + payloadSize: conf.PayloadSize, + limiter: rate.NewLimiter(rate.Limit(conf.RateLimit), 1), + stepUp: conf.RateStep, + stepUpInterval: conf.RateStepInterval, + timeout: conf.Timeout, + failedCommandIdSet: NewBitSet(5000000), // assuming max 5 million commands for now } - var creds credentials.TransportCredentials if conf.TLS { creds = credentials.NewClientTLSFromCert(conf.RootCAs, "") @@ -125,6 +177,7 @@ func (c *Client) Connect(replicas []hotstuff.ReplicaInfo) (err error) { } c.gorumsConfig, err = c.mgr.NewConfiguration(&qspec{faulty: hotstuff.NumFaulty(len(replicas))}, gorums.WithNodeMap(nodes)) if err != nil { + c.logger.Error("unable to create the configuration in client") c.mgr.Close() return err } @@ -182,6 +235,7 @@ func (c *Client) Stop() { } func (c *Client) close() { + // Signal the command handler to stop fetching statuses before closing the manager. c.mgr.Close() err := c.reader.Close() if err != nil { @@ -197,7 +251,6 @@ func (c *Client) sendCommands(ctx context.Context) error { nextLogTime = time.Now().Add(time.Second) ) -loop: for ctx.Err() == nil { // step up the rate limiter @@ -236,23 +289,12 @@ loop: SequenceNumber: num, Data: data[:n], } - - ctx, cancel := context.WithTimeout(ctx, c.timeout) - promise := c.gorumsConfig.ExecCommand(ctx, cmd) - pending := pendingCmd{sequenceNumber: num, sendTime: time.Now(), promise: promise, cancelCtx: cancel} - + c.gorumsConfig.ExecCommand(context.Background(), cmd) num++ - select { - case c.pendingCmds <- pending: - case <-ctx.Done(): - break loop - } - if time.Now().After(nextLogTime) { c.logger.Infof("%d commands sent so far", num) nextLogTime = time.Now().Add(time.Second) } - } return nil } @@ -262,38 +304,30 @@ loop: // acknowledged in the order that they were sent. func (c *Client) handleCommands(ctx context.Context) (executed, failed, timeout int) { for { - var ( - cmd pendingCmd - ok bool - ) + statusRefresher := time.NewTicker(100 * time.Millisecond) select { - case cmd, ok = <-c.pendingCmds: - if !ok { - return + case <-statusRefresher.C: + commandStatus, err := c.gorumsConfig.CommandStatus(ctx, &clientpb.Command{ + ClientID: uint32(c.id), + }) + if err != nil { + c.logger.Error("Failed to get command status: ", err) + continue } + c.mut.Lock() + if c.highestCommitted < commandStatus.HighestSequenceNumber { + c.highestCommitted = commandStatus.HighestSequenceNumber + } + for _, failedSeqNum := range commandStatus.FailedSequenceNumbers { + c.failedCommandIdSet.Add(failedSeqNum) + } + failed = c.failedCommandIdSet.Count() + executed = int(c.highestCommitted) - failed + timeout = 0 + c.mut.Unlock() case <-ctx.Done(): return } - _, err := cmd.promise.Get() - if err != nil { - if errors.Is(err, context.DeadlineExceeded) { - c.logger.Debug("Command timed out.") - timeout++ - } else if !errors.Is(err, context.Canceled) { - c.logger.Debugf("Did not get enough replies for command: %v\n", err) - failed++ - } - } else { - executed++ - } - c.mut.Lock() - if cmd.sequenceNumber > c.highestCommitted { - c.highestCommitted = cmd.sequenceNumber - } - c.mut.Unlock() - - duration := time.Since(cmd.sendTime) - c.eventLoop.AddEvent(LatencyMeasurementEvent{Latency: duration}) } } @@ -301,3 +335,59 @@ func (c *Client) handleCommands(ctx context.Context) (executed, failed, timeout type LatencyMeasurementEvent struct { Latency time.Duration } + +// BitSet is a space-efficient set for uint64 values +type BitSet struct { + mut sync.Mutex + bits []uint64 +} + +func NewBitSet(maxVal uint64) *BitSet { + size := (maxVal / 64) + 1 + return &BitSet{ + bits: make([]uint64, size), + } +} + +func (bs *BitSet) Add(val uint64) { + bs.mut.Lock() + defer bs.mut.Unlock() + + index := val / 64 + offset := val % 64 + if index < uint64(len(bs.bits)) { + bs.bits[index] |= (1 << offset) + } +} + +func (bs *BitSet) Contains(val uint64) bool { + bs.mut.Lock() + defer bs.mut.Unlock() + + index := val / 64 + offset := val % 64 + if index < uint64(len(bs.bits)) { + return (bs.bits[index] & (1 << offset)) != 0 + } + return false +} + +func (bs *BitSet) Count() int { + bs.mut.Lock() + defer bs.mut.Unlock() + + count := 0 + for _, word := range bs.bits { + count += popcount(word) + } + return count +} + +func popcount(x uint64) int { + count := 0 + for x != 0 { + x &= x - 1 + count++ + } + return count +} diff --git a/internal/proto/clientpb/client.pb.go b/internal/proto/clientpb/client.pb.go index 6ed28beb..14da714e 100644 --- a/internal/proto/clientpb/client.pb.go +++ b/internal/proto/clientpb/client.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.4 +// protoc v6.30.2 // source: internal/proto/clientpb/client.proto package clientpb @@ -23,6 +23,58 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +type CommandStatusResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + HighestSequenceNumber uint64 `protobuf:"varint,1,opt,name=highestSequenceNumber,proto3" json:"highestSequenceNumber,omitempty"` + FailedSequenceNumbers []uint64 `protobuf:"varint,2,rep,packed,name=failedSequenceNumbers,proto3" json:"failedSequenceNumbers,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *CommandStatusResponse) Reset() { + *x = CommandStatusResponse{} + mi := &file_internal_proto_clientpb_client_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *CommandStatusResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CommandStatusResponse) ProtoMessage() {} + +func (x *CommandStatusResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_proto_clientpb_client_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CommandStatusResponse.ProtoReflect.Descriptor instead. +func (*CommandStatusResponse) Descriptor() ([]byte, []int) { + return file_internal_proto_clientpb_client_proto_rawDescGZIP(), []int{0} +} + +func (x *CommandStatusResponse) GetHighestSequenceNumber() uint64 { + if x != nil { + return x.HighestSequenceNumber + } + return 0 +} + +func (x *CommandStatusResponse) GetFailedSequenceNumbers() []uint64 { + if x != nil { + return x.FailedSequenceNumbers + } + return nil +} + // Command is the request that is sent to the HotStuff replicas with the data to // be executed. type Command struct { @@ -36,7 +88,7 @@ type Command struct { func (x *Command) Reset() { *x = Command{} - mi := &file_internal_proto_clientpb_client_proto_msgTypes[0] + mi := &file_internal_proto_clientpb_client_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -48,7 +100,7 @@ func (x *Command) String() string { func (*Command) ProtoMessage() {} func (x *Command) ProtoReflect() protoreflect.Message { - mi := &file_internal_proto_clientpb_client_proto_msgTypes[0] + mi := &file_internal_proto_clientpb_client_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -61,7 +113,7 @@ func (x *Command) ProtoReflect() protoreflect.Message { // Deprecated: Use Command.ProtoReflect.Descriptor instead. func (*Command) Descriptor() ([]byte, []int) { - return file_internal_proto_clientpb_client_proto_rawDescGZIP(), []int{0} + return file_internal_proto_clientpb_client_proto_rawDescGZIP(), []int{1} } func (x *Command) GetClientID() uint32 { @@ -95,7 +147,7 @@ type Batch struct { func (x *Batch) Reset() { *x = Batch{} - mi := &file_internal_proto_clientpb_client_proto_msgTypes[1] + mi := &file_internal_proto_clientpb_client_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -107,7 +159,7 @@ func (x *Batch) String() string { func (*Batch) ProtoMessage() {} func (x *Batch) ProtoReflect() protoreflect.Message { - mi := &file_internal_proto_clientpb_client_proto_msgTypes[1] + mi := &file_internal_proto_clientpb_client_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -120,7 +172,7 @@ func (x *Batch) ProtoReflect() protoreflect.Message { // Deprecated: Use Batch.ProtoReflect.Descriptor instead. func (*Batch) Descriptor() ([]byte, []int) { - return file_internal_proto_clientpb_client_proto_rawDescGZIP(), []int{1} + return file_internal_proto_clientpb_client_proto_rawDescGZIP(), []int{2} } func (x *Batch) GetCommands() []*Command { @@ -134,15 +186,19 @@ var File_internal_proto_clientpb_client_proto protoreflect.FileDescriptor const file_internal_proto_clientpb_client_proto_rawDesc = "" + "\n" + - "$internal/proto/clientpb/client.proto\x12\bclientpb\x1a\fgorums.proto\x1a\x1bgoogle/protobuf/empty.proto\"a\n" + + "$internal/proto/clientpb/client.proto\x12\bclientpb\x1a\fgorums.proto\x1a\x1bgoogle/protobuf/empty.proto\"\x83\x01\n" + + "\x15CommandStatusResponse\x124\n" + + "\x15highestSequenceNumber\x18\x01 \x01(\x04R\x15highestSequenceNumber\x124\n" + + "\x15failedSequenceNumbers\x18\x02 \x03(\x04R\x15failedSequenceNumbers\"a\n" + "\aCommand\x12\x1a\n" + "\bClientID\x18\x01 \x01(\rR\bClientID\x12&\n" + "\x0eSequenceNumber\x18\x02 \x01(\x04R\x0eSequenceNumber\x12\x12\n" + "\x04Data\x18\x03 \x01(\fR\x04Data\"6\n" + "\x05Batch\x12-\n" + - "\bCommands\x18\x01 \x03(\v2\x11.clientpb.CommandR\bCommands2L\n" + - "\x06Client\x12B\n" + - "\vExecCommand\x12\x11.clientpb.Command\x1a\x16.google.protobuf.Empty\"\b\xa0\xb5\x18\x01ะต\x18\x01B3Z1github.com/relab/hotstuff/internal/proto/clientpbb\x06proto3" + "\bCommands\x18\x01 \x03(\v2\x11.clientpb.CommandR\bCommands2\x93\x01\n" + + "\x06Client\x12>\n" + + "\vExecCommand\x12\x11.clientpb.Command\x1a\x16.google.protobuf.Empty\"\x04\x98\xb5\x18\x01\x12I\n" + + "\rCommandStatus\x12\x11.clientpb.Command\x1a\x1f.clientpb.CommandStatusResponse\"\x04\xa0\xb5\x18\x01B3Z1github.com/relab/hotstuff/internal/proto/clientpbb\x06proto3" var ( file_internal_proto_clientpb_client_proto_rawDescOnce sync.Once @@ -156,18 +212,21 @@ func file_internal_proto_clientpb_client_proto_rawDescGZIP() []byte { return file_internal_proto_clientpb_client_proto_rawDescData } -var file_internal_proto_clientpb_client_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_internal_proto_clientpb_client_proto_msgTypes = make([]protoimpl.MessageInfo, 3) var file_internal_proto_clientpb_client_proto_goTypes = []any{ - (*Command)(nil), // 0: clientpb.Command - (*Batch)(nil), // 1: clientpb.Batch - (*emptypb.Empty)(nil), // 2: google.protobuf.Empty + (*CommandStatusResponse)(nil), // 0: clientpb.CommandStatusResponse + (*Command)(nil), // 1: clientpb.Command + (*Batch)(nil), // 2: clientpb.Batch + (*emptypb.Empty)(nil), // 3: google.protobuf.Empty } var file_internal_proto_clientpb_client_proto_depIdxs = []int32{ - 0, // 0: clientpb.Batch.Commands:type_name -> clientpb.Command - 0, // 1: clientpb.Client.ExecCommand:input_type -> clientpb.Command - 2, // 2: clientpb.Client.ExecCommand:output_type -> google.protobuf.Empty - 2, // [2:3] is the sub-list for method output_type - 1, // [1:2] is the sub-list for method input_type + 1, // 0: clientpb.Batch.Commands:type_name -> clientpb.Command + 1, // 1: clientpb.Client.ExecCommand:input_type -> clientpb.Command + 1, // 2: clientpb.Client.CommandStatus:input_type -> clientpb.Command + 3, // 3: clientpb.Client.ExecCommand:output_type -> google.protobuf.Empty + 0, // 4: clientpb.Client.CommandStatus:output_type -> clientpb.CommandStatusResponse + 3, // [3:5] is the sub-list for method output_type + 1, // [1:3] is the sub-list for method input_type 1, // [1:1] is the sub-list for extension type_name 1, // [1:1] is the sub-list for extension extendee 0, // [0:1] is the sub-list for field type_name @@ -184,7 +243,7 @@ func file_internal_proto_clientpb_client_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_internal_proto_clientpb_client_proto_rawDesc), len(file_internal_proto_clientpb_client_proto_rawDesc)), NumEnums: 0, - NumMessages: 2, + NumMessages: 3, NumExtensions: 0, NumServices: 1, }, diff --git a/internal/proto/clientpb/client.proto b/internal/proto/clientpb/client.proto index 05badc4b..b64958a3 100644 --- a/internal/proto/clientpb/client.proto +++ b/internal/proto/clientpb/client.proto @@ -12,11 +12,20 @@ service Client { // ExecCommand sends a command to all replicas and waits for valid signatures // from f+1 replicas rpc ExecCommand(Command) returns (google.protobuf.Empty) { - option (gorums.quorumcall) = true; - option (gorums.async) = true; + option (gorums.multicast) = true; + } + + rpc CommandStatus(Command) returns (CommandStatusResponse) { + option (gorums.quorumcall) = true; } } +message CommandStatusResponse { + uint64 highestSequenceNumber = 1; + repeated uint64 failedSequenceNumbers = 2; +} + + // Command is the request that is sent to the HotStuff replicas with the data to // be executed. message Command { diff --git a/internal/proto/clientpb/client_gorums.pb.go b/internal/proto/clientpb/client_gorums.pb.go index cd75440d..cf8fa150 100644 --- a/internal/proto/clientpb/client_gorums.pb.go +++ b/internal/proto/clientpb/client_gorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-gorums. DO NOT EDIT. // versions: // protoc-gen-gorums v0.10.0-devel -// protoc v6.33.4 +// protoc v6.30.2 // source: internal/proto/clientpb/client.proto package clientpb @@ -150,75 +150,84 @@ type Node struct { *gorums.RawNode } +// ClientClient is the client interface for the Client service. +type ClientClient interface { + ExecCommand(ctx context.Context, in *Command, opts ...gorums.CallOption) + CommandStatus(ctx context.Context, in *Command) (resp *CommandStatusResponse, err error) +} + +// enforce interface compliance +var _ ClientClient = (*Configuration)(nil) + +// Reference imports to suppress errors if they are not otherwise used. +var _ emptypb.Empty + // ExecCommand sends a command to all replicas and waits for valid signatures // from f+1 replicas -func (c *Configuration) ExecCommand(ctx context.Context, in *Command) *AsyncEmpty { +func (c *Configuration) ExecCommand(ctx context.Context, in *Command, opts ...gorums.CallOption) { cd := gorums.QuorumCallData{ Message: in, Method: "clientpb.Client.ExecCommand", } - cd.QuorumFunction = func(req proto.Message, replies map[uint32]proto.Message) (proto.Message, bool) { - r := make(map[uint32]*emptypb.Empty, len(replies)) - for k, v := range replies { - r[k] = v.(*emptypb.Empty) - } - return c.qspec.ExecCommandQF(req.(*Command), r) - } - - fut := c.RawConfiguration.AsyncCall(ctx, cd) - return &AsyncEmpty{fut} -} -// ClientClient is the client interface for the Client service. -type ClientClient interface { - ExecCommand(ctx context.Context, in *Command) *AsyncEmpty + c.RawConfiguration.Multicast(ctx, cd, opts...) } -// enforce interface compliance -var _ ClientClient = (*Configuration)(nil) - // QuorumSpec is the interface of quorum functions for Client. type QuorumSpec interface { gorums.ConfigOption - // ExecCommandQF is the quorum function for the ExecCommand - // asynchronous quorum call method. The in parameter is the request object - // supplied to the ExecCommand method at call time, and may or may not + // CommandStatusQF is the quorum function for the CommandStatus + // quorum call method. The in parameter is the request object + // supplied to the CommandStatus method at call time, and may or may not // be used by the quorum function. If the in parameter is not needed // you should implement your quorum function with '_ *Command'. - ExecCommandQF(in *Command, replies map[uint32]*emptypb.Empty) (*emptypb.Empty, bool) + CommandStatusQF(in *Command, replies map[uint32]*CommandStatusResponse) (*CommandStatusResponse, bool) +} + +// CommandStatus is a quorum call invoked on all nodes in configuration c, +// with the same argument in, and returns a combined result. +func (c *Configuration) CommandStatus(ctx context.Context, in *Command) (resp *CommandStatusResponse, err error) { + cd := gorums.QuorumCallData{ + Message: in, + Method: "clientpb.Client.CommandStatus", + } + cd.QuorumFunction = func(req proto.Message, replies map[uint32]proto.Message) (proto.Message, bool) { + r := make(map[uint32]*CommandStatusResponse, len(replies)) + for k, v := range replies { + r[k] = v.(*CommandStatusResponse) + } + return c.qspec.CommandStatusQF(req.(*Command), r) + } + + res, err := c.RawConfiguration.QuorumCall(ctx, cd) + if err != nil { + return nil, err + } + return res.(*CommandStatusResponse), err } // Client is the server-side API for the Client Service type ClientServer interface { - ExecCommand(ctx gorums.ServerCtx, request *Command) (response *emptypb.Empty, err error) + ExecCommand(ctx gorums.ServerCtx, request *Command) + CommandStatus(ctx gorums.ServerCtx, request *Command) (response *CommandStatusResponse, err error) } func RegisterClientServer(srv *gorums.Server, impl ClientServer) { srv.RegisterHandler("clientpb.Client.ExecCommand", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Command](in) - resp, err := impl.ExecCommand(ctx, req) + impl.ExecCommand(ctx, req) + return nil, nil + }) + srv.RegisterHandler("clientpb.Client.CommandStatus", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { + req := gorums.AsProto[*Command](in) + resp, err := impl.CommandStatus(ctx, req) return gorums.NewResponseMessage(in.GetMetadata(), resp), err }) } -type internalEmpty struct { +type internalCommandStatusResponse struct { nid uint32 - reply *emptypb.Empty + reply *CommandStatusResponse err error } - -// AsyncEmpty is a async object for processing replies. -type AsyncEmpty struct { - *gorums.Async -} - -// Get returns the reply and any error associated with the called method. -// The method blocks until a reply or error is available. -func (f *AsyncEmpty) Get() (*emptypb.Empty, error) { - resp, err := f.Async.Get() - if err != nil { - return nil, err - } - return resp.(*emptypb.Empty), err -} diff --git a/server/clientio.go b/server/clientio.go index 8f88871e..54f786f0 100644 --- a/server/clientio.go +++ b/server/clientio.go @@ -10,26 +10,94 @@ import ( "github.com/relab/hotstuff/core/eventloop" "github.com/relab/hotstuff/core/logging" "github.com/relab/hotstuff/internal/proto/clientpb" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/emptypb" ) -// ClientIO serves a client. +const ( + // hotstuffFailedCommandLength is the maximum number of failed sequence numbers to track per client + hotstuffFailedCommandLength = 100 +) + +// clientStatusWindow tracks command statuses for a single client. +// It stores the highest successfully executed sequence number and a list of failed sequence numbers. +type clientStatusWindow struct { + // HighestSuccess is the highest successfully executed sequence number for this client + HighestSuccess uint64 + // FailedCmds is the list of recent failed sequence numbers (up to maxFailed entries) + FailedCmds []uint64 +} + +// CommandStatusTracker stores per-client command status information. +// It tracks the highest successful sequence number and recent failed sequence numbers for each client. +type CommandStatusTracker struct { + // clientWindows maps ClientID to its status window + clientWindows map[uint32]*clientStatusWindow + // maxFailed is the maximum number of failed sequence numbers to track per client + maxFailed int +} + +// ensureWindow returns the status window for clientID, creating it if it doesn't exist. +func (cst *CommandStatusTracker) ensureWindow(clientID uint32) *clientStatusWindow { + w, ok := cst.clientWindows[clientID] + if !ok { + w = &clientStatusWindow{ + FailedCmds: make([]uint64, 0, cst.maxFailed), + } + cst.clientWindows[clientID] = w + } + return w +} + +// addFailed records a failed sequence number in the window. +// If the window is full (maxFailed entries), it removes the oldest entry before adding the new one. +func (cst *CommandStatusTracker) addFailed(client uint32, seq uint64) { + w := cst.ensureWindow(client) + // If we have reached maxFailed, remove the oldest entry + if len(w.FailedCmds) >= cst.maxFailed { + w.FailedCmds = w.FailedCmds[1:] + } + w.FailedCmds = append(w.FailedCmds, seq) +} + +// NewCommandStatusTracker creates a new CommandStatusTracker with default settings. +func NewCommandStatusTracker() *CommandStatusTracker { + return &CommandStatusTracker{ + clientWindows: make(map[uint32]*clientStatusWindow), + maxFailed: hotstuffFailedCommandLength, + } +} + +// setSuccess updates the highest successfully executed sequence number for a client. +// It only updates if the new sequence number is higher than the current highest. +func (cst *CommandStatusTracker) setSuccess(clientID uint32, seqNum uint64) { + w := cst.ensureWindow(clientID) + if seqNum < w.HighestSuccess { + return + } + w.HighestSuccess = seqNum +} + +// GetClientStatuses returns the status window for a given client. +// If the client doesn't exist, it creates and returns an empty window. +func (cst *CommandStatusTracker) GetClientStatuses(clientID uint32) *clientStatusWindow { + return cst.ensureWindow(clientID) +} + +// ClientIO serves client requests and manages command execution tracking. type ClientIO struct { logger logging.Logger cmdCache *clientpb.CommandCache - mut sync.Mutex - srv *gorums.Server - awaitingCmds map[clientpb.MessageID]chan<- error - hash hash.Hash - cmdCount uint32 + mut sync.Mutex + srv *gorums.Server + hash hash.Hash + cmdCount uint32 - lastExecutedSeqNum map[uint32]uint64 // highest executed sequence number per client ID + lastExecutedSeqNum map[uint32]uint64 // tracks the highest executed sequence number per client ID + statusTracker *CommandStatusTracker // tracks command execution status (success/failure) per client } -// NewClientIO returns a new client IO server. +// NewClientIO creates and returns a new ClientIO server instance. +// It registers the server with the event loop to handle Execute and Abort events. func NewClientIO( el *eventloop.EventLoop, logger logging.Logger, @@ -40,10 +108,10 @@ func NewClientIO( logger: logger, cmdCache: cmdCache, - awaitingCmds: make(map[clientpb.MessageID]chan<- error), srv: gorums.NewServer(srvOpts...), hash: sha256.New(), lastExecutedSeqNum: make(map[uint32]uint64), + statusTracker: NewCommandStatusTracker(), } clientpb.RegisterClientServer(srv.srv, srv) eventloop.Register(el, func(event clientpb.ExecuteEvent) { @@ -55,6 +123,7 @@ func NewClientIO( return srv } +// StartOnListener starts the gRPC server on the provided listener. func (srv *ClientIO) StartOnListener(lis net.Listener) { go func() { err := srv.srv.Serve(lis) @@ -64,72 +133,74 @@ func (srv *ClientIO) StartOnListener(lis net.Listener) { }() } +// Stop stops the gRPC server. func (srv *ClientIO) Stop() { srv.srv.Stop() } +// Hash returns the current hash of all executed commands. func (srv *ClientIO) Hash() hash.Hash { return srv.hash } +// CmdCount returns the total number of executed commands. func (srv *ClientIO) CmdCount() uint32 { return srv.cmdCount } -func (srv *ClientIO) ExecCommand(ctx gorums.ServerCtx, cmd *clientpb.Command) (*emptypb.Empty, error) { - id := cmd.ID() - errChan := make(chan error) - - srv.mut.Lock() - srv.awaitingCmds[id] = errChan - srv.mut.Unlock() - +// ExecCommand receives a command from a client and adds it to the command cache. +func (srv *ClientIO) ExecCommand(ctx gorums.ServerCtx, cmd *clientpb.Command) { srv.cmdCache.Add(cmd) ctx.Release() - err := <-errChan - return &emptypb.Empty{}, err } +// Exec executes a batch of commands, updating the hash and command count. +// It skips duplicate commands and marks successful executions in the status tracker. func (srv *ClientIO) Exec(batch *clientpb.Batch) { for _, cmd := range batch.GetCommands() { - id := cmd.ID() srv.mut.Lock() if srv.isDuplicate(cmd) { srv.logger.Info("duplicate command found") - srv.completeCommand(id, status.Error(codes.Aborted, "command already executed")) srv.mut.Unlock() continue } srv.lastExecutedSeqNum[cmd.ClientID] = cmd.SequenceNumber + // Mark command as executed in status tracker + srv.statusTracker.setSuccess(cmd.ClientID, cmd.SequenceNumber) _, _ = srv.hash.Write(cmd.Data) srv.cmdCount++ - srv.completeCommand(id, nil) srv.mut.Unlock() } srv.logger.Debugf("Hash: %.8x", srv.hash.Sum(nil)) } +// Abort marks a batch of commands as failed in the status tracker. func (srv *ClientIO) Abort(batch *clientpb.Batch) { for _, cmd := range batch.GetCommands() { srv.mut.Lock() - srv.completeCommand(cmd.ID(), status.Error(codes.Aborted, "blockchain was forked")) + // Mark command as aborted in status tracker + srv.statusTracker.addFailed(cmd.ClientID, cmd.SequenceNumber) srv.mut.Unlock() } } -// isDuplicate return true if the command has already been executed. +// isDuplicate returns true if the command has already been executed. // The caller must hold srv.mut.Lock(). func (srv *ClientIO) isDuplicate(cmd *clientpb.Command) bool { seqNum, ok := srv.lastExecutedSeqNum[cmd.ClientID] return ok && seqNum >= cmd.SequenceNumber } -// completeCommand sends an error or nil to the awaiting client's error channel. -// The caller must hold srv.mut.Lock(). -func (srv *ClientIO) completeCommand(id clientpb.MessageID, err error) { - if errChan, ok := srv.awaitingCmds[id]; ok { - errChan <- err - delete(srv.awaitingCmds, id) - } +// CommandStatus returns the execution status for a given command. +// It returns the highest executed sequence number and list of failed sequence numbers for the client. +func (srv *ClientIO) CommandStatus(_ gorums.ServerCtx, in *clientpb.Command) (resp *clientpb.CommandStatusResponse, err error) { + srv.mut.Lock() + defer srv.mut.Unlock() + CommandStatus := srv.statusTracker.GetClientStatuses(in.ClientID) + + return &clientpb.CommandStatusResponse{ + HighestSequenceNumber: CommandStatus.HighestSuccess, + FailedSequenceNumbers: CommandStatus.FailedCmds, + }, nil } diff --git a/server/clientio_test.go b/server/clientio_test.go new file mode 100644 index 00000000..723a76bd --- /dev/null +++ b/server/clientio_test.go @@ -0,0 +1,88 @@ +package server + +import ( + "testing" +) + +func TestCommandStatusTracker_SetSuccess(t *testing.T) { + tr := NewCommandStatusTracker() + clientID := uint32(1) + + tr.setSuccess(clientID, 10) + w := tr.GetClientStatuses(clientID) + + if w.HighestSuccess != 10 { + t.Fatalf("HighestSuccess = %d; want 10", w.HighestSuccess) + } + + // Setting lower sequence number should not update + tr.setSuccess(clientID, 5) + w = tr.GetClientStatuses(clientID) + if w.HighestSuccess != 10 { + t.Fatalf("HighestSuccess = %d; want 10", w.HighestSuccess) + } + + // Setting higher should update + tr.setSuccess(clientID, 20) + w = tr.GetClientStatuses(clientID) + if w.HighestSuccess != 20 { + t.Fatalf("HighestSuccess = %d; want 20", w.HighestSuccess) + } +} + +func TestCommandStatusTracker_AddFailed(t *testing.T) { + tr := NewCommandStatusTracker() + tr.maxFailed = 3 // Set small limit for testing + clientID := uint32(2) + + tr.addFailed(clientID, 10) + tr.addFailed(clientID, 15) + tr.addFailed(clientID, 20) + + w := tr.GetClientStatuses(clientID) + if len(w.FailedCmds) != 3 { + t.Fatalf("FailedCmds length = %d; want 3", len(w.FailedCmds)) + } + + // Adding one more should drop the oldest + tr.addFailed(clientID, 25) + w = tr.GetClientStatuses(clientID) + if len(w.FailedCmds) != 3 { + t.Fatalf("FailedCmds length = %d; want 3", len(w.FailedCmds)) + } + + // Should contain 15, 20, 25 (10 should be dropped) + expected := []uint64{15, 20, 25} + for i, seq := range w.FailedCmds { + if seq != expected[i] { + t.Fatalf("FailedCmds[%d] = %d; want %d", i, seq, expected[i]) + } + } +} + +func TestCommandStatusTracker_GetClientStatuses(t *testing.T) { + tr := NewCommandStatusTracker() + clientID := uint32(3) + + // New client should have empty window + w := tr.GetClientStatuses(clientID) + if w.HighestSuccess != 0 { + t.Fatalf("HighestSuccess = %d; want 0", w.HighestSuccess) + } + if len(w.FailedCmds) != 0 { + t.Fatalf("FailedCmds length = %d; want 0", len(w.FailedCmds)) + } + + // Add some data + tr.setSuccess(clientID, 50) + tr.addFailed(clientID, 52) + tr.addFailed(clientID, 55) + + w = tr.GetClientStatuses(clientID) + if w.HighestSuccess != 50 { + t.Fatalf("HighestSuccess = %d; want 50", w.HighestSuccess) + } + if len(w.FailedCmds) != 2 { + t.Fatalf("FailedCmds length = %d; want 2", len(w.FailedCmds)) + } +}