本章分析幾個使用 Rust 的知名 Python 專案,學習實際應用的模式。

本章目標

學完本章後,你將能夠:

  1. 理解 Rust 在數值計算的應用
  2. 理解 Rust 在文字處理的應用
  3. 評估自己的專案是否適合使用 Rust

【案例一】數值計算:實現快速排序

需求分析

1場景:需要對大量數值資料進行排序
2問題:Python 內建排序雖然是 C 實現,但有特殊需求時需要自訂
3
4實現目標:
5├── 支援自訂比較函式
6├── 支援並行排序(大資料集)
7├── 與 NumPy 整合
8└── 效能接近或超過 NumPy sort

Rust 實現

 1// src/lib.rs
 2use pyo3::prelude::*;
 3use numpy::{PyArray1, PyReadonlyArray1};
 4use rayon::prelude::*;
 5
 6/// 並行快速排序
 7#[pyfunction]
 8fn parallel_sort(py: Python<'_>, arr: PyReadonlyArray1<'_, f64>) -> Bound<'_, PyArray1<f64>> {
 9    let arr = arr.as_array();
10
11    // 釋放 GIL 進行並行排序
12    let mut data: Vec<f64> = py.allow_threads(|| {
13        let mut data: Vec<f64> = arr.to_vec();
14        data.par_sort_by(|a, b| a.partial_cmp(b).unwrap());
15        data
16    });
17
18    PyArray1::from_vec(py, data)
19}
20
21/// 部分排序(只排序前 k 個元素)
22#[pyfunction]
23fn partial_sort(py: Python<'_>, arr: PyReadonlyArray1<'_, f64>, k: usize) -> PyResult<Bound<'_, PyArray1<f64>>> {
24    let arr = arr.as_array();
25
26    if k > arr.len() {
27        return Err(PyValueError::new_err("k 大於陣列長度"));
28    }
29
30    let result = py.allow_threads(|| {
31        let mut data: Vec<f64> = arr.to_vec();
32        // 使用 select_nth_unstable 獲得前 k 個最小元素
33        data.select_nth_unstable_by(k, |a, b| a.partial_cmp(b).unwrap());
34        data[..k].to_vec()
35    });
36
37    Ok(PyArray1::from_vec(py, result))
38}
39
40/// 找出 top-k 元素(不完全排序,更快)
41#[pyfunction]
42fn top_k(py: Python<'_>, arr: PyReadonlyArray1<'_, f64>, k: usize) -> PyResult<Bound<'_, PyArray1<f64>>> {
43    use std::collections::BinaryHeap;
44    use std::cmp::Ordering;
45
46    // 包裝 f64 以支援 BinaryHeap
47    #[derive(PartialEq)]
48    struct MinFloat(f64);
49
50    impl Eq for MinFloat {}
51
52    impl PartialOrd for MinFloat {
53        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
54            // 反向比較,使 BinaryHeap 成為 min-heap
55            other.0.partial_cmp(&self.0)
56        }
57    }
58
59    impl Ord for MinFloat {
60        fn cmp(&self, other: &Self) -> Ordering {
61            self.partial_cmp(other).unwrap()
62        }
63    }
64
65    let arr = arr.as_array();
66
67    let result = py.allow_threads(|| {
68        let mut heap: BinaryHeap<MinFloat> = BinaryHeap::with_capacity(k + 1);
69
70        for &x in arr.iter() {
71            heap.push(MinFloat(x));
72            if heap.len() > k {
73                heap.pop();
74            }
75        }
76
77        heap.into_sorted_vec().into_iter().map(|x| x.0).collect::<Vec<_>>()
78    });
79
80    Ok(PyArray1::from_vec(py, result))
81}
82
83#[pymodule]
84fn fast_sort(m: &Bound<'_, PyModule>) -> PyResult<()> {
85    m.add_function(wrap_pyfunction!(parallel_sort, m)?)?;
86    m.add_function(wrap_pyfunction!(partial_sort, m)?)?;
87    m.add_function(wrap_pyfunction!(top_k, m)?)?;
88    Ok(())
89}

效能比較

 1import numpy as np
 2import fast_sort
 3import timeit
 4
 5# 測試資料
 6n = 1_000_000
 7data = np.random.rand(n)
 8
 9# NumPy sort
