Files
Arkie-Library-Backend/internal/handlers/upload.go

144 lines
4.1 KiB
Go
Raw Normal View History

2026-05-16 00:18:22 +08:00
package handlers
import (
"bytes"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
2026-05-19 07:37:25 +08:00
"github.com/aws/aws-sdk-go-v2/service/s3/types"
2026-05-16 00:18:22 +08:00
"github.com/google/uuid"
)
const uploadMaxBytes = 512 << 20 // 512 MiB upper bound per upload
// UploadDeps configures admin multipart upload (local disk and/or S3).
type UploadDeps struct {
LocalDir string
MaxMultipartMem int64
S3 *s3.Client
S3Bucket string
AWSRegion string
S3Prefix string // e.g. "uploads" (no leading/trailing slashes)
S3PublicBase string // optional, e.g. https://cdn.example.com — else virtual-hosted S3 URL
2026-05-19 07:37:25 +08:00
S3ObjectACL string // optional canned ACL, e.g. public-read (bucket must allow ACLs)
2026-05-16 00:18:22 +08:00
}
func UploadFile(d UploadDeps) http.HandlerFunc {
maxMem := d.MaxMultipartMem
if maxMem <= 0 {
maxMem = 32 << 20
}
return func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(maxMem); err != nil {
http.Error(w, "multipart required", http.StatusBadRequest)
return
}
file, hdr, err := r.FormFile("file")
if err != nil {
http.Error(w, "file field required", http.StatusBadRequest)
return
}
defer file.Close()
ext := filepath.Ext(hdr.Filename)
if ext == "" {
ext = ".bin"
}
name := uuid.NewString() + ext
data, err := io.ReadAll(io.LimitReader(file, uploadMaxBytes+1))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if len(data) > uploadMaxBytes {
http.Error(w, "file too large", http.StatusRequestEntityTooLarge)
return
}
ct := hdr.Header.Get("Content-Type")
if ct == "" {
ct = http.DetectContentType(data)
}
if d.S3 != nil && strings.TrimSpace(d.S3Bucket) != "" {
pfx := strings.Trim(strings.TrimSpace(d.S3Prefix), "/")
if pfx == "" {
pfx = "uploads"
}
key := pfx + "/" + name
ctx := r.Context()
2026-05-19 07:37:25 +08:00
put := &s3.PutObjectInput{
2026-05-16 00:18:22 +08:00
Bucket: aws.String(d.S3Bucket),
Key: aws.String(key),
Body: bytes.NewReader(data),
ContentType: aws.String(ct),
2026-05-19 07:37:25 +08:00
}
if acl, ok := s3PutObjectCannedACL(d.S3ObjectACL); ok {
put.ACL = acl
}
_, err := d.S3.PutObject(ctx, put)
2026-05-16 00:18:22 +08:00
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
pub := publicObjectURL(d.S3PublicBase, d.S3Bucket, d.AWSRegion, key)
writeJSON(w, map[string]any{"url": pub, "filename": name, "storage": "s3"})
return
}
dst := filepath.Join(d.LocalDir, name)
if err := os.MkdirAll(d.LocalDir, 0o755); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
out, err := os.Create(dst)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer out.Close()
if _, err := out.Write(data); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]any{"url": "/uploads/" + name, "filename": name, "storage": "local"})
}
}
2026-05-19 07:37:25 +08:00
// s3PutObjectCannedACL maps env S3_OBJECT_ACL to SDK enum; unknown values are ignored.
func s3PutObjectCannedACL(raw string) (types.ObjectCannedACL, bool) {
switch strings.TrimSpace(strings.ToLower(raw)) {
case "private":
return types.ObjectCannedACLPrivate, true
case "public-read":
return types.ObjectCannedACLPublicRead, true
case "public-read-write":
return types.ObjectCannedACLPublicReadWrite, true
case "authenticated-read":
return types.ObjectCannedACLAuthenticatedRead, true
case "bucket-owner-full-control":
return types.ObjectCannedACLBucketOwnerFullControl, true
case "bucket-owner-read":
return types.ObjectCannedACLBucketOwnerRead, true
case "aws-exec-read":
return types.ObjectCannedACLAwsExecRead, true
default:
return "", false
}
}
2026-05-16 00:18:22 +08:00
func publicObjectURL(base, bucket, region, key string) string {
base = strings.TrimSpace(base)
if base != "" {
return strings.TrimSuffix(base, "/") + "/" + key
}
// Virtual-hostedstyle URL (works for most buckets; use S3_PUBLIC_BASE_URL if not).
return fmt.Sprintf("https://%s.s3.%s.amazonaws.com/%s", bucket, region, key)
}