mirror of
https://github.com/lyx0/yaf.git
synced 2024-11-13 19:49:53 +01:00
Merge pull request #1 from leon-richardt/ref/check-file-existence
ref: check file existence instead of keeping in-memory set
This commit is contained in:
commit
3b247a932f
4 changed files with 38 additions and 136 deletions
14
jaf.go
14
jaf.go
|
@ -3,7 +3,6 @@ package main
|
|||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
|
@ -13,8 +12,7 @@ import (
|
|||
const allowedChars = "0123456789ABCDEFGHIJKLMNOPQRTSUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
var (
|
||||
savedFileNames = NewSet()
|
||||
config Config
|
||||
config Config
|
||||
)
|
||||
|
||||
type parameters struct {
|
||||
|
@ -42,16 +40,6 @@ func main() {
|
|||
log.Fatalf("could not read config file: %s\n", err.Error())
|
||||
}
|
||||
|
||||
files, err := ioutil.ReadDir(config.FileDir)
|
||||
if err != nil {
|
||||
log.Fatalf("could not read file root %s: %s\n", config.FileDir, err.Error())
|
||||
}
|
||||
|
||||
// Cache taken file names on start-up
|
||||
for _, fileInfo := range files {
|
||||
savedFileNames.Insert(fileInfo.Name())
|
||||
}
|
||||
|
||||
// Start server
|
||||
uploadServer := &http.Server{
|
||||
ReadTimeout: 30 * time.Second,
|
||||
|
|
34
set.go
34
set.go
|
@ -1,34 +0,0 @@
|
|||
package main
|
||||
|
||||
type Set struct {
|
||||
_map map[interface{}]struct{}
|
||||
}
|
||||
|
||||
func NewSet() *Set {
|
||||
set := &Set{}
|
||||
set._map = make(map[interface{}]struct{})
|
||||
return set
|
||||
}
|
||||
|
||||
func (set *Set) Contains(value interface{}) bool {
|
||||
_, ok := set._map[value]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (set *Set) Insert(value interface{}) bool {
|
||||
if set.Contains(value) {
|
||||
return false
|
||||
}
|
||||
|
||||
set._map[value] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
func (set *Set) Remove(value interface{}) bool {
|
||||
if !set.Contains(value) {
|
||||
return false
|
||||
}
|
||||
|
||||
delete(set._map, value)
|
||||
return true
|
||||
}
|
75
set_test.go
75
set_test.go
|
@ -1,75 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestContains(t *testing.T) {
|
||||
set := NewSet()
|
||||
|
||||
// Oracle testing
|
||||
dummy := 0
|
||||
in := set.Contains(dummy)
|
||||
if in {
|
||||
t.Errorf("oracle > set.Contains(%d) = true before insertion", dummy)
|
||||
}
|
||||
|
||||
set.Insert(dummy)
|
||||
in = set.Contains(dummy)
|
||||
if !in {
|
||||
t.Errorf("oracle > set.Contains(%d) = false after insertion", dummy)
|
||||
}
|
||||
|
||||
// Property testing
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
const reps = 1000
|
||||
for i := 0; i < reps; i++ {
|
||||
lastInsert := rand.Int()
|
||||
set.Insert(lastInsert)
|
||||
|
||||
in = set.Contains(lastInsert)
|
||||
|
||||
if !in {
|
||||
t.Errorf("property > set.Contains(%d) = false after insertion", dummy)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsert(t *testing.T) {
|
||||
set := NewSet()
|
||||
|
||||
// Oracle testing
|
||||
dummy := 0
|
||||
innovative := set.Insert(dummy)
|
||||
|
||||
if !innovative {
|
||||
t.Errorf("oracle > set.Insert(%d) = false but was innovative", dummy)
|
||||
}
|
||||
|
||||
in := set.Contains(dummy)
|
||||
if !in {
|
||||
t.Errorf("oracle > set.Contains(%d) = false after insertion", dummy)
|
||||
}
|
||||
|
||||
// Duplicate insertion should return false
|
||||
innovative = set.Insert(dummy)
|
||||
if innovative {
|
||||
t.Errorf("oracle > set.Insert(%d) = true but was not innovative", dummy)
|
||||
}
|
||||
|
||||
// Property testing
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
const reps = 1000
|
||||
for i := 0; i < reps; i++ {
|
||||
val := rand.Int()
|
||||
|
||||
inBefore := set.Contains(val)
|
||||
innovative = set.Insert(val)
|
||||
|
||||
if inBefore && innovative {
|
||||
t.Errorf("property > included value reported as innovative")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
|
@ -26,29 +27,45 @@ func (h *uploadHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
defer uploadFile.Close()
|
||||
|
||||
_, fileExtension := splitFileName(header.Filename)
|
||||
|
||||
// Find an unused file name
|
||||
fileID := createRandomFileName(h.config.LinkLength)
|
||||
for ; savedFileNames.Contains(fileID); fileID = createRandomFileName(h.config.LinkLength) {
|
||||
}
|
||||
|
||||
fullFileName := fileID + fileExtension
|
||||
savePath := h.config.FileDir + fullFileName
|
||||
link := h.config.LinkPrefix + fullFileName
|
||||
|
||||
err = saveFile(uploadFile, savePath)
|
||||
link, err := generateLink(h, &uploadFile, fileExtension)
|
||||
if err != nil {
|
||||
http.Error(w, "could not save file: "+err.Error(), http.StatusInternalServerError)
|
||||
log.Println(" could not save file: " + err.Error())
|
||||
return
|
||||
}
|
||||
savedFileNames.Insert(fullFileName)
|
||||
|
||||
// Implicitly means code 200
|
||||
w.Write([]byte(link))
|
||||
}
|
||||
|
||||
func saveFile(data multipart.File, name string) error {
|
||||
// Generates a valid link to uploadFile with the specified file extension.
|
||||
// Returns the link or an error in case of failure.
|
||||
// Does not close the passed file pointer.
|
||||
func generateLink(handler *uploadHandler, uploadFile *multipart.File, fileExtension string) (string, error) {
|
||||
// Find an unused file name
|
||||
var fullFileName string
|
||||
var savePath string
|
||||
for {
|
||||
fileStem := createRandomFileName(handler.config.LinkLength)
|
||||
fullFileName = fileStem + fileExtension
|
||||
savePath = handler.config.FileDir + fullFileName
|
||||
|
||||
if !fileExists(savePath) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
link := handler.config.LinkPrefix + fullFileName
|
||||
|
||||
err := saveFile(uploadFile, savePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return link, nil
|
||||
}
|
||||
|
||||
func saveFile(data *multipart.File, name string) error {
|
||||
file, err := os.Create(name)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -56,7 +73,7 @@ func saveFile(data multipart.File, name string) error {
|
|||
|
||||
defer file.Close()
|
||||
|
||||
_, err = io.Copy(file, data)
|
||||
_, err = io.Copy(file, *data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -64,6 +81,12 @@ func saveFile(data multipart.File, name string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func fileExists(filename string) bool {
|
||||
_, err := os.Stat(filename)
|
||||
|
||||
return !errors.Is(err, os.ErrNotExist)
|
||||
}
|
||||
|
||||
func createRandomFileName(length int) string {
|
||||
chars := make([]byte, length)
|
||||
|
||||
|
|
Loading…
Reference in a new issue