Skip to content

Commit 4fc92ab

Browse files
authored
Add non-cooperative jobs to new ThreadPool (#6245)
* This PR adds more lightweight non-cooperative jobs to new thread pool. * The default Job and IncrementalJob become non-cooperative * Cooperative[Incremental]Job can be used for reentrant behavior. --------- Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
1 parent ef8bd5c commit 4fc92ab

4 files changed

Lines changed: 266 additions & 89 deletions

File tree

dali/core/exec/thread_pool_base.cc

Lines changed: 68 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,34 +18,43 @@
1818

1919
namespace dali {
2020

21-
JobBase::~JobBase() noexcept(false) {
21+
template <bool cooperative>
22+
JobBase<cooperative>::~JobBase() noexcept(false) {
2223
if (total_tasks_ > 0 && !wait_completed_) {
2324
throw std::logic_error("The job is not empty, but hasn't been discarded or waited for.");
2425
}
2526
while (running_)
2627
std::this_thread::yield();
2728
}
2829

29-
void JobBase::DoWait() {
30+
template <bool cooperative>
31+
void JobBase<cooperative>::DoWait() {
3032
if (wait_started_)
3133
throw std::logic_error("This job has already been waited for.");
32-
wait_started_ = true;
3334

3435
if (total_tasks_ == 0) {
36+
// If there are no tasks, it's legal to skip a call to Run, therefore executor_ can be null.
37+
wait_started_ = true;
3538
wait_completed_ = true;
3639
return;
3740
}
3841

39-
if (executor_ == nullptr)
42+
if (this->executor_ == nullptr)
4043
throw std::logic_error("This job hasn't been run - cannot wait for it.");
4144

42-
auto ready = [&]() { return num_pending_tasks_ == 0; };
43-
if (ThreadPoolBase::this_thread_pool() != nullptr) {
44-
bool result = ThreadPoolBase::this_thread_pool()->WaitOrRunTasks(cv_, ready);
45-
wait_completed_ = true;
46-
if (!result)
47-
throw std::logic_error("The thread pool was stopped");
45+
if (ThreadPoolBase::this_thread_pool() == this->executor_) {
46+
if constexpr (cooperative) {
47+
auto ready = [&]() { return num_pending_tasks_ == 0; };
48+
wait_started_ = true;
49+
bool result = ThreadPoolBase::this_thread_pool()->WaitOrRunTasks(this->cv_, ready);
50+
wait_completed_ = true;
51+
if (!result)
52+
throw std::logic_error("The thread pool was stopped");
53+
} else {
54+
throw std::logic_error("Cannot wait for this job from inside the thread pool.");
55+
}
4856
} else {
57+
wait_started_ = true;
4958
int old = num_pending_tasks_.load();
5059
while (old != 0) {
5160
num_pending_tasks_.wait(old);
@@ -56,10 +65,13 @@ void JobBase::DoWait() {
5665
}
5766
}
5867

59-
void JobBase::DoNotify() {
68+
template <bool cooperative>
69+
void JobBase<cooperative>::DoNotify() {
6070
num_pending_tasks_.notify_all();
61-
(void)std::lock_guard(mtx_);
62-
cv_.notify_all();
71+
if constexpr (cooperative) {
72+
(void)std::lock_guard(this->mtx_);
73+
this->cv_.notify_all();
74+
}
6375
// We need this second flag to avoid a race condition where the destructor is called between
6476
// decrementing num_pending_tasks_ and notification_ without excessive use of mutexes.
6577
// This must be the very last operation in the task function that touches `this`.
@@ -68,29 +80,35 @@ void JobBase::DoNotify() {
6880

6981
// Job ////////////////////////////////////////////////////////////////////
7082

71-
void Job::Run(ThreadPoolBase &tp, bool wait) {
72-
if (executor_ != nullptr)
83+
template <bool cooperative>
84+
void JobImpl<cooperative>::Run(ThreadPoolBase &tp, bool wait) {
85+
if (this->executor_ != nullptr)
7386
throw std::logic_error("This job has already been started.");
74-
executor_ = &tp;
75-
running_ = !tasks_.empty();
87+
88+
if (!cooperative && &tp == ThreadPoolBase::this_thread_pool())
89+
throw std::logic_error("Cannot run this job from inside the thread pool.");
90+
91+
this->executor_ = &tp;
92+
this->running_ = !tasks_.empty();
7693
{
7794
auto batch = tp.BeginBulkAdd();
7895
for (auto &x : tasks_) {
7996
batch.Add(std::move(x.second.func));
8097
}
8198
int added = batch.Size();
8299
if (added) {
83-
num_pending_tasks_ += added;
84-
running_ = true;
100+
this->num_pending_tasks_ += added;
101+
this->running_ = true;
85102
}
86103
batch.Submit();
87104
}
88105
if (wait && !tasks_.empty())
89106
Wait();
90107
}
91108

92-
void Job::Wait() {
93-
DoWait();
109+
template <bool cooperative>
110+
void JobImpl<cooperative>::Wait() {
111+
this->DoWait();
94112

95113
// note - this vector is not allocated unless there were exceptions thrown
96114
std::vector<std::exception_ptr> errors;
@@ -104,19 +122,25 @@ void Job::Wait() {
104122
throw MultipleErrors(std::move(errors));
105123
}
106124

107-
void Job::Discard() {
108-
if (executor_ != nullptr)
125+
template <bool cooperative>
126+
void JobImpl<cooperative>::Discard() {
127+
if (this->executor_ != nullptr)
109128
throw std::logic_error("Cannot discard a job that has already been started");
110129
tasks_.clear();
111-
total_tasks_ = 0;
130+
this->total_tasks_ = 0;
112131
}
113132

114133
// IncrementalJob /////////////////////////////////////////////////////////
115134

116-
void IncrementalJob::Run(ThreadPoolBase &tp, bool wait) {
117-
if (executor_ && executor_ != &tp)
135+
template <bool cooperative>
136+
void IncrementalJobImpl<cooperative>::Run(ThreadPoolBase &tp, bool wait) {
137+
if (this->executor_ && this->executor_ != &tp)
118138
throw std::logic_error("This job is already running in a different executor.");
119-
executor_ = &tp;
139+
140+
if (!cooperative && &tp == ThreadPoolBase::this_thread_pool())
141+
throw std::logic_error("Cannot run this job from inside the thread pool.");
142+
143+
this->executor_ = &tp;
120144
{
121145
auto it = last_task_run_.has_value() ? std::next(*last_task_run_) : tasks_.begin();
122146
auto batch = tp.BeginBulkAdd();
@@ -126,24 +150,26 @@ void IncrementalJob::Run(ThreadPoolBase &tp, bool wait) {
126150
}
127151
int added = batch.Size();
128152
if (added) {
129-
num_pending_tasks_ += added;
130-
running_ = true;
153+
this->num_pending_tasks_ += added;
154+
this->running_ = true;
131155
}
132156
batch.Submit();
133157
}
134158
if (wait && !tasks_.empty())
135159
Wait();
136160
}
137161

138-
void IncrementalJob::Discard() {
139-
if (executor_)
162+
template <bool cooperative>
163+
void IncrementalJobImpl<cooperative>::Discard() {
164+
if (this->executor_)
140165
throw std::logic_error("Cannot discard a job that has already been started");
141166
tasks_.clear();
142-
total_tasks_ = 0;
167+
this->total_tasks_ = 0;
143168
}
144169

145-
void IncrementalJob::Wait() {
146-
DoWait();
170+
template <bool cooperative>
171+
void IncrementalJobImpl<cooperative>::Wait() {
172+
this->DoWait();
147173
// note - this vector is not allocated unless there were exceptions thrown
148174
std::vector<std::exception_ptr> errors;
149175
for (auto &x : tasks_) {
@@ -158,8 +184,15 @@ void IncrementalJob::Wait() {
158184

159185
///////////////////////////////////////////////////////////////////////////
160186

187+
template class JobBase<true>;
188+
template class JobImpl<true>;
189+
template class JobBase<false>;
190+
template class JobImpl<false>;
191+
template class IncrementalJobImpl<true>;
192+
template class IncrementalJobImpl<false>;
193+
161194
thread_local ThreadPoolBase *ThreadPoolBase::this_thread_pool_ = nullptr;
162-
thread_local int ThreadPoolBase::this_thread_idx_ = -1;
195+
thread_local int ThisThreadIdx::this_thread_idx_ = -1;
163196

164197
void ThreadPoolBase::Init(int num_threads, const std::function<OnThreadStartFn> &on_thread_start) {
165198
if (shutdown_pending_)

dali/core/exec/thread_pool_base_test.cc

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,21 @@ TEST(NewThreadPool, RunLargeIncrementalJobInThreadPool) {
153153
template <typename JobType>
154154
class NewThreadPoolJobTest : public ::testing::Test {};
155155

156-
using JobTypes = ::testing::Types<Job, IncrementalJob>;
156+
template <typename JobType>
157+
class NewThreadPoolCooperativeJobTest : public ::testing::Test {};
158+
159+
template <typename JobType>
160+
class NewThreadPoolNonCooperativeJobTest : public ::testing::Test {};
161+
162+
using JobTypes = ::testing::Types<Job, IncrementalJob, CooperativeJob, CooperativeIncrementalJob>;
157163
TYPED_TEST_SUITE(NewThreadPoolJobTest, JobTypes);
158164

165+
using CooperativeJobTypes = ::testing::Types<CooperativeJob, CooperativeIncrementalJob>;
166+
TYPED_TEST_SUITE(NewThreadPoolCooperativeJobTest, CooperativeJobTypes);
167+
168+
using NonCooperativeJobTypes = ::testing::Types<Job, IncrementalJob>;
169+
TYPED_TEST_SUITE(NewThreadPoolNonCooperativeJobTest, NonCooperativeJobTypes);
170+
159171

160172
TYPED_TEST(NewThreadPoolJobTest, RunJobInSeries) {
161173
TypeParam job;
@@ -184,7 +196,7 @@ TYPED_TEST(NewThreadPoolJobTest, Discard) {
184196
});
185197
}
186198

187-
TYPED_TEST(NewThreadPoolJobTest, ErrorIncrementalJobNotStarted) {
199+
TYPED_TEST(NewThreadPoolJobTest, ErrorJobNotStarted) {
188200
try {
189201
TypeParam job;
190202
job.AddTask([]() {});
@@ -195,6 +207,19 @@ TYPED_TEST(NewThreadPoolJobTest, ErrorIncrementalJobNotStarted) {
195207
GTEST_FAIL() << "Expected a logic error.";
196208
}
197209

210+
TYPED_TEST(NewThreadPoolJobTest, ErrorWaitBeforeRun) {
211+
TypeParam job;
212+
try {
213+
job.AddTask([]() {});
214+
job.Wait();
215+
} catch (std::logic_error &e) {
216+
EXPECT_NE(nullptr, strstr(e.what(), "hasn't been run"));
217+
job.Discard();
218+
return;
219+
}
220+
GTEST_FAIL() << "Expected a logic error.";
221+
}
222+
198223
TYPED_TEST(NewThreadPoolJobTest, RethrowMultipleErrors) {
199224
TypeParam job;
200225
ThreadPoolBase tp(4);
@@ -210,7 +235,7 @@ TYPED_TEST(NewThreadPoolJobTest, RethrowMultipleErrors) {
210235
EXPECT_THROW(job.Run(tp, true), MultipleErrors);
211236
}
212237

213-
TYPED_TEST(NewThreadPoolJobTest, Reentrant) {
238+
TYPED_TEST(NewThreadPoolCooperativeJobTest, Reentrant) {
214239
TypeParam job;
215240
ThreadPoolBase tp(1); // must not hang with just one thread
216241
std::atomic_int outer{0}, inner{0};
@@ -221,7 +246,7 @@ TYPED_TEST(NewThreadPoolJobTest, Reentrant) {
221246
}
222247

223248
job.AddTask([&]() {
224-
Job innerJob;
249+
TypeParam innerJob;
225250

226251
for (int i = 0; i < 10; i++)
227252
innerJob.AddTask([&, i]() {
@@ -241,6 +266,25 @@ TYPED_TEST(NewThreadPoolJobTest, Reentrant) {
241266
job.Run(tp, true);
242267
}
243268

269+
TYPED_TEST(NewThreadPoolNonCooperativeJobTest, Reentrant) {
270+
TypeParam job;
271+
ThreadPoolBase tp(1); // must not hang with just one thread
272+
job.AddTask([&]() {
273+
TypeParam innerJob;
274+
innerJob.AddTask([]() {});
275+
276+
try {
277+
innerJob.Run(tp, true);
278+
} catch (...) {
279+
innerJob.Discard();
280+
throw;
281+
}
282+
});
283+
284+
job.Run(tp, false);
285+
EXPECT_THROW(job.Wait(), std::logic_error);
286+
}
287+
244288
TYPED_TEST(NewThreadPoolJobTest, JobPerf) {
245289
using JobType = TypeParam;
246290
ThreadPoolBase tp(4);
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef DALI_CORE_EXEC_THREAD_IDX_H_
16+
#define DALI_CORE_EXEC_THREAD_IDX_H_
17+
18+
#include "dali/core/api_helper.h"
19+
20+
namespace dali {
21+
22+
class DLL_PUBLIC ThisThreadIdx {
23+
public:
24+
/**
25+
* @brief Returns the index of the current thread within the current thread pool
26+
*
27+
* @return the thread index or -1 if the calling thread does not belong to a thread pool
28+
*/
29+
static int this_thread_idx() {
30+
return this_thread_idx_;
31+
}
32+
33+
protected:
34+
static thread_local int this_thread_idx_;
35+
};
36+
37+
} // namespace dali
38+
39+
#endif // DALI_CORE_EXEC_THREAD_IDX_H_

0 commit comments

Comments
 (0)