// Copyright (c) 2020 Huawei Technologies Co.,Ltd. All rights reserved. // // StratoVirt is licensed under Mulan PSL v2. // You can use this software according to the terms and conditions of the Mulan // PSL v2. // You may obtain a copy of Mulan PSL v2 at: // http://license.coscl.org.cn/MulanPSL2 // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. use std::fs::File; use std::mem::size_of; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::{UnixListener, UnixStream}; use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned}; use error_chain::bail; use libc::{ c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, CMSG_LEN, CMSG_SPACE, MSG_NOSIGNAL, MSG_WAITALL, SCM_RIGHTS, SOL_SOCKET, }; use log::error; use super::errors::{ErrorKind, Result, ResultExt}; /// This function returns the caller's thread ID(TID). pub fn gettid() -> u64 { unsafe { libc::syscall(libc::SYS_gettid) as u64 } } /// This function used to remove group and others permission using libc::chmod. pub fn limit_permission(path: &str) -> Result<()> { let file_path = path.as_bytes().to_vec(); let cstr_file_path = std::ffi::CString::new(file_path).unwrap(); let ret = unsafe { libc::chmod(cstr_file_path.as_ptr(), 0o600) }; if ret == 0 { Ok(()) } else { Err(ErrorKind::ChmodFailed(ret).into()) } } /// Gets the page size of host. pub fn host_page_size() -> u64 { unsafe { libc::sysconf(libc::_SC_PAGESIZE) as u64 } } /// Parse unix uri to unix path. /// /// # Notions /// /// Unix uri is the string as `unix:/xxx/xxx`. pub fn parse_unix_uri(uri: &str) -> Result<String> { let parse_vec: Vec<&str> = uri.split(':').collect(); if parse_vec.len() == 2 && parse_vec[0] == "unix" { Ok(parse_vec[1].to_string()) } else { bail!("Invalid unix uri: {}", uri) } } /// Call libc::mmap to allocate memory or map disk file. /// /// # Arguments /// /// * `file` - Backend file. /// * `len` - Length of maping. /// * `offset` - Offset in the file (or other object). /// * `read_only` - Allow to write or not. /// * `is_share` - Share the mapping or not. /// * `dump_guest_core` - Exclude from a core dump or not. /// /// # Errors /// /// * Failed to do mmap. pub fn do_mmap( file: &Option<&File>, len: u64, offset: u64, read_only: bool, is_share: bool, dump_guest_core: bool, ) -> Result<u64> { let mut flags: i32 = 0; let mut fd: i32 = -1; if let Some(f) = file { fd = f.as_raw_fd(); } else { flags |= libc::MAP_ANONYMOUS; } if is_share { flags |= libc::MAP_SHARED; } else { flags |= libc::MAP_PRIVATE; } let mut prot = libc::PROT_READ; if !read_only { prot |= libc::PROT_WRITE; } // Safe because the return value is checked. let hva = unsafe { libc::mmap( std::ptr::null_mut() as *mut libc::c_void, len as libc::size_t, prot, flags, fd as libc::c_int, offset as libc::off_t, ) }; if hva == libc::MAP_FAILED { return Err(std::io::Error::last_os_error()).chain_err(|| "Mmap failed."); } if !dump_guest_core { set_memory_undumpable(hva, len); } Ok(hva as u64) } fn set_memory_undumpable(host_addr: *mut libc::c_void, size: u64) { // Safe because host_addr and size are valid and return value is checked. let ret = unsafe { libc::madvise(host_addr, size as libc::size_t, libc::MADV_DONTDUMP) }; if ret < 0 { error!( "Syscall madvise(with MADV_DONTDUMP) failed, OS error is {}", std::io::Error::last_os_error() ); } } /// Unix socket is a data communication endpoint for exchanging data /// between processes executing on the same host OS. pub struct UnixSock { // Unix socket path path: String, // A unix socket listener acts as a synchronizalbe event. listener: Option<UnixListener>, // Unix socket stream perform like streams of information. sock: Option<UnixStream>, } impl Clone for UnixSock { fn clone(&self) -> Self { UnixSock { path: self.path.clone(), listener: self.listener.as_ref().map(|l| l.try_clone().unwrap()), sock: self.sock.as_ref().map(|s| s.try_clone().unwrap()), } } } #[allow(dead_code)] impl UnixSock { pub fn new(path: &str) -> Self { UnixSock { path: path.to_string(), listener: None, sock: None, } } /// Bind assigns a unique listener for the socket. fn bind(&mut self, unlink: bool) -> Result<()> { if unlink { std::fs::remove_file(self.path.as_str()) .chain_err(|| format!("Failed to remove socket file {}.", self.path.as_str()))?; } let listener = UnixListener::bind(self.path.as_str()) .chain_err(|| format!("Failed to bind the socket {}", self.path))?; self.listener = Some(listener); Ok(()) } /// The listener accepts incoming client connections. fn accept(&mut self) -> Result<()> { let (sock, _addr) = self .listener .as_ref() .unwrap() .accept() .chain_err(|| format!("Failed to accept the socket {}", self.path))?; self.sock = Some(sock); Ok(()) } fn is_accepted(&self) -> bool { self.sock.is_some() } /// Unix socket stream create a connection for requests. pub fn connect(&mut self) -> Result<()> { let sock = UnixStream::connect(self.path.as_str()) .chain_err(|| format!("Failed to connect the socket {}", self.path))?; self.sock = Some(sock); Ok(()) } /// Get Stream's fd from `UnixSock`. pub fn get_stream_raw_fd(&self) -> RawFd { self.sock.as_ref().unwrap().as_raw_fd() } /// Get listener's fd from `UnixSock`. pub fn get_listener_raw_fd(&self) -> RawFd { self.listener.as_ref().unwrap().as_raw_fd() } fn cmsg_data(&self, cmsg_buffer: *mut cmsghdr) -> *mut RawFd { // Safe as parameter is zero. (cmsg_buffer as *mut u8).wrapping_add(unsafe { CMSG_LEN(0) } as usize) as *mut RawFd } fn get_next_cmsg( &self, msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut cmsghdr, ) -> *mut cmsghdr { // Safe to get cmsg_len because the parameter is valid. let next_cmsg = (cmsg_ptr as *mut u8) .wrapping_add(unsafe { CMSG_LEN(cmsg.cmsg_len as u32) } as usize) as *mut cmsghdr; // Safe to get msg_control because the parameter is valid. let nex_cmsg_pos = (next_cmsg as *mut u8).wrapping_sub(msghdr.msg_control as usize) as usize; // Safe as parameter is zero. if nex_cmsg_pos.wrapping_add(unsafe { CMSG_LEN(0) } as usize) > msghdr.msg_controllen as usize { null_mut() } else { next_cmsg } } /// Send message and scm_fds to socket file descriptor. /// /// # Arguments /// /// * `iovecs` - Data buffer that need to send to socket. /// * `out_fds` - EventFds that need to send to socket. /// /// # Errors /// /// The socket file descriptor is broken. pub fn send_msg(&self, iovecs: &mut [iovec], out_fds: &[RawFd]) -> std::io::Result<usize> { // It is safe because we check the iovecs lens before. #[cfg(not(target_env = "musl"))] let iovecs_len = iovecs.len(); #[cfg(target_env = "musl")] let iovecs_len = iovecs.len() as i32; // It is safe because we check the out_fds lens before. #[cfg(not(target_env = "musl"))] let cmsg_len = unsafe { CMSG_LEN((size_of::<RawFd>() * out_fds.len()) as u32) } as usize; #[cfg(target_env = "musl")] let cmsg_len = unsafe { CMSG_LEN((size_of::<RawFd>() * out_fds.len()) as u32) } as u32; // It is safe because we check the out_fds lens before. #[cfg(not(target_env = "musl"))] let cmsg_capacity = unsafe { CMSG_SPACE((size_of::<RawFd>() * out_fds.len()) as u32) } as usize; #[cfg(target_env = "musl")] let cmsg_capacity = unsafe { CMSG_SPACE((size_of::<RawFd>() * out_fds.len()) as u32) } as u32; let mut cmsg_buffer = vec![0_u64; cmsg_capacity as usize]; // In `musl` toolchain, msghdr has private member `__pad0` and `__pad1`, it can't be // initialized in normal way. let mut msg: msghdr = unsafe { std::mem::zeroed() }; msg.msg_name = null_mut(); msg.msg_namelen = 0; msg.msg_iov = iovecs.as_mut_ptr(); msg.msg_iovlen = iovecs_len; msg.msg_control = null_mut(); msg.msg_controllen = 0; msg.msg_flags = 0; if !out_fds.is_empty() { let cmsg = cmsghdr { cmsg_len, #[cfg(target_env = "musl")] __pad1: 0, cmsg_level: SOL_SOCKET, cmsg_type: SCM_RIGHTS, }; unsafe { write_unaligned(cmsg_buffer.as_mut_ptr() as *mut cmsghdr, cmsg); copy_nonoverlapping( out_fds.as_ptr(), self.cmsg_data(cmsg_buffer.as_mut_ptr() as *mut cmsghdr), out_fds.len(), ); } msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void; msg.msg_controllen = cmsg_capacity; } // Safe as msg parameters are valid. let write_count = unsafe { sendmsg(self.sock.as_ref().unwrap().as_raw_fd(), &msg, MSG_NOSIGNAL) }; if write_count == -1 { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, format!( "Failed to send msg, err: {}", std::io::Error::last_os_error() ), )); } else { Ok(write_count as usize) } } /// Receive message and scm_fds from socket file descriptor. /// /// # Arguments /// /// * `iovecs` - Data buffer that need to receive from socket. /// * `in_fds` - EventFds that need to receive from socket. /// /// # Errors /// /// The socket file descriptor is broken. pub fn recv_msg( &self, iovecs: &mut [iovec], in_fds: &mut [RawFd], ) -> std::io::Result<(usize, usize)> { // It is safe because we check the iovecs lens before. #[cfg(not(target_env = "musl"))] let iovecs_len = iovecs.len(); #[cfg(target_env = "musl")] let iovecs_len = iovecs.len() as i32; // It is safe because we check the in_fds lens before. #[cfg(not(target_env = "musl"))] let cmsg_capacity = unsafe { CMSG_SPACE((size_of::<RawFd>() * in_fds.len()) as u32) } as usize; #[cfg(target_env = "musl")] let cmsg_capacity = unsafe { CMSG_SPACE((size_of::<RawFd>() * in_fds.len()) as u32) } as u32; let mut cmsg_buffer = vec![0_u64; cmsg_capacity as usize]; // In `musl` toolchain, msghdr has private member `__pad0` and `__pad1`, it can't be // initialized in normal way. let mut msg: msghdr = unsafe { std::mem::zeroed() }; msg.msg_name = null_mut(); msg.msg_namelen = 0; msg.msg_iov = iovecs.as_mut_ptr(); msg.msg_iovlen = iovecs_len; msg.msg_control = null_mut(); msg.msg_controllen = 0; msg.msg_flags = 0; if !in_fds.is_empty() { msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void; msg.msg_controllen = cmsg_capacity; } // Safe as msg parameters are valid. let total_read = unsafe { recvmsg( self.sock.as_ref().unwrap().as_raw_fd(), &mut msg, MSG_WAITALL, ) }; if total_read == -1 { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, format!( "Failed to recv msg, err: {}", std::io::Error::last_os_error() ), )); } if total_read == 0 && (msg.msg_controllen as usize) < size_of::<cmsghdr>() { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, format!( "The length of control message is invalid, {} {}", msg.msg_controllen, size_of::<cmsghdr>() ), )); } let mut cmsg_ptr = msg.msg_control as *mut cmsghdr; let mut in_fds_count = 0_usize; while !cmsg_ptr.is_null() { let cmsg = unsafe { (cmsg_ptr as *mut cmsghdr).read_unaligned() }; if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS { let fd_count = (cmsg.cmsg_len as usize - unsafe { CMSG_LEN(0) } as usize) / size_of::<RawFd>(); unsafe { copy_nonoverlapping( self.cmsg_data(cmsg_ptr), in_fds[in_fds_count..(in_fds_count + fd_count)].as_mut_ptr(), fd_count, ); } in_fds_count += fd_count; } cmsg_ptr = self.get_next_cmsg(&msg, &cmsg, cmsg_ptr); } Ok((total_read as usize, in_fds_count as usize)) } } #[cfg(test)] mod tests { use std::path::Path; use std::time::Duration; use libc::{c_void, iovec}; use super::{parse_unix_uri, UnixSock}; #[test] fn test_parse_uri() { let test_uri_01 = "unix:/tmp/test_file.sock"; assert!(parse_unix_uri(test_uri_01).is_ok()); assert_eq!( parse_unix_uri(test_uri_01).unwrap(), String::from("/tmp/test_file.sock") ); let test_uri_02 = "file:/tmp/test_file:file"; assert!(parse_unix_uri(test_uri_02).is_err()); let test_uri_03 = "tcp:127.0.0.1"; assert!(parse_unix_uri(test_uri_03).is_err()); } #[test] fn test_create_unix_socket() { let path_name = String::from("test_socket1.sock"); let sock_path = Path::new("./test_socket1.sock"); let mut listener = UnixSock::new(&path_name); if sock_path.exists() { assert!(listener.bind(true).is_ok()); } else { assert!(listener.bind(false).is_ok()); } assert_ne!(listener.get_listener_raw_fd(), 0); std::thread::sleep(Duration::from_millis(100)); let mut stream = UnixSock::new(&path_name); assert!(stream.connect().is_ok()); assert_ne!(stream.get_stream_raw_fd(), 0); assert!(listener.accept().is_ok()); assert_eq!(listener.is_accepted(), true); } #[test] fn test_send_recv_sock_msg() { let path_name = String::from("test_socket2.sock"); let sock_path = Path::new("./test_socket2.sock"); let mut listener = UnixSock::new(&path_name); if sock_path.exists() { assert!(listener.bind(true).is_ok()); } else { assert!(listener.bind(false).is_ok()); } std::thread::sleep(Duration::from_millis(100)); let mut stream = UnixSock::new(&path_name); assert!(stream.connect().is_ok()); assert!(listener.accept().is_ok()); let buff = "send message".as_bytes(); let mut data: Vec<u8> = Vec::new(); data.extend(buff); let mut io_data = vec![iovec { iov_base: data.as_slice()[0..buff.len()].as_ptr() as *mut c_void, iov_len: buff.len(), }]; let out_fds = [listener.get_stream_raw_fd()]; let size = listener.send_msg(&mut io_data, &out_fds).unwrap(); assert_eq!(size, buff.len()); let mut recv: Vec<iovec> = io_data; let mut in_fd = [0; 1]; let (data_size, fd_size) = stream.recv_msg(&mut recv, &mut in_fd).unwrap(); assert_eq!(data_size, buff.len()); assert_eq!(fd_size, in_fd.len()); } }