mirror of
https://github.com/lyx0/yaf.git
synced 2024-11-13 19:49:53 +01:00
ref: check file existence instead of keeping in-memory set
The benefit of keeping all managed file names in memory instead of checking on demand does not outweigh the increased memory usage. Additionally, this method allows users to manually move files into the served directory without fearing they might be overwritten by jaf.
This commit is contained in:
parent
bcd34a7a33
commit
9ca7b2a898
4 changed files with 38 additions and 136 deletions
12
jaf.go
12
jaf.go
|
@ -3,7 +3,6 @@ package main
|
|||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
|
@ -13,7 +12,6 @@ import (
|
|||
const allowedChars = "0123456789ABCDEFGHIJKLMNOPQRTSUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
var (
|
||||
savedFileNames = NewSet()
|
||||
config Config
|
||||
)
|
||||
|
||||
|
@ -42,16 +40,6 @@ func main() {
|
|||
log.Fatalf("could not read config file: %s\n", err.Error())
|
||||
}
|
||||
|
||||
files, err := ioutil.ReadDir(config.FileDir)
|
||||
if err != nil {
|
||||
log.Fatalf("could not read file root %s: %s\n", config.FileDir, err.Error())
|
||||
}
|
||||
|
||||
// Cache taken file names on start-up
|
||||
for _, fileInfo := range files {
|
||||
savedFileNames.Insert(fileInfo.Name())
|
||||
}
|
||||
|
||||
// Start server
|
||||
uploadServer := &http.Server{
|
||||
ReadTimeout: 30 * time.Second,
|
||||
|
|
34
set.go
34
set.go
|
@ -1,34 +0,0 @@
|
|||
package main
|
||||
|
||||
type Set struct {
|
||||
_map map[interface{}]struct{}
|
||||
}
|
||||
|
||||
func NewSet() *Set {
|
||||
set := &Set{}
|
||||
set._map = make(map[interface{}]struct{})
|
||||
return set
|
||||
}
|
||||
|
||||
func (set *Set) Contains(value interface{}) bool {
|
||||
_, ok := set._map[value]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (set *Set) Insert(value interface{}) bool {
|
||||
if set.Contains(value) {
|
||||
return false
|
||||
}
|
||||
|
||||
set._map[value] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
func (set *Set) Remove(value interface{}) bool {
|
||||
if !set.Contains(value) {
|
||||
return false
|
||||
}
|
||||
|
||||
delete(set._map, value)
|
||||
return true
|
||||
}
|
75
set_test.go
75
set_test.go
|
@ -1,75 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestContains(t *testing.T) {
|
||||
set := NewSet()
|
||||
|
||||
// Oracle testing
|
||||
dummy := 0
|
||||
in := set.Contains(dummy)
|
||||
if in {
|
||||
t.Errorf("oracle > set.Contains(%d) = true before insertion", dummy)
|
||||
}
|
||||
|
||||
set.Insert(dummy)
|
||||
in = set.Contains(dummy)
|
||||
if !in {
|
||||
t.Errorf("oracle > set.Contains(%d) = false after insertion", dummy)
|
||||
}
|
||||
|
||||
// Property testing
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
const reps = 1000
|
||||
for i := 0; i < reps; i++ {
|
||||
lastInsert := rand.Int()
|
||||
set.Insert(lastInsert)
|
||||
|
||||
in = set.Contains(lastInsert)
|
||||
|
||||
if !in {
|
||||
t.Errorf("property > set.Contains(%d) = false after insertion", dummy)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsert(t *testing.T) {
|
||||
set := NewSet()
|
||||
|
||||
// Oracle testing
|
||||
dummy := 0
|
||||
innovative := set.Insert(dummy)
|
||||
|
||||
if !innovative {
|
||||
t.Errorf("oracle > set.Insert(%d) = false but was innovative", dummy)
|
||||
}
|
||||
|
||||
in := set.Contains(dummy)
|
||||
if !in {
|
||||
t.Errorf("oracle > set.Contains(%d) = false after insertion", dummy)
|
||||
}
|
||||
|
||||
// Duplicate insertion should return false
|
||||
innovative = set.Insert(dummy)
|
||||
if innovative {
|
||||
t.Errorf("oracle > set.Insert(%d) = true but was not innovative", dummy)
|
||||
}
|
||||
|
||||
// Property testing
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
const reps = 1000
|
||||
for i := 0; i < reps; i++ {
|
||||
val := rand.Int()
|
||||
|
||||
inBefore := set.Contains(val)
|
||||
innovative = set.Insert(val)
|
||||
|
||||
if inBefore && innovative {
|
||||
t.Errorf("property > included value reported as innovative")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
|
@ -26,29 +27,45 @@ func (h *uploadHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
defer uploadFile.Close()
|
||||
|
||||
_, fileExtension := splitFileName(header.Filename)
|
||||
|
||||
// Find an unused file name
|
||||
fileID := createRandomFileName(h.config.LinkLength)
|
||||
for ; savedFileNames.Contains(fileID); fileID = createRandomFileName(h.config.LinkLength) {
|
||||
}
|
||||
|
||||
fullFileName := fileID + fileExtension
|
||||
savePath := h.config.FileDir + fullFileName
|
||||
link := h.config.LinkPrefix + fullFileName
|
||||
|
||||
err = saveFile(uploadFile, savePath)
|
||||
link, err := generateLink(h, &uploadFile, fileExtension)
|
||||
if err != nil {
|
||||
http.Error(w, "could not save file: "+err.Error(), http.StatusInternalServerError)
|
||||
log.Println(" could not save file: " + err.Error())
|
||||
return
|
||||
}
|
||||
savedFileNames.Insert(fullFileName)
|
||||
|
||||
// Implicitly means code 200
|
||||
w.Write([]byte(link))
|
||||
}
|
||||
|
||||
func saveFile(data multipart.File, name string) error {
|
||||
// Generates a valid link to uploadFile with the specified file extension.
|
||||
// Returns the link or an error in case of failure.
|
||||
// Does not close the passed file pointer.
|
||||
func generateLink(handler *uploadHandler, uploadFile *multipart.File, fileExtension string) (string, error) {
|
||||
// Find an unused file name
|
||||
var fullFileName string
|
||||
var savePath string
|
||||
for {
|
||||
fileStem := createRandomFileName(handler.config.LinkLength)
|
||||
fullFileName = fileStem + fileExtension
|
||||
savePath = handler.config.FileDir + fullFileName
|
||||
|
||||
if !fileExists(savePath) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
link := handler.config.LinkPrefix + fullFileName
|
||||
|
||||
err := saveFile(uploadFile, savePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return link, nil
|
||||
}
|
||||
|
||||
func saveFile(data *multipart.File, name string) error {
|
||||
file, err := os.Create(name)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -56,7 +73,7 @@ func saveFile(data multipart.File, name string) error {
|
|||
|
||||
defer file.Close()
|
||||
|
||||
_, err = io.Copy(file, data)
|
||||
_, err = io.Copy(file, *data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -64,6 +81,12 @@ func saveFile(data multipart.File, name string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func fileExists(filename string) bool {
|
||||
_, err := os.Stat(filename)
|
||||
|
||||
return !errors.Is(err, os.ErrNotExist)
|
||||
}
|
||||
|
||||
func createRandomFileName(length int) string {
|
||||
chars := make([]byte, length)
|
||||
|
||||
|
|
Loading…
Reference in a new issue