From 14ad203f63e859ddd9b8a1484b5e020c350172dd Mon Sep 17 00:00:00 2001 From: Nikolai Rodionov Date: Sun, 17 May 2026 20:07:31 +0200 Subject: [PATCH] Structure code a wee bit better Signed-off-by: Nikolai Rodionov --- api/v1/accounts.go | 2 +- api/v1/tokens.go | 12 +-- cmd/migrate.go | 4 +- cmd/server.go | 4 +- internal/cache/cache.go | 1 + internal/controllers/accounts.go | 120 +++++++++----------------- internal/controllers/accounts_test.go | 66 ++++++++++++-- internal/controllers/authorization.go | 6 +- internal/helpers/postgres/postgres.go | 7 +- internal/repository/accounts.go | 62 +++++++++++++ 10 files changed, 181 insertions(+), 103 deletions(-) create mode 100644 internal/cache/cache.go create mode 100644 internal/repository/accounts.go diff --git a/api/v1/accounts.go b/api/v1/accounts.go index 95de164..90e2e44 100644 --- a/api/v1/accounts.go +++ b/api/v1/accounts.go @@ -30,7 +30,7 @@ type AccountsServer struct { } func (a *AccountsServer) RefreshToken(ctx context.Context, in *empty.Empty) (*empty.Empty, error) { - claims, err := a.authorizationCtrl.ClaimsFromContext(ctx) + claims, err := controllers.ClaimsFromContext(ctx) if err != nil { return nil, status.Error(codes.Aborted, "Context is invalid") } diff --git a/api/v1/tokens.go b/api/v1/tokens.go index f6aec00..16887da 100644 --- a/api/v1/tokens.go +++ b/api/v1/tokens.go @@ -32,7 +32,7 @@ func NewTokensServer( // CreateToken implements [v1.TokensServiceServer]. func (srv *TokensServer) CreateToken(ctx context.Context, in *tokens.CreateTokenRequest) (*tokens.CreateTokenResponse, error) { - claims, err := srv.authorizationCtrl.ClaimsFromContext(ctx) + claims, err := controllers.ClaimsFromContext(ctx) if err != nil { return nil, status.Error(codes.Aborted, "Context is invalid") } @@ -70,7 +70,7 @@ func (srv *TokensServer) CreateToken(ctx context.Context, in *tokens.CreateToken // ForceTokenExpiration implements [v1.TokensServiceServer]. func (srv *TokensServer) ForceTokenExpiration(ctx context.Context, in *tokens.ForceTokenExpirationRequest) (*emptypb.Empty, error) { - claims, err := srv.authorizationCtrl.ClaimsFromContext(ctx) + claims, err := controllers.ClaimsFromContext(ctx) if err != nil { return nil, status.Error(codes.Aborted, "Context is invalid") } @@ -96,7 +96,7 @@ func (srv *TokensServer) ForceTokenExpiration(ctx context.Context, in *tokens.Fo // GetToken implements [v1.TokensServiceServer]. func (srv *TokensServer) GetToken(ctx context.Context, in *tokens.GetTokenRequest) (*tokens.GetTokenResponse, error) { - claims, err := srv.authorizationCtrl.ClaimsFromContext(ctx) + claims, err := controllers.ClaimsFromContext(ctx) if err != nil { return nil, status.Error(codes.Aborted, "Context is invalid") } @@ -136,7 +136,7 @@ func (srv *TokensServer) GetToken(ctx context.Context, in *tokens.GetTokenReques // ListTokens implements [v1.TokensServiceServer]. func (srv *TokensServer) ListTokens(in *emptypb.Empty, stream grpc.ServerStreamingServer[tokens.ListTokensResponse]) error { - claims, err := srv.authorizationCtrl.ClaimsFromContext(stream.Context()) + claims, err := controllers.ClaimsFromContext(stream.Context()) if err != nil { return status.Error(codes.Aborted, "Context is invalid") } @@ -174,7 +174,7 @@ func (srv *TokensServer) ListTokens(in *emptypb.Empty, stream grpc.ServerStreami // RegenerateToken implements [v1.TokensServiceServer]. func (srv *TokensServer) RegenerateToken(ctx context.Context, in *tokens.RegenerateTokenRequest) (*tokens.RegenerateTokenResponse, error) { - claims, err := srv.authorizationCtrl.ClaimsFromContext(ctx) + claims, err := controllers.ClaimsFromContext(ctx) if err != nil { return nil, status.Error(codes.Aborted, "Context is invalid") } @@ -204,7 +204,7 @@ func (srv *TokensServer) RegenerateToken(ctx context.Context, in *tokens.Regener // UpdateToken implements [v1.TokensServiceServer]. func (srv *TokensServer) UpdateToken(ctx context.Context, in *tokens.UpdateTokenRequest) (*tokens.UpdateTokenResponse, error) { - claims, err := srv.authorizationCtrl.ClaimsFromContext(ctx) + claims, err := controllers.ClaimsFromContext(ctx) if err != nil { return nil, status.Error(codes.Aborted, "Context is invalid") } diff --git a/cmd/migrate.go b/cmd/migrate.go index 10f2047..466ac84 100644 --- a/cmd/migrate.go +++ b/cmd/migrate.go @@ -2,10 +2,10 @@ package cmd import ( "context" - "database/sql" "errors" "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/logger" + postgres_helper "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/postgres" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" @@ -24,7 +24,7 @@ func (cmd *Migrate) Run(ctx context.Context) error { log := logger.FromContext(ctx) log.Info("Starting a database migration driver") - db, err := sql.Open("postgres", cmd.DBConnectionString) + db, err := postgres_helper.Open(ctx, cmd.DBConnectionString) if err != nil { log.Error(err, "Couldn't start a database driver") return err diff --git a/cmd/server.go b/cmd/server.go index 35d4747..487c5c0 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -2,7 +2,6 @@ package cmd import ( "context" - "database/sql" "fmt" "net" "strings" @@ -11,6 +10,7 @@ import ( v1 "gitea.badhouseplants.net/softplayer/softplayer-backend/api/v1" "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/controllers" "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/logger" + "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/postgres" accounts "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/accounts/v1" test "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/test/v1" tokens "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/tokens/v1" @@ -52,7 +52,7 @@ func (cmd *Server) Run(ctx context.Context) error { log := logger.FromContext(ctx) log.Info("Opening a database connection") - db, err := sql.Open("postgres", cmd.DBConnectionString) + db, err := postgres.Open(ctx, cmd.DBConnectionString) if err != nil { log.Error(err, "Couldn't start a database driver") return err diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..08bf029 --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1 @@ +package cache diff --git a/internal/controllers/accounts.go b/internal/controllers/accounts.go index ee44d91..d6ad0fb 100644 --- a/internal/controllers/accounts.go +++ b/internal/controllers/accounts.go @@ -2,25 +2,26 @@ package controllers import ( "context" + "database/sql" "errors" - "fmt" "time" "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/hash" "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/logger" - "github.com/golang-jwt/jwt/v5" + "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/repository" "github.com/google/uuid" - "github.com/jackc/pgconn" - "github.com/jackc/pgerrcode" - "github.com/jackc/pgx/v5/pgxpool" "github.com/redis/go-redis/v9" ) -var ErrEmailUsed = errors.New("email is already used") +var ( + ErrEmailUsed = errors.New("email is already used") + ErrUserNotFound = errors.New("user not found") + ErrWrongPassword = errors.New("wrong password") +) type AccountController struct { - DB *pgxpool.Pool + DB *sql.DB Redis *redis.Client DevMode bool HashCost int16 @@ -44,30 +45,27 @@ type AccountData struct { // Create a new account func (c *AccountController) Create(ctx context.Context, data *AccountData) (string, error) { - log := logger.FromContext(ctx) + log := logger.FromContext(ctx).WithValues("email", data.Email) + log.V(2).Info("Creating a user") data.UUID = uuid.New().String() passwordHash, err := hash.HashPassword(data.Password, int(c.HashCost)) if err != nil { log.Error(err, "Couldn't crate the password hash") - return "", nil + return "", ErrServerError } - query := "INSERT INTO accounts (uuid, email, password_hash) VALUES ($1, $2, $3)" + queryData := &repository.AccountData{ + UUID: data.UUID, + Email: data.Email, + PasswordHash: passwordHash, + } - if _, err := c.DB.ExecContext(ctx, query, data.UUID, data.Email, passwordHash); err != nil { - fmt.Printf("ERR TYPE: %T\n", err) - fmt.Printf("ERR VALUE: %#v\n", err) - var pgErr *pgconn.PgError - if errors.As(err, &pgErr) { - log.Error(nil, "error", "err", pgErr) - if pgErr.Code == pgerrcode.UniqueViolation { - return "", ErrEmailUsed - } - log.Error(err, "Couldn't create a user") - return "", ErrServerError + if err := repository.CreateAccount(ctx, c.DB, queryData); err != nil { + if errors.Is(err, repository.ErrAlreadyExists) { + return "", ErrEmailUsed } - log.Error(err, "Couldn't create a user, wtf") + log.Error(err, "Couldn't create a user") return "", ErrServerError } @@ -76,70 +74,30 @@ func (c *AccountController) Create(ctx context.Context, data *AccountData) (stri // Login into an existing account (check password and email) func (c *AccountController) Login(ctx context.Context, email, password string) (string, error) { - log := logger.FromContext(ctx) - query := "SELECT uuid, password_hash FROM accounts WHERE email = $1;" + log := logger.FromContext(ctx).WithValues("email", email) + log.V(2).Info("Trying to verify user login") - var passwordHash string - var uuid string - - if err := c.DB.QueryRow(query, email).Scan(&uuid, &passwordHash); err != nil { - log.Error(err, "Couldn't get a user from the database") - return "", err + passwordHash, err := repository.GetPasswordHashForEmail(ctx, c.DB, email) + if err != nil { + if errors.Is(err, repository.ErrNotFound) { + return "", ErrUserNotFound + } + log.Error(err, "Couldn't get the password hash") + return "", ErrServerError } if err := hash.CheckPasswordHash(password, passwordHash); err != nil { - log.Error(err, "Wrong password") - return "", err + return "", ErrWrongPassword + } + + uuid, err := repository.GetUUIDForEmail(ctx, c.DB, email) + if err != nil { + if errors.Is(err, repository.ErrNotFound) { + return "", ErrUserNotFound + } + log.Error(err, "Couldn't get tha password hash") + return "", ErrServerError } return uuid, nil } - -func (c *AccountController) GenerateAccessToken(userID string) (string, error) { - tokenID := uuid.New().String() - claims := jwt.MapClaims{ - "user_id": userID, - "type": "access", - "exp": time.Now().Add(c.AccessTokenTTL).Unix(), - "token_id": tokenID, - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString(c.JWTSecret) -} - -func redisKey(id string) string { - return fmt.Sprintf("refresh:%s", id) -} - -func (c *AccountController) GenerateRefreshToken(ctx context.Context, userID string) (string, error) { - tokenID := uuid.New().String() - claims := jwt.MapClaims{ - "user_id": userID, - "token_id": tokenID, - "type": "refresh", - "exp": time.Now().Add(c.RefreshTokenTTL).Unix(), - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - if err := c.Redis.Set(ctx, redisKey(tokenID), userID, c.RefreshTokenTTL).Err(); err != nil { - return "", err - } - return token.SignedString(c.JWTSecret) -} - -// It must validate the refresh token -// Get it's id from the content -// Find a corresponding token in redis, and if it's found, remove it and create a new one -func (c *AccountController) ValidateRefreshToken(ctx context.Context, tokenID, userID string) (string, error) { - log := logger.FromContext(ctx) - userIDRedis := c.Redis.Get(ctx, redisKey(tokenID)).Val() - if err := c.Redis.Del(ctx, redisKey(tokenID)).Err(); err != nil { - log.Error(err, "Couldn't delete redis entry") - return "", err - } - if userID != userIDRedis { - return "", errors.New("user id doesn't match") - } - return userIDRedis, nil -} diff --git a/internal/controllers/accounts_test.go b/internal/controllers/accounts_test.go index e9090a7..9667700 100644 --- a/internal/controllers/accounts_test.go +++ b/internal/controllers/accounts_test.go @@ -1,6 +1,7 @@ package controllers_test import ( + "context" "database/sql" "fmt" "os" @@ -15,12 +16,12 @@ import ( "github.com/stretchr/testify/assert" ) -func newTestDbConnection() *sql.DB { +func newTestDbConnection(ctx context.Context) *sql.DB { connStr, ok := os.LookupEnv("SOFTPLAYER_DB_CONNECTION_STRING") if !ok { panic("set the db connection string env var") } - db, err := postgres.Open(connStr) + db, err := postgres.Open(ctx, connStr) if err != nil { panic(err) } @@ -37,9 +38,9 @@ func newTestRedisConnection() *redis.Client { }) } -func newTestAccountController() *controllers.AccountController { +func newTestAccountController(ctx context.Context) *controllers.AccountController { return &controllers.AccountController{ - DB: newTestDbConnection(), + DB: newTestDbConnection(ctx), Redis: newTestRedisConnection(), DevMode: true, HashCost: 3, @@ -63,7 +64,7 @@ func newTestUniqueEmail(prefix string) string { } func TestIntegrationAccountCreate_Success(t *testing.T) { - ctrl := newTestAccountController() + ctrl := newTestAccountController(t.Context()) accountData := &controllers.AccountData{ Password: "qwertyu9", Email: newTestUniqueEmail("accounts"), @@ -74,7 +75,7 @@ func TestIntegrationAccountCreate_Success(t *testing.T) { } func TestIntegrationAccountCreate_ExistingAccountErr(t *testing.T) { - ctrl := newTestAccountController() + ctrl := newTestAccountController(t.Context()) email := newTestUniqueEmail("accounts") accountData := &controllers.AccountData{ Password: "qwertyu9", @@ -89,3 +90,56 @@ func TestIntegrationAccountCreate_ExistingAccountErr(t *testing.T) { assert.Error(t, err) assert.ErrorIs(t, err, controllers.ErrEmailUsed) } + +func TestIntegrationAccountLogin_Success(t *testing.T) { + ctrl := newTestAccountController(t.Context()) + email := newTestUniqueEmail("accounts") + accountData := &controllers.AccountData{ + Password: "qwertyu9", + Email: email, + } + id, err := ctrl.Create(t.Context(), accountData) + assert.NoError(t, err) + assert.NotEmpty(t, id) + accountData.UUID = id + + id, err = ctrl.Login(t.Context(), accountData.Email, accountData.Password) + assert.NoError(t, err) + assert.NotEmpty(t, id) +} + +func TestIntegrationAccountLogin_WrongPassword(t *testing.T) { + ctrl := newTestAccountController(t.Context()) + email := newTestUniqueEmail("accounts") + accountData := &controllers.AccountData{ + Password: "qwertyu9", + Email: email, + } + id, err := ctrl.Create(t.Context(), accountData) + assert.NoError(t, err) + assert.NotEmpty(t, id) + accountData.UUID = id + + id, err = ctrl.Login(t.Context(), accountData.Email, "Wrong Password") + assert.Empty(t, id) + assert.Error(t, err) + assert.ErrorIs(t, err, controllers.ErrWrongPassword) +} + +func TestIntegrationAccountLogin_WrongEmail(t *testing.T) { + ctrl := newTestAccountController(t.Context()) + email := newTestUniqueEmail("accounts") + accountData := &controllers.AccountData{ + Password: "qwertyu9", + Email: email, + } + id, err := ctrl.Create(t.Context(), accountData) + assert.NoError(t, err) + assert.NotEmpty(t, id) + accountData.UUID = id + + id, err = ctrl.Login(t.Context(), "some@email.com", "Wrong Password") + assert.Empty(t, id) + assert.Error(t, err) + assert.ErrorIs(t, err, controllers.ErrUserNotFound) +} diff --git a/internal/controllers/authorization.go b/internal/controllers/authorization.go index 07f14a3..430058d 100644 --- a/internal/controllers/authorization.go +++ b/internal/controllers/authorization.go @@ -68,12 +68,12 @@ type JWTData struct { } // Write claims into context -func (a *AuthController) WithClaims(ctx context.Context, claims *Claims) context.Context { +func WithClaims(ctx context.Context, claims *Claims) context.Context { return context.WithValue(ctx, claimsContextKey, claims) } // Extract claims from context -func (a *AuthController) ClaimsFromContext(ctx context.Context) (*Claims, error) { +func ClaimsFromContext(ctx context.Context) (*Claims, error) { claims, ok := ctx.Value(claimsContextKey).(*Claims) if !ok || claims == nil { return nil, errors.New("claims not found in context") @@ -115,7 +115,7 @@ func (a *AuthController) AuthInterceptorFN(ctx context.Context) (context.Context } } - ctx = a.WithClaims(ctx, claims) + ctx = WithClaims(ctx, claims) return ctx, nil } diff --git a/internal/helpers/postgres/postgres.go b/internal/helpers/postgres/postgres.go index 9f7362a..51c9c91 100644 --- a/internal/helpers/postgres/postgres.go +++ b/internal/helpers/postgres/postgres.go @@ -2,12 +2,14 @@ package postgres import ( "context" + "database/sql" "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/stdlib" _ "github.com/jackc/pgx/v5/stdlib" ) -func Open(ctx context.Context, dsn string) (*pgxpool.Pool, error) { +func Open(ctx context.Context, dsn string) (*sql.DB, error) { dbpool, err := pgxpool.New(context.Background(), dsn) if err != nil { return nil, err @@ -17,5 +19,6 @@ func Open(ctx context.Context, dsn string) (*pgxpool.Pool, error) { return nil, err } - return dbpool, nil + db := stdlib.OpenDBFromPool(dbpool) + return db, nil } diff --git a/internal/repository/accounts.go b/internal/repository/accounts.go new file mode 100644 index 0000000..0b259b8 --- /dev/null +++ b/internal/repository/accounts.go @@ -0,0 +1,62 @@ +package repository + +import ( + "context" + "database/sql" + "errors" + + "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v5/pgconn" +) + +var ( + ErrAlreadyExists = errors.New("entry already exists") + ErrNotFound = errors.New("entry not found") +) + +type AccountData struct { + UUID string + Email string + PasswordHash string +} + +// CreateAccount adds a new account to a database +func CreateAccount(ctx context.Context, db *sql.DB, account *AccountData) error { + query := "INSERT INTO accounts (uuid, email, password_hash) VALUES ($1, $2, $3)" + + if _, err := db.ExecContext(ctx, query, account.UUID, account.Email, account.PasswordHash); err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + if pgErr.Code == pgerrcode.UniqueViolation { + return ErrAlreadyExists + } + return err + } + return err + } + return nil +} + +// GetPasswordHashForEmail returns the password hash for a user +func GetPasswordHashForEmail(ctx context.Context, db *sql.DB, email string) (hash string, err error) { + query := "SELECT password_hash FROM accounts WHERE email = $1;" + if err = db.QueryRowContext(ctx, query, email).Scan(&hash); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", ErrNotFound + } + return + } + return +} + +// GetUUIDForEmail returns a uuid of a user +func GetUUIDForEmail(ctx context.Context, db *sql.DB, email string) (uuid string, err error) { + query := "SELECT uuid FROM accounts WHERE email = $1;" + if err = db.QueryRow(query, email).Scan(&uuid); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", ErrNotFound + } + return + } + return +}