package authorization import ( "context" "encoding/json" "errors" "fmt" "strings" "time" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth" "github.com/redis/go-redis/v9" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) type TokenType string const ( TokenTypeAccess TokenType = "access" TokenTypeRefresh TokenType = "refresh" ) var ( ErrUnknownTokenType = errors.New("token type unknown") ErrInvalidToken = errors.New("invalid token") ) type Claims struct { UserID string `json:"user_id"` TokenID string `json:"token_id"` TokenType TokenType `json:"token_type"` jwt.RegisteredClaims } type AuthController struct { jwtSecret []byte accessTTL time.Duration refreshTTL time.Duration redis *redis.Client } type contextKey string const claimsContextKey contextKey = "jwt_claims" func NewAuthController(jwtSecret []byte, accessTTL, refreshTTL time.Duration, redis *redis.Client) *AuthController { return &AuthController{ jwtSecret: jwtSecret, accessTTL: accessTTL, refreshTTL: refreshTTL, redis: redis, } } // Write claims into context func (a *AuthController) WithClaims(ctx context.Context, claims *Claims) context.Context { return context.WithValue(ctx, claimsContextKey, claims) } // Extract claims from context func (a *AuthController) ClaimsFromContext(ctx context.Context) (*Claims, error) { claims, ok := ctx.Value(claimsContextKey).(*Claims) if !ok || claims == nil { return nil, errors.New("claims not found in context") } return claims, nil } func (a *AuthController) AuthInterceptorFN(ctx context.Context) (context.Context, error) { tokenString, err := auth.AuthFromMD(ctx, "bearer") if err != nil { return nil, err } claims, err := a.ParseToken(tokenString) if err != nil { return nil, status.Error(codes.Unauthenticated, "Invalid JWT token") } if method, ok := grpc.Method(ctx); ok { if claims.TokenType == TokenTypeRefresh && !strings.Contains(method, "RefreshToken") { return nil, status.Error(codes.Unauthenticated, "Refresh token is not allowed for this method") } } ctx = a.WithClaims(ctx, claims) return ctx, nil } // Generate JWT token func (a *AuthController) GenerateToken(userID string, tokenType TokenType) (token, tokenID string, err error) { var expiresAt time.Time notBefore := time.Now() switch tokenType { case TokenTypeAccess: expiresAt = time.Now().Add(a.accessTTL) case TokenTypeRefresh: expiresAt = time.Now().Add(a.refreshTTL) default: return "", "", ErrUnknownTokenType } if tokenType != TokenTypeAccess && tokenType != TokenTypeRefresh { return "", "", ErrUnknownTokenType } tokenID = uuid.New().String() claims := Claims{ UserID: userID, TokenID: tokenID, TokenType: tokenType, RegisteredClaims: jwt.RegisteredClaims{ Issuer: "", Subject: "", Audience: jwt.ClaimStrings{}, ExpiresAt: jwt.NewNumericDate(expiresAt), NotBefore: jwt.NewNumericDate(notBefore), IssuedAt: jwt.NewNumericDate(time.Now()), ID: userID, }, } tokenJwt := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token, err = tokenJwt.SignedString(a.jwtSecret) if err != nil { return "", "", err } return } func (a *AuthController) ParseToken(tokenStr string) (*Claims, error) { token, err := jwt.ParseWithClaims( tokenStr, &Claims{}, func(token *jwt.Token) (interface{}, error) { return a.jwtSecret, nil }, ) if err != nil { return nil, err } claims, ok := token.Claims.(*Claims) if !ok || !token.Valid { return nil, ErrInvalidToken } return claims, nil } type Session struct { UserID string `json:"user_id"` } func redisSessionKey(input string) string { return fmt.Sprintf("session:%s", input) } func (a *AuthController) SaveSession(ctx context.Context, tokenID string, session *Session) error { sessionJson, err := json.Marshal(session) if err != nil { return err } if err := a.redis.Set(ctx, redisSessionKey(tokenID), string(sessionJson), a.refreshTTL).Err(); err != nil { return err } return nil } func (a *AuthController) GetSession(ctx context.Context, tokenID string) (*Session, error) { sessionRaw := a.redis.Get(ctx, redisSessionKey(tokenID)).Val() if err := a.redis.Del(ctx, redisSessionKey(tokenID)).Err(); err != nil { return nil, err } session := &Session{} if err := json.Unmarshal([]byte(sessionRaw), session); err != nil { return nil, err } return session, nil }