入門系列介紹了 with 語句的基本使用。本章深入探討上下文管理器的實現原理與進階應用,包括 contextlib 工具、嵌套組合、以及非同步上下文管理。

先備知識

  • 入門系列的 with 語句使用
  • 基本的類別定義與魔術方法

上下文管理器協議

__enter____exit__

 1class ManagedResource:
 2    """展示上下文管理器協議"""
 3
 4    def __init__(self, name: str) -> None:
 5        self.name = name
 6        print(f"Creating {name}")
 7
 8    def __enter__(self) -> "ManagedResource":
 9        """進入 with 區塊時呼叫
10
11        Returns:
12            as 子句綁定的物件
13        """
14        print(f"Entering {self.name}")
15        return self  # 這會被 as 子句捕獲
16
17    def __exit__(
18        self,
19        exc_type: type[BaseException] | None,
20        exc_val: BaseException | None,
21        exc_tb: TracebackType | None
22    ) -> bool:
23        """離開 with 區塊時呼叫
24
25        Args:
26            exc_type: 異常類型(無異常時為 None)
27            exc_val: 異常實例(無異常時為 None)
28            exc_tb: 追蹤資訊(無異常時為 None)
29
30        Returns:
31            True 表示已處理異常,不再傳播
32            False 或 None 表示讓異常繼續傳播
33        """
34        print(f"Exiting {self.name}")
35        if exc_type is not None:
36            print(f"  Exception: {exc_type.__name__}: {exc_val}")
37        return False  # 讓異常繼續傳播
38
39# 使用
40with ManagedResource("test") as resource:
41    print(f"Using {resource.name}")
42    # raise ValueError("oops")  # 取消註解測試異常處理
43
44# 輸出:
45# Creating test
46# Entering test
47# Using test
48# Exiting test

__exit__ 的異常處理

 1from types import TracebackType
 2
 3class SuppressErrors:
 4    """抑制特定異常的上下文管理器"""
 5
 6    def __init__(self, *exceptions: type[BaseException]) -> None:
 7        self.exceptions = exceptions
 8
 9    def __enter__(self) -> None:
10        pass
11
12    def __exit__(
13        self,
14        exc_type: type[BaseException] | None,
15        exc_val: BaseException | None,
16        exc_tb: TracebackType | None
17    ) -> bool:
18        # 如果是我們要抑制的異常類型,返回 True
19        if exc_type is not None and issubclass(exc_type, self.exceptions):
20            print(f"Suppressed: {exc_type.__name__}: {exc_val}")
21            return True  # 吞掉異常
22        return False
23
24# 使用
25with SuppressErrors(ValueError, KeyError):
26    raise ValueError("This will be suppressed")
27
28print("Continues normally")

返回值的重要性

 1class DatabaseConnection:
 2    """展示 __enter__ 返回值的用法"""
 3
 4    def __init__(self, url: str) -> None:
 5        self.url = url
 6        self._connection = None
 7
 8    def __enter__(self) -> "Cursor":
 9        self._connection = connect(self.url)
10        return self._connection.cursor()  # 返回 cursor,不是 self
11
12    def __exit__(self, *args) -> None:
13        if self._connection:
14            self._connection.close()
15
16# 使用
17with DatabaseConnection("postgres://...") as cursor:
18    cursor.execute("SELECT * FROM users")
19    # cursor 是 Cursor 物件,不是 DatabaseConnection

contextlib 工具

@contextmanager 裝飾器

用生成器函式建立上下文管理器,比定義類別更簡潔:

 1from contextlib import contextmanager
 2from typing import Iterator
 3
 4@contextmanager
 5def timer(name: str) -> Iterator[None]:
 6    """計時上下文管理器"""
 7    import time
 8    start = time.perf_counter()
 9    print(f"Starting {name}...")
10
11    try:
12        yield  # 這裡是 with 區塊執行的地方
13    finally:
14        elapsed = time.perf_counter() - start
15        print(f"{name} took {elapsed:.3f}s")
16
17# 使用
18with timer("data processing"):
19    process_large_dataset()
20
21# 如果需要返回值
22@contextmanager
23def temp_directory() -> Iterator[Path]:
24    """建立臨時目錄,結束後自動清理"""
25    import tempfile
26    import shutil
27    from pathlib import Path
28
29    path = Path(tempfile.mkdtemp())
30    try:
31        yield path  # path 會被 as 子句捕獲
32    finally:
33        shutil.rmtree(path)
34
35with temp_directory() as tmpdir:
36    (tmpdir / "test.txt").write_text("hello")

ExitStack

動態管理多個上下文管理器:

 1from contextlib import ExitStack
 2
 3def process_multiple_files(filenames: list[str]) -> list[str]:
 4    """處理多個檔案,確保全部關閉"""
 5    with ExitStack() as stack:
 6        files = [
 7            stack.enter_context(open(fn))
 8            for fn in filenames
 9        ]
