mirror of
https://github.com/lyx0/yaf.git
synced 2024-11-13 19:49:53 +01:00
Merge pull request #1 from leon-richardt/ref/check-file-existence
ref: check file existence instead of keeping in-memory set
This commit is contained in:
commit
3b247a932f
4 changed files with 38 additions and 136 deletions
14
jaf.go
14
jaf.go
|
@ -3,7 +3,6 @@ package main
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -13,8 +12,7 @@ import (
|
||||||
const allowedChars = "0123456789ABCDEFGHIJKLMNOPQRTSUVWXYZabcdefghijklmnopqrstuvwxyz"
|
const allowedChars = "0123456789ABCDEFGHIJKLMNOPQRTSUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
savedFileNames = NewSet()
|
config Config
|
||||||
config Config
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type parameters struct {
|
type parameters struct {
|
||||||
|
@ -42,16 +40,6 @@ func main() {
|
||||||
log.Fatalf("could not read config file: %s\n", err.Error())
|
log.Fatalf("could not read config file: %s\n", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
files, err := ioutil.ReadDir(config.FileDir)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("could not read file root %s: %s\n", config.FileDir, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cache taken file names on start-up
|
|
||||||
for _, fileInfo := range files {
|
|
||||||
savedFileNames.Insert(fileInfo.Name())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start server
|
// Start server
|
||||||
uploadServer := &http.Server{
|
uploadServer := &http.Server{
|
||||||
ReadTimeout: 30 * time.Second,
|
ReadTimeout: 30 * time.Second,
|
||||||
|
|
34
set.go
34
set.go
|
@ -1,34 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
type Set struct {
|
|
||||||
_map map[interface{}]struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSet() *Set {
|
|
||||||
set := &Set{}
|
|
||||||
set._map = make(map[interface{}]struct{})
|
|
||||||
return set
|
|
||||||
}
|
|
||||||
|
|
||||||
func (set *Set) Contains(value interface{}) bool {
|
|
||||||
_, ok := set._map[value]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (set *Set) Insert(value interface{}) bool {
|
|
||||||
if set.Contains(value) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
set._map[value] = struct{}{}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (set *Set) Remove(value interface{}) bool {
|
|
||||||
if !set.Contains(value) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(set._map, value)
|
|
||||||
return true
|
|
||||||
}
|
|
75
set_test.go
75
set_test.go
|
@ -1,75 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math/rand"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestContains(t *testing.T) {
|
|
||||||
set := NewSet()
|
|
||||||
|
|
||||||
// Oracle testing
|
|
||||||
dummy := 0
|
|
||||||
in := set.Contains(dummy)
|
|
||||||
if in {
|
|
||||||
t.Errorf("oracle > set.Contains(%d) = true before insertion", dummy)
|
|
||||||
}
|
|
||||||
|
|
||||||
set.Insert(dummy)
|
|
||||||
in = set.Contains(dummy)
|
|
||||||
if !in {
|
|
||||||
t.Errorf("oracle > set.Contains(%d) = false after insertion", dummy)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Property testing
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
|
||||||
const reps = 1000
|
|
||||||
for i := 0; i < reps; i++ {
|
|
||||||
lastInsert := rand.Int()
|
|
||||||
set.Insert(lastInsert)
|
|
||||||
|
|
||||||
in = set.Contains(lastInsert)
|
|
||||||
|
|
||||||
if !in {
|
|
||||||
t.Errorf("property > set.Contains(%d) = false after insertion", dummy)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInsert(t *testing.T) {
|
|
||||||
set := NewSet()
|
|
||||||
|
|
||||||
// Oracle testing
|
|
||||||
dummy := 0
|
|
||||||
innovative := set.Insert(dummy)
|
|
||||||
|
|
||||||
if !innovative {
|
|
||||||
t.Errorf("oracle > set.Insert(%d) = false but was innovative", dummy)
|
|
||||||
}
|
|
||||||
|
|
||||||
in := set.Contains(dummy)
|
|
||||||
if !in {
|
|
||||||
t.Errorf("oracle > set.Contains(%d) = false after insertion", dummy)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Duplicate insertion should return false
|
|
||||||
innovative = set.Insert(dummy)
|
|
||||||
if innovative {
|
|
||||||
t.Errorf("oracle > set.Insert(%d) = true but was not innovative", dummy)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Property testing
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
|
||||||
const reps = 1000
|
|
||||||
for i := 0; i < reps; i++ {
|
|
||||||
val := rand.Int()
|
|
||||||
|
|
||||||
inBefore := set.Contains(val)
|
|
||||||
innovative = set.Insert(val)
|
|
||||||
|
|
||||||
if inBefore && innovative {
|
|
||||||
t.Errorf("property > included value reported as innovative")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
@ -26,29 +27,45 @@ func (h *uploadHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
defer uploadFile.Close()
|
defer uploadFile.Close()
|
||||||
|
|
||||||
_, fileExtension := splitFileName(header.Filename)
|
_, fileExtension := splitFileName(header.Filename)
|
||||||
|
link, err := generateLink(h, &uploadFile, fileExtension)
|
||||||
// Find an unused file name
|
|
||||||
fileID := createRandomFileName(h.config.LinkLength)
|
|
||||||
for ; savedFileNames.Contains(fileID); fileID = createRandomFileName(h.config.LinkLength) {
|
|
||||||
}
|
|
||||||
|
|
||||||
fullFileName := fileID + fileExtension
|
|
||||||
savePath := h.config.FileDir + fullFileName
|
|
||||||
link := h.config.LinkPrefix + fullFileName
|
|
||||||
|
|
||||||
err = saveFile(uploadFile, savePath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "could not save file: "+err.Error(), http.StatusInternalServerError)
|
http.Error(w, "could not save file: "+err.Error(), http.StatusInternalServerError)
|
||||||
log.Println(" could not save file: " + err.Error())
|
log.Println(" could not save file: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
savedFileNames.Insert(fullFileName)
|
|
||||||
|
|
||||||
// Implicitly means code 200
|
// Implicitly means code 200
|
||||||
w.Write([]byte(link))
|
w.Write([]byte(link))
|
||||||
}
|
}
|
||||||
|
|
||||||
func saveFile(data multipart.File, name string) error {
|
// Generates a valid link to uploadFile with the specified file extension.
|
||||||
|
// Returns the link or an error in case of failure.
|
||||||
|
// Does not close the passed file pointer.
|
||||||
|
func generateLink(handler *uploadHandler, uploadFile *multipart.File, fileExtension string) (string, error) {
|
||||||
|
// Find an unused file name
|
||||||
|
var fullFileName string
|
||||||
|
var savePath string
|
||||||
|
for {
|
||||||
|
fileStem := createRandomFileName(handler.config.LinkLength)
|
||||||
|
fullFileName = fileStem + fileExtension
|
||||||
|
savePath = handler.config.FileDir + fullFileName
|
||||||
|
|
||||||
|
if !fileExists(savePath) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
link := handler.config.LinkPrefix + fullFileName
|
||||||
|
|
||||||
|
err := saveFile(uploadFile, savePath)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return link, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveFile(data *multipart.File, name string) error {
|
||||||
file, err := os.Create(name)
|
file, err := os.Create(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -56,7 +73,7 @@ func saveFile(data multipart.File, name string) error {
|
||||||
|
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
_, err = io.Copy(file, data)
|
_, err = io.Copy(file, *data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -64,6 +81,12 @@ func saveFile(data multipart.File, name string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func fileExists(filename string) bool {
|
||||||
|
_, err := os.Stat(filename)
|
||||||
|
|
||||||
|
return !errors.Is(err, os.ErrNotExist)
|
||||||
|
}
|
||||||
|
|
||||||
func createRandomFileName(length int) string {
|
func createRandomFileName(length int) string {
|
||||||
chars := make([]byte, length)
|
chars := make([]byte, length)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue