Implement refresh token endpoint
All checks were successful
ci/woodpecker/push/build Pipeline was successful
All checks were successful
ci/woodpecker/push/build Pipeline was successful
Signed-off-by: Nikolai Rodionov <allanger@badhouseplants.net>
This commit is contained in:
@@ -21,6 +21,8 @@ tasks:
|
|||||||
|
|
||||||
run-server-dev:
|
run-server-dev:
|
||||||
desc: Run the local dev server
|
desc: Run the local dev server
|
||||||
|
deps:
|
||||||
|
- run-migrations-dev
|
||||||
env:
|
env:
|
||||||
SOFTPLAYER_DB_CONNECTION_STRING: postgres://softplayer:qwertyu9@localhost:30432/softplayer?sslmode=disable
|
SOFTPLAYER_DB_CONNECTION_STRING: postgres://softplayer:qwertyu9@localhost:30432/softplayer?sslmode=disable
|
||||||
SOFTPLAYER_REDIS_HOST: localhost:30379
|
SOFTPLAYER_REDIS_HOST: localhost:30379
|
||||||
|
|||||||
@@ -1,17 +1,77 @@
|
|||||||
package v1
|
package v1
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/authorization"
|
||||||
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/controllers"
|
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/controllers"
|
||||||
accounts "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/accounts/v1"
|
accounts "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/accounts/v1"
|
||||||
|
"github.com/golang/protobuf/ptypes/empty"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
"google.golang.org/protobuf/types/known/emptypb"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewAccountAuthRPCImpl(ctrl *controllers.AccountController) *AccountsAuthServer {
|
func NewAccountAuthRPCImpl(
|
||||||
|
accountsCtrl *controllers.AccountController,
|
||||||
|
authorizationCtrl *authorization.AuthController,
|
||||||
|
) *AccountsAuthServer {
|
||||||
return &AccountsAuthServer{
|
return &AccountsAuthServer{
|
||||||
ctrl: ctrl,
|
accountsCtrl: accountsCtrl,
|
||||||
|
authorizationCtrl: authorizationCtrl,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccountsAuthServer struct {
|
type AccountsAuthServer struct {
|
||||||
accounts.UnimplementedAccountsAuthServiceServer
|
accounts.UnimplementedAccountsAuthServiceServer
|
||||||
ctrl *controllers.AccountController
|
accountsCtrl *controllers.AccountController
|
||||||
|
authorizationCtrl *authorization.AuthController
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AccountsAuthServer) RefreshToken(ctx context.Context, in *empty.Empty) (*empty.Empty, error) {
|
||||||
|
claims, err := a.authorizationCtrl.ClaimsFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Error(codes.Aborted, "Context is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.TokenType != authorization.TokenTypeRefresh {
|
||||||
|
return nil, status.Error(codes.Unauthenticated, "Invalid token")
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := a.authorizationCtrl.GetSession(ctx, claims.TokenID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Error(codes.Unauthenticated, "Session doesn't exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.UserID != claims.UserID {
|
||||||
|
return nil, status.Error(codes.Unauthenticated, "Invalid session")
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken, _, err := a.authorizationCtrl.GenerateToken(session.UserID, authorization.TokenTypeAccess)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Error(codes.Aborted, "Couldn't generate an access token")
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshToken, tokenID, err := a.authorizationCtrl.GenerateToken(session.UserID, authorization.TokenTypeRefresh)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Error(codes.Aborted, "Couldn't generate an access token")
|
||||||
|
}
|
||||||
|
|
||||||
|
newSession := &authorization.Session{UserID: session.UserID}
|
||||||
|
|
||||||
|
if err := a.authorizationCtrl.SaveSession(ctx, tokenID, newSession); err != nil {
|
||||||
|
return nil, status.Error(codes.Aborted, "Couldn't store session")
|
||||||
|
}
|
||||||
|
|
||||||
|
header := metadata.New(map[string]string{
|
||||||
|
"X-Access-Token": accessToken,
|
||||||
|
"X-Refresh-Token": refreshToken,
|
||||||
|
})
|
||||||
|
if err := grpc.SetHeader(ctx, header); err != nil {
|
||||||
|
return nil, status.Error(codes.Aborted, "Couldn't set metadata")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &emptypb.Empty{}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package v1
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/authorization"
|
||||||
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/controllers"
|
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/controllers"
|
||||||
accounts "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/accounts/v1"
|
accounts "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/accounts/v1"
|
||||||
"github.com/golang/protobuf/ptypes/empty"
|
"github.com/golang/protobuf/ptypes/empty"
|
||||||
@@ -13,32 +14,42 @@ import (
|
|||||||
"google.golang.org/protobuf/types/known/emptypb"
|
"google.golang.org/protobuf/types/known/emptypb"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewAccountNoAuthRPCImpl(ctrl *controllers.AccountController) *AccountsNoAuthServer {
|
func NewAccountNoAuthRPCImpl(
|
||||||
|
accountsCtrl *controllers.AccountController,
|
||||||
|
authorizationCtrl *authorization.AuthController,
|
||||||
|
) *AccountsNoAuthServer {
|
||||||
return &AccountsNoAuthServer{
|
return &AccountsNoAuthServer{
|
||||||
ctrl: ctrl,
|
accountsCtrl: accountsCtrl,
|
||||||
|
authorizationCtrl: authorizationCtrl,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccountsNoAuthServer struct {
|
type AccountsNoAuthServer struct {
|
||||||
accounts.UnimplementedAccountsNoAuthServiceServer
|
accounts.UnimplementedAccountsNoAuthServiceServer
|
||||||
ctrl *controllers.AccountController
|
accountsCtrl *controllers.AccountController
|
||||||
|
authorizationCtrl *authorization.AuthController
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AccountsNoAuthServer) SignIn(ctx context.Context, in *accounts.SignInRequest) (*empty.Empty, error) {
|
func (a *AccountsNoAuthServer) SignIn(ctx context.Context, in *accounts.SignInRequest) (*empty.Empty, error) {
|
||||||
id, err := a.ctrl.Login(ctx, in.GetEmail(), in.GetPassword())
|
id, err := a.accountsCtrl.Login(ctx, in.GetEmail(), in.GetPassword())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.Aborted, "Couldn't create a user")
|
return nil, status.Error(codes.Aborted, "Couldn't create a user")
|
||||||
}
|
}
|
||||||
accessToken, err := a.ctrl.GenerateAccessToken(id)
|
accessToken, _, err := a.authorizationCtrl.GenerateToken(id, authorization.TokenTypeAccess)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.Aborted, "Couldn't generate an access token")
|
return nil, status.Error(codes.Aborted, "Couldn't generate an access token")
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshToken, err := a.ctrl.GenerateRefreshToken(ctx, id)
|
refreshToken, tokenID, err := a.authorizationCtrl.GenerateToken(id, authorization.TokenTypeRefresh)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.Aborted, "Couldn't generate an access token")
|
return nil, status.Error(codes.Aborted, "Couldn't generate an access token")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
session := &authorization.Session{UserID: id}
|
||||||
|
|
||||||
|
if err := a.authorizationCtrl.SaveSession(ctx, tokenID, session); err != nil {
|
||||||
|
return nil, status.Error(codes.Aborted, "Couldn't store session")
|
||||||
|
}
|
||||||
header := metadata.New(map[string]string{
|
header := metadata.New(map[string]string{
|
||||||
"X-Access-Token": accessToken,
|
"X-Access-Token": accessToken,
|
||||||
"X-Refresh-Token": refreshToken,
|
"X-Refresh-Token": refreshToken,
|
||||||
@@ -55,21 +66,27 @@ func (a *AccountsNoAuthServer) SignUp(ctx context.Context, in *accounts.SignUpRe
|
|||||||
Password: in.GetPassword(),
|
Password: in.GetPassword(),
|
||||||
Email: in.GetEmail(),
|
Email: in.GetEmail(),
|
||||||
}
|
}
|
||||||
id, err := a.ctrl.Create(ctx, data)
|
id, err := a.accountsCtrl.Create(ctx, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.Aborted, "Couldn't create a user")
|
return nil, status.Error(codes.Aborted, "Couldn't create a user")
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken, err := a.ctrl.GenerateAccessToken(id)
|
accessToken, _, err := a.authorizationCtrl.GenerateToken(id, authorization.TokenTypeAccess)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.Aborted, "Couldn't generate an access token")
|
return nil, status.Error(codes.Aborted, "Couldn't generate an access token")
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshToken, err := a.ctrl.GenerateRefreshToken(ctx, id)
|
refreshToken, tokenID, err := a.authorizationCtrl.GenerateToken(id, authorization.TokenTypeRefresh)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.Aborted, "Couldn't generate an access token")
|
return nil, status.Error(codes.Aborted, "Couldn't generate an access token")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
session := &authorization.Session{UserID: id}
|
||||||
|
|
||||||
|
if err := a.authorizationCtrl.SaveSession(ctx, tokenID, session); err != nil {
|
||||||
|
return nil, status.Error(codes.Aborted, "Couldn't store session")
|
||||||
|
}
|
||||||
|
|
||||||
header := metadata.New(map[string]string{
|
header := metadata.New(map[string]string{
|
||||||
"X-Access-Token": accessToken,
|
"X-Access-Token": accessToken,
|
||||||
"X-Refresh-Token": refreshToken,
|
"X-Refresh-Token": refreshToken,
|
||||||
@@ -79,33 +96,3 @@ func (a *AccountsNoAuthServer) SignUp(ctx context.Context, in *accounts.SignUpRe
|
|||||||
}
|
}
|
||||||
return &emptypb.Empty{}, nil
|
return &emptypb.Empty{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AccountsAuthServer) RefreshToken(ctx context.Context, in *empty.Empty) (*empty.Empty, error) {
|
|
||||||
//uuid, err := a.ctrl.ValidateRefreshToken(ctx, , userID)
|
|
||||||
//if err != nil {
|
|
||||||
// return nil, status.Error(codes.Unauthenticated, "refresh token is invalid")
|
|
||||||
//}
|
|
||||||
//accessToken, err := a.ctrl.GenerateAccessToken(uuid)
|
|
||||||
//if err != nil {
|
|
||||||
// log.Error(err, "Couldn't generate an access token")
|
|
||||||
// return nil, status.Error(codes.Aborted, "Couldn't generate Access Token")
|
|
||||||
//}
|
|
||||||
|
|
||||||
//refreshToken, err := a.ctrl.GenerateRefreshToken(ctx, uuid)
|
|
||||||
//if err != nil {
|
|
||||||
// log.Error(err, "Couldn't generate a refresh token")
|
|
||||||
// return nil, status.Error(codes.Aborted, "Couldn't generate Access Token")
|
|
||||||
//}
|
|
||||||
|
|
||||||
//header := metadata.Pairs(
|
|
||||||
// "access-token", accessToken,
|
|
||||||
// "refreshToken", refreshToken,
|
|
||||||
//)
|
|
||||||
|
|
||||||
//if err := grpc.SetHeader(ctx, header); err != nil {
|
|
||||||
// log.Error(err, "Couldn't set headers")
|
|
||||||
// return nil, status.Error(codes.Unknown, "Couldn't set headers")
|
|
||||||
//}
|
|
||||||
|
|
||||||
return nil, status.Error(codes.Unimplemented, "Endpoint is not Unimplemented yet")
|
|
||||||
}
|
|
||||||
|
|||||||
3
go.mod
3
go.mod
@@ -17,6 +17,7 @@ require (
|
|||||||
github.com/mattn/go-colorable v0.1.14
|
github.com/mattn/go-colorable v0.1.14
|
||||||
github.com/redis/go-redis/v9 v9.18.0
|
github.com/redis/go-redis/v9 v9.18.0
|
||||||
github.com/sirupsen/logrus v1.9.3
|
github.com/sirupsen/logrus v1.9.3
|
||||||
|
github.com/stretchr/testify v1.11.1
|
||||||
go.uber.org/zap v1.27.0
|
go.uber.org/zap v1.27.0
|
||||||
golang.org/x/crypto v0.47.0
|
golang.org/x/crypto v0.47.0
|
||||||
gopkg.in/yaml.v2 v2.4.0
|
gopkg.in/yaml.v2 v2.4.0
|
||||||
@@ -138,7 +139,7 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260508191738-b9850db6fe45
|
gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260509192230-bf7467c36f59
|
||||||
github.com/golang/protobuf v1.5.4
|
github.com/golang/protobuf v1.5.4
|
||||||
golang.org/x/net v0.49.0 // indirect
|
golang.org/x/net v0.49.0 // indirect
|
||||||
golang.org/x/sys v0.40.0 // indirect
|
golang.org/x/sys v0.40.0 // indirect
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -6,8 +6,8 @@ dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
|
|||||||
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
||||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||||
gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260508191738-b9850db6fe45 h1:7oWiqVQHcBI/7uGTLBubIV62/gRtxe//XjssZW1eWks=
|
gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260509192230-bf7467c36f59 h1:pI25/wjcfvX62PcxyZ/i7XPTxdyCV9tV34JFSWQxYNw=
|
||||||
gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260508191738-b9850db6fe45/go.mod h1:AgOh1lkPHyRgBf3/s1btKcAqke/33LbKYarTD13qeAg=
|
gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260509192230-bf7467c36f59/go.mod h1:AgOh1lkPHyRgBf3/s1btKcAqke/33LbKYarTD13qeAg=
|
||||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU=
|
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU=
|
||||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||||
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg=
|
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg=
|
||||||
|
|||||||
184
internal/authorization/auth.go
Normal file
184
internal/authorization/auth.go
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
package authorization
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TokenType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TokenTypeAccess TokenType = "access"
|
||||||
|
TokenTypeRefresh TokenType = "refresh"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrUnknownTokenType = errors.New("token type unknown")
|
||||||
|
ErrInvalidToken = errors.New("invalid token")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Claims struct {
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
TokenID string `json:"token_id"`
|
||||||
|
TokenType TokenType `json:"token_type"`
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthController struct {
|
||||||
|
jwtSecret []byte
|
||||||
|
accessTTL time.Duration
|
||||||
|
refreshTTL time.Duration
|
||||||
|
redis *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
const claimsContextKey contextKey = "jwt_claims"
|
||||||
|
|
||||||
|
func NewAuthController(jwtSecret []byte, accessTTL, refreshTTL time.Duration, redis *redis.Client) *AuthController {
|
||||||
|
return &AuthController{
|
||||||
|
jwtSecret: jwtSecret,
|
||||||
|
accessTTL: accessTTL,
|
||||||
|
refreshTTL: refreshTTL,
|
||||||
|
redis: redis,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write claims into context
|
||||||
|
func (a *AuthController) WithClaims(ctx context.Context, claims *Claims) context.Context {
|
||||||
|
return context.WithValue(ctx, claimsContextKey, claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract claims from context
|
||||||
|
func (a *AuthController) ClaimsFromContext(ctx context.Context) (*Claims, error) {
|
||||||
|
claims, ok := ctx.Value(claimsContextKey).(*Claims)
|
||||||
|
if !ok || claims == nil {
|
||||||
|
return nil, errors.New("claims not found in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthController) AuthInterceptorFN(ctx context.Context) (context.Context, error) {
|
||||||
|
tokenString, err := auth.AuthFromMD(ctx, "bearer")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := a.ParseToken(tokenString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Error(codes.Unauthenticated, "Invalid JWT token")
|
||||||
|
}
|
||||||
|
if method, ok := grpc.Method(ctx); ok {
|
||||||
|
if claims.TokenType == TokenTypeRefresh && !strings.Contains(method, "RefreshToken") {
|
||||||
|
return nil, status.Error(codes.Unauthenticated, "Refresh token is not allowed for this method")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = a.WithClaims(ctx, claims)
|
||||||
|
return ctx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate JWT token
|
||||||
|
func (a *AuthController) GenerateToken(userID string, tokenType TokenType) (token, tokenID string, err error) {
|
||||||
|
var expiresAt time.Time
|
||||||
|
notBefore := time.Now()
|
||||||
|
switch tokenType {
|
||||||
|
case TokenTypeAccess:
|
||||||
|
expiresAt = time.Now().Add(a.accessTTL)
|
||||||
|
case TokenTypeRefresh:
|
||||||
|
expiresAt = time.Now().Add(a.refreshTTL)
|
||||||
|
default:
|
||||||
|
return "", "", ErrUnknownTokenType
|
||||||
|
}
|
||||||
|
if tokenType != TokenTypeAccess && tokenType != TokenTypeRefresh {
|
||||||
|
return "", "", ErrUnknownTokenType
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenID = uuid.New().String()
|
||||||
|
claims := Claims{
|
||||||
|
UserID: userID,
|
||||||
|
TokenID: tokenID,
|
||||||
|
TokenType: tokenType,
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Issuer: "",
|
||||||
|
Subject: "",
|
||||||
|
Audience: jwt.ClaimStrings{},
|
||||||
|
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||||||
|
NotBefore: jwt.NewNumericDate(notBefore),
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
ID: userID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tokenJwt := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
token, err = tokenJwt.SignedString(a.jwtSecret)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthController) ParseToken(tokenStr string) (*Claims, error) {
|
||||||
|
token, err := jwt.ParseWithClaims(
|
||||||
|
tokenStr,
|
||||||
|
&Claims{},
|
||||||
|
func(token *jwt.Token) (interface{}, error) {
|
||||||
|
return a.jwtSecret, nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, ok := token.Claims.(*Claims)
|
||||||
|
if !ok || !token.Valid {
|
||||||
|
return nil, ErrInvalidToken
|
||||||
|
}
|
||||||
|
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func redisSessionKey(input string) string {
|
||||||
|
return fmt.Sprintf("session:%s", input)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthController) SaveSession(ctx context.Context, tokenID string, session *Session) error {
|
||||||
|
sessionJson, err := json.Marshal(session)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := a.redis.Set(ctx, redisSessionKey(tokenID), string(sessionJson), a.refreshTTL).Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthController) GetSession(ctx context.Context, tokenID string) (*Session, error) {
|
||||||
|
sessionRaw := a.redis.Get(ctx, redisSessionKey(tokenID)).Val()
|
||||||
|
if err := a.redis.Del(ctx, redisSessionKey(tokenID)).Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
session := &Session{}
|
||||||
|
if err := json.Unmarshal([]byte(sessionRaw), session); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
57
internal/authorization/auth_test.go
Normal file
57
internal/authorization/auth_test.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package authorization_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/authorization"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
testAccessTTL = time.Second * 5
|
||||||
|
testRefreshTTL = time.Second * 20
|
||||||
|
testUserID = uuid.New().String()
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateInvalidTokenType(t *testing.T) {
|
||||||
|
authCtrl := authorization.NewAuthController([]byte("test"), testAccessTTL, testRefreshTTL, nil)
|
||||||
|
token, _, err := authCtrl.GenerateToken(testUserID, "invalid_type")
|
||||||
|
assert.Equal(t, "", token)
|
||||||
|
assert.ErrorIs(t, authorization.ErrUnknownTokenType, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateValidateAccessToken(t *testing.T) {
|
||||||
|
authCtrl := authorization.NewAuthController([]byte("test"), testAccessTTL, testRefreshTTL, nil)
|
||||||
|
now := time.Now()
|
||||||
|
token, _, err := authCtrl.GenerateToken(testUserID, authorization.TokenTypeAccess)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, token)
|
||||||
|
|
||||||
|
claims, err := authCtrl.ParseToken(token)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, testUserID, claims.UserID)
|
||||||
|
assert.NotEmpty(t, claims.TokenID)
|
||||||
|
assert.Equal(t, authorization.TokenTypeAccess, claims.TokenType)
|
||||||
|
assert.Equal(t, now.Add(testAccessTTL).Unix(), claims.ExpiresAt.Unix())
|
||||||
|
assert.Equal(t, now.Unix(), claims.IssuedAt.Unix())
|
||||||
|
assert.Equal(t, now.Unix(), claims.NotBefore.Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateValidateRefreshToken(t *testing.T) {
|
||||||
|
authCtrl := authorization.NewAuthController([]byte("test"), testAccessTTL, testRefreshTTL, nil)
|
||||||
|
now := time.Now()
|
||||||
|
token, _, err := authCtrl.GenerateToken(testUserID, authorization.TokenTypeRefresh)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, token)
|
||||||
|
|
||||||
|
claims, err := authCtrl.ParseToken(token)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, testUserID, claims.UserID)
|
||||||
|
assert.NotEmpty(t, claims.TokenID)
|
||||||
|
assert.Equal(t, authorization.TokenTypeRefresh, claims.TokenType)
|
||||||
|
assert.Equal(t, now.Add(testRefreshTTL).Unix(), claims.ExpiresAt.Unix())
|
||||||
|
assert.Equal(t, now.Unix(), claims.IssuedAt.Unix())
|
||||||
|
assert.Equal(t, now.Unix(), claims.NotBefore.Unix())
|
||||||
|
}
|
||||||
@@ -77,10 +77,12 @@ func (c *AccountController) Login(ctx context.Context, email, password string) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *AccountController) GenerateAccessToken(userID string) (string, error) {
|
func (c *AccountController) GenerateAccessToken(userID string) (string, error) {
|
||||||
|
tokenID := uuid.New().String()
|
||||||
claims := jwt.MapClaims{
|
claims := jwt.MapClaims{
|
||||||
"user_id": userID,
|
"user_id": userID,
|
||||||
"type": "access",
|
"type": "access",
|
||||||
"exp": time.Now().Add(c.AccessTokenTTL).Unix(),
|
"exp": time.Now().Add(c.AccessTokenTTL).Unix(),
|
||||||
|
"token_id": tokenID,
|
||||||
}
|
}
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
|||||||
@@ -1,64 +0,0 @@
|
|||||||
package interceptors
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/tools/logger"
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/metadata"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
)
|
|
||||||
|
|
||||||
type JWTVerifier struct {
|
|
||||||
secret []byte
|
|
||||||
serverCtx context.Context
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewJWTVerifier(ctx context.Context, secret []byte) *JWTVerifier {
|
|
||||||
return &JWTVerifier{
|
|
||||||
serverCtx: ctx,
|
|
||||||
secret: secret,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This is an interceptors that should verify that a user is authorized
|
|
||||||
func (v *JWTVerifier) JWTAuthInterceptor(
|
|
||||||
ctx context.Context,
|
|
||||||
req interface{},
|
|
||||||
info *grpc.UnaryServerInfo,
|
|
||||||
handler grpc.UnaryHandler,
|
|
||||||
) (interface{}, error) {
|
|
||||||
log := logger.FromContext(v.serverCtx).WithValues("method", info.FullMethod)
|
|
||||||
if !strings.Contains(info.FullMethod, "NoAuth") {
|
|
||||||
log.Info("Checking the JWT token")
|
|
||||||
md, ok := metadata.FromIncomingContext(ctx)
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Error(codes.Unauthenticated, "User is not authorized")
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenString := md.Get("token")[0]
|
|
||||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) {
|
|
||||||
// hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key")
|
|
||||||
return v.secret, nil
|
|
||||||
}, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}))
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Error(codes.Unauthenticated, "User is not authorized")
|
|
||||||
}
|
|
||||||
|
|
||||||
if claims, ok := token.Claims.(jwt.MapClaims); ok {
|
|
||||||
fmt.Println(claims["userID"])
|
|
||||||
} else {
|
|
||||||
fmt.Println(err)
|
|
||||||
}
|
|
||||||
// Get the token from the metadata
|
|
||||||
// Validate the token
|
|
||||||
// Get the user id from the token
|
|
||||||
} else {
|
|
||||||
log.Info("Auth is not required for this request")
|
|
||||||
}
|
|
||||||
return handler(ctx, req)
|
|
||||||
}
|
|
||||||
51
main.go
51
main.go
@@ -10,12 +10,12 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
v1 "gitea.badhouseplants.net/softplayer/softplayer-backend/api/v1"
|
v1 "gitea.badhouseplants.net/softplayer/softplayer-backend/api/v1"
|
||||||
|
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/authorization"
|
||||||
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/controllers"
|
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/controllers"
|
||||||
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/tools/logger"
|
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/tools/logger"
|
||||||
accounts "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/accounts/v1"
|
accounts "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/accounts/v1"
|
||||||
test "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/test/v1"
|
test "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/test/v1"
|
||||||
"github.com/alecthomas/kong"
|
"github.com/alecthomas/kong"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"github.com/golang-migrate/migrate/v4"
|
"github.com/golang-migrate/migrate/v4"
|
||||||
"github.com/golang-migrate/migrate/v4/database/postgres"
|
"github.com/golang-migrate/migrate/v4/database/postgres"
|
||||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||||
@@ -25,9 +25,7 @@ import (
|
|||||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/selector"
|
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/selector"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/reflection"
|
"google.golang.org/grpc/reflection"
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
)
|
)
|
||||||
@@ -161,48 +159,29 @@ func server(ctx context.Context, params Serve) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// jwtVerifier := interceptors.NewJWTVerifier(ctx, []byte(params.JWTSecret))
|
|
||||||
|
|
||||||
authFn := func(ctx context.Context) (context.Context, error) {
|
|
||||||
tokenString, err := auth.AuthFromMD(ctx, "bearer")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) {
|
|
||||||
return []byte(params.JWTSecret), nil
|
|
||||||
}, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}))
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Error(codes.Unauthenticated, "User is not authorized")
|
|
||||||
}
|
|
||||||
|
|
||||||
if claims, ok := token.Claims.(jwt.MapClaims); ok {
|
|
||||||
userIDRaw, userIDOk := claims["user_id"]
|
|
||||||
if !userIDOk {
|
|
||||||
return nil, errors.New("required claims are missing in the token")
|
|
||||||
}
|
|
||||||
userID := userIDRaw.(string)
|
|
||||||
log.Info(userID)
|
|
||||||
} else {
|
|
||||||
return ctx, errors.New("claims are missing in the token")
|
|
||||||
}
|
|
||||||
return ctx, nil
|
|
||||||
}
|
|
||||||
authReqServices := func(ctx context.Context, callMeta interceptors.CallMeta) bool {
|
authReqServices := func(ctx context.Context, callMeta interceptors.CallMeta) bool {
|
||||||
return !strings.Contains(callMeta.Service, "NoAuth")
|
return !strings.Contains(callMeta.Service, "NoAuth")
|
||||||
}
|
}
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: params.RedisHost,
|
||||||
|
})
|
||||||
|
|
||||||
|
authInterceptor := authorization.NewAuthController(
|
||||||
|
[]byte(params.JWTSecret),
|
||||||
|
params.AccessTokenTTL,
|
||||||
|
params.RefrestTokenTTL,
|
||||||
|
rdb,
|
||||||
|
)
|
||||||
|
|
||||||
grpcServer := grpc.NewServer(
|
grpcServer := grpc.NewServer(
|
||||||
grpc.ChainUnaryInterceptor(
|
grpc.ChainUnaryInterceptor(
|
||||||
grpc_zap.UnaryServerInterceptor(logger.SetupLogger("info")),
|
grpc_zap.UnaryServerInterceptor(logger.SetupLogger("info")),
|
||||||
// jwtVerifier.JWTAuthInterceptor,
|
// jwtVerifier.JWTAuthInterceptor,
|
||||||
selector.UnaryServerInterceptor(auth.UnaryServerInterceptor(authFn), selector.MatchFunc(authReqServices)),
|
selector.UnaryServerInterceptor(auth.UnaryServerInterceptor(authInterceptor.AuthInterceptorFN), selector.MatchFunc(authReqServices)),
|
||||||
),
|
),
|
||||||
grpc.StreamInterceptor(grpc_zap.StreamServerInterceptor(logger.SetupLogger("info"))),
|
grpc.StreamInterceptor(grpc_zap.StreamServerInterceptor(logger.SetupLogger("info"))),
|
||||||
)
|
)
|
||||||
|
|
||||||
rdb := redis.NewClient(&redis.Options{
|
|
||||||
Addr: params.RedisHost,
|
|
||||||
})
|
|
||||||
if params.Reflection {
|
if params.Reflection {
|
||||||
reflection.Register(grpcServer)
|
reflection.Register(grpcServer)
|
||||||
}
|
}
|
||||||
@@ -216,8 +195,8 @@ func server(ctx context.Context, params Serve) error {
|
|||||||
JWTSecret: []byte(params.JWTSecret),
|
JWTSecret: []byte(params.JWTSecret),
|
||||||
Redis: rdb,
|
Redis: rdb,
|
||||||
}
|
}
|
||||||
accounts.RegisterAccountsNoAuthServiceServer(grpcServer, v1.NewAccountNoAuthRPCImpl(accountCtrl))
|
accounts.RegisterAccountsNoAuthServiceServer(grpcServer, v1.NewAccountNoAuthRPCImpl(accountCtrl, authInterceptor))
|
||||||
accounts.RegisterAccountsAuthServiceServer(grpcServer, v1.NewAccountAuthRPCImpl(accountCtrl))
|
accounts.RegisterAccountsAuthServiceServer(grpcServer, v1.NewAccountAuthRPCImpl(accountCtrl, authInterceptor))
|
||||||
test.RegisterTestAuthServiceServer(grpcServer, v1.NewTestAuthRPCImpl())
|
test.RegisterTestAuthServiceServer(grpcServer, v1.NewTestAuthRPCImpl())
|
||||||
test.RegisterTestNoAuthServiceServer(grpcServer, v1.NewTestNoAuthRPCImpl())
|
test.RegisterTestNoAuthServiceServer(grpcServer, v1.NewTestNoAuthRPCImpl())
|
||||||
if err := grpcServer.Serve(lis); err != nil {
|
if err := grpcServer.Serve(lis); err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user