Writing tests for main package and for GET handlers

This commit is contained in:
Muyao CHEN 2024-07-01 22:34:16 +02:00
parent dc7ece7d12
commit 21478b20ae
8 changed files with 312 additions and 18 deletions

View File

@ -23,8 +23,25 @@ var (
// main is the main application function
func main() {
// what am I going to put in the session
gob.Register(models.Reservation{})
err := run()
if err != nil {
log.Fatal(err)
}
fmt.Printf("Starting application on port %s\n", portNumber)
srv := &http.Server{
Addr: portNumber,
Handler: routes(&app),
}
err = srv.ListenAndServe()
log.Fatal(err)
}
func run() error {
// what am I going to put in the session
gob.Register(models.Reservation{})
// change this to true when in production
app.InProduction = false
@ -48,14 +65,5 @@ func main() {
handlers.NewHandlers(repo)
render.NewTemplates(&app)
fmt.Printf("Starting application on port %s\n", portNumber)
srv := &http.Server{
Addr: portNumber,
Handler: routes(&app),
}
err = srv.ListenAndServe()
log.Fatal(err)
return nil
}

10
cmd/web/main_test.go Normal file
View File

@ -0,0 +1,10 @@
package main
import "testing"
func TestRun(t *testing.T) {
err := run()
if err != nil {
t.Error("failed run()")
}
}

View File

@ -0,0 +1,45 @@
package main
import (
"net/http"
"testing"
)
func TestNoSurf(t *testing.T) {
var myH myHandler
h := NoSurf(&myH)
switch v := h.(type) {
case http.Handler:
// do nothing
default:
t.Errorf("type is not http.Handler, but is %T", v)
}
}
func TestSessionLoad(t *testing.T) {
var myH myHandler
s := SessionLoad(&myH)
switch v := s.(type) {
case http.Handler:
// do nothing
default:
t.Errorf("type is not http.Handler, but is %T", v)
}
}
func TestWriteToConsole(t *testing.T) {
var myH myHandler
w := WriteToConsole(&myH)
switch v := w.(type) {
case http.Handler:
// do nothing
default:
t.Errorf("type is not http.Handler, but is %T", v)
}
}

21
cmd/web/routes_test.go Normal file
View File

@ -0,0 +1,21 @@
package main
import (
"go-udemy-web-1/internal/config"
"testing"
"github.com/go-chi/chi/v5"
)
func TestRoutes(t *testing.T) {
var app config.AppConfig
mux := routes(&app)
switch v := mux.(type) {
case *chi.Mux:
// do nothing
default:
t.Errorf("type is not *chi.Mux, is %T", v)
}
}

16
cmd/web/setup_test.go Normal file
View File

@ -0,0 +1,16 @@
package main
import (
"net/http"
"os"
"testing"
)
func TestMain(m *testing.M) {
os.Exit(m.Run())
}
type myHandler struct{}
func (mh *myHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

View File

@ -0,0 +1,48 @@
package handlers
import (
"net/http"
"net/http/httptest"
"testing"
)
type postData struct {
key string
value string
}
var theTests = []struct {
name string
url string
method string
params []postData
expectedStatusCode int
}{
{"home", "/", "GET", []postData{}, http.StatusOK},
{"about", "/about", "GET", []postData{}, http.StatusOK},
{"gq", "/generals-quarters", "GET", []postData{}, http.StatusOK},
{"ms", "/majors-suite", "GET", []postData{}, http.StatusOK},
{"sa", "/availability", "GET", []postData{}, http.StatusOK},
{"contact", "/contact", "GET", []postData{}, http.StatusOK},
{"ma", "/make-reservation", "GET", []postData{}, http.StatusOK},
}
func TestHandlers(t *testing.T) {
routes := getRoutes()
ts := httptest.NewTLSServer(routes)
defer ts.Close()
for _, e := range theTests {
if e.method == "GET" {
resp, err := ts.Client().Get(ts.URL + e.url)
if err != nil {
t.Log(err)
t.Fatal(err)
}
if resp.StatusCode != e.expectedStatusCode {
t.Errorf("for %s, expected %d but got %d", e.name, e.expectedStatusCode, resp.StatusCode)
}
} else {
}
}
}

View File

@ -0,0 +1,140 @@
package handlers
import (
"encoding/gob"
"fmt"
"go-udemy-web-1/internal/config"
"go-udemy-web-1/internal/models"
"go-udemy-web-1/internal/render"
"html/template"
"log"
"net/http"
"path/filepath"
"time"
"github.com/alexedwards/scs/v2"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/justinas/nosurf"
)
var functions = template.FuncMap{}
var (
app config.AppConfig
session *scs.SessionManager
)
func getRoutes() http.Handler {
gob.Register(models.Reservation{})
// change this to true when in production
app.InProduction = false
session = scs.New()
session.Lifetime = 24 * time.Hour
session.Cookie.Persist = true
session.Cookie.SameSite = http.SameSiteLaxMode
session.Cookie.Secure = app.InProduction
app.Session = session
tc, err := CreateTestTemplateCache()
if err != nil {
log.Fatalf("cannot create template cache: %s", err)
}
app.TemplateCahce = tc
app.UseCache = true // Not to use ./templates
repo := NewRepo(&app)
NewHandlers(repo)
render.NewTemplates(&app)
mux := chi.NewMux()
mux.Use(middleware.Recoverer)
mux.Use(NoSurf)
mux.Use(SessionLoad)
mux.Get("/", Repo.Home)
mux.Get("/about", Repo.About)
mux.Get("/contact", Repo.Contact)
mux.Get("/generals-quarters", Repo.Generals)
mux.Get("/majors-suite", Repo.Majors)
mux.Get("/availability", Repo.Availability)
mux.Post("/availability", Repo.PostAvailability)
mux.Post("/availability-json", Repo.AvailabilityJSON)
mux.Get("/make-reservation", Repo.MakeReservation)
mux.Post("/make-reservation", Repo.PostMakeReservation)
mux.Get("/reservation-summary", Repo.ReservationSummary)
fileServer := http.FileServer(http.Dir("./static/"))
mux.Handle("/static/*", http.StripPrefix("/static", fileServer))
return mux
}
// WriteToConsole writes a log when user hits a page
func WriteToConsole(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Printf("Hit the page %s\n", r.URL.String())
next.ServeHTTP(w, r)
})
}
// NoSurf adds CSRF protection to all POST requests
func NoSurf(next http.Handler) http.Handler {
csrfHandler := nosurf.New(next)
csrfHandler.SetBaseCookie(http.Cookie{
HttpOnly: true,
Path: "/",
Secure: app.InProduction,
SameSite: http.SameSiteLaxMode,
})
return csrfHandler
}
// SessionLoad loads and saves the session on every request
func SessionLoad(next http.Handler) http.Handler {
return session.LoadAndSave(next)
}
var pathToTemplates = "../../templates"
func CreateTestTemplateCache() (map[string]*template.Template, error) {
myCache := map[string]*template.Template{}
// get all of the files named *.page.tmpl from templates
pages, err := filepath.Glob(fmt.Sprintf("%s/*.page.tmpl", pathToTemplates))
if err != nil {
return myCache, err
}
// range through all files ending with *page.tmpl
for _, page := range pages {
name := filepath.Base(page)
ts, err := template.New(name).Funcs(functions).ParseFiles(page)
if err != nil {
return myCache, err
}
matches, err := filepath.Glob(fmt.Sprintf("%s/*.layout.tmpl", pathToTemplates))
if err != nil {
return myCache, err
}
if len(matches) > 0 {
ts, err = ts.ParseGlob(fmt.Sprintf("%s/*.layout.tmpl", pathToTemplates))
if err != nil {
return myCache, err
}
}
myCache[name] = ts
}
return myCache, nil
}

View File

@ -2,6 +2,7 @@ package render
import (
"bytes"
"fmt"
"go-udemy-web-1/internal/config"
"go-udemy-web-1/internal/models"
"html/template"
@ -12,7 +13,12 @@ import (
"github.com/justinas/nosurf"
)
var app *config.AppConfig
var functions = template.FuncMap{}
var (
app *config.AppConfig
pathToTemplates = "./templates"
)
// NewTemplates sets the config for the template package
func NewTemplates(a *config.AppConfig) {
@ -65,8 +71,8 @@ func RenderTemplate(w http.ResponseWriter, r *http.Request, tmpl string, td *mod
func CreateTemplateCache() (map[string]*template.Template, error) {
myCache := map[string]*template.Template{}
// get all of the files named *.page.tmpl from ./templates
pages, err := filepath.Glob("./templates/*.page.tmpl")
// get all of the files named *.page.tmpl from templates
pages, err := filepath.Glob(fmt.Sprintf("%s/*.page.tmpl", pathToTemplates))
if err != nil {
return myCache, err
}
@ -74,18 +80,18 @@ func CreateTemplateCache() (map[string]*template.Template, error) {
// range through all files ending with *page.tmpl
for _, page := range pages {
name := filepath.Base(page)
ts, err := template.New(name).ParseFiles(page)
ts, err := template.New(name).Funcs(functions).ParseFiles(page)
if err != nil {
return myCache, err
}
matches, err := filepath.Glob("./templates/*.layout.tmpl")
matches, err := filepath.Glob(fmt.Sprintf("%s/*.layout.tmpl", pathToTemplates))
if err != nil {
return myCache, err
}
if len(matches) > 0 {
ts, err = ts.ParseGlob("./templates/*.layout.tmpl")
ts, err = ts.ParseGlob(fmt.Sprintf("%s/*.layout.tmpl", pathToTemplates))
if err != nil {
return myCache, err
}