diff --git a/job.go b/job.go index 700a0b6..85a1857 100644 --- a/job.go +++ b/job.go @@ -18,11 +18,12 @@ import ( // internalJob stores the information needed by the scheduler // to manage scheduling, starting and stopping the job type internalJob struct { - ctx context.Context - cancel context.CancelFunc - id uuid.UUID - name string - tags []string + ctx context.Context + parentCtx context.Context + cancel context.CancelFunc + id uuid.UUID + name string + tags []string jobSchedule // as some jobs may queue up, it's possible to @@ -703,6 +704,14 @@ func WithIdentifier(id uuid.UUID) JobOption { } } +// WithContext sets the parent context for the job +func WithContext(ctx context.Context) JobOption { + return func(j *internalJob, _ time.Time) error { + j.parentCtx = ctx + return nil + } +} + // ----------------------------------------------- // ----------------------------------------------- // ------------- Job Event Listeners ------------- diff --git a/scheduler.go b/scheduler.go index 825323e..8789f50 100644 --- a/scheduler.go +++ b/scheduler.go @@ -648,8 +648,6 @@ func (s *scheduler) addOrUpdateJob(id uuid.UUID, definition JobDefinition, taskW j.id = id } - j.ctx, j.cancel = context.WithCancel(s.shutdownCtx) - if taskWrapper == nil { return nil, ErrNewJobTaskNil } @@ -664,10 +662,6 @@ func (s *scheduler) addOrUpdateJob(id uuid.UUID, definition JobDefinition, taskW return nil, ErrNewJobTaskNotFunc } - if err := s.verifyParameterType(taskFunc, tsk); err != nil { - return nil, err - } - j.name = runtime.FuncForPC(taskFunc.Pointer()).Name() j.function = tsk.function j.parameters = tsk.parameters @@ -686,6 +680,28 @@ func (s *scheduler) addOrUpdateJob(id uuid.UUID, definition JobDefinition, taskW } } + if j.parentCtx == nil { + j.parentCtx = s.shutdownCtx + } + j.ctx, j.cancel = context.WithCancel(j.parentCtx) + + if !taskFunc.IsZero() && taskFunc.Type().NumIn() > 0 { + // if the first parameter is a context.Context and params have no context.Context, add current ctx to the params + if taskFunc.Type().In(0) == reflect.TypeOf((*context.Context)(nil)).Elem() { + if len(tsk.parameters) == 0 { + tsk.parameters = []any{j.ctx} + j.parameters = []any{j.ctx} + } else if _, ok := tsk.parameters[0].(context.Context); !ok { + tsk.parameters = append([]any{j.ctx}, tsk.parameters...) + j.parameters = append([]any{j.ctx}, j.parameters...) + } + } + } + + if err := s.verifyParameterType(taskFunc, tsk); err != nil { + return nil, err + } + if err := definition.setup(&j, s.location, s.exec.clock.Now()); err != nil { return nil, err } diff --git a/scheduler_test.go b/scheduler_test.go index 8d46d31..3c47472 100644 --- a/scheduler_test.go +++ b/scheduler_test.go @@ -361,6 +361,41 @@ func TestScheduler_StopTimeout(t *testing.T) { } } +func TestScheduler_StopLongRunningJobs(t *testing.T) { + t.Run("start, run job, stop jobs before job is completed", func(t *testing.T) { + s := newTestScheduler(t, + WithStopTimeout(50*time.Millisecond), + ) + + _, err := s.NewJob( + DurationJob( + 50*time.Millisecond, + ), + NewTask( + func(ctx context.Context) { + select { + case <-ctx.Done(): + case <-time.After(100 * time.Millisecond): + t.Fatal("job can not been canceled") + } + }, + ), + WithStartAt( + WithStartImmediately(), + ), + WithSingletonMode(LimitModeReschedule), + ) + require.NoError(t, err) + + s.Start() + + time.Sleep(20 * time.Millisecond) + // the running job is canceled, no unexpected timeout error + require.NoError(t, s.StopJobs()) + time.Sleep(100 * time.Millisecond) + }) +} + func TestScheduler_Shutdown(t *testing.T) { defer verifyNoGoroutineLeaks(t)