From 93b3a0d7ec95c33dc327397ab1756c36853328ee Mon Sep 17 00:00:00 2001 From: Eason Lin Date: Sun, 16 Jul 2017 11:42:08 +0800 Subject: [PATCH] feat(context): add SaveUploadedFile func. (#1022) * feat(context): add SaveUploadedFile func. * feat(context): update multiple upload examples. * style(example): fix gofmt * fix(example): add missing return --- README.md | 8 ++++- context.go | 19 ++++++++++++ context_test.go | 42 +++++++++++++++++++++++++++ examples/upload-file/multiple/main.go | 20 ++----------- examples/upload-file/single/main.go | 20 ++----------- 5 files changed, 73 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 86001ae..ccc5bcc 100644 --- a/README.md +++ b/README.md @@ -275,7 +275,10 @@ func main() { // single file file, _ := c.FormFile("file") log.Println(file.Filename) - + + // Upload the file to specific dst. + // c.SaveUploadedFile(file, dst) + c.String(http.StatusOK, fmt.Sprintf("'%s' uploaded!", file.Filename)) }) router.Run(":8080") @@ -304,6 +307,9 @@ func main() { for _, file := range files { log.Println(file.Filename) + + // Upload the file to specific dst. + // c.SaveUploadedFile(file, dst) } c.String(http.StatusOK, fmt.Sprintf("%d files uploaded!", len(files))) }) diff --git a/context.go b/context.go index 198dd3e..f29464d 100644 --- a/context.go +++ b/context.go @@ -13,6 +13,7 @@ import ( "net" "net/http" "net/url" + "os" "strings" "time" @@ -431,6 +432,24 @@ func (c *Context) MultipartForm() (*multipart.Form, error) { return c.Request.MultipartForm, err } +// SaveUploadedFile uploads the form file to specific dst. +func (c *Context) SaveUploadedFile(file *multipart.FileHeader, dst string) error { + src, err := file.Open() + if err != nil { + return err + } + defer src.Close() + + out, err := os.Create(dst) + if err != nil { + return err + } + defer out.Close() + + io.Copy(out, src) + return nil +} + // Bind checks the Content-Type to select a binding engine automatically, // Depending the "Content-Type" header different bindings are used: // "application/json" --> JSON binding diff --git a/context_test.go b/context_test.go index 758fecd..db960fb 100644 --- a/context_test.go +++ b/context_test.go @@ -72,12 +72,18 @@ func TestContextFormFile(t *testing.T) { if assert.NoError(t, err) { assert.Equal(t, "test", f.Filename) } + + assert.NoError(t, c.SaveUploadedFile(f, "test")) } func TestContextMultipartForm(t *testing.T) { buf := new(bytes.Buffer) mw := multipart.NewWriter(buf) mw.WriteField("foo", "bar") + w, err := mw.CreateFormFile("file", "test") + if assert.NoError(t, err) { + w.Write([]byte("test")) + } mw.Close() c, _ := CreateTestContext(httptest.NewRecorder()) c.Request, _ = http.NewRequest("POST", "/", buf) @@ -86,6 +92,42 @@ func TestContextMultipartForm(t *testing.T) { if assert.NoError(t, err) { assert.NotNil(t, f) } + + assert.NoError(t, c.SaveUploadedFile(f.File["file"][0], "test")) +} + +func TestSaveUploadedOpenFailed(t *testing.T) { + buf := new(bytes.Buffer) + mw := multipart.NewWriter(buf) + mw.Close() + + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest("POST", "/", buf) + c.Request.Header.Set("Content-Type", mw.FormDataContentType()) + + f := &multipart.FileHeader{ + Filename: "file", + } + assert.Error(t, c.SaveUploadedFile(f, "test")) +} + +func TestSaveUploadedCreateFailed(t *testing.T) { + buf := new(bytes.Buffer) + mw := multipart.NewWriter(buf) + w, err := mw.CreateFormFile("file", "test") + if assert.NoError(t, err) { + w.Write([]byte("test")) + } + mw.Close() + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest("POST", "/", buf) + c.Request.Header.Set("Content-Type", mw.FormDataContentType()) + f, err := c.FormFile("file") + if assert.NoError(t, err) { + assert.Equal(t, "test", f.Filename) + } + + assert.Error(t, c.SaveUploadedFile(f, "/")) } func TestContextReset(t *testing.T) { diff --git a/examples/upload-file/multiple/main.go b/examples/upload-file/multiple/main.go index 2258834..4bb4cdc 100644 --- a/examples/upload-file/multiple/main.go +++ b/examples/upload-file/multiple/main.go @@ -2,9 +2,7 @@ package main import ( "fmt" - "io" "net/http" - "os" "github.com/gin-gonic/gin" ) @@ -25,24 +23,10 @@ func main() { files := form.File["files"] for _, file := range files { - // Source - src, err := file.Open() - if err != nil { - c.String(http.StatusBadRequest, fmt.Sprintf("file open err: %s", err.Error())) + if err := c.SaveUploadedFile(file, file.Filename); err != nil { + c.String(http.StatusBadRequest, fmt.Sprintf("upload file err: %s", err.Error())) return } - defer src.Close() - - // Destination - dst, err := os.Create(file.Filename) - if err != nil { - c.String(http.StatusBadRequest, fmt.Sprintf("Create file err: %s", err.Error())) - return - } - defer dst.Close() - - // Copy - io.Copy(dst, src) } c.String(http.StatusOK, fmt.Sprintf("Uploaded successfully %d files with fields name=%s and email=%s.", len(files), name, email)) diff --git a/examples/upload-file/single/main.go b/examples/upload-file/single/main.go index 1e9596c..372a299 100644 --- a/examples/upload-file/single/main.go +++ b/examples/upload-file/single/main.go @@ -2,9 +2,7 @@ package main import ( "fmt" - "io" "net/http" - "os" "github.com/gin-gonic/gin" ) @@ -22,23 +20,11 @@ func main() { c.String(http.StatusBadRequest, fmt.Sprintf("get form err: %s", err.Error())) return } - src, err := file.Open() - if err != nil { - c.String(http.StatusBadRequest, fmt.Sprintf("file open err: %s", err.Error())) + + if err := c.SaveUploadedFile(file, file.Filename); err != nil { + c.String(http.StatusBadRequest, fmt.Sprintf("upload file err: %s", err.Error())) return } - defer src.Close() - - // Destination - dst, err := os.Create(file.Filename) - if err != nil { - c.String(http.StatusBadRequest, fmt.Sprintf("Create file err: %s", err.Error())) - return - } - defer dst.Close() - - // Copy - io.Copy(dst, src) c.String(http.StatusOK, fmt.Sprintf("File %s uploaded successfully with fields name=%s and email=%s.", file.Filename, name, email)) })