Merge pull request #1 from leon-richardt/ref/check-file-existence

ref: check file existence instead of keeping in-memory set
This commit is contained in:
Leon Richardt 2022-06-21 18:18:35 +02:00 committed by GitHub
commit 3b247a932f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 136 deletions

14
jaf.go
View file

@ -3,7 +3,6 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"math/rand" "math/rand"
"net/http" "net/http"
@ -13,8 +12,7 @@ import (
const allowedChars = "0123456789ABCDEFGHIJKLMNOPQRTSUVWXYZabcdefghijklmnopqrstuvwxyz" const allowedChars = "0123456789ABCDEFGHIJKLMNOPQRTSUVWXYZabcdefghijklmnopqrstuvwxyz"
var ( var (
savedFileNames = NewSet() config Config
config Config
) )
type parameters struct { type parameters struct {
@ -42,16 +40,6 @@ func main() {
log.Fatalf("could not read config file: %s\n", err.Error()) log.Fatalf("could not read config file: %s\n", err.Error())
} }
files, err := ioutil.ReadDir(config.FileDir)
if err != nil {
log.Fatalf("could not read file root %s: %s\n", config.FileDir, err.Error())
}
// Cache taken file names on start-up
for _, fileInfo := range files {
savedFileNames.Insert(fileInfo.Name())
}
// Start server // Start server
uploadServer := &http.Server{ uploadServer := &http.Server{
ReadTimeout: 30 * time.Second, ReadTimeout: 30 * time.Second,

34
set.go
View file

@ -1,34 +0,0 @@
package main
type Set struct {
_map map[interface{}]struct{}
}
func NewSet() *Set {
set := &Set{}
set._map = make(map[interface{}]struct{})
return set
}
func (set *Set) Contains(value interface{}) bool {
_, ok := set._map[value]
return ok
}
func (set *Set) Insert(value interface{}) bool {
if set.Contains(value) {
return false
}
set._map[value] = struct{}{}
return true
}
func (set *Set) Remove(value interface{}) bool {
if !set.Contains(value) {
return false
}
delete(set._map, value)
return true
}

View file

@ -1,75 +0,0 @@
package main
import (
"math/rand"
"testing"
"time"
)
func TestContains(t *testing.T) {
set := NewSet()
// Oracle testing
dummy := 0
in := set.Contains(dummy)
if in {
t.Errorf("oracle > set.Contains(%d) = true before insertion", dummy)
}
set.Insert(dummy)
in = set.Contains(dummy)
if !in {
t.Errorf("oracle > set.Contains(%d) = false after insertion", dummy)
}
// Property testing
rand.Seed(time.Now().UnixNano())
const reps = 1000
for i := 0; i < reps; i++ {
lastInsert := rand.Int()
set.Insert(lastInsert)
in = set.Contains(lastInsert)
if !in {
t.Errorf("property > set.Contains(%d) = false after insertion", dummy)
}
}
}
func TestInsert(t *testing.T) {
set := NewSet()
// Oracle testing
dummy := 0
innovative := set.Insert(dummy)
if !innovative {
t.Errorf("oracle > set.Insert(%d) = false but was innovative", dummy)
}
in := set.Contains(dummy)
if !in {
t.Errorf("oracle > set.Contains(%d) = false after insertion", dummy)
}
// Duplicate insertion should return false
innovative = set.Insert(dummy)
if innovative {
t.Errorf("oracle > set.Insert(%d) = true but was not innovative", dummy)
}
// Property testing
rand.Seed(time.Now().UnixNano())
const reps = 1000
for i := 0; i < reps; i++ {
val := rand.Int()
inBefore := set.Contains(val)
innovative = set.Insert(val)
if inBefore && innovative {
t.Errorf("property > included value reported as innovative")
}
}
}

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"errors"
"io" "io"
"log" "log"
"math/rand" "math/rand"
@ -26,29 +27,45 @@ func (h *uploadHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer uploadFile.Close() defer uploadFile.Close()
_, fileExtension := splitFileName(header.Filename) _, fileExtension := splitFileName(header.Filename)
link, err := generateLink(h, &uploadFile, fileExtension)
// Find an unused file name
fileID := createRandomFileName(h.config.LinkLength)
for ; savedFileNames.Contains(fileID); fileID = createRandomFileName(h.config.LinkLength) {
}
fullFileName := fileID + fileExtension
savePath := h.config.FileDir + fullFileName
link := h.config.LinkPrefix + fullFileName
err = saveFile(uploadFile, savePath)
if err != nil { if err != nil {
http.Error(w, "could not save file: "+err.Error(), http.StatusInternalServerError) http.Error(w, "could not save file: "+err.Error(), http.StatusInternalServerError)
log.Println(" could not save file: " + err.Error()) log.Println(" could not save file: " + err.Error())
return return
} }
savedFileNames.Insert(fullFileName)
// Implicitly means code 200 // Implicitly means code 200
w.Write([]byte(link)) w.Write([]byte(link))
} }
func saveFile(data multipart.File, name string) error { // Generates a valid link to uploadFile with the specified file extension.
// Returns the link or an error in case of failure.
// Does not close the passed file pointer.
func generateLink(handler *uploadHandler, uploadFile *multipart.File, fileExtension string) (string, error) {
// Find an unused file name
var fullFileName string
var savePath string
for {
fileStem := createRandomFileName(handler.config.LinkLength)
fullFileName = fileStem + fileExtension
savePath = handler.config.FileDir + fullFileName
if !fileExists(savePath) {
break
}
}
link := handler.config.LinkPrefix + fullFileName
err := saveFile(uploadFile, savePath)
if err != nil {
return "", err
}
return link, nil
}
func saveFile(data *multipart.File, name string) error {
file, err := os.Create(name) file, err := os.Create(name)
if err != nil { if err != nil {
return err return err
@ -56,7 +73,7 @@ func saveFile(data multipart.File, name string) error {
defer file.Close() defer file.Close()
_, err = io.Copy(file, data) _, err = io.Copy(file, *data)
if err != nil { if err != nil {
return err return err
} }
@ -64,6 +81,12 @@ func saveFile(data multipart.File, name string) error {
return nil return nil
} }
func fileExists(filename string) bool {
_, err := os.Stat(filename)
return !errors.Is(err, os.ErrNotExist)
}
func createRandomFileName(length int) string { func createRandomFileName(length int) string {
chars := make([]byte, length) chars := make([]byte, length)