diff --git a/cmd/server.go b/cmd/server.go index 82bde3c..9df1e62 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "database/sql" "fmt" "net" "strings" @@ -21,6 +22,9 @@ import ( _ "github.com/jackc/pgx/v5/stdlib" "github.com/redis/go-redis/v9" "google.golang.org/grpc" + "google.golang.org/grpc/health" + healthgrpc "google.golang.org/grpc/health/grpc_health_v1" + healthpb "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/reflection" ) @@ -44,6 +48,8 @@ type Server struct { // Dev and logging Reflection bool `env:"SOFTPLAYER_REFLECTION" default:"false"` DevMode bool `env:"SOFTPLAYER_DEV_MODE" default:"false"` + // HealthChecks + HealthCheckInterval time.Duration `env:"SOFTPLAYER_HEALTH_CHECK_INTERVAL" default:"5s"` } // Run the grpc backend server @@ -121,6 +127,33 @@ func (cmd *Server) Run(ctx context.Context) error { tokens.RegisterTokensServiceServer(grpcServer, v1.NewTokensServer(tokenCtrl, authController)) tokens.RegisterPublicTokensServiceServer(grpcServer, v1.NewPublicTokensServer(tokenCtrl, authController)) + healthcheck := health.NewServer() + healthgrpc.RegisterHealthServer(grpcServer, healthcheck) + + go func() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for range ticker.C { + + // Example checks + dbOK := checkDatabase(db) + redisOK := checkRedis(rdb) + + status := healthpb.HealthCheckResponse_SERVING + + if !dbOK || !redisOK { + status = healthpb.HealthCheckResponse_NOT_SERVING + } + + healthcheck.SetServingStatus( + "", + status, + ) + + } + }() + info := grpcServer.GetServiceInfo() tokenCtrl.SetGRPCInfo(info) tokenCtrl.SetRules() @@ -149,6 +182,10 @@ func selectorRequireAuth(ctx context.Context, callMeta interceptors.CallMeta) bo return false } + if serviceName == "Health" { + return false + } + if strings.Contains(serviceName, "ServerReflection") { return false } @@ -159,3 +196,33 @@ func selectorRequireAuth(ctx context.Context, callMeta interceptors.CallMeta) bo return true } + +func checkDatabase(db *sql.DB) bool { + ctx, cancel := context.WithTimeout( + context.Background(), + 2*time.Second, + ) + defer cancel() + + // Fast connectivity check + if err := db.PingContext(ctx); err != nil { + return false + } + + return true +} + +func checkRedis(rdb *redis.Client) bool { + ctx, cancel := context.WithTimeout( + context.Background(), + 2*time.Second, + ) + defer cancel() + + err := rdb.Ping(ctx).Err() + if err != nil { + return false + } + + return true +}