Skip to content

Add non-cooperative jobs to new ThreadPool#6245

Merged
mzient merged 8 commits intoNVIDIA:mainfrom
mzient:NewThreadPoolNonCooperativeTasks
Mar 5, 2026
Merged

Add non-cooperative jobs to new ThreadPool#6245
mzient merged 8 commits intoNVIDIA:mainfrom
mzient:NewThreadPoolNonCooperativeTasks

Conversation

@mzient
Copy link
Contributor

@mzient mzient commented Mar 3, 2026

Category:

New feature (non-breaking change which adds functionality)
Refactoring (Redesign of existing code that doesn't affect functionality)

Description:

This PR adds more lightweight non-cooperative jobs to new thread pool.
Prior to this change, all jobs had a cooperative Wait option - this enabled Wait to be called from within the same thread pool as the one in which the job is running. To avoid deadlocks, the Wait would pick up tasks for execution until the Job is complete. This cannot be implemented (as of C++20) with atomic_wait and required extra mutex/condvar. While the mutex would remain unused without cooperative wait, the notification of the condition as always necessary.
This change adds a flavor of Job that doesn't allow cooperative wait and therefore can be implemented with just atomic_wait.

Additional information:

Affected modules and functionalities:

Key points relevant for the review:

Tests:

  • Existing tests apply
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: N/A

mzient added 2 commits March 3, 2026 15:00
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45240584]: BUILD STARTED

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Greptile Summary

This PR templatizes JobBase, JobImpl, and IncrementalJobImpl on a cooperative boolean, introducing non-cooperative variants (Job, IncrementalJob) alongside the existing cooperative variants (CooperativeJob, CooperativeIncrementalJob). Non-cooperative jobs throw std::logic_error when Run() or Wait() is called from inside the owning thread pool rather than blocking cooperatively. It also refactors the this_thread_idx_ thread-local into a new standalone ThisThreadIdx base class in thread_idx.h.

Key changes:

  • JobBase<cooperative> — cooperative mutex/condvar fields are now conditional via JobBaseFields<cooperative> (empty for false), reducing overhead for the non-cooperative path.
  • wait_started_ state-corruption fixwait_started_ is no longer set unconditionally at the top of DoWait(); it is only set after passing all precondition checks (executor_ == nullptr, non-cooperative reentrance), preventing a poisoned job from permanently rejecting a legitimate subsequent Wait() call.
  • New test fixtures NewThreadPoolCooperativeJobTest and NewThreadPoolNonCooperativeJobTest with a new Reentrant test for non-cooperative jobs and a new ErrorWaitBeforeRun test specifically validating the state-corruption fix.
  • ThreadPoolBase::GetThreadIds() is a new public API method added without test coverage.

Confidence Score: 4/5

  • This PR is safe to merge; the core logic is sound and the state-corruption fix is correctly implemented.
  • The refactoring is mechanically clean — template introduction is consistent, explicit instantiations are complete, and the wait_started_ state-corruption fix is correctly applied (the flag is only set after all precondition throws). New test coverage is appropriate for the non-cooperative reentrance paths. The one point of deduction is that GetThreadIds(), a new public API method, has no test and no documented thread-safety contract.
  • include/dali/core/exec/thread_pool_base.h — the new GetThreadIds() method lacks tests and thread-safety documentation.

Important Files Changed

Filename Overview
dali/core/exec/thread_pool_base.cc Refactors JobBase, JobImpl, and IncrementalJobImpl into class templates parameterized on cooperative; fixes the wait_started_ state-corruption bug (no longer set before throwing on executor_==nullptr or non-cooperative reentrance); adds explicit instantiations for both cooperative and non-cooperative variants.
include/dali/core/exec/thread_pool_base.h Introduces JobBaseFields<cooperative> to conditionally include mtx_/cv_ only for cooperative variants; adds JobImpl<cooperative>, IncrementalJobImpl<cooperative>, and four public type aliases (Job, CooperativeJob, IncrementalJob, CooperativeIncrementalJob); adds GetThreadIds() to ThreadPoolBase; ThreadPoolBase now inherits from ThisThreadIdx.
include/dali/core/exec/thread_idx.h New header extracting the this_thread_idx_ thread-local into a standalone ThisThreadIdx base class so it can be reused independently of ThreadPoolBase. Simple and correct.
dali/core/exec/thread_pool_base_test.cc Adds NewThreadPoolCooperativeJobTest and NewThreadPoolNonCooperativeJobTest typed-test fixtures; adds ErrorWaitBeforeRun test validating the wait_started_ fix; adds non-cooperative reentrance test; fixes Reentrant cooperative test to use TypeParam innerJob instead of a hardcoded Job.

