Skip to content

Commit d1276d9

Browse files
committed
feat: implement aroundEach and aroundAll hooks
1 parent 8f8231c commit d1276d9

7 files changed

Lines changed: 176 additions & 74 deletions

File tree

packages/runner/src/hooks.ts

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import type {
22
AfterAllListener,
33
AfterEachListener,
4+
AroundEachListener,
45
BeforeAllListener,
56
BeforeEachListener,
67
OnTestFailedHandler,
@@ -266,6 +267,46 @@ export const onTestFinished: TaskHook<OnTestFinishedHandler> = createTestHook(
266267
},
267268
)
268269

270+
/**
271+
* Registers a callback function that wraps around each test within the current suite.
272+
* The callback receives a `runTest` function that must be called to run the test.
273+
* This hook is useful for scenarios where you need to wrap tests in a context (e.g., database transactions).
274+
*
275+
* **Note:** When multiple `aroundEach` hooks are registered, they are nested inside each other.
276+
* The first registered hook is the outermost wrapper.
277+
*
278+
* @param {Function} fn - The callback function that wraps the test. Must call `runTest()` to run the test.
279+
* @param {number} [timeout] - Optional timeout in milliseconds for the hook. If not provided, the default hook timeout from the runner's configuration is used.
280+
* @returns {void}
281+
* @example
282+
* ```ts
283+
* // Example of using aroundEach to wrap tests in a database transaction
284+
* aroundEach(async (runTest) => {
285+
* await database.beginTransaction();
286+
* await runTest(); // Run the test
287+
* await database.rollback();
288+
* });
289+
* ```
290+
*/
291+
export function aroundEach<ExtraContext = object>(
292+
fn: AroundEachListener<ExtraContext>,
293+
timeout?: number,
294+
): void {
295+
assertTypes(fn, '"aroundEach" callback', ['function'])
296+
const runner = getRunner()
297+
return getCurrentSuite<ExtraContext>().on(
298+
'aroundEach',
299+
withTimeout(
300+
fn,
301+
// TODO: what should be the timeout? it runs _every_ hook inside (+a test)
302+
timeout ?? (getDefaultHookTimeout() + runner.config.testTimeout + getDefaultHookTimeout()),
303+
true,
304+
new Error('STACK_TRACE_ERROR'),
305+
([, context], error) => abortContextSignal(context, error),
306+
),
307+
)
308+
}
309+
269310
function createTestHook<T>(
270311
name: string,
271312
handler: (test: TaskPopulated, handler: T, timeout?: number) => void,

packages/runner/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ export { recordArtifact } from './artifact'
22
export {
33
afterAll,
44
afterEach,
5+
aroundEach,
56
beforeAll,
67
beforeEach,
78
onTestFailed,

packages/runner/src/run.ts

Lines changed: 122 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import type { Awaitable, TestError } from '@vitest/utils'
22
import type { DiffOptions } from '@vitest/utils/diff'
33
import type { FileSpecification, VitestRunner } from './types/runner'
44
import type {
5+
AroundEachListener,
56
File,
67
SequenceHooks,
78
Suite,
@@ -217,6 +218,51 @@ export async function callSuiteHook<T extends keyof SuiteHooks>(
217218
return callbacks
218219
}
219220

221+
function getAroundEachHooks(suite: Suite): AroundEachListener[] {
222+
const hooks: AroundEachListener[] = []
223+
const parentSuite: Suite | null = 'filepath' in suite ? null : suite.suite || suite.file
224+
if (parentSuite) {
225+
hooks.push(...getAroundEachHooks(parentSuite))
226+
}
227+
hooks.push(...getHooks(suite).aroundEach)
228+
return hooks
229+
}
230+
231+
async function callAroundEachHooks(
232+
suite: Suite,
233+
test: Test,
234+
runTest: () => Promise<void>,
235+
): Promise<void> {
236+
const hooks = getAroundEachHooks(suite)
237+
238+
if (!hooks.length) {
239+
await runTest()
240+
return
241+
}
242+
243+
const runNextHook = async (index: number): Promise<void> => {
244+
if (index >= hooks.length) {
245+
return runTest()
246+
}
247+
248+
const hook = hooks[index]
249+
let useCalled = false
250+
const use = () => {
251+
useCalled = true
252+
return runNextHook(index + 1)
253+
}
254+
await hook(use, test.context, suite)
255+
if (!useCalled) {
256+
throw new Error(
257+
'The `runTest()` callback was not called in the `aroundEach` hook. '
258+
+ 'Make sure to call `runTest()` to run the test.',
259+
)
260+
}
261+
}
262+
263+
await runNextHook(0)
264+
}
265+
220266
const packs = new Map<string, [TaskResult | undefined, TaskMeta]>()
221267
const eventsPacks: [string, TaskUpdateEvent, undefined][] = []
222268
const pendingTasksUpdates: Promise<void>[] = []
@@ -365,93 +411,95 @@ export async function runTest(test: Test, runner: VitestRunner): Promise<void> {
365411
const retry = getRetryCount(test.retry)
366412
for (let retryCount = 0; retryCount <= retry; retryCount++) {
367413
let beforeEachCleanups: unknown[] = []
368-
try {
369-
await runner.onBeforeTryTask?.(test, {
370-
retry: retryCount,
371-
repeats: repeatCount,
372-
})
373-
374-
test.result.repeatCount = repeatCount
414+
await callAroundEachHooks(suite, test, async () => {
415+
try {
416+
await runner.onBeforeTryTask?.(test, {
417+
retry: retryCount,
418+
repeats: repeatCount,
419+
})
420+
421+
test.result!.repeatCount = repeatCount
422+
423+
beforeEachCleanups = await $('test.beforeEach', () => callSuiteHook(
424+
suite,
425+
test,
426+
'beforeEach',
427+
runner,
428+
[test.context, suite],
429+
))
430+
431+
if (runner.runTask) {
432+
await $('test.callback', () => runner.runTask!(test))
433+
}
434+
else {
435+
const fn = getFn(test)
436+
if (!fn) {
437+
throw new Error(
438+
'Test function is not found. Did you add it using `setFn`?',
439+
)
440+
}
441+
await $('test.callback', () => fn())
442+
}
375443

376-
beforeEachCleanups = await $('test.beforeEach', () => callSuiteHook(
377-
suite,
378-
test,
379-
'beforeEach',
380-
runner,
381-
[test.context, suite],
382-
))
444+
await runner.onAfterTryTask?.(test, {
445+
retry: retryCount,
446+
repeats: repeatCount,
447+
})
383448

384-
if (runner.runTask) {
385-
await $('test.callback', () => runner.runTask!(test))
386-
}
387-
else {
388-
const fn = getFn(test)
389-
if (!fn) {
390-
throw new Error(
391-
'Test function is not found. Did you add it using `setFn`?',
392-
)
449+
if (test.result!.state !== 'fail') {
450+
if (!test.repeats) {
451+
test.result!.state = 'pass'
452+
}
453+
else if (test.repeats && retry === retryCount) {
454+
test.result!.state = 'pass'
455+
}
393456
}
394-
await $('test.callback', () => fn())
457+
}
458+
catch (e) {
459+
failTask(test.result!, e, runner.config.diffOptions)
395460
}
396461

397-
await runner.onAfterTryTask?.(test, {
398-
retry: retryCount,
399-
repeats: repeatCount,
400-
})
462+
try {
463+
await runner.onTaskFinished?.(test)
464+
}
465+
catch (e) {
466+
failTask(test.result!, e, runner.config.diffOptions)
467+
}
401468

402-
if (test.result.state !== 'fail') {
403-
if (!test.repeats) {
404-
test.result.state = 'pass'
405-
}
406-
else if (test.repeats && retry === retryCount) {
407-
test.result.state = 'pass'
469+
try {
470+
await $('test.afterEach', () => callSuiteHook(suite, test, 'afterEach', runner, [
471+
test.context,
472+
suite,
473+
]))
474+
if (beforeEachCleanups.length) {
475+
await $('test.cleanup', () => callCleanupHooks(runner, beforeEachCleanups))
408476
}
477+
await callFixtureCleanup(test.context)
409478
}
410-
}
411-
catch (e) {
412-
failTask(test.result, e, runner.config.diffOptions)
413-
}
414-
415-
try {
416-
await runner.onTaskFinished?.(test)
417-
}
418-
catch (e) {
419-
failTask(test.result, e, runner.config.diffOptions)
420-
}
421-
422-
try {
423-
await $('test.afterEach', () => callSuiteHook(suite, test, 'afterEach', runner, [
424-
test.context,
425-
suite,
426-
]))
427-
if (beforeEachCleanups.length) {
428-
await $('test.cleanup', () => callCleanupHooks(runner, beforeEachCleanups))
479+
catch (e) {
480+
failTask(test.result!, e, runner.config.diffOptions)
429481
}
430-
await callFixtureCleanup(test.context)
431-
}
432-
catch (e) {
433-
failTask(test.result, e, runner.config.diffOptions)
434-
}
435482

436-
if (test.onFinished?.length) {
437-
await $('test.onFinished', () => callTestHooks(runner, test, test.onFinished!, 'stack'))
438-
}
483+
if (test.onFinished?.length) {
484+
await $('test.onFinished', () => callTestHooks(runner, test, test.onFinished!, 'stack'))
485+
}
439486

440-
if (test.result.state === 'fail' && test.onFailed?.length) {
441-
await $('test.onFailed', () => callTestHooks(
442-
runner,
443-
test,
444-
test.onFailed!,
445-
runner.config.sequence.hooks,
446-
))
447-
}
487+
if (test.result!.state === 'fail' && test.onFailed?.length) {
488+
await $('test.onFailed', () => callTestHooks(
489+
runner,
490+
test,
491+
test.onFailed!,
492+
runner.config.sequence.hooks,
493+
))
494+
}
448495

449-
test.onFailed = undefined
450-
test.onFinished = undefined
496+
test.onFailed = undefined
497+
test.onFinished = undefined
451498

452-
await runner.onAfterRetryTask?.(test, {
453-
retry: retryCount,
454-
repeats: repeatCount,
499+
await runner.onAfterRetryTask?.(test, {
500+
retry: retryCount,
501+
repeats: repeatCount,
502+
})
455503
})
456504

457505
// skipped with new PendingError

packages/runner/src/suite.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ export function createSuiteHooks(): SuiteHooks {
247247
afterAll: [],
248248
beforeEach: [],
249249
afterEach: [],
250+
aroundEach: [],
250251
}
251252
}
252253

packages/runner/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ export type {
99
export type {
1010
AfterAllListener,
1111
AfterEachListener,
12+
AroundEachListener,
1213
BeforeAllListener,
1314
BeforeEachListener,
1415
File,

packages/runner/src/types/tasks.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,11 +688,20 @@ export interface AfterEachListener<ExtraContext = object> {
688688
): Awaitable<unknown>
689689
}
690690

691+
export interface AroundEachListener<ExtraContext = object> {
692+
(
693+
runTest: () => Promise<void>,
694+
context: TestContext & ExtraContext,
695+
suite: Readonly<Suite>
696+
): Awaitable<unknown>
697+
}
698+
691699
export interface SuiteHooks<ExtraContext = object> {
692700
beforeAll: BeforeAllListener[]
693701
afterAll: AfterAllListener[]
694702
beforeEach: BeforeEachListener<ExtraContext>[]
695703
afterEach: AfterEachListener<ExtraContext>[]
704+
aroundEach: AroundEachListener<ExtraContext>[]
696705
}
697706

698707
export interface TaskCustomOptions extends TestOptions {

packages/vitest/src/public/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ export type {
9797
export {
9898
afterAll,
9999
afterEach,
100+
aroundEach,
100101
beforeAll,
101102
beforeEach,
102103
describe,

0 commit comments

Comments
 (0)