diff --git a/plugins/host_taskqueue.go b/plugins/host_taskqueue.go index ec68de6d8..336ae1fe6 100644 --- a/plugins/host_taskqueue.go +++ b/plugins/host_taskqueue.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "math" "os" "path/filepath" "sync" @@ -29,13 +28,11 @@ const ( maxRetentionMs int64 = 604_800_000 // 1 week maxQueueNameLength = 128 maxPayloadSize = 1 * 1024 * 1024 // 1MB - maxBackoffMs = 3_600_000 // 1 hour + maxBackoffMs int64 = 3_600_000 // 1 hour cleanupInterval = 5 * time.Minute pollInterval = 5 * time.Second shutdownTimeout = 10 * time.Second -) -const ( taskStatusPending = "pending" taskStatusRunning = "running" taskStatusCompleted = "completed" @@ -44,7 +41,6 @@ const ( ) // CapabilityTaskWorker indicates the plugin can receive task execution callbacks. -// Detected when the plugin exports the task worker callback function. const CapabilityTaskWorker Capability = "TaskWorker" const FuncTaskWorkerCallback = "nd_task_execute" @@ -53,11 +49,18 @@ func init() { registerCapability(CapabilityTaskWorker, FuncTaskWorkerCallback) } -// queueState holds in-memory state for a single task queue. type queueState struct { config host.QueueConfig signal chan struct{} - limiter *rate.Limiter // rate limiter for delay enforcement between dispatches + limiter *rate.Limiter +} + +// notifyWorkers sends a non-blocking signal to wake up queue workers. +func (qs *queueState) notifyWorkers() { + select { + case qs.signal <- struct{}{}: + default: + } } // taskQueueServiceImpl implements host.TaskQueueService with SQLite persistence @@ -79,13 +82,11 @@ type taskQueueServiceImpl struct { // newTaskQueueService creates a new taskQueueServiceImpl with its own SQLite database. func newTaskQueueService(pluginName string, manager *Manager, maxConcurrency int32) (*taskQueueServiceImpl, error) { - // Create plugin data directory dataDir := filepath.Join(conf.Server.DataFolder, "plugins", pluginName) if err := os.MkdirAll(dataDir, 0700); err != nil { return nil, fmt.Errorf("creating plugin data directory: %w", err) } - // Open SQLite database dbPath := filepath.Join(dataDir, "taskqueue.db") db, err := sql.Open("sqlite3", dbPath+"?_busy_timeout=5000&_journal_mode=WAL&_foreign_keys=off") if err != nil { @@ -95,7 +96,6 @@ func newTaskQueueService(pluginName string, manager *Manager, maxConcurrency int db.SetMaxOpenConns(3) db.SetMaxIdleConns(1) - // Create schema if err := createTaskQueueSchema(db); err != nil { db.Close() return nil, fmt.Errorf("creating taskqueue schema: %w", err) @@ -114,7 +114,6 @@ func newTaskQueueService(pluginName string, manager *Manager, maxConcurrency int } s.invokeCallbackFn = s.defaultInvokeCallback - // Start cleanup goroutine s.wg.Add(1) go s.cleanupLoop() @@ -150,17 +149,9 @@ func createTaskQueueSchema(db *sql.DB) error { return err } -// CreateQueue creates a named task queue with the given configuration. -func (s *taskQueueServiceImpl) CreateQueue(ctx context.Context, name string, config host.QueueConfig) error { - // Validate queue name - if len(name) == 0 { - return fmt.Errorf("queue name cannot be empty") - } - if len(name) > maxQueueNameLength { - return fmt.Errorf("queue name exceeds maximum length of %d bytes", maxQueueNameLength) - } - - // Apply defaults +// applyConfigDefaults fills zero-value config fields with sensible defaults +// and clamps values to valid ranges, logging warnings for clamped values. +func (s *taskQueueServiceImpl) applyConfigDefaults(ctx context.Context, name string, config *host.QueueConfig) { if config.Concurrency <= 0 { config.Concurrency = defaultConcurrency } @@ -171,7 +162,6 @@ func (s *taskQueueServiceImpl) CreateQueue(ctx context.Context, name string, con config.RetentionMs = defaultRetentionMs } - // Clamp retention if config.RetentionMs < minRetentionMs { log.Warn(ctx, "TaskQueue retention clamped to minimum", "plugin", s.pluginName, "queue", name, "requested", config.RetentionMs, "min", minRetentionMs) @@ -182,26 +172,38 @@ func (s *taskQueueServiceImpl) CreateQueue(ctx context.Context, name string, con "requested", config.RetentionMs, "max", maxRetentionMs) config.RetentionMs = maxRetentionMs } +} - s.mu.Lock() - defer s.mu.Unlock() - - // Clamp concurrency based on maxConcurrency minus already-allocated concurrency +// clampConcurrency reduces config.Concurrency if it exceeds the remaining budget. +// Must be called with s.mu held. +func (s *taskQueueServiceImpl) clampConcurrency(ctx context.Context, name string, config *host.QueueConfig) { var allocated int32 for _, qs := range s.queues { allocated += qs.config.Concurrency } - available := s.maxConcurrency - allocated - if available <= 0 { - available = 1 // Always allow at least 1 - } + available := max(s.maxConcurrency-allocated, 1) if config.Concurrency > available { log.Warn(ctx, "TaskQueue concurrency clamped", "plugin", s.pluginName, "queue", name, "requested", config.Concurrency, "available", available, "maxConcurrency", s.maxConcurrency) config.Concurrency = available } +} + +func (s *taskQueueServiceImpl) CreateQueue(ctx context.Context, name string, config host.QueueConfig) error { + if len(name) == 0 { + return fmt.Errorf("queue name cannot be empty") + } + if len(name) > maxQueueNameLength { + return fmt.Errorf("queue name exceeds maximum length of %d bytes", maxQueueNameLength) + } + + s.applyConfigDefaults(ctx, name, &config) + + s.mu.Lock() + defer s.mu.Unlock() + + s.clampConcurrency(ctx, name, &config) - // Check queue name doesn't already exist if _, exists := s.queues[name]; exists { return fmt.Errorf("queue %q already exists", name) } @@ -230,7 +232,6 @@ func (s *taskQueueServiceImpl) CreateQueue(ctx context.Context, name string, con return fmt.Errorf("resetting stale tasks: %w", err) } - // Store queue state qs := &queueState{ config: config, signal: make(chan struct{}, 1), @@ -242,7 +243,6 @@ func (s *taskQueueServiceImpl) CreateQueue(ctx context.Context, name string, con } s.queues[name] = qs - // Start worker goroutines for i := int32(0); i < config.Concurrency; i++ { s.wg.Add(1) go s.worker(name, qs) @@ -254,7 +254,6 @@ func (s *taskQueueServiceImpl) CreateQueue(ctx context.Context, name string, con return nil } -// Enqueue adds a task to the named queue and returns the task ID. func (s *taskQueueServiceImpl) Enqueue(ctx context.Context, queueName string, payload []byte) (string, error) { s.mu.Lock() qs, exists := s.queues[queueName] @@ -263,7 +262,6 @@ func (s *taskQueueServiceImpl) Enqueue(ctx context.Context, queueName string, pa if !exists { return "", fmt.Errorf("queue %q does not exist", queueName) } - if len(payload) > maxPayloadSize { return "", fmt.Errorf("payload size %d exceeds maximum of %d bytes", len(payload), maxPayloadSize) } @@ -279,12 +277,7 @@ func (s *taskQueueServiceImpl) Enqueue(ctx context.Context, queueName string, pa return "", fmt.Errorf("enqueuing task: %w", err) } - // Signal workers (non-blocking) - select { - case qs.signal <- struct{}{}: - default: - } - + qs.notifyWorkers() log.Trace(ctx, "Enqueued task", "plugin", s.pluginName, "queue", queueName, "taskID", taskID) return taskID, nil } @@ -356,15 +349,8 @@ func (s *taskQueueServiceImpl) worker(queueName string, qs *queueState) { } } -// drainQueue processes tasks until the queue is empty. func (s *taskQueueServiceImpl) drainQueue(queueName string, qs *queueState) { - for { - if s.ctx.Err() != nil { - return - } - if !s.processTask(queueName, qs) { - return - } + for s.ctx.Err() == nil && s.processTask(queueName, qs) { } } @@ -375,8 +361,7 @@ func (s *taskQueueServiceImpl) processTask(queueName string, qs *queueState) boo // Atomically dequeue a task var taskID string var payload []byte - var attempt int32 - var maxRetries int32 + var attempt, maxRetries int32 err := s.db.QueryRowContext(s.ctx, ` UPDATE tasks SET status = ?, attempt = attempt + 1, updated_at = ? WHERE id = ( @@ -414,53 +399,51 @@ func (s *taskQueueServiceImpl) processTask(queueName string, qs *queueState) boo return false } - now = time.Now().UnixMilli() if callbackErr == nil { - // Success: mark as completed - _, err = s.db.ExecContext(s.ctx, `UPDATE tasks SET status = ?, updated_at = ? WHERE id = ?`, taskStatusCompleted, now, taskID) - if err != nil { - log.Error(s.ctx, "Failed to mark task as completed", "plugin", s.pluginName, "taskID", taskID, err) - } - log.Debug(s.ctx, "Task completed", "plugin", s.pluginName, "queue", queueName, "taskID", taskID) + s.completeTask(queueName, taskID) } else { - // Failure: retry or mark as failed - log.Warn(s.ctx, "Task execution failed", "plugin", s.pluginName, "queue", queueName, - "taskID", taskID, "attempt", attempt, "maxRetries", maxRetries, "err", callbackErr) + s.handleTaskFailure(queueName, taskID, attempt, maxRetries, qs, callbackErr) + } + return true +} - if attempt <= maxRetries { - // Retry with exponential backoff: backoffMs * 2^(attempt-1) - backoff := qs.config.BackoffMs * int64(math.Pow(2, float64(attempt-1))) - if backoff < 0 || backoff > maxBackoffMs { - backoff = maxBackoffMs - } - nextRunAt := now + backoff - _, err = s.db.ExecContext(s.ctx, ` - UPDATE tasks SET status = ?, next_run_at = ?, updated_at = ? WHERE id = ? - `, taskStatusPending, nextRunAt, now, taskID) - if err != nil { - log.Error(s.ctx, "Failed to reschedule task for retry", "plugin", s.pluginName, "taskID", taskID, err) - } +func (s *taskQueueServiceImpl) completeTask(queueName, taskID string) { + now := time.Now().UnixMilli() + if _, err := s.db.ExecContext(s.ctx, `UPDATE tasks SET status = ?, updated_at = ? WHERE id = ?`, taskStatusCompleted, now, taskID); err != nil { + log.Error(s.ctx, "Failed to mark task as completed", "plugin", s.pluginName, "taskID", taskID, err) + } + log.Debug(s.ctx, "Task completed", "plugin", s.pluginName, "queue", queueName, "taskID", taskID) +} - // Schedule a delayed signal so the worker picks up the retried task - // after the backoff period, rather than waiting for the next poll. - backoffDuration := time.Duration(backoff) * time.Millisecond - time.AfterFunc(backoffDuration, func() { - select { - case qs.signal <- struct{}{}: - default: - } - }) - } else { - // Exhausted retries: mark as failed - _, err = s.db.ExecContext(s.ctx, `UPDATE tasks SET status = ?, updated_at = ? WHERE id = ?`, taskStatusFailed, now, taskID) - if err != nil { - log.Error(s.ctx, "Failed to mark task as failed", "plugin", s.pluginName, "taskID", taskID, err) - } - log.Warn(s.ctx, "Task failed after all retries", "plugin", s.pluginName, "queue", queueName, "taskID", taskID) +func (s *taskQueueServiceImpl) handleTaskFailure(queueName, taskID string, attempt, maxRetries int32, qs *queueState, callbackErr error) { + log.Warn(s.ctx, "Task execution failed", "plugin", s.pluginName, "queue", queueName, + "taskID", taskID, "attempt", attempt, "maxRetries", maxRetries, "err", callbackErr) + + now := time.Now().UnixMilli() + if attempt > maxRetries { + if _, err := s.db.ExecContext(s.ctx, `UPDATE tasks SET status = ?, updated_at = ? WHERE id = ?`, taskStatusFailed, now, taskID); err != nil { + log.Error(s.ctx, "Failed to mark task as failed", "plugin", s.pluginName, "taskID", taskID, err) } + log.Warn(s.ctx, "Task failed after all retries", "plugin", s.pluginName, "queue", queueName, "taskID", taskID) + return } - return true + // Exponential backoff: backoffMs * 2^(attempt-1) + backoff := qs.config.BackoffMs << (attempt - 1) + if backoff <= 0 || backoff > maxBackoffMs { + backoff = maxBackoffMs + } + nextRunAt := now + backoff + if _, err := s.db.ExecContext(s.ctx, ` + UPDATE tasks SET status = ?, next_run_at = ?, updated_at = ? WHERE id = ? + `, taskStatusPending, nextRunAt, now, taskID); err != nil { + log.Error(s.ctx, "Failed to reschedule task for retry", "plugin", s.pluginName, "taskID", taskID, err) + } + + // Wake worker after backoff expires + time.AfterFunc(time.Duration(backoff)*time.Millisecond, func() { + qs.notifyWorkers() + }) } // revertTaskToPending puts a running task back to pending status and decrements the attempt @@ -561,11 +544,9 @@ func (s *taskQueueServiceImpl) Close() error { // Mark running tasks as pending for recovery on next startup if s.db != nil { now := time.Now().UnixMilli() - _, err := s.db.Exec(`UPDATE tasks SET status = ?, updated_at = ? WHERE status = ?`, taskStatusPending, now, taskStatusRunning) - if err != nil { + if _, err := s.db.Exec(`UPDATE tasks SET status = ?, updated_at = ? WHERE status = ?`, taskStatusPending, now, taskStatusRunning); err != nil { log.Error("Failed to reset running tasks on shutdown", "plugin", s.pluginName, err) } - log.Debug("Closing plugin taskqueue", "plugin", s.pluginName) return s.db.Close() }