Class Diagram

%%{init: {'theme': 'neutral'}}%%
classDiagram
    class ThisThreadIdx {
        +static int this_thread_idx()
        #static thread_local int this_thread_idx_
    }

    class JobBaseFields_false {
        <<empty specialization>>
    }

    class JobBaseFields_true {
        #std::mutex mtx_
        #std::condition_variable cv_
    }

    class JobBase_cooperative {
        <<template bool cooperative>>
        #DoWait()
        #DoNotify()
        #std::atomic_int num_pending_tasks_
        #std::atomic_bool running_
        #int total_tasks_
        #bool wait_started_
        #bool wait_completed_
        #const void* executor_
    }

    class JobImpl_cooperative {
        <<template bool cooperative>>
        +AddTask(runnable, priority)
        +Run(ThreadPoolBase, wait)
        +Run(Executor, wait)
        +Wait()
        +Discard()
    }

    class IncrementalJobImpl_cooperative {
        <<template bool cooperative>>
        +AddTask(runnable)
        +Run(ThreadPoolBase, wait)
        +Run(Executor, wait)
        +Wait()
        +Discard()
    }

    class ThreadPoolBase {
        +Init(num_threads, on_thread_start)
        +AddTask(f)
        +NumThreads()
        +GetThreadIds()
        +static this_thread_pool()
    }

    %% Aliases
    class Job["Job = JobImpl&lt;false&gt;"] 
    class CooperativeJob["CooperativeJob = JobImpl&lt;true&gt;"]
    class IncrementalJob["IncrementalJob = IncrementalJobImpl&lt;false&gt;"]
    class CooperativeIncrementalJob["CooperativeIncrementalJob = IncrementalJobImpl&lt;true&gt;"]

    JobBase_cooperative --|> JobBaseFields_false : false specialization
    JobBase_cooperative --|> JobBaseFields_true : true specialization
    JobImpl_cooperative --|> JobBase_cooperative
    IncrementalJobImpl_cooperative --|> JobBase_cooperative
    ThreadPoolBase --|> ThisThreadIdx

    Job ..> JobImpl_cooperative : alias false
    CooperativeJob ..> JobImpl_cooperative : alias true
    IncrementalJob ..> IncrementalJobImpl_cooperative : alias false
    CooperativeIncrementalJob ..> IncrementalJobImpl_cooperative : alias true
Loading

Last reviewed commit: 8155ff5

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 18 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +200 to +206
* A cooperative job can be waited for from inside the thread pool the job is running in. While
* the calling thread executes `Wait` on the job, some scheduled task might be picked up from the
* thread pool and executed int the contex of the calling thread.
*/
using CooperativeJob = JobImpl<true>;

/** Cooperative incremental job.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in docstring: "int the contex" should be "in the context".

Suggested change
* A cooperative job can be waited for from inside the thread pool the job is running in. While
* the calling thread executes `Wait` on the job, some scheduled task might be picked up from the
* thread pool and executed int the contex of the calling thread.
*/
using CooperativeJob = JobImpl<true>;
/** Cooperative incremental job.
* A cooperative job can be waited for from inside the thread pool the job is running in. While
* the calling thread executes `Wait` on the job, some scheduled task might be picked up from the
* thread pool and executed in the context of the calling thread.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (1)

dali/core/exec/thread_pool_base_test.cc
The PR introduces a new runtime check: calling Wait() on a non-cooperative job from inside the thread pool throws std::logic_error("Cannot wait for this job from inside the thread pool.") (see thread_pool_base.cc:52). This error path is not covered by any test. The Reentrant test only covers cooperative jobs.

A test should be added to verify that non-cooperative Job and IncrementalJob types throw when Wait() is called from within a thread-pool task, ensuring this guard remains effective against regressions.

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45243011]: BUILD STARTED

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
@mzient
Copy link
Contributor Author

mzient commented Mar 3, 2026

@greptileai There are no existing usages, so there's no backward compatibility issue.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Got it - if there are no existing usages of this interface yet, then backward compatibility isn't a concern here. Thanks for clarifying!

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45243643]: BUILD STARTED

});

job.Run(tp, false);
EXPECT_THROW(job.Wait(), std::logic_error);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test should verify exception message for robustness:

The non-cooperative reentrant test currently only checks the exception type:

EXPECT_THROW(job.Wait(), std::logic_error);

It doesn't verify the actual error message. Adding a message check would make the test more robust against accidentally catching a different std::logic_error from unrelated code.

Recommendation: Use EXPECT_THAT with a matcher to verify the message:

EXPECT_THROW(
  {
    try {
      job.Wait();
      FAIL() << "Expected std::logic_error";
    } catch (const std::logic_error &e) {
      EXPECT_THAT(e.what(), testing::HasSubstr("Cannot run this job from inside the thread pool"));
      throw;
    }
  },
  std::logic_error
);

Or wrap the Wait call and check the message separately.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

mzient added 2 commits March 3, 2026 18:17
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
Comment on lines +180 to 184
/** Runs the job in the thread pool.
*
* More tasks can be added after this call, but they won't start until another call to Run.
*/
void Run(ThreadPoolBase &tp, bool wait);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IncrementalJobImpl::Run() docstring incompletely describes wait parameter behavior

