diff --git a/api/v1/accounts_no_auth.go b/api/v1/accounts_no_auth.go index 4331e9c..e099fc1 100644 --- a/api/v1/accounts_no_auth.go +++ b/api/v1/accounts_no_auth.go @@ -76,3 +76,41 @@ func (a *AccountsNoAuthServer) ResetPassword(ctx context.Context, in *accounts.R func (acc *AccountsNoAuthServer) NewPassword(ctx context.Context, in *accounts.NewPasswordRequest) (*empty.Empty, error) { return nil, status.Error(codes.Unimplemented, "Endpoint is not implemented") } + +func (a *AccountsNoAuthServer) RefreshToken(ctx context.Context, in *empty.Empty) (*empty.Empty, error) { + log := logger.FromContext(ctx) + + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Error(codes.Unauthenticated, "User is not authorized") + } + + tokenString := md.Get("token")[0] + uuid, err := a.ctrl.ValidateRefreshToken(ctx, tokenString) + if err != nil { + return nil, status.Error(codes.Unauthenticated, "refresh token is invalid") + } + accessToken, err := a.ctrl.GenerateAccessToken(uuid) + if err != nil { + log.Error(err, "Couldn't generate an access token") + return nil, status.Error(codes.Aborted, "Couldn't generate Access Token") + } + + refreshToken, err := a.ctrl.GenerateRefreshToken(ctx, uuid) + if err != nil { + log.Error(err, "Couldn't generate a refresh token") + return nil, status.Error(codes.Aborted, "Couldn't generate Access Token") + } + + header := metadata.Pairs( + "access-token", accessToken, + "refreshToken", refreshToken, + ) + + if err := grpc.SetHeader(ctx, header); err != nil { + log.Error(err, "Couldn't set headers") + return nil, status.Error(codes.Unknown, "Couldn't set headers") + } + + return &emptypb.Empty{}, nil +} diff --git a/internal/controllers/accounts.go b/internal/controllers/accounts.go index e11d151..6495932 100644 --- a/internal/controllers/accounts.go +++ b/internal/controllers/accounts.go @@ -3,6 +3,7 @@ package controllers import ( "context" "database/sql" + "errors" "fmt" "time" @@ -79,3 +80,34 @@ func (c *AccountController) GenerateRefreshToken(ctx context.Context, userID str } return token.SignedString(c.JWTSecret) } + +// It must validate the refresh token +// Get it's id from the content +// Find a corresponding token in redis, and if it's found, remove it and create a new one +func (c *AccountController) ValidateRefreshToken(ctx context.Context, tokenString string) (string, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) { + // hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key") + return c.JWTSecret, nil + }, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()})) + if err != nil { + return "", err + } + + var tokenID string + var userID string + if claims, ok := token.Claims.(jwt.MapClaims); ok { + tokenID = claims["token_id"].(string) + userID = claims["user_id"].(string) + } else { + return "", errors.New("token id is not set") + } + + userIDRedis := c.Redis.Get(ctx, tokenID).String() + if c.Redis.Del(ctx, tokenID).Err() != nil { + return "", err + } + if userID != userIDRedis { + return "", errors.New("user id doesn't match") + } + return userIDRedis, nil +}