10t_numpy = timeit.timeit(lambda: np.sort(data.copy()), number=10)
11
12# Rust parallel sort
13t_rust = timeit.timeit(lambda: fast_sort.parallel_sort(data), number=10)
14
15# 找 top-1000(Rust)
16t_topk_rust = timeit.timeit(lambda: fast_sort.top_k(data, 1000), number=10)
17
18# 找 top-1000(NumPy: 完整排序後取前 k)
19t_topk_numpy = timeit.timeit(lambda: np.sort(data)[:1000], number=10)
20
21print(f"完整排序 - NumPy: {t_numpy:.3f}s, Rust: {t_rust:.3f}s")
22print(f"Top-1000 - NumPy: {t_topk_numpy:.3f}s, Rust: {t_topk_rust:.3f}s")
23
24# 預期結果(依硬體而異):
25# 完整排序 - NumPy: 0.85s, Rust: 0.45s(使用多核心)
26# Top-1000 - NumPy: 0.85s, Rust: 0.02s(不需完整排序)

【案例二】文字處理:高效能 Tokenizer

需求分析

1場景:NLP 應用需要將文字切分為 tokens
2問題:純 Python 實現太慢,無法處理大量文字
3
4實現目標:
5├── 支援 Unicode
6├── 支援正規表達式模式
7├── 批次處理
8└── 與現有 NLP 工具整合

Rust 實現

  1// src/lib.rs
  2use pyo3::prelude::*;
  3use regex::Regex;
  4use rayon::prelude::*;
  5
  6#[pyclass]
  7struct Tokenizer {
  8    pattern: Regex,
  9    lowercase: bool,
 10}
 11
 12#[pymethods]
 13impl Tokenizer {
 14    #[new]
 15    #[pyo3(signature = (pattern=r"\w+", lowercase=true))]
 16    fn new(pattern: &str, lowercase: bool) -> PyResult<Self> {
 17        let pattern = Regex::new(pattern)
 18            .map_err(|e| PyValueError::new_err(format!("無效的正規表達式: {}", e)))?;
 19        Ok(Tokenizer { pattern, lowercase })
 20    }
 21
 22    /// 對單一字串進行 tokenization
 23    fn tokenize(&self, text: &str) -> Vec<String> {
 24        self.pattern
 25            .find_iter(text)
 26            .map(|m| {
 27                let s = m.as_str();
 28                if self.lowercase {
 29                    s.to_lowercase()
 30                } else {
 31                    s.to_string()
 32                }
 33            })
 34            .collect()
 35    }
 36
 37    /// 批次 tokenization(並行處理)
 38    fn tokenize_batch(&self, py: Python<'_>, texts: Vec<String>) -> Vec<Vec<String>> {
 39        py.allow_threads(|| {
 40            texts
 41                .par_iter()
 42                .map(|text| self.tokenize(text))
 43                .collect()
 44        })
 45    }
 46
 47    /// 計算詞頻
 48    fn count_tokens(&self, py: Python<'_>, text: &str) -> HashMap<String, usize> {
 49        use std::collections::HashMap;
 50
 51        py.allow_threads(|| {
 52            let mut counts = HashMap::new();
 53            for mat in self.pattern.find_iter(text) {
 54                let token = if self.lowercase {
 55                    mat.as_str().to_lowercase()
 56                } else {
 57                    mat.as_str().to_string()
 58                };
 59                *counts.entry(token).or_insert(0) += 1;
 60            }
 61            counts
 62        })
 63    }
 64}
 65
 66// 簡單的 BPE(Byte Pair Encoding)實現
 67#[pyclass]
 68struct SimpleBPE {
 69    vocab: HashMap<String, u32>,
 70    merges: Vec<(String, String)>,
 71}
 72
 73#[pymethods]
 74impl SimpleBPE {
 75    #[new]
 76    fn new() -> Self {
 77        SimpleBPE {
 78            vocab: HashMap::new(),
 79            merges: Vec::new(),
 80        }
 81    }
 82
 83    /// 訓練 BPE
 84    fn train(&mut self, py: Python<'_>, texts: Vec<String>, vocab_size: usize) {
 85        use std::collections::HashMap;
 86
 87        py.allow_threads(|| {
 88            // 初始化:每個字元是一個 token
 89            let mut word_freqs: HashMap<Vec<String>, usize> = HashMap::new();
 90
 91            for text in &texts {
 92                for word in text.split_whitespace() {
 93                    let chars: Vec<String> = word.chars().map(|c| c.to_string()).collect();
 94                    *word_freqs.entry(chars).or_insert(0) += 1;
 95                }
 96            }
 97
 98            // 迭代合併最頻繁的 pair
 99            while self.vocab.len() < vocab_size {
100                // 計算 pair 頻率
101                let mut pair_freqs: HashMap<(String, String), usize> = HashMap::new();
102
103                for (word, freq) in &word_freqs {
104                    for i in 0..word.len().saturating_sub(1) {
105                        let pair = (word[i].clone(), word[i + 1].clone());
106                        *pair_freqs.entry(pair).or_insert(0) += freq;
107                    }
108                }
109
110                // 找出最頻繁的 pair
111                if let Some((best_pair, _)) = pair_freqs.iter().max_by_key(|(_, &freq)| freq) {
112                    let new_token = format!("{}{}", best_pair.0, best_pair.1);
113                    self.merges.push(best_pair.clone());
114                    self.vocab.insert(new_token.clone(), self.vocab.len() as u32);
115
116                    // 更新 word_freqs
117                    let mut new_word_freqs = HashMap::new();
118                    for (word, freq) in word_freqs {
119                        let mut new_word = Vec::new();
120                        let mut i = 0;
121                        while i < word.len() {
122                            if i + 1 < word.len() && word[i] == best_pair.0 && word[i + 1] == best_pair.1 {
123                                new_word.push(new_token.clone());
124                                i += 2;
125                            } else {
126                                new_word.push(word[i].clone());
127                                i += 1;
128                            }
129                        }
130                        *new_word_freqs.entry(new_word).or_insert(0) += freq;
131                    }
132                    word_freqs = new_word_freqs;
133                } else {
134                    break;
135                }
136            }
137        });
138    }
139
140    /// 編碼文字
141    fn encode(&self, text: &str) -> Vec<u32> {
142        let mut tokens: Vec<String> = text.chars().map(|c| c.to_string()).collect();
143
144        for (a, b) in &self.merges {
145            let merged = format!("{}{}", a, b);
146            let mut new_tokens = Vec::new();
147            let mut i = 0;
148            while i < tokens.len() {
149                if i + 1 < tokens.len() && &tokens[i] == a && &tokens[i + 1] == b {
150                    new_tokens.push(merged.clone());
151                    i += 2;
152                } else {
153                    new_tokens.push(tokens[i].clone());
154                    i += 1;
155                }
156            }
157            tokens = new_tokens;
158        }
159
160        tokens
161            .iter()
162            .filter_map(|t| self.vocab.get(t).copied())
163            .collect()
164    }
165}
166
167use std::collections::HashMap;
168
169#[pymodule]
170fn fast_tokenizer(m: &Bound<'_, PyModule>) -> PyResult<()> {
171    m.add_class::<Tokenizer>()?;
172    m.add_class::<SimpleBPE>()?;
173    Ok(())
174}

