3.5.3 進階上下文管理
3.5.3 進階上下文管理
入門系列介紹了 with 語句的基本使用。本章深入探討上下文管理器的實現原理與進階應用,包括 contextlib 工具、嵌套組合、以及非同步上下文管理。
先備知識
- 入門系列的
with語句使用 - 基本的類別定義與魔術方法
上下文管理器協議
__enter__ 與 __exit__
1class ManagedResource:
2 """展示上下文管理器協議"""
3
4 def __init__(self, name: str) -> None:
5 self.name = name
6 print(f"Creating {name}")
7
8 def __enter__(self) -> "ManagedResource":
9 """進入 with 區塊時呼叫
10
11 Returns:
12 as 子句綁定的物件
13 """
14 print(f"Entering {self.name}")
15 return self # 這會被 as 子句捕獲
16
17 def __exit__(
18 self,
19 exc_type: type[BaseException] | None,
20 exc_val: BaseException | None,
21 exc_tb: TracebackType | None
22 ) -> bool:
23 """離開 with 區塊時呼叫
24
25 Args:
26 exc_type: 異常類型(無異常時為 None)
27 exc_val: 異常實例(無異常時為 None)
28 exc_tb: 追蹤資訊(無異常時為 None)
29
30 Returns:
31 True 表示已處理異常,不再傳播
32 False 或 None 表示讓異常繼續傳播
33 """
34 print(f"Exiting {self.name}")
35 if exc_type is not None:
36 print(f" Exception: {exc_type.__name__}: {exc_val}")
37 return False # 讓異常繼續傳播
38
39# 使用
40with ManagedResource("test") as resource:
41 print(f"Using {resource.name}")
42 # raise ValueError("oops") # 取消註解測試異常處理
43
44# 輸出:
45# Creating test
46# Entering test
47# Using test
48# Exiting test__exit__ 的異常處理
1from types import TracebackType
2
3class SuppressErrors:
4 """抑制特定異常的上下文管理器"""
5
6 def __init__(self, *exceptions: type[BaseException]) -> None:
7 self.exceptions = exceptions
8
9 def __enter__(self) -> None:
10 pass
11
12 def __exit__(
13 self,
14 exc_type: type[BaseException] | None,
15 exc_val: BaseException | None,
16 exc_tb: TracebackType | None
17 ) -> bool:
18 # 如果是我們要抑制的異常類型,返回 True
19 if exc_type is not None and issubclass(exc_type, self.exceptions):
20 print(f"Suppressed: {exc_type.__name__}: {exc_val}")
21 return True # 吞掉異常
22 return False
23
24# 使用
25with SuppressErrors(ValueError, KeyError):
26 raise ValueError("This will be suppressed")
27
28print("Continues normally")返回值的重要性
1class DatabaseConnection:
2 """展示 __enter__ 返回值的用法"""
3
4 def __init__(self, url: str) -> None:
5 self.url = url
6 self._connection = None
7
8 def __enter__(self) -> "Cursor":
9 self._connection = connect(self.url)
10 return self._connection.cursor() # 返回 cursor,不是 self
11
12 def __exit__(self, *args) -> None:
13 if self._connection:
14 self._connection.close()
15
16# 使用
17with DatabaseConnection("postgres://...") as cursor:
18 cursor.execute("SELECT * FROM users")
19 # cursor 是 Cursor 物件,不是 DatabaseConnectioncontextlib 工具
@contextmanager 裝飾器
用生成器函式建立上下文管理器,比定義類別更簡潔:
1from contextlib import contextmanager
2from typing import Iterator
3
4@contextmanager
5def timer(name: str) -> Iterator[None]:
6 """計時上下文管理器"""
7 import time
8 start = time.perf_counter()
9 print(f"Starting {name}...")
10
11 try:
12 yield # 這裡是 with 區塊執行的地方
13 finally:
14 elapsed = time.perf_counter() - start
15 print(f"{name} took {elapsed:.3f}s")
16
17# 使用
18with timer("data processing"):
19 process_large_dataset()
20
21# 如果需要返回值
22@contextmanager
23def temp_directory() -> Iterator[Path]:
24 """建立臨時目錄,結束後自動清理"""
25 import tempfile
26 import shutil
27 from pathlib import Path
28
29 path = Path(tempfile.mkdtemp())
30 try:
31 yield path # path 會被 as 子句捕獲
32 finally:
33 shutil.rmtree(path)
34
35with temp_directory() as tmpdir:
36 (tmpdir / "test.txt").write_text("hello")ExitStack
動態管理多個上下文管理器:
1from contextlib import ExitStack
2
3def process_multiple_files(filenames: list[str]) -> list[str]:
4 """處理多個檔案,確保全部關閉"""
5 with ExitStack() as stack:
6 files = [
7 stack.enter_context(open(fn))
8 for fn in filenames
9 ]
10 return [f.read() for f in files]
11
12# 更複雜的例子:條件式資源
13def connect_services(config: dict) -> dict:
14 """根據配置連接多個服務"""
15 with ExitStack() as stack:
16 services = {}
17
18 if config.get("database"):
19 services["db"] = stack.enter_context(
20 DatabaseConnection(config["database"])
21 )
22
23 if config.get("cache"):
24 services["cache"] = stack.enter_context(
25 CacheConnection(config["cache"])
26 )
27
28 if config.get("queue"):
29 services["queue"] = stack.enter_context(
30 QueueConnection(config["queue"])
31 )
32
33 # 所有連接都會在離開時按相反順序關閉
34 return do_work(services)ExitStack 的回調功能
1from contextlib import ExitStack
2
3def complex_setup() -> None:
4 with ExitStack() as stack:
5 # 註冊清理回調
6 stack.callback(print, "Cleanup 1")
7 stack.callback(print, "Cleanup 2")
8
9 # 推遲上下文管理器
10 cm = some_context_manager()
11 stack.push(cm) # 不立即進入,但會在結束時呼叫 __exit__
12
13 # 做一些工作...
14 pass
15
16 # 離開時會執行:
17 # 1. cm.__exit__()
18 # 2. print("Cleanup 2")
19 # 3. print("Cleanup 1") # 注意順序是相反的nullcontext
需要可選的上下文管理器時使用:
1from contextlib import nullcontext
2
3def process_data(lock: Lock | None = None) -> None:
4 """可選的鎖定"""
5 with lock if lock else nullcontext():
6 # 處理資料
7 pass
8
9# 更清楚的寫法
10def process_data(lock: Lock | None = None) -> None:
11 cm = lock if lock else nullcontext()
12 with cm:
13 pass
14
15# 帶返回值的 nullcontext
16from contextlib import nullcontext
17
18def get_stream(filename: str | None) -> Iterator[TextIO]:
19 if filename:
20 return open(filename)
21 return nullcontext(sys.stdout) # 返回 stdout
22
23with get_stream(None) as f:
24 f.write("Hello") # 寫到 stdout嵌套與組合上下文
資源的有序獲取與釋放
1# 傳統嵌套
2with open("input.txt") as infile:
3 with open("output.txt", "w") as outfile:
4 outfile.write(infile.read())
5
6# Python 3.9+ 可以用括號分組
7with (
8 open("input.txt") as infile,
9 open("output.txt", "w") as outfile
10):
11 outfile.write(infile.read())
12
13# 多個資源:ExitStack 更靈活
14with ExitStack() as stack:
15 files = [stack.enter_context(open(f)) for f in filenames]組合上下文管理器
1from contextlib import contextmanager
2from typing import Iterator
3
4@contextmanager
5def database_transaction(db: Database) -> Iterator[Transaction]:
6 """資料庫交易上下文"""
7 tx = db.begin_transaction()
8 try:
9 yield tx
10 tx.commit()
11 except Exception:
12 tx.rollback()
13 raise
14
15@contextmanager
16def acquire_lock(lock: Lock, timeout: float = 30) -> Iterator[None]:
17 """帶超時的鎖定獲取"""
18 if not lock.acquire(timeout=timeout):
19 raise TimeoutError(f"Could not acquire lock within {timeout}s")
20 try:
21 yield
22 finally:
23 lock.release()
24
25# 組合使用
26@contextmanager
27def safe_update(db: Database, lock: Lock) -> Iterator[Transaction]:
28 """帶鎖定的安全更新"""
29 with acquire_lock(lock):
30 with database_transaction(db) as tx:
31 yield txasync context manager
基本語法
1class AsyncResource:
2 """非同步上下文管理器"""
3
4 async def __aenter__(self) -> "AsyncResource":
5 await self.connect()
6 return self
7
8 async def __aexit__(
9 self,
10 exc_type: type[BaseException] | None,
11 exc_val: BaseException | None,
12 exc_tb: TracebackType | None
13 ) -> bool:
14 await self.disconnect()
15 return False
16
17# 使用
18async def main():
19 async with AsyncResource() as resource:
20 await resource.do_something()@asynccontextmanager
1from contextlib import asynccontextmanager
2from typing import AsyncIterator
3
4@asynccontextmanager
5async def async_timer(name: str) -> AsyncIterator[None]:
6 """非同步計時器"""
7 import time
8 start = time.perf_counter()
9 print(f"Starting {name}...")
10
11 try:
12 yield
13 finally:
14 elapsed = time.perf_counter() - start
15 print(f"{name} took {elapsed:.3f}s")
16
17async def main():
18 async with async_timer("async task"):
19 await asyncio.sleep(1)與 asyncio 模組的連結
1import asyncio
2from contextlib import asynccontextmanager
3
4@asynccontextmanager
5async def managed_task(coro) -> AsyncIterator[asyncio.Task]:
6 """管理 Task 生命週期"""
7 task = asyncio.create_task(coro)
8 try:
9 yield task
10 finally:
11 if not task.done():
12 task.cancel()
13 try:
14 await task
15 except asyncio.CancelledError:
16 pass
17
18async def main():
19 async with managed_task(background_worker()) as task:
20 # 做一些工作
21 await asyncio.sleep(5)
22 # 離開時自動取消 task實際範例:交易管理器
結合所有概念,建立一個完整的交易管理器:
1from contextlib import contextmanager
2from typing import Iterator, Protocol
3from dataclasses import dataclass, field
4from enum import Enum, auto
5
6class TransactionState(Enum):
7 PENDING = auto()
8 ACTIVE = auto()
9 COMMITTED = auto()
10 ROLLED_BACK = auto()
11
12class Transactional(Protocol):
13 """可參與交易的資源協議"""
14
15 def begin(self) -> None:
16 ...
17
18 def commit(self) -> None:
19 ...
20
21 def rollback(self) -> None:
22 ...
23
24@dataclass
25class Transaction:
26 """交易物件"""
27 id: str
28 resources: list[Transactional] = field(default_factory=list)
29 state: TransactionState = TransactionState.PENDING
30
31 def add_resource(self, resource: Transactional) -> None:
32 if self.state != TransactionState.ACTIVE:
33 raise RuntimeError("Transaction is not active")
34 resource.begin()
35 self.resources.append(resource)
36
37 def commit(self) -> None:
38 if self.state != TransactionState.ACTIVE:
39 raise RuntimeError("Transaction is not active")
40
41 try:
42 for resource in self.resources:
43 resource.commit()
44 self.state = TransactionState.COMMITTED
45 except Exception:
46 self.rollback()
47 raise
48
49 def rollback(self) -> None:
50 if self.state not in (TransactionState.ACTIVE, TransactionState.PENDING):
51 return
52
53 errors = []
54 for resource in reversed(self.resources):
55 try:
56 resource.rollback()
57 except Exception as e:
58 errors.append(e)
59
60 self.state = TransactionState.ROLLED_BACK
61
62 if errors:
63 raise ExceptionGroup("Rollback failed", errors)
64
65class TransactionManager:
66 """交易管理器"""
67
68 def __init__(self) -> None:
69 self._tx_counter = 0
70
71 @contextmanager
72 def transaction(self) -> Iterator[Transaction]:
73 """建立新交易"""
74 self._tx_counter += 1
75 tx = Transaction(id=f"tx-{self._tx_counter}")
76 tx.state = TransactionState.ACTIVE
77
78 try:
79 yield tx
80 tx.commit()
81 except Exception:
82 tx.rollback()
83 raise
84
85# 使用範例
86class DatabaseResource:
87 """模擬資料庫資源"""
88
89 def __init__(self, name: str) -> None:
90 self.name = name
91
92 def begin(self) -> None:
93 print(f"{self.name}: BEGIN")
94
95 def commit(self) -> None:
96 print(f"{self.name}: COMMIT")
97
98 def rollback(self) -> None:
99 print(f"{self.name}: ROLLBACK")
100
101def main():
102 manager = TransactionManager()
103 db1 = DatabaseResource("primary")
104 db2 = DatabaseResource("replica")
105
106 with manager.transaction() as tx:
107 tx.add_resource(db1)
108 tx.add_resource(db2)
109
110 # 做一些操作...
111 print("Doing work in transaction")
112
113 # 如果拋出異常,兩個資源都會 rollback
114 # raise ValueError("Something went wrong")
115
116 # 正常結束時,兩個資源都會 commit常見錯誤
1. 忘記 yield
1@contextmanager
2def broken():
3 print("enter")
4 # 忘記 yield!
5 print("exit")
6
7# 使用時會報錯:generator didn't yield2. 在 finally 中 yield
1@contextmanager
2def also_broken():
3 try:
4 yield
5 finally:
6 yield # 錯誤!不能在 finally 中 yield3. 異常處理不當
1@contextmanager
2def risky():
3 resource = acquire_resource()
4 yield resource
5 release_resource(resource) # 如果 with 區塊拋異常,這行不會執行!
6
7# 正確做法
8@contextmanager
9def safe():
10 resource = acquire_resource()
11 try:
12 yield resource
13 finally:
14 release_resource(resource) # 一定會執行小結
| 概念 | 用途 |
|---|---|
__enter__/__exit__ | 上下文管理器協議 |
@contextmanager | 用生成器建立上下文管理器 |
ExitStack | 動態管理多個上下文 |
nullcontext | 可選的上下文管理器 |
async with | 非同步上下文管理 |
@asynccontextmanager | 用非同步生成器建立上下文管理器 |
思考題
__exit__返回True和False的差別是什麼?- 什麼時候應該用
ExitStack而不是巢狀with? @contextmanager中yield前後的程式碼分別對應什麼?
上一章:3.5.2 異常設計架構 下一章:3.5.4 插件系統設計