From bdbd9d8839de45f2f14bcfb7c307a231271a6cc9 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Wed, 27 May 2026 12:29:13 +0200 Subject: [PATCH 1/2] test inmemory broker middleware startup / shutdown handlers --- tests/brokers/test_inmemory.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/brokers/test_inmemory.py b/tests/brokers/test_inmemory.py index 9dc85019..eef7a2aa 100644 --- a/tests/brokers/test_inmemory.py +++ b/tests/brokers/test_inmemory.py @@ -4,6 +4,7 @@ import pytest from taskiq import InMemoryBroker +from taskiq.abc.middleware import TaskiqMiddleware from taskiq.events import TaskiqEvents from taskiq.state import TaskiqState @@ -66,6 +67,24 @@ async def _c_startup(state: TaskiqState) -> None: assert broker.state.from_client == test_value +async def test_middleware_startup_and_shutdown_fire() -> None: + calls = [] + + class RecordingMiddleware(TaskiqMiddleware): + async def startup(self) -> None: + calls.append("startup") + + async def shutdown(self) -> None: + calls.append("shutdown") + + broker = InMemoryBroker().with_middlewares(RecordingMiddleware()) + + await broker.startup() + await broker.shutdown() + + assert calls == ["startup", "shutdown"] + + async def test_execution() -> None: broker = InMemoryBroker() test_value = uuid.uuid4().hex From 8e27a14143b8a9f7d840f369ede1c5a026582f84 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Wed, 27 May 2026 12:29:31 +0200 Subject: [PATCH 2/2] fix in memory broker startup / shutdown handlers --- taskiq/brokers/inmemory_broker.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/taskiq/brokers/inmemory_broker.py b/taskiq/brokers/inmemory_broker.py index 0a7cc98e..807f1e75 100644 --- a/taskiq/brokers/inmemory_broker.py +++ b/taskiq/brokers/inmemory_broker.py @@ -5,6 +5,7 @@ from typing import Any, TypeVar from taskiq.abc.broker import AsyncBroker +from taskiq.abc.middleware import TaskiqMiddleware from taskiq.abc.result_backend import AsyncResultBackend, TaskiqResult from taskiq.depends.progress_tracker import TaskProgress from taskiq.events import TaskiqEvents @@ -198,9 +199,21 @@ async def startup(self) -> None: for handler in self.event_handlers.get(event, []): await maybe_awaitable(handler(self.state)) + for middleware in self.middlewares: + if middleware.__class__.startup != TaskiqMiddleware.startup: + await maybe_awaitable(middleware.startup()) + + await self.result_backend.startup() + async def shutdown(self) -> None: """Runs shutdown events for client and worker side.""" for event in (TaskiqEvents.CLIENT_SHUTDOWN, TaskiqEvents.WORKER_SHUTDOWN): for handler in self.event_handlers.get(event, []): await maybe_awaitable(handler(self.state)) + + for middleware in self.middlewares: + if middleware.__class__.shutdown != TaskiqMiddleware.shutdown: + await maybe_awaitable(middleware.shutdown()) + + await self.result_backend.shutdown() self.executor.shutdown()