使用範例

 1from fast_tokenizer import Tokenizer, SimpleBPE
 2
 3# 基本 tokenization
 4tokenizer = Tokenizer(r"\w+", lowercase=True)
 5
 6text = "Hello, World! This is a test."
 7tokens = tokenizer.tokenize(text)
 8print(tokens)  # ['hello', 'world', 'this', 'is', 'a', 'test']
 9
10# 批次處理
11texts = ["First sentence.", "Second sentence.", "Third one."]
12batch_tokens = tokenizer.tokenize_batch(texts)
13print(batch_tokens)
14
15# 詞頻統計
16counts = tokenizer.count_tokens("the cat sat on the mat")
17print(counts)  # {'the': 2, 'cat': 1, 'sat': 1, 'on': 1, 'mat': 1}
18
19# BPE 訓練
20bpe = SimpleBPE()
21corpus = ["hello world", "hello there", "world peace"]
22bpe.train(corpus, vocab_size=100)
23encoded = bpe.encode("hello world")
24print(encoded)

【案例三】資料驗證:Pydantic 風格驗證器

需求分析

1場景:API 需要驗證大量輸入資料
2問題:純 Python 驗證太慢(Pydantic v1 的問題)
3
4實現目標:
5├── 型別檢查
6├── 範圍驗證
7├── 自訂驗證函式
8└── 清晰的錯誤訊息

