diff --git a/pkg/raftstore/consensus/consensus.go b/pkg/raftstore/consensus/consensus.go index 2541b9141..252e5baae 100644 --- a/pkg/raftstore/consensus/consensus.go +++ b/pkg/raftstore/consensus/consensus.go @@ -6,8 +6,10 @@ import ( "fmt" "net/http" "net/url" + "time" "github.com/interuss/dss/pkg/logging" + params "github.com/interuss/dss/pkg/raftstore/params" "github.com/interuss/stacktrace" "go.etcd.io/etcd/client/pkg/v3/types" "go.etcd.io/etcd/server/v3/etcdserver/api/rafthttp" @@ -17,43 +19,89 @@ import ( "go.uber.org/zap" ) -const ( - defaultClusterID uint64 = 1 -) - type Consensus struct { logger *zap.Logger + id uint64 node raft.Node - transport *rafthttp.Transport - server *http.Server + removedPeers map[uint64]bool + transport *rafthttp.Transport + server *http.Server storage *storage - errorC chan error + + confState raftpb.ConfState + snapshotIndex uint64 + appliedIndex uint64 } -func NewConsensus(ctx context.Context, logger *zap.Logger, nodeID uint64, peers map[uint64]*url.URL, dataDir string, snapshotCatchupEntries uint64) (*Consensus, error) { - storage, _, err := newStorage(ctx, logger.With(zap.String("component", "storage")), dataDir, nodeID, snapshotCatchupEntries) +func NewConsensus(ctx context.Context, logger *zap.Logger, peers map[uint64]*url.URL, connectParams params.ConnectParameters) (*Consensus, error) { + storage, old, err := newStorage(ctx, logger.With(zap.String("component", "storage")), connectParams.DataDir, connectParams.ID, connectParams.SnapshotCatchupEntries) if err != nil { return nil, stacktrace.Propagate(err, "failed to initialize storage") } + var node raft.Node + config := connectParams.RaftConfig(storage) + if old { + node = raft.RestartNode(config) + } else { + node = raft.StartNode(config, peersList(peers)) + } + consensus := &Consensus{ logger: logging.WithValuesFromContext(ctx, logger), + id: connectParams.ID, + node: node, + + removedPeers: make(map[uint64]bool), + storage: storage, - errorC: make(chan error, 1), } - err = consensus.initTransport(ctx, nodeID, defaultClusterID, peers) + err = consensus.initTransport(ctx, connectParams.ID, connectParams.ClusterID, peers) if err != nil { return nil, stacktrace.Propagate(err, "failed to initialize transport") } + snap, err := consensus.storage.Snapshot() + if err != nil { + return nil, stacktrace.Propagate(err, "failed to get snapshot from storage") + } + + consensus.confState = snap.Metadata.ConfState + consensus.snapshotIndex = snap.Metadata.Index + consensus.appliedIndex = snap.Metadata.Index + + go func() { + err := consensus.handleReady(connectParams.TickInterval) + if err != nil { + consensus.logger.Error("handleReady exited with error, shutting down consensus", zap.Error(err)) + } + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if shutdownErr := consensus.server.Shutdown(shutdownCtx); shutdownErr != nil { + consensus.logger.Error("failed to shutdown http server", zap.Error(shutdownErr)) + } + + consensus.transport.Stop() + consensus.node.Stop() + }() + return consensus, nil } +func peersList(peers map[uint64]*url.URL) []raft.Peer { + result := make([]raft.Peer, 0, len(peers)) + for id := range peers { + result = append(result, raft.Peer{ID: id}) + } + return result +} + func (c *Consensus) initTransport(ctx context.Context, nodeID uint64, clusterID uint64, peers map[uint64]*url.URL) error { nodeIDStr := fmt.Sprintf("%d", nodeID) @@ -64,7 +112,7 @@ func (c *Consensus) initTransport(ctx context.Context, nodeID uint64, clusterID Raft: c, ServerStats: v2stats.NewServerStats(nodeIDStr, nodeIDStr), LeaderStats: v2stats.NewLeaderStats(c.logger, nodeIDStr), - ErrorC: c.errorC, + ErrorC: make(chan error), } err := transport.Start() @@ -95,7 +143,7 @@ func (c *Consensus) initTransport(ctx context.Context, nodeID uint64, clusterID err := c.server.ListenAndServe() if err != nil && !errors.Is(err, http.ErrServerClosed) { c.logger.Error("http server error", zap.Error(err)) - c.errorC <- err + c.transport.ErrorC <- err } }() @@ -103,6 +151,99 @@ func (c *Consensus) initTransport(ctx context.Context, nodeID uint64, clusterID return nil } +// handleReady processes the Ready channel of the Raft node and applies committed entries to the state machine +func (c *Consensus) handleReady(tickInterval time.Duration) error { + ticker := time.NewTicker(tickInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + c.node.Tick() + case err := <-c.transport.ErrorC: + return stacktrace.Propagate(err, "transport error") + case rd, ok := <-c.node.Ready(): + if !ok { + return stacktrace.NewError("could not read from Ready(), shutting down handler") + } + + err := c.storage.handleReceivedState(rd.Snapshot, rd.HardState, rd.Entries) + if err != nil { + return stacktrace.Propagate(err, "failed to handle received snapshot") + } + + if !raft.IsEmptySnap(rd.Snapshot) { + if rd.Snapshot.Metadata.Index <= c.appliedIndex { + return stacktrace.NewError("snapshot index %d is not greater than applied index %d", rd.Snapshot.Metadata.Index, c.appliedIndex) + } + + err = c.dispatchSnapshot(rd.Snapshot.Data) + if err != nil { + return stacktrace.Propagate(err, "failed to dispatch snapshot") + } + + c.confState = rd.Snapshot.Metadata.ConfState + c.snapshotIndex = rd.Snapshot.Metadata.Index + c.appliedIndex = rd.Snapshot.Metadata.Index + } + + c.transport.Send(c.checkUpdateConfState(rd.Messages)) + + entries, err := c.entriesToApply(rd.CommittedEntries) + if err != nil { + return stacktrace.Propagate(err, "failed to get entries to apply") + } + + err = c.publishEntries(entries) + if err != nil { + return stacktrace.Propagate(err, "failed to publish entries") + } + + c.node.Advance() + } + } +} + +// TODO implement +func (c *Consensus) publishEntries(_ []raftpb.Entry) error { + return nil +} + +// TODO implement +func (c *Consensus) dispatchSnapshot(_ []byte) error { + return nil +} + +func (c *Consensus) entriesToApply(entries []raftpb.Entry) ([]raftpb.Entry, error) { + if len(entries) == 0 { + return entries, nil + } + + result := make([]raftpb.Entry, 0) + + firstIdx := entries[0].Index + if firstIdx > c.appliedIndex+1 { + return nil, stacktrace.NewError("first index of committed entry[%d] should <= progress.appliedIndex[%d]+1", firstIdx, c.appliedIndex) + } + + if c.appliedIndex-firstIdx+1 < uint64(len(entries)) { + result = entries[c.appliedIndex-firstIdx+1:] + } + + return result, nil +} + +// checkUpdateConfState checks if any of the messages to be sent contain a snapshot +// and updates the ConfState in the snapshot as it could be outdated. +func (c *Consensus) checkUpdateConfState(msgs []raftpb.Message) []raftpb.Message { + for _, msg := range msgs { + if msg.Type == raftpb.MsgSnap { + msg.Snapshot.Metadata.ConfState = c.confState + } + } + return msgs +} + // Process implements the rafthttp.Raft interface. func (c *Consensus) Process(ctx context.Context, m raftpb.Message) error { return c.node.Step(ctx, m) diff --git a/pkg/raftstore/params/params.go b/pkg/raftstore/params/params.go index a135813d2..d5906dba4 100644 --- a/pkg/raftstore/params/params.go +++ b/pkg/raftstore/params/params.go @@ -5,8 +5,10 @@ import ( "net/url" "strconv" "strings" + "time" "github.com/interuss/stacktrace" + "go.etcd.io/raft/v3" ) const ( @@ -14,7 +16,14 @@ const ( // the default Raft related parameters are the same as the default values used by etcd for the moment. // TODO - review and adjust these parameters as needed based on testing and performance tuning. - defaultSnapshotCatchupEntries = 10000 + defaultSnapshotCatchupEntries = 5000 + defaultSnapshotIntervalEntries = 10000 + defaultTickInterval = 100 * time.Millisecond + + defaultElectionTick = 10 + defaultHeartbeatTick = 1 + defaultMaxSizePerMsg = 1024 * 1024 + defaultMaxInflightMsgs = 4096 / 8 ) type ( @@ -28,11 +37,19 @@ type ( // a full snapshot from the leader. It must not be deleted while the node is running or // across restarts unless the node is being permanently shut down. // If the directory is lost, the node will recover by receiving a snapshot from the leader. - DataDir string + DataDir string + ClusterID uint64 // SnapshotCatchupEntries is the number of entries for a slow follower to catch-up after compacting. // This gives the follower a buffer of entries while avoiding the need to send a full snapshot. - SnapshotCatchupEntries uint64 + SnapshotCatchupEntries uint64 + SnapshotIntervalEntries uint64 + TickInterval time.Duration + + ElectionTick int + HeartbeatTick int + MaxSizePerMsg uint64 + MaxInflightMsgs int } ) @@ -70,16 +87,35 @@ func (c ConnectParameters) PeerMap() (map[uint64]*url.URL, error) { return peers, nil } +func (c ConnectParameters) RaftConfig(storage raft.Storage) *raft.Config { + return &raft.Config{ + ID: c.ID, + ElectionTick: c.ElectionTick, + HeartbeatTick: c.HeartbeatTick, + MaxSizePerMsg: c.MaxSizePerMsg, + MaxInflightMsgs: c.MaxInflightMsgs, + Storage: storage, + } +} + var ( connectParameters ConnectParameters ) func init() { flag.Uint64Var(&connectParameters.ID, "raft_node_id", 0, "raft node ID for this instance (must be non-zero and unique within the cluster)") + flag.Uint64Var(&connectParameters.ClusterID, "raft_cluster_id", 1, "id of the cluster, used to isolate different Raft clusters running in the same network (must be the same for all nodes in the cluster)") flag.StringVar(&connectParameters.Peers, "raft_peers", "", `comma-separated "nodeID=peerURL" pairs for all cluster members, including the current node, e.g. "1=http://node1:9021,2=http://node2:9021,3=http://node3:9021"`) flag.StringVar(&connectParameters.DataDir, "raft_datadir", defaultDataDir, "directory for raft data (WAL segments and snapshots), required for restarts. These should not be deleted while the node is running or across restarts unless the node is being permanently shut down.") flag.Uint64Var(&connectParameters.SnapshotCatchupEntries, "raft_snapshot_catchup_entries", defaultSnapshotCatchupEntries, "number of entries for a slow follower to catch-up after compacting") + flag.Uint64Var(&connectParameters.SnapshotIntervalEntries, "raft_snapshot_interval_entries", defaultSnapshotIntervalEntries, "number of entries between snapshots") + flag.DurationVar(&connectParameters.TickInterval, "raft_tick_interval", defaultTickInterval, "interval between raft ticks, controls the logical clock of the Raft node and thus the timing of elections and heartbeats") + + flag.IntVar(&connectParameters.ElectionTick, "raft_election_tick", defaultElectionTick, "number of ticks without a leader heartbeat before a follower starts an election") + flag.IntVar(&connectParameters.HeartbeatTick, "raft_heartbeat_tick", defaultHeartbeatTick, "number of ticks between leader heartbeats (must be less than raft_election_tick)") + flag.Uint64Var(&connectParameters.MaxSizePerMsg, "raft_max_size_per_msg", defaultMaxSizePerMsg, "max bytes per raft message (0 = unlimited)") + flag.IntVar(&connectParameters.MaxInflightMsgs, "raft_max_inflight_msgs", defaultMaxInflightMsgs, "max number of in-flight messages") } // GetConnectParameters returns a ConnectParameters instance that gets populated from well-known CLI flags. diff --git a/pkg/raftstore/store.go b/pkg/raftstore/store.go index b602d31c3..3ef62c157 100644 --- a/pkg/raftstore/store.go +++ b/pkg/raftstore/store.go @@ -31,7 +31,7 @@ func Init[R any](ctx context.Context, logger *zap.Logger, newRepo func() R) (*St return } - sharedConsensus, sharedConsensusErr = consensus.NewConsensus(ctx, logger, params.ID, peers, params.DataDir, params.SnapshotCatchupEntries) + sharedConsensus, sharedConsensusErr = consensus.NewConsensus(ctx, logger, peers, params) if sharedConsensusErr != nil { sharedConsensusErr = stacktrace.Propagate(sharedConsensusErr, "failed to initialize consensus") }