10        return [f.read() for f in files]
11
12# 更複雜的例子:條件式資源
13def connect_services(config: dict) -> dict:
14    """根據配置連接多個服務"""
15    with ExitStack() as stack:
16        services = {}
17
18        if config.get("database"):
19            services["db"] = stack.enter_context(
20                DatabaseConnection(config["database"])
21            )
22
23        if config.get("cache"):
24            services["cache"] = stack.enter_context(
25                CacheConnection(config["cache"])
26            )
27
28        if config.get("queue"):
29            services["queue"] = stack.enter_context(
30                QueueConnection(config["queue"])
31            )
32
33        # 所有連接都會在離開時按相反順序關閉
34        return do_work(services)

ExitStack 的回調功能

 1from contextlib import ExitStack
 2
 3def complex_setup() -> None:
 4    with ExitStack() as stack:
 5        # 註冊清理回調
 6        stack.callback(print, "Cleanup 1")
 7        stack.callback(print, "Cleanup 2")
 8
 9        # 推遲上下文管理器
10        cm = some_context_manager()
11        stack.push(cm)  # 不立即進入,但會在結束時呼叫 __exit__
12
13        # 做一些工作...
14        pass
15
16    # 離開時會執行:
17    # 1. cm.__exit__()
18    # 2. print("Cleanup 2")
19    # 3. print("Cleanup 1")  # 注意順序是相反的

nullcontext

需要可選的上下文管理器時使用:

 1from contextlib import nullcontext
 2
 3def process_data(lock: Lock | None = None) -> None:
 4    """可選的鎖定"""
 5    with lock if lock else nullcontext():
 6        # 處理資料
 7        pass
 8
 9# 更清楚的寫法
10def process_data(lock: Lock | None = None) -> None:
11    cm = lock if lock else nullcontext()
12    with cm:
13        pass
14
15# 帶返回值的 nullcontext
16from contextlib import nullcontext
17
18def get_stream(filename: str | None) -> Iterator[TextIO]:
19    if filename:
20        return open(filename)
21    return nullcontext(sys.stdout)  # 返回 stdout
22
23with get_stream(None) as f:
24    f.write("Hello")  # 寫到 stdout

嵌套與組合上下文

資源的有序獲取與釋放

 1# 傳統嵌套
 2with open("input.txt") as infile:
 3    with open("output.txt", "w") as outfile:
 4        outfile.write(infile.read())
 5
 6# Python 3.9+ 可以用括號分組
 7with (
 8    open("input.txt") as infile,
 9    open("output.txt", "w") as outfile
10):
11    outfile.write(infile.read())
12
13# 多個資源:ExitStack 更靈活
14with ExitStack() as stack:
15    files = [stack.enter_context(open(f)) for f in filenames]

組合上下文管理器

 1from contextlib import contextmanager
 2from typing import Iterator
 3
 4@contextmanager
 5def database_transaction(db: Database) -> Iterator[Transaction]:
 6    """資料庫交易上下文"""
 7    tx = db.begin_transaction()
 8    try:
 9        yield tx
10        tx.commit()
11    except Exception:
12        tx.rollback()
13        raise
14
15@contextmanager
16def acquire_lock(lock: Lock, timeout: float = 30) -> Iterator[None]:
17    """帶超時的鎖定獲取"""
18    if not lock.acquire(timeout=timeout):
19        raise TimeoutError(f"Could not acquire lock within {timeout}s")
20    try:
21        yield
22    finally:
23        lock.release()
24
25# 組合使用
26@contextmanager
27def safe_update(db: Database, lock: Lock) -> Iterator[Transaction]:
28    """帶鎖定的安全更新"""
29    with acquire_lock(lock):
30        with database_transaction(db) as tx:
31            yield tx

async context manager

基本語法

 1class AsyncResource:
 2    """非同步上下文管理器"""
 3
 4    async def __aenter__(self) -> "AsyncResource":
 5        await self.connect()
 6        return self
 7
 8    async def __aexit__(
 9        self,
10        exc_type: type[BaseException] | None,
11        exc_val: BaseException | None,
12        exc_tb: TracebackType | None
13    ) -> bool:
14        await self.disconnect()
15        return False
16
17# 使用
18async def main():
19    async with AsyncResource() as resource:
20        await resource.do_something()

@asynccontextmanager

 1from contextlib import asynccontextmanager
 2from typing import AsyncIterator
 3
 4@asynccontextmanager
 5async def async_timer(name: str) -> AsyncIterator[None]:
 6    """非同步計時器"""
 7    import time
 8    start = time.perf_counter()
 9    print(f"Starting {name}...")
10
11    try:
12        yield
13    finally:
14        elapsed = time.perf_counter() - start
15        print(f"{name} took {elapsed:.3f}s")
16
17async def main():
18    async with async_timer("async task"):
19        await asyncio.sleep(1)

與 asyncio 模組的連結

 1import asyncio
 2from contextlib import asynccontextmanager
 3
 4@asynccontextmanager
 5async def managed_task(coro) -> AsyncIterator[asyncio.Task]:
 6    """管理 Task 生命週期"""
 7    task = asyncio.create_task(coro)
 8    try:
 9        yield task
