diff --git a/jaf.go b/jaf.go index cd62e39..7af4a36 100644 --- a/jaf.go +++ b/jaf.go @@ -3,7 +3,6 @@ package main import ( "flag" "fmt" - "io/ioutil" "log" "math/rand" "net/http" @@ -13,8 +12,7 @@ import ( const allowedChars = "0123456789ABCDEFGHIJKLMNOPQRTSUVWXYZabcdefghijklmnopqrstuvwxyz" var ( - savedFileNames = NewSet() - config Config + config Config ) type parameters struct { @@ -42,16 +40,6 @@ func main() { log.Fatalf("could not read config file: %s\n", err.Error()) } - files, err := ioutil.ReadDir(config.FileDir) - if err != nil { - log.Fatalf("could not read file root %s: %s\n", config.FileDir, err.Error()) - } - - // Cache taken file names on start-up - for _, fileInfo := range files { - savedFileNames.Insert(fileInfo.Name()) - } - // Start server uploadServer := &http.Server{ ReadTimeout: 30 * time.Second, diff --git a/set.go b/set.go deleted file mode 100644 index 547bcb0..0000000 --- a/set.go +++ /dev/null @@ -1,34 +0,0 @@ -package main - -type Set struct { - _map map[interface{}]struct{} -} - -func NewSet() *Set { - set := &Set{} - set._map = make(map[interface{}]struct{}) - return set -} - -func (set *Set) Contains(value interface{}) bool { - _, ok := set._map[value] - return ok -} - -func (set *Set) Insert(value interface{}) bool { - if set.Contains(value) { - return false - } - - set._map[value] = struct{}{} - return true -} - -func (set *Set) Remove(value interface{}) bool { - if !set.Contains(value) { - return false - } - - delete(set._map, value) - return true -} diff --git a/set_test.go b/set_test.go deleted file mode 100644 index 56282a8..0000000 --- a/set_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package main - -import ( - "math/rand" - "testing" - "time" -) - -func TestContains(t *testing.T) { - set := NewSet() - - // Oracle testing - dummy := 0 - in := set.Contains(dummy) - if in { - t.Errorf("oracle > set.Contains(%d) = true before insertion", dummy) - } - - set.Insert(dummy) - in = set.Contains(dummy) - if !in { - t.Errorf("oracle > set.Contains(%d) = false after insertion", dummy) - } - - // Property testing - rand.Seed(time.Now().UnixNano()) - const reps = 1000 - for i := 0; i < reps; i++ { - lastInsert := rand.Int() - set.Insert(lastInsert) - - in = set.Contains(lastInsert) - - if !in { - t.Errorf("property > set.Contains(%d) = false after insertion", dummy) - } - } -} - -func TestInsert(t *testing.T) { - set := NewSet() - - // Oracle testing - dummy := 0 - innovative := set.Insert(dummy) - - if !innovative { - t.Errorf("oracle > set.Insert(%d) = false but was innovative", dummy) - } - - in := set.Contains(dummy) - if !in { - t.Errorf("oracle > set.Contains(%d) = false after insertion", dummy) - } - - // Duplicate insertion should return false - innovative = set.Insert(dummy) - if innovative { - t.Errorf("oracle > set.Insert(%d) = true but was not innovative", dummy) - } - - // Property testing - rand.Seed(time.Now().UnixNano()) - const reps = 1000 - for i := 0; i < reps; i++ { - val := rand.Int() - - inBefore := set.Contains(val) - innovative = set.Insert(val) - - if inBefore && innovative { - t.Errorf("property > included value reported as innovative") - } - } -} diff --git a/uploadhandler.go b/uploadhandler.go index e2bcb88..fdd135b 100644 --- a/uploadhandler.go +++ b/uploadhandler.go @@ -1,6 +1,7 @@ package main import ( + "errors" "io" "log" "math/rand" @@ -26,29 +27,45 @@ func (h *uploadHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer uploadFile.Close() _, fileExtension := splitFileName(header.Filename) - - // Find an unused file name - fileID := createRandomFileName(h.config.LinkLength) - for ; savedFileNames.Contains(fileID); fileID = createRandomFileName(h.config.LinkLength) { - } - - fullFileName := fileID + fileExtension - savePath := h.config.FileDir + fullFileName - link := h.config.LinkPrefix + fullFileName - - err = saveFile(uploadFile, savePath) + link, err := generateLink(h, &uploadFile, fileExtension) if err != nil { http.Error(w, "could not save file: "+err.Error(), http.StatusInternalServerError) log.Println(" could not save file: " + err.Error()) return } - savedFileNames.Insert(fullFileName) // Implicitly means code 200 w.Write([]byte(link)) } -func saveFile(data multipart.File, name string) error { +// Generates a valid link to uploadFile with the specified file extension. +// Returns the link or an error in case of failure. +// Does not close the passed file pointer. +func generateLink(handler *uploadHandler, uploadFile *multipart.File, fileExtension string) (string, error) { + // Find an unused file name + var fullFileName string + var savePath string + for { + fileStem := createRandomFileName(handler.config.LinkLength) + fullFileName = fileStem + fileExtension + savePath = handler.config.FileDir + fullFileName + + if !fileExists(savePath) { + break + } + } + + link := handler.config.LinkPrefix + fullFileName + + err := saveFile(uploadFile, savePath) + if err != nil { + return "", err + } + + return link, nil +} + +func saveFile(data *multipart.File, name string) error { file, err := os.Create(name) if err != nil { return err @@ -56,7 +73,7 @@ func saveFile(data multipart.File, name string) error { defer file.Close() - _, err = io.Copy(file, data) + _, err = io.Copy(file, *data) if err != nil { return err } @@ -64,6 +81,12 @@ func saveFile(data multipart.File, name string) error { return nil } +func fileExists(filename string) bool { + _, err := os.Stat(filename) + + return !errors.Is(err, os.ErrNotExist) +} + func createRandomFileName(length int) string { chars := make([]byte, length)