diff --git a/README.md b/README.md index 3ab8654..60fc879 100644 --- a/README.md +++ b/README.md @@ -577,3 +577,8 @@ But between `[]*T` and `[]T`, the only difference that I see (pointed out by `ChatGPT`) is how the memory is allocated. With `[]T` it might be better for the GC to deal with the memory free. I thing for my project I will stick to `[]T`. + +### 2024/10/25 + +Read this [article](https://konradreiche.com/blog/two-common-go-interface-misuses/) +today, maybe I am abusing the usage of interfaces? diff --git a/internal/howmuch/adapter/repo/db.go b/internal/howmuch/adapter/repo/db.go index e1db2f7..b949459 100644 --- a/internal/howmuch/adapter/repo/db.go +++ b/internal/howmuch/adapter/repo/db.go @@ -26,6 +26,7 @@ import ( "context" "database/sql" + "git.vinchent.xyz/vinchent/howmuch/internal/howmuch/adapter/repo/sqlc" "git.vinchent.xyz/vinchent/howmuch/internal/howmuch/usecase/repo" "git.vinchent.xyz/vinchent/howmuch/internal/pkg/log" ) @@ -66,3 +67,11 @@ func (dr *dbRepository) Transaction( data, err := txFunc(ctx, tx) return data, err } + +func getQueries(queries *sqlc.Queries, tx any) *sqlc.Queries { + transaction, ok := tx.(*sql.Tx) + if ok { + return sqlc.New(transaction) + } + return queries +} diff --git a/internal/howmuch/adapter/repo/event.go b/internal/howmuch/adapter/repo/event.go index d36f7ac..4944a45 100644 --- a/internal/howmuch/adapter/repo/event.go +++ b/internal/howmuch/adapter/repo/event.go @@ -48,8 +48,35 @@ func NewEventRepository(db *sql.DB) repo.EventRepository { func (e *eventRepository) Create( ctx context.Context, evEntity *model.EventEntity, + tx any, ) (*model.EventEntity, error) { - panic("unimplemented") + timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout) + defer cancel() + + queries := getQueries(e.queries, tx) + + event, err := queries.InsertEvent(timeoutCtx, sqlc.InsertEventParams{ + Name: evEntity.Name, + Description: sql.NullString{String: evEntity.Description, Valid: true}, + TotalAmount: sql.NullInt32{Int32: int32(evEntity.TotalAmount), Valid: true}, + DefaultCurrency: evEntity.DefaultCurrency, + OwnerID: int32(evEntity.OwnerID), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }) + if err != nil { + return nil, err + } + return &model.EventEntity{ + ID: int(event.ID), + Name: event.Name, + Description: event.Description.String, + TotalAmount: int(event.TotalAmount.Int32), + DefaultCurrency: event.DefaultCurrency, + OwnerID: int(event.OwnerID), + CreatedAt: event.CreatedAt, + UpdatedAt: event.UpdatedAt, + }, nil } func convToEventRetrieved(eventDTO *sqlc.GetEventByIDRow) (*model.EventRetrieved, error) { @@ -89,8 +116,17 @@ func convToEventRetrieved(eventDTO *sqlc.GetEventByIDRow) (*model.EventRetrieved } // GetByID implements repo.EventRepository. -func (e *eventRepository) GetByID(ctx context.Context, eventID int) (*model.EventRetrieved, error) { - eventDTO, err := e.queries.GetEventByID(ctx, int32(eventID)) +func (e *eventRepository) GetByID( + ctx context.Context, + eventID int, + tx any, +) (*model.EventRetrieved, error) { + timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout) + defer cancel() + + queries := getQueries(e.queries, tx) + + eventDTO, err := queries.GetEventByID(timeoutCtx, int32(eventID)) if err != nil { log.ErrorLog("query error", "err", err) return nil, err @@ -128,8 +164,14 @@ func convToEventList(eventsDTO []sqlc.ListEventsByUserIDRow) ([]model.EventListR func (e *eventRepository) ListEventsByUserID( ctx context.Context, userID int, + tx any, ) ([]model.EventListRetrieved, error) { - eventsDTO, err := e.queries.ListEventsByUserID(ctx, int32(userID)) + timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout) + defer cancel() + + queries := getQueries(e.queries, tx) + + eventsDTO, err := queries.ListEventsByUserID(timeoutCtx, int32(userID)) if err != nil { log.ErrorLog("query error", "err", err) return nil, err @@ -142,8 +184,14 @@ func (e *eventRepository) ListEventsByUserID( func (e *eventRepository) UpdateEventByID( ctx context.Context, event *model.EventUpdateEntity, + tx any, ) error { - err := e.queries.UpdateEventByID(ctx, sqlc.UpdateEventByIDParams{ + timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout) + defer cancel() + + queries := getQueries(e.queries, tx) + + err := queries.UpdateEventByID(timeoutCtx, sqlc.UpdateEventByIDParams{ ID: int32(event.ID), Name: event.Name, Description: sql.NullString{String: event.Description, Valid: true}, diff --git a/internal/howmuch/adapter/repo/expense.go b/internal/howmuch/adapter/repo/expense.go index c846eba..2fce177 100644 --- a/internal/howmuch/adapter/repo/expense.go +++ b/internal/howmuch/adapter/repo/expense.go @@ -45,21 +45,39 @@ func NewExpenseRepository(db *sql.DB) repo.ExpenseRepository { } // DeleteExpense implements repo.ExpenseRepository. -func (e *expenseRepository) DeleteExpense(ctx context.Context, expenseID int) error { - return e.queries.DeleteExpense(ctx, int32(expenseID)) +func (e *expenseRepository) DeleteExpense(ctx context.Context, expenseID int, tx any) error { + timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout) + defer cancel() + + queries := getQueries(e.queries, tx) + return queries.DeleteExpense(timeoutCtx, int32(expenseID)) } // DeleteTransactionsOfExpense implements repo.ExpenseRepository. -func (e *expenseRepository) DeleteTransactionsOfExpense(ctx context.Context, expenseID int) error { - return e.queries.DeleteTransactionsOfExpenseID(ctx, int32(expenseID)) +func (e *expenseRepository) DeleteTransactionsOfExpense( + ctx context.Context, + expenseID int, + tx any, +) error { + timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout) + defer cancel() + + queries := getQueries(e.queries, tx) + return queries.DeleteTransactionsOfExpenseID(timeoutCtx, int32(expenseID)) } // GetExpenseByID implements repo.ExpenseRepository. func (e *expenseRepository) GetExpenseByID( ctx context.Context, expenseID int, + tx any, ) (*model.ExpenseRetrieved, error) { - expenseDTO, err := e.queries.GetExpenseByID(ctx, int32(expenseID)) + timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout) + defer cancel() + + queries := getQueries(e.queries, tx) + + expenseDTO, err := queries.GetExpenseByID(timeoutCtx, int32(expenseID)) if err != nil { return nil, err } @@ -150,8 +168,14 @@ func convToExpenseRetrieved(expenseDTO *sqlc.GetExpenseByIDRow) (*model.ExpenseR func (e *expenseRepository) InsertExpense( ctx context.Context, expenseEntity *model.ExpenseEntity, + tx any, ) (*model.ExpenseEntity, error) { - expenseDTO, err := e.queries.InsertExpense(ctx, sqlc.InsertExpenseParams{ + timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout) + defer cancel() + + queries := getQueries(e.queries, tx) + + expenseDTO, err := queries.InsertExpense(timeoutCtx, sqlc.InsertExpenseParams{ CreatedAt: time.Now(), UpdatedAt: time.Now(), Amount: int32(expenseEntity.Amount), @@ -179,8 +203,14 @@ func (e *expenseRepository) InsertExpense( func (e *expenseRepository) ListExpensesByEventID( ctx context.Context, id int, + tx any, ) ([]model.ExpensesListRetrieved, error) { - listDTO, err := e.queries.ListExpensesByEventID(ctx, int32(id)) + timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout) + defer cancel() + + queries := getQueries(e.queries, tx) + + listDTO, err := queries.ListExpensesByEventID(timeoutCtx, int32(id)) if err != nil { return nil, err } @@ -206,8 +236,14 @@ func (e *expenseRepository) ListExpensesByEventID( func (e *expenseRepository) UpdateExpenseByID( ctx context.Context, expenseUpdate *model.ExpenseUpdateEntity, + tx any, ) (*model.ExpenseEntity, error) { - expenseDTO, err := e.queries.UpdateExpenseByID(ctx, sqlc.UpdateExpenseByIDParams{ + timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout) + defer cancel() + + queries := getQueries(e.queries, tx) + + expenseDTO, err := queries.UpdateExpenseByID(timeoutCtx, sqlc.UpdateExpenseByIDParams{ ID: int32(expenseUpdate.ID), UpdatedAt: time.Now(), Amount: int32(expenseUpdate.Amount), diff --git a/internal/howmuch/adapter/repo/user.go b/internal/howmuch/adapter/repo/user.go index 8dc5100..2a32fb7 100644 --- a/internal/howmuch/adapter/repo/user.go +++ b/internal/howmuch/adapter/repo/user.go @@ -34,43 +34,36 @@ import ( ) type userRepository struct { - querier *sqlc.Queries + queries *sqlc.Queries } -const insertTimeout = 1 * time.Second +const queryTimeout = 3 * time.Second func NewUserRepository(db *sql.DB) repo.UserRepository { return &userRepository{ - querier: sqlc.New(db), + queries: sqlc.New(db), } } // Create -func (ur *userRepository) Create( +func (u *userRepository) Create( ctx context.Context, - transaction interface{}, - u *model.UserEntity, + userEntity *model.UserEntity, + tx any, ) (*model.UserEntity, error) { - timeoutCtx, cancel := context.WithTimeout(ctx, insertTimeout) + timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout) defer cancel() - args := sqlc.InsertUserParams{ - Email: u.Email, - FirstName: u.FirstName, - LastName: u.LastName, - Password: u.Password, + queries := getQueries(u.queries, tx) + + userDB, err := queries.InsertUser(timeoutCtx, sqlc.InsertUserParams{ + Email: userEntity.Email, + FirstName: userEntity.FirstName, + LastName: userEntity.LastName, + Password: userEntity.Password, CreatedAt: time.Now(), UpdatedAt: time.Now(), - } - - tx, ok := transaction.(*sql.Tx) - if !ok { - return nil, errors.New("transaction is not a *sql.Tx") - } - - queries := sqlc.New(tx) - - userDB, err := queries.InsertUser(timeoutCtx, args) + }) if err != nil { return nil, err } @@ -87,8 +80,17 @@ func (ur *userRepository) Create( } // GetByEmail if not found, return nil for user but not error. -func (ur *userRepository) GetByEmail(ctx context.Context, email string) (*model.UserEntity, error) { - userDB, err := ur.querier.GetUserByEmail(ctx, email) +func (u *userRepository) GetByEmail( + ctx context.Context, + email string, + tx any, +) (*model.UserEntity, error) { + timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout) + defer cancel() + + queries := getQueries(u.queries, tx) + + userDB, err := queries.GetUserByEmail(timeoutCtx, email) if errors.Is(err, sql.ErrNoRows) { // No query error, but user not found return nil, nil @@ -107,8 +109,13 @@ func (ur *userRepository) GetByEmail(ctx context.Context, email string) (*model. }, nil } -func (ur *userRepository) GetByID(ctx context.Context, id int) (*model.UserEntity, error) { - userDB, err := ur.querier.GetUserByID(ctx, int32(id)) +func (u *userRepository) GetByID(ctx context.Context, id int, tx any) (*model.UserEntity, error) { + timeoutCtx, cancel := context.WithTimeout(ctx, queryTimeout) + defer cancel() + + queries := getQueries(u.queries, tx) + + userDB, err := queries.GetUserByID(timeoutCtx, int32(id)) if errors.Is(err, sql.ErrNoRows) { // No query error, but user not found return nil, nil diff --git a/internal/howmuch/usecase/repo/event.go b/internal/howmuch/usecase/repo/event.go index fbf69af..4a3b888 100644 --- a/internal/howmuch/usecase/repo/event.go +++ b/internal/howmuch/usecase/repo/event.go @@ -29,30 +29,36 @@ import ( ) type EventRepository interface { - Create(ctx context.Context, evEntity *model.EventEntity) (*model.EventEntity, error) + Create(ctx context.Context, evEntity *model.EventEntity, tx any) (*model.EventEntity, error) // UpdateEventByID updates the event related information (name, descriptions) - UpdateEventByID(ctx context.Context, event *model.EventUpdateEntity) error + UpdateEventByID(ctx context.Context, event *model.EventUpdateEntity, tx any) error - GetByID(ctx context.Context, eventID int) (*model.EventRetrieved, error) + GetByID(ctx context.Context, eventID int, tx any) (*model.EventRetrieved, error) // related to events of a user - ListEventsByUserID(ctx context.Context, userID int) ([]model.EventListRetrieved, error) + ListEventsByUserID(ctx context.Context, userID int, tx any) ([]model.EventListRetrieved, error) // CheckParticipation(ctx context.Context, userID, eventID int) error } type ExpenseRepository interface { - DeleteExpense(ctx context.Context, expenseID int) error - DeleteTransactionsOfExpense(ctx context.Context, expenseID int) error - GetExpenseByID(ctx context.Context, expenseID int) (*model.ExpenseRetrieved, error) + DeleteExpense(ctx context.Context, expenseID int, tx any) error + DeleteTransactionsOfExpense(ctx context.Context, expenseID int, tx any) error + GetExpenseByID(ctx context.Context, expenseID int, tx any) (*model.ExpenseRetrieved, error) InsertExpense( ctx context.Context, expenseEntity *model.ExpenseEntity, + tx any, ) (*model.ExpenseEntity, error) - ListExpensesByEventID(ctx context.Context, id int) ([]model.ExpensesListRetrieved, error) + ListExpensesByEventID( + ctx context.Context, + id int, + tx any, + ) ([]model.ExpensesListRetrieved, error) UpdateExpenseByID( ctx context.Context, expenseUpdate *model.ExpenseUpdateEntity, + tx any, ) (*model.ExpenseEntity, error) } diff --git a/internal/howmuch/usecase/repo/user.go b/internal/howmuch/usecase/repo/user.go index ff3e3e8..8fcf119 100644 --- a/internal/howmuch/usecase/repo/user.go +++ b/internal/howmuch/usecase/repo/user.go @@ -31,9 +31,9 @@ import ( type UserRepository interface { Create( ctx context.Context, - transaction interface{}, u *model.UserEntity, + tx any, ) (*model.UserEntity, error) - GetByEmail(ctx context.Context, email string) (*model.UserEntity, error) - GetByID(ctx context.Context, id int) (*model.UserEntity, error) + GetByEmail(ctx context.Context, email string, tx any) (*model.UserEntity, error) + GetByID(ctx context.Context, id int, tx any) (*model.UserEntity, error) } diff --git a/internal/howmuch/usecase/usecase/event.go b/internal/howmuch/usecase/usecase/event.go index 8a1c049..6b745c0 100644 --- a/internal/howmuch/usecase/usecase/event.go +++ b/internal/howmuch/usecase/usecase/event.go @@ -62,7 +62,7 @@ func (evuc *eventUsecase) CreateEvent( ctx context.Context, evRequest *model.EventCreateRequest, ) (*model.EventInfoResponse, error) { - // transfer evRequest to PO + // transfer evRequest to evEntity evEntity := &model.EventEntity{ Name: evRequest.Name, @@ -75,7 +75,8 @@ func (evuc *eventUsecase) CreateEvent( data, err := evuc.dbRepo.Transaction( ctx, func(txCtx context.Context, tx interface{}) (interface{}, error) { - created, err := evuc.eventRepo.Create(ctx, evEntity) + // Create + created, err := evuc.eventRepo.Create(ctx, evEntity, tx) if err != nil { return nil, err } @@ -104,6 +105,7 @@ func (evuc *eventUsecase) CreateEvent( ), Owner: ownerResponse, CreatedAt: created.CreatedAt, + UpdatedAt: created.UpdatedAt, } return evResponse, err }) diff --git a/internal/howmuch/usecase/usecase/repomock/testuserrepo.go b/internal/howmuch/usecase/usecase/repomock/testuserrepo.go index 29f6163..e09b866 100644 --- a/internal/howmuch/usecase/usecase/repomock/testuserrepo.go +++ b/internal/howmuch/usecase/usecase/repomock/testuserrepo.go @@ -36,8 +36,8 @@ type TestUserRepository struct{} func (tur *TestUserRepository) Create( ctx context.Context, - transaction interface{}, u *model.UserEntity, + tx any, ) (*model.UserEntity, error) { user := *u @@ -53,6 +53,7 @@ func (tur *TestUserRepository) Create( func (tur *TestUserRepository) GetByEmail( ctx context.Context, email string, + tx any, ) (*model.UserEntity, error) { hashedPwd, _ := bcrypt.GenerateFromPassword([]byte("strongHashed"), 12) switch email { @@ -71,7 +72,11 @@ func (tur *TestUserRepository) GetByEmail( return nil, UserTestDummyErr } -func (tur *TestUserRepository) GetByID(ctx context.Context, id int) (*model.UserEntity, error) { +func (tur *TestUserRepository) GetByID( + ctx context.Context, + id int, + tx any, +) (*model.UserEntity, error) { hashedPwd, _ := bcrypt.GenerateFromPassword([]byte("strongHashed"), 12) switch id { case 123: diff --git a/internal/howmuch/usecase/usecase/user.go b/internal/howmuch/usecase/usecase/user.go index b95b60a..60a974b 100644 --- a/internal/howmuch/usecase/usecase/user.go +++ b/internal/howmuch/usecase/usecase/user.go @@ -86,12 +86,12 @@ func (uuc *userUsecase) Create( data, err := uuc.dbRepo.Transaction( ctx, func(txCtx context.Context, tx interface{}) (interface{}, error) { - created, err := uuc.userRepo.Create(txCtx, tx, &model.UserEntity{ + created, err := uuc.userRepo.Create(txCtx, &model.UserEntity{ Email: u.Email, Password: u.Password, FirstName: u.FirstName, LastName: u.LastName, - }) + }, tx) if err != nil { match, _ := regexp.MatchString("SQLSTATE 23505", err.Error()) if match { @@ -132,7 +132,7 @@ func (uuc *userUsecase) Create( } func (uuc *userUsecase) Exist(ctx context.Context, u *model.UserExistRequest) error { - got, err := uuc.userRepo.GetByEmail(ctx, u.Email) + got, err := uuc.userRepo.GetByEmail(ctx, u.Email, nil) // Any query error? if err != nil { return err @@ -160,7 +160,7 @@ func (uuc *userUsecase) GetUserBaseResponseByID( // If not exists, get from the DB. And then put back // into the cache with a timeout. // Refresh the cache when the user data is updated (for now it cannot be updated) - got, err := uuc.userRepo.GetByID(ctx, userID) + got, err := uuc.userRepo.GetByID(ctx, userID, nil) if err != nil { return nil, err }