package controllers import ( "context" "crypto/sha256" "database/sql" "encoding/hex" "encoding/json" "errors" "fmt" "regexp" "testing" "time" "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/logger" "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/token" "github.com/google/uuid" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "google.golang.org/grpc" ) type TokenController struct { DB *sql.DB HashCost int16 ServiceInfo map[string]grpc.ServiceInfo rules []rule Redis *redis.Client } // Services that are not available for tokens var DisabledServicesRegex = []string{".*Accounts.*", ".*Tokens.*"} // Errors var ( ErrServerError = errors.New("internal server error") ErrUserTokenMismatch = errors.New("user doesn't match the token") ErrBadToken = errors.New("bad token") ) type TokenData struct { UUID string Name string UserID string CreatedAt time.Time LastUsedAt time.Time RevokedAt time.Time ExpiresAt time.Time GeneratedAt time.Time Scopes map[string][]string } type Scopes struct{} // Set the grpc info, must happen after all the service are initialized func (ctrl *TokenController) SetGRPCInfo(info map[string]grpc.ServiceInfo) { ctrl.ServiceInfo = info } func (ctrl *TokenController) SetRules() { rules := []rule{ {re: regexp.MustCompile(`.*Tokens.*`)}, {re: regexp.MustCompile(`.*Accounts.*`)}, {re: regexp.MustCompile(`.*Reflection.*`)}, } ctrl.rules = rules } // Each token operation must first verify that the current user // is allowed to manipulate the token. func (ctrl *TokenController) VerifyTokenOwner(ctx context.Context, userID, tokenID string) error { log := logger.FromContext(ctx).WithValues("uuid", tokenID, "user_id", userID) log.V(2).Info("Verifying the token owner") // First try to get from the redis redisKey := fmt.Sprintf("token:%s", tokenID) realUserID := ctrl.Redis.Get(ctx, redisKey).Val() // If not found in cache, get from postgres if realUserID == "" { query := "SELECT user_id FROM tokens WHERE uuid = $1;" if err := ctrl.DB.QueryRowContext(ctx, query, tokenID).Scan( &realUserID, ); err != nil { log.Error(err, "Couldn't get user_id for a token") return ErrServerError } } if realUserID != userID { return ErrUserTokenMismatch } err := ctrl.Redis.Set(ctx, redisKey, realUserID, time.Hour) if err != nil { log.Info("Couldn't write to cache", "error", err) } return nil } // Create a new token, store its hash in the database and return the token value func (ctrl *TokenController) Create(ctx context.Context, data *TokenData) (string, error) { id := uuid.NewString() log := logger.FromContext(ctx).WithValues("uuid", id) log.V(2).Info("Creating a new token") tokenValue, err := token.GenerateToken() if err != nil { log.Error(err, "Couldn't create a token") return "", ErrServerError } tokenHash := hashSHA256(tokenValue) scopesJson, err := json.Marshal(data.Scopes) if err != nil { log.Error(err, "Couldn't marshal permissions into json") return "", ErrServerError } query := ` INSERT INTO tokens (uuid, description, token_hash, user_id, scopes, created_at, generated_at, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)` if _, err := ctrl.DB.Query( query, id, data.Name, tokenHash, data.UserID, scopesJson, time.Now(), time.Now(), data.ExpiresAt, ); err != nil { log.Error(err, "Couldn't insert a token in the database") return "", ErrServerError } return tokenValue, nil } // Update token name or permissions, other changes are ignored by this method func (ctrl *TokenController) Update(ctx context.Context, data *TokenData) error { log := logger.FromContext(ctx).WithValues("uuid", data.UUID) log.V(2).Info("Updating a token") scopesJson, err := json.Marshal(data.Scopes) if err != nil { log.Error(err, "Couldn't marshal permissions into json") return ErrServerError } query := "UPDATE tokens SET description = $1, scopes = $2 WHERE uuid = $3;" if _, err := ctrl.DB.Query(query, data.Name, scopesJson, data.UUID); err != nil { log.Error(err, "Couldn't update a token in the database") return ErrServerError } return nil } // ForceExpiration of a token, so it can no longer be used func (ctrl *TokenController) ForceExpiration(ctx context.Context, id string) error { log := logger.FromContext(ctx).WithValues("uuid", id) log.V(2).Info("Forcing a token expiration") query := "UPDATE tokens SET revoked_at = $1 WHERE uuid = $2;" if _, err := ctrl.DB.Query(query, time.Now(), id); err != nil { log.Error(err, "Couldn't update a token in the database") return ErrServerError } return nil } // Regenerate a token and get a new value func (ctrl *TokenController) Regenerate(ctx context.Context, id string) (string, error) { log := logger.FromContext(ctx).WithValues("uuid", id) log.V(2).Info("Regenerating a token") tokenValue, err := token.GenerateToken() if err != nil { log.Error(err, "Couldn't create a token") return "", ErrServerError } tokenHash := hashSHA256(tokenValue) query := ` UPDATE tokens SET token_hash = $1, generated_at = $2, expires_at = NOW() + (expires_at - generated_at), WHERE uuid = $3;` if _, err := ctrl.DB.Query(query, tokenHash, time.Now(), id); err != nil { log.Error(err, "Couldn't insert a token in the database") return "", ErrServerError } return "", nil } // Get an existing token data func (ctrl *TokenController) Get(ctx context.Context, id, userID string) (*TokenData, error) { log := logger.FromContext(ctx).WithValues("uuid", id, "user_id", userID) log.V(2).Info("Regenerating a token") query := ` SELECT uuid, description, generated_at, expires_at, last_used_at, revoked_at, created_at FROM tokens WHERE uuid = $1 AND user_id = $2` token := &TokenData{} var generatedAt sql.NullTime var expiresAt sql.NullTime var revokedAt sql.NullTime var lastUsedAt sql.NullTime var createdAt sql.NullTime if err := ctrl.DB.QueryRowContext(ctx, query, id, userID).Scan( &token.UUID, &token.Name, &generatedAt, &expiresAt, &lastUsedAt, &revokedAt, &createdAt, ); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, err } log.Error(err, "Couldn't find a token") return nil, ErrServerError } token.GeneratedAt = generatedAt.Time token.ExpiresAt = expiresAt.Time token.RevokedAt = revokedAt.Time token.LastUsedAt = lastUsedAt.Time token.CreatedAt = createdAt.Time return token, nil } // List all available token func (ctrl *TokenController) List(ctx context.Context, userID string) ([]TokenData, error) { log := logger.FromContext(ctx).WithValues("user_id", userID) log.V(2).Info("Regenerating a token") result := []TokenData{} query := ` SELECT uuid, description, generated_at, expires_at, revoked_at, last_used_at, created_at FROM tokens WHERE user_id = $1` rows, err := ctrl.DB.QueryContext(ctx, query, userID) if err != nil { log.Error(err, "Couldn't list tokens") return nil, ErrServerError } defer rows.Close() if err := rows.Err(); err != nil { log.Error(err, "Couldn't list tokens") return nil, ErrServerError } for rows.Next() { var t TokenData var generatedAt sql.NullTime var expiresAt sql.NullTime var revokedAt sql.NullTime var lastUsedAt sql.NullTime var createdAt sql.NullTime err := rows.Scan( &t.UUID, &t.Name, &generatedAt, &expiresAt, &revokedAt, &lastUsedAt, &createdAt, ) t.GeneratedAt = generatedAt.Time t.ExpiresAt = expiresAt.Time t.RevokedAt = revokedAt.Time t.LastUsedAt = lastUsedAt.Time t.CreatedAt = createdAt.Time if err != nil { log.Error(err, "Couldn't write token into a struct") return nil, err } result = append(result, t) } return result, nil } // Lis all available permissions func (ctrl *TokenController) ListPermissions(ctx context.Context) (result map[string][]string) { result = map[string][]string{} for key, val := range ctrl.ServiceInfo { if shouldSkip(key, ctrl.rules) { continue } var services []string for _, svc := range val.Methods { services = append(services, svc.Name) } result[key] = services } return } type rule struct { re *regexp.Regexp } func shouldSkip(s string, rules []rule) bool { for _, r := range rules { if r.re.MatchString(s) { return true } } return false } type TokenAuthResult struct { UserID string Scope string } func (ctrl *TokenController) AuthenticateWithToken(ctx context.Context, token string) (*TokenAuthResult, error) { log := logger.FromContext(ctx) log.V(2).Info("Authenticating with a token") query := ` SELECT user_id, scopes, expires_at, revoked_at FROM tokens WHERE token_hash = $1` var userID string var expiresAt sql.NullTime var revokedAt sql.NullTime var scope string if err := ctrl.DB.QueryRowContext(ctx, query, hashSHA256(token)).Scan( &userID, &scope, &expiresAt, &revokedAt, ); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, err } log.Error(err, "Couldn't find a token") return nil, ErrServerError } if revokedAt.Valid { return nil, ErrBadToken } if expiresAt.Time.Before(time.Now()) { return nil, ErrBadToken } result := &TokenAuthResult{ UserID: userID, Scope: scope, } return result, nil } func hashSHA256(s string) string { hash := sha256.Sum256([]byte(s)) return hex.EncodeToString(hash[:]) } func TestUnitHashPersistence(t *testing.T) { password := "qwertyu9" hash1 := hashSHA256(password) hash2 := hashSHA256(password) assert.Equal(t, hash1, hash2) }