From e58eba1b165e57f207d093d00488d07c68c78ee7 Mon Sep 17 00:00:00 2001 From: Nikolai Rodionov Date: Sat, 9 May 2026 21:36:23 +0200 Subject: [PATCH] Implement refresh token endpoint Signed-off-by: Nikolai Rodionov --- Taskfile.yml | 2 + api/v1/accounts_auth.go | 66 +++++++++- api/v1/accounts_no_auth.go | 65 ++++------ go.mod | 3 +- go.sum | 4 +- internal/authorization/auth.go | 184 ++++++++++++++++++++++++++++ internal/authorization/auth_test.go | 57 +++++++++ internal/controllers/accounts.go | 8 +- internal/interceptors/authjwt.go | 64 ---------- main.go | 51 +++----- 10 files changed, 356 insertions(+), 148 deletions(-) create mode 100644 internal/authorization/auth.go create mode 100644 internal/authorization/auth_test.go delete mode 100644 internal/interceptors/authjwt.go diff --git a/Taskfile.yml b/Taskfile.yml index b0e3b0b..e7ede5c 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -21,6 +21,8 @@ tasks: run-server-dev: desc: Run the local dev server + deps: + - run-migrations-dev env: SOFTPLAYER_DB_CONNECTION_STRING: postgres://softplayer:qwertyu9@localhost:30432/softplayer?sslmode=disable SOFTPLAYER_REDIS_HOST: localhost:30379 diff --git a/api/v1/accounts_auth.go b/api/v1/accounts_auth.go index 2c42abe..dee0454 100644 --- a/api/v1/accounts_auth.go +++ b/api/v1/accounts_auth.go @@ -1,17 +1,77 @@ package v1 import ( + "context" + + "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/authorization" "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/controllers" accounts "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/accounts/v1" + "github.com/golang/protobuf/ptypes/empty" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/emptypb" ) -func NewAccountAuthRPCImpl(ctrl *controllers.AccountController) *AccountsAuthServer { +func NewAccountAuthRPCImpl( + accountsCtrl *controllers.AccountController, + authorizationCtrl *authorization.AuthController, +) *AccountsAuthServer { return &AccountsAuthServer{ - ctrl: ctrl, + accountsCtrl: accountsCtrl, + authorizationCtrl: authorizationCtrl, } } type AccountsAuthServer struct { accounts.UnimplementedAccountsAuthServiceServer - ctrl *controllers.AccountController + accountsCtrl *controllers.AccountController + authorizationCtrl *authorization.AuthController +} + +func (a *AccountsAuthServer) RefreshToken(ctx context.Context, in *empty.Empty) (*empty.Empty, error) { + claims, err := a.authorizationCtrl.ClaimsFromContext(ctx) + if err != nil { + return nil, status.Error(codes.Aborted, "Context is invalid") + } + + if claims.TokenType != authorization.TokenTypeRefresh { + return nil, status.Error(codes.Unauthenticated, "Invalid token") + } + + session, err := a.authorizationCtrl.GetSession(ctx, claims.TokenID) + if err != nil { + return nil, status.Error(codes.Unauthenticated, "Session doesn't exists") + } + + if session.UserID != claims.UserID { + return nil, status.Error(codes.Unauthenticated, "Invalid session") + } + + accessToken, _, err := a.authorizationCtrl.GenerateToken(session.UserID, authorization.TokenTypeAccess) + if err != nil { + return nil, status.Error(codes.Aborted, "Couldn't generate an access token") + } + + refreshToken, tokenID, err := a.authorizationCtrl.GenerateToken(session.UserID, authorization.TokenTypeRefresh) + if err != nil { + return nil, status.Error(codes.Aborted, "Couldn't generate an access token") + } + + newSession := &authorization.Session{UserID: session.UserID} + + if err := a.authorizationCtrl.SaveSession(ctx, tokenID, newSession); err != nil { + return nil, status.Error(codes.Aborted, "Couldn't store session") + } + + header := metadata.New(map[string]string{ + "X-Access-Token": accessToken, + "X-Refresh-Token": refreshToken, + }) + if err := grpc.SetHeader(ctx, header); err != nil { + return nil, status.Error(codes.Aborted, "Couldn't set metadata") + } + + return &emptypb.Empty{}, nil } diff --git a/api/v1/accounts_no_auth.go b/api/v1/accounts_no_auth.go index 431225b..c4fba1b 100644 --- a/api/v1/accounts_no_auth.go +++ b/api/v1/accounts_no_auth.go @@ -3,6 +3,7 @@ package v1 import ( "context" + "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/authorization" "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/controllers" accounts "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/accounts/v1" "github.com/golang/protobuf/ptypes/empty" @@ -13,32 +14,42 @@ import ( "google.golang.org/protobuf/types/known/emptypb" ) -func NewAccountNoAuthRPCImpl(ctrl *controllers.AccountController) *AccountsNoAuthServer { +func NewAccountNoAuthRPCImpl( + accountsCtrl *controllers.AccountController, + authorizationCtrl *authorization.AuthController, +) *AccountsNoAuthServer { return &AccountsNoAuthServer{ - ctrl: ctrl, + accountsCtrl: accountsCtrl, + authorizationCtrl: authorizationCtrl, } } type AccountsNoAuthServer struct { accounts.UnimplementedAccountsNoAuthServiceServer - ctrl *controllers.AccountController + accountsCtrl *controllers.AccountController + authorizationCtrl *authorization.AuthController } func (a *AccountsNoAuthServer) SignIn(ctx context.Context, in *accounts.SignInRequest) (*empty.Empty, error) { - id, err := a.ctrl.Login(ctx, in.GetEmail(), in.GetPassword()) + id, err := a.accountsCtrl.Login(ctx, in.GetEmail(), in.GetPassword()) if err != nil { return nil, status.Error(codes.Aborted, "Couldn't create a user") } - accessToken, err := a.ctrl.GenerateAccessToken(id) + accessToken, _, err := a.authorizationCtrl.GenerateToken(id, authorization.TokenTypeAccess) if err != nil { return nil, status.Error(codes.Aborted, "Couldn't generate an access token") } - refreshToken, err := a.ctrl.GenerateRefreshToken(ctx, id) + refreshToken, tokenID, err := a.authorizationCtrl.GenerateToken(id, authorization.TokenTypeRefresh) if err != nil { return nil, status.Error(codes.Aborted, "Couldn't generate an access token") } + session := &authorization.Session{UserID: id} + + if err := a.authorizationCtrl.SaveSession(ctx, tokenID, session); err != nil { + return nil, status.Error(codes.Aborted, "Couldn't store session") + } header := metadata.New(map[string]string{ "X-Access-Token": accessToken, "X-Refresh-Token": refreshToken, @@ -55,21 +66,27 @@ func (a *AccountsNoAuthServer) SignUp(ctx context.Context, in *accounts.SignUpRe Password: in.GetPassword(), Email: in.GetEmail(), } - id, err := a.ctrl.Create(ctx, data) + id, err := a.accountsCtrl.Create(ctx, data) if err != nil { return nil, status.Error(codes.Aborted, "Couldn't create a user") } - accessToken, err := a.ctrl.GenerateAccessToken(id) + accessToken, _, err := a.authorizationCtrl.GenerateToken(id, authorization.TokenTypeAccess) if err != nil { return nil, status.Error(codes.Aborted, "Couldn't generate an access token") } - refreshToken, err := a.ctrl.GenerateRefreshToken(ctx, id) + refreshToken, tokenID, err := a.authorizationCtrl.GenerateToken(id, authorization.TokenTypeRefresh) if err != nil { return nil, status.Error(codes.Aborted, "Couldn't generate an access token") } + session := &authorization.Session{UserID: id} + + if err := a.authorizationCtrl.SaveSession(ctx, tokenID, session); err != nil { + return nil, status.Error(codes.Aborted, "Couldn't store session") + } + header := metadata.New(map[string]string{ "X-Access-Token": accessToken, "X-Refresh-Token": refreshToken, @@ -79,33 +96,3 @@ func (a *AccountsNoAuthServer) SignUp(ctx context.Context, in *accounts.SignUpRe } return &emptypb.Empty{}, nil } - -func (a *AccountsAuthServer) RefreshToken(ctx context.Context, in *empty.Empty) (*empty.Empty, error) { - //uuid, err := a.ctrl.ValidateRefreshToken(ctx, , userID) - //if err != nil { - // return nil, status.Error(codes.Unauthenticated, "refresh token is invalid") - //} - //accessToken, err := a.ctrl.GenerateAccessToken(uuid) - //if err != nil { - // log.Error(err, "Couldn't generate an access token") - // return nil, status.Error(codes.Aborted, "Couldn't generate Access Token") - //} - - //refreshToken, err := a.ctrl.GenerateRefreshToken(ctx, uuid) - //if err != nil { - // log.Error(err, "Couldn't generate a refresh token") - // return nil, status.Error(codes.Aborted, "Couldn't generate Access Token") - //} - - //header := metadata.Pairs( - // "access-token", accessToken, - // "refreshToken", refreshToken, - //) - - //if err := grpc.SetHeader(ctx, header); err != nil { - // log.Error(err, "Couldn't set headers") - // return nil, status.Error(codes.Unknown, "Couldn't set headers") - //} - - return nil, status.Error(codes.Unimplemented, "Endpoint is not Unimplemented yet") -} diff --git a/go.mod b/go.mod index cb5d648..b728fb6 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/mattn/go-colorable v0.1.14 github.com/redis/go-redis/v9 v9.18.0 github.com/sirupsen/logrus v1.9.3 + github.com/stretchr/testify v1.11.1 go.uber.org/zap v1.27.0 golang.org/x/crypto v0.47.0 gopkg.in/yaml.v2 v2.4.0 @@ -138,7 +139,7 @@ require ( ) require ( - gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260508191738-b9850db6fe45 + gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260509192230-bf7467c36f59 github.com/golang/protobuf v1.5.4 golang.org/x/net v0.49.0 // indirect golang.org/x/sys v0.40.0 // indirect diff --git a/go.sum b/go.sum index 1a1c7d3..796c637 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,8 @@ dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= -gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260508191738-b9850db6fe45 h1:7oWiqVQHcBI/7uGTLBubIV62/gRtxe//XjssZW1eWks= -gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260508191738-b9850db6fe45/go.mod h1:AgOh1lkPHyRgBf3/s1btKcAqke/33LbKYarTD13qeAg= +gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260509192230-bf7467c36f59 h1:pI25/wjcfvX62PcxyZ/i7XPTxdyCV9tV34JFSWQxYNw= +gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260509192230-bf7467c36f59/go.mod h1:AgOh1lkPHyRgBf3/s1btKcAqke/33LbKYarTD13qeAg= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg= diff --git a/internal/authorization/auth.go b/internal/authorization/auth.go new file mode 100644 index 0000000..f22e6ed --- /dev/null +++ b/internal/authorization/auth.go @@ -0,0 +1,184 @@ +package authorization + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth" + "github.com/redis/go-redis/v9" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type TokenType string + +const ( + TokenTypeAccess TokenType = "access" + TokenTypeRefresh TokenType = "refresh" +) + +var ( + ErrUnknownTokenType = errors.New("token type unknown") + ErrInvalidToken = errors.New("invalid token") +) + +type Claims struct { + UserID string `json:"user_id"` + TokenID string `json:"token_id"` + TokenType TokenType `json:"token_type"` + jwt.RegisteredClaims +} + +type AuthController struct { + jwtSecret []byte + accessTTL time.Duration + refreshTTL time.Duration + redis *redis.Client +} + +type contextKey string + +const claimsContextKey contextKey = "jwt_claims" + +func NewAuthController(jwtSecret []byte, accessTTL, refreshTTL time.Duration, redis *redis.Client) *AuthController { + return &AuthController{ + jwtSecret: jwtSecret, + accessTTL: accessTTL, + refreshTTL: refreshTTL, + redis: redis, + } +} + +// Write claims into context +func (a *AuthController) WithClaims(ctx context.Context, claims *Claims) context.Context { + return context.WithValue(ctx, claimsContextKey, claims) +} + +// Extract claims from context +func (a *AuthController) ClaimsFromContext(ctx context.Context) (*Claims, error) { + claims, ok := ctx.Value(claimsContextKey).(*Claims) + if !ok || claims == nil { + return nil, errors.New("claims not found in context") + } + + return claims, nil +} + +func (a *AuthController) AuthInterceptorFN(ctx context.Context) (context.Context, error) { + tokenString, err := auth.AuthFromMD(ctx, "bearer") + if err != nil { + return nil, err + } + + claims, err := a.ParseToken(tokenString) + if err != nil { + return nil, status.Error(codes.Unauthenticated, "Invalid JWT token") + } + if method, ok := grpc.Method(ctx); ok { + if claims.TokenType == TokenTypeRefresh && !strings.Contains(method, "RefreshToken") { + return nil, status.Error(codes.Unauthenticated, "Refresh token is not allowed for this method") + } + } + + ctx = a.WithClaims(ctx, claims) + return ctx, nil +} + +// Generate JWT token +func (a *AuthController) GenerateToken(userID string, tokenType TokenType) (token, tokenID string, err error) { + var expiresAt time.Time + notBefore := time.Now() + switch tokenType { + case TokenTypeAccess: + expiresAt = time.Now().Add(a.accessTTL) + case TokenTypeRefresh: + expiresAt = time.Now().Add(a.refreshTTL) + default: + return "", "", ErrUnknownTokenType + } + if tokenType != TokenTypeAccess && tokenType != TokenTypeRefresh { + return "", "", ErrUnknownTokenType + } + + tokenID = uuid.New().String() + claims := Claims{ + UserID: userID, + TokenID: tokenID, + TokenType: tokenType, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "", + Subject: "", + Audience: jwt.ClaimStrings{}, + ExpiresAt: jwt.NewNumericDate(expiresAt), + NotBefore: jwt.NewNumericDate(notBefore), + IssuedAt: jwt.NewNumericDate(time.Now()), + ID: userID, + }, + } + tokenJwt := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token, err = tokenJwt.SignedString(a.jwtSecret) + if err != nil { + return "", "", err + } + return +} + +func (a *AuthController) ParseToken(tokenStr string) (*Claims, error) { + token, err := jwt.ParseWithClaims( + tokenStr, + &Claims{}, + func(token *jwt.Token) (interface{}, error) { + return a.jwtSecret, nil + }, + ) + if err != nil { + return nil, err + } + + claims, ok := token.Claims.(*Claims) + if !ok || !token.Valid { + return nil, ErrInvalidToken + } + + return claims, nil +} + +type Session struct { + UserID string `json:"user_id"` +} + +func redisSessionKey(input string) string { + return fmt.Sprintf("session:%s", input) +} + +func (a *AuthController) SaveSession(ctx context.Context, tokenID string, session *Session) error { + sessionJson, err := json.Marshal(session) + if err != nil { + return err + } + + if err := a.redis.Set(ctx, redisSessionKey(tokenID), string(sessionJson), a.refreshTTL).Err(); err != nil { + return err + } + + return nil +} + +func (a *AuthController) GetSession(ctx context.Context, tokenID string) (*Session, error) { + sessionRaw := a.redis.Get(ctx, redisSessionKey(tokenID)).Val() + if err := a.redis.Del(ctx, redisSessionKey(tokenID)).Err(); err != nil { + return nil, err + } + session := &Session{} + if err := json.Unmarshal([]byte(sessionRaw), session); err != nil { + return nil, err + } + return session, nil +} diff --git a/internal/authorization/auth_test.go b/internal/authorization/auth_test.go new file mode 100644 index 0000000..c4bd9ae --- /dev/null +++ b/internal/authorization/auth_test.go @@ -0,0 +1,57 @@ +package authorization_test + +import ( + "testing" + "time" + + "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/authorization" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +var ( + testAccessTTL = time.Second * 5 + testRefreshTTL = time.Second * 20 + testUserID = uuid.New().String() +) + +func TestGenerateInvalidTokenType(t *testing.T) { + authCtrl := authorization.NewAuthController([]byte("test"), testAccessTTL, testRefreshTTL, nil) + token, _, err := authCtrl.GenerateToken(testUserID, "invalid_type") + assert.Equal(t, "", token) + assert.ErrorIs(t, authorization.ErrUnknownTokenType, err) +} + +func TestGenerateValidateAccessToken(t *testing.T) { + authCtrl := authorization.NewAuthController([]byte("test"), testAccessTTL, testRefreshTTL, nil) + now := time.Now() + token, _, err := authCtrl.GenerateToken(testUserID, authorization.TokenTypeAccess) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + claims, err := authCtrl.ParseToken(token) + assert.NoError(t, err) + assert.Equal(t, testUserID, claims.UserID) + assert.NotEmpty(t, claims.TokenID) + assert.Equal(t, authorization.TokenTypeAccess, claims.TokenType) + assert.Equal(t, now.Add(testAccessTTL).Unix(), claims.ExpiresAt.Unix()) + assert.Equal(t, now.Unix(), claims.IssuedAt.Unix()) + assert.Equal(t, now.Unix(), claims.NotBefore.Unix()) +} + +func TestGenerateValidateRefreshToken(t *testing.T) { + authCtrl := authorization.NewAuthController([]byte("test"), testAccessTTL, testRefreshTTL, nil) + now := time.Now() + token, _, err := authCtrl.GenerateToken(testUserID, authorization.TokenTypeRefresh) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + claims, err := authCtrl.ParseToken(token) + assert.NoError(t, err) + assert.Equal(t, testUserID, claims.UserID) + assert.NotEmpty(t, claims.TokenID) + assert.Equal(t, authorization.TokenTypeRefresh, claims.TokenType) + assert.Equal(t, now.Add(testRefreshTTL).Unix(), claims.ExpiresAt.Unix()) + assert.Equal(t, now.Unix(), claims.IssuedAt.Unix()) + assert.Equal(t, now.Unix(), claims.NotBefore.Unix()) +} diff --git a/internal/controllers/accounts.go b/internal/controllers/accounts.go index d941f4a..14e1217 100644 --- a/internal/controllers/accounts.go +++ b/internal/controllers/accounts.go @@ -77,10 +77,12 @@ func (c *AccountController) Login(ctx context.Context, email, password string) ( } func (c *AccountController) GenerateAccessToken(userID string) (string, error) { + tokenID := uuid.New().String() claims := jwt.MapClaims{ - "user_id": userID, - "type": "access", - "exp": time.Now().Add(c.AccessTokenTTL).Unix(), + "user_id": userID, + "type": "access", + "exp": time.Now().Add(c.AccessTokenTTL).Unix(), + "token_id": tokenID, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) diff --git a/internal/interceptors/authjwt.go b/internal/interceptors/authjwt.go deleted file mode 100644 index 30eca99..0000000 --- a/internal/interceptors/authjwt.go +++ /dev/null @@ -1,64 +0,0 @@ -package interceptors - -import ( - "context" - "fmt" - "strings" - - "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/tools/logger" - "github.com/golang-jwt/jwt/v5" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/status" -) - -type JWTVerifier struct { - secret []byte - serverCtx context.Context -} - -func NewJWTVerifier(ctx context.Context, secret []byte) *JWTVerifier { - return &JWTVerifier{ - serverCtx: ctx, - secret: secret, - } -} - -// This is an interceptors that should verify that a user is authorized -func (v *JWTVerifier) JWTAuthInterceptor( - ctx context.Context, - req interface{}, - info *grpc.UnaryServerInfo, - handler grpc.UnaryHandler, -) (interface{}, error) { - log := logger.FromContext(v.serverCtx).WithValues("method", info.FullMethod) - if !strings.Contains(info.FullMethod, "NoAuth") { - log.Info("Checking the JWT token") - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return nil, status.Error(codes.Unauthenticated, "User is not authorized") - } - - tokenString := md.Get("token")[0] - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) { - // hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key") - return v.secret, nil - }, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()})) - if err != nil { - return nil, status.Error(codes.Unauthenticated, "User is not authorized") - } - - if claims, ok := token.Claims.(jwt.MapClaims); ok { - fmt.Println(claims["userID"]) - } else { - fmt.Println(err) - } - // Get the token from the metadata - // Validate the token - // Get the user id from the token - } else { - log.Info("Auth is not required for this request") - } - return handler(ctx, req) -} diff --git a/main.go b/main.go index d3e8b63..325fcc5 100644 --- a/main.go +++ b/main.go @@ -10,12 +10,12 @@ import ( "time" v1 "gitea.badhouseplants.net/softplayer/softplayer-backend/api/v1" + "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/authorization" "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/controllers" "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/tools/logger" accounts "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/accounts/v1" test "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/test/v1" "github.com/alecthomas/kong" - "github.com/golang-jwt/jwt/v5" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" @@ -25,9 +25,7 @@ import ( "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/selector" _ "github.com/lib/pq" "google.golang.org/grpc" - "google.golang.org/grpc/codes" "google.golang.org/grpc/reflection" - "google.golang.org/grpc/status" "github.com/redis/go-redis/v9" ) @@ -161,48 +159,29 @@ func server(ctx context.Context, params Serve) error { return err } - // jwtVerifier := interceptors.NewJWTVerifier(ctx, []byte(params.JWTSecret)) - - authFn := func(ctx context.Context) (context.Context, error) { - tokenString, err := auth.AuthFromMD(ctx, "bearer") - if err != nil { - return nil, err - } - - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) { - return []byte(params.JWTSecret), nil - }, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()})) - if err != nil { - return nil, status.Error(codes.Unauthenticated, "User is not authorized") - } - - if claims, ok := token.Claims.(jwt.MapClaims); ok { - userIDRaw, userIDOk := claims["user_id"] - if !userIDOk { - return nil, errors.New("required claims are missing in the token") - } - userID := userIDRaw.(string) - log.Info(userID) - } else { - return ctx, errors.New("claims are missing in the token") - } - return ctx, nil - } authReqServices := func(ctx context.Context, callMeta interceptors.CallMeta) bool { return !strings.Contains(callMeta.Service, "NoAuth") } + rdb := redis.NewClient(&redis.Options{ + Addr: params.RedisHost, + }) + + authInterceptor := authorization.NewAuthController( + []byte(params.JWTSecret), + params.AccessTokenTTL, + params.RefrestTokenTTL, + rdb, + ) + grpcServer := grpc.NewServer( grpc.ChainUnaryInterceptor( grpc_zap.UnaryServerInterceptor(logger.SetupLogger("info")), // jwtVerifier.JWTAuthInterceptor, - selector.UnaryServerInterceptor(auth.UnaryServerInterceptor(authFn), selector.MatchFunc(authReqServices)), + selector.UnaryServerInterceptor(auth.UnaryServerInterceptor(authInterceptor.AuthInterceptorFN), selector.MatchFunc(authReqServices)), ), grpc.StreamInterceptor(grpc_zap.StreamServerInterceptor(logger.SetupLogger("info"))), ) - rdb := redis.NewClient(&redis.Options{ - Addr: params.RedisHost, - }) if params.Reflection { reflection.Register(grpcServer) } @@ -216,8 +195,8 @@ func server(ctx context.Context, params Serve) error { JWTSecret: []byte(params.JWTSecret), Redis: rdb, } - accounts.RegisterAccountsNoAuthServiceServer(grpcServer, v1.NewAccountNoAuthRPCImpl(accountCtrl)) - accounts.RegisterAccountsAuthServiceServer(grpcServer, v1.NewAccountAuthRPCImpl(accountCtrl)) + accounts.RegisterAccountsNoAuthServiceServer(grpcServer, v1.NewAccountNoAuthRPCImpl(accountCtrl, authInterceptor)) + accounts.RegisterAccountsAuthServiceServer(grpcServer, v1.NewAccountAuthRPCImpl(accountCtrl, authInterceptor)) test.RegisterTestAuthServiceServer(grpcServer, v1.NewTestAuthRPCImpl()) test.RegisterTestNoAuthServiceServer(grpcServer, v1.NewTestNoAuthRPCImpl()) if err := grpcServer.Serve(lis); err != nil {