diff --git a/blogposts/blogposts.go b/blogposts/blogposts.go index f959fc9..bd41183 100644 --- a/blogposts/blogposts.go +++ b/blogposts/blogposts.go @@ -1,9 +1,17 @@ package blogposts import ( + "errors" "io/fs" + "strings" ) +var markdownSuffix = map[string]struct{}{ + "md": {}, +} + +var ErrUnknownFileType = errors.New("unknown file type, must be markdown") + func NewPostsFromFS(fileSystem fs.FS) ([]Post, error) { dir, err := fs.ReadDir(fileSystem, ".") if err != nil { @@ -11,7 +19,11 @@ func NewPostsFromFS(fileSystem fs.FS) ([]Post, error) { } var posts []Post for _, f := range dir { - post, err := getPost(fileSystem, f.Name()) + fileName := f.Name() + if !isMarkdownFile(fileName) { + return nil, ErrUnknownFileType + } + post, err := getPost(fileSystem, fileName) if err != nil { return nil, err } @@ -19,3 +31,9 @@ func NewPostsFromFS(fileSystem fs.FS) ([]Post, error) { } return posts, nil } + +func isMarkdownFile(fileName string) bool { + splitted := strings.Split(fileName, ".") + _, ok := markdownSuffix[splitted[len(splitted)-1:][0]] + return ok +} diff --git a/blogposts/blogposts_test.go b/blogposts/blogposts_test.go index ca5f8f2..702aa15 100644 --- a/blogposts/blogposts_test.go +++ b/blogposts/blogposts_test.go @@ -56,6 +56,17 @@ World`, }) } +func TestWrongFile(t *testing.T) { + fs := fstest.MapFS{ + "hello world.txt": {Data: []byte("Yolo")}, + } + + _, err := NewPostsFromFS(fs) + if err == nil { + t.Errorf("should be an error but not") + } +} + func assertPost(t testing.TB, got Post, want Post) { t.Helper() if !reflect.DeepEqual(got, want) {