refacto: add db tx as a possible input for repo methods
All checks were successful
Build and test / Build (push) Successful in 2m22s

This commit is contained in:
Muyao CHEN 2024-10-25 23:52:43 +02:00
parent b30a5c5c2d
commit 14ee642aab
10 changed files with 176 additions and 58 deletions

View File

@ -577,3 +577,8 @@ But between `[]*T` and `[]T`, the only difference that I see (pointed out by
`ChatGPT`) is how the memory is allocated. With `[]T` it might be better for `ChatGPT`) is how the memory is allocated. With `[]T` it might be better for
the GC to deal with the memory free. I thing for my project I will stick to the GC to deal with the memory free. I thing for my project I will stick to
`[]T`. `[]T`.
### 2024/10/25
Read this [article](https://konradreiche.com/blog/two-common-go-interface-misuses/)
today, maybe I am abusing the usage of interfaces?

View File

@ -26,6 +26,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"git.vinchent.xyz/vinchent/howmuch/internal/howmuch/adapter/repo/sqlc"
"git.vinchent.xyz/vinchent/howmuch/internal/howmuch/usecase/repo" "git.vinchent.xyz/vinchent/howmuch/internal/howmuch/usecase/repo"
"git.vinchent.xyz/vinchent/howmuch/internal/pkg/log" "git.vinchent.xyz/vinchent/howmuch/internal/pkg/log"
) )
@ -66,3 +67,11 @@ func (dr *dbRepository) Transaction(
data, err := txFunc(ctx, tx) data, err := txFunc(ctx, tx)
return data, err return data, err
} }
func getQueries(queries *sqlc.Queries, tx any) *sqlc.Queries {
transaction, ok := tx.(*sql.Tx)
if ok {
return sqlc.New(transaction)
}
return queries
}

View File

