5.2 PyO3 基礎
5.2 PyO3 基礎
本章介紹 PyO3,Rust 的 Python 綁定函式庫。
本章目標
學完本章後,你將能夠:
- 理解 PyO3 的設計原理
- 使用 #[pyfunction] 和 #[pyclass] 建立綁定
- 處理型別轉換和錯誤
【原理層】PyO3 的設計
PyO3 是什麼?
PyO3 是 Rust 與 Python 之間的橋樑:
1PyO3 提供:
2├── Rust → Python:將 Rust 程式碼編譯為 Python 模組
3├── Python → Rust:在 Rust 中嵌入 Python 直譯器
4├── 型別轉換:自動處理 Rust ↔ Python 型別
5└── GIL 管理:安全地處理 Python 的全域直譯器鎖版本要求
1# Cargo.toml(2025 年建議)
2[dependencies]
3pyo3 = { version = "0.23", features = ["extension-module"] }
4
5# 支援的版本:
6# - Rust: 1.63+(建議 1.75+)
7# - Python: 3.8+
8# - PyO3: 0.23+(支援 Free-threading)與 Python 的互動模型
1Python 程式
2 │
3 ↓ import
4┌─────────────────────────────────┐
5│ Rust 編譯的 .so/.pyd 模組 │
6│ ┌─────────────────────────┐ │
7│ │ PyO3 綁定層 │ │
8│ │ - 型別轉換 │ │
9│ │ - GIL 管理 │ │
10│ │ - 錯誤處理 │ │
11│ └─────────────────────────┘ │
12│ ┌─────────────────────────┐ │
13│ │ 純 Rust 程式碼 │ │
14│ │ - 核心邏輯 │ │
15│ │ - 無 GIL 限制 │ │
16│ └─────────────────────────┘ │
17└─────────────────────────────────┘【設計層】專案設定
Cargo.toml 設定
1[package]
2name = "my_rust_module"
3version = "0.1.0"
4edition = "2021"
5
6[lib]
7name = "my_rust_module"
8crate-type = ["cdylib"] # 編譯為動態連結庫
9
10[dependencies]
11pyo3 = { version = "0.23", features = ["extension-module"] }
12
13# 選用功能
14# pyo3 = { version = "0.23", features = [
15# "extension-module",
16# "abi3-py38", # 穩定 ABI(支援 Python 3.8+)
17# "multiple-pymethods",
18# ]}基本模組結構
1// src/lib.rs
2use pyo3::prelude::*;
3
4/// 簡單的加法函式
5#[pyfunction]
6fn add(a: i64, b: i64) -> i64 {
7 a + b
8}
9
10/// 建立 Python 模組
11#[pymodule]
12fn my_rust_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
13 m.add_function(wrap_pyfunction!(add, m)?)?;
14 Ok(())
15}abi3:穩定 ABI
1# 啟用 abi3 的好處:
2# 1. 一次編譯,多版本 Python 使用
3# 2. 減少發布的 wheel 數量
4# 3. 更好的向前相容性
5
6[dependencies]
7pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"] }1不使用 abi3:
2├── my_module-cp38-cp38-linux_x86_64.whl
3├── my_module-cp39-cp39-linux_x86_64.whl
4├── my_module-cp310-cp310-linux_x86_64.whl
5├── my_module-cp311-cp311-linux_x86_64.whl
6└── my_module-cp312-cp312-linux_x86_64.whl
7
8使用 abi3-py38:
9└── my_module-cp38-abi3-linux_x86_64.whl # 支援 Python 3.8+【實作層】函式綁定
#[pyfunction] 基礎
1use pyo3::prelude::*;
2
3// 基本函式
4#[pyfunction]
5fn greet(name: &str) -> String {
6 format!("Hello, {}!", name)
7}
8
9// 帶預設參數
10#[pyfunction]
11#[pyo3(signature = (a, b=1.0))]
12fn divide(a: f64, b: f64) -> f64 {
13 a / b
14}
15
16// 可變參數
17#[pyfunction]
18#[pyo3(signature = (*args))]
19fn sum_all(args: Vec<i64>) -> i64 {
20 args.iter().sum()
21}
22
23// 關鍵字參數
24#[pyfunction]
25#[pyo3(signature = (**kwargs))]
26fn print_kwargs(kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<()> {
27 if let Some(dict) = kwargs {
28 for (key, value) in dict.iter() {
29 println!("{}: {}", key, value);
30 }
31 }
32 Ok(())
33}
34
35// 文件字串
36/// 計算兩個數的最大公因數
37///
38/// Args:
39/// a: 第一個整數
40/// b: 第二個整數
41///
42/// Returns:
43/// 最大公因數
44#[pyfunction]
45fn gcd(a: u64, b: u64) -> u64 {
46 if b == 0 { a } else { gcd(b, a % b) }
47}
48
49#[pymodule]
50fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
51 m.add_function(wrap_pyfunction!(greet, m)?)?;
52 m.add_function(wrap_pyfunction!(divide, m)?)?;
53 m.add_function(wrap_pyfunction!(sum_all, m)?)?;
54 m.add_function(wrap_pyfunction!(print_kwargs, m)?)?;
55 m.add_function(wrap_pyfunction!(gcd, m)?)?;
56 Ok(())
57}型別轉換
1use pyo3::prelude::*;
2use pyo3::types::{PyList, PyDict, PyTuple};
3
4// 自動型別轉換
5#[pyfunction]
6fn process_list(items: Vec<i64>) -> Vec<i64> {
7 // Python list 自動轉換為 Vec
8 items.iter().map(|x| x * 2).collect()
9}
10
11#[pyfunction]
12fn process_dict(data: HashMap<String, i64>) -> i64 {
13 // Python dict 自動轉換為 HashMap
14 data.values().sum()
15}
16
17// 手動處理 Python 物件
18#[pyfunction]
19fn manual_conversion(py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult<String> {
20 // 檢查型別
21 if obj.is_instance_of::<PyList>() {
22 let list = obj.downcast::<PyList>()?;
23 Ok(format!("List with {} items", list.len()))
24 } else if obj.is_instance_of::<PyDict>() {
25 let dict = obj.downcast::<PyDict>()?;
26 Ok(format!("Dict with {} keys", dict.len()))
27 } else {
28 Ok(format!("Unknown type: {}", obj.get_type().name()?))
29 }
30}
31
32// 回傳多個值(使用 tuple)
33#[pyfunction]
34fn divmod(a: i64, b: i64) -> (i64, i64) {
35 (a / b, a % b)
36}
37
38// 回傳 Option(轉換為 None 或值)
39#[pyfunction]
40fn find_item(items: Vec<i64>, target: i64) -> Option<usize> {
41 items.iter().position(|&x| x == target)
42}【實作層】類別綁定
#[pyclass] 基礎
1use pyo3::prelude::*;
2
3#[pyclass]
4struct Point {
5 #[pyo3(get, set)] // 自動產生 getter 和 setter
6 x: f64,
7 #[pyo3(get, set)]
8 y: f64,
9}
10
11#[pymethods]
12impl Point {
13 // 建構子
14 #[new]
15 fn new(x: f64, y: f64) -> Self {
16 Point { x, y }
17 }
18
19 // 方法
20 fn distance(&self, other: &Point) -> f64 {
21 let dx = self.x - other.x;
22 let dy = self.y - other.y;
23 (dx * dx + dy * dy).sqrt()
24 }
25
26 // 類別方法
27 #[classmethod]
28 fn origin(_cls: &Bound<'_, PyType>) -> Self {
29 Point { x: 0.0, y: 0.0 }
30 }
31
32 // 靜態方法
33 #[staticmethod]
34 fn from_polar(r: f64, theta: f64) -> Self {
35 Point {
36 x: r * theta.cos(),
37 y: r * theta.sin(),
38 }
39 }
40
41 // __repr__
42 fn __repr__(&self) -> String {
43 format!("Point({}, {})", self.x, self.y)
44 }
45
46 // __str__
47 fn __str__(&self) -> String {
48 format!("({}, {})", self.x, self.y)
49 }
50}
51
52#[pymodule]
53fn geometry(m: &Bound<'_, PyModule>) -> PyResult<()> {
54 m.add_class::<Point>()?;
55 Ok(())
56}Python 使用:
1from geometry import Point
2
3p1 = Point(3.0, 4.0)
4p2 = Point.origin() # 類別方法
5p3 = Point.from_polar(5.0, 0.927) # 靜態方法
6
7print(p1) # (3.0, 4.0)
8print(repr(p1)) # Point(3.0, 4.0)
9print(p1.distance(p2)) # 5.0
10
11p1.x = 10.0 # setter
12print(p1.x) # getter運算子重載
1use pyo3::prelude::*;
2use std::ops::{Add, Sub, Mul};
3
4#[pyclass]
5#[derive(Clone)]
6struct Vector2D {
7 x: f64,
8 y: f64,
9}
10
11#[pymethods]
12impl Vector2D {
13 #[new]
14 fn new(x: f64, y: f64) -> Self {
15 Vector2D { x, y }
16 }
17
18 // __add__
19 fn __add__(&self, other: &Vector2D) -> Vector2D {
20 Vector2D {
21 x: self.x + other.x,
22 y: self.y + other.y,
23 }
24 }
25
26 // __sub__
27 fn __sub__(&self, other: &Vector2D) -> Vector2D {
28 Vector2D {
29 x: self.x - other.x,
30 y: self.y - other.y,
31 }
32 }
33
34 // __mul__(標量乘法)
35 fn __mul__(&self, scalar: f64) -> Vector2D {
36 Vector2D {
37 x: self.x * scalar,
38 y: self.y * scalar,
39 }
40 }
41
42 // __rmul__(右乘)
43 fn __rmul__(&self, scalar: f64) -> Vector2D {
44 self.__mul__(scalar)
45 }
46
47 // __neg__
48 fn __neg__(&self) -> Vector2D {
49 Vector2D {
50 x: -self.x,
51 y: -self.y,
52 }
53 }
54
55 // __eq__
56 fn __eq__(&self, other: &Vector2D) -> bool {
57 (self.x - other.x).abs() < 1e-10 &&
58 (self.y - other.y).abs() < 1e-10
59 }
60
61 // __len__(向量維度)
62 fn __len__(&self) -> usize {
63 2
64 }
65
66 // __getitem__
67 fn __getitem__(&self, index: usize) -> PyResult<f64> {
68 match index {
69 0 => Ok(self.x),
70 1 => Ok(self.y),
71 _ => Err(PyIndexError::new_err("Index out of range")),
72 }
73 }
74
75 fn __repr__(&self) -> String {
76 format!("Vector2D({}, {})", self.x, self.y)
77 }
78}繼承與多型
1use pyo3::prelude::*;
2
3// 基礎類別
4#[pyclass(subclass)] // 允許被繼承
5struct Animal {
6 #[pyo3(get)]
7 name: String,
8}
9
10#[pymethods]
11impl Animal {
12 #[new]
13 fn new(name: String) -> Self {
14 Animal { name }
15 }
16
17 // 可被覆寫的方法
18 fn speak(&self) -> String {
19 "...".to_string()
20 }
21}
22
23// 子類別
24#[pyclass(extends=Animal)]
25struct Dog {}
26
27#[pymethods]
28impl Dog {
29 #[new]
30 fn new(name: String) -> (Self, Animal) {
31 (Dog {}, Animal { name })
32 }
33
34 fn speak(&self) -> String {
35 "Woof!".to_string()
36 }
37}
38
39#[pyclass(extends=Animal)]
40struct Cat {}
41
42#[pymethods]
43impl Cat {
44 #[new]
45 fn new(name: String) -> (Self, Animal) {
46 (Cat {}, Animal { name })
47 }
48
49 fn speak(&self) -> String {
50 "Meow!".to_string()
51 }
52}【實作層】錯誤處理
PyResult 與錯誤轉換
1use pyo3::prelude::*;
2use pyo3::exceptions::{PyValueError, PyTypeError, PyIOError};
3
4// 回傳 PyResult
5#[pyfunction]
6fn safe_divide(a: f64, b: f64) -> PyResult<f64> {
7 if b == 0.0 {
8 Err(PyValueError::new_err("除數不能為零"))
9 } else {
10 Ok(a / b)
11 }
12}
13
14// 自訂錯誤類型
15use std::io;
16
17fn read_file_internal(path: &str) -> Result<String, io::Error> {
18 std::fs::read_to_string(path)
19}
20
21#[pyfunction]
22fn read_file(path: &str) -> PyResult<String> {
23 read_file_internal(path).map_err(|e| {
24 PyIOError::new_err(format!("無法讀取檔案: {}", e))
25 })
26}
27
28// 使用 ? 運算子
29#[pyfunction]
30fn parse_and_double(s: &str) -> PyResult<i64> {
31 let num: i64 = s.parse().map_err(|_| {
32 PyValueError::new_err(format!("無法解析為整數: {}", s))
33 })?;
34 Ok(num * 2)
35}
36
37// 自動轉換 Rust 錯誤
38use thiserror::Error;
39
40#[derive(Error, Debug)]
41enum MyError {
42 #[error("數值錯誤: {0}")]
43 ValueError(String),
44 #[error("IO 錯誤: {0}")]
45 IoError(#[from] io::Error),
46}
47
48impl From<MyError> for PyErr {
49 fn from(err: MyError) -> PyErr {
50 match err {
51 MyError::ValueError(msg) => PyValueError::new_err(msg),
52 MyError::IoError(e) => PyIOError::new_err(e.to_string()),
53 }
54 }
55}
56
57#[pyfunction]
58fn risky_operation(value: i64) -> Result<i64, MyError> {
59 if value < 0 {
60 Err(MyError::ValueError("負數不允許".to_string()))
61 } else {
62 Ok(value * 2)
63 }
64}自訂異常
1use pyo3::prelude::*;
2use pyo3::create_exception;
3
4// 建立自訂異常
5create_exception!(my_module, ValidationError, pyo3::exceptions::PyException);
6create_exception!(my_module, ProcessingError, pyo3::exceptions::PyException);
7
8#[pyfunction]
9fn validate_data(data: &str) -> PyResult<()> {
10 if data.is_empty() {
11 return Err(ValidationError::new_err("資料不能為空"));
12 }
13 if data.len() > 100 {
14 return Err(ValidationError::new_err("資料太長"));
15 }
16 Ok(())
17}
18
19#[pymodule]
20fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
21 m.add("ValidationError", m.py().get_type::<ValidationError>())?;
22 m.add("ProcessingError", m.py().get_type::<ProcessingError>())?;
23 m.add_function(wrap_pyfunction!(validate_data, m)?)?;
24 Ok(())
25}【實作層】GIL 管理
釋放 GIL
1use pyo3::prelude::*;
2
3// CPU 密集計算,應該釋放 GIL
4#[pyfunction]
5fn heavy_computation(n: u64) -> f64 {
6 // 釋放 GIL,允許其他 Python 執行緒執行
7 Python::with_gil(|py| {
8 py.allow_threads(|| {
9 // 這裡的程式碼不持有 GIL
10 let mut result = 0.0;
11 for i in 0..n {
12 result += (i as f64).sin() * (i as f64).cos();
13 }
14 result
15 })
16 })
17}
18
19// 或者使用 Python 參數
20#[pyfunction]
21fn parallel_sum(py: Python<'_>, data: Vec<f64>) -> f64 {
22 py.allow_threads(|| {
23 // 可以安全地使用多執行緒
24 use rayon::prelude::*;
25 data.par_iter().sum()
26 })
27}需要 GIL 的操作
1use pyo3::prelude::*;
2
3#[pyfunction]
4fn callback_example(py: Python<'_>, callback: PyObject) -> PyResult<()> {
5 // 模擬一些計算
6 let results: Vec<i64> = py.allow_threads(|| {
7 (0..10).map(|x| x * x).collect()
8 });
9
10 // 呼叫 Python 回呼需要 GIL
11 for result in results {
12 callback.call1(py, (result,))?;
13 }
14
15 Ok(())
16}
17
18#[pyfunction]
19fn mixed_workload(py: Python<'_>, n: u64) -> PyResult<Vec<f64>> {
20 let mut results = Vec::new();
21
22 for i in 0..n {
23 // 計算(釋放 GIL)
24 let value = py.allow_threads(|| {
25 (i as f64).sin()
26 });
27
28 // Python 互動(需要 GIL)
29 results.push(value);
30
31 // 定期檢查是否有 Python 訊號(如 Ctrl+C)
32 if i % 1000 == 0 {
33 py.check_signals()?;
34 }
35 }
36
37 Ok(results)
38}【進階】與 NumPy 整合
numpy crate
1[dependencies]
2pyo3 = { version = "0.23", features = ["extension-module"] }
3numpy = "0.23"
4ndarray = "0.16" 1use pyo3::prelude::*;
2use numpy::{PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2};
3use ndarray::{Array1, Array2};
4
5// 接受 NumPy 陣列
6#[pyfunction]
7fn array_sum(arr: PyReadonlyArray1<'_, f64>) -> f64 {
8 arr.as_array().sum()
9}
10
11// 回傳 NumPy 陣列
12#[pyfunction]
13fn create_range(py: Python<'_>, n: usize) -> Bound<'_, PyArray1<f64>> {
14 let arr: Array1<f64> = Array1::from_iter((0..n).map(|x| x as f64));
15 PyArray1::from_owned_array(py, arr)
16}
17
18// 處理 2D 陣列
19#[pyfunction]
20fn matrix_multiply<'py>(
21 py: Python<'py>,
22 a: PyReadonlyArray2<'py, f64>,
23 b: PyReadonlyArray2<'py, f64>,
24) -> PyResult<Bound<'py, PyArray2<f64>>> {
25 let a = a.as_array();
26 let b = b.as_array();
27
28 // 檢查維度
29 if a.ncols() != b.nrows() {
30 return Err(PyValueError::new_err("矩陣維度不匹配"));
31 }
32
33 // 計算(釋放 GIL)
34 let result = py.allow_threads(|| {
35 a.dot(&b)
36 });
37
38 Ok(PyArray2::from_owned_array(py, result))
39}
40
41// 原地修改
42#[pyfunction]
43fn normalize_inplace(mut arr: PyReadwriteArray1<'_, f64>) {
44 let mut arr = arr.as_array_mut();
45 let sum: f64 = arr.sum();
46 if sum != 0.0 {
47 arr.mapv_inplace(|x| x / sum);
48 }
49}思考題
- PyO3 如何處理 Rust 的所有權系統和 Python 的垃圾回收之間的衝突?
- 什麼時候應該使用
py.allow_threads()?有什麼風險? - 為什麼 PyO3 使用
Bound<'_, T>而不是直接傳遞 Python 物件?
實作練習
- 使用 PyO3 實現一個簡單的 Counter 類別,支援
+、-、+=運算子 - 實現一個接受 Python 回呼的函式,用於處理大量資料
- 使用 numpy crate 實現一個高效能的向量運算函式
延伸閱讀
上一章:為什麼選擇 Rust? 下一章:Maturin 開發流程