Skip to content

Commit 0d7df15

Browse files
committed
Validate ORCJIT JITDylib handles before use
1 parent e8e6b6b commit 0d7df15

1 file changed

Lines changed: 30 additions & 1 deletion

File tree

src/gallium/auxiliary/gallivm/lp_bld_init_orc.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <string>
1111
#include <vector>
1212
#include <mutex>
13+
#include <unordered_set>
1314
#include <cstdlib>
1415
#include "lp_bld.h"
1516
#include "lp_bld_debug.h"
@@ -182,7 +183,9 @@ class LPJit
182183
LPJit* jit = get_instance();
183184
std::lock_guard<std::mutex> guard(jit->lookup_mutex);
184185
JITDylib& tmp = ExitOnErr(jit->lljit->createJITDylib(name));
185-
return wrap(&tmp);
186+
LLVMOrcJITDylibRef jd = wrap(&tmp);
187+
jit->live_jd_handles.insert(jd);
188+
return jd;
186189
}
187190

188191
static void register_gallivm_state(gallivm_state *gallivm) {
@@ -206,10 +209,17 @@ class LPJit
206209
using llvm::Module;
207210
using llvm::orc::ThreadSafeModule;
208211
using llvm::orc::JITDylib;
212+
if (!jd) {
213+
return;
214+
}
209215
ThreadSafeModule tsm(
210216
std::unique_ptr<Module>(llvm::unwrap(mod)), *::unwrap(ts_context));
211217
LPJit* jit = get_instance();
212218
std::lock_guard<std::mutex> guard(jit->lookup_mutex);
219+
if (!jit->live_jd_handles.count(jd)) {
220+
debug_printf("ORCJIT skip addIRModule for unknown JITDylib handle %p\n", jd);
221+
return;
222+
}
213223
ExitOnErr(jit->lljit->addIRModule(
214224
*::unwrap(jd), std::move(tsm)
215225
));
@@ -229,9 +239,16 @@ class LPJit
229239
using llvm::orc::ExecutionSession;
230240
using llvm::orc::JITDylib;
231241
using llvm::orc::SymbolMap;
242+
if (!sym || !addr || !jd) {
243+
return;
244+
}
232245
JITDylib* JD = ::unwrap(jd);
233246
LPJit* jit = LPJit::get_instance();
234247
std::lock_guard<std::mutex> guard(jit->lookup_mutex);
248+
if (!jit->live_jd_handles.count(jd)) {
249+
debug_printf("ORCJIT skip addMapping for unknown JITDylib handle %p\n", jd);
250+
return;
251+
}
235252
auto& es = jit->lljit->getExecutionSession();
236253
auto name = es.intern(llvm::unwrap(sym)->getName());
237254
SymbolMap map(1);
@@ -251,13 +268,20 @@ class LPJit
251268
using llvm::orc::JITDylib;
252269
using llvm::JITEvaluatedSymbol;
253270
using llvm::orc::ExecutorAddr;
271+
if (!func_name || !jd) {
272+
return nullptr;
273+
}
254274
JITDylib* JD = ::unwrap(jd);
255275
LPJit* jit = get_instance();
256276
llvm::ObjectCache *objcache = nullptr;
257277
auto &ircl = jit->lljit->getIRCompileLayer();
258278
auto &irc = ircl.getCompiler();
259279
auto &sc = dynamic_cast<llvm::orc::SimpleCompiler &>(irc);
260280
std::lock_guard<std::mutex> guard(jit->lookup_mutex);
281+
if (!jit->live_jd_handles.count(jd)) {
282+
debug_printf("ORCJIT skip lookup for unknown JITDylib handle %p (%s)\n", jd, func_name);
283+
return nullptr;
284+
}
261285
if (gallivm && gallivm->cache) {
262286
objcache = (LPObjectCacheORC *)gallivm->cache->jit_obj_cache;
263287
}
@@ -280,6 +304,10 @@ class LPJit
280304
using llvm::orc::JITDylib;
281305
LPJit* jit = LPJit::get_instance();
282306
std::lock_guard<std::mutex> guard(jit->lookup_mutex);
307+
if (!jit->live_jd_handles.count(jd)) {
308+
return;
309+
}
310+
jit->live_jd_handles.erase(jd);
283311
auto& es = jit->lljit->getExecutionSession();
284312
ExitOnErr(es.removeJITDylib(* ::unwrap(jd)));
285313
#endif
@@ -314,6 +342,7 @@ class LPJit
314342
unsigned jit_dylib_count;
315343

316344
std::mutex lookup_mutex;
345+
std::unordered_set<LLVMOrcJITDylibRef> live_jd_handles;
317346

318347
#if DEBUG
319348
/* map from module name to gallivm_state */

0 commit comments

Comments
 (0)