@ -48,8 +48,35 @@ func NewEventRepository(db *sql.DB) repo.EventRepository {
func (e *eventRepository) Create( func (e *eventRepository) Create(
ctx context.Context, ctx context.Context,
evEntity *model.EventEntity, evEntity *model.EventEntity,
tx any,
) (*model.EventEntity, error) { ) (*model.EventEntity, error) {
panic("unimplemented") timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
event, err := queries.InsertEvent(timeoutCtx, sqlc.InsertEventParams{
Name: evEntity.Name,
Description: sql.NullString{String: evEntity.Description, Valid: true},
TotalAmount: sql.NullInt32{Int32: int32(evEntity.TotalAmount), Valid: true},
DefaultCurrency: evEntity.DefaultCurrency,
OwnerID: int32(evEntity.OwnerID),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
})
if err != nil {
return nil, err
}
return &model.EventEntity{
ID: int(event.ID),
Name: event.Name,
Description: event.Description.String,
TotalAmount: int(event.TotalAmount.Int32),
DefaultCurrency: event.DefaultCurrency,
OwnerID: int(event.OwnerID),
CreatedAt: event.CreatedAt,
UpdatedAt: event.UpdatedAt,
}, nil
} }
func convToEventRetrieved(eventDTO *sqlc.GetEventByIDRow) (*model.EventRetrieved, error) { func convToEventRetrieved(eventDTO *sqlc.GetEventByIDRow) (*model.EventRetrieved, error) {
@ -89,8 +116,17 @@ func convToEventRetrieved(eventDTO *sqlc.GetEventByIDRow) (*model.EventRetrieved
} }
// GetByID implements repo.EventRepository. // GetByID implements repo.EventRepository.
func (e *eventRepository) GetByID(ctx context.Context, eventID int) (*model.EventRetrieved, error) { func (e *eventRepository) GetByID(
eventDTO, err := e.queries.GetEventByID(ctx, int32(eventID)) ctx context.Context,
eventID int,
tx any,
) (*model.EventRetrieved, error) {
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
eventDTO, err := queries.GetEventByID(timeoutCtx, int32(eventID))
if err != nil { if err != nil {
log.ErrorLog("query error", "err", err) log.ErrorLog("query error", "err", err)
return nil, err return nil, err
@ -128,8 +164,14 @@ func convToEventList(eventsDTO []sqlc.ListEventsByUserIDRow) ([]model.EventListR
func (e *eventRepository) ListEventsByUserID( func (e *eventRepository) ListEventsByUserID(
ctx context.Context, ctx context.Context,
userID int, userID int,
tx any,
) ([]model.EventListRetrieved, error) { ) ([]model.EventListRetrieved, error) {
eventsDTO, err := e.queries.ListEventsByUserID(ctx, int32(userID)) timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
eventsDTO, err := queries.ListEventsByUserID(timeoutCtx, int32(userID))
if err != nil { if err != nil {
log.ErrorLog("query error", "err", err) log.ErrorLog("query error", "err", err)
return nil, err return nil, err
@ -142,8 +184,14 @@ func (e *eventRepository) ListEventsByUserID(
func (e *eventRepository) UpdateEventByID( func (e *eventRepository) UpdateEventByID(
ctx context.Context, ctx context.Context,
event *model.EventUpdateEntity, event *model.EventUpdateEntity,
tx any,
) error { ) error {
err := e.queries.UpdateEventByID(ctx, sqlc.UpdateEventByIDParams{ timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
err := queries.UpdateEventByID(timeoutCtx, sqlc.UpdateEventByIDParams{
ID: int32(event.ID), ID: int32(event.ID),
Name: event.Name, Name: event.Name,
Description: sql.NullString{String: event.Description, Valid: true}, Description: sql.NullString{String: event.Description, Valid: true},

View File

@ -45,21 +45,39 @@ func NewExpenseRepository(db *sql.DB) repo.ExpenseRepository {
} }
// DeleteExpense implements repo.ExpenseRepository. // DeleteExpense implements repo.ExpenseRepository.
func (e *expenseRepository) DeleteExpense(ctx context.Context, expenseID int) error { func (e *expenseRepository) DeleteExpense(ctx context.Context, expenseID int, tx any) error {
return e.queries.DeleteExpense(ctx, int32(expenseID)) timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
return queries.DeleteExpense(timeoutCtx, int32(expenseID))
} }
// DeleteTransactionsOfExpense implements repo.ExpenseRepository. // DeleteTransactionsOfExpense implements repo.ExpenseRepository.
func (e *expenseRepository) DeleteTransactionsOfExpense(ctx context.Context, expenseID int) error { func (e *expenseRepository) DeleteTransactionsOfExpense(
return e.queries.DeleteTransactionsOfExpenseID(ctx, int32(expenseID)) ctx context.Context,
expenseID int,
tx any,
) error {
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
return queries.DeleteTransactionsOfExpenseID(timeoutCtx, int32(expenseID))
} }
// GetExpenseByID implements repo.ExpenseRepository. // GetExpenseByID implements repo.ExpenseRepository.
func (e *expenseRepository) GetExpenseByID( func (e *expenseRepository) GetExpenseByID(
ctx context.Context, ctx context.Context,
expenseID int, expenseID int,
tx any,
) (*model.ExpenseRetrieved, error) { ) (*model.ExpenseRetrieved, error) {
expenseDTO, err := e.queries.GetExpenseByID(ctx, int32(expenseID)) timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
expenseDTO, err := queries.GetExpenseByID(timeoutCtx, int32(expenseID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -150,8 +168,14 @@ func convToExpenseRetrieved(expenseDTO *sqlc.GetExpenseByIDRow) (*model.ExpenseR
func (e *expenseRepository) InsertExpense( func (e *expenseRepository) InsertExpense(
ctx context.Context, ctx context.Context,
expenseEntity *model.ExpenseEntity, expenseEntity *model.ExpenseEntity,
tx any,
) (*model.ExpenseEntity, error) { ) (*model.ExpenseEntity, error) {
expenseDTO, err := e.queries.InsertExpense(ctx, sqlc.InsertExpenseParams{ timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
expenseDTO, err := queries.InsertExpense(timeoutCtx, sqlc.InsertExpenseParams{
CreatedAt: time.Now(), CreatedAt: time.Now(),
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
Amount: int32(expenseEntity.Amount), Amount: int32(expenseEntity.Amount),
@ -179,8 +203,14 @@ func (e *expenseRepository) InsertExpense(
func (e *expenseRepository) ListExpensesByEventID( func (e *expenseRepository) ListExpensesByEventID(
ctx context.Context, ctx context.Context,
id int, id int,
tx any,
) ([]model.ExpensesListRetrieved, error) { ) ([]model.ExpensesListRetrieved, error) {
listDTO, err := e.queries.ListExpensesByEventID(ctx, int32(id)) timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
listDTO, err := queries.ListExpensesByEventID(timeoutCtx, int32(id))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -206,8 +236,14 @@ func (e *expenseRepository) ListExpensesByEventID(
func (e *expenseRepository) UpdateExpenseByID( func (e *expenseRepository) UpdateExpenseByID(
ctx context.Context, ctx context.Context,
expenseUpdate *model.ExpenseUpdateEntity, expenseUpdate *model.ExpenseUpdateEntity,
tx any,
) (*model.ExpenseEntity, error) { ) (*model.ExpenseEntity, error) {
expenseDTO, err := e.queries.UpdateExpenseByID(ctx, sqlc.UpdateExpenseByIDParams{ timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
expenseDTO, err := queries.UpdateExpenseByID(timeoutCtx, sqlc.UpdateExpenseByIDParams{
ID: int32(expenseUpdate.ID), ID: int32(expenseUpdate.ID),
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
Amount: int32(expenseUpdate.Amount), Amount: int32(expenseUpdate.Amount),

View File

@ -34,43 +34,36 @@ import (
) )
type userRepository struct { type userRepository struct {
querier *sqlc.Queries queries *sqlc.Queries
} }
const insertTimeout = 1 * time.Second const queryTimeout = 3 * time.Second
func NewUserRepository(db *sql.DB) repo.UserRepository { func NewUserRepository(db *sql.DB) repo.UserRepository {
return &userRepository{ return &userRepository{
querier: sqlc.New(db), queries: sqlc.New(db),
} }
} }
// Create // Create
func (ur *userRepository) Create( func (u *userRepository) Create(
ctx context.Context, ctx context.Context,
transaction interface{}, userEntity *model.UserEntity,
u *model.UserEntity, tx any,
) (*model.UserEntity, error) { ) (*model.UserEntity, error) {
timeoutCtx, cancel := context.WithTimeout(ctx, insertTimeout) timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel() defer cancel()
args := sqlc.InsertUserParams{ queries := getQueries(u.queries, tx)
Email: u.Email,
FirstName: u.FirstName, userDB, err := queries.InsertUser(timeoutCtx, sqlc.InsertUserParams{
LastName: u.LastName, Email: userEntity.Email,
Password: u.Password, FirstName: userEntity.FirstName,
LastName: userEntity.LastName,
Password: userEntity.Password,
CreatedAt: time.Now(), CreatedAt: time.Now(),
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
} })
tx, ok := transaction.(*sql.Tx)
if !ok {
return nil, errors.New("transaction is not a *sql.Tx")
}
queries := sqlc.New(tx)
userDB, err := queries.InsertUser(timeoutCtx, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -87,8 +80,17 @@ func (ur *userRepository) Create(
} }
// GetByEmail if not found, return nil for user but not error. // GetByEmail if not found, return nil for user but not error.
func (ur *userRepository) GetByEmail(ctx context.Context, email string) (*model.UserEntity, error) { func (u *userRepository) GetByEmail(
userDB, err := ur.querier.GetUserByEmail(ctx, email) ctx context.Context,
email string,
tx any,
) (*model.UserEntity, error) {
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(u.queries, tx)
userDB, err := queries.GetUserByEmail(timeoutCtx, email)
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
// No query error, but user not found // No query error, but user not found
return nil, nil return nil, nil
@ -107,8 +109,13 @@ func (ur *userRepository) GetByEmail(ctx context.Context, email string) (*model.
}, nil }, nil
} }
func (ur *userRepository) GetByID(ctx context.Context, id int) (*model.UserEntity, error) { func (u *userRepository) GetByID(ctx context.Context, id int, tx any) (*model.UserEntity, error) {
userDB, err := ur.querier.GetUserByID(ctx, int32(id)) timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(u.queries, tx)
userDB, err := queries.GetUserByID(timeoutCtx, int32(id))
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
// No query error, but user not found // No query error, but user not found
return nil, nil return nil, nil

View File

@ -29,30 +29,36 @@ import (
) )
type EventRepository interface { type EventRepository interface {
Create(ctx context.Context, evEntity *model.EventEntity) (*model.EventEntity, error) Create(ctx context.Context, evEntity *model.EventEntity, tx any) (*model.EventEntity, error)
// UpdateEventByID updates the event related information (name, descriptions) // UpdateEventByID updates the event related information (name, descriptions)
UpdateEventByID(ctx context.Context, event *model.EventUpdateEntity) error UpdateEventByID(ctx context.Context, event *model.EventUpdateEntity, tx any) error
GetByID(ctx context.Context, eventID int) (*model.EventRetrieved, error) GetByID(ctx context.Context, eventID int, tx any) (*model.EventRetrieved, error)
// related to events of a user // related to events of a user
ListEventsByUserID(ctx context.Context, userID int) ([]model.EventListRetrieved, error) ListEventsByUserID(ctx context.Context, userID int, tx any) ([]model.EventListRetrieved, error)
// CheckParticipation(ctx context.Context, userID, eventID int) error // CheckParticipation(ctx context.Context, userID, eventID int) error
} }
type ExpenseRepository interface { type ExpenseRepository interface {
DeleteExpense(ctx context.Context, expenseID int) error DeleteExpense(ctx context.Context, expenseID int, tx any) error
DeleteTransactionsOfExpense(ctx context.Context, expenseID int) error DeleteTransactionsOfExpense(ctx context.Context, expenseID int, tx any) error
GetExpenseByID(ctx context.Context, expenseID int) (*model.ExpenseRetrieved, error) GetExpenseByID(ctx context.Context, expenseID int, tx any) (*model.ExpenseRetrieved, error)
InsertExpense( InsertExpense(
ctx context.Context, ctx context.Context,
expenseEntity *model.ExpenseEntity, expenseEntity *model.ExpenseEntity,
tx any,
) (*model.ExpenseEntity, error) ) (*model.ExpenseEntity, error)
ListExpensesByEventID(ctx context.Context, id int) ([]model.ExpensesListRetrieved, error) ListExpensesByEventID(
ctx context.Context,
id int,
tx any,
) ([]model.ExpensesListRetrieved, error)
UpdateExpenseByID( UpdateExpenseByID(
ctx context.Context, ctx context.Context,
expenseUpdate *model.ExpenseUpdateEntity, expenseUpdate *model.ExpenseUpdateEntity,
tx any,
) (*model.ExpenseEntity, error) ) (*model.ExpenseEntity, error)
} }

View File

@ -31,9 +31,9 @@ import (
type UserRepository interface { type UserRepository interface {
Create( Create(
ctx context.Context, ctx context.Context,
transaction interface{},
u *model.UserEntity, u *model.UserEntity,
tx any,
) (*model.UserEntity, error) ) (*model.UserEntity, error)
GetByEmail(ctx context.Context, email string) (*model.UserEntity, error) GetByEmail(ctx context.Context, email string, tx any) (*model.UserEntity, error)
GetByID(ctx context.Context, id int) (*model.UserEntity, error) GetByID(ctx context.Context, id int, tx any) (*model.UserEntity, error)
} }

View File

@ -62,7 +62,7 @@ func (evuc *eventUsecase) CreateEvent(
ctx context.Context, ctx context.Context,
evRequest *model.EventCreateRequest, evRequest *model.EventCreateRequest,
) (*model.EventInfoResponse, error) { ) (*model.EventInfoResponse, error) {
// transfer evRequest to PO // transfer evRequest to evEntity
evEntity := &model.EventEntity{ evEntity := &model.EventEntity{
Name: evRequest.Name, Name: evRequest.Name,
@ -75,7 +75,8 @@ func (evuc *eventUsecase) CreateEvent(
data, err := evuc.dbRepo.Transaction( data, err := evuc.dbRepo.Transaction(
ctx, ctx,
func(txCtx context.Context, tx interface{}) (interface{}, error) { func(txCtx context.Context, tx interface{}) (interface{}, error) {
created, err := evuc.eventRepo.Create(ctx, evEntity) // Create
created, err := evuc.eventRepo.Create(ctx, evEntity, tx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -104,6 +105,7 @@ func (evuc *eventUsecase) CreateEvent(
), ),
Owner: ownerResponse, Owner: ownerResponse,
CreatedAt: created.CreatedAt, CreatedAt: created.CreatedAt,
UpdatedAt: created.UpdatedAt,
} }
return evResponse, err return evResponse, err
}) })

View File

@ -36,8 +36,8 @@ type TestUserRepository struct{}
func (tur *TestUserRepository) Create( func (tur *TestUserRepository) Create(
ctx context.Context, ctx context.Context,
transaction interface{},
u *model.UserEntity, u *model.UserEntity,
tx any,
) (*model.UserEntity, error) { ) (*model.UserEntity, error) {
user := *u user := *u
@ -53,6 +53,7 @@ func (tur *TestUserRepository) Create(
func (tur *TestUserRepository) GetByEmail( func (tur *TestUserRepository) GetByEmail(
ctx context.Context, ctx context.Context,
email string, email string,
tx any,
) (*model.UserEntity, error) { ) (*model.UserEntity, error) {
hashedPwd, _ := bcrypt.GenerateFromPassword([]byte("strongHashed"), 12) hashedPwd, _ := bcrypt.GenerateFromPassword([]byte("strongHashed"), 12)
switch email { switch email {
@ -71,7 +72,11 @@ func (tur *TestUserRepository) GetByEmail(
return nil, UserTestDummyErr return nil, UserTestDummyErr
} }
func (tur *TestUserRepository) GetByID(ctx context.Context, id int) (*model.UserEntity, error) { func (tur *TestUserRepository) GetByID(
ctx context.Context,
id int,
tx any,
) (*model.UserEntity, error) {
hashedPwd, _ := bcrypt.GenerateFromPassword([]byte("strongHashed"), 12) hashedPwd, _ := bcrypt.GenerateFromPassword([]byte("strongHashed"), 12)
switch id { switch id {
case 123: case 123:

View File

@ -86,12 +86,12 @@ func (uuc *userUsecase) Create(
data, err := uuc.dbRepo.Transaction( data, err := uuc.dbRepo.Transaction(
ctx, ctx,
func(txCtx context.Context, tx interface{}) (interface{}, error) { func(txCtx context.Context, tx interface{}) (interface{}, error) {
created, err := uuc.userRepo.Create(txCtx, tx, &model.UserEntity{ created, err := uuc.userRepo.Create(txCtx, &model.UserEntity{
Email: u.Email, Email: u.Email,
Password: u.Password, Password: u.Password,
FirstName: u.FirstName, FirstName: u.FirstName,
LastName: u.LastName, LastName: u.LastName,
}) }, tx)
if err != nil { if err != nil {
match, _ := regexp.MatchString("SQLSTATE 23505", err.Error()) match, _ := regexp.MatchString("SQLSTATE 23505", err.Error())
if match { if match {
@ -132,7 +132,7 @@ func (uuc *userUsecase) Create(
} }
func (uuc *userUsecase) Exist(ctx context.Context, u *model.UserExistRequest) error { func (uuc *userUsecase) Exist(ctx context.Context, u *model.UserExistRequest) error {
got, err := uuc.userRepo.GetByEmail(ctx, u.Email) got, err := uuc.userRepo.GetByEmail(ctx, u.Email, nil)
// Any query error? // Any query error?
if err != nil { if err != nil {
return err return err
@ -160,7 +160,7 @@ func (uuc *userUsecase) GetUserBaseResponseByID(
// If not exists, get from the DB. And then put back // If not exists, get from the DB. And then put back
// into the cache with a timeout. // into the cache with a timeout.
// Refresh the cache when the user data is updated (for now it cannot be updated) // Refresh the cache when the user data is updated (for now it cannot be updated)
got, err := uuc.userRepo.GetByID(ctx, userID) got, err := uuc.userRepo.GetByID(ctx, userID, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }