// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. // // StratoVirt is licensed under Mulan PSL v2. // You can use this software according to the terms and conditions of the Mulan // PSL v2. // You may obtain a copy of Mulan PSL v2 at: // http://license.coscl.org.cn/MulanPSL2 // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. use std::sync::{Arc, Condvar, Mutex}; use std::thread; use std::time::Duration; use anyhow::{bail, Context, Result}; use log::error; use crate::link_list::{List, Node}; const MIN_THREADS: u64 = 1; const MAX_THREADS: u64 = 64; type PoolTask = Box<dyn TaskOperation>; pub trait TaskOperation: Sync + Send { fn run(&mut self); } struct PoolState { /// The total number of current threads in thread pool. /// Including the number of threads need to be created and the number of running threads. total_threads: u64, /// The current number of blocking threads, they will be blocked /// until awakened by request_cond or timeout. blocked_threads: u64, /// The number of threads need to be created. It could be created /// in the main loop or another thread in thread pool later. new_threads: u64, /// The number of threads that have been created but /// have not yet entered the work loop. pending_threads: u64, /// The minimum number of threads residing in the thread pool. min_threads: u64, /// The maximum number of threads that thread pool can create. max_threads: u64, /// List of pending tasks in the thread pool. req_lists: List<PoolTask>, } /// SAFETY: All the operations on req_lists are protected by the mutex, /// so there is no synchronization problem. unsafe impl Send for PoolState {} impl PoolState { fn new() -> Self { Self { total_threads: 0, blocked_threads: 0, new_threads: 0, pending_threads: 0, min_threads: MIN_THREADS, max_threads: MAX_THREADS, req_lists: List::new(), } } fn spawn_thread_needed(&self) -> bool { self.blocked_threads == 0 && self.total_threads < self.max_threads } fn is_running(&self) -> bool { self.total_threads <= self.max_threads } fn spawn_thread(&mut self, pool: Arc<ThreadPool>) -> Result<()> { self.total_threads += 1; self.new_threads += 1; if self.pending_threads == 0 { self.do_spawn_thread(pool)?; } Ok(()) } fn do_spawn_thread(&mut self, pool: Arc<ThreadPool>) -> Result<()> { if self.new_threads == 0 { return Ok(()); } self.new_threads -= 1; self.pending_threads += 1; trace::thread_pool_spawn_thread( &self.total_threads, &self.blocked_threads, &self.new_threads, &self.pending_threads, ); thread::Builder::new() .name("thread-pool".to_string()) .spawn(move || worker_thread(pool)) .with_context(|| "Failed to create thread in pool!")?; Ok(()) } } pub struct ThreadPool { /// Data shared by all threads in the pool. pool_state: Arc<Mutex<PoolState>>, /// Notify the thread in the pool that there are some work to do. request_cond: Condvar, /// Notify threadpool that the current thread has exited. stop_cond: Condvar, } impl Default for ThreadPool { fn default() -> Self { Self { pool_state: Arc::new(Mutex::new(PoolState::new())), request_cond: Condvar::new(), stop_cond: Condvar::new(), } } } impl ThreadPool { /// Submit task to thread pool. pub fn submit_task(pool: Arc<ThreadPool>, task: Box<dyn TaskOperation>) -> Result<()> { trace::thread_pool_submit_task(); let mut locked_state = pool.pool_state.lock().unwrap(); if locked_state.spawn_thread_needed() { locked_state.spawn_thread(pool.clone())? } locked_state.req_lists.add_tail(Box::new(Node::new(task))); drop(locked_state); pool.request_cond.notify_one(); Ok(()) } /// It should be confirmed that all threads have successfully exited /// before function return. pub fn cancel(&self) -> Result<()> { let mut locked_state = self.pool_state.lock().unwrap(); locked_state.total_threads -= locked_state.new_threads; locked_state.new_threads = 0; locked_state.max_threads = 0; self.request_cond.notify_all(); while locked_state.total_threads > 0 { match self.stop_cond.wait(locked_state) { Ok(lock) => locked_state = lock, Err(e) => bail!("{:?}", e), } } Ok(()) } } fn worker_thread(pool: Arc<ThreadPool>) { let mut locked_state = pool.pool_state.lock().unwrap(); locked_state.pending_threads -= 1; locked_state .do_spawn_thread(pool.clone()) .unwrap_or_else(|e| error!("Thread pool error: {:?}", e)); while locked_state.is_running() { let result; if locked_state.req_lists.len == 0 { locked_state.blocked_threads += 1; match pool .request_cond .wait_timeout(locked_state, Duration::from_secs(10)) { Ok((guard, ret)) => { locked_state = guard; result = ret; } Err(e) => { error!("Unknown errors have occurred thread pool: {:?}", e); locked_state = e.into_inner().0; break; } } locked_state.blocked_threads -= 1; if result.timed_out() && locked_state.req_lists.len == 0 && locked_state.total_threads > locked_state.min_threads { // If wait time_out and no pending task and current total number // of threads exceeds the minimum, then exit. break; } continue; } let mut req = locked_state.req_lists.pop_head().unwrap(); drop(locked_state); (*req.value).run(); locked_state = pool.pool_state.lock().unwrap(); } locked_state.total_threads -= 1; trace::thread_pool_exit_thread(&locked_state.total_threads, &locked_state.req_lists.len); pool.stop_cond.notify_one(); pool.request_cond.notify_one(); } #[cfg(test)] mod test { use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::{thread, time}; use super::{TaskOperation, ThreadPool}; struct PoolTask { count: Arc<AtomicU64>, } impl TaskOperation for PoolTask { fn run(&mut self) { std::thread::sleep(std::time::Duration::from_millis(500)); self.count.fetch_add(1, Ordering::SeqCst); } } #[test] fn test_pool_exit() { let pool = Arc::new(ThreadPool::default()); let count = Arc::new(AtomicU64::new(0)); let begin = time::SystemTime::now(); for _ in 0..10 { let task = Box::new(PoolTask { count: count.clone(), }); assert!(ThreadPool::submit_task(pool.clone(), task).is_ok()); } // Waiting for creating. while pool.pool_state.lock().unwrap().req_lists.len != 0 { thread::sleep(time::Duration::from_millis(10)); let now = time::SystemTime::now(); let duration = now.duration_since(begin).unwrap().as_millis(); assert!(duration < 500 * 10); } assert!(pool.cancel().is_ok()); let end = time::SystemTime::now(); let duration = end.duration_since(begin).unwrap().as_millis(); // All tasks are processed in parallel. assert!(duration < 500 * 10); // All the task has been finished. assert_eq!(count.load(Ordering::SeqCst), 10); } }