10    finally:
11        if not task.done():
12            task.cancel()
13            try:
14                await task
15            except asyncio.CancelledError:
16                pass
17
18async def main():
19    async with managed_task(background_worker()) as task:
20        # 做一些工作
21        await asyncio.sleep(5)
22    # 離開時自動取消 task

實際範例:交易管理器

結合所有概念,建立一個完整的交易管理器:

  1from contextlib import contextmanager
  2from typing import Iterator, Protocol
  3from dataclasses import dataclass, field
  4from enum import Enum, auto
  5
  6class TransactionState(Enum):
  7    PENDING = auto()
  8    ACTIVE = auto()
  9    COMMITTED = auto()
 10    ROLLED_BACK = auto()
 11
 12class Transactional(Protocol):
 13    """可參與交易的資源協議"""
 14
 15    def begin(self) -> None:
 16        ...
 17
 18    def commit(self) -> None:
 19        ...
 20
 21    def rollback(self) -> None:
 22        ...
 23
 24@dataclass
 25class Transaction:
 26    """交易物件"""
 27    id: str
 28    resources: list[Transactional] = field(default_factory=list)
 29    state: TransactionState = TransactionState.PENDING
 30
 31    def add_resource(self, resource: Transactional) -> None:
 32        if self.state != TransactionState.ACTIVE:
 33            raise RuntimeError("Transaction is not active")
 34        resource.begin()
 35        self.resources.append(resource)
 36
 37    def commit(self) -> None:
 38        if self.state != TransactionState.ACTIVE:
 39            raise RuntimeError("Transaction is not active")
 40
 41        try:
 42            for resource in self.resources:
 43                resource.commit()
 44            self.state = TransactionState.COMMITTED
 45        except Exception:
 46            self.rollback()
 47            raise
 48
 49    def rollback(self) -> None:
 50        if self.state not in (TransactionState.ACTIVE, TransactionState.PENDING):
 51            return
 52
 53        errors = []
 54        for resource in reversed(self.resources):
 55            try:
 56                resource.rollback()
 57            except Exception as e:
 58                errors.append(e)
 59
 60        self.state = TransactionState.ROLLED_BACK
 61
 62        if errors:
 63            raise ExceptionGroup("Rollback failed", errors)
 64
 65class TransactionManager:
 66    """交易管理器"""
 67
 68    def __init__(self) -> None:
 69        self._tx_counter = 0
 70
 71    @contextmanager
 72    def transaction(self) -> Iterator[Transaction]:
 73        """建立新交易"""
 74        self._tx_counter += 1
 75        tx = Transaction(id=f"tx-{self._tx_counter}")
 76        tx.state = TransactionState.ACTIVE
 77
 78        try:
 79            yield tx
 80            tx.commit()
 81        except Exception:
 82            tx.rollback()
 83            raise
 84
 85# 使用範例
 86class DatabaseResource:
 87    """模擬資料庫資源"""
 88
 89    def __init__(self, name: str) -> None:
 90        self.name = name
 91
 92    def begin(self) -> None:
 93        print(f"{self.name}: BEGIN")
 94
 95    def commit(self) -> None:
 96        print(f"{self.name}: COMMIT")
 97
 98    def rollback(self) -> None:
 99        print(f"{self.name}: ROLLBACK")
100
101def main():
102    manager = TransactionManager()
103    db1 = DatabaseResource("primary")
104    db2 = DatabaseResource("replica")
105
106    with manager.transaction() as tx:
107        tx.add_resource(db1)
108        tx.add_resource(db2)
109
110        # 做一些操作...
111        print("Doing work in transaction")
112
113        # 如果拋出異常,兩個資源都會 rollback
114        # raise ValueError("Something went wrong")
115
116    # 正常結束時,兩個資源都會 commit

常見錯誤

1. 忘記 yield

1@contextmanager
2def broken():
3    print("enter")
4    # 忘記 yield!
5    print("exit")
6
7# 使用時會報錯:generator didn't yield

2. 在 finally 中 yield

1@contextmanager
2def also_broken():
3    try:
4        yield
5    finally:
6        yield  # 錯誤!不能在 finally 中 yield

3. 異常處理不當

 1@contextmanager
 2def risky():
 3    resource = acquire_resource()
 4    yield resource
 5    release_resource(resource)  # 如果 with 區塊拋異常,這行不會執行!
 6
 7# 正確做法
 8@contextmanager
 9def safe():
10    resource = acquire_resource()
11    try:
12        yield resource
13    finally:
14        release_resource(resource)  # 一定會執行

小結

概念用途
__enter__/__exit__上下文管理器協議
@contextmanager用生成器建立上下文管理器
ExitStack動態管理多個上下文
nullcontext可選的上下文管理器
async with非同步上下文管理
@asynccontextmanager用非同步生成器建立上下文管理器

思考題

  1. __exit__ 返回 TrueFalse 的差別是什麼?
  2. 什麼時候應該用 ExitStack 而不是巢狀 with
  3. @contextmanageryield 前後的程式碼分別對應什麼?

上一章:3.5.2 異常設計架構 下一章:3.5.4 插件系統設計