diff --git a/scheduler.go b/scheduler.go index f0c0e20..1656aee 100644 --- a/scheduler.go +++ b/scheduler.go @@ -7,6 +7,7 @@ import ( "runtime" "slices" "strings" + "sync/atomic" "time" "github.com/google/uuid" @@ -72,7 +73,7 @@ type scheduler struct { // the location used by the scheduler for scheduling when relevant location *time.Location // whether the scheduler has been started or not - started bool + started atomic.Bool // globally applied JobOption's set on all jobs added to the scheduler // note: individually set JobOption's take precedence. globalJobOptions []JobOption @@ -233,13 +234,8 @@ func NewScheduler(options ...SchedulerOption) (Scheduler, error) { func (s *scheduler) stopScheduler() { s.logger.Debug("gocron: stopping scheduler") - if !s.started { - s.logger.Debug("gocron: scheduler already stopped") - s.stopErrCh <- nil - return - } - if s.started { + if s.started.Load() { s.exec.stopCh <- struct{}{} } @@ -250,7 +246,7 @@ func (s *scheduler) stopScheduler() { <-j.ctx.Done() } var err error - if s.started { + if s.started.Load() { t := time.NewTimer(s.exec.stopTimeout + 1*time.Second) select { case err = <-s.exec.done: @@ -275,7 +271,7 @@ func (s *scheduler) stopScheduler() { } s.stopErrCh <- err - s.started = false + s.started.Store(false) s.logger.Debug("gocron: scheduler stopped") } @@ -480,7 +476,7 @@ func (s *scheduler) selectJobOutRequest(out *jobOutRequest) { func (s *scheduler) selectNewJob(in newJobIn) { j := in.job - if s.started { + if s.started.Load() { next := j.startTime if j.startImmediately { next = s.now() @@ -531,7 +527,7 @@ func (s *scheduler) selectStart() { s.logger.Debug("gocron: scheduler starting") go s.exec.start() - s.started = true + s.started.Store(true) for id, j := range s.jobs { next := j.startTime if j.startImmediately { @@ -826,15 +822,16 @@ func (s *scheduler) StopJobs() error { } func (s *scheduler) Shutdown() error { - if !s.started { + s.logger.Debug("scheduler shutting down") + + s.shutdownCancel() + if !s.started.Load() { return nil } - s.shutdownCancel() t := time.NewTimer(s.exec.stopTimeout + 2*time.Second) select { case err := <-s.stopErrCh: - t.Stop() return err case <-t.C: diff --git a/scheduler_test.go b/scheduler_test.go index c2b0594..d029522 100644 --- a/scheduler_test.go +++ b/scheduler_test.go @@ -568,10 +568,9 @@ func TestScheduler_Shutdown(t *testing.T) { assert.ErrorIs(t, err, ErrJobNotFound) }) - t.Run("calling shutdown multiple times including before start is a no-op", func(t *testing.T) { + t.Run("calling shutdown multiple times is a no-op", func(t *testing.T) { s := newTestScheduler(t) - assert.NoError(t, s.Shutdown()) s.Start() assert.NoError(t, s.Shutdown())