Rust实现线程池 | PHM's world

LOADING

歡迎來到烏托邦的世界

Rust实现线程池

Rust手写线程池

实现思路:
实现两个链表,生产者链表和消费者链表联和处理。

消息传递:使用 Rust 的 channel 创建一个任务通道,sender 负责发送任务,receiver 负责接收任务。

任务存储:在 ReceiverWrapper 中使用 VecDeque 存储任务队列。

use std::sync::{Arc, Mutex, Condvar};
use std::thread;
use std::sync::mpsc::{channel, Sender, Receiver};
use std::collections::VecDeque;

struct ThreadPool {
    workers: Vec<Worker>,
    sender: Sender<Message>,
}

impl ThreadPool {
    fn new(num_workers: usize, max_jobs: usize) -> ThreadPool {
        assert!(num_workers > 0);
        let (sender, receiver) = channel();
        let receiver = Arc::new(Mutex::new(ReceiverWrapper { receiver, max_jobs }));

        let mut workers = Vec::with_capacity(num_workers);
        for id in 0..num_workers {
            workers.push(Worker::new(id, Arc::clone(&receiver)));
        }

        ThreadPool { workers, sender }
    }

    fn execute<F>(&self, f: F)
    where
        F: FnOnce() + Send + 'static,
    {
        let job = Box::new(f);
        self.sender.send(Message::NewJob(job)).unwrap();
    }
}

impl Drop for ThreadPool {
    fn drop(&mut self) {
        for _ in &self.workers {
            self.sender.send(Message::Terminate).unwrap();
        }

        for worker in &mut self.workers {
            if let Some(thread) = worker.thread.take() {
                thread.join().unwrap();
            }
        }
    }
}

type Job = Box<dyn FnOnce() + Send + 'static>;

enum Message {
    NewJob(Job),
    Terminate,
}

struct Worker {
    id: usize,
    thread: Option<thread::JoinHandle<()>>,
}

impl Worker {
    fn new(id: usize, receiver: Arc<Mutex<ReceiverWrapper>>) -> Worker {
        let thread = thread::spawn(move || loop {
            let message = {
                let mut receiver = receiver.lock().unwrap();
                while receiver.jobs.len() == 0 {
                    if receiver.terminate {
                        return;
                    }
                    receiver.condvar.wait(&mut receiver).unwrap();
                }
                receiver.jobs.pop_front().unwrap()
            };

            match message {
                Message::NewJob(job) => {
                    println!("Worker {} got a job; executing.", id);
                    job();
                }
                Message::Terminate => {
                    println!("Worker {} was told to terminate.", id);
                    break;
                }
            }
        });

        Worker {
            id,
            thread: Some(thread),
        }
    }
}

struct ReceiverWrapper {
    receiver: Receiver<Message>,
    jobs: VecDeque<Message>,
    max_jobs: usize,
    terminate: bool,
    condvar: Condvar,
}

impl ReceiverWrapper {
    fn new(receiver: Receiver<Message>, max_jobs: usize) -> Self {
        ReceiverWrapper {
            receiver,
            jobs: VecDeque::new(),
            max_jobs,
            terminate: false,
            condvar: Condvar::new(),
        }
    }
}

fn main() {
    let pool = ThreadPool::new(4, 10);

    for i in 0..20 {
        pool.execute(move || {
            println!("Task {} is being processed", i);
        });
    }
}

另外:

use std::sync::{Arc, Mutex, Condvar};
use std::thread;
use std::sync::mpsc::{channel, Sender, Receiver};
use std::collections::VecDeque;

struct ThreadPool {
    workers: Vec<Worker>,
    sender: Sender<Message>,
}

struct Worker {
    thread: Option<thread::JoinHandle<()>>,
    terminate: Arc<Mutex<bool>>,
}

enum Message {
    NewJob(Job),
    Terminate,
}

type Job = Box<dyn FnOnce() + Send + 'static>;

struct JobWrapper {
    func: Box<dyn FnOnce() + Send + 'static>,
    user_data: Vec<u8>,
}

impl ThreadPool {
    fn new(num_workers: usize, max_jobs: usize) -> ThreadPool {
        let (sender, receiver) = channel();
        let receiver = Arc::new(Mutex::new(ReceiverWrapper::new(receiver, max_jobs)));
        let mut workers = Vec::with_capacity(num_workers);

        for _ in 0..num_workers {
            workers.push(Worker::new(Arc::clone(&receiver)));
        }

        ThreadPool { workers, sender }
    }

    fn push_job<F>(&self, func: F, arg: Vec<u8>) -> Result<(), &'static str>
    where
        F: FnOnce(Vec<u8>) + Send + 'static,
    {
        let job = JobWrapper {
            func: Box::new(func),
            user_data: arg,
        };
        self.sender.send(Message::NewJob(Box::new(move || (job.func)(job.user_data))))
            .map_err(|_| "Failed to send job")
    }
}

impl Drop for ThreadPool {
    fn drop(&mut self) {
        for _ in &self.workers {
            self.sender.send(Message::Terminate).unwrap();
        }

        for worker in &mut self.workers {
            if let Some(thread) = worker.thread.take() {
                thread.join().unwrap();
            }
        }
    }
}

struct ReceiverWrapper {
    receiver: Receiver<Message>,
    jobs: VecDeque<Message>,
    max_jobs: usize,
    condvar: Condvar,
    terminate: bool,
}

impl ReceiverWrapper {
    fn new(receiver: Receiver<Message>, max_jobs: usize) -> Self {
        ReceiverWrapper {
            receiver,
            jobs: VecDeque::with_capacity(max_jobs),
            max_jobs,
            condvar: Condvar::new(),
            terminate: false,
        }
    }
}

impl Worker {
    fn new(receiver: Arc<Mutex<ReceiverWrapper>>) -> Worker {
        let terminate = Arc::new(Mutex::new(false));
        let terminate_clone = Arc::clone(&terminate);

        let thread = thread::spawn(move || loop {
            let message = {
                let mut receiver = receiver.lock().unwrap();
                while receiver.jobs.is_empty() && !*receiver.terminate {
                    receiver = receiver.condvar.wait(receiver).unwrap();
                }
                if *receiver.terminate && receiver.jobs.is_empty() {
                    break;
                }
                receiver.jobs.pop_front().unwrap()
            };

            match message {
                Message::NewJob(job) => {
                    job();
                }
                Message::Terminate => {
                    break;
                }
            }
        });

        Worker {
            thread: Some(thread),
            terminate: terminate_clone,
        }
    }
}

fn main() {
    let pool = ThreadPool::new(4, 10);

    loop {
        let arg = vec![i as u8]; // Dummy argument
        pool.push_job(move |data| {
            println!("Task {} is being processed with data: {:?}", i, data);
        }, arg).unwrap();
    }
}

入口函数:

use std::sync::{Arc, Mutex, Condvar};
use std::thread;
use std::sync::mpsc::{channel, Sender, Receiver};
use std::collections::VecDeque;
use std::mem;
use std::time::Duration;

fn main() {
    let pool = ThreadPool::new(1000, 2000);

    println!("线程池初始化成功");

    for i in 0..1000 {
        let arg = Box::new(i);
        pool.push_job(test_fun, arg).unwrap();
    }

    thread::sleep(Duration::from_secs(1)); // 等待所有任务完成
}