diff --git a/.github/workflows/haskell-ci.yml b/.github/workflows/haskell-ci.yml index 663bb73..cc6c1ef 100644 --- a/.github/workflows/haskell-ci.yml +++ b/.github/workflows/haskell-ci.yml @@ -23,7 +23,7 @@ on: jobs: linux: name: Haskell-CI - Linux - ${{ matrix.compiler }} - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 timeout-minutes: 60 container: diff --git a/.gitignore b/.gitignore index 2fd2c86..e735e0b 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ /dist-newstyle/ /cabal.project.local /cabal.project.freeze +/haddocks/ TAGS .ghc.environment.* .cabal-sandbox diff --git a/consumers/CHANGELOG.md b/consumers/CHANGELOG.md index 36e7fd2..5aea51b 100644 --- a/consumers/CHANGELOG.md +++ b/consumers/CHANGELOG.md @@ -1,3 +1,12 @@ +# consumers-2.3.3.2 (XXX-XX-XX) +* Log batch size limits when processing. +* Split off testing utilities into separate module. +* Ensure tables are cleaned up on test-teardown. +* Allow passing in PG variables through environment variables. +* Expose inlining information so it is possible to specialize consumers at + call-sites. +* Bump Ubuntu image used in CI. + # consumers-2.3.3.1 (2025-04-03) * Do not prepare query that updates jobs in the monitor thread. diff --git a/consumers/consumers.cabal b/consumers/consumers.cabal index 94a4a27..1332f60 100644 --- a/consumers/consumers.cabal +++ b/consumers/consumers.cabal @@ -105,6 +105,8 @@ test-suite consumers-test monad-time, mtl, stm, + tasty, + tasty-hunit, text, time, transformers, @@ -114,3 +116,4 @@ test-suite consumers-test type: exitcode-stdio-1.0 main-is: Test.hs + other-modules: Util diff --git a/consumers/src/Database/PostgreSQL/Consumers/Components.hs b/consumers/src/Database/PostgreSQL/Consumers/Components.hs index 316a177..17358d5 100644 --- a/consumers/src/Database/PostgreSQL/Consumers/Components.hs +++ b/consumers/src/Database/PostgreSQL/Consumers/Components.hs @@ -50,6 +50,7 @@ runConsumer -- ^ The consumer. -> ConnectionSourceM m -> m (m ()) +{-# INLINEABLE runConsumer #-} runConsumer cc cs = runConsumerWithMaybeIdleSignal cc cs Nothing runConsumerWithIdleSignal @@ -67,6 +68,7 @@ runConsumerWithIdleSignal -> ConnectionSourceM m -> TMVar Bool -> m (m ()) +{-# INLINEABLE runConsumerWithIdleSignal #-} runConsumerWithIdleSignal cc cs idleSignal = runConsumerWithMaybeIdleSignal cc cs (Just idleSignal) -- | Run the consumer and also signal whenever the consumer is waiting for @@ -85,12 +87,13 @@ runConsumerWithMaybeIdleSignal -> ConnectionSourceM m -> Maybe (TMVar Bool) -> m (m ()) +{-# INLINEABLE runConsumerWithMaybeIdleSignal #-} runConsumerWithMaybeIdleSignal cc0 cs mIdleSignal | ccMaxRunningJobs cc < 1 = do logInfo_ "ccMaxRunningJobs < 1, not starting the consumer" pure $ pure () | otherwise = do - semaphore <- newMVar () + (triggerNotification, listenNotification) <- mkNotification runningJobsInfo <- liftBase $ newTVarIO M.empty runningJobs <- liftBase $ newTVarIO 0 @@ -102,7 +105,7 @@ runConsumerWithMaybeIdleSignal cc0 cs mIdleSignal cid <- registerConsumer cc cs localData ["consumer_id" .= show cid] $ do - listener <- spawnListener cc cs semaphore + listener <- spawnListener cc cs triggerNotification monitor <- localDomain "monitor" $ spawnMonitor cc cs cid dispatcher <- localDomain "dispatcher" $ @@ -110,7 +113,7 @@ runConsumerWithMaybeIdleSignal cc0 cs mIdleSignal cc cs cid - semaphore + listenNotification runningJobsInfo runningJobs mIdleSignal @@ -184,9 +187,10 @@ spawnListener :: (MonadBaseControl IO m, MonadMask m) => ConsumerConfig m idx job -> ConnectionSourceM m - -> MVar () + -> TriggerNotification m -> m ThreadId -spawnListener cc cs semaphore = +{-# INLINEABLE spawnListener #-} +spawnListener cc cs outbox = forkP "listener" $ case ccNotificationChannel cc of Just chan -> @@ -204,8 +208,7 @@ spawnListener cc cs semaphore = liftBase . threadDelay $ ccNotificationTimeout cc signalDispatcher where - signalDispatcher = do - liftBase $ tryPutMVar semaphore () + signalDispatcher = triggerNotification outbox noTs = defaultTransactionSettings @@ -228,6 +231,7 @@ spawnMonitor -> ConnectionSourceM m -> ConsumerID -> m ThreadId +{-# INLINEABLE spawnMonitor #-} spawnMonitor ConsumerConfig {..} cs cid = forkP "monitor" . forever $ do runDBT cs ts $ do now <- currentTime @@ -309,14 +313,18 @@ spawnDispatcher => ConsumerConfig m idx job -> ConnectionSourceM m -> ConsumerID - -> MVar () + -> ListenNotification m -> TVar (M.Map ThreadId idx) -> TVar Int -> Maybe (TMVar Bool) -> m ThreadId -spawnDispatcher ConsumerConfig {..} cs cid semaphore runningJobsInfo runningJobs mIdleSignal = +{-# INLINEABLE spawnDispatcher #-} +spawnDispatcher ConsumerConfig {..} cs cid inbox runningJobsInfo runningJobs mIdleSignal = forkP "dispatcher" . forever $ do - void $ takeMVar semaphore + listenNotification inbox + -- When awoken, we always start slow, processing only a single job in a + -- batch. Each time we can fill a batch completely with jobs, we grow the maximum + -- batch size. someJobWasProcessed <- loop 1 if someJobWasProcessed then setIdle False @@ -336,7 +344,9 @@ spawnDispatcher ConsumerConfig {..} cs cid semaphore runningJobsInfo runningJobs logInfo "Processing batch" $ object [ "batch_size" .= batchSize + , "limit" .= limit ] + -- Update runningJobs before forking so that we can adjust -- maxBatchSize appropriately later. We also need to mask asynchronous -- exceptions here as we rely on correct value of runningJobs to @@ -349,9 +359,11 @@ spawnDispatcher ConsumerConfig {..} cs cid semaphore runningJobsInfo runningJobs . forkP "batch processor" . (`finally` subtractJobs) . restore - $ do - mapM startJob batch >>= mapM joinJob >>= updateJobs + $ mapM startJob batch >>= mapM joinJob >>= updateJobs + -- Induce some backpressure. If the number of running jobs by all batch + -- processors exceed the global limit, we wait. If it does not, start a + -- new iteration with a double the limit when (batchSize == limit) $ do maxBatchSize <- atomically $ do jobs <- readTVar runningJobs @@ -433,6 +445,7 @@ updateJobsQuery -> [(idx, Result)] -> UTCTime -> SQL +{-# INLINEABLE updateJobsQuery #-} updateJobsQuery jobsTable results now = smconcat [ "WITH removed AS (" diff --git a/consumers/src/Database/PostgreSQL/Consumers/Utils.hs b/consumers/src/Database/PostgreSQL/Consumers/Utils.hs index bd7a830..d2712b5 100644 --- a/consumers/src/Database/PostgreSQL/Consumers/Utils.hs +++ b/consumers/src/Database/PostgreSQL/Consumers/Utils.hs @@ -5,6 +5,9 @@ module Database.PostgreSQL.Consumers.Utils , forkP , gforkP , preparedSqlName + , TriggerNotification (triggerNotification) + , ListenNotification (listenNotification) + , mkNotification ) where import Control.Concurrent.Lifted @@ -14,6 +17,7 @@ import Control.Exception.Lifted qualified as E import Control.Monad.Base import Control.Monad.Catch import Control.Monad.Trans.Control +import Data.Functor (void) import Data.Maybe import Data.Text qualified as T import Database.PostgreSQL.PQTypes.Class @@ -22,6 +26,7 @@ import Database.PostgreSQL.PQTypes.SQL.Raw -- | Run an action 'm' that returns a finalizer and perform the returned -- finalizer after the action 'action' completes. finalize :: (MonadMask m, MonadBase IO m) => m (m ()) -> m a -> m a +{-# INLINEABLE finalize #-} finalize m action = do finalizer <- newEmptyMVar flip finally (tryTakeMVar finalizer >>= fromMaybe (pure ())) $ do @@ -49,6 +54,7 @@ instance Exception ThrownFrom -- | Stop execution of a thread. stopExecution :: MonadBase IO m => ThreadId -> m () +{-# INLINEABLE stopExecution #-} stopExecution = flip throwTo StopExecution ---------------------------------------- @@ -56,6 +62,7 @@ stopExecution = flip throwTo StopExecution -- | Modified version of 'fork' that propagates thrown exceptions to the parent -- thread. forkP :: MonadBaseControl IO m => String -> m () -> m ThreadId +{-# INLINEABLE forkP #-} forkP = forkImpl fork -- | Modified version of 'TG.fork' that propagates thrown exceptions to the @@ -66,6 +73,7 @@ gforkP -> String -> m () -> m (ThreadId, m (T.Result ())) +{-# INLINEABLE gforkP #-} gforkP = forkImpl . TG.fork ---------------------------------------- @@ -76,6 +84,7 @@ forkImpl -> String -> m () -> m a +{-# INLINEABLE forkImpl #-} forkImpl ffork tname m = E.mask $ \release -> do parent <- myThreadId ffork $ @@ -86,3 +95,18 @@ forkImpl ffork tname m = E.mask $ \release -> do preparedSqlName :: T.Text -> RawSQL () -> QueryName preparedSqlName baseName tableName = QueryName . T.take 63 $ baseName <> "$" <> unRawSQL tableName + +---------------------------------------- + +newtype TriggerNotification m = TriggerNotification {triggerNotification :: m ()} + +newtype ListenNotification m = ListenNotification {listenNotification :: m ()} + +mkNotification :: MonadBaseControl IO m => m (TriggerNotification m, ListenNotification m) +{-# INLINEABLE mkNotification #-} +mkNotification = do + notificationRef <- newEmptyMVar + pure + ( TriggerNotification . void $ tryPutMVar notificationRef () + , ListenNotification $ takeMVar notificationRef + ) diff --git a/consumers/test/Test.hs b/consumers/test/Test.hs index fd2b61d..c99d2b0 100644 --- a/consumers/test/Test.hs +++ b/consumers/test/Test.hs @@ -1,171 +1,100 @@ module Main where import Control.Concurrent.STM -import Control.Exception import Control.Monad -import Control.Monad.Base import Control.Monad.Catch import Control.Monad.IO.Class -import Control.Monad.State.Strict +import Control.Monad.RWS import Control.Monad.Time -import Control.Monad.Trans.Control import Data.Int import Data.Text qualified as T -import Data.Time import Database.PostgreSQL.Consumers import Database.PostgreSQL.PQTypes -import Database.PostgreSQL.PQTypes.Checks import Database.PostgreSQL.PQTypes.Model import Log import Log.Backend.StandardOutput -import System.Environment -import System.Exit import Test.HUnit qualified as T - -data TestEnvSt = TestEnvSt - { teCurrentTime :: UTCTime - , teMonotonicTime :: Double - } - -type InnerTestEnv = StateT TestEnvSt (DBT (LogT IO)) - -newtype TestEnv a = TestEnv {unTestEnv :: InnerTestEnv a} - deriving (Applicative, Functor, Monad, MonadLog, MonadDB, MonadThrow, MonadCatch, MonadMask, MonadIO, MonadBase IO, MonadState TestEnvSt) - -instance MonadBaseControl IO TestEnv where - type StM TestEnv a = StM InnerTestEnv a - liftBaseWith f = TestEnv $ liftBaseWith (\run -> f $ run . unTestEnv) - restoreM = TestEnv . restoreM - -instance MonadTime TestEnv where - currentTime = gets teCurrentTime - monotonicTime = gets teMonotonicTime - -modifyTestTime :: MonadState TestEnvSt m => (UTCTime -> UTCTime) -> m () -modifyTestTime modtime = modify (\te -> te {teCurrentTime = modtime . teCurrentTime $ te}) - -runTestEnv :: ConnectionSourceM (LogT IO) -> Logger -> TestEnv a -> IO a -runTestEnv connSource logger = - runLogT "consumers-test" logger defaultLogLevel - . runDBT connSource defaultTransactionSettings - . (\m' -> fst <$> runStateT m' (TestEnvSt (UTCTime (ModifiedJulianDay 0) 0) 0)) - . unTestEnv +import Test.Tasty +import Test.Tasty.HUnit +import Util main :: IO () -main = void . T.runTestTT $ T.TestCase test - -test :: IO () -test = do - connString <- - getArgs >>= \case - connString : _args -> pure $ T.pack connString - [] -> - lookupEnv "GITHUB_ACTIONS" >>= \case - Just "true" -> pure "host=postgres user=postgres password=postgres" - _ -> printUsage >> exitFailure - - let connSettings = - defaultConnectionSettings - { csConnInfo = connString - } - ConnectionSource connSource = simpleSource connSettings - - withStdOutLogger $ \logger -> - runTestEnv connSource logger $ do - createTables +main = do + connectionParamsString <- getConnectionString + let connectionSettings = defaultConnectionSettings {csConnInfo = connectionParamsString} + defaultMain (allTests connectionSettings) + +allTests :: ConnectionSettings -> TestTree +allTests connectionSource = + testGroup + "consumers" + [ testCase "can grow the number of jobs ran concurrently" (testJobScheduleGrowth connectionSource) + ] + +-------------------- + +-- | Test that when a batch is submitted, it is consumed completely and in an +-- accelerated fashion that grows the batch size exponentially. +testJobScheduleGrowth :: ConnectionSettings -> IO () +testJobScheduleGrowth connectionSettings = do + let ConnectionSource connSource = simpleSource connectionSettings + withStdOutLogger $ \logger -> do + let additionalColumns = + [ tblColumn + { colName = "countdown" + , colType = IntegerT + , colNullable = False + } + ] + runTestEnv connSource logger (TestSetup "test_job_schedule_growth" additionalColumns) $ do + consumerConfig <- getConsumerConfig + TestEnvSt {..} <- get idleSignal <- liftIO newEmptyTMVarIO putJob 10 >> commit - forM_ [1 .. 10 :: Int] $ \_ -> do + rowCountGrowth :: [Int64] <- replicateM (10 :: Int) $ do -- Move time forward 2hours, because jobs are scheduled 1 hour into future - modifyTestTime $ addUTCTime (2 * 60 * 60) + shiftTestTimeHours 2 finalize ( localDomain "process" $ runConsumerWithIdleSignal consumerConfig connSource idleSignal ) - $ do - waitUntilTrue idleSignal + $ waitUntilTrue idleSignal currentTime >>= (logInfo_ . T.pack . ("current time: " ++) . show) - -- Each job creates 2 new jobs, so there should be 1024 jobs in table. - runSQL_ "SELECT COUNT(*) from consumers_test_jobs" - rowcount0 :: Int64 <- fetchOne runIdentity + -- Each job creates 2 new jobs, so there should be 1024 jobs in table. + runSQL_ ("SELECT COUNT(*) from " <> raw teJobTableName) + fetchOne runIdentity + -- Move time 2 hours forward - modifyTestTime $ addUTCTime (2 * 60 * 60) + shiftTestTimeHours 2 finalize ( localDomain "process" $ runConsumerWithIdleSignal consumerConfig connSource idleSignal ) - $ do - waitUntilTrue idleSignal + $ waitUntilTrue idleSignal + -- Jobs are designed to double only 10 times, so there should be no jobs left now. - runSQL_ "SELECT COUNT(*) from consumers_test_jobs" + runSQL_ ("SELECT COUNT(*) from " <> raw teJobTableName) rowcount1 :: Int64 <- fetchOne runIdentity - liftIO $ T.assertEqual "Number of jobs in table after 10 steps is 1024" 1024 rowcount0 + liftIO $ T.assertEqual "Number of jobs in table after 10 steps grows exponentially" [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] rowCountGrowth liftIO $ T.assertEqual "Number of jobs in table after 11 steps is 0" 0 rowcount1 - dropTables where - waitUntilTrue tmvar = liftIO . atomically $ do - takeTMVar tmvar >>= \case - True -> pure () - False -> retry - - printUsage = do - prog <- getProgName - putStrLn $ "Usage: " <> prog <> " " - - definitions = emptyDbDefinitions {dbTables = [consumersTable, jobsTable]} - -- NB: order of migrations is important. - migrations = - [ createTableMigration consumersTable - , createTableMigration jobsTable - ] - - createTables :: TestEnv () - createTables = do - migrateDatabase - defaultExtrasOptions - definitions - migrations - checkDatabase - defaultExtrasOptions - definitions - - dropTables :: TestEnv () - dropTables = do - migrateDatabase - defaultExtrasOptions - emptyDbDefinitions - [ dropTableMigration jobsTable - , dropTableMigration consumersTable - ] - - consumerConfig = - ConsumerConfig - { ccJobsTable = "consumers_test_jobs" - , ccConsumersTable = "consumers_test_consumers" - , ccJobSelectors = ["id", "countdown"] - , ccJobFetcher = id - , ccJobIndex = \(i :: Int64, _ :: Int32) -> i - , ccNotificationChannel = Just "consumers_test_chan" - , -- select some small timeout - ccNotificationTimeout = 100 * 1000 -- 100 msec - , ccMaxRunningJobs = 20 - , ccProcessJob = processJob - , ccOnException = handleException - , ccJobLogData = \(i, _) -> ["job_id" .= i] - } + getConsumerConfig :: TestEnv (ConsumerConfig TestEnv Int64 (Int64, Int32)) + getConsumerConfig = defaultConsumerConfig processJob ["id", "countdown"] fst putJob :: Int32 -> TestEnv () putJob countdown = localDomain "put" $ do + TestEnvSt {..} <- get now <- currentTime runSQL_ $ - "INSERT INTO consumers_test_jobs " + "INSERT INTO " + <> raw teJobTableName <> "(run_at, finished_at, reserved_by, attempts, countdown) " <> "VALUES (" now <> " + interval '1 hour', NULL, NULL, 0, " countdown <> ")" - notify "consumers_test_chan" "" + notify teNotificationChannel "" processJob :: (Int64, Int32) -> TestEnv Result processJob (_idx, countdown) = do @@ -175,93 +104,30 @@ test = do commit pure (Ok Remove) - handleException :: SomeException -> (Int64, Int32) -> TestEnv Action - handleException _ _ = pure . RerunAfter $ imicroseconds 500000 - -jobsTable :: Table -jobsTable = - tblTable - { tblName = "consumers_test_jobs" - , tblVersion = 1 - , tblColumns = - [ tblColumn - { colName = "id" - , colType = BigSerialT - , colNullable = False - } - , tblColumn - { colName = "run_at" - , colType = TimestampWithZoneT - , colNullable = True - } - , tblColumn - { colName = "finished_at" - , colType = TimestampWithZoneT - , colNullable = True - } - , tblColumn - { colName = "reserved_by" - , colType = BigIntT - , colNullable = True - } - , tblColumn - { colName = "attempts" - , colType = IntegerT - , colNullable = False - } - , -- The only non-obligatory field: - tblColumn - { colName = "countdown" - , colType = IntegerT - , colNullable = False - } - ] - , tblPrimaryKey = pkOnColumn "id" - , tblForeignKeys = - [ (fkOnColumn "reserved_by" "consumers_test_consumers" "id") - { fkOnDelete = ForeignKeySetNull - } - ] - } - -consumersTable :: Table -consumersTable = - tblTable - { tblName = "consumers_test_consumers" - , tblVersion = 1 - , tblColumns = - [ tblColumn - { colName = "id" - , colType = BigSerialT - , colNullable = False - } - , tblColumn - { colName = "name" - , colType = TextT - , colNullable = False - } - , tblColumn - { colName = "last_activity" - , colType = TimestampWithZoneT - , colNullable = False - } - ] - , tblPrimaryKey = pkOnColumn "id" - } - -createTableMigration :: MonadDB m => Table -> Migration m -createTableMigration tbl = - Migration - { mgrTableName = tblName tbl - , mgrFrom = 0 - , mgrAction = StandardMigration $ do - createTable True tbl - } - -dropTableMigration :: Table -> Migration m -dropTableMigration tbl = - Migration - { mgrTableName = tblName tbl - , mgrFrom = 1 - , mgrAction = DropTableMigration DropTableRestrict - } +waitUntilTrue :: MonadIO m => TMVar Bool -> m () +waitUntilTrue tmvar = liftIO . atomically $ do + takeTMVar tmvar >>= \case + True -> pure () + False -> retry + +handleException :: Applicative m => SomeException -> k -> m Action +handleException _ _ = pure . RerunAfter $ imicroseconds 500000 + +defaultConsumerConfig :: (FromRow job, Applicative m) => (job -> m Result) -> [SQL] -> (job -> idx) -> TestEnv (ConsumerConfig m idx job) +defaultConsumerConfig processJob jobSelectors jobIndex = do + TestEnvSt {..} <- get + pure $ + ConsumerConfig + { ccJobsTable = teJobTableName + , ccConsumersTable = teConsumerTableName + , ccJobSelectors = jobSelectors + , ccJobFetcher = id + , ccJobIndex = jobIndex + , ccNotificationChannel = Just teNotificationChannel + , -- select some small timeout + ccNotificationTimeout = 100 * 1000 -- 100 msec + , ccMaxRunningJobs = 20 + , ccProcessJob = processJob + , ccOnException = handleException + , ccJobLogData = const [] + } diff --git a/consumers/test/Util.hs b/consumers/test/Util.hs new file mode 100644 index 0000000..2c7db77 --- /dev/null +++ b/consumers/test/Util.hs @@ -0,0 +1,210 @@ +module Util where + +import Control.Applicative ((<|>)) +import Control.Monad.Base +import Control.Monad.Catch +import Control.Monad.IO.Class +import Control.Monad.State.Strict +import Control.Monad.Time +import Control.Monad.Trans.Control +import Data.Text qualified as T +import Data.Time +import Database.PostgreSQL.PQTypes +import Database.PostgreSQL.PQTypes.Checks +import Database.PostgreSQL.PQTypes.Model +import Log +import System.Environment +import System.Exit + +getConnectionString :: IO T.Text +getConnectionString = do + connectionParamsString <- (<|>) <$> paramsFromGithub <*> paramsFromEnvironmentVariables + allArgs <- getArgs + case connectionParamsString of + Just params -> pure (stringFromParams params) + _ -> case allArgs of + connString : _args -> pure (T.pack connString) + [] -> printUsage *> exitFailure + where + printUsage = do + prog <- getProgName + putStrLn $ "Usage: " <> prog <> " " + + paramsFromGithub = + lookupEnv "GITHUB_ACTIONS" >>= \case + Just "true" -> pure $ Just ("postgres", "postgres", "postgres", "postgres") + _ -> pure Nothing + paramsFromEnvironmentVariables = do + variables <- + sequence + [ lookupEnv "PGHOST" + , lookupEnv "PGUSER" + , lookupEnv "PGDATABASE" + , lookupEnv "PGPASSWORD" + ] + case variables of + [Just host, Just user, Just database, Just password] -> pure $ Just (host, user, database, password) + _ -> pure Nothing + stringFromParams (host, user, database, pass) = + T.pack ("host=" <> host <> " user=" <> user <> " dbname=" <> database <> " password=" <> pass) + +data TestEnvSt = TestEnvSt + { teCurrentTime :: UTCTime + , teMonotonicTime :: Double + , teJobTableName :: RawSQL () + , teConsumerTableName :: RawSQL () + , teNotificationChannel :: Channel + , teAdditionalCols :: [TableColumn] + } + +type InnerTestEnv = StateT TestEnvSt (DBT (LogT IO)) + +newtype TestEnv a = TestEnv {unTestEnv :: InnerTestEnv a} + deriving (Applicative, Functor, Monad, MonadLog, MonadDB, MonadThrow, MonadCatch, MonadMask, MonadIO, MonadBase IO, MonadState TestEnvSt) + +instance MonadBaseControl IO TestEnv where + type StM TestEnv a = StM InnerTestEnv a + liftBaseWith f = TestEnv $ liftBaseWith (\run -> f $ run . unTestEnv) + restoreM = TestEnv . restoreM + +instance MonadTime TestEnv where + currentTime = gets teCurrentTime + monotonicTime = gets teMonotonicTime + +data TestSetup = TestSetup + { tsTestSuffix :: RawSQL () + , tsAdditionalCols :: [TableColumn] + } + +modifyTestTime :: MonadState TestEnvSt m => (UTCTime -> UTCTime) -> m () +modifyTestTime modtime = modify (\te -> te {teCurrentTime = modtime . teCurrentTime $ te}) + +shiftTestTimeHours :: MonadState TestEnvSt m => NominalDiffTime -> m () +shiftTestTimeHours hr = modifyTestTime $ addUTCTime (hr * 60 * 60) + +runTestEnv :: ConnectionSourceM (LogT IO) -> Logger -> TestSetup -> TestEnv a -> IO a +runTestEnv connSource logger TestSetup {..} test = + (runLogT "consumers-test" logger defaultLogLevel . runDBT connSource defaultTransactionSettings) + . (`evalStateT` testEnvironment) + $ unTestEnv (bracket createTables (const dropTables) (const test)) + where + jobTableName = "jobs_" <> tsTestSuffix + consumerTableName = "consumers_" <> tsTestSuffix + notificationChannelName = "notification_" <> tsTestSuffix + testEnvironment = + TestEnvSt + (UTCTime (ModifiedJulianDay 0) 0) + 0 + jobTableName + consumerTableName + (Channel notificationChannelName) + tsAdditionalCols + jobTable = mkJobsTable jobTableName consumerTableName tsAdditionalCols + consumerTable = mkConsumersTable consumerTableName + definitions = emptyDbDefinitions {dbTables = [consumerTable, jobTable]} + -- NB: order of migrations is important. + migrations = + [ createTableMigration consumerTable + , createTableMigration jobTable + ] + + createTables :: TestEnv () + createTables = do + migrateDatabase + defaultExtrasOptions + definitions + migrations + checkDatabase + defaultExtrasOptions + definitions + + dropTables :: TestEnv () + dropTables = + migrateDatabase + defaultExtrasOptions + emptyDbDefinitions + [ dropTableMigration jobTable + , dropTableMigration consumerTable + ] + +mkJobsTable :: RawSQL () -> RawSQL () -> [TableColumn] -> Table +mkJobsTable tableName consumerTableName additionalCols = + tblTable + { tblName = tableName + , tblVersion = 1 + , tblColumns = + [ tblColumn + { colName = "id" + , colType = BigSerialT + , colNullable = False + } + , tblColumn + { colName = "run_at" + , colType = TimestampWithZoneT + , colNullable = True + } + , tblColumn + { colName = "finished_at" + , colType = TimestampWithZoneT + , colNullable = True + } + , tblColumn + { colName = "reserved_by" + , colType = BigIntT + , colNullable = True + } + , tblColumn + { colName = "attempts" + , colType = IntegerT + , colNullable = False + } + ] + ++ additionalCols + , tblPrimaryKey = pkOnColumn "id" + , tblForeignKeys = + [ (fkOnColumn "reserved_by" consumerTableName "id") + { fkOnDelete = ForeignKeySetNull + } + ] + } + +mkConsumersTable :: RawSQL () -> Table +mkConsumersTable tableName = + tblTable + { tblName = tableName + , tblVersion = 1 + , tblColumns = + [ tblColumn + { colName = "id" + , colType = BigSerialT + , colNullable = False + } + , tblColumn + { colName = "name" + , colType = TextT + , colNullable = False + } + , tblColumn + { colName = "last_activity" + , colType = TimestampWithZoneT + , colNullable = False + } + ] + , tblPrimaryKey = pkOnColumn "id" + } + +createTableMigration :: MonadDB m => Table -> Migration m +createTableMigration tbl = + Migration + { mgrTableName = tblName tbl + , mgrFrom = 0 + , mgrAction = StandardMigration $ createTable True tbl + } + +dropTableMigration :: Table -> Migration m +dropTableMigration tbl = + Migration + { mgrTableName = tblName tbl + , mgrFrom = 1 + , mgrAction = DropTableMigration DropTableRestrict + }