ref: check file existence instead of keeping in-memory set

The benefit of keeping all managed file names in memory instead of
checking on demand does not outweigh the increased memory usage.
Additionally, this method allows users to manually move files into the
served directory without fearing they might be overwritten by jaf.
This commit is contained in:
Leon Richardt 2022-06-21 17:28:07 +02:00
parent bcd34a7a33
commit 9ca7b2a898
No known key found for this signature in database
GPG key ID: AD8BDD6273FE8FC5
4 changed files with 38 additions and 136 deletions

12
jaf.go
View file

@ -3,7 +3,6 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"math/rand" "math/rand"
"net/http" "net/http"
@ -13,7 +12,6 @@ import (
const allowedChars = "0123456789ABCDEFGHIJKLMNOPQRTSUVWXYZabcdefghijklmnopqrstuvwxyz" const allowedChars = "0123456789ABCDEFGHIJKLMNOPQRTSUVWXYZabcdefghijklmnopqrstuvwxyz"
var ( var (
savedFileNames = NewSet()
config Config config Config
) )
@ -42,16 +40,6 @@ func main() {
log.Fatalf("could not read config file: %s\n", err.Error()) log.Fatalf("could not read config file: %s\n", err.Error())
} }
files, err := ioutil.ReadDir(config.FileDir)
if err != nil {
log.Fatalf("could not read file root %s: %s\n", config.FileDir, err.Error())
}
// Cache taken file names on start-up
for _, fileInfo := range files {
savedFileNames.Insert(fileInfo.Name())
}
// Start server // Start server
uploadServer := &http.Server{ uploadServer := &http.Server{
ReadTimeout: 30 * time.Second, ReadTimeout: 30 * time.Second,

34
set.go
View file

@ -1,34 +0,0 @@
package main
type Set struct {
_map map[interface{}]struct{}
}
func NewSet() *Set {
set := &Set{}
set._map = make(map[interface{}]struct{})
return set
}
func (set *Set) Contains(value interface{}) bool {
_, ok := set._map[value]
return ok
}
func (set *Set) Insert(value interface{}) bool {
if set.Contains(value) {
return false
}
set._map[value] = struct{}{}
return true
}
func (set *Set) Remove(value interface{}) bool {
if !set.Contains(value) {
return false
}
delete(set._map, value)
return true
}

View file

@ -1,75 +0,0 @@
package main
import (
"math/rand"
"testing"
"time"
)
func TestContains(t *testing.T) {
set := NewSet()
// Oracle testing
dummy := 0
in := set.Contains(dummy)
if in {
t.Errorf("oracle > set.Contains(%d) = true before insertion", dummy)
}
set.Insert(dummy)
in = set.Contains(dummy)
if !in {
t.Errorf("oracle > set.Contains(%d) = false after insertion", dummy)
}
// Property testing
rand.Seed(time.Now().UnixNano())
const reps = 1000
for i := 0; i < reps; i++ {
lastInsert := rand.Int()
set.Insert(lastInsert)
in = set.Contains(lastInsert)
if !in {
t.Errorf("property > set.Contains(%d) = false after insertion", dummy)
}
}
}
func TestInsert(t *testing.T) {
set := NewSet()
// Oracle testing
dummy := 0
innovative := set.Insert(dummy)
if !innovative {
t.Errorf("oracle > set.Insert(%d) = false but was innovative", dummy)
}
in := set.Contains(dummy)
if !in {
t.Errorf("oracle > set.Contains(%d) = false after insertion", dummy)
}
// Duplicate insertion should return false
innovative = set.Insert(dummy)
if innovative {
t.Errorf("oracle > set.Insert(%d) = true but was not innovative", dummy)
}
// Property testing
rand.Seed(time.Now().UnixNano())
const reps = 1000
for i := 0; i < reps; i++ {
val := rand.Int()
inBefore := set.Contains(val)
innovative = set.Insert(val)
if inBefore && innovative {
t.Errorf("property > included value reported as innovative")
}
}
}

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"errors"
"io" "io"
"log" "log"
"math/rand" "math/rand"
@ -26,29 +27,45 @@ func (h *uploadHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer uploadFile.Close() defer uploadFile.Close()
_, fileExtension := splitFileName(header.Filename) _, fileExtension := splitFileName(header.Filename)
link, err := generateLink(h, &uploadFile, fileExtension)
// Find an unused file name
fileID := createRandomFileName(h.config.LinkLength)
for ; savedFileNames.Contains(fileID); fileID = createRandomFileName(h.config.LinkLength) {
}
fullFileName := fileID + fileExtension
savePath := h.config.FileDir + fullFileName
link := h.config.LinkPrefix + fullFileName
err = saveFile(uploadFile, savePath)
if err != nil { if err != nil {
http.Error(w, "could not save file: "+err.Error(), http.StatusInternalServerError) http.Error(w, "could not save file: "+err.Error(), http.StatusInternalServerError)
log.Println(" could not save file: " + err.Error()) log.Println(" could not save file: " + err.Error())
return return
} }
savedFileNames.Insert(fullFileName)
// Implicitly means code 200 // Implicitly means code 200
w.Write([]byte(link)) w.Write([]byte(link))
} }
func saveFile(data multipart.File, name string) error { // Generates a valid link to uploadFile with the specified file extension.
// Returns the link or an error in case of failure.
// Does not close the passed file pointer.
func generateLink(handler *uploadHandler, uploadFile *multipart.File, fileExtension string) (string, error) {
// Find an unused file name
var fullFileName string
var savePath string
for {
fileStem := createRandomFileName(handler.config.LinkLength)
fullFileName = fileStem + fileExtension
savePath = handler.config.FileDir + fullFileName
if !fileExists(savePath) {
break
}
}
link := handler.config.LinkPrefix + fullFileName
err := saveFile(uploadFile, savePath)
if err != nil {
return "", err
}
return link, nil
}
func saveFile(data *multipart.File, name string) error {
file, err := os.Create(name) file, err := os.Create(name)
if err != nil { if err != nil {
return err return err
@ -56,7 +73,7 @@ func saveFile(data multipart.File, name string) error {
defer file.Close() defer file.Close()
_, err = io.Copy(file, data) _, err = io.Copy(file, *data)
if err != nil { if err != nil {
return err return err
} }
@ -64,6 +81,12 @@ func saveFile(data multipart.File, name string) error {
return nil return nil
} }
func fileExists(filename string) bool {
_, err := os.Stat(filename)
return !errors.Is(err, os.ErrNotExist)
}
func createRandomFileName(length int) string { func createRandomFileName(length int) string {
chars := make([]byte, length) chars := make([]byte, length)