1818
1919namespace dali {
2020
21- JobBase::~JobBase () noexcept (false ) {
21+ template <bool cooperative>
22+ JobBase<cooperative>::~JobBase () noexcept (false ) {
2223 if (total_tasks_ > 0 && !wait_completed_) {
2324 throw std::logic_error (" The job is not empty, but hasn't been discarded or waited for." );
2425 }
2526 while (running_)
2627 std::this_thread::yield ();
2728}
2829
29- void JobBase::DoWait () {
30+ template <bool cooperative>
31+ void JobBase<cooperative>::DoWait() {
3032 if (wait_started_)
3133 throw std::logic_error (" This job has already been waited for." );
32- wait_started_ = true ;
3334
3435 if (total_tasks_ == 0 ) {
36+ // If there are no tasks, it's legal to skip a call to Run, therefore executor_ can be null.
37+ wait_started_ = true ;
3538 wait_completed_ = true ;
3639 return ;
3740 }
3841
39- if (executor_ == nullptr )
42+ if (this -> executor_ == nullptr )
4043 throw std::logic_error (" This job hasn't been run - cannot wait for it." );
4144
42- auto ready = [&]() { return num_pending_tasks_ == 0 ; };
43- if (ThreadPoolBase::this_thread_pool () != nullptr ) {
44- bool result = ThreadPoolBase::this_thread_pool ()->WaitOrRunTasks (cv_, ready);
45- wait_completed_ = true ;
46- if (!result)
47- throw std::logic_error (" The thread pool was stopped" );
45+ if (ThreadPoolBase::this_thread_pool () == this ->executor_ ) {
46+ if constexpr (cooperative) {
47+ auto ready = [&]() { return num_pending_tasks_ == 0 ; };
48+ wait_started_ = true ;
49+ bool result = ThreadPoolBase::this_thread_pool ()->WaitOrRunTasks (this ->cv_ , ready);
50+ wait_completed_ = true ;
51+ if (!result)
52+ throw std::logic_error (" The thread pool was stopped" );
53+ } else {
54+ throw std::logic_error (" Cannot wait for this job from inside the thread pool." );
55+ }
4856 } else {
57+ wait_started_ = true ;
4958 int old = num_pending_tasks_.load ();
5059 while (old != 0 ) {
5160 num_pending_tasks_.wait (old);
@@ -56,10 +65,13 @@ void JobBase::DoWait() {
5665 }
5766}
5867
59- void JobBase::DoNotify () {
68+ template <bool cooperative>
69+ void JobBase<cooperative>::DoNotify() {
6070 num_pending_tasks_.notify_all ();
61- (void )std::lock_guard (mtx_);
62- cv_.notify_all ();
71+ if constexpr (cooperative) {
72+ (void )std::lock_guard (this ->mtx_ );
73+ this ->cv_ .notify_all ();
74+ }
6375 // We need this second flag to avoid a race condition where the destructor is called between
6476 // decrementing num_pending_tasks_ and notification_ without excessive use of mutexes.
6577 // This must be the very last operation in the task function that touches `this`.
@@ -68,29 +80,35 @@ void JobBase::DoNotify() {
6880
6981// Job ////////////////////////////////////////////////////////////////////
7082
71- void Job::Run (ThreadPoolBase &tp, bool wait) {
72- if (executor_ != nullptr )
83+ template <bool cooperative>
84+ void JobImpl<cooperative>::Run(ThreadPoolBase &tp, bool wait) {
85+ if (this ->executor_ != nullptr )
7386 throw std::logic_error (" This job has already been started." );
74- executor_ = &tp;
75- running_ = !tasks_.empty ();
87+
88+ if (!cooperative && &tp == ThreadPoolBase::this_thread_pool ())
89+ throw std::logic_error (" Cannot run this job from inside the thread pool." );
90+
91+ this ->executor_ = &tp;
92+ this ->running_ = !tasks_.empty ();
7693 {
7794 auto batch = tp.BeginBulkAdd ();
7895 for (auto &x : tasks_) {
7996 batch.Add (std::move (x.second .func ));
8097 }
8198 int added = batch.Size ();
8299 if (added) {
83- num_pending_tasks_ += added;
84- running_ = true ;
100+ this -> num_pending_tasks_ += added;
101+ this -> running_ = true ;
85102 }
86103 batch.Submit ();
87104 }
88105 if (wait && !tasks_.empty ())
89106 Wait ();
90107}
91108
92- void Job::Wait () {
93- DoWait ();
109+ template <bool cooperative>
110+ void JobImpl<cooperative>::Wait() {
111+ this ->DoWait ();
94112
95113 // note - this vector is not allocated unless there were exceptions thrown
96114 std::vector<std::exception_ptr> errors;
@@ -104,19 +122,25 @@ void Job::Wait() {
104122 throw MultipleErrors (std::move (errors));
105123}
106124
107- void Job::Discard () {
108- if (executor_ != nullptr )
125+ template <bool cooperative>
126+ void JobImpl<cooperative>::Discard() {
127+ if (this ->executor_ != nullptr )
109128 throw std::logic_error (" Cannot discard a job that has already been started" );
110129 tasks_.clear ();
111- total_tasks_ = 0 ;
130+ this -> total_tasks_ = 0 ;
112131}
113132
114133// IncrementalJob /////////////////////////////////////////////////////////
115134
116- void IncrementalJob::Run (ThreadPoolBase &tp, bool wait) {
117- if (executor_ && executor_ != &tp)
135+ template <bool cooperative>
136+ void IncrementalJobImpl<cooperative>::Run(ThreadPoolBase &tp, bool wait) {
137+ if (this ->executor_ && this ->executor_ != &tp)
118138 throw std::logic_error (" This job is already running in a different executor." );
119- executor_ = &tp;
139+
140+ if (!cooperative && &tp == ThreadPoolBase::this_thread_pool ())
141+ throw std::logic_error (" Cannot run this job from inside the thread pool." );
142+
143+ this ->executor_ = &tp;
120144 {
121145 auto it = last_task_run_.has_value () ? std::next (*last_task_run_) : tasks_.begin ();
122146 auto batch = tp.BeginBulkAdd ();
@@ -126,24 +150,26 @@ void IncrementalJob::Run(ThreadPoolBase &tp, bool wait) {
126150 }
127151 int added = batch.Size ();
128152 if (added) {
129- num_pending_tasks_ += added;
130- running_ = true ;
153+ this -> num_pending_tasks_ += added;
154+ this -> running_ = true ;
131155 }
132156 batch.Submit ();
133157 }
134158 if (wait && !tasks_.empty ())
135159 Wait ();
136160}
137161
138- void IncrementalJob::Discard () {
139- if (executor_)
162+ template <bool cooperative>
163+ void IncrementalJobImpl<cooperative>::Discard() {
164+ if (this ->executor_ )
140165 throw std::logic_error (" Cannot discard a job that has already been started" );
141166 tasks_.clear ();
142- total_tasks_ = 0 ;
167+ this -> total_tasks_ = 0 ;
143168}
144169
145- void IncrementalJob::Wait () {
146- DoWait ();
170+ template <bool cooperative>
171+ void IncrementalJobImpl<cooperative>::Wait() {
172+ this ->DoWait ();
147173 // note - this vector is not allocated unless there were exceptions thrown
148174 std::vector<std::exception_ptr> errors;
149175 for (auto &x : tasks_) {
@@ -158,8 +184,15 @@ void IncrementalJob::Wait() {
158184
159185// /////////////////////////////////////////////////////////////////////////
160186
187+ template class JobBase <true >;
188+ template class JobImpl <true >;
189+ template class JobBase <false >;
190+ template class JobImpl <false >;
191+ template class IncrementalJobImpl <true >;
192+ template class IncrementalJobImpl <false >;
193+
161194thread_local ThreadPoolBase *ThreadPoolBase::this_thread_pool_ = nullptr ;
162- thread_local int ThreadPoolBase ::this_thread_idx_ = -1 ;
195+ thread_local int ThisThreadIdx ::this_thread_idx_ = -1 ;
163196
164197void ThreadPoolBase::Init (int num_threads, const std::function<OnThreadStartFn> &on_thread_start) {
165198 if (shutdown_pending_)
0 commit comments