The docstring says "More tasks can be added after this call, but they won't start until another call to Run." This is only true when wait = false. When wait = true, the method internally calls Wait(), which sets wait_started_ = true and blocks subsequent AddTask() calls.

Fix: Clarify the wait parameter's effect:

Suggested change
/** Runs the job in the thread pool.
*
* More tasks can be added after this call, but they won't start until another call to Run.
*/
void Run(ThreadPoolBase &tp, bool wait);
/** Runs the job in the thread pool.
*
* If `wait` is false, more tasks can be added after this call, but they won't start
* until another call to Run. If `wait` is true, this call also waits for completion,
* after which no more tasks can be added (equivalent to calling Wait()).
*/
void Run(ThreadPoolBase &tp, bool wait);

Comment on lines 31 to 42
@@ -36,15 +38,19 @@ void JobBase::DoWait() {
return;
}

if (executor_ == nullptr)
if (this->executor_ == nullptr)
throw std::logic_error("This job hasn't been run - cannot wait for it.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait_started_ poisoned when executor_ == nullptr

wait_started_ is set to true at line 34 before the executor_ == nullptr guard at line 41. If Wait() is called on a job that has tasks but has not yet been Run(), the exception "This job hasn't been run - cannot wait for it." is correctly thrown — but wait_started_ remains true while wait_completed_ remains false.

A subsequent call sequence like:

Job job;
job.AddTask([]() {});
try { job.Wait(); } catch (...) {}  // forgot to Run first
job.Run(tp, false);
job.Wait();  // ← throws "already been waited for" — wrong!

…silently fails to wait for the job, and then the destructor throws "The job is not empty, but hasn't been discarded or waited for." because wait_completed_ is never set. The error message seen by the caller is also misleading since the job was never successfully waited for.

Fix: Move all precondition checks (the executor_ == nullptr and non-cooperative reentrance checks) above the wait_started_ = true assignment, so the flag is only mutated when the wait is actually going to proceed.

mzient added 2 commits March 3, 2026 20:46
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
@mzient
Copy link
Contributor Author

mzient commented Mar 3, 2026

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45262286]: BUILD STARTED

Comment on lines +301 to 308
/** Returns the ids of the threads in the thread pool */
auto GetThreadIds() const {
int n = threads_.size();
std::vector<std::thread::id> ids(n);
for (int i = 0; i < n; i++)
ids[i] = threads_[i].get_id();
return ids;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New public API GetThreadIds() has no test coverage

GetThreadIds() is a new public method added in this PR but is not exercised by any test in thread_pool_base_test.cc. Given that it reads threads_ without acquiring mtx_, it would also be useful to document the thread-safety contract (i.e. that it is safe to call after Init() returns and before Shutdown() begins).

Consider adding a simple test — e.g. verifying that GetThreadIds() returns exactly NumThreads() unique IDs after Init().

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45243643]: BUILD FAILED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45262286]: BUILD PASSED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45243643]: BUILD PASSED

using CooperativeIncrementalJob = IncrementalJobImpl<true>;

class DLL_PUBLIC ThreadPoolBase {
class DLL_PUBLIC ThreadPoolBase : public ThisThreadIdx {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick:

Suggested change
class DLL_PUBLIC ThreadPoolBase : public ThisThreadIdx {
class DLL_PUBLIC ThreadPoolBase : private ThisThreadIdx {

My understanding is that the thread index is only supposed to be accessed through ThisThreadIdx. Is that correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really - in derived classes it can be used as a member, with an unqualified name. It also makes sense (to some extent) to use ThreadPoolType::this_thread_idx()

@mzient mzient merged commit 4fc92ab into NVIDIA:main Mar 5, 2026
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants