Structure code a wee bit better
All checks were successful
ci/woodpecker/push/build Pipeline was successful

Signed-off-by: Nikolai Rodionov <allanger@badhouseplants.net>
This commit is contained in:
2026-05-17 20:07:31 +02:00
parent 60297dfaf7
commit 14ad203f63
10 changed files with 181 additions and 103 deletions

View File

@@ -30,7 +30,7 @@ type AccountsServer struct {
}
func (a *AccountsServer) RefreshToken(ctx context.Context, in *empty.Empty) (*empty.Empty, error) {
claims, err := a.authorizationCtrl.ClaimsFromContext(ctx)
claims, err := controllers.ClaimsFromContext(ctx)
if err != nil {
return nil, status.Error(codes.Aborted, "Context is invalid")
}

View File

@@ -32,7 +32,7 @@ func NewTokensServer(
// CreateToken implements [v1.TokensServiceServer].
func (srv *TokensServer) CreateToken(ctx context.Context, in *tokens.CreateTokenRequest) (*tokens.CreateTokenResponse, error) {
claims, err := srv.authorizationCtrl.ClaimsFromContext(ctx)
claims, err := controllers.ClaimsFromContext(ctx)
if err != nil {
return nil, status.Error(codes.Aborted, "Context is invalid")
}
@@ -70,7 +70,7 @@ func (srv *TokensServer) CreateToken(ctx context.Context, in *tokens.CreateToken
// ForceTokenExpiration implements [v1.TokensServiceServer].
func (srv *TokensServer) ForceTokenExpiration(ctx context.Context, in *tokens.ForceTokenExpirationRequest) (*emptypb.Empty, error) {
claims, err := srv.authorizationCtrl.ClaimsFromContext(ctx)
claims, err := controllers.ClaimsFromContext(ctx)
if err != nil {
return nil, status.Error(codes.Aborted, "Context is invalid")
}
@@ -96,7 +96,7 @@ func (srv *TokensServer) ForceTokenExpiration(ctx context.Context, in *tokens.Fo
// GetToken implements [v1.TokensServiceServer].
func (srv *TokensServer) GetToken(ctx context.Context, in *tokens.GetTokenRequest) (*tokens.GetTokenResponse, error) {
claims, err := srv.authorizationCtrl.ClaimsFromContext(ctx)
claims, err := controllers.ClaimsFromContext(ctx)
if err != nil {
return nil, status.Error(codes.Aborted, "Context is invalid")
}
@@ -136,7 +136,7 @@ func (srv *TokensServer) GetToken(ctx context.Context, in *tokens.GetTokenReques
// ListTokens implements [v1.TokensServiceServer].
func (srv *TokensServer) ListTokens(in *emptypb.Empty, stream grpc.ServerStreamingServer[tokens.ListTokensResponse]) error {
claims, err := srv.authorizationCtrl.ClaimsFromContext(stream.Context())
claims, err := controllers.ClaimsFromContext(stream.Context())
if err != nil {
return status.Error(codes.Aborted, "Context is invalid")
}
@@ -174,7 +174,7 @@ func (srv *TokensServer) ListTokens(in *emptypb.Empty, stream grpc.ServerStreami
// RegenerateToken implements [v1.TokensServiceServer].
func (srv *TokensServer) RegenerateToken(ctx context.Context, in *tokens.RegenerateTokenRequest) (*tokens.RegenerateTokenResponse, error) {
claims, err := srv.authorizationCtrl.ClaimsFromContext(ctx)
claims, err := controllers.ClaimsFromContext(ctx)
if err != nil {
return nil, status.Error(codes.Aborted, "Context is invalid")
}
@@ -204,7 +204,7 @@ func (srv *TokensServer) RegenerateToken(ctx context.Context, in *tokens.Regener
// UpdateToken implements [v1.TokensServiceServer].
func (srv *TokensServer) UpdateToken(ctx context.Context, in *tokens.UpdateTokenRequest) (*tokens.UpdateTokenResponse, error) {
claims, err := srv.authorizationCtrl.ClaimsFromContext(ctx)
claims, err := controllers.ClaimsFromContext(ctx)
if err != nil {
return nil, status.Error(codes.Aborted, "Context is invalid")
}

View File

@@ -2,10 +2,10 @@ package cmd
import (
"context"
"database/sql"
"errors"
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/logger"
postgres_helper "gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/postgres"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/postgres"
_ "github.com/golang-migrate/migrate/v4/source/file"
@@ -24,7 +24,7 @@ func (cmd *Migrate) Run(ctx context.Context) error {
log := logger.FromContext(ctx)
log.Info("Starting a database migration driver")
db, err := sql.Open("postgres", cmd.DBConnectionString)
db, err := postgres_helper.Open(ctx, cmd.DBConnectionString)
if err != nil {
log.Error(err, "Couldn't start a database driver")
return err

View File

@@ -2,7 +2,6 @@ package cmd
import (
"context"
"database/sql"
"fmt"
"net"
"strings"
@@ -11,6 +10,7 @@ import (
v1 "gitea.badhouseplants.net/softplayer/softplayer-backend/api/v1"
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/controllers"
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/logger"
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/postgres"
accounts "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/accounts/v1"
test "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/test/v1"
tokens "gitea.badhouseplants.net/softplayer/softplayer-go-proto/pkg/tokens/v1"
@@ -52,7 +52,7 @@ func (cmd *Server) Run(ctx context.Context) error {
log := logger.FromContext(ctx)
log.Info("Opening a database connection")
db, err := sql.Open("postgres", cmd.DBConnectionString)
db, err := postgres.Open(ctx, cmd.DBConnectionString)
if err != nil {
log.Error(err, "Couldn't start a database driver")
return err

1
internal/cache/cache.go vendored Normal file
View File

@@ -0,0 +1 @@
package cache

View File

@@ -2,25 +2,26 @@ package controllers
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/hash"
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/helpers/logger"
"github.com/golang-jwt/jwt/v5"
"gitea.badhouseplants.net/softplayer/softplayer-backend/internal/repository"
"github.com/google/uuid"
"github.com/jackc/pgconn"
"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
)
var ErrEmailUsed = errors.New("email is already used")
var (
ErrEmailUsed = errors.New("email is already used")
ErrUserNotFound = errors.New("user not found")
ErrWrongPassword = errors.New("wrong password")
)
type AccountController struct {
DB *pgxpool.Pool
DB *sql.DB
Redis *redis.Client
DevMode bool
HashCost int16
@@ -44,30 +45,27 @@ type AccountData struct {
// Create a new account
func (c *AccountController) Create(ctx context.Context, data *AccountData) (string, error) {
log := logger.FromContext(ctx)
log := logger.FromContext(ctx).WithValues("email", data.Email)
log.V(2).Info("Creating a user")
data.UUID = uuid.New().String()
passwordHash, err := hash.HashPassword(data.Password, int(c.HashCost))
if err != nil {
log.Error(err, "Couldn't crate the password hash")
return "", nil
return "", ErrServerError
}
query := "INSERT INTO accounts (uuid, email, password_hash) VALUES ($1, $2, $3)"
queryData := &repository.AccountData{
UUID: data.UUID,
Email: data.Email,
PasswordHash: passwordHash,
}
if _, err := c.DB.ExecContext(ctx, query, data.UUID, data.Email, passwordHash); err != nil {
fmt.Printf("ERR TYPE: %T\n", err)
fmt.Printf("ERR VALUE: %#v\n", err)
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
log.Error(nil, "error", "err", pgErr)
if pgErr.Code == pgerrcode.UniqueViolation {
return "", ErrEmailUsed
}
log.Error(err, "Couldn't create a user")
return "", ErrServerError
if err := repository.CreateAccount(ctx, c.DB, queryData); err != nil {
if errors.Is(err, repository.ErrAlreadyExists) {
return "", ErrEmailUsed
}
log.Error(err, "Couldn't create a user, wtf")
log.Error(err, "Couldn't create a user")
return "", ErrServerError
}
@@ -76,70 +74,30 @@ func (c *AccountController) Create(ctx context.Context, data *AccountData) (stri
// Login into an existing account (check password and email)
func (c *AccountController) Login(ctx context.Context, email, password string) (string, error) {
log := logger.FromContext(ctx)
query := "SELECT uuid, password_hash FROM accounts WHERE email = $1;"
log := logger.FromContext(ctx).WithValues("email", email)
log.V(2).Info("Trying to verify user login")
var passwordHash string
var uuid string
if err := c.DB.QueryRow(query, email).Scan(&uuid, &passwordHash); err != nil {
log.Error(err, "Couldn't get a user from the database")
return "", err
passwordHash, err := repository.GetPasswordHashForEmail(ctx, c.DB, email)
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
return "", ErrUserNotFound
}
log.Error(err, "Couldn't get the password hash")
return "", ErrServerError
}
if err := hash.CheckPasswordHash(password, passwordHash); err != nil {
log.Error(err, "Wrong password")
return "", err
return "", ErrWrongPassword
}
uuid, err := repository.GetUUIDForEmail(ctx, c.DB, email)
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
return "", ErrUserNotFound
}
log.Error(err, "Couldn't get tha password hash")
return "", ErrServerError
}
return uuid, nil
}
func (c *AccountController) GenerateAccessToken(userID string) (string, error) {
tokenID := uuid.New().String()
claims := jwt.MapClaims{
"user_id": userID,
"type": "access",
"exp": time.Now().Add(c.AccessTokenTTL).Unix(),
"token_id": tokenID,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(c.JWTSecret)
}
func redisKey(id string) string {
return fmt.Sprintf("refresh:%s", id)
}
func (c *AccountController) GenerateRefreshToken(ctx context.Context, userID string) (string, error) {
tokenID := uuid.New().String()
claims := jwt.MapClaims{
"user_id": userID,
"token_id": tokenID,
"type": "refresh",
"exp": time.Now().Add(c.RefreshTokenTTL).Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
if err := c.Redis.Set(ctx, redisKey(tokenID), userID, c.RefreshTokenTTL).Err(); err != nil {
return "", err
}
return token.SignedString(c.JWTSecret)
}
// It must validate the refresh token
// Get it's id from the content
// Find a corresponding token in redis, and if it's found, remove it and create a new one
func (c *AccountController) ValidateRefreshToken(ctx context.Context, tokenID, userID string) (string, error) {
log := logger.FromContext(ctx)
userIDRedis := c.Redis.Get(ctx, redisKey(tokenID)).Val()
if err := c.Redis.Del(ctx, redisKey(tokenID)).Err(); err != nil {
log.Error(err, "Couldn't delete redis entry")
return "", err
}
if userID != userIDRedis {
return "", errors.New("user id doesn't match")
}
return userIDRedis, nil
}

View File

@@ -1,6 +1,7 @@
package controllers_test
import (
"context"
"database/sql"
"fmt"
"os"
@@ -15,12 +16,12 @@ import (
"github.com/stretchr/testify/assert"
)
func newTestDbConnection() *sql.DB {
func newTestDbConnection(ctx context.Context) *sql.DB {
connStr, ok := os.LookupEnv("SOFTPLAYER_DB_CONNECTION_STRING")
if !ok {
panic("set the db connection string env var")
}
db, err := postgres.Open(connStr)
db, err := postgres.Open(ctx, connStr)
if err != nil {
panic(err)
}
@@ -37,9 +38,9 @@ func newTestRedisConnection() *redis.Client {
})
}
func newTestAccountController() *controllers.AccountController {
func newTestAccountController(ctx context.Context) *controllers.AccountController {
return &controllers.AccountController{
DB: newTestDbConnection(),
DB: newTestDbConnection(ctx),
Redis: newTestRedisConnection(),
DevMode: true,
HashCost: 3,
@@ -63,7 +64,7 @@ func newTestUniqueEmail(prefix string) string {
}
func TestIntegrationAccountCreate_Success(t *testing.T) {
ctrl := newTestAccountController()
ctrl := newTestAccountController(t.Context())
accountData := &controllers.AccountData{
Password: "qwertyu9",
Email: newTestUniqueEmail("accounts"),
@@ -74,7 +75,7 @@ func TestIntegrationAccountCreate_Success(t *testing.T) {
}
func TestIntegrationAccountCreate_ExistingAccountErr(t *testing.T) {
ctrl := newTestAccountController()
ctrl := newTestAccountController(t.Context())
email := newTestUniqueEmail("accounts")
accountData := &controllers.AccountData{
Password: "qwertyu9",
@@ -89,3 +90,56 @@ func TestIntegrationAccountCreate_ExistingAccountErr(t *testing.T) {
assert.Error(t, err)
assert.ErrorIs(t, err, controllers.ErrEmailUsed)
}
func TestIntegrationAccountLogin_Success(t *testing.T) {
ctrl := newTestAccountController(t.Context())
email := newTestUniqueEmail("accounts")
accountData := &controllers.AccountData{
Password: "qwertyu9",
Email: email,
}
id, err := ctrl.Create(t.Context(), accountData)
assert.NoError(t, err)
assert.NotEmpty(t, id)
accountData.UUID = id
id, err = ctrl.Login(t.Context(), accountData.Email, accountData.Password)
assert.NoError(t, err)
assert.NotEmpty(t, id)
}
func TestIntegrationAccountLogin_WrongPassword(t *testing.T) {
ctrl := newTestAccountController(t.Context())
email := newTestUniqueEmail("accounts")
accountData := &controllers.AccountData{
Password: "qwertyu9",
Email: email,
}
id, err := ctrl.Create(t.Context(), accountData)
assert.NoError(t, err)
assert.NotEmpty(t, id)
accountData.UUID = id
id, err = ctrl.Login(t.Context(), accountData.Email, "Wrong Password")
assert.Empty(t, id)
assert.Error(t, err)
assert.ErrorIs(t, err, controllers.ErrWrongPassword)
}
func TestIntegrationAccountLogin_WrongEmail(t *testing.T) {
ctrl := newTestAccountController(t.Context())
email := newTestUniqueEmail("accounts")
accountData := &controllers.AccountData{
Password: "qwertyu9",
Email: email,
}
id, err := ctrl.Create(t.Context(), accountData)
assert.NoError(t, err)
assert.NotEmpty(t, id)
accountData.UUID = id
id, err = ctrl.Login(t.Context(), "some@email.com", "Wrong Password")
assert.Empty(t, id)
assert.Error(t, err)
assert.ErrorIs(t, err, controllers.ErrUserNotFound)
}

View File

@@ -68,12 +68,12 @@ type JWTData struct {
}
// Write claims into context
func (a *AuthController) WithClaims(ctx context.Context, claims *Claims) context.Context {
func WithClaims(ctx context.Context, claims *Claims) context.Context {
return context.WithValue(ctx, claimsContextKey, claims)
}
// Extract claims from context
func (a *AuthController) ClaimsFromContext(ctx context.Context) (*Claims, error) {
func ClaimsFromContext(ctx context.Context) (*Claims, error) {
claims, ok := ctx.Value(claimsContextKey).(*Claims)
if !ok || claims == nil {
return nil, errors.New("claims not found in context")
@@ -115,7 +115,7 @@ func (a *AuthController) AuthInterceptorFN(ctx context.Context) (context.Context
}
}
ctx = a.WithClaims(ctx, claims)
ctx = WithClaims(ctx, claims)
return ctx, nil
}

View File

@@ -2,12 +2,14 @@ package postgres
import (
"context"
"database/sql"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
_ "github.com/jackc/pgx/v5/stdlib"
)
func Open(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
func Open(ctx context.Context, dsn string) (*sql.DB, error) {
dbpool, err := pgxpool.New(context.Background(), dsn)
if err != nil {
return nil, err
@@ -17,5 +19,6 @@ func Open(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
return nil, err
}
return dbpool, nil
db := stdlib.OpenDBFromPool(dbpool)
return db, nil
}

View File

@@ -0,0 +1,62 @@
package repository
import (
"context"
"database/sql"
"errors"
"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5/pgconn"
)
var (
ErrAlreadyExists = errors.New("entry already exists")
ErrNotFound = errors.New("entry not found")
)
type AccountData struct {
UUID string
Email string
PasswordHash string
}
// CreateAccount adds a new account to a database
func CreateAccount(ctx context.Context, db *sql.DB, account *AccountData) error {
query := "INSERT INTO accounts (uuid, email, password_hash) VALUES ($1, $2, $3)"
if _, err := db.ExecContext(ctx, query, account.UUID, account.Email, account.PasswordHash); err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
if pgErr.Code == pgerrcode.UniqueViolation {
return ErrAlreadyExists
}
return err
}
return err
}
return nil
}
// GetPasswordHashForEmail returns the password hash for a user
func GetPasswordHashForEmail(ctx context.Context, db *sql.DB, email string) (hash string, err error) {
query := "SELECT password_hash FROM accounts WHERE email = $1;"
if err = db.QueryRowContext(ctx, query, email).Scan(&hash); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", ErrNotFound
}
return
}
return
}
// GetUUIDForEmail returns a uuid of a user
func GetUUIDForEmail(ctx context.Context, db *sql.DB, email string) (uuid string, err error) {
query := "SELECT uuid FROM accounts WHERE email = $1;"
if err = db.QueryRow(query, email).Scan(&uuid); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", ErrNotFound
}
return
}
return
}