Implement refresh token endpoint
All checks were successful
ci/woodpecker/push/build Pipeline was successful

Signed-off-by: Nikolai Rodionov <allanger@badhouseplants.net>
This commit is contained in:
2026-05-09 21:36:23 +02:00
parent 19e47876f0
commit e58eba1b16
10 changed files with 356 additions and 148 deletions

View File

@@ -0,0 +1,184 @@
package authorization
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
"github.com/redis/go-redis/v9"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type TokenType string
const (
TokenTypeAccess TokenType = "access"
TokenTypeRefresh TokenType = "refresh"
)
var (
ErrUnknownTokenType = errors.New("token type unknown")
ErrInvalidToken = errors.New("invalid token")
)
type Claims struct {
UserID string `json:"user_id"`
TokenID string `json:"token_id"`
TokenType TokenType `json:"token_type"`
jwt.RegisteredClaims
}
type AuthController struct {
jwtSecret []byte
accessTTL time.Duration
refreshTTL time.Duration
redis *redis.Client
}
type contextKey string
const claimsContextKey contextKey = "jwt_claims"
func NewAuthController(jwtSecret []byte, accessTTL, refreshTTL time.Duration, redis *redis.Client) *AuthController {
return &AuthController{
jwtSecret: jwtSecret,
accessTTL: accessTTL,
refreshTTL: refreshTTL,
redis: redis,
}
}
// Write claims into context
func (a *AuthController) WithClaims(ctx context.Context, claims *Claims) context.Context {
return context.WithValue(ctx, claimsContextKey, claims)
}
// Extract claims from context
func (a *AuthController) ClaimsFromContext(ctx context.Context) (*Claims, error) {
claims, ok := ctx.Value(claimsContextKey).(*Claims)
if !ok || claims == nil {
return nil, errors.New("claims not found in context")
}
return claims, nil
}
func (a *AuthController) AuthInterceptorFN(ctx context.Context) (context.Context, error) {
tokenString, err := auth.AuthFromMD(ctx, "bearer")
if err != nil {
return nil, err
}
claims, err := a.ParseToken(tokenString)
if err != nil {
return nil, status.Error(codes.Unauthenticated, "Invalid JWT token")
}
if method, ok := grpc.Method(ctx); ok {
if claims.TokenType == TokenTypeRefresh && !strings.Contains(method, "RefreshToken") {
return nil, status.Error(codes.Unauthenticated, "Refresh token is not allowed for this method")
}
}
ctx = a.WithClaims(ctx, claims)
return ctx, nil
}
// Generate JWT token
func (a *AuthController) GenerateToken(userID string, tokenType TokenType) (token, tokenID string, err error) {
var expiresAt time.Time
notBefore := time.Now()
switch tokenType {
case TokenTypeAccess:
expiresAt = time.Now().Add(a.accessTTL)
case TokenTypeRefresh:
expiresAt = time.Now().Add(a.refreshTTL)
default:
return "", "", ErrUnknownTokenType
}
if tokenType != TokenTypeAccess && tokenType != TokenTypeRefresh {
return "", "", ErrUnknownTokenType
}
tokenID = uuid.New().String()
claims := Claims{
UserID: userID,
TokenID: tokenID,
TokenType: tokenType,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "",
Subject: "",
Audience: jwt.ClaimStrings{},
ExpiresAt: jwt.NewNumericDate(expiresAt),
NotBefore: jwt.NewNumericDate(notBefore),
IssuedAt: jwt.NewNumericDate(time.Now()),
ID: userID,
},
}
tokenJwt := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token, err = tokenJwt.SignedString(a.jwtSecret)
if err != nil {
return "", "", err
}
return
}
func (a *AuthController) ParseToken(tokenStr string) (*Claims, error) {
token, err := jwt.ParseWithClaims(
tokenStr,
&Claims{},
func(token *jwt.Token) (interface{}, error) {
return a.jwtSecret, nil
},
)
if err != nil {
return nil, err
}
claims, ok := token.Claims.(*Claims)
if !ok || !token.Valid {
return nil, ErrInvalidToken
}
return claims, nil
}
type Session struct {
UserID string `json:"user_id"`
}
func redisSessionKey(input string) string {
return fmt.Sprintf("session:%s", input)
}
func (a *AuthController) SaveSession(ctx context.Context, tokenID string, session *Session) error {
sessionJson, err := json.Marshal(session)
if err != nil {
return err
}
if err := a.redis.Set(ctx, redisSessionKey(tokenID), string(sessionJson), a.refreshTTL).Err(); err != nil {
return err
}
return nil
}
func (a *AuthController) GetSession(ctx context.Context, tokenID string) (*Session, error) {
sessionRaw := a.redis.Get(ctx, redisSessionKey(tokenID)).Val()
if err := a.redis.Del(ctx, redisSessionKey(tokenID)).Err(); err != nil {
return nil, err
}
session := &Session{}
if err := json.Unmarshal([]byte(sessionRaw), session); err != nil {
return nil, err
}
return session, nil
}

