diff --git a/cmd/server/main.go b/cmd/server/main.go index a955152..35e059a 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -127,6 +127,7 @@ func main() { r.Get("/posts/latest", handlers.ListPostsLatest) r.Get("/posts/search", handlers.SearchPosts) r.Get("/posts/{id}", handlers.GetPost) + r.Get("/posts/{id}/attachments/{aid}/download", handlers.GetPostAttachmentDownload) r.Post("/posts/{id}/attachments/{aid}/download", handlers.PostAttachmentDownload) r.Get("/resources", handlers.ListResources) r.Get("/resources/recommended", handlers.ListRecommended) diff --git a/internal/handlers/post_download.go b/internal/handlers/post_download.go new file mode 100644 index 0000000..92a0149 --- /dev/null +++ b/internal/handlers/post_download.go @@ -0,0 +1,110 @@ +package handlers + +import ( + "context" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +type postAttachmentDownloadInfo struct { + URL string + Filename string + Mime string +} + +func loadPublishedPostAttachment(ctx context.Context, pool *pgxpool.Pool, postID, aid uuid.UUID) (postAttachmentDownloadInfo, error) { + var info postAttachmentDownloadInfo + err := pool.QueryRow(ctx, ` + SELECT pa.url, COALESCE(NULLIF(TRIM(pa.filename), ''), 'download'), COALESCE(NULLIF(TRIM(pa.mime), ''), 'application/octet-stream') + FROM post_attachments pa + JOIN posts p ON p.id = pa.post_id + WHERE pa.id = $1 AND pa.post_id = $2 + AND `+publicPostWhere, aid, postID, + ).Scan(&info.URL, &info.Filename, &info.Mime) + if err != nil { + return info, err + } + info.URL = strings.TrimSpace(info.URL) + if info.URL == "" { + return info, pgx.ErrNoRows + } + return info, nil +} + +func incrPostDownloadCount(ctx context.Context, pool *pgxpool.Pool, postID uuid.UUID) error { + cmd, err := pool.Exec(ctx, ` + UPDATE posts SET download_count = download_count + 1, updated_at = NOW() + WHERE id = $1 AND `+publicPostWhere, postID) + if err != nil { + return err + } + if cmd.RowsAffected() == 0 { + return pgx.ErrNoRows + } + return nil +} + +func handlePostAttachmentDownload(w http.ResponseWriter, r *http.Request, redirect bool) { + pool := poolFrom(r) + postID, err := uuid.Parse(chi.URLParam(r, "id")) + if err != nil { + http.Error(w, "bad id", http.StatusBadRequest) + return + } + aid, err := uuid.Parse(chi.URLParam(r, "aid")) + if err != nil { + http.Error(w, "bad attachment id", http.StatusBadRequest) + return + } + ctx := r.Context() + info, err := loadPublishedPostAttachment(ctx, pool, postID, aid) + if err != nil { + if err == pgx.ErrNoRows { + http.NotFound(w, r) + return + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err := incrPostDownloadCount(ctx, pool, postID); err != nil { + if err == pgx.ErrNoRows { + http.NotFound(w, r) + return + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if redirect { + http.Redirect(w, r, info.URL, http.StatusFound) + return + } + writeJSON(w, map[string]any{ + "ok": true, + "url": info.URL, + "filename": info.Filename, + "mime": info.Mime, + }) +} + +// PostAttachmentDownload increments download_count and returns the attachment file URL for the client to open/save. +func PostAttachmentDownload(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + handlePostAttachmentDownload(w, r, false) +} + +// GetPostAttachmentDownload increments download_count and redirects the browser to the attachment URL. +func GetPostAttachmentDownload(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + handlePostAttachmentDownload(w, r, true) +} diff --git a/internal/handlers/posts_public.go b/internal/handlers/posts_public.go index ed96bee..675a5da 100644 --- a/internal/handlers/posts_public.go +++ b/internal/handlers/posts_public.go @@ -93,30 +93,6 @@ func GetPost(w http.ResponseWriter, r *http.Request) { writeJSON(w, dto) } -func PostAttachmentDownload(w http.ResponseWriter, r *http.Request) { - pool := poolFrom(r) - postID, err := uuid.Parse(chi.URLParam(r, "id")) - if err != nil { - http.Error(w, "bad id", http.StatusBadRequest) - return - } - aid, err := uuid.Parse(chi.URLParam(r, "aid")) - if err != nil { - http.Error(w, "bad attachment id", http.StatusBadRequest) - return - } - cmd, err := pool.Exec(r.Context(), ` - UPDATE posts SET download_count = download_count + 1, updated_at = NOW() - WHERE id = $1 AND status = 'published' AND is_public = TRUE - AND (published_at IS NULL OR published_at <= NOW())`, postID) - if err != nil || cmd.RowsAffected() == 0 { - http.NotFound(w, r) - return - } - writeJSON(w, map[string]any{"ok": true}) - _ = aid -} - func listPostsQuery(w http.ResponseWriter, r *http.Request, searchMode bool) { pool := poolFrom(r) limit := postLimitDef(r, 20, 50)