refacto: add db tx as a possible input for repo methods
All checks were successful
Build and test / Build (push) Successful in 2m22s

This commit is contained in:
Muyao CHEN 2024-10-25 23:52:43 +02:00
parent b30a5c5c2d
commit 14ee642aab
10 changed files with 176 additions and 58 deletions

View File

@ -577,3 +577,8 @@ But between `[]*T` and `[]T`, the only difference that I see (pointed out by
`ChatGPT`) is how the memory is allocated. With `[]T` it might be better for
the GC to deal with the memory free. I thing for my project I will stick to
`[]T`.
### 2024/10/25
Read this [article](https://konradreiche.com/blog/two-common-go-interface-misuses/)
today, maybe I am abusing the usage of interfaces?

View File

@ -26,6 +26,7 @@ import (
"context"
"database/sql"
"git.vinchent.xyz/vinchent/howmuch/internal/howmuch/adapter/repo/sqlc"
"git.vinchent.xyz/vinchent/howmuch/internal/howmuch/usecase/repo"
"git.vinchent.xyz/vinchent/howmuch/internal/pkg/log"
)
@ -66,3 +67,11 @@ func (dr *dbRepository) Transaction(
data, err := txFunc(ctx, tx)
return data, err
}
func getQueries(queries *sqlc.Queries, tx any) *sqlc.Queries {
transaction, ok := tx.(*sql.Tx)
if ok {
return sqlc.New(transaction)
}
return queries
}

View File

@ -48,8 +48,35 @@ func NewEventRepository(db *sql.DB) repo.EventRepository {
func (e *eventRepository) Create(
ctx context.Context,
evEntity *model.EventEntity,
tx any,
) (*model.EventEntity, error) {
panic("unimplemented")
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
event, err := queries.InsertEvent(timeoutCtx, sqlc.InsertEventParams{
Name: evEntity.Name,
Description: sql.NullString{String: evEntity.Description, Valid: true},
TotalAmount: sql.NullInt32{Int32: int32(evEntity.TotalAmount), Valid: true},
DefaultCurrency: evEntity.DefaultCurrency,
OwnerID: int32(evEntity.OwnerID),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
})
if err != nil {
return nil, err
}
return &model.EventEntity{
ID: int(event.ID),
Name: event.Name,
Description: event.Description.String,
TotalAmount: int(event.TotalAmount.Int32),
DefaultCurrency: event.DefaultCurrency,
OwnerID: int(event.OwnerID),
CreatedAt: event.CreatedAt,
UpdatedAt: event.UpdatedAt,
}, nil
}
func convToEventRetrieved(eventDTO *sqlc.GetEventByIDRow) (*model.EventRetrieved, error) {
@ -89,8 +116,17 @@ func convToEventRetrieved(eventDTO *sqlc.GetEventByIDRow) (*model.EventRetrieved
}
// GetByID implements repo.EventRepository.
func (e *eventRepository) GetByID(ctx context.Context, eventID int) (*model.EventRetrieved, error) {
eventDTO, err := e.queries.GetEventByID(ctx, int32(eventID))
func (e *eventRepository) GetByID(
ctx context.Context,
eventID int,
tx any,
) (*model.EventRetrieved, error) {
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
eventDTO, err := queries.GetEventByID(timeoutCtx, int32(eventID))
if err != nil {
log.ErrorLog("query error", "err", err)
return nil, err
@ -128,8 +164,14 @@ func convToEventList(eventsDTO []sqlc.ListEventsByUserIDRow) ([]model.EventListR
func (e *eventRepository) ListEventsByUserID(
ctx context.Context,
userID int,
tx any,
) ([]model.EventListRetrieved, error) {
eventsDTO, err := e.queries.ListEventsByUserID(ctx, int32(userID))
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
eventsDTO, err := queries.ListEventsByUserID(timeoutCtx, int32(userID))
if err != nil {
log.ErrorLog("query error", "err", err)
return nil, err
@ -142,8 +184,14 @@ func (e *eventRepository) ListEventsByUserID(
func (e *eventRepository) UpdateEventByID(
ctx context.Context,
event *model.EventUpdateEntity,
tx any,
) error {
err := e.queries.UpdateEventByID(ctx, sqlc.UpdateEventByIDParams{
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
err := queries.UpdateEventByID(timeoutCtx, sqlc.UpdateEventByIDParams{
ID: int32(event.ID),
Name: event.Name,
Description: sql.NullString{String: event.Description, Valid: true},

View File

@ -45,21 +45,39 @@ func NewExpenseRepository(db *sql.DB) repo.ExpenseRepository {
}
// DeleteExpense implements repo.ExpenseRepository.
func (e *expenseRepository) DeleteExpense(ctx context.Context, expenseID int) error {
return e.queries.DeleteExpense(ctx, int32(expenseID))
func (e *expenseRepository) DeleteExpense(ctx context.Context, expenseID int, tx any) error {
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
return queries.DeleteExpense(timeoutCtx, int32(expenseID))
}
// DeleteTransactionsOfExpense implements repo.ExpenseRepository.
func (e *expenseRepository) DeleteTransactionsOfExpense(ctx context.Context, expenseID int) error {
return e.queries.DeleteTransactionsOfExpenseID(ctx, int32(expenseID))
func (e *expenseRepository) DeleteTransactionsOfExpense(
ctx context.Context,
expenseID int,
tx any,
) error {
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
return queries.DeleteTransactionsOfExpenseID(timeoutCtx, int32(expenseID))
}
// GetExpenseByID implements repo.ExpenseRepository.
func (e *expenseRepository) GetExpenseByID(
ctx context.Context,
expenseID int,
tx any,
) (*model.ExpenseRetrieved, error) {
expenseDTO, err := e.queries.GetExpenseByID(ctx, int32(expenseID))
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
expenseDTO, err := queries.GetExpenseByID(timeoutCtx, int32(expenseID))
if err != nil {
return nil, err
}
@ -150,8 +168,14 @@ func convToExpenseRetrieved(expenseDTO *sqlc.GetExpenseByIDRow) (*model.ExpenseR
func (e *expenseRepository) InsertExpense(
ctx context.Context,
expenseEntity *model.ExpenseEntity,
tx any,
) (*model.ExpenseEntity, error) {
expenseDTO, err := e.queries.InsertExpense(ctx, sqlc.InsertExpenseParams{
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
expenseDTO, err := queries.InsertExpense(timeoutCtx, sqlc.InsertExpenseParams{
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
Amount: int32(expenseEntity.Amount),
@ -179,8 +203,14 @@ func (e *expenseRepository) InsertExpense(
func (e *expenseRepository) ListExpensesByEventID(
ctx context.Context,
id int,
tx any,
) ([]model.ExpensesListRetrieved, error) {
listDTO, err := e.queries.ListExpensesByEventID(ctx, int32(id))
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
listDTO, err := queries.ListExpensesByEventID(timeoutCtx, int32(id))
if err != nil {
return nil, err
}
@ -206,8 +236,14 @@ func (e *expenseRepository) ListExpensesByEventID(
func (e *expenseRepository) UpdateExpenseByID(
ctx context.Context,
expenseUpdate *model.ExpenseUpdateEntity,
tx any,
) (*model.ExpenseEntity, error) {
expenseDTO, err := e.queries.UpdateExpenseByID(ctx, sqlc.UpdateExpenseByIDParams{
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(e.queries, tx)
expenseDTO, err := queries.UpdateExpenseByID(timeoutCtx, sqlc.UpdateExpenseByIDParams{
ID: int32(expenseUpdate.ID),
UpdatedAt: time.Now(),
Amount: int32(expenseUpdate.Amount),

View File

@ -34,43 +34,36 @@ import (
)
type userRepository struct {
querier *sqlc.Queries
queries *sqlc.Queries
}
const insertTimeout = 1 * time.Second
const queryTimeout = 3 * time.Second
func NewUserRepository(db *sql.DB) repo.UserRepository {
return &userRepository{
querier: sqlc.New(db),
queries: sqlc.New(db),
}
}
// Create
func (ur *userRepository) Create(
func (u *userRepository) Create(
ctx context.Context,
transaction interface{},
u *model.UserEntity,
userEntity *model.UserEntity,
tx any,
) (*model.UserEntity, error) {
timeoutCtx, cancel := context.WithTimeout(ctx, insertTimeout)
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
args := sqlc.InsertUserParams{
Email: u.Email,
FirstName: u.FirstName,
LastName: u.LastName,
Password: u.Password,
queries := getQueries(u.queries, tx)
userDB, err := queries.InsertUser(timeoutCtx, sqlc.InsertUserParams{
Email: userEntity.Email,
FirstName: userEntity.FirstName,
LastName: userEntity.LastName,
Password: userEntity.Password,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
tx, ok := transaction.(*sql.Tx)
if !ok {
return nil, errors.New("transaction is not a *sql.Tx")
}
queries := sqlc.New(tx)
userDB, err := queries.InsertUser(timeoutCtx, args)
})
if err != nil {
return nil, err
}
@ -87,8 +80,17 @@ func (ur *userRepository) Create(
}
// GetByEmail if not found, return nil for user but not error.
func (ur *userRepository) GetByEmail(ctx context.Context, email string) (*model.UserEntity, error) {
userDB, err := ur.querier.GetUserByEmail(ctx, email)
func (u *userRepository) GetByEmail(
ctx context.Context,
email string,
tx any,
) (*model.UserEntity, error) {
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(u.queries, tx)
userDB, err := queries.GetUserByEmail(timeoutCtx, email)
if errors.Is(err, sql.ErrNoRows) {
// No query error, but user not found
return nil, nil
@ -107,8 +109,13 @@ func (ur *userRepository) GetByEmail(ctx context.Context, email string) (*model.
}, nil
}
func (ur *userRepository) GetByID(ctx context.Context, id int) (*model.UserEntity, error) {
userDB, err := ur.querier.GetUserByID(ctx, int32(id))
func (u *userRepository) GetByID(ctx context.Context, id int, tx any) (*model.UserEntity, error) {
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
queries := getQueries(u.queries, tx)
userDB, err := queries.GetUserByID(timeoutCtx, int32(id))
if errors.Is(err, sql.ErrNoRows) {
// No query error, but user not found
return nil, nil

View File

@ -29,30 +29,36 @@ import (
)
type EventRepository interface {
Create(ctx context.Context, evEntity *model.EventEntity) (*model.EventEntity, error)
Create(ctx context.Context, evEntity *model.EventEntity, tx any) (*model.EventEntity, error)
// UpdateEventByID updates the event related information (name, descriptions)
UpdateEventByID(ctx context.Context, event *model.EventUpdateEntity) error
UpdateEventByID(ctx context.Context, event *model.EventUpdateEntity, tx any) error
GetByID(ctx context.Context, eventID int) (*model.EventRetrieved, error)
GetByID(ctx context.Context, eventID int, tx any) (*model.EventRetrieved, error)
// related to events of a user
ListEventsByUserID(ctx context.Context, userID int) ([]model.EventListRetrieved, error)
ListEventsByUserID(ctx context.Context, userID int, tx any) ([]model.EventListRetrieved, error)
// CheckParticipation(ctx context.Context, userID, eventID int) error
}
type ExpenseRepository interface {
DeleteExpense(ctx context.Context, expenseID int) error
DeleteTransactionsOfExpense(ctx context.Context, expenseID int) error
GetExpenseByID(ctx context.Context, expenseID int) (*model.ExpenseRetrieved, error)
DeleteExpense(ctx context.Context, expenseID int, tx any) error
DeleteTransactionsOfExpense(ctx context.Context, expenseID int, tx any) error
GetExpenseByID(ctx context.Context, expenseID int, tx any) (*model.ExpenseRetrieved, error)
InsertExpense(
ctx context.Context,
expenseEntity *model.ExpenseEntity,
tx any,
) (*model.ExpenseEntity, error)
ListExpensesByEventID(ctx context.Context, id int) ([]model.ExpensesListRetrieved, error)
ListExpensesByEventID(
ctx context.Context,
id int,
tx any,
) ([]model.ExpensesListRetrieved, error)
UpdateExpenseByID(
ctx context.Context,
expenseUpdate *model.ExpenseUpdateEntity,
tx any,
) (*model.ExpenseEntity, error)
}

View File

@ -31,9 +31,9 @@ import (
type UserRepository interface {
Create(
ctx context.Context,
transaction interface{},
u *model.UserEntity,
tx any,
) (*model.UserEntity, error)
GetByEmail(ctx context.Context, email string) (*model.UserEntity, error)
GetByID(ctx context.Context, id int) (*model.UserEntity, error)
GetByEmail(ctx context.Context, email string, tx any) (*model.UserEntity, error)
GetByID(ctx context.Context, id int, tx any) (*model.UserEntity, error)
}

View File

@ -62,7 +62,7 @@ func (evuc *eventUsecase) CreateEvent(
ctx context.Context,
evRequest *model.EventCreateRequest,
) (*model.EventInfoResponse, error) {
// transfer evRequest to PO
// transfer evRequest to evEntity
evEntity := &model.EventEntity{
Name: evRequest.Name,
@ -75,7 +75,8 @@ func (evuc *eventUsecase) CreateEvent(
data, err := evuc.dbRepo.Transaction(
ctx,
func(txCtx context.Context, tx interface{}) (interface{}, error) {
created, err := evuc.eventRepo.Create(ctx, evEntity)
// Create
created, err := evuc.eventRepo.Create(ctx, evEntity, tx)
if err != nil {
return nil, err
}
@ -104,6 +105,7 @@ func (evuc *eventUsecase) CreateEvent(
),
Owner: ownerResponse,
CreatedAt: created.CreatedAt,
UpdatedAt: created.UpdatedAt,
}
return evResponse, err
})

View File

@ -36,8 +36,8 @@ type TestUserRepository struct{}
func (tur *TestUserRepository) Create(
ctx context.Context,
transaction interface{},
u *model.UserEntity,
tx any,
) (*model.UserEntity, error) {
user := *u
@ -53,6 +53,7 @@ func (tur *TestUserRepository) Create(
func (tur *TestUserRepository) GetByEmail(
ctx context.Context,
email string,
tx any,
) (*model.UserEntity, error) {
hashedPwd, _ := bcrypt.GenerateFromPassword([]byte("strongHashed"), 12)
switch email {
@ -71,7 +72,11 @@ func (tur *TestUserRepository) GetByEmail(
return nil, UserTestDummyErr
}
func (tur *TestUserRepository) GetByID(ctx context.Context, id int) (*model.UserEntity, error) {
func (tur *TestUserRepository) GetByID(
ctx context.Context,
id int,
tx any,
) (*model.UserEntity, error) {
hashedPwd, _ := bcrypt.GenerateFromPassword([]byte("strongHashed"), 12)
switch id {
case 123:

View File

@ -86,12 +86,12 @@ func (uuc *userUsecase) Create(
data, err := uuc.dbRepo.Transaction(
ctx,
func(txCtx context.Context, tx interface{}) (interface{}, error) {
created, err := uuc.userRepo.Create(txCtx, tx, &model.UserEntity{
created, err := uuc.userRepo.Create(txCtx, &model.UserEntity{
Email: u.Email,
Password: u.Password,
FirstName: u.FirstName,
LastName: u.LastName,
})
}, tx)
if err != nil {
match, _ := regexp.MatchString("SQLSTATE 23505", err.Error())
if match {
@ -132,7 +132,7 @@ func (uuc *userUsecase) Create(
}
func (uuc *userUsecase) Exist(ctx context.Context, u *model.UserExistRequest) error {
got, err := uuc.userRepo.GetByEmail(ctx, u.Email)
got, err := uuc.userRepo.GetByEmail(ctx, u.Email, nil)
// Any query error?
if err != nil {
return err
@ -160,7 +160,7 @@ func (uuc *userUsecase) GetUserBaseResponseByID(
// If not exists, get from the DB. And then put back
// into the cache with a timeout.
// Refresh the cache when the user data is updated (for now it cannot be updated)
got, err := uuc.userRepo.GetByID(ctx, userID)
got, err := uuc.userRepo.GetByID(ctx, userID, nil)
if err != nil {
return nil, err
}