All checks were successful
ci/woodpecker/push/build Pipeline was successful
Reviewed-on: #8
391 lines
9.4 KiB
Go
391 lines
9.4 KiB
Go
package controllers
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"regexp"
|
|
"testing"
|
|
"time"
|
|
|
|
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/logger"
|
|
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/token"
|
|
"github.com/google/uuid"
|
|
"github.com/redis/go-redis/v9"
|
|
"github.com/stretchr/testify/assert"
|
|
"google.golang.org/grpc"
|
|
)
|
|
|
|
type TokenController struct {
|
|
DB *sql.DB
|
|
HashCost int16
|
|
ServiceInfo map[string]grpc.ServiceInfo
|
|
rules []rule
|
|
Redis *redis.Client
|
|
}
|
|
|
|
// Services that are not available for tokens
|
|
var DisabledServicesRegex = []string{".*Accounts.*", ".*Tokens.*"}
|
|
|
|
// Errors
|
|
var (
|
|
ErrServerError = errors.New("internal server error")
|
|
ErrUserTokenMismatch = errors.New("user doesn't match the token")
|
|
ErrBadToken = errors.New("bad token")
|
|
)
|
|
|
|
type TokenData struct {
|
|
UUID string
|
|
Name string
|
|
UserID string
|
|
CreatedAt time.Time
|
|
LastUsedAt time.Time
|
|
RevokedAt time.Time
|
|
ExpiresAt time.Time
|
|
GeneratedAt time.Time
|
|
Scopes map[string][]string
|
|
}
|
|
|
|
type Scopes struct{}
|
|
|
|
// Set the grpc info, must happen after all the service are initialized
|
|
func (ctrl *TokenController) SetGRPCInfo(info map[string]grpc.ServiceInfo) {
|
|
ctrl.ServiceInfo = info
|
|
}
|
|
|
|
func (ctrl *TokenController) SetRules() {
|
|
rules := []rule{
|
|
{re: regexp.MustCompile(`.*Tokens.*`)},
|
|
{re: regexp.MustCompile(`.*Accounts.*`)},
|
|
{re: regexp.MustCompile(`.*Reflection.*`)},
|
|
}
|
|
ctrl.rules = rules
|
|
}
|
|
|
|
// Each token operation must first verify that the current user
|
|
// is allowed to manipulate the token.
|
|
func (ctrl *TokenController) VerifyTokenOwner(ctx context.Context, userID, tokenID string) error {
|
|
log := logger.FromContext(ctx).WithValues("uuid", tokenID, "user_id", userID)
|
|
log.V(2).Info("Verifying the token owner")
|
|
|
|
// First try to get from the redis
|
|
redisKey := fmt.Sprintf("token:%s", tokenID)
|
|
realUserID := ctrl.Redis.Get(ctx, redisKey).Val()
|
|
// If not found in cache, get from postgres
|
|
if realUserID == "" {
|
|
query := "SELECT user_id FROM tokens WHERE uuid = $1;"
|
|
if err := ctrl.DB.QueryRowContext(ctx, query, tokenID).Scan(
|
|
&realUserID,
|
|
); err != nil {
|
|
log.Error(err, "Couldn't get user_id for a token")
|
|
return ErrServerError
|
|
}
|
|
}
|
|
|
|
if realUserID != userID {
|
|
return ErrUserTokenMismatch
|
|
}
|
|
err := ctrl.Redis.Set(ctx, redisKey, realUserID, time.Hour)
|
|
if err != nil {
|
|
log.Info("Couldn't write to cache", "error", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Create a new token, store its hash in the database and return the token value
|
|
func (ctrl *TokenController) Create(ctx context.Context, data *TokenData) (string, error) {
|
|
id := uuid.NewString()
|
|
log := logger.FromContext(ctx).WithValues("uuid", id)
|
|
log.V(2).Info("Creating a new token")
|
|
|
|
tokenValue, err := token.GenerateToken()
|
|
if err != nil {
|
|
log.Error(err, "Couldn't create a token")
|
|
return "", ErrServerError
|
|
}
|
|
|
|
tokenHash := hashSHA256(tokenValue)
|
|
|
|
scopesJson, err := json.Marshal(data.Scopes)
|
|
if err != nil {
|
|
log.Error(err, "Couldn't marshal permissions into json")
|
|
return "", ErrServerError
|
|
}
|
|
|
|
query := `
|
|
INSERT INTO tokens (uuid, description, token_hash, user_id, scopes, created_at, generated_at, expires_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`
|
|
|
|
if _, err := ctrl.DB.Query(
|
|
query,
|
|
id,
|
|
data.Name,
|
|
tokenHash,
|
|
data.UserID,
|
|
scopesJson,
|
|
time.Now(),
|
|
time.Now(),
|
|
data.ExpiresAt,
|
|
); err != nil {
|
|
log.Error(err, "Couldn't insert a token in the database")
|
|
return "", ErrServerError
|
|
}
|
|
|
|
return tokenValue, nil
|
|
}
|
|
|
|
// Update token name or permissions, other changes are ignored by this method
|
|
func (ctrl *TokenController) Update(ctx context.Context, data *TokenData) error {
|
|
log := logger.FromContext(ctx).WithValues("uuid", data.UUID)
|
|
log.V(2).Info("Updating a token")
|
|
|
|
scopesJson, err := json.Marshal(data.Scopes)
|
|
if err != nil {
|
|
log.Error(err, "Couldn't marshal permissions into json")
|
|
return ErrServerError
|
|
}
|
|
|
|
query := "UPDATE tokens SET description = $1, scopes = $2 WHERE uuid = $3;"
|
|
if _, err := ctrl.DB.Query(query, data.Name, scopesJson, data.UUID); err != nil {
|
|
log.Error(err, "Couldn't update a token in the database")
|
|
return ErrServerError
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ForceExpiration of a token, so it can no longer be used
|
|
func (ctrl *TokenController) ForceExpiration(ctx context.Context, id string) error {
|
|
log := logger.FromContext(ctx).WithValues("uuid", id)
|
|
log.V(2).Info("Forcing a token expiration")
|
|
|
|
query := "UPDATE tokens SET revoked_at = $1 WHERE uuid = $2;"
|
|
if _, err := ctrl.DB.Query(query, time.Now(), id); err != nil {
|
|
log.Error(err, "Couldn't update a token in the database")
|
|
return ErrServerError
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Regenerate a token and get a new value
|
|
func (ctrl *TokenController) Regenerate(ctx context.Context, id string) (string, error) {
|
|
log := logger.FromContext(ctx).WithValues("uuid", id)
|
|
log.V(2).Info("Regenerating a token")
|
|
|
|
tokenValue, err := token.GenerateToken()
|
|
if err != nil {
|
|
log.Error(err, "Couldn't create a token")
|
|
return "", ErrServerError
|
|
}
|
|
|
|
tokenHash := hashSHA256(tokenValue)
|
|
|
|
query := `
|
|
UPDATE tokens
|
|
SET
|
|
token_hash = $1,
|
|
generated_at = $2,
|
|
expires_at = NOW() + (expires_at - generated_at),
|
|
WHERE uuid = $3;`
|
|
|
|
if _, err := ctrl.DB.Query(query, tokenHash, time.Now(), id); err != nil {
|
|
log.Error(err, "Couldn't insert a token in the database")
|
|
return "", ErrServerError
|
|
}
|
|
|
|
return "", nil
|
|
}
|
|
|
|
// Get an existing token data
|
|
func (ctrl *TokenController) Get(ctx context.Context, id, userID string) (*TokenData, error) {
|
|
log := logger.FromContext(ctx).WithValues("uuid", id, "user_id", userID)
|
|
log.V(2).Info("Regenerating a token")
|
|
query := `
|
|
SELECT uuid, description, generated_at, expires_at, last_used_at, revoked_at, created_at
|
|
FROM tokens
|
|
WHERE uuid = $1 AND user_id = $2`
|
|
token := &TokenData{}
|
|
var generatedAt sql.NullTime
|
|
var expiresAt sql.NullTime
|
|
var revokedAt sql.NullTime
|
|
var lastUsedAt sql.NullTime
|
|
var createdAt sql.NullTime
|
|
|
|
if err := ctrl.DB.QueryRowContext(ctx, query, id, userID).Scan(
|
|
&token.UUID,
|
|
&token.Name,
|
|
&generatedAt,
|
|
&expiresAt,
|
|
&lastUsedAt,
|
|
&revokedAt,
|
|
&createdAt,
|
|
); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, err
|
|
}
|
|
log.Error(err, "Couldn't find a token")
|
|
return nil, ErrServerError
|
|
}
|
|
token.GeneratedAt = generatedAt.Time
|
|
token.ExpiresAt = expiresAt.Time
|
|
token.RevokedAt = revokedAt.Time
|
|
token.LastUsedAt = lastUsedAt.Time
|
|
token.CreatedAt = createdAt.Time
|
|
|
|
return token, nil
|
|
}
|
|
|
|
// List all available token
|
|
func (ctrl *TokenController) List(ctx context.Context, userID string) ([]TokenData, error) {
|
|
log := logger.FromContext(ctx).WithValues("user_id", userID)
|
|
log.V(2).Info("Regenerating a token")
|
|
|
|
result := []TokenData{}
|
|
|
|
query := `
|
|
SELECT uuid, description, generated_at, expires_at, revoked_at, last_used_at, created_at
|
|
FROM tokens
|
|
WHERE user_id = $1`
|
|
|
|
rows, err := ctrl.DB.QueryContext(ctx, query, userID)
|
|
if err != nil {
|
|
log.Error(err, "Couldn't list tokens")
|
|
return nil, ErrServerError
|
|
}
|
|
|
|
defer rows.Close()
|
|
|
|
if err := rows.Err(); err != nil {
|
|
log.Error(err, "Couldn't list tokens")
|
|
return nil, ErrServerError
|
|
}
|
|
|
|
for rows.Next() {
|
|
var t TokenData
|
|
|
|
var generatedAt sql.NullTime
|
|
var expiresAt sql.NullTime
|
|
var revokedAt sql.NullTime
|
|
var lastUsedAt sql.NullTime
|
|
var createdAt sql.NullTime
|
|
|
|
err := rows.Scan(
|
|
&t.UUID,
|
|
&t.Name,
|
|
&generatedAt,
|
|
&expiresAt,
|
|
&revokedAt,
|
|
&lastUsedAt,
|
|
&createdAt,
|
|
)
|
|
|
|
t.GeneratedAt = generatedAt.Time
|
|
t.ExpiresAt = expiresAt.Time
|
|
t.RevokedAt = revokedAt.Time
|
|
t.LastUsedAt = lastUsedAt.Time
|
|
t.CreatedAt = createdAt.Time
|
|
|
|
if err != nil {
|
|
log.Error(err, "Couldn't write token into a struct")
|
|
return nil, err
|
|
}
|
|
|
|
result = append(result, t)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// Lis all available permissions
|
|
func (ctrl *TokenController) ListPermissions(ctx context.Context) (result map[string][]string) {
|
|
result = map[string][]string{}
|
|
for key, val := range ctrl.ServiceInfo {
|
|
if shouldSkip(key, ctrl.rules) {
|
|
continue
|
|
}
|
|
var services []string
|
|
for _, svc := range val.Methods {
|
|
services = append(services, svc.Name)
|
|
}
|
|
result[key] = services
|
|
}
|
|
return
|
|
}
|
|
|
|
type rule struct {
|
|
re *regexp.Regexp
|
|
}
|
|
|
|
func shouldSkip(s string, rules []rule) bool {
|
|
for _, r := range rules {
|
|
if r.re.MatchString(s) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
type TokenAuthResult struct {
|
|
UserID string
|
|
Scope string
|
|
}
|
|
|
|
func (ctrl *TokenController) AuthenticateWithToken(ctx context.Context, token string) (*TokenAuthResult, error) {
|
|
log := logger.FromContext(ctx)
|
|
log.V(2).Info("Authenticating with a token")
|
|
|
|
query := `
|
|
SELECT user_id, scopes, expires_at, revoked_at
|
|
FROM tokens
|
|
WHERE token_hash = $1`
|
|
|
|
var userID string
|
|
var expiresAt sql.NullTime
|
|
var revokedAt sql.NullTime
|
|
var scope string
|
|
if err := ctrl.DB.QueryRowContext(ctx, query, hashSHA256(token)).Scan(
|
|
&userID,
|
|
&scope,
|
|
&expiresAt,
|
|
&revokedAt,
|
|
); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, err
|
|
}
|
|
log.Error(err, "Couldn't find a token")
|
|
return nil, ErrServerError
|
|
}
|
|
|
|
if revokedAt.Valid {
|
|
return nil, ErrBadToken
|
|
}
|
|
|
|
if expiresAt.Time.Before(time.Now()) {
|
|
return nil, ErrBadToken
|
|
}
|
|
|
|
result := &TokenAuthResult{
|
|
UserID: userID,
|
|
Scope: scope,
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func hashSHA256(s string) string {
|
|
hash := sha256.Sum256([]byte(s))
|
|
return hex.EncodeToString(hash[:])
|
|
}
|
|
|
|
func TestUnitHashPersistence(t *testing.T) {
|
|
password := "qwertyu9"
|
|
hash1 := hashSHA256(password)
|
|
hash2 := hashSHA256(password)
|
|
assert.Equal(t, hash1, hash2)
|
|
}
|