refacto: add db tx as a possible input for repo methods
All checks were successful
Build and test / Build (push) Successful in 2m22s
All checks were successful
Build and test / Build (push) Successful in 2m22s
This commit is contained in:
parent
b30a5c5c2d
commit
14ee642aab
@ -577,3 +577,8 @@ But between `[]*T` and `[]T`, the only difference that I see (pointed out by
|
||||
`ChatGPT`) is how the memory is allocated. With `[]T` it might be better for
|
||||
the GC to deal with the memory free. I thing for my project I will stick to
|
||||
`[]T`.
|
||||
|
||||
### 2024/10/25
|
||||
|
||||
Read this [article](https://konradreiche.com/blog/two-common-go-interface-misuses/)
|
||||
today, maybe I am abusing the usage of interfaces?
|
||||
|
@ -26,6 +26,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"git.vinchent.xyz/vinchent/howmuch/internal/howmuch/adapter/repo/sqlc"
|
||||
"git.vinchent.xyz/vinchent/howmuch/internal/howmuch/usecase/repo"
|
||||
"git.vinchent.xyz/vinchent/howmuch/internal/pkg/log"
|
||||
)
|
||||
@ -66,3 +67,11 @@ func (dr *dbRepository) Transaction(
|
||||
data, err := txFunc(ctx, tx)
|
||||
return data, err
|
||||
}
|
||||
|
||||
func getQueries(queries *sqlc.Queries, tx any) *sqlc.Queries {
|
||||
transaction, ok := tx.(*sql.Tx)
|
||||
if ok {
|
||||
return sqlc.New(transaction)
|
||||
}
|
||||
return queries
|
||||
}
|
||||
|
@ -48,8 +48,35 @@ func NewEventRepository(db *sql.DB) repo.EventRepository {
|
||||
func (e *eventRepository) Create(
|
||||
ctx context.Context,
|
||||
evEntity *model.EventEntity,
|
||||
tx any,
|
||||
) (*model.EventEntity, error) {
|
||||
panic("unimplemented")
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
|
||||
defer cancel()
|
||||
|
||||
queries := getQueries(e.queries, tx)
|
||||
|
||||
event, err := queries.InsertEvent(timeoutCtx, sqlc.InsertEventParams{
|
||||
Name: evEntity.Name,
|
||||
Description: sql.NullString{String: evEntity.Description, Valid: true},
|
||||
TotalAmount: sql.NullInt32{Int32: int32(evEntity.TotalAmount), Valid: true},
|
||||
DefaultCurrency: evEntity.DefaultCurrency,
|
||||
OwnerID: int32(evEntity.OwnerID),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model.EventEntity{
|
||||
ID: int(event.ID),
|
||||
Name: event.Name,
|
||||
Description: event.Description.String,
|
||||
TotalAmount: int(event.TotalAmount.Int32),
|
||||
DefaultCurrency: event.DefaultCurrency,
|
||||
OwnerID: int(event.OwnerID),
|
||||
CreatedAt: event.CreatedAt,
|
||||
UpdatedAt: event.UpdatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func convToEventRetrieved(eventDTO *sqlc.GetEventByIDRow) (*model.EventRetrieved, error) {
|
||||
@ -89,8 +116,17 @@ func convToEventRetrieved(eventDTO *sqlc.GetEventByIDRow) (*model.EventRetrieved
|
||||
}
|
||||
|
||||
// GetByID implements repo.EventRepository.
|
||||
func (e *eventRepository) GetByID(ctx context.Context, eventID int) (*model.EventRetrieved, error) {
|
||||
eventDTO, err := e.queries.GetEventByID(ctx, int32(eventID))
|
||||
func (e *eventRepository) GetByID(
|
||||
ctx context.Context,
|
||||
eventID int,
|
||||
tx any,
|
||||
) (*model.EventRetrieved, error) {
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
|
||||
defer cancel()
|
||||
|
||||
queries := getQueries(e.queries, tx)
|
||||
|
||||
eventDTO, err := queries.GetEventByID(timeoutCtx, int32(eventID))
|
||||
if err != nil {
|
||||
log.ErrorLog("query error", "err", err)
|
||||
return nil, err
|
||||
@ -128,8 +164,14 @@ func convToEventList(eventsDTO []sqlc.ListEventsByUserIDRow) ([]model.EventListR
|
||||
func (e *eventRepository) ListEventsByUserID(
|
||||
ctx context.Context,
|
||||
userID int,
|
||||
tx any,
|
||||
) ([]model.EventListRetrieved, error) {
|
||||
eventsDTO, err := e.queries.ListEventsByUserID(ctx, int32(userID))
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
|
||||
defer cancel()
|
||||
|
||||
queries := getQueries(e.queries, tx)
|
||||
|
||||
eventsDTO, err := queries.ListEventsByUserID(timeoutCtx, int32(userID))
|
||||
if err != nil {
|
||||
log.ErrorLog("query error", "err", err)
|
||||
return nil, err
|
||||
@ -142,8 +184,14 @@ func (e *eventRepository) ListEventsByUserID(
|
||||
func (e *eventRepository) UpdateEventByID(
|
||||
ctx context.Context,
|
||||
event *model.EventUpdateEntity,
|
||||
tx any,
|
||||
) error {
|
||||
err := e.queries.UpdateEventByID(ctx, sqlc.UpdateEventByIDParams{
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
|
||||
defer cancel()
|
||||
|
||||
queries := getQueries(e.queries, tx)
|
||||
|
||||
err := queries.UpdateEventByID(timeoutCtx, sqlc.UpdateEventByIDParams{
|
||||
ID: int32(event.ID),
|
||||
Name: event.Name,
|
||||
Description: sql.NullString{String: event.Description, Valid: true},
|
||||
|
@ -45,21 +45,39 @@ func NewExpenseRepository(db *sql.DB) repo.ExpenseRepository {
|
||||
}
|
||||
|
||||
// DeleteExpense implements repo.ExpenseRepository.
|
||||
func (e *expenseRepository) DeleteExpense(ctx context.Context, expenseID int) error {
|
||||
return e.queries.DeleteExpense(ctx, int32(expenseID))
|
||||
func (e *expenseRepository) DeleteExpense(ctx context.Context, expenseID int, tx any) error {
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
|
||||
defer cancel()
|
||||
|
||||
queries := getQueries(e.queries, tx)
|
||||
return queries.DeleteExpense(timeoutCtx, int32(expenseID))
|
||||
}
|
||||
|
||||
// DeleteTransactionsOfExpense implements repo.ExpenseRepository.
|
||||
func (e *expenseRepository) DeleteTransactionsOfExpense(ctx context.Context, expenseID int) error {
|
||||
return e.queries.DeleteTransactionsOfExpenseID(ctx, int32(expenseID))
|
||||
func (e *expenseRepository) DeleteTransactionsOfExpense(
|
||||
ctx context.Context,
|
||||
expenseID int,
|
||||
tx any,
|
||||
) error {
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
|
||||
defer cancel()
|
||||
|
||||
queries := getQueries(e.queries, tx)
|
||||
return queries.DeleteTransactionsOfExpenseID(timeoutCtx, int32(expenseID))
|
||||
}
|
||||
|
||||
// GetExpenseByID implements repo.ExpenseRepository.
|
||||
func (e *expenseRepository) GetExpenseByID(
|
||||
ctx context.Context,
|
||||
expenseID int,
|
||||
tx any,
|
||||
) (*model.ExpenseRetrieved, error) {
|
||||
expenseDTO, err := e.queries.GetExpenseByID(ctx, int32(expenseID))
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
|
||||
defer cancel()
|
||||
|
||||
queries := getQueries(e.queries, tx)
|
||||
|
||||
expenseDTO, err := queries.GetExpenseByID(timeoutCtx, int32(expenseID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -150,8 +168,14 @@ func convToExpenseRetrieved(expenseDTO *sqlc.GetExpenseByIDRow) (*model.ExpenseR
|
||||
func (e *expenseRepository) InsertExpense(
|
||||
ctx context.Context,
|
||||
expenseEntity *model.ExpenseEntity,
|
||||
tx any,
|
||||
) (*model.ExpenseEntity, error) {
|
||||
expenseDTO, err := e.queries.InsertExpense(ctx, sqlc.InsertExpenseParams{
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
|
||||
defer cancel()
|
||||
|
||||
queries := getQueries(e.queries, tx)
|
||||
|
||||
expenseDTO, err := queries.InsertExpense(timeoutCtx, sqlc.InsertExpenseParams{
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
Amount: int32(expenseEntity.Amount),
|
||||
@ -179,8 +203,14 @@ func (e *expenseRepository) InsertExpense(
|
||||
func (e *expenseRepository) ListExpensesByEventID(
|
||||
ctx context.Context,
|
||||
id int,
|
||||
tx any,
|
||||
) ([]model.ExpensesListRetrieved, error) {
|
||||
listDTO, err := e.queries.ListExpensesByEventID(ctx, int32(id))
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
|
||||
defer cancel()
|
||||
|
||||
queries := getQueries(e.queries, tx)
|
||||
|
||||
listDTO, err := queries.ListExpensesByEventID(timeoutCtx, int32(id))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -206,8 +236,14 @@ func (e *expenseRepository) ListExpensesByEventID(
|
||||
func (e *expenseRepository) UpdateExpenseByID(
|
||||
ctx context.Context,
|
||||
expenseUpdate *model.ExpenseUpdateEntity,
|
||||
tx any,
|
||||
) (*model.ExpenseEntity, error) {
|
||||
expenseDTO, err := e.queries.UpdateExpenseByID(ctx, sqlc.UpdateExpenseByIDParams{
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
|
||||
defer cancel()
|
||||
|
||||
queries := getQueries(e.queries, tx)
|
||||
|
||||
expenseDTO, err := queries.UpdateExpenseByID(timeoutCtx, sqlc.UpdateExpenseByIDParams{
|
||||
ID: int32(expenseUpdate.ID),
|
||||
UpdatedAt: time.Now(),
|
||||
Amount: int32(expenseUpdate.Amount),
|
||||
|
@ -34,43 +34,36 @@ import (
|
||||
)
|
||||
|
||||
type userRepository struct {
|
||||
querier *sqlc.Queries
|
||||
queries *sqlc.Queries
|
||||
}
|
||||
|
||||
const insertTimeout = 1 * time.Second
|
||||
const queryTimeout = 3 * time.Second
|
||||
|
||||
func NewUserRepository(db *sql.DB) repo.UserRepository {
|
||||
return &userRepository{
|
||||
querier: sqlc.New(db),
|
||||
queries: sqlc.New(db),
|
||||
}
|
||||
}
|
||||
|
||||
// Create
|
||||
func (ur *userRepository) Create(
|
||||
func (u *userRepository) Create(
|
||||
ctx context.Context,
|
||||
transaction interface{},
|
||||
u *model.UserEntity,
|
||||
userEntity *model.UserEntity,
|
||||
tx any,
|
||||
) (*model.UserEntity, error) {
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, insertTimeout)
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
|
||||
defer cancel()
|
||||
|
||||
args := sqlc.InsertUserParams{
|
||||
Email: u.Email,
|
||||
FirstName: u.FirstName,
|
||||
LastName: u.LastName,
|
||||
Password: u.Password,
|
||||
queries := getQueries(u.queries, tx)
|
||||
|
||||
userDB, err := queries.InsertUser(timeoutCtx, sqlc.InsertUserParams{
|
||||
Email: userEntity.Email,
|
||||
FirstName: userEntity.FirstName,
|
||||
LastName: userEntity.LastName,
|
||||
Password: userEntity.Password,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
tx, ok := transaction.(*sql.Tx)
|
||||
if !ok {
|
||||
return nil, errors.New("transaction is not a *sql.Tx")
|
||||
}
|
||||
|
||||
queries := sqlc.New(tx)
|
||||
|
||||
userDB, err := queries.InsertUser(timeoutCtx, args)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -87,8 +80,17 @@ func (ur *userRepository) Create(
|
||||
}
|
||||
|
||||
// GetByEmail if not found, return nil for user but not error.
|
||||
func (ur *userRepository) GetByEmail(ctx context.Context, email string) (*model.UserEntity, error) {
|
||||
userDB, err := ur.querier.GetUserByEmail(ctx, email)
|
||||
func (u *userRepository) GetByEmail(
|
||||
ctx context.Context,
|
||||
email string,
|
||||
tx any,
|
||||
) (*model.UserEntity, error) {
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
|
||||
defer cancel()
|
||||
|
||||
queries := getQueries(u.queries, tx)
|
||||
|
||||
userDB, err := queries.GetUserByEmail(timeoutCtx, email)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// No query error, but user not found
|
||||
return nil, nil
|
||||
@ -107,8 +109,13 @@ func (ur *userRepository) GetByEmail(ctx context.Context, email string) (*model.
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ur *userRepository) GetByID(ctx context.Context, id int) (*model.UserEntity, error) {
|
||||
userDB, err := ur.querier.GetUserByID(ctx, int32(id))
|
||||
func (u *userRepository) GetByID(ctx context.Context, id int, tx any) (*model.UserEntity, error) {
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout)
|
||||
defer cancel()
|
||||
|
||||
queries := getQueries(u.queries, tx)
|
||||
|
||||
userDB, err := queries.GetUserByID(timeoutCtx, int32(id))
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// No query error, but user not found
|
||||
return nil, nil
|
||||
|
@ -29,30 +29,36 @@ import (
|
||||
)
|
||||
|
||||
type EventRepository interface {
|
||||
Create(ctx context.Context, evEntity *model.EventEntity) (*model.EventEntity, error)
|
||||
Create(ctx context.Context, evEntity *model.EventEntity, tx any) (*model.EventEntity, error)
|
||||
|
||||
// UpdateEventByID updates the event related information (name, descriptions)
|
||||
UpdateEventByID(ctx context.Context, event *model.EventUpdateEntity) error
|
||||
UpdateEventByID(ctx context.Context, event *model.EventUpdateEntity, tx any) error
|
||||
|
||||
GetByID(ctx context.Context, eventID int) (*model.EventRetrieved, error)
|
||||
GetByID(ctx context.Context, eventID int, tx any) (*model.EventRetrieved, error)
|
||||
|
||||
// related to events of a user
|
||||
ListEventsByUserID(ctx context.Context, userID int) ([]model.EventListRetrieved, error)
|
||||
ListEventsByUserID(ctx context.Context, userID int, tx any) ([]model.EventListRetrieved, error)
|
||||
|
||||
// CheckParticipation(ctx context.Context, userID, eventID int) error
|
||||
}
|
||||
|
||||
type ExpenseRepository interface {
|
||||
DeleteExpense(ctx context.Context, expenseID int) error
|
||||
DeleteTransactionsOfExpense(ctx context.Context, expenseID int) error
|
||||
GetExpenseByID(ctx context.Context, expenseID int) (*model.ExpenseRetrieved, error)
|
||||
DeleteExpense(ctx context.Context, expenseID int, tx any) error
|
||||
DeleteTransactionsOfExpense(ctx context.Context, expenseID int, tx any) error
|
||||
GetExpenseByID(ctx context.Context, expenseID int, tx any) (*model.ExpenseRetrieved, error)
|
||||
InsertExpense(
|
||||
ctx context.Context,
|
||||
expenseEntity *model.ExpenseEntity,
|
||||
tx any,
|
||||
) (*model.ExpenseEntity, error)
|
||||
ListExpensesByEventID(ctx context.Context, id int) ([]model.ExpensesListRetrieved, error)
|
||||
ListExpensesByEventID(
|
||||
ctx context.Context,
|
||||
id int,
|
||||
tx any,
|
||||
) ([]model.ExpensesListRetrieved, error)
|
||||
UpdateExpenseByID(
|
||||
ctx context.Context,
|
||||
expenseUpdate *model.ExpenseUpdateEntity,
|
||||
tx any,
|
||||
) (*model.ExpenseEntity, error)
|
||||
}
|
||||
|
@ -31,9 +31,9 @@ import (
|
||||
type UserRepository interface {
|
||||
Create(
|
||||
ctx context.Context,
|
||||
transaction interface{},
|
||||
u *model.UserEntity,
|
||||
tx any,
|
||||
) (*model.UserEntity, error)
|
||||
GetByEmail(ctx context.Context, email string) (*model.UserEntity, error)
|
||||
GetByID(ctx context.Context, id int) (*model.UserEntity, error)
|
||||
GetByEmail(ctx context.Context, email string, tx any) (*model.UserEntity, error)
|
||||
GetByID(ctx context.Context, id int, tx any) (*model.UserEntity, error)
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ func (evuc *eventUsecase) CreateEvent(
|
||||
ctx context.Context,
|
||||
evRequest *model.EventCreateRequest,
|
||||
) (*model.EventInfoResponse, error) {
|
||||
// transfer evRequest to PO
|
||||
// transfer evRequest to evEntity
|
||||
|
||||
evEntity := &model.EventEntity{
|
||||
Name: evRequest.Name,
|
||||
@ -75,7 +75,8 @@ func (evuc *eventUsecase) CreateEvent(
|
||||
data, err := evuc.dbRepo.Transaction(
|
||||
ctx,
|
||||
func(txCtx context.Context, tx interface{}) (interface{}, error) {
|
||||
created, err := evuc.eventRepo.Create(ctx, evEntity)
|
||||
// Create
|
||||
created, err := evuc.eventRepo.Create(ctx, evEntity, tx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -104,6 +105,7 @@ func (evuc *eventUsecase) CreateEvent(
|
||||
),
|
||||
Owner: ownerResponse,
|
||||
CreatedAt: created.CreatedAt,
|
||||
UpdatedAt: created.UpdatedAt,
|
||||
}
|
||||
return evResponse, err
|
||||
})
|
||||
|
@ -36,8 +36,8 @@ type TestUserRepository struct{}
|
||||
|
||||
func (tur *TestUserRepository) Create(
|
||||
ctx context.Context,
|
||||
transaction interface{},
|
||||
u *model.UserEntity,
|
||||
tx any,
|
||||
) (*model.UserEntity, error) {
|
||||
user := *u
|
||||
|
||||
@ -53,6 +53,7 @@ func (tur *TestUserRepository) Create(
|
||||
func (tur *TestUserRepository) GetByEmail(
|
||||
ctx context.Context,
|
||||
email string,
|
||||
tx any,
|
||||
) (*model.UserEntity, error) {
|
||||
hashedPwd, _ := bcrypt.GenerateFromPassword([]byte("strongHashed"), 12)
|
||||
switch email {
|
||||
@ -71,7 +72,11 @@ func (tur *TestUserRepository) GetByEmail(
|
||||
return nil, UserTestDummyErr
|
||||
}
|
||||
|
||||
func (tur *TestUserRepository) GetByID(ctx context.Context, id int) (*model.UserEntity, error) {
|
||||
func (tur *TestUserRepository) GetByID(
|
||||
ctx context.Context,
|
||||
id int,
|
||||
tx any,
|
||||
) (*model.UserEntity, error) {
|
||||
hashedPwd, _ := bcrypt.GenerateFromPassword([]byte("strongHashed"), 12)
|
||||
switch id {
|
||||
case 123:
|
||||
|
@ -86,12 +86,12 @@ func (uuc *userUsecase) Create(
|
||||
data, err := uuc.dbRepo.Transaction(
|
||||
ctx,
|
||||
func(txCtx context.Context, tx interface{}) (interface{}, error) {
|
||||
created, err := uuc.userRepo.Create(txCtx, tx, &model.UserEntity{
|
||||
created, err := uuc.userRepo.Create(txCtx, &model.UserEntity{
|
||||
Email: u.Email,
|
||||
Password: u.Password,
|
||||
FirstName: u.FirstName,
|
||||
LastName: u.LastName,
|
||||
})
|
||||
}, tx)
|
||||
if err != nil {
|
||||
match, _ := regexp.MatchString("SQLSTATE 23505", err.Error())
|
||||
if match {
|
||||
@ -132,7 +132,7 @@ func (uuc *userUsecase) Create(
|
||||
}
|
||||
|
||||
func (uuc *userUsecase) Exist(ctx context.Context, u *model.UserExistRequest) error {
|
||||
got, err := uuc.userRepo.GetByEmail(ctx, u.Email)
|
||||
got, err := uuc.userRepo.GetByEmail(ctx, u.Email, nil)
|
||||
// Any query error?
|
||||
if err != nil {
|
||||
return err
|
||||
@ -160,7 +160,7 @@ func (uuc *userUsecase) GetUserBaseResponseByID(
|
||||
// If not exists, get from the DB. And then put back
|
||||
// into the cache with a timeout.
|
||||
// Refresh the cache when the user data is updated (for now it cannot be updated)
|
||||
got, err := uuc.userRepo.GetByID(ctx, userID)
|
||||
got, err := uuc.userRepo.GetByID(ctx, userID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user