diff --git a/api/v1/accounts.go b/api/v1/accounts.go index e268f3a..95de164 100644 --- a/api/v1/accounts.go +++ b/api/v1/accounts.go @@ -48,16 +48,23 @@ func (a *AccountsServer) RefreshToken(ctx context.Context, in *empty.Empty) (*em return nil, status.Error(codes.Unauthenticated, "Invalid session") } - accessToken, _, err := a.authorizationCtrl.GenerateToken(session.UserID, controllers.TokenTypeAccess) + accessToken, _, err := a.authorizationCtrl.GenerateToken(&controllers.JWTData{ + UserID: claims.UserID, + TokenType: controllers.TokenTypeAccess, + TokenAud: controllers.TokenAudWeb, + }) if err != nil { return nil, status.Error(codes.Aborted, "Couldn't generate an access token") } - refreshToken, tokenID, err := a.authorizationCtrl.GenerateToken(session.UserID, controllers.TokenTypeRefresh) + refreshToken, tokenID, err := a.authorizationCtrl.GenerateToken(&controllers.JWTData{ + UserID: claims.UserID, + TokenType: controllers.TokenTypeRefresh, + TokenAud: controllers.TokenAudWeb, + }) if err != nil { return nil, status.Error(codes.Aborted, "Couldn't generate an access token") } - newSession := &controllers.Session{UserID: session.UserID} if err := a.authorizationCtrl.SaveSession(ctx, tokenID, newSession); err != nil { diff --git a/api/v1/public_accounts.go b/api/v1/public_accounts.go index 9afd2dd..8ec563e 100644 --- a/api/v1/public_accounts.go +++ b/api/v1/public_accounts.go @@ -34,12 +34,20 @@ func (a *PublicAccountService) SignIn(ctx context.Context, in *accounts.SignInRe if err != nil { return nil, status.Error(codes.Aborted, "Couldn't create a user") } - accessToken, _, err := a.authorizationCtrl.GenerateToken(id, controllers.TokenTypeAccess) + accessToken, _, err := a.authorizationCtrl.GenerateToken(&controllers.JWTData{ + UserID: id, + TokenType: controllers.TokenTypeAccess, + TokenAud: controllers.TokenAudWeb, + }) if err != nil { return nil, status.Error(codes.Aborted, "Couldn't generate an access token") } - refreshToken, tokenID, err := a.authorizationCtrl.GenerateToken(id, controllers.TokenTypeRefresh) + refreshToken, tokenID, err := a.authorizationCtrl.GenerateToken(&controllers.JWTData{ + UserID: id, + TokenType: controllers.TokenTypeRefresh, + TokenAud: controllers.TokenAudWeb, + }) if err != nil { return nil, status.Error(codes.Aborted, "Couldn't generate an access token") } @@ -70,12 +78,20 @@ func (a *PublicAccountService) SignUp(ctx context.Context, in *accounts.SignUpRe return nil, status.Error(codes.Aborted, "Couldn't create a user") } - accessToken, _, err := a.authorizationCtrl.GenerateToken(id, controllers.TokenTypeAccess) + accessToken, _, err := a.authorizationCtrl.GenerateToken(&controllers.JWTData{ + UserID: id, + TokenType: controllers.TokenTypeAccess, + TokenAud: controllers.TokenAudWeb, + }) if err != nil { return nil, status.Error(codes.Aborted, "Couldn't generate an access token") } - refreshToken, tokenID, err := a.authorizationCtrl.GenerateToken(id, controllers.TokenTypeRefresh) + refreshToken, tokenID, err := a.authorizationCtrl.GenerateToken(&controllers.JWTData{ + UserID: id, + TokenType: controllers.TokenTypeRefresh, + TokenAud: controllers.TokenAudWeb, + }) if err != nil { return nil, status.Error(codes.Aborted, "Couldn't generate an access token") } diff --git a/api/v1/public_tokens.go b/api/v1/public_tokens.go new file mode 100644 index 0000000..91e6075 --- /dev/null +++ b/api/v1/public_tokens.go @@ -0,0 +1,64 @@ +package v1 + +import ( + "context" + "errors" + + "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/controllers" + tokens "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/tokens/v1" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/emptypb" +) + +// var _ tokens.PublicTokensServiceServer = (*PublicTokensServer)(nil) + +type PublicTokensServer struct { + tokens.UnimplementedPublicTokensServiceServer + tokenCtrl *controllers.TokenController + authorizationCtrl *controllers.AuthController +} + +func NewPublicTokensServer( + tokenCtrl *controllers.TokenController, + authorizationCtrl *controllers.AuthController, +) *PublicTokensServer { + return &PublicTokensServer{ + tokenCtrl: tokenCtrl, + authorizationCtrl: authorizationCtrl, + } +} + +func (srv *PublicTokensServer) AuthenticateWithToken(ctx context.Context, in *tokens.AuthenticateWithTokenRequest) (*emptypb.Empty, error) { + tokenAuthRes, err := srv.tokenCtrl.AuthenticateWithToken(ctx, in.TokenValue.Token) + if err != nil { + if errors.Is(err, controllers.ErrBadToken) { + return nil, status.Error(codes.Unauthenticated, "Token is not valid") + } + if errors.Is(err, controllers.ErrServerError) { + return nil, status.Error(codes.Internal, "Something is broken on our side") + } + return nil, status.Error(codes.Aborted, "Couldn't list tokens") + } + + jwtData := &controllers.JWTData{ + UserID: tokenAuthRes.UserID, + TokenType: controllers.TokenTypeAccess, + TokenAud: controllers.TokenAudToken, + Scope: tokenAuthRes.Scope, + } + accessToken, _, err := srv.authorizationCtrl.GenerateToken(jwtData) + if err != nil { + return nil, status.Error(codes.Aborted, "Couldn't generate an access token") + } + + header := metadata.New(map[string]string{ + "X-Access-Token": accessToken, + }) + if err := grpc.SetHeader(ctx, header); err != nil { + return nil, status.Error(codes.Aborted, "Couldn't set metadata") + } + return &emptypb.Empty{}, nil +} diff --git a/api/v1/tokens.go b/api/v1/tokens.go index 937e482..f6aec00 100644 --- a/api/v1/tokens.go +++ b/api/v1/tokens.go @@ -3,7 +3,6 @@ package v1 import ( "context" "errors" - "fmt" "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/controllers" tokens "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/tokens/v1" @@ -261,18 +260,3 @@ func (srv *TokensServer) ListPermissions(in *emptypb.Empty, stream grpc.ServerSt } return nil } - -func (srv *TokensServer) AuthenticateWithToken(ctx context.Context, in *tokens.AuthenticateWithTokenRequest) (*emptypb.Empty, error) { - scopes, err := srv.tokenCtrl.AuthenticateWithToken(ctx, in.TokenValue.Token) - if err != nil { - if errors.Is(err, controllers.ErrBadToken) { - return nil, status.Error(codes.Unauthenticated, "Token is not valid") - } - if errors.Is(err, controllers.ErrServerError) { - return nil, status.Error(codes.Internal, "Something is broken on our side") - } - return nil, status.Error(codes.Aborted, "Couldn't list tokens") - } - fmt.Println(scopes) - return nil, nil -} diff --git a/cmd/server.go b/cmd/server.go index ed97fc4..a158f07 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -67,7 +67,7 @@ func (cmd *Server) Run(ctx context.Context) error { Addr: cmd.RedisHost, }) - authInterceptor := controllers.NewAuthController( + authController := controllers.NewAuthController( []byte(cmd.JWTSecret), cmd.AccessTokenTTL, cmd.RefrestTokenTTL, @@ -79,14 +79,14 @@ func (cmd *Server) Run(ctx context.Context) error { grpc_zap.UnaryServerInterceptor(logger.SetupLogger("info")), // jwtVerifier.JWTAuthInterceptor, selector.UnaryServerInterceptor( - auth.UnaryServerInterceptor(authInterceptor.AuthInterceptorFN), + auth.UnaryServerInterceptor(authController.AuthInterceptorFN), selector.MatchFunc(selectorRequireAuth), ), ), grpc.ChainStreamInterceptor( grpc_zap.StreamServerInterceptor(logger.SetupLogger("info")), selector.StreamServerInterceptor( - auth.StreamServerInterceptor(authInterceptor.AuthInterceptorFN), + auth.StreamServerInterceptor(authController.AuthInterceptorFN), selector.MatchFunc(selectorRequireAuth), ), ), @@ -113,11 +113,12 @@ func (cmd *Server) Run(ctx context.Context) error { } // Services that should be accessible for tokens should go here - accounts.RegisterAccountsServiceServer(grpcServer, v1.NewAccountServer(accountCtrl, authInterceptor)) + accounts.RegisterAccountsServiceServer(grpcServer, v1.NewAccountServer(accountCtrl, authController)) test.RegisterTestServiceServer(grpcServer, v1.NewTestServer()) test.RegisterPublicTestServiceServer(grpcServer, v1.NewPublicTestServer()) - tokens.RegisterTokensServiceServer(grpcServer, v1.NewTokensServer(tokenCtrl, authInterceptor)) - accounts.RegisterPublicAccountsServiceServer(grpcServer, v1.NewPublicAccountServer(accountCtrl, authInterceptor)) + tokens.RegisterTokensServiceServer(grpcServer, v1.NewTokensServer(tokenCtrl, authController)) + tokens.RegisterPublicTokensServiceServer(grpcServer, v1.NewPublicTokensServer(tokenCtrl, authController)) + accounts.RegisterPublicAccountsServiceServer(grpcServer, v1.NewPublicAccountServer(accountCtrl, authController)) info := grpcServer.GetServiceInfo() tokenCtrl.SetGRPCInfo(info) diff --git a/go.mod b/go.mod index ec03edb..60f6fdf 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,7 @@ require ( ) require ( - gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260514173933-48bddcf5c686 + gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260515083721-50411957979f github.com/golang/protobuf v1.5.4 golang.org/x/net v0.49.0 // indirect golang.org/x/sys v0.40.0 // indirect diff --git a/go.sum b/go.sum index ed898b9..cb2407e 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT cloud.google.com/go v0.121.6 h1:waZiuajrI28iAf40cWgycWNgaXPO06dupuS+sgibK6c= cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= -gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260514173933-48bddcf5c686 h1:tOSfg7VeD0Xq2NhVQblSiWGICvSH8RWfaaPH7mCvw0Y= -gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260514173933-48bddcf5c686/go.mod h1:AgOh1lkPHyRgBf3/s1btKcAqke/33LbKYarTD13qeAg= +gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260515083721-50411957979f h1:o+GpCFtuw59LrWw9ZkWQOXhQxJjLaJYM+uZm0gdtrRI= +gitea.badhouseplants.net/softplayer/softplayer-go-proto v0.0.0-20260515083721-50411957979f/go.mod h1:AgOh1lkPHyRgBf3/s1btKcAqke/33LbKYarTD13qeAg= github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg= github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= diff --git a/internal/controllers/authorization.go b/internal/controllers/authorization.go index 7a6be33..07f14a3 100644 --- a/internal/controllers/authorization.go +++ b/internal/controllers/authorization.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "slices" "strings" "time" @@ -22,6 +23,8 @@ type TokenType string const ( TokenTypeAccess TokenType = "access" TokenTypeRefresh TokenType = "refresh" + TokenAudToken string = "token" + TokenAudWeb string = "web" ) var ( @@ -33,6 +36,7 @@ type Claims struct { UserID string `json:"user_id"` TokenID string `json:"token_id"` TokenType TokenType `json:"token_type"` + Scope string `json:"scope,omitempty"` jwt.RegisteredClaims } @@ -56,6 +60,13 @@ func NewAuthController(jwtSecret []byte, accessTTL, refreshTTL time.Duration, re } } +type JWTData struct { + UserID string + TokenType TokenType + TokenAud string + Scope string +} + // Write claims into context func (a *AuthController) WithClaims(ctx context.Context, claims *Claims) context.Context { return context.WithValue(ctx, claimsContextKey, claims) @@ -87,15 +98,43 @@ func (a *AuthController) AuthInterceptorFN(ctx context.Context) (context.Context } } + // If it's a cli token, we need to check the scope + if slices.Contains(claims.Audience, TokenAudToken) { + currentMethod, ok := grpc.Method(ctx) + if !ok { + return nil, errors.New("unknown method") + } + + scopeMap := map[string][]string{} + if err := json.Unmarshal([]byte(claims.Scope), &scopeMap); err != nil { + return nil, ErrServerError + } + allowed := isAllowed(scopeMap, currentMethod) + if !allowed { + return nil, errors.New("not authorized") + } + } + ctx = a.WithClaims(ctx, claims) return ctx, nil } +func isAllowed(scope map[string][]string, currentMethod string) bool { + for service, methods := range scope { + for _, method := range methods { + if fmt.Sprintf("/%s/%s", service, method) == currentMethod { + return true + } + } + } + return false +} + // Generate JWT token -func (a *AuthController) GenerateToken(userID string, tokenType TokenType) (token, tokenID string, err error) { +func (a *AuthController) GenerateToken(data *JWTData) (token, tokenID string, err error) { var expiresAt time.Time notBefore := time.Now() - switch tokenType { + switch data.TokenType { case TokenTypeAccess: expiresAt = time.Now().Add(a.accessTTL) case TokenTypeRefresh: @@ -103,25 +142,25 @@ func (a *AuthController) GenerateToken(userID string, tokenType TokenType) (toke default: return "", "", ErrUnknownTokenType } - if tokenType != TokenTypeAccess && tokenType != TokenTypeRefresh { - return "", "", ErrUnknownTokenType - } tokenID = uuid.New().String() + claims := Claims{ - UserID: userID, + UserID: data.UserID, TokenID: tokenID, - TokenType: tokenType, + TokenType: data.TokenType, + Scope: data.Scope, RegisteredClaims: jwt.RegisteredClaims{ Issuer: "", - Subject: "", - Audience: jwt.ClaimStrings{}, + Subject: data.UserID, + Audience: jwt.ClaimStrings{data.TokenAud}, ExpiresAt: jwt.NewNumericDate(expiresAt), NotBefore: jwt.NewNumericDate(notBefore), IssuedAt: jwt.NewNumericDate(time.Now()), - ID: userID, + ID: tokenID, }, } + tokenJwt := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token, err = tokenJwt.SignedString(a.jwtSecret) if err != nil { diff --git a/internal/controllers/authorization_test.go b/internal/controllers/authorization_test.go index face84a..567014a 100644 --- a/internal/controllers/authorization_test.go +++ b/internal/controllers/authorization_test.go @@ -16,16 +16,26 @@ var ( ) func TestGenerateInvalidTokenType(t *testing.T) { + data := &controllers.JWTData{ + UserID: testUserID, + TokenType: "invalid_type", + } + authCtrl := controllers.NewAuthController([]byte("test"), testAccessTTL, testRefreshTTL, nil) - token, _, err := authCtrl.GenerateToken(testUserID, "invalid_type") + + token, _, err := authCtrl.GenerateToken(data) assert.Equal(t, "", token) assert.ErrorIs(t, controllers.ErrUnknownTokenType, err) } func TestGenerateValidateAccessToken(t *testing.T) { + data := &controllers.JWTData{ + UserID: testUserID, + TokenType: controllers.TokenTypeAccess, + } authCtrl := controllers.NewAuthController([]byte("test"), testAccessTTL, testRefreshTTL, nil) now := time.Now() - token, _, err := authCtrl.GenerateToken(testUserID, controllers.TokenTypeAccess) + token, _, err := authCtrl.GenerateToken(data) assert.NoError(t, err) assert.NotEmpty(t, token) @@ -40,9 +50,13 @@ func TestGenerateValidateAccessToken(t *testing.T) { } func TestGenerateValidateRefreshToken(t *testing.T) { + data := &controllers.JWTData{ + UserID: testUserID, + TokenType: controllers.TokenTypeRefresh, + } authCtrl := controllers.NewAuthController([]byte("test"), testAccessTTL, testRefreshTTL, nil) now := time.Now() - token, _, err := authCtrl.GenerateToken(testUserID, controllers.TokenTypeRefresh) + token, _, err := authCtrl.GenerateToken(data) assert.NoError(t, err) assert.NotEmpty(t, token) diff --git a/internal/controllers/tokens.go b/internal/controllers/tokens.go index c673871..7a4d772 100644 --- a/internal/controllers/tokens.go +++ b/internal/controllers/tokens.go @@ -330,22 +330,27 @@ func shouldSkip(s string, rules []rule) bool { return false } -func (ctrl *TokenController) AuthenticateWithToken(ctx context.Context, token string) (map[string][]string, error) { +type TokenAuthResult struct { + UserID string + Scope string +} + +func (ctrl *TokenController) AuthenticateWithToken(ctx context.Context, token string) (*TokenAuthResult, error) { log := logger.FromContext(ctx) log.V(2).Info("Authenticating with a token") query := ` - SELECT scopes, expires_at, revoked_at + SELECT user_id, scopes, expires_at, revoked_at FROM tokens WHERE token_hash = $1` + var userID string var expiresAt sql.NullTime var revokedAt sql.NullTime - var scopes string - fmt.Println(hashSHA256(token)) - fmt.Println(hashSHA256(token)) + var scope string if err := ctrl.DB.QueryRowContext(ctx, query, hashSHA256(token)).Scan( - &scopes, + &userID, + &scope, &expiresAt, &revokedAt, ); err != nil { @@ -356,20 +361,20 @@ func (ctrl *TokenController) AuthenticateWithToken(ctx context.Context, token st return nil, ErrServerError } - if !revokedAt.Valid { + if revokedAt.Valid { return nil, ErrBadToken } - if expiresAt.Time.After(time.Now()) { + if expiresAt.Time.Before(time.Now()) { return nil, ErrBadToken } - scopesMap := map[string][]string{} - - if err := json.Unmarshal([]byte(scopes), scopesMap); err != nil { - return nil, ErrServerError + result := &TokenAuthResult{ + UserID: userID, + Scope: scope, } - return scopesMap, nil + + return result, nil } func hashSHA256(s string) string {