View File

@@ -0,0 +1,57 @@
package authorization_test
import (
"testing"
"time"
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/authorization"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
var (
testAccessTTL = time.Second * 5
testRefreshTTL = time.Second * 20
testUserID = uuid.New().String()
)
func TestGenerateInvalidTokenType(t *testing.T) {
authCtrl := authorization.NewAuthController([]byte("test"), testAccessTTL, testRefreshTTL, nil)
token, _, err := authCtrl.GenerateToken(testUserID, "invalid_type")
assert.Equal(t, "", token)
assert.ErrorIs(t, authorization.ErrUnknownTokenType, err)
}
func TestGenerateValidateAccessToken(t *testing.T) {
authCtrl := authorization.NewAuthController([]byte("test"), testAccessTTL, testRefreshTTL, nil)
now := time.Now()
token, _, err := authCtrl.GenerateToken(testUserID, authorization.TokenTypeAccess)
assert.NoError(t, err)
assert.NotEmpty(t, token)
claims, err := authCtrl.ParseToken(token)
assert.NoError(t, err)
assert.Equal(t, testUserID, claims.UserID)
assert.NotEmpty(t, claims.TokenID)
assert.Equal(t, authorization.TokenTypeAccess, claims.TokenType)
assert.Equal(t, now.Add(testAccessTTL).Unix(), claims.ExpiresAt.Unix())
assert.Equal(t, now.Unix(), claims.IssuedAt.Unix())
assert.Equal(t, now.Unix(), claims.NotBefore.Unix())
}
func TestGenerateValidateRefreshToken(t *testing.T) {
authCtrl := authorization.NewAuthController([]byte("test"), testAccessTTL, testRefreshTTL, nil)
now := time.Now()
token, _, err := authCtrl.GenerateToken(testUserID, authorization.TokenTypeRefresh)
assert.NoError(t, err)
assert.NotEmpty(t, token)
claims, err := authCtrl.ParseToken(token)
assert.NoError(t, err)
assert.Equal(t, testUserID, claims.UserID)
assert.NotEmpty(t, claims.TokenID)
assert.Equal(t, authorization.TokenTypeRefresh, claims.TokenType)
assert.Equal(t, now.Add(testRefreshTTL).Unix(), claims.ExpiresAt.Unix())
assert.Equal(t, now.Unix(), claims.IssuedAt.Unix())
assert.Equal(t, now.Unix(), claims.NotBefore.Unix())
}

View File

@@ -77,10 +77,12 @@ func (c *AccountController) Login(ctx context.Context, email, password string) (
}
func (c *AccountController) GenerateAccessToken(userID string) (string, error) {
tokenID := uuid.New().String()
claims := jwt.MapClaims{
"user_id": userID,
"type": "access",
"exp": time.Now().Add(c.AccessTokenTTL).Unix(),
"user_id": userID,
"type": "access",
"exp": time.Now().Add(c.AccessTokenTTL).Unix(),
"token_id": tokenID,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)

View File

@@ -1,64 +0,0 @@
package interceptors
import (
"context"
"fmt"
"strings"
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/tools/logger"
"github.com/golang-jwt/jwt/v5"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
type JWTVerifier struct {
secret []byte
serverCtx context.Context
}
func NewJWTVerifier(ctx context.Context, secret []byte) *JWTVerifier {
return &JWTVerifier{
serverCtx: ctx,
secret: secret,
}
}
// This is an interceptors that should verify that a user is authorized
func (v *JWTVerifier) JWTAuthInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
log := logger.FromContext(v.serverCtx).WithValues("method", info.FullMethod)
if !strings.Contains(info.FullMethod, "NoAuth") {
log.Info("Checking the JWT token")
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Error(codes.Unauthenticated, "User is not authorized")
}
tokenString := md.Get("token")[0]
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) {
// hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key")
return v.secret, nil
}, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}))
if err != nil {
return nil, status.Error(codes.Unauthenticated, "User is not authorized")
}
if claims, ok := token.Claims.(jwt.MapClaims); ok {
fmt.Println(claims["userID"])
} else {
fmt.Println(err)
}
// Get the token from the metadata
// Validate the token
// Get the user id from the token
} else {
log.Info("Auth is not required for this request")
}
return handler(ctx, req)
}