Rust 實現

  1// src/lib.rs
  2use pyo3::prelude::*;
  3use pyo3::exceptions::PyValueError;
  4use std::collections::HashMap;
  5
  6// 驗證錯誤
  7#[pyclass]
  8#[derive(Clone)]
  9struct ValidationError {
 10    #[pyo3(get)]
 11    field: String,
 12    #[pyo3(get)]
 13    message: String,
 14}
 15
 16#[pymethods]
 17impl ValidationError {
 18    fn __repr__(&self) -> String {
 19        format!("ValidationError(field='{}', message='{}')", self.field, self.message)
 20    }
 21}
 22
 23// 欄位驗證器
 24#[pyclass]
 25struct Field {
 26    name: String,
 27    field_type: String,
 28    required: bool,
 29    min_value: Option<f64>,
 30    max_value: Option<f64>,
 31    min_length: Option<usize>,
 32    max_length: Option<usize>,
 33    pattern: Option<regex::Regex>,
 34}
 35
 36#[pymethods]
 37impl Field {
 38    #[new]
 39    #[pyo3(signature = (name, field_type, required=true, min_value=None, max_value=None, min_length=None, max_length=None, pattern=None))]
 40    fn new(
 41        name: String,
 42        field_type: String,
 43        required: bool,
 44        min_value: Option<f64>,
 45        max_value: Option<f64>,
 46        min_length: Option<usize>,
 47        max_length: Option<usize>,
 48        pattern: Option<String>,
 49    ) -> PyResult<Self> {
 50        let pattern = match pattern {
 51            Some(p) => Some(regex::Regex::new(&p)
 52                .map_err(|e| PyValueError::new_err(format!("無效的正規表達式: {}", e)))?),
 53            None => None,
 54        };
 55
 56        Ok(Field {
 57            name,
 58            field_type,
 59            required,
 60            min_value,
 61            max_value,
 62            min_length,
 63            max_length,
 64            pattern,
 65        })
 66    }
 67}
 68
 69// Schema 驗證器
 70#[pyclass]
 71struct Schema {
 72    fields: Vec<Field>,
 73}
 74
 75#[pymethods]
 76impl Schema {
 77    #[new]
 78    fn new(fields: Vec<Py<Field>>) -> PyResult<Self> {
 79        Python::with_gil(|py| {
 80            let fields: Vec<Field> = fields
 81                .iter()
 82                .map(|f| f.borrow(py).clone())
 83                .collect();
 84            Ok(Schema { fields })
 85        })
 86    }
 87
 88    /// 驗證單一物件
 89    fn validate(&self, py: Python<'_>, data: &Bound<'_, PyDict>) -> PyResult<Vec<ValidationError>> {
 90        let mut errors = Vec::new();
 91
 92        for field in &self.fields {
 93            let value = data.get_item(&field.name)?;
 94
 95            match value {
 96                None => {
 97                    if field.required {
 98                        errors.push(ValidationError {
 99                            field: field.name.clone(),
100                            message: "此欄位為必填".to_string(),
101                        });
102                    }
103                }
104                Some(v) => {
105                    // 型別檢查
106                    match field.field_type.as_str() {
107                        "int" => {
108                            if let Ok(num) = v.extract::<i64>() {
109                                // 範圍檢查
110                                if let Some(min) = field.min_value {
111                                    if (num as f64) < min {
112                                        errors.push(ValidationError {
113                                            field: field.name.clone(),
114                                            message: format!("值必須 >= {}", min),
115                                        });
116                                    }
117                                }
118                                if let Some(max) = field.max_value {
119                                    if (num as f64) > max {
120                                        errors.push(ValidationError {
121                                            field: field.name.clone(),
122                                            message: format!("值必須 <= {}", max),
123                                        });
124                                    }
125                                }
126                            } else {
127                                errors.push(ValidationError {
128                                    field: field.name.clone(),
129                                    message: "必須是整數".to_string(),
130                                });
131                            }
132                        }
133                        "float" => {
134                            if let Ok(num) = v.extract::<f64>() {
135                                if let Some(min) = field.min_value {
136                                    if num < min {
137                                        errors.push(ValidationError {
138                                            field: field.name.clone(),
139                                            message: format!("值必須 >= {}", min),
140                                        });
141                                    }
142                                }
143                                if let Some(max) = field.max_value {
144                                    if num > max {
145                                        errors.push(ValidationError {
146                                            field: field.name.clone(),
147                                            message: format!("值必須 <= {}", max),
148                                        });
149                                    }
150                                }
151                            } else {
152                                errors.push(ValidationError {
153                                    field: field.name.clone(),
154                                    message: "必須是浮點數".to_string(),
155                                });
156                            }
157                        }
158                        "str" => {
159                            if let Ok(s) = v.extract::<String>() {
160                                // 長度檢查
161                                if let Some(min_len) = field.min_length {
162                                    if s.len() < min_len {
163                                        errors.push(ValidationError {
164                                            field: field.name.clone(),
165                                            message: format!("長度必須 >= {}", min_len),
166                                        });
167                                    }
168                                }
169                                if let Some(max_len) = field.max_length {
170                                    if s.len() > max_len {
171                                        errors.push(ValidationError {
172                                            field: field.name.clone(),
173                                            message: format!("長度必須 <= {}", max_len),
174                                        });
175                                    }
176                                }
177                                // 正規表達式檢查
178                                if let Some(ref pattern) = field.pattern {
179                                    if !pattern.is_match(&s) {
180                                        errors.push(ValidationError {
181                                            field: field.name.clone(),
182                                            message: "格式不符合要求".to_string(),
183                                        });
184                                    }
185                                }
186                            } else {
187                                errors.push(ValidationError {
188                                    field: field.name.clone(),
189                                    message: "必須是字串".to_string(),
190                                });
191                            }
192                        }
193                        _ => {}
194                    }
195                }
196            }
197        }
198
199        Ok(errors)
200    }
201
202    /// 批次驗證
203    fn validate_batch(&self, py: Python<'_>, data_list: Vec<Bound<'_, PyDict>>) -> PyResult<Vec<Vec<ValidationError>>> {
204        let mut results = Vec::new();
205        for data in data_list {
206            results.push(self.validate(py, &data)?);
207        }
208        Ok(results)
209    }
210}
211
212use pyo3::types::PyDict;
213
214#[pymodule]
215fn fast_validator(m: &Bound<'_, PyModule>) -> PyResult<()> {
216    m.add_class::<Field>()?;
217    m.add_class::<Schema>()?;
218    m.add_class::<ValidationError>()?;
219    Ok(())
220}

