package controllers import ( "context" "crypto/sha256" "database/sql" "encoding/hex" "encoding/json" "errors" "regexp" "testing" "time" "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/cache" "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/logger" "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/token" "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/repository" "github.com/google/uuid" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "google.golang.org/grpc" ) type TokenController struct { DB *sql.DB ServiceInfo map[string]grpc.ServiceInfo rules []rule Redis *redis.Client } // Services that are not available for tokens var DisabledServicesRegex = []string{".*Accounts.*", ".*Tokens.*"} // Errors var ( ErrServerError = errors.New("internal server error") ErrUserTokenMismatch = errors.New("user doesn't match the token") ErrBadToken = errors.New("bad token") ErrTokenNotFound = errors.New("token not found") ) type TokenData struct { UUID string Name string UserID string CreatedAt time.Time LastUsedAt time.Time RevokedAt time.Time ExpiresAt time.Time GeneratedAt time.Time Scopes map[string][]string } type Scopes struct{} // Set the grpc info, must happen after all the service are initialized func (ctrl *TokenController) SetGRPCInfo(info map[string]grpc.ServiceInfo) { ctrl.ServiceInfo = info } func (ctrl *TokenController) SetRules() { rules := []rule{ {re: regexp.MustCompile(`.*Tokens.*`)}, {re: regexp.MustCompile(`.*Accounts.*`)}, {re: regexp.MustCompile(`.*Reflection.*`)}, } ctrl.rules = rules } // Each token operation must first verify that the current user // is allowed to manipulate the token. func (ctrl *TokenController) VerifyTokenOwner(ctx context.Context, userID, tokenID string) error { log := logger.FromContext(ctx).WithValues("uuid", tokenID, "user_id", userID) log.V(2).Info("Verifying the token owner") realUserID := cache.GetFromCache(ctx, ctrl.Redis, cache.CacheFolderToken, tokenID) // If not found in cache, get from postgres if realUserID == "" { query := "SELECT user_id FROM tokens WHERE uuid = $1;" if err := ctrl.DB.QueryRowContext(ctx, query, tokenID).Scan( &realUserID, ); err != nil { log.Error(err, "Couldn't get user_id for a token") return ErrServerError } } if realUserID != userID { return ErrUserTokenMismatch } if err := cache.SaveToCache(ctx, ctrl.Redis, cache.CacheFolderToken, realUserID, tokenID, time.Hour); err != nil { log.Info("Couldn't write to cache", "error", err) } return nil } // Create a new token, store its hash in the database and return the token value func (ctrl *TokenController) Create(ctx context.Context, data *TokenData) (string, string, error) { id := uuid.NewString() log := logger.FromContext(ctx).WithValues("uuid", id) log.V(2).Info("Creating a new token") userExists, err := repository.IsAccountExist(ctx, ctrl.DB, data.UserID) if err != nil { log.Error(err, "Couldn't check whether a user exists") return "", "", ErrServerError } // If user doesn't exist, do not generate a token if !userExists { return "", "", ErrUserNotFound } tokenValue, err := token.GenerateToken() if err != nil { log.Error(err, "Couldn't create a token") return "", "", ErrServerError } tokenHash := hashSHA256(tokenValue) scopesJson, err := json.Marshal(data.Scopes) if err != nil { log.Error(err, "Couldn't marshal permissions into json") return "", "", ErrServerError } queryData := &repository.TokenData{ UUID: id, Decsription: data.Name, TokenHash: tokenHash, UserID: data.UserID, CreatedAt: time.Now(), GeneratedAt: time.Now(), ExpiresAt: data.ExpiresAt, Scope: string(scopesJson), } if err := repository.CreateToken(ctx, ctrl.DB, queryData); err != nil { log.Error(err, "Couldn't create a token") return "", "", ErrServerError } return tokenValue, id, nil } // Update token name or permissions, other changes are ignored by this method func (ctrl *TokenController) Update(ctx context.Context, data *TokenData) error { log := logger.FromContext(ctx).WithValues("uuid", data.UUID) log.V(2).Info("Updating a token") scopesJson, err := json.Marshal(data.Scopes) if err != nil { log.Error(err, "Couldn't marshal permissions into json") return ErrServerError } queryData := &repository.TokenData{ UUID: data.UUID, Scope: string(scopesJson), Decsription: data.Name, } if err := repository.UpdateToken(ctx, ctrl.DB, queryData); err != nil { log.Error(err, "Couldn't update a token") return ErrServerError } return nil } // ForceExpiration of a token, so it can no longer be used func (ctrl *TokenController) ForceExpiration(ctx context.Context, id string) error { log := logger.FromContext(ctx).WithValues("uuid", id) log.V(2).Info("Forcing a token expiration") query := "UPDATE tokens SET revoked_at = $1 WHERE uuid = $2;" if _, err := ctrl.DB.Query(query, time.Now(), id); err != nil { log.Error(err, "Couldn't update a token in the database") return ErrServerError } return nil } // Regenerate a token and get a new value func (ctrl *TokenController) Regenerate(ctx context.Context, id string) (string, error) { log := logger.FromContext(ctx).WithValues("uuid", id) log.V(2).Info("Regenerating a token") tokenValue, err := token.GenerateToken() if err != nil { log.Error(err, "Couldn't create a token") return "", ErrServerError } tokenHash := hashSHA256(tokenValue) query := ` UPDATE tokens SET token_hash = $1, generated_at = $2, expires_at = NOW() + (expires_at - generated_at), WHERE uuid = $3;` if _, err := ctrl.DB.Query(query, tokenHash, time.Now(), id); err != nil { log.Error(err, "Couldn't insert a token in the database") return "", ErrServerError } return "", nil } // Get an existing token data func (ctrl *TokenController) Get(ctx context.Context, id, userID string) (*TokenData, error) { log := logger.FromContext(ctx).WithValues("uuid", id, "user_id", userID) log.V(2).Info("Getting a token") queryResult, err := repository.GetToken(ctx, ctrl.DB, id, userID) if err != nil { if errors.Is(err, repository.ErrNotFound) { return nil, ErrTokenNotFound } log.Error(err, "Couldn't get a token from DB") return nil, ErrServerError } scope := map[string][]string{} if err := json.Unmarshal([]byte(queryResult.Scope), &scope); err != nil { log.Error(err, "Couldn't unmarshal scope into json") return nil, ErrServerError } result := &TokenData{ UUID: queryResult.UUID, Name: queryResult.Decsription, CreatedAt: queryResult.CreatedAt, LastUsedAt: queryResult.LastUsedAt, RevokedAt: queryResult.RevokedAt, ExpiresAt: queryResult.ExpiresAt, GeneratedAt: queryResult.GeneratedAt, Scopes: scope, } return result, nil } // List all available token func (ctrl *TokenController) List(ctx context.Context, userID string) ([]TokenData, error) { log := logger.FromContext(ctx).WithValues("user_id", userID) log.V(2).Info("Regenerating a token") result := []TokenData{} query := ` SELECT uuid, description, generated_at, expires_at, revoked_at, last_used_at, created_at FROM tokens WHERE user_id = $1` rows, err := ctrl.DB.QueryContext(ctx, query, userID) if err != nil { log.Error(err, "Couldn't list tokens") return nil, ErrServerError } defer rows.Close() if err := rows.Err(); err != nil { log.Error(err, "Couldn't list tokens") return nil, ErrServerError } for rows.Next() { var t TokenData var generatedAt sql.NullTime var expiresAt sql.NullTime var revokedAt sql.NullTime var lastUsedAt sql.NullTime var createdAt sql.NullTime err := rows.Scan( &t.UUID, &t.Name, &generatedAt, &expiresAt, &revokedAt, &lastUsedAt, &createdAt, ) t.GeneratedAt = generatedAt.Time t.ExpiresAt = expiresAt.Time t.RevokedAt = revokedAt.Time t.LastUsedAt = lastUsedAt.Time t.CreatedAt = createdAt.Time if err != nil { log.Error(err, "Couldn't write token into a struct") return nil, err } result = append(result, t) } return result, nil } // Lis all available permissions func (ctrl *TokenController) ListPermissions(ctx context.Context) (result map[string][]string) { result = map[string][]string{} for key, val := range ctrl.ServiceInfo { if shouldSkip(key, ctrl.rules) { continue } var services []string for _, svc := range val.Methods { services = append(services, svc.Name) } result[key] = services } return } type rule struct { re *regexp.Regexp } func shouldSkip(s string, rules []rule) bool { for _, r := range rules { if r.re.MatchString(s) { return true } } return false } type TokenAuthResult struct { UserID string Scope string } func (ctrl *TokenController) AuthenticateWithToken(ctx context.Context, token string) (*TokenAuthResult, error) { log := logger.FromContext(ctx) log.V(2).Info("Authenticating with a token") query := ` SELECT user_id, scopes, expires_at, revoked_at FROM tokens WHERE token_hash = $1` var userID string var expiresAt sql.NullTime var revokedAt sql.NullTime var scope string if err := ctrl.DB.QueryRowContext(ctx, query, hashSHA256(token)).Scan( &userID, &scope, &expiresAt, &revokedAt, ); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, err } log.Error(err, "Couldn't find a token") return nil, ErrServerError } if revokedAt.Valid { return nil, ErrBadToken } if expiresAt.Time.Before(time.Now()) { return nil, ErrBadToken } result := &TokenAuthResult{ UserID: userID, Scope: scope, } return result, nil } func hashSHA256(s string) string { hash := sha256.Sum256([]byte(s)) return hex.EncodeToString(hash[:]) } // Unit Tests func TestUnitHashPersistence(t *testing.T) { password := "qwertyu9" hash1 := hashSHA256(password) hash2 := hashSHA256(password) assert.Equal(t, hash1, hash2) }