// Copyright (c) 2022 Huawei Technologies Co.,Ltd. All rights reserved. // // StratoVirt is licensed under Mulan PSL v2. // You can use this software according to the terms and conditions of the Mulan // PSL v2. // You may obtain a copy of Mulan PSL v2 at: // http://license.coscl.org.cn/MulanPSL2 // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. use log::error; use std::mem::size_of; use std::os::unix::io::RawFd; use std::slice; use std::sync::{Arc, Mutex}; use util::unix::limit_permission; use virtio::vhost::user::{ RegionMemInfo, VhostUserHdrFlag, VhostUserMemHdr, VhostUserMsgHdr, VhostUserMsgReq, VhostUserVringAddr, VhostUserVringState, MAX_ATTACHED_FD_ENTRIES, }; use virtio::VhostUser::VhostUserSock; use anyhow::{bail, Context, Result}; /// The trait for dealing with vhost-user request in the server. pub trait VhostUserReqHandler: Send + Sync { /// Set the current process as the owner of this file descriptor. fn set_owner(&mut self) -> Result<()>; /// Get a bitmask of supported virtio/vhost features. fn get_features(&self) -> Result<u64>; /// Inform the vhost subsystem which features to enable. /// /// # Arguments /// /// * `features` - The features from the vhost-user client in StratoVirt. fn set_features(&mut self, features: u64) -> Result<()>; /// Set the guest memory mappings for vhost to use. /// /// # Arguments /// /// * `regions` - The slice of memory region information for the message of memory table. /// * `fds` - The files descriptors are used to map shared memory for the process and /// StratoVirt. fn set_mem_table(&mut self, regions: &[RegionMemInfo], fds: &[RawFd]) -> Result<()>; /// Set the size of descriptors in the virtio queue. /// /// # Arguments /// /// * `queue_index` - The index of virtio queue. /// * `num` - The total size of virtio queue. fn set_vring_num(&mut self, queue_index: usize, num: u16) -> Result<()>; /// Set the addresses for a given virtio queue. /// /// # Arguments /// /// * `queue_index` - The index of virtio queue. /// * `flags` - Option flags. /// * `desc_table` - The start address of descriptor table. /// * `used_ring` - The start address of used ring. /// * `avail_ring` - The start address of avail ring. /// * `log` - The start address of log. fn set_vring_addr( &mut self, queue_index: usize, flags: u32, desc_table: u64, used_ring: u64, avail_ring: u64, log: u64, ) -> Result<()>; /// Set the first index to look for available descriptors. /// /// # Arguments /// /// * `queue_index` - The index of virtio queue. /// * `num` - the first index to look for available descriptors. fn set_vring_base(&mut self, queue_index: usize, num: u16) -> Result<()>; /// Set the eventfd to trigger when buffers need to be processed /// by the guest. /// /// # Arguments /// /// * `queue_index` - The index of virtio queue. /// * `fd` - The files descriptor used to notify the guest. fn set_vring_call(&mut self, queue_index: usize, fd: RawFd) -> Result<()>; /// Set the eventfd that will be signaled by the guest when buffers /// need to be processed by the host. /// /// # Arguments /// /// * `queue_index` - The index of virtio queue. /// * `fd` - The files descriptor used to notify the host. fn set_vring_kick(&mut self, queue_index: usize, fd: RawFd) -> Result<()>; /// set the status of virtio queue. /// /// # Arguments /// /// * `queue_index` - The index of virtio queue. /// * `status` - The status of virtio queue. fn set_vring_enable(&mut self, queue_index: usize, status: u32) -> Result<()>; } /// The vhost-user server handler can communicate with StratoVirt and set the data of requests /// to the backend. #[derive(Clone)] pub struct VhostUserServerHandler { /// The information of socket used to communicate with StratoVirt. pub sock: VhostUserSock, /// The backend used to save the data of requests from StratoVirt. backend: Arc<Mutex<dyn VhostUserReqHandler>>, } fn close_fds(fds: Vec<RawFd>) { for fd in fds { let _ = unsafe { libc::close(fd) }; } } fn is_invalid_fds(hdr: &mut VhostUserMsgHdr, rfds: Option<Vec<RawFd>>) -> Result<()> { match VhostUserMsgReq::from(hdr.request) { VhostUserMsgReq::SetMemTable => Ok(()), VhostUserMsgReq::SetVringCall => Ok(()), VhostUserMsgReq::SetVringKick => Ok(()), VhostUserMsgReq::SetSlaveReqFd => Ok(()), _ => { if rfds.is_some() { if let Some(fds) = rfds { close_fds(fds); } bail!("The fds is invalid, request: {}", hdr.request); } else { Ok(()) } } } } impl VhostUserServerHandler { /// Construct a vhost-user server handler /// /// # Arguments /// /// * `path` - The path of unix socket file which communicates with StratoVirt. /// * `backend` - The trait of backend used to save the data of requests from StratoVirt. pub fn new(path: &str, backend: Arc<Mutex<dyn VhostUserReqHandler>>) -> Result<Self> { let mut sock = VhostUserSock::new(path); sock.domain .bind(true) .with_context(|| format!("Failed to bind for vhost user server {}", path))?; limit_permission(path).with_context(|| format!("Failed to limit permission {}", path))?; Ok(VhostUserServerHandler { sock, backend }) } fn recv_hdr_and_fds(&mut self) -> Result<(VhostUserMsgHdr, Option<Vec<RawFd>>)> { let mut hdr = VhostUserMsgHdr::default(); let body_opt: Option<&mut u32> = None; let payload_opt: Option<&mut [u8]> = None; let mut fds = vec![0; MAX_ATTACHED_FD_ENTRIES]; let (rcv_len, fds_num) = self .sock .recv_msg(Some(&mut hdr), body_opt, payload_opt, &mut fds) .with_context(|| "Failed to recv hdr and fds")?; if rcv_len != size_of::<VhostUserMsgHdr>() { bail!( "The received length {} is invalid, expect {}", rcv_len, size_of::<VhostUserMsgHdr>() ); } else if hdr.is_invalid() { bail!( "The header of vhost user msg is invalid, request: {}, size: {}, flags: {}", hdr.request, hdr.size, hdr.flags ); } let rfds = match fds_num { 0 => None, n => { let mut fds_temp = Vec::with_capacity(n); fds_temp.extend_from_slice(&fds[0..n]); Some(fds_temp) } }; is_invalid_fds(&mut hdr, rfds.clone())?; Ok((hdr, rfds)) } fn recv_body(&mut self, len: usize) -> Result<(usize, Vec<u8>)> { let mut rbuf = vec![0u8; len]; let body_opt: Option<&mut u32> = None; let hdr_opt: Option<&mut VhostUserMsgHdr> = None; let (rcv_len, _) = self .sock .recv_msg(hdr_opt, body_opt, Some(&mut rbuf), &mut []) .with_context(|| "Failed to recv msg body")?; if rcv_len != len { bail!( "The length of msg body {} is invalid, expected {}", rcv_len, len ); } Ok((rcv_len, rbuf)) } fn get_msg_body<'a, D: Sized>( &self, hdr: &VhostUserMsgHdr, buf: &'a [u8], len: usize, ) -> Result<&'a D> { if !self.is_valid_request(hdr, len, size_of::<D>()) { bail!( "Failed to get msg body for request {}, len {}, payload size {}, hdr.size {}", hdr.request, len, size_of::<D>(), hdr.size ); } let body = unsafe { &*(buf.as_ptr() as *const D) }; Ok(body) } fn send_ack_msg<D: Sized>(&mut self, request: u32, res: D, fds: &[RawFd]) -> Result<()> { let hdr = VhostUserMsgHdr::new( request, VhostUserHdrFlag::Reply as u32, size_of::<D>() as u32, ); let payload_opt: Option<&[u8]> = None; self.sock .send_msg(Some(&hdr), Some(&res), payload_opt, fds) .with_context(|| "Failed to send ack msg")?; Ok(()) } #[cfg_attr(feature = "cargo-clippy", allow(clippy::cast_ptr_alignment))] fn set_msg_mem_table( &mut self, hdr: &VhostUserMsgHdr, buf: &[u8], len: usize, fds_opt: Option<Vec<RawFd>>, ) -> Result<()> { if len < size_of::<VhostUserMemHdr>() { if let Some(fds) = fds_opt { close_fds(fds); } bail!("The header length of mem table is invalid {}", len); } let memhdrsize = size_of::<VhostUserMemHdr>(); let memhdr = unsafe { &*(buf.as_ptr() as *const VhostUserMemHdr) }; let total_size = (memhdr.nregions as usize * size_of::<RegionMemInfo>()) + memhdrsize; if (hdr.size as usize) != total_size { if let Some(fds) = fds_opt { close_fds(fds); } bail!( "The body length of mem table is invalid {}, expected {}", total_size, hdr.size, ); } let regions = unsafe { slice::from_raw_parts( buf.as_ptr().add(memhdrsize) as *const RegionMemInfo, memhdr.nregions as usize, ) }; if let Some(fds) = fds_opt { let fds_len = fds.len(); if fds_len != (memhdr.nregions as usize) { close_fds(fds); bail!( "The length of fds {} for mem table is invalid, expected {}", fds_len, memhdr.nregions ); } self.backend.lock().unwrap().set_mem_table(regions, &fds)?; } else { bail!("The fds of mem table is null"); } Ok(()) } fn is_valid_request(&self, hdr: &VhostUserMsgHdr, size: usize, expected: usize) -> bool { (hdr.size as usize == expected) && (size == expected) && !hdr.is_reply() } fn process_request( &mut self, hdr: &VhostUserMsgHdr, buf: &[u8], len: usize, rfds: Option<Vec<RawFd>>, ) -> Result<()> { match VhostUserMsgReq::from(hdr.request) { VhostUserMsgReq::GetFeatures => { if !self.is_valid_request(hdr, len, 0) { bail!("Invalid request size of GetFeatures"); } let features = self.backend.lock().unwrap().get_features()?; if hdr.need_reply() { self.send_ack_msg(VhostUserMsgReq::GetFeatures as u32, features, &[]) .with_context(|| "Failed to send ack msg for getting features")?; } } VhostUserMsgReq::SetFeatures => { let features = self .get_msg_body::<u64>(hdr, buf, len) .with_context(|| "Failed to get msg body for setting features")?; self.backend.lock().unwrap().set_features(*features)?; } VhostUserMsgReq::SetOwner => { if !self.is_valid_request(hdr, len, 0) { bail!("Invalid request size of SetOwner"); } self.backend.lock().unwrap().set_owner()?; } VhostUserMsgReq::SetMemTable => { let ret = match self.set_msg_mem_table(hdr, buf, len, rfds) { Err(ref e) => { error!("Failed to set mem table {:?}", e); 1u64 } Ok(_) => 0u64, }; if hdr.need_reply() { self.send_ack_msg(VhostUserMsgReq::SetMemTable as u32, ret, &[]) .with_context(|| "Failed to send ack msg for setting mem table")?; } } VhostUserMsgReq::SetVringNum => { let vringstate = self .get_msg_body::<VhostUserVringState>(hdr, buf, len) .with_context(|| "Failed to get msg body for setting vring num")?; self.backend .lock() .unwrap() .set_vring_num(vringstate.index as usize, vringstate.value as u16)?; } VhostUserMsgReq::SetVringAddr => { let vringaddr = self .get_msg_body::<VhostUserVringAddr>(hdr, buf, len) .with_context(|| "Failed to get msg body for setting vring addr")?; self.backend.lock().unwrap().set_vring_addr( vringaddr.index as usize, vringaddr.flags, vringaddr.desc_user_addr, vringaddr.used_user_addr, vringaddr.avail_user_addr, vringaddr.log_guest_addr, )?; } VhostUserMsgReq::SetVringBase => { let vringstate = self .get_msg_body::<VhostUserVringState>(hdr, buf, len) .with_context(|| "Failed to get msg body for setting vring base")?; self.backend .lock() .unwrap() .set_vring_base(vringstate.index as usize, vringstate.value as u16)?; } VhostUserMsgReq::SetVringEnable => { let vringstate = self .get_msg_body::<VhostUserVringState>(hdr, buf, len) .with_context(|| "Failed to get msg body for setting vring enable")?; self.backend .lock() .unwrap() .set_vring_enable(vringstate.index as usize, vringstate.value)?; } VhostUserMsgReq::SetVringKick => { let index = self .get_msg_body::<u64>(hdr, buf, len) .with_context(|| "Failed to get msg body for setting vring kick")?; if let Some(fds) = rfds { let fds_len = fds.len(); if fds_len != 1 { close_fds(fds); bail!("The length {} of fds for kicking is invalid", fds_len); } self.backend .lock() .unwrap() .set_vring_kick(*index as usize, fds[0])?; } else { bail!("The length of fds for kicking is null"); } } VhostUserMsgReq::SetVringCall => { let index = self .get_msg_body::<u64>(hdr, buf, len) .with_context(|| "Failed to get msg body for setting vring call")?; if let Some(fds) = rfds { let fds_len = fds.len(); if fds_len != 1 { close_fds(fds); bail!("The length {} of fds for calling is invalid", fds_len); } self.backend .lock() .unwrap() .set_vring_call(*index as usize, fds[0])?; } else { bail!("The length of fds for calling is null"); } } _ => { bail!("The request {} is unknown", hdr.request); } }; Ok(()) } /// The function used to process requests from StratoVirt. pub fn handle_request(&mut self) -> Result<()> { let (hdr, rfds) = self .recv_hdr_and_fds() .with_context(|| "Failed to recv header and fds")?; let (len, buf) = match hdr.size { 0 => (0, vec![0u8; 0]), _ => { let (rcv_len, rbuf) = self .recv_body(hdr.size as usize) .with_context(|| "Failed to recv msg body")?; (rcv_len, rbuf) } }; self.process_request(&hdr, &buf, len, rfds) .with_context(|| format!("Failed to process the request {}", hdr.request))?; Ok(()) } }