使用範例

 1from fast_validator import Field, Schema
 2
 3# 定義 schema
 4schema = Schema([
 5    Field("name", "str", required=True, min_length=1, max_length=100),
 6    Field("age", "int", required=True, min_value=0, max_value=150),
 7    Field("email", "str", required=True, pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$"),
 8    Field("score", "float", required=False, min_value=0, max_value=100),
 9])
10
11# 驗證資料
12data = {"name": "Alice", "age": 30, "email": "alice@example.com"}
13errors = schema.validate(data)
14if errors:
15    for e in errors:
16        print(f"{e.field}: {e.message}")
17else:
18    print("驗證通過")
19
20# 無效資料
21invalid_data = {"name": "", "age": -5, "email": "invalid"}
22errors = schema.validate(invalid_data)
23for e in errors:
24    print(f"{e.field}: {e.message}")
25# name: 長度必須 >= 1
26# age: 值必須 >= 0
27# email: 格式不符合要求

【總結】何時使用 Rust

決策清單

 1應該使用 Rust:
 2
 31. 效能瓶頸明確
 4   □ profiler 顯示特定函式占用大量時間
 5   □ 純 Python 優化已到極限
 6   □ 現有 C 擴展不滿足需求
 7
 82. 資料處理需求
 9   □ 大量數值計算
10   □ 頻繁的字串處理
11   □ 需要並行處理
12
133. 安全性要求
14   □ 處理不可信的輸入
15   □ 需要避免記憶體錯誤
16   □ 長期運行的服務
17
184. 跨平台需求
19   □ 需要支援多個作業系統
20   □ 需要支援多個 Python 版本
21
22可能不需要 Rust:
23
241. 效能不是主要瓶頸
252. 團隊沒有 Rust 經驗且時間緊迫
263. 專案規模小且不會長期維護
274. 可以用現有函式庫解決

最佳實踐

 11. 設計階段
 2   ├── 明確定義 Python/Rust 邊界
 3   ├── 最小化跨語言呼叫次數
 4   ├── 使用批次處理減少開銷
 5   └── 設計清晰的錯誤處理
 6
 72. 開發階段
 8   ├── 先用 Python 原型驗證邏輯
 9   ├── 逐步將瓶頸移到 Rust
10   ├── 完善的測試覆蓋
11   └── 使用 maturin develop 快速迭代
12
133. 發布階段
14   ├── 使用 CI/CD 自動建構 wheel
15   ├── 支援主流 Python 版本
16   ├── 提供 fallback 純 Python 實現
17   └── 清楚的安裝文件

思考題

  1. 在設計 Rust 擴展的 API 時,如何平衡效能和易用性?
  2. 如何處理 Rust 函式庫沒有 Python 綁定的情況?
  3. 在什麼情況下,應該用 Cython 而不是 Rust?

實作練習

  1. 選擇一個你常用的純 Python 函式,用 Rust 重寫並比較效能
  2. 分析 Polars 的原始碼結構,理解大型 Rust Python 專案的組織方式
  3. 實現一個簡單的 JSON parser,比較與 Python json 模組的效能差異

延伸閱讀


上一章:Maturin 開發流程 下一模組:模組六:打包與發布