diff --git a/main.go b/main.go index 4714d62..54e60e5 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "flag" "fmt" "html/template" "os" @@ -11,15 +12,19 @@ import ( ) func main() { + // 设置上传文件夹 + // 定义命令行参数 -u 或 --upload + uploadDir := flag.String("u", "./uploads", "upload directory") + flag.Parse() + + // 创建目录(如果不存在) + os.MkdirAll(*uploadDir, 0755) + r := gin.Default() // 注册模板函数 r.SetFuncMap(template.FuncMap{ "hasSuffix": strings.HasSuffix, }) - // 设置上传文件夹 - uploadDir := "./uploads" - os.MkdirAll(uploadDir, 0755) - // 加载模板 r.LoadHTMLGlob("templates/*") @@ -27,7 +32,7 @@ func main() { r.GET("/", func(c *gin.Context) { var files []string - filepath.Walk(uploadDir, func(path string, info os.FileInfo, err error) error { + filepath.Walk(*uploadDir, func(path string, info os.FileInfo, err error) error { if !info.IsDir() { files = append(files, info.Name()) } @@ -55,7 +60,7 @@ func main() { } for _, file := range files { - dst := filepath.Join(uploadDir, file.Filename) + dst := filepath.Join(*uploadDir, file.Filename) if err := c.SaveUploadedFile(file, dst); err != nil { c.JSON(500, gin.H{"error": err.Error()}) return @@ -69,7 +74,7 @@ func main() { // 删除文件 r.DELETE("/delete/:filename", func(c *gin.Context) { filename := c.Param("filename") - path := filepath.Join(uploadDir, filename) + path := filepath.Join(*uploadDir, filename) if _, err := os.Stat(path); os.IsNotExist(err) { c.JSON(404, gin.H{"error": "File not found"}) @@ -87,10 +92,11 @@ func main() { // 下载接口 r.GET("/download/:filename", func(c *gin.Context) { filename := c.Param("filename") - filepath := filepath.Join(uploadDir, filename) + filepath := filepath.Join(*uploadDir, filename) c.FileAttachment(filepath, filename) }) + fmt.Println("upload dir =", *uploadDir) r.Run("0.0.0.0:8080") fmt.Println("启动") }