package handlers import ( "context" "net/http" "strings" "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) type postAttachmentDownloadInfo struct { URL string Filename string Mime string } func loadPublishedPostAttachment(ctx context.Context, pool *pgxpool.Pool, postID, aid uuid.UUID) (postAttachmentDownloadInfo, error) { var info postAttachmentDownloadInfo err := pool.QueryRow(ctx, ` SELECT pa.url, COALESCE(NULLIF(TRIM(pa.filename), ''), 'download'), COALESCE(NULLIF(TRIM(pa.mime), ''), 'application/octet-stream') FROM post_attachments pa JOIN posts p ON p.id = pa.post_id WHERE pa.id = $1 AND pa.post_id = $2 AND `+publicPostWhere, aid, postID, ).Scan(&info.URL, &info.Filename, &info.Mime) if err != nil { return info, err } info.URL = strings.TrimSpace(info.URL) if info.URL == "" { return info, pgx.ErrNoRows } return info, nil } func incrPostDownloadCount(ctx context.Context, pool *pgxpool.Pool, postID uuid.UUID) error { cmd, err := pool.Exec(ctx, ` UPDATE posts SET download_count = download_count + 1, updated_at = NOW() WHERE id = $1 AND `+publicPostWhere, postID) if err != nil { return err } if cmd.RowsAffected() == 0 { return pgx.ErrNoRows } return nil } func handlePostAttachmentDownload(w http.ResponseWriter, r *http.Request, redirect bool) { pool := poolFrom(r) postID, err := uuid.Parse(chi.URLParam(r, "id")) if err != nil { http.Error(w, "bad id", http.StatusBadRequest) return } aid, err := uuid.Parse(chi.URLParam(r, "aid")) if err != nil { http.Error(w, "bad attachment id", http.StatusBadRequest) return } ctx := r.Context() info, err := loadPublishedPostAttachment(ctx, pool, postID, aid) if err != nil { if err == pgx.ErrNoRows { http.NotFound(w, r) return } http.Error(w, err.Error(), http.StatusInternalServerError) return } if err := incrPostDownloadCount(ctx, pool, postID); err != nil { if err == pgx.ErrNoRows { http.NotFound(w, r) return } http.Error(w, err.Error(), http.StatusInternalServerError) return } if redirect { http.Redirect(w, r, info.URL, http.StatusFound) return } writeJSON(w, map[string]any{ "ok": true, "url": info.URL, "filename": info.Filename, "mime": info.Mime, }) } // PostAttachmentDownload increments download_count and returns the attachment file URL for the client to open/save. func PostAttachmentDownload(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } handlePostAttachmentDownload(w, r, false) } // GetPostAttachmentDownload increments download_count and redirects the browser to the attachment URL. func GetPostAttachmentDownload(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } handlePostAttachmentDownload(w, r, true) }