From 8bb47022f48b6f4363b010598c617550ea0df3ab Mon Sep 17 00:00:00 2001 From: Mateo Presa Castro Date: Sun, 7 Dec 2025 13:53:58 +0100 Subject: [PATCH] fix: transfer leader on close --- kv/kv.go | 17 +++++++++++++---- mokv/mokv.go | 5 +++++ server/server_test.go | 4 ++++ 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/kv/kv.go b/kv/kv.go index ecd06c4..3921d37 100644 --- a/kv/kv.go +++ b/kv/kv.go @@ -204,14 +204,20 @@ func (kv *KV) Join(id, addr string) error { // Leave removes a node from the Raft cluster. func (kv *KV) Leave(id string) error { - removeFuture := kv.raft.RemoveServer(raft.ServerID(id), 0, 0) - return removeFuture.Error() + future := kv.raft.RemoveServer(raft.ServerID(id), 0, 0) + return future.Error() } // Close gracefully shuts down the Raft node. func (kv *KV) Close() error { - f := kv.raft.Shutdown() - return f.Error() + if kv.isLeader() { + future := kv.raft.LeadershipTransfer() + if err := future.Error(); err != nil { + return err + } + } + future := kv.raft.Shutdown() + return future.Error() } // WaitForLeader blocks until the Raft node detects a leader in the cluster or times out. @@ -230,6 +236,9 @@ func (kv *KV) WaitForLeader(timeout time.Duration) error { } } } +func (kv *KV) isLeader() bool { + return kv.raft.State() == raft.Leader +} // apply encodes the request and submits it to Raft. func (kv *KV) apply(reqType RequestType, req proto.Message) (any, error) { diff --git a/mokv/mokv.go b/mokv/mokv.go index 5d02e60..a6f6f96 100644 --- a/mokv/mokv.go +++ b/mokv/mokv.go @@ -47,6 +47,7 @@ type Storer interface { Set(key string, value []byte) error Delete(key string) error List() <-chan []byte + Close() error } // GetEnv defines a function signature for retrieving environment variables. @@ -227,6 +228,10 @@ func (r *MOKV) close(ctx context.Context) error { errs = append(errs, fmt.Errorf("membership leave error: %w", err)) } + if err := r.kv.Close(); err != nil { + errs = append(errs, fmt.Errorf("KV close error: %w", err)) + } + r.grpcServer.GracefulStop() if err := r.meterProvider.Shutdown(ctx); err != nil { diff --git a/server/server_test.go b/server/server_test.go index aa07aaa..a85a3a8 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -29,6 +29,10 @@ func (t *testKV) GetServers() ([]*api.Server, error) { return nil, nil } +func (t *testKV) Close() error { + return nil +} + // setupTestServer creates and starts a test server, returning cleanup function func setupTestServer(t *testing.T) (api.KVClient, func()) { t.Helper()