commit 2b9de725078820701cf440ebff2bb119205c25ed Author: Leon Richardt Date: Tue Oct 27 17:17:41 2020 +0100 initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..391c53e --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +# Binary +jaf diff --git a/config.go b/config.go new file mode 100644 index 0000000..1829906 --- /dev/null +++ b/config.go @@ -0,0 +1,82 @@ +package main + +import ( + "bufio" + "log" + "os" + "strconv" + "strings" +) + +const ( + commentPrefix = "#" +) + +type Config struct { + Port int + LinkPrefix string + FileDir string + LinkLength int +} + +func ConfigFromFile(filePath string) (*Config, error) { + file, err := os.Open(filePath) + if err != nil { + return nil, err + } + defer file.Close() + + oldPrefix := log.Prefix() + defer log.SetPrefix(oldPrefix) + + log.SetPrefix("config.FromFile > ") + + retval := &Config{} + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + if strings.HasPrefix(line, commentPrefix) { + // Skip comments + continue + } + + tokens := strings.Split(line, ": ") + if len(tokens) != 2 { + log.Printf("unexpected line: \"%s\", ignoring\n", line) + continue + } + + key, val := strings.TrimSpace(tokens[0]), strings.TrimSpace(tokens[1]) + + switch key { + case "Port": + parsed, err := strconv.Atoi(val) + if err != nil { + return nil, err + } + + retval.Port = parsed + case "LinkPrefix": + retval.LinkPrefix = val + case "FileDir": + retval.FileDir = val + case "LinkLength": + parsed, err := strconv.Atoi(val) + if err != nil { + return nil, err + } + + retval.LinkLength = parsed + default: + log.Printf("unexpected key: \"%s\", ignoring\n", key) + } + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return retval, nil +} diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..5e7a096 --- /dev/null +++ b/config_test.go @@ -0,0 +1,29 @@ +package main + +import ( + "testing" +) + +func assertEqualInt(have int, want int, t *testing.T) { + if have != want { + t.Errorf("have: %d, want: %d\n", have, want) + } +} + +func assertEqualString(have string, want string, t *testing.T) { + if have != want { + t.Errorf("have: %s, want: %s\n", have, want) + } +} + +func TestConfigFromFile(t *testing.T) { + config, err := ConfigFromFile("example.conf") + if err != nil { + panic(err) + } + + assertEqualInt(config.Port, 4711, t) + assertEqualString(config.LinkPrefix, "https://jaf.example.com/", t) + assertEqualString(config.FileDir, "/var/www/jaf.example.com/", t) + assertEqualInt(config.LinkLength, 5, t) +} diff --git a/example.conf b/example.conf new file mode 100644 index 0000000..46ca837 --- /dev/null +++ b/example.conf @@ -0,0 +1,5 @@ +Port: 4711 +# a comment +LinkPrefix: https://jaf.example.com/ +FileDir: /var/www/jaf.example.com/ +LinkLength: 5 diff --git a/jaf.go b/jaf.go new file mode 100644 index 0000000..cd62e39 --- /dev/null +++ b/jaf.go @@ -0,0 +1,65 @@ +package main + +import ( + "flag" + "fmt" + "io/ioutil" + "log" + "math/rand" + "net/http" + "time" +) + +const allowedChars = "0123456789ABCDEFGHIJKLMNOPQRTSUVWXYZabcdefghijklmnopqrstuvwxyz" + +var ( + savedFileNames = NewSet() + config Config +) + +type parameters struct { + configFile string +} + +func parseParams() *parameters { + configFile := flag.String("configFile", "jaf.conf", "path to config file") + flag.Parse() + + retval := ¶meters{} + retval.configFile = *configFile + return retval +} + +func main() { + rand.Seed(time.Now().UnixNano()) + log.SetPrefix("jaf > ") + + params := parseParams() + + // Read config + config, err := ConfigFromFile(params.configFile) + if err != nil { + log.Fatalf("could not read config file: %s\n", err.Error()) + } + + files, err := ioutil.ReadDir(config.FileDir) + if err != nil { + log.Fatalf("could not read file root %s: %s\n", config.FileDir, err.Error()) + } + + // Cache taken file names on start-up + for _, fileInfo := range files { + savedFileNames.Insert(fileInfo.Name()) + } + + // Start server + uploadServer := &http.Server{ + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + Addr: fmt.Sprintf(":%d", config.Port), + } + + log.Printf("starting jaf on port %d\n", config.Port) + http.Handle("/upload", &uploadHandler{config: config}) + uploadServer.ListenAndServe() +} diff --git a/set.go b/set.go new file mode 100644 index 0000000..547bcb0 --- /dev/null +++ b/set.go @@ -0,0 +1,34 @@ +package main + +type Set struct { + _map map[interface{}]struct{} +} + +func NewSet() *Set { + set := &Set{} + set._map = make(map[interface{}]struct{}) + return set +} + +func (set *Set) Contains(value interface{}) bool { + _, ok := set._map[value] + return ok +} + +func (set *Set) Insert(value interface{}) bool { + if set.Contains(value) { + return false + } + + set._map[value] = struct{}{} + return true +} + +func (set *Set) Remove(value interface{}) bool { + if !set.Contains(value) { + return false + } + + delete(set._map, value) + return true +} diff --git a/set_test.go b/set_test.go new file mode 100644 index 0000000..56282a8 --- /dev/null +++ b/set_test.go @@ -0,0 +1,75 @@ +package main + +import ( + "math/rand" + "testing" + "time" +) + +func TestContains(t *testing.T) { + set := NewSet() + + // Oracle testing + dummy := 0 + in := set.Contains(dummy) + if in { + t.Errorf("oracle > set.Contains(%d) = true before insertion", dummy) + } + + set.Insert(dummy) + in = set.Contains(dummy) + if !in { + t.Errorf("oracle > set.Contains(%d) = false after insertion", dummy) + } + + // Property testing + rand.Seed(time.Now().UnixNano()) + const reps = 1000 + for i := 0; i < reps; i++ { + lastInsert := rand.Int() + set.Insert(lastInsert) + + in = set.Contains(lastInsert) + + if !in { + t.Errorf("property > set.Contains(%d) = false after insertion", dummy) + } + } +} + +func TestInsert(t *testing.T) { + set := NewSet() + + // Oracle testing + dummy := 0 + innovative := set.Insert(dummy) + + if !innovative { + t.Errorf("oracle > set.Insert(%d) = false but was innovative", dummy) + } + + in := set.Contains(dummy) + if !in { + t.Errorf("oracle > set.Contains(%d) = false after insertion", dummy) + } + + // Duplicate insertion should return false + innovative = set.Insert(dummy) + if innovative { + t.Errorf("oracle > set.Insert(%d) = true but was not innovative", dummy) + } + + // Property testing + rand.Seed(time.Now().UnixNano()) + const reps = 1000 + for i := 0; i < reps; i++ { + val := rand.Int() + + inBefore := set.Contains(val) + innovative = set.Insert(val) + + if inBefore && innovative { + t.Errorf("property > included value reported as innovative") + } + } +} diff --git a/uploadhandler.go b/uploadhandler.go new file mode 100644 index 0000000..9fc4830 --- /dev/null +++ b/uploadhandler.go @@ -0,0 +1,91 @@ +package main + +import ( + "io" + "log" + "math/rand" + "mime/multipart" + "net/http" + "os" + "strings" +) + +type uploadHandler struct { + config *Config +} + +func (h *uploadHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + log.Println("request received from " + r.RemoteAddr) + + uploadFile, header, err := r.FormFile("file") + if err != nil { + http.Error(w, "could not read uploaded file: "+err.Error(), http.StatusBadRequest) + log.Println(" could not read uploaded file: " + err.Error()) + return + } + defer uploadFile.Close() + + originalName, fileExtension := splitFileName(header.Filename) + log.Println(" received file: " + originalName) + + // Find an unused file name + fileID := createRandomFileName(h.config.LinkLength) + for ; savedFileNames.Contains(fileID); fileID = createRandomFileName(h.config.LinkLength) { + } + log.Println(" generated random id: " + fileID) + + fullFileName := fileID + fileExtension + savePath := h.config.FileDir + fullFileName + link := h.config.LinkPrefix + fullFileName + + err = saveFile(uploadFile, savePath) + if err != nil { + http.Error(w, "could not save file: "+err.Error(), http.StatusInternalServerError) + log.Println(" could not save file: " + err.Error()) + return + } + savedFileNames.Insert(fullFileName) + log.Println(" saved file as: " + fullFileName) + + // Implicitly means code 200 + w.Write([]byte(link)) +} + +func saveFile(data multipart.File, name string) error { + file, err := os.Create(name) + if err != nil { + return err + } + + defer file.Close() + + _, err = io.Copy(file, data) + if err != nil { + return err + } + + return nil +} + +func createRandomFileName(length int) string { + chars := make([]byte, length) + + for i := 0; i < length; i++ { + index := rand.Intn(len(allowedChars)) + chars[i] = allowedChars[index] + } + + return string(chars) +} + +func splitFileName(name string) (string, string) { + extIndex := strings.LastIndex(name, ".") + + if extIndex == -1 { + // No dot at all + return name, "" + } + + return name[:extIndex], name[extIndex:] +}