udemy-go-microservices/authentication-service/data/models.go

265 lines
5.8 KiB
Go

package data
import (
"context"
"database/sql"
"errors"
"log"
"time"
"golang.org/x/crypto/bcrypt"
)
const dbTimeout = time.Second * 3
var db *sql.DB
// New is the function used to create an instance of the data package. It returns the type
// Model, which embeds all the types we want to be available to our application.
func New(dbPool *sql.DB) Models {
db = dbPool
return Models{
User: User{},
}
}
// Models is the type for this package. Note that any model that is included as a member
// in this type is available to us throughout the application, anywhere that the
// app variable is used, provided that the model is also added in the New function.
type Models struct {
User User
}
// User is the structure which holds one user from the database.
type User struct {
ID int `json:"id"`
Email string `json:"email"`
FirstName string `json:"first_name,omitempty"`
LastName string `json:"last_name,omitempty"`
Password string `json:"-"`
Active int `json:"active"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// GetAll returns a slice of all users, sorted by last name
func (u *User) GetAll() ([]*User, error) {
ctx, cancel := context.WithTimeout(context.Background(), dbTimeout)
defer cancel()
query := `select id, email, first_name, last_name, password, user_active, created_at, updated_at
from users order by last_name`
rows, err := db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var users []*User
for rows.Next() {
var user User
err := rows.Scan(
&user.ID,
&user.Email,
&user.FirstName,
&user.LastName,
&user.Password,
&user.Active,
&user.CreatedAt,
&user.UpdatedAt,
)
if err != nil {
log.Println("Error scanning", err)
return nil, err
}
users = append(users, &user)
}
return users, nil
}
// GetByEmail returns one user by email
func (u *User) GetByEmail(email string) (*User, error) {
ctx, cancel := context.WithTimeout(context.Background(), dbTimeout)
defer cancel()
query := `select id, email, first_name, last_name, password, user_active, created_at, updated_at from users where email = $1`
var user User
row := db.QueryRowContext(ctx, query, email)
err := row.Scan(
&user.ID,
&user.Email,
&user.FirstName,
&user.LastName,
&user.Password,
&user.Active,
&user.CreatedAt,
&user.UpdatedAt,
)
if err != nil {
return nil, err
}
return &user, nil
}
// GetOne returns one user by id
func (u *User) GetOne(id int) (*User, error) {
ctx, cancel := context.WithTimeout(context.Background(), dbTimeout)
defer cancel()
query := `select id, email, first_name, last_name, password, user_active, created_at, updated_at from users where id = $1`
var user User
row := db.QueryRowContext(ctx, query, id)
err := row.Scan(
&user.ID,
&user.Email,
&user.FirstName,
&user.LastName,
&user.Password,
&user.Active,
&user.CreatedAt,
&user.UpdatedAt,
)
if err != nil {
return nil, err
}
return &user, nil
}
// Update updates one user in the database, using the information
// stored in the receiver u
func (u *User) Update() error {
ctx, cancel := context.WithTimeout(context.Background(), dbTimeout)
defer cancel()
stmt := `update users set
email = $1,
first_name = $2,
last_name = $3,
user_active = $4,
updated_at = $5
where id = $6
`
_, err := db.ExecContext(ctx, stmt,
u.Email,
u.FirstName,
u.LastName,
u.Active,
time.Now(),
u.ID,
)
if err != nil {
return err
}
return nil
}
// Delete deletes one user from the database, by User.ID
func (u *User) Delete() error {
ctx, cancel := context.WithTimeout(context.Background(), dbTimeout)
defer cancel()
stmt := `delete from users where id = $1`
_, err := db.ExecContext(ctx, stmt, u.ID)
if err != nil {
return err
}
return nil
}
// DeleteByID deletes one user from the database, by ID
func (u *User) DeleteByID(id int) error {
ctx, cancel := context.WithTimeout(context.Background(), dbTimeout)
defer cancel()
stmt := `delete from users where id = $1`
_, err := db.ExecContext(ctx, stmt, id)
if err != nil {
return err
}
return nil
}
// Insert inserts a new user into the database, and returns the ID of the newly inserted row
func (u *User) Insert(user User) (int, error) {
ctx, cancel := context.WithTimeout(context.Background(), dbTimeout)
defer cancel()
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), 12)
if err != nil {
return 0, err
}
var newID int
stmt := `insert into users (email, first_name, last_name, password, user_active, created_at, updated_at)
values ($1, $2, $3, $4, $5, $6, $7) returning id`
err = db.QueryRowContext(ctx, stmt,
user.Email,
user.FirstName,
user.LastName,
hashedPassword,
user.Active,
time.Now(),
time.Now(),
).Scan(&newID)
if err != nil {
return 0, err
}
return newID, nil
}
// ResetPassword is the method we will use to change a user's password.
func (u *User) ResetPassword(password string) error {
ctx, cancel := context.WithTimeout(context.Background(), dbTimeout)
defer cancel()
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), 12)
if err != nil {
return err
}
stmt := `update users set password = $1 where id = $2`
_, err = db.ExecContext(ctx, stmt, hashedPassword, u.ID)
if err != nil {
return err
}
return nil
}
// PasswordMatches uses Go's bcrypt package to compare a user supplied password
// with the hash we have stored for a given user in the database. If the password
// and hash match, we return true; otherwise, we return false.
func (u *User) PasswordMatches(plainText string) (bool, error) {
err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(plainText))
if err != nil {
switch {
case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword):
// invalid password
return false, nil
default:
return false, err
}
}
return true, nil
}