From 9ca7b2a898fb57bb71100eb55abd6dcbc8d293c1 Mon Sep 17 00:00:00 2001 From: Leon Richardt Date: Tue, 21 Jun 2022 17:28:07 +0200 Subject: [PATCH] ref: check file existence instead of keeping in-memory set The benefit of keeping all managed file names in memory instead of checking on demand does not outweigh the increased memory usage. Additionally, this method allows users to manually move files into the served directory without fearing they might be overwritten by jaf. --- jaf.go | 14 +-------- set.go | 34 ---------------------- set_test.go | 75 ------------------------------------------------ uploadhandler.go | 51 +++++++++++++++++++++++--------- 4 files changed, 38 insertions(+), 136 deletions(-) delete mode 100644 set.go delete mode 100644 set_test.go diff --git a/jaf.go b/jaf.go index cd62e39..7af4a36 100644 --- a/jaf.go +++ b/jaf.go @@ -3,7 +3,6 @@ package main import ( "flag" "fmt" - "io/ioutil" "log" "math/rand" "net/http" @@ -13,8 +12,7 @@ import ( const allowedChars = "0123456789ABCDEFGHIJKLMNOPQRTSUVWXYZabcdefghijklmnopqrstuvwxyz" var ( - savedFileNames = NewSet() - config Config + config Config ) type parameters struct { @@ -42,16 +40,6 @@ func main() { log.Fatalf("could not read config file: %s\n", err.Error()) } - files, err := ioutil.ReadDir(config.FileDir) - if err != nil { - log.Fatalf("could not read file root %s: %s\n", config.FileDir, err.Error()) - } - - // Cache taken file names on start-up - for _, fileInfo := range files { - savedFileNames.Insert(fileInfo.Name()) - } - // Start server uploadServer := &http.Server{ ReadTimeout: 30 * time.Second, diff --git a/set.go b/set.go deleted file mode 100644 index 547bcb0..0000000 --- a/set.go +++ /dev/null @@ -1,34 +0,0 @@ -package main - -type Set struct { - _map map[interface{}]struct{} -} - -func NewSet() *Set { - set := &Set{} - set._map = make(map[interface{}]struct{}) - return set -} - -func (set *Set) Contains(value interface{}) bool { - _, ok := set._map[value] - return ok -} - -func (set *Set) Insert(value interface{}) bool { - if set.Contains(value) { - return false - } - - set._map[value] = struct{}{} - return true -} - -func (set *Set) Remove(value interface{}) bool { - if !set.Contains(value) { - return false - } - - delete(set._map, value) - return true -} diff --git a/set_test.go b/set_test.go deleted file mode 100644 index 56282a8..0000000 --- a/set_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package main - -import ( - "math/rand" - "testing" - "time" -) - -func TestContains(t *testing.T) { - set := NewSet() - - // Oracle testing - dummy := 0 - in := set.Contains(dummy) - if in { - t.Errorf("oracle > set.Contains(%d) = true before insertion", dummy) - } - - set.Insert(dummy) - in = set.Contains(dummy) - if !in { - t.Errorf("oracle > set.Contains(%d) = false after insertion", dummy) - } - - // Property testing - rand.Seed(time.Now().UnixNano()) - const reps = 1000 - for i := 0; i < reps; i++ { - lastInsert := rand.Int() - set.Insert(lastInsert) - - in = set.Contains(lastInsert) - - if !in { - t.Errorf("property > set.Contains(%d) = false after insertion", dummy) - } - } -} - -func TestInsert(t *testing.T) { - set := NewSet() - - // Oracle testing - dummy := 0 - innovative := set.Insert(dummy) - - if !innovative { - t.Errorf("oracle > set.Insert(%d) = false but was innovative", dummy) - } - - in := set.Contains(dummy) - if !in { - t.Errorf("oracle > set.Contains(%d) = false after insertion", dummy) - } - - // Duplicate insertion should return false - innovative = set.Insert(dummy) - if innovative { - t.Errorf("oracle > set.Insert(%d) = true but was not innovative", dummy) - } - - // Property testing - rand.Seed(time.Now().UnixNano()) - const reps = 1000 - for i := 0; i < reps; i++ { - val := rand.Int() - - inBefore := set.Contains(val) - innovative = set.Insert(val) - - if inBefore && innovative { - t.Errorf("property > included value reported as innovative") - } - } -} diff --git a/uploadhandler.go b/uploadhandler.go index e2bcb88..fdd135b 100644 --- a/uploadhandler.go +++ b/uploadhandler.go @@ -1,6 +1,7 @@ package main import ( + "errors" "io" "log" "math/rand" @@ -26,29 +27,45 @@ func (h *uploadHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer uploadFile.Close() _, fileExtension := splitFileName(header.Filename) - - // Find an unused file name - fileID := createRandomFileName(h.config.LinkLength) - for ; savedFileNames.Contains(fileID); fileID = createRandomFileName(h.config.LinkLength) { - } - - fullFileName := fileID + fileExtension - savePath := h.config.FileDir + fullFileName - link := h.config.LinkPrefix + fullFileName - - err = saveFile(uploadFile, savePath) + link, err := generateLink(h, &uploadFile, fileExtension) if err != nil { http.Error(w, "could not save file: "+err.Error(), http.StatusInternalServerError) log.Println(" could not save file: " + err.Error()) return } - savedFileNames.Insert(fullFileName) // Implicitly means code 200 w.Write([]byte(link)) } -func saveFile(data multipart.File, name string) error { +// Generates a valid link to uploadFile with the specified file extension. +// Returns the link or an error in case of failure. +// Does not close the passed file pointer. +func generateLink(handler *uploadHandler, uploadFile *multipart.File, fileExtension string) (string, error) { + // Find an unused file name + var fullFileName string + var savePath string + for { + fileStem := createRandomFileName(handler.config.LinkLength) + fullFileName = fileStem + fileExtension + savePath = handler.config.FileDir + fullFileName + + if !fileExists(savePath) { + break + } + } + + link := handler.config.LinkPrefix + fullFileName + + err := saveFile(uploadFile, savePath) + if err != nil { + return "", err + } + + return link, nil +} + +func saveFile(data *multipart.File, name string) error { file, err := os.Create(name) if err != nil { return err @@ -56,7 +73,7 @@ func saveFile(data multipart.File, name string) error { defer file.Close() - _, err = io.Copy(file, data) + _, err = io.Copy(file, *data) if err != nil { return err } @@ -64,6 +81,12 @@ func saveFile(data multipart.File, name string) error { return nil } +func fileExists(filename string) bool { + _, err := os.Stat(filename) + + return !errors.Is(err, os.ErrNotExist) +} + func createRandomFileName(length int) string { chars := make([]byte, length)