diff --git a/cmd/web/main.go b/cmd/web/main.go index c6972a1..12bceec 100644 --- a/cmd/web/main.go +++ b/cmd/web/main.go @@ -23,8 +23,25 @@ var ( // main is the main application function func main() { - // what am I going to put in the session - gob.Register(models.Reservation{}) + err := run() + if err != nil { + log.Fatal(err) + } + + fmt.Printf("Starting application on port %s\n", portNumber) + + srv := &http.Server{ + Addr: portNumber, + Handler: routes(&app), + } + + err = srv.ListenAndServe() + log.Fatal(err) +} + +func run() error { + // what am I going to put in the session + gob.Register(models.Reservation{}) // change this to true when in production app.InProduction = false @@ -48,14 +65,5 @@ func main() { handlers.NewHandlers(repo) render.NewTemplates(&app) - - fmt.Printf("Starting application on port %s\n", portNumber) - - srv := &http.Server{ - Addr: portNumber, - Handler: routes(&app), - } - - err = srv.ListenAndServe() - log.Fatal(err) + return nil } diff --git a/cmd/web/main_test.go b/cmd/web/main_test.go new file mode 100644 index 0000000..acce8cb --- /dev/null +++ b/cmd/web/main_test.go @@ -0,0 +1,10 @@ +package main + +import "testing" + +func TestRun(t *testing.T) { + err := run() + if err != nil { + t.Error("failed run()") + } +} diff --git a/cmd/web/middleware_test.go b/cmd/web/middleware_test.go new file mode 100644 index 0000000..0838340 --- /dev/null +++ b/cmd/web/middleware_test.go @@ -0,0 +1,45 @@ +package main + +import ( + "net/http" + "testing" +) + +func TestNoSurf(t *testing.T) { + var myH myHandler + + h := NoSurf(&myH) + + switch v := h.(type) { + case http.Handler: + // do nothing + default: + t.Errorf("type is not http.Handler, but is %T", v) + } +} + +func TestSessionLoad(t *testing.T) { + var myH myHandler + + s := SessionLoad(&myH) + + switch v := s.(type) { + case http.Handler: + // do nothing + default: + t.Errorf("type is not http.Handler, but is %T", v) + } +} + +func TestWriteToConsole(t *testing.T) { + var myH myHandler + + w := WriteToConsole(&myH) + + switch v := w.(type) { + case http.Handler: + // do nothing + default: + t.Errorf("type is not http.Handler, but is %T", v) + } +} diff --git a/cmd/web/routes_test.go b/cmd/web/routes_test.go new file mode 100644 index 0000000..d9fe90a --- /dev/null +++ b/cmd/web/routes_test.go @@ -0,0 +1,21 @@ +package main + +import ( + "go-udemy-web-1/internal/config" + "testing" + + "github.com/go-chi/chi/v5" +) + +func TestRoutes(t *testing.T) { + var app config.AppConfig + + mux := routes(&app) + + switch v := mux.(type) { + case *chi.Mux: + // do nothing + default: + t.Errorf("type is not *chi.Mux, is %T", v) + } +} diff --git a/cmd/web/setup_test.go b/cmd/web/setup_test.go new file mode 100644 index 0000000..57b6405 --- /dev/null +++ b/cmd/web/setup_test.go @@ -0,0 +1,16 @@ +package main + +import ( + "net/http" + "os" + "testing" +) + +func TestMain(m *testing.M) { + os.Exit(m.Run()) +} + +type myHandler struct{} + +func (mh *myHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { +} diff --git a/internal/handlers/handlers_test.go b/internal/handlers/handlers_test.go new file mode 100644 index 0000000..310cc0e --- /dev/null +++ b/internal/handlers/handlers_test.go @@ -0,0 +1,48 @@ +package handlers + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +type postData struct { + key string + value string +} + +var theTests = []struct { + name string + url string + method string + params []postData + expectedStatusCode int +}{ + {"home", "/", "GET", []postData{}, http.StatusOK}, + {"about", "/about", "GET", []postData{}, http.StatusOK}, + {"gq", "/generals-quarters", "GET", []postData{}, http.StatusOK}, + {"ms", "/majors-suite", "GET", []postData{}, http.StatusOK}, + {"sa", "/availability", "GET", []postData{}, http.StatusOK}, + {"contact", "/contact", "GET", []postData{}, http.StatusOK}, + {"ma", "/make-reservation", "GET", []postData{}, http.StatusOK}, +} + +func TestHandlers(t *testing.T) { + routes := getRoutes() + ts := httptest.NewTLSServer(routes) + defer ts.Close() + + for _, e := range theTests { + if e.method == "GET" { + resp, err := ts.Client().Get(ts.URL + e.url) + if err != nil { + t.Log(err) + t.Fatal(err) + } + if resp.StatusCode != e.expectedStatusCode { + t.Errorf("for %s, expected %d but got %d", e.name, e.expectedStatusCode, resp.StatusCode) + } + } else { + } + } +} diff --git a/internal/handlers/setup_test.go b/internal/handlers/setup_test.go new file mode 100644 index 0000000..bc6187e --- /dev/null +++ b/internal/handlers/setup_test.go @@ -0,0 +1,140 @@ +package handlers + +import ( + "encoding/gob" + "fmt" + "go-udemy-web-1/internal/config" + "go-udemy-web-1/internal/models" + "go-udemy-web-1/internal/render" + "html/template" + "log" + "net/http" + "path/filepath" + "time" + + "github.com/alexedwards/scs/v2" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/justinas/nosurf" +) + +var functions = template.FuncMap{} + +var ( + app config.AppConfig + session *scs.SessionManager +) + +func getRoutes() http.Handler { + gob.Register(models.Reservation{}) + // change this to true when in production + app.InProduction = false + + session = scs.New() + session.Lifetime = 24 * time.Hour + session.Cookie.Persist = true + session.Cookie.SameSite = http.SameSiteLaxMode + session.Cookie.Secure = app.InProduction + + app.Session = session + + tc, err := CreateTestTemplateCache() + if err != nil { + log.Fatalf("cannot create template cache: %s", err) + } + app.TemplateCahce = tc + app.UseCache = true // Not to use ./templates + + repo := NewRepo(&app) + NewHandlers(repo) + + render.NewTemplates(&app) + + mux := chi.NewMux() + + mux.Use(middleware.Recoverer) + mux.Use(NoSurf) + mux.Use(SessionLoad) + + mux.Get("/", Repo.Home) + mux.Get("/about", Repo.About) + mux.Get("/contact", Repo.Contact) + mux.Get("/generals-quarters", Repo.Generals) + mux.Get("/majors-suite", Repo.Majors) + mux.Get("/availability", Repo.Availability) + mux.Post("/availability", Repo.PostAvailability) + mux.Post("/availability-json", Repo.AvailabilityJSON) + mux.Get("/make-reservation", Repo.MakeReservation) + mux.Post("/make-reservation", Repo.PostMakeReservation) + mux.Get("/reservation-summary", Repo.ReservationSummary) + + fileServer := http.FileServer(http.Dir("./static/")) + mux.Handle("/static/*", http.StripPrefix("/static", fileServer)) + + return mux +} + +// WriteToConsole writes a log when user hits a page +func WriteToConsole(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Printf("Hit the page %s\n", r.URL.String()) + + next.ServeHTTP(w, r) + }) +} + +// NoSurf adds CSRF protection to all POST requests +func NoSurf(next http.Handler) http.Handler { + csrfHandler := nosurf.New(next) + + csrfHandler.SetBaseCookie(http.Cookie{ + HttpOnly: true, + Path: "/", + Secure: app.InProduction, + SameSite: http.SameSiteLaxMode, + }) + + return csrfHandler +} + +// SessionLoad loads and saves the session on every request +func SessionLoad(next http.Handler) http.Handler { + return session.LoadAndSave(next) +} + +var pathToTemplates = "../../templates" + +func CreateTestTemplateCache() (map[string]*template.Template, error) { + myCache := map[string]*template.Template{} + + // get all of the files named *.page.tmpl from templates + pages, err := filepath.Glob(fmt.Sprintf("%s/*.page.tmpl", pathToTemplates)) + if err != nil { + return myCache, err + } + + // range through all files ending with *page.tmpl + for _, page := range pages { + name := filepath.Base(page) + ts, err := template.New(name).Funcs(functions).ParseFiles(page) + if err != nil { + return myCache, err + } + + matches, err := filepath.Glob(fmt.Sprintf("%s/*.layout.tmpl", pathToTemplates)) + if err != nil { + return myCache, err + } + + if len(matches) > 0 { + ts, err = ts.ParseGlob(fmt.Sprintf("%s/*.layout.tmpl", pathToTemplates)) + if err != nil { + return myCache, err + } + } + + myCache[name] = ts + } + + return myCache, nil +} diff --git a/internal/render/render.go b/internal/render/render.go index 17caaea..14ae06b 100644 --- a/internal/render/render.go +++ b/internal/render/render.go @@ -2,6 +2,7 @@ package render import ( "bytes" + "fmt" "go-udemy-web-1/internal/config" "go-udemy-web-1/internal/models" "html/template" @@ -12,7 +13,12 @@ import ( "github.com/justinas/nosurf" ) -var app *config.AppConfig +var functions = template.FuncMap{} + +var ( + app *config.AppConfig + pathToTemplates = "./templates" +) // NewTemplates sets the config for the template package func NewTemplates(a *config.AppConfig) { @@ -65,8 +71,8 @@ func RenderTemplate(w http.ResponseWriter, r *http.Request, tmpl string, td *mod func CreateTemplateCache() (map[string]*template.Template, error) { myCache := map[string]*template.Template{} - // get all of the files named *.page.tmpl from ./templates - pages, err := filepath.Glob("./templates/*.page.tmpl") + // get all of the files named *.page.tmpl from templates + pages, err := filepath.Glob(fmt.Sprintf("%s/*.page.tmpl", pathToTemplates)) if err != nil { return myCache, err } @@ -74,18 +80,18 @@ func CreateTemplateCache() (map[string]*template.Template, error) { // range through all files ending with *page.tmpl for _, page := range pages { name := filepath.Base(page) - ts, err := template.New(name).ParseFiles(page) + ts, err := template.New(name).Funcs(functions).ParseFiles(page) if err != nil { return myCache, err } - matches, err := filepath.Glob("./templates/*.layout.tmpl") + matches, err := filepath.Glob(fmt.Sprintf("%s/*.layout.tmpl", pathToTemplates)) if err != nil { return myCache, err } if len(matches) > 0 { - ts, err = ts.ParseGlob("./templates/*.layout.tmpl") + ts, err = ts.ParseGlob(fmt.Sprintf("%s/*.layout.tmpl", pathToTemplates)) if err != nil { return myCache, err }