5.4 實戰案例分析
5.4 實戰案例分析
本章分析幾個使用 Rust 的知名 Python 專案,學習實際應用的模式。
本章目標
學完本章後,你將能夠:
- 理解 Rust 在數值計算的應用
- 理解 Rust 在文字處理的應用
- 評估自己的專案是否適合使用 Rust
【案例一】數值計算:實現快速排序
需求分析
1場景:需要對大量數值資料進行排序
2問題:Python 內建排序雖然是 C 實現,但有特殊需求時需要自訂
3
4實現目標:
5├── 支援自訂比較函式
6├── 支援並行排序(大資料集)
7├── 與 NumPy 整合
8└── 效能接近或超過 NumPy sortRust 實現
1// src/lib.rs
2use pyo3::prelude::*;
3use numpy::{PyArray1, PyReadonlyArray1};
4use rayon::prelude::*;
5
6/// 並行快速排序
7#[pyfunction]
8fn parallel_sort(py: Python<'_>, arr: PyReadonlyArray1<'_, f64>) -> Bound<'_, PyArray1<f64>> {
9 let arr = arr.as_array();
10
11 // 釋放 GIL 進行並行排序
12 let mut data: Vec<f64> = py.allow_threads(|| {
13 let mut data: Vec<f64> = arr.to_vec();
14 data.par_sort_by(|a, b| a.partial_cmp(b).unwrap());
15 data
16 });
17
18 PyArray1::from_vec(py, data)
19}
20
21/// 部分排序(只排序前 k 個元素)
22#[pyfunction]
23fn partial_sort(py: Python<'_>, arr: PyReadonlyArray1<'_, f64>, k: usize) -> PyResult<Bound<'_, PyArray1<f64>>> {
24 let arr = arr.as_array();
25
26 if k > arr.len() {
27 return Err(PyValueError::new_err("k 大於陣列長度"));
28 }
29
30 let result = py.allow_threads(|| {
31 let mut data: Vec<f64> = arr.to_vec();
32 // 使用 select_nth_unstable 獲得前 k 個最小元素
33 data.select_nth_unstable_by(k, |a, b| a.partial_cmp(b).unwrap());
34 data[..k].to_vec()
35 });
36
37 Ok(PyArray1::from_vec(py, result))
38}
39
40/// 找出 top-k 元素(不完全排序,更快)
41#[pyfunction]
42fn top_k(py: Python<'_>, arr: PyReadonlyArray1<'_, f64>, k: usize) -> PyResult<Bound<'_, PyArray1<f64>>> {
43 use std::collections::BinaryHeap;
44 use std::cmp::Ordering;
45
46 // 包裝 f64 以支援 BinaryHeap
47 #[derive(PartialEq)]
48 struct MinFloat(f64);
49
50 impl Eq for MinFloat {}
51
52 impl PartialOrd for MinFloat {
53 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
54 // 反向比較,使 BinaryHeap 成為 min-heap
55 other.0.partial_cmp(&self.0)
56 }
57 }
58
59 impl Ord for MinFloat {
60 fn cmp(&self, other: &Self) -> Ordering {
61 self.partial_cmp(other).unwrap()
62 }
63 }
64
65 let arr = arr.as_array();
66
67 let result = py.allow_threads(|| {
68 let mut heap: BinaryHeap<MinFloat> = BinaryHeap::with_capacity(k + 1);
69
70 for &x in arr.iter() {
71 heap.push(MinFloat(x));
72 if heap.len() > k {
73 heap.pop();
74 }
75 }
76
77 heap.into_sorted_vec().into_iter().map(|x| x.0).collect::<Vec<_>>()
78 });
79
80 Ok(PyArray1::from_vec(py, result))
81}
82
83#[pymodule]
84fn fast_sort(m: &Bound<'_, PyModule>) -> PyResult<()> {
85 m.add_function(wrap_pyfunction!(parallel_sort, m)?)?;
86 m.add_function(wrap_pyfunction!(partial_sort, m)?)?;
87 m.add_function(wrap_pyfunction!(top_k, m)?)?;
88 Ok(())
89}效能比較
1import numpy as np
2import fast_sort
3import timeit
4
5# 測試資料
6n = 1_000_000
7data = np.random.rand(n)
8
9# NumPy sort
10t_numpy = timeit.timeit(lambda: np.sort(data.copy()), number=10)
11
12# Rust parallel sort
13t_rust = timeit.timeit(lambda: fast_sort.parallel_sort(data), number=10)
14
15# 找 top-1000(Rust)
16t_topk_rust = timeit.timeit(lambda: fast_sort.top_k(data, 1000), number=10)
17
18# 找 top-1000(NumPy: 完整排序後取前 k)
19t_topk_numpy = timeit.timeit(lambda: np.sort(data)[:1000], number=10)
20
21print(f"完整排序 - NumPy: {t_numpy:.3f}s, Rust: {t_rust:.3f}s")
22print(f"Top-1000 - NumPy: {t_topk_numpy:.3f}s, Rust: {t_topk_rust:.3f}s")
23
24# 預期結果(依硬體而異):
25# 完整排序 - NumPy: 0.85s, Rust: 0.45s(使用多核心)
26# Top-1000 - NumPy: 0.85s, Rust: 0.02s(不需完整排序)【案例二】文字處理:高效能 Tokenizer
需求分析
1場景:NLP 應用需要將文字切分為 tokens
2問題:純 Python 實現太慢,無法處理大量文字
3
4實現目標:
5├── 支援 Unicode
6├── 支援正規表達式模式
7├── 批次處理
8└── 與現有 NLP 工具整合Rust 實現
1// src/lib.rs
2use pyo3::prelude::*;
3use regex::Regex;
4use rayon::prelude::*;
5
6#[pyclass]
7struct Tokenizer {
8 pattern: Regex,
9 lowercase: bool,
10}
11
12#[pymethods]
13impl Tokenizer {
14 #[new]
15 #[pyo3(signature = (pattern=r"\w+", lowercase=true))]
16 fn new(pattern: &str, lowercase: bool) -> PyResult<Self> {
17 let pattern = Regex::new(pattern)
18 .map_err(|e| PyValueError::new_err(format!("無效的正規表達式: {}", e)))?;
19 Ok(Tokenizer { pattern, lowercase })
20 }
21
22 /// 對單一字串進行 tokenization
23 fn tokenize(&self, text: &str) -> Vec<String> {
24 self.pattern
25 .find_iter(text)
26 .map(|m| {
27 let s = m.as_str();
28 if self.lowercase {
29 s.to_lowercase()
30 } else {
31 s.to_string()
32 }
33 })
34 .collect()
35 }
36
37 /// 批次 tokenization(並行處理)
38 fn tokenize_batch(&self, py: Python<'_>, texts: Vec<String>) -> Vec<Vec<String>> {
39 py.allow_threads(|| {
40 texts
41 .par_iter()
42 .map(|text| self.tokenize(text))
43 .collect()
44 })
45 }
46
47 /// 計算詞頻
48 fn count_tokens(&self, py: Python<'_>, text: &str) -> HashMap<String, usize> {
49 use std::collections::HashMap;
50
51 py.allow_threads(|| {
52 let mut counts = HashMap::new();
53 for mat in self.pattern.find_iter(text) {
54 let token = if self.lowercase {
55 mat.as_str().to_lowercase()
56 } else {
57 mat.as_str().to_string()
58 };
59 *counts.entry(token).or_insert(0) += 1;
60 }
61 counts
62 })
63 }
64}
65
66// 簡單的 BPE(Byte Pair Encoding)實現
67#[pyclass]
68struct SimpleBPE {
69 vocab: HashMap<String, u32>,
70 merges: Vec<(String, String)>,
71}
72
73#[pymethods]
74impl SimpleBPE {
75 #[new]
76 fn new() -> Self {
77 SimpleBPE {
78 vocab: HashMap::new(),
79 merges: Vec::new(),
80 }
81 }
82
83 /// 訓練 BPE
84 fn train(&mut self, py: Python<'_>, texts: Vec<String>, vocab_size: usize) {
85 use std::collections::HashMap;
86
87 py.allow_threads(|| {
88 // 初始化:每個字元是一個 token
89 let mut word_freqs: HashMap<Vec<String>, usize> = HashMap::new();
90
91 for text in &texts {
92 for word in text.split_whitespace() {
93 let chars: Vec<String> = word.chars().map(|c| c.to_string()).collect();
94 *word_freqs.entry(chars).or_insert(0) += 1;
95 }
96 }
97
98 // 迭代合併最頻繁的 pair
99 while self.vocab.len() < vocab_size {
100 // 計算 pair 頻率
101 let mut pair_freqs: HashMap<(String, String), usize> = HashMap::new();
102
103 for (word, freq) in &word_freqs {
104 for i in 0..word.len().saturating_sub(1) {
105 let pair = (word[i].clone(), word[i + 1].clone());
106 *pair_freqs.entry(pair).or_insert(0) += freq;
107 }
108 }
109
110 // 找出最頻繁的 pair
111 if let Some((best_pair, _)) = pair_freqs.iter().max_by_key(|(_, &freq)| freq) {
112 let new_token = format!("{}{}", best_pair.0, best_pair.1);
113 self.merges.push(best_pair.clone());
114 self.vocab.insert(new_token.clone(), self.vocab.len() as u32);
115
116 // 更新 word_freqs
117 let mut new_word_freqs = HashMap::new();
118 for (word, freq) in word_freqs {
119 let mut new_word = Vec::new();
120 let mut i = 0;
121 while i < word.len() {
122 if i + 1 < word.len() && word[i] == best_pair.0 && word[i + 1] == best_pair.1 {
123 new_word.push(new_token.clone());
124 i += 2;
125 } else {
126 new_word.push(word[i].clone());
127 i += 1;
128 }
129 }
130 *new_word_freqs.entry(new_word).or_insert(0) += freq;
131 }
132 word_freqs = new_word_freqs;
133 } else {
134 break;
135 }
136 }
137 });
138 }
139
140 /// 編碼文字
141 fn encode(&self, text: &str) -> Vec<u32> {
142 let mut tokens: Vec<String> = text.chars().map(|c| c.to_string()).collect();
143
144 for (a, b) in &self.merges {
145 let merged = format!("{}{}", a, b);
146 let mut new_tokens = Vec::new();
147 let mut i = 0;
148 while i < tokens.len() {
149 if i + 1 < tokens.len() && &tokens[i] == a && &tokens[i + 1] == b {
150 new_tokens.push(merged.clone());
151 i += 2;
152 } else {
153 new_tokens.push(tokens[i].clone());
154 i += 1;
155 }
156 }
157 tokens = new_tokens;
158 }
159
160 tokens
161 .iter()
162 .filter_map(|t| self.vocab.get(t).copied())
163 .collect()
164 }
165}
166
167use std::collections::HashMap;
168
169#[pymodule]
170fn fast_tokenizer(m: &Bound<'_, PyModule>) -> PyResult<()> {
171 m.add_class::<Tokenizer>()?;
172 m.add_class::<SimpleBPE>()?;
173 Ok(())
174}使用範例
1from fast_tokenizer import Tokenizer, SimpleBPE
2
3# 基本 tokenization
4tokenizer = Tokenizer(r"\w+", lowercase=True)
5
6text = "Hello, World! This is a test."
7tokens = tokenizer.tokenize(text)
8print(tokens) # ['hello', 'world', 'this', 'is', 'a', 'test']
9
10# 批次處理
11texts = ["First sentence.", "Second sentence.", "Third one."]
12batch_tokens = tokenizer.tokenize_batch(texts)
13print(batch_tokens)
14
15# 詞頻統計
16counts = tokenizer.count_tokens("the cat sat on the mat")
17print(counts) # {'the': 2, 'cat': 1, 'sat': 1, 'on': 1, 'mat': 1}
18
19# BPE 訓練
20bpe = SimpleBPE()
21corpus = ["hello world", "hello there", "world peace"]
22bpe.train(corpus, vocab_size=100)
23encoded = bpe.encode("hello world")
24print(encoded)【案例三】資料驗證:Pydantic 風格驗證器
需求分析
1場景:API 需要驗證大量輸入資料
2問題:純 Python 驗證太慢(Pydantic v1 的問題)
3
4實現目標:
5├── 型別檢查
6├── 範圍驗證
7├── 自訂驗證函式
8└── 清晰的錯誤訊息Rust 實現
1// src/lib.rs
2use pyo3::prelude::*;
3use pyo3::exceptions::PyValueError;
4use std::collections::HashMap;
5
6// 驗證錯誤
7#[pyclass]
8#[derive(Clone)]
9struct ValidationError {
10 #[pyo3(get)]
11 field: String,
12 #[pyo3(get)]
13 message: String,
14}
15
16#[pymethods]
17impl ValidationError {
18 fn __repr__(&self) -> String {
19 format!("ValidationError(field='{}', message='{}')", self.field, self.message)
20 }
21}
22
23// 欄位驗證器
24#[pyclass]
25struct Field {
26 name: String,
27 field_type: String,
28 required: bool,
29 min_value: Option<f64>,
30 max_value: Option<f64>,
31 min_length: Option<usize>,
32 max_length: Option<usize>,
33 pattern: Option<regex::Regex>,
34}
35
36#[pymethods]
37impl Field {
38 #[new]
39 #[pyo3(signature = (name, field_type, required=true, min_value=None, max_value=None, min_length=None, max_length=None, pattern=None))]
40 fn new(
41 name: String,
42 field_type: String,
43 required: bool,
44 min_value: Option<f64>,
45 max_value: Option<f64>,
46 min_length: Option<usize>,
47 max_length: Option<usize>,
48 pattern: Option<String>,
49 ) -> PyResult<Self> {
50 let pattern = match pattern {
51 Some(p) => Some(regex::Regex::new(&p)
52 .map_err(|e| PyValueError::new_err(format!("無效的正規表達式: {}", e)))?),
53 None => None,
54 };
55
56 Ok(Field {
57 name,
58 field_type,
59 required,
60 min_value,
61 max_value,
62 min_length,
63 max_length,
64 pattern,
65 })
66 }
67}
68
69// Schema 驗證器
70#[pyclass]
71struct Schema {
72 fields: Vec<Field>,
73}
74
75#[pymethods]
76impl Schema {
77 #[new]
78 fn new(fields: Vec<Py<Field>>) -> PyResult<Self> {
79 Python::with_gil(|py| {
80 let fields: Vec<Field> = fields
81 .iter()
82 .map(|f| f.borrow(py).clone())
83 .collect();
84 Ok(Schema { fields })
85 })
86 }
87
88 /// 驗證單一物件
89 fn validate(&self, py: Python<'_>, data: &Bound<'_, PyDict>) -> PyResult<Vec<ValidationError>> {
90 let mut errors = Vec::new();
91
92 for field in &self.fields {
93 let value = data.get_item(&field.name)?;
94
95 match value {
96 None => {
97 if field.required {
98 errors.push(ValidationError {
99 field: field.name.clone(),
100 message: "此欄位為必填".to_string(),
101 });
102 }
103 }
104 Some(v) => {
105 // 型別檢查
106 match field.field_type.as_str() {
107 "int" => {
108 if let Ok(num) = v.extract::<i64>() {
109 // 範圍檢查
110 if let Some(min) = field.min_value {
111 if (num as f64) < min {
112 errors.push(ValidationError {
113 field: field.name.clone(),
114 message: format!("值必須 >= {}", min),
115 });
116 }
117 }
118 if let Some(max) = field.max_value {
119 if (num as f64) > max {
120 errors.push(ValidationError {
121 field: field.name.clone(),
122 message: format!("值必須 <= {}", max),
123 });
124 }
125 }
126 } else {
127 errors.push(ValidationError {
128 field: field.name.clone(),
129 message: "必須是整數".to_string(),
130 });
131 }
132 }
133 "float" => {
134 if let Ok(num) = v.extract::<f64>() {
135 if let Some(min) = field.min_value {
136 if num < min {
137 errors.push(ValidationError {
138 field: field.name.clone(),
139 message: format!("值必須 >= {}", min),
140 });
141 }
142 }
143 if let Some(max) = field.max_value {
144 if num > max {
145 errors.push(ValidationError {
146 field: field.name.clone(),
147 message: format!("值必須 <= {}", max),
148 });
149 }
150 }
151 } else {
152 errors.push(ValidationError {
153 field: field.name.clone(),
154 message: "必須是浮點數".to_string(),
155 });
156 }
157 }
158 "str" => {
159 if let Ok(s) = v.extract::<String>() {
160 // 長度檢查
161 if let Some(min_len) = field.min_length {
162 if s.len() < min_len {
163 errors.push(ValidationError {
164 field: field.name.clone(),
165 message: format!("長度必須 >= {}", min_len),
166 });
167 }
168 }
169 if let Some(max_len) = field.max_length {
170 if s.len() > max_len {
171 errors.push(ValidationError {
172 field: field.name.clone(),
173 message: format!("長度必須 <= {}", max_len),
174 });
175 }
176 }
177 // 正規表達式檢查
178 if let Some(ref pattern) = field.pattern {
179 if !pattern.is_match(&s) {
180 errors.push(ValidationError {
181 field: field.name.clone(),
182 message: "格式不符合要求".to_string(),
183 });
184 }
185 }
186 } else {
187 errors.push(ValidationError {
188 field: field.name.clone(),
189 message: "必須是字串".to_string(),
190 });
191 }
192 }
193 _ => {}
194 }
195 }
196 }
197 }
198
199 Ok(errors)
200 }
201
202 /// 批次驗證
203 fn validate_batch(&self, py: Python<'_>, data_list: Vec<Bound<'_, PyDict>>) -> PyResult<Vec<Vec<ValidationError>>> {
204 let mut results = Vec::new();
205 for data in data_list {
206 results.push(self.validate(py, &data)?);
207 }
208 Ok(results)
209 }
210}
211
212use pyo3::types::PyDict;
213
214#[pymodule]
215fn fast_validator(m: &Bound<'_, PyModule>) -> PyResult<()> {
216 m.add_class::<Field>()?;
217 m.add_class::<Schema>()?;
218 m.add_class::<ValidationError>()?;
219 Ok(())
220}使用範例
1from fast_validator import Field, Schema
2
3# 定義 schema
4schema = Schema([
5 Field("name", "str", required=True, min_length=1, max_length=100),
6 Field("age", "int", required=True, min_value=0, max_value=150),
7 Field("email", "str", required=True, pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$"),
8 Field("score", "float", required=False, min_value=0, max_value=100),
9])
10
11# 驗證資料
12data = {"name": "Alice", "age": 30, "email": "alice@example.com"}
13errors = schema.validate(data)
14if errors:
15 for e in errors:
16 print(f"{e.field}: {e.message}")
17else:
18 print("驗證通過")
19
20# 無效資料
21invalid_data = {"name": "", "age": -5, "email": "invalid"}
22errors = schema.validate(invalid_data)
23for e in errors:
24 print(f"{e.field}: {e.message}")
25# name: 長度必須 >= 1
26# age: 值必須 >= 0
27# email: 格式不符合要求【總結】何時使用 Rust
決策清單
1應該使用 Rust:
2
31. 效能瓶頸明確
4 □ profiler 顯示特定函式占用大量時間
5 □ 純 Python 優化已到極限
6 □ 現有 C 擴展不滿足需求
7
82. 資料處理需求
9 □ 大量數值計算
10 □ 頻繁的字串處理
11 □ 需要並行處理
12
133. 安全性要求
14 □ 處理不可信的輸入
15 □ 需要避免記憶體錯誤
16 □ 長期運行的服務
17
184. 跨平台需求
19 □ 需要支援多個作業系統
20 □ 需要支援多個 Python 版本
21
22可能不需要 Rust:
23
241. 效能不是主要瓶頸
252. 團隊沒有 Rust 經驗且時間緊迫
263. 專案規模小且不會長期維護
274. 可以用現有函式庫解決最佳實踐
11. 設計階段
2 ├── 明確定義 Python/Rust 邊界
3 ├── 最小化跨語言呼叫次數
4 ├── 使用批次處理減少開銷
5 └── 設計清晰的錯誤處理
6
72. 開發階段
8 ├── 先用 Python 原型驗證邏輯
9 ├── 逐步將瓶頸移到 Rust
10 ├── 完善的測試覆蓋
11 └── 使用 maturin develop 快速迭代
12
133. 發布階段
14 ├── 使用 CI/CD 自動建構 wheel
15 ├── 支援主流 Python 版本
16 ├── 提供 fallback 純 Python 實現
17 └── 清楚的安裝文件思考題
- 在設計 Rust 擴展的 API 時,如何平衡效能和易用性?
- 如何處理 Rust 函式庫沒有 Python 綁定的情況?
- 在什麼情況下,應該用 Cython 而不是 Rust?
實作練習
- 選擇一個你常用的純 Python 函式,用 Rust 重寫並比較效能
- 分析 Polars 的原始碼結構,理解大型 Rust Python 專案的組織方式
- 實現一個簡單的 JSON parser,比較與 Python
json模組的效能差異
延伸閱讀
上一章:Maturin 開發流程 下一模組:模組六:打包與發布