diff --git a/src/Python/Inline/Literal.hs b/src/Python/Inline/Literal.hs index 5a4f472..72fbbbd 100644 --- a/src/Python/Inline/Literal.hs +++ b/src/Python/Inline/Literal.hs @@ -607,7 +607,7 @@ instance (FromPy a1, FromPy a2, ToPy b) => ToPy (a1 -> a2 -> IO b) where -- | Execute haskell callback function pyCallback :: Program (Ptr PyObject) (Ptr PyObject) -> IO (Ptr PyObject) -pyCallback io = callbackEnsurePyLock $ unPy $ ensureGIL $ runProgram io `catch` convertHaskell2Py +pyCallback io = callbackEnsurePyLock $ unsafeRunPy $ ensureGIL $ runProgram io `catch` convertHaskell2Py -- | Load argument from python object for haskell evaluation loadArg diff --git a/src/Python/Internal/Eval.hs b/src/Python/Internal/Eval.hs index c46b097..8389ca4 100644 --- a/src/Python/Internal/Eval.hs +++ b/src/Python/Internal/Eval.hs @@ -15,7 +15,7 @@ module Python.Internal.Eval -- * Evaluator , runPy , runPyInMain - , unPy + , unsafeRunPy -- * GC-related , newPyObject -- * C-API wrappers @@ -41,6 +41,7 @@ import Control.Monad.Catch import Control.Monad.IO.Class import Control.Monad.Trans.Cont import Data.Maybe +import Data.Function import Foreign.Concurrent qualified as GHC import Foreign.Ptr import Foreign.ForeignPtr @@ -273,13 +274,39 @@ releaseLock tid = readTVar globalPyLock >>= \case initializePython :: IO () -- See NOTE: [Python and threading] initializePython = [CU.exp| int { Py_IsInitialized() } |] >>= \case - 0 | rtsSupportsBoundThreads -> runInBoundThread $ mask_ $ doInializePython - | otherwise -> mask_ $ doInializePython + 0 | rtsSupportsBoundThreads -> runInBoundThread $ doInializePython + | otherwise -> doInializePython _ -> pure () -- | Destroy python interpreter. finalizePython :: IO () -finalizePython = mask_ doFinalizePython +finalizePython = join $ atomically $ readTVar globalPyState >>= \case + NotInitialized -> throwSTM PythonNotInitialized + InitFailed -> throwSTM PythonIsFinalized + Finalized -> pure $ pure () + InInitialization -> retry + InFinalization -> retry + -- We can simply call Py_Finalize + Running1 -> checkLock $ [C.block| void { + PyGILState_Ensure(); + Py_Finalize(); + } |] + -- We need to call Py_Finalize on main thread + RunningN _ eval _ tid_gc -> checkLock $ do + killThread tid_gc + resp <- newEmptyMVar + putMVar eval $ StopReq resp + takeMVar resp + where + checkLock action = readTVar globalPyLock >>= \case + LockUninialized -> throwSTM $ PyInternalError "finalizePython LockUninialized" + LockFinalized -> throwSTM $ PyInternalError "finalizePython LockFinalized" + Locked{} -> retry + LockedByGC -> retry + LockUnlocked -> do + writeTVar globalPyLock LockFinalized + writeTVar globalPyState Finalized + pure action -- | Bracket which ensures that action is executed with properly -- initialized interpreter @@ -303,7 +330,6 @@ doInializePython = do let fini st = atomically $ do writeTVar globalPyState $ st writeTVar globalPyLock $ LockUnlocked - pure $ (mask_ $ if -- On multithreaded runtime create bound thread to make @@ -335,22 +361,18 @@ mainThread lock_init lock_eval = do putMVar lock_init r_init case r_init of False -> pure () - True -> mask_ $ do - let loop - = handle (\InterruptMain -> pure ()) - $ takeMVar lock_eval >>= \case - EvalReq py resp -> do - res <- (Right <$> runPy py) `catch` (pure . Left) - putMVar resp res - loop - StopReq resp -> do - [C.block| void { - PyGILState_Ensure(); - Py_Finalize(); - } |] - putMVar resp () - loop - + True -> mask_ $ fix $ \loop -> + takeMVar lock_eval >>= \case + EvalReq py resp -> do + res <- (Right <$> runPy py) `catch` (pure . Left) + putMVar resp res + loop + StopReq resp -> do + [C.block| void { + PyGILState_Ensure(); + Py_Finalize(); + } |] + putMVar resp () doInializePythonIO :: IO Bool @@ -401,35 +423,6 @@ doInializePythonIO = do } |] return $! r == 0 -doFinalizePython :: IO () -doFinalizePython = join $ atomically $ readTVar globalPyState >>= \case - NotInitialized -> throwSTM PythonNotInitialized - InitFailed -> throwSTM PythonIsFinalized - Finalized -> pure $ pure () - InInitialization -> retry - InFinalization -> retry - -- We can simply call Py_Finalize - Running1 -> checkLock $ [C.block| void { - PyGILState_Ensure(); - Py_Finalize(); - } |] - -- We need to call Py_Finalize on main thread - RunningN _ eval _ tid_gc -> checkLock $ do - killThread tid_gc - resp <- newEmptyMVar - putMVar eval $ StopReq resp - takeMVar resp - where - checkLock action = readTVar globalPyLock >>= \case - LockUninialized -> throwSTM $ PyInternalError "doFinalizePython LockUninialized" - LockFinalized -> throwSTM $ PyInternalError "doFinalizePython LockFinalized" - Locked{} -> retry - LockedByGC -> retry - LockUnlocked -> do - writeTVar globalPyLock LockFinalized - writeTVar globalPyState Finalized - pure action - ---------------------------------------------------------------- -- Running Py monad @@ -454,7 +447,7 @@ runPy py where -- We check whether interpreter is initialized. Throw exception if -- it wasn't. Better than segfault isn't it? - go = ensurePyLock $ unPy (ensureGIL py) + go = ensurePyLock $ mask_ $ unsafeRunPy (ensureGIL py) -- | Same as 'runPy' but will make sure that code is run in python's -- main thread. It's thread in which python's interpreter was @@ -464,7 +457,11 @@ runPyInMain :: Py a -> IO a -- See NOTE: [Python and threading] runPyInMain py -- Multithreaded RTS - | rtsSupportsBoundThreads = join $ atomically $ readTVar globalPyState >>= \case + | rtsSupportsBoundThreads = bracket acquireMain releaseMain evalMain + -- Single-threaded RTS + | otherwise = runPy py + where + acquireMain = atomically $ readTVar globalPyState >>= \case NotInitialized -> throwSTM PythonNotInitialized InitFailed -> throwSTM PyInitializationFailed Finalized -> throwSTM PythonIsFinalized @@ -473,19 +470,20 @@ runPyInMain py Running1 -> throwSTM $ PyInternalError "runPyInMain: Running1" RunningN _ eval tid_main _ -> do acquireLock tid_main - pure - $ flip finally (atomically (releaseLock tid_main)) - $ flip onException (throwTo tid_main InterruptMain) - $ do resp <- newEmptyMVar - putMVar eval $ EvalReq py resp - either throwM pure =<< takeMVar resp - -- Single-threaded RTS - | otherwise = runPy py + pure (tid_main, eval) + -- + releaseMain (tid_main, _ ) = atomically (releaseLock tid_main) + evalMain (tid_main, eval) = do + r <- mask_ $ do resp <- newEmptyMVar + putMVar eval $ EvalReq py resp + takeMVar resp `onException` throwTo tid_main InterruptMain + either throwM pure r + -- | Execute python action. This function is unsafe and should be only -- called in thread of interpreter. -unPy :: Py a -> IO a -unPy (Py io) = io +unsafeRunPy :: Py a -> IO a +unsafeRunPy (Py io) = io diff --git a/test/TST/Run.hs b/test/TST/Run.hs index ed547c6..fef8c06 100644 --- a/test/TST/Run.hs +++ b/test/TST/Run.hs @@ -2,6 +2,8 @@ -- Tests for variable scope and names module TST.Run(tests) where +import Control.Concurrent +import Control.Exception import Control.Monad import Control.Monad.IO.Class import Test.Tasty @@ -19,7 +21,24 @@ tests = testGroup "Run python" import threading assert threading.main_thread() == threading.current_thread() |] - , testCase "Python exceptions are converted" $ runPy $ throwsPy [py_| 1 / 0 |] + , testCase "Python exceptions are converted (py)" $ runPy $ throwsPy [py_| 1 / 0 |] + , testCase "Python exceptions are converted (std)" $ throwsPyIO $ runPy [py_| 1 / 0 |] + , testCase "Python exceptions are converted (main)" $ throwsPyIO $ runPyInMain [py_| 1 / 0 |] + , testCase "Main doesn't deadlock after exception" $ do + throwsPyIO $ runPyInMain [py_| 1 / 0 |] + runPyInMain [py_| assert True |] + -- Here we test that exceptions are really passed to python's thread without running python + , testCase "Exception in runPyInMain works" $ do + lock <- newEmptyMVar + tid <- myThreadId + _ <- forkIO $ takeMVar lock >> throwTo tid Stop + handle (\Stop -> pure ()) + $ runPyInMain + $ do liftIO $ putMVar lock () + liftIO $ threadDelay 10_000_000 + error "Should be interrupted" + runPyInMain $ pure () + -- , testCase "Scope pymain->any" $ runPy $ do [pymain| x = 12 @@ -112,3 +131,7 @@ tests = testGroup "Run python" pass |] ] + +data Stop = Stop + deriving stock Show + deriving anyclass Exception diff --git a/test/TST/Util.hs b/test/TST/Util.hs index 8295cd4..df918e8 100644 --- a/test/TST/Util.hs +++ b/test/TST/Util.hs @@ -12,3 +12,7 @@ throwsPy :: Py () -> Py () throwsPy io = (io >> liftIO (assertFailure "Evaluation should raise python exception")) `catch` (\(_::PyError) -> pure ()) +throwsPyIO :: IO () -> IO () +throwsPyIO io = (io >> assertFailure "Evaluation should raise python exception") + `catch` (\(_::PyError) -> pure ()) +