// Copyright (c) 2020 Huawei Technologies Co.,Ltd. All rights reserved. // // StratoVirt is licensed under Mulan PSL v2. // You can use this software according to the terms and conditions of the Mulan // PSL v2. // You may obtain a copy of Mulan PSL v2 at: // http://license.coscl.org.cn/MulanPSL2 // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. use std::io::{Error, ErrorKind, Read, Write}; use std::mem::size_of; use std::os::unix::io::RawFd; use anyhow::{bail, Result}; use libc::{ c_void, iovec, msghdr, recvmsg, sendmsg, CMSG_DATA, CMSG_FIRSTHDR, CMSG_LEN, CMSG_NXTHDR, MSG_DONTWAIT, MSG_NOSIGNAL, SCM_RIGHTS, SOL_SOCKET, }; use serde::Deserialize; const MAX_SOCKET_MSG_LENGTH: usize = 8192; /// The max buffer length received by recvmsg. const MAX_RECV_BUF_LEN: usize = 4096; /// The max buffer length used by recvmsg for file descriptors. const MAX_RECV_FDS_LEN: usize = MAX_RECV_BUF_LEN; /// Wrapper over socket file description read and write message. /// /// # Examples /// /// ```no_run /// use std::io::prelude::*; /// use std::os::unix::io::AsRawFd; /// use std::os::unix::net::UnixStream; /// /// use machine_manager::socket::SocketRWHandler; /// /// fn main() -> std::io::Result<()> { /// let mut stream = UnixStream::connect("/path/to/my/socket")?; /// let mut handler = SocketRWHandler::new(stream.as_raw_fd()); /// stream.write_all(b"hello world")?; /// let mut buffer = [0_u8; 20]; /// let count = handler.read(&mut buffer)?; /// println!("{}", String::from_utf8_lossy(&buffer[..count])); /// Ok(()) /// } /// ``` #[allow(clippy::upper_case_acronyms)] pub struct SocketRWHandler { /// Socket fd to read and write message socket_fd: RawFd, /// Buffer to restore byte read and write with fd buf: Vec<u8>, /// Pos to buffer when read and write with fd pos: usize, /// Fds when read from fd's scm right scm_fd: Vec<RawFd>, } impl SocketRWHandler { /// Allocates a new `SocketRWHandler` with a socket fd /// /// # Arguments /// /// * `r` - The file descriptor for socket. pub fn new(r: RawFd) -> Self { SocketRWHandler { socket_fd: r, buf: Vec::new(), pos: 0, scm_fd: Vec::new(), } } /// Get inner buf as a `String`. pub fn get_buf_string(&mut self) -> Result<String> { if self.buf.len() > MAX_SOCKET_MSG_LENGTH { bail!("The socket message is too long."); } Ok(String::from_utf8_lossy(&self.buf).trim().to_string()) } /// Get the last file descriptor read from `scm_fd`. pub fn getfd(&mut self) -> Option<RawFd> { if self.scm_fd.is_empty() { None } else { Some(self.scm_fd[self.scm_fd.len() - 1]) } } fn parse_fd(&mut self, mhdr: &msghdr) { // At least it should has one RawFd. // SAFETY: The input parameter is constant. let min_cmsg_len = unsafe { CMSG_LEN(size_of::<RawFd>() as u32) as u64 }; if (mhdr.msg_controllen as u64) < min_cmsg_len { return; } // SAFETY: The pointer of mhdr can be guaranteed not null. let mut cmsg_hdr = unsafe { CMSG_FIRSTHDR(mhdr as *const msghdr).as_ref() }; while cmsg_hdr.is_some() { let scm = cmsg_hdr.unwrap(); if scm.cmsg_level == SOL_SOCKET && scm.cmsg_type == SCM_RIGHTS && scm.cmsg_len as u64 >= min_cmsg_len { // SAFETY: The pointer of scm can be guaranteed not null. let fds = unsafe { let fd_num = (scm.cmsg_len as u64 - CMSG_LEN(0) as u64) as usize / size_of::<RawFd>(); std::slice::from_raw_parts(CMSG_DATA(scm) as *const RawFd, fd_num) }; self.scm_fd.append(&mut fds.to_vec()); } // SAFETY: The pointer of mhdr can be guaranteed not null. cmsg_hdr = unsafe { CMSG_NXTHDR(mhdr as *const msghdr, scm).as_ref() }; } } /// Receive bytes and scm_fd from socket file descriptor. /// /// # Notes /// /// Use [recvmsg(2)](https://linux.die.net/man/2/recvmsg) to receive /// messages from `socket_fd`. Some fd can be passed over an `UnixSocket` /// in a single Control Message. /// This function can read both buffer[u8] and fd. /// /// # Errors /// The socket file descriptor is broken. fn read_fd(&mut self) -> std::io::Result<()> { let recv_buf = [0_u8; MAX_RECV_BUF_LEN]; let mut iov = iovec { iov_base: recv_buf.as_ptr() as *mut c_void, iov_len: MAX_RECV_BUF_LEN, }; let mut cmsg_space = [0_u8; MAX_RECV_FDS_LEN]; loop { let mut mhdr: msghdr = // SAFETY: In `musl` toolchain, msghdr has private member `__pad0` and `__pad1`, it can't be // initialized in normal way. unsafe { std::mem::zeroed() }; mhdr.msg_name = std::ptr::null_mut(); mhdr.msg_namelen = 0; mhdr.msg_iov = &mut iov as *mut iovec; mhdr.msg_iovlen = 1; mhdr.msg_control = cmsg_space.as_mut_ptr() as *mut c_void; mhdr.msg_controllen = cmsg_space.len() as _; mhdr.msg_flags = 0; // MSG_DONTWAIT: Enables nonblocking operation, if the operation would block the call // fails with the error EAGAIN or EWOULDBLOCK. When this error occurs, break loop // SAFETY: The pointer of mhdr can been guaranteed not null. let ret = unsafe { recvmsg(self.socket_fd, &mut mhdr, MSG_DONTWAIT) }; // when use tcpsocket client and exit with ctrl+c, ret value will return 0 and get // error WouldBlock or BrokenPipe, so we should handle this 0 to break this loop. if ret == -1 || ret == 0 { let sock_err = Error::last_os_error(); if sock_err.kind() == ErrorKind::WouldBlock || sock_err.kind() == ErrorKind::BrokenPipe { break; } else { return Err(sock_err); } } self.parse_fd(&mhdr); if ret > 0 { self.buf.extend(&recv_buf[..ret as usize]); if let Some(pos) = self.pos.checked_add(ret as usize) { self.pos = pos; } else { return Err(ErrorKind::InvalidInput.into()); } } } Ok(()) } /// Send bytes message with socket file descriptor. /// /// # Notes /// Use [sendmsg(2)](https://linux.die.net/man/2/sendmsg) to send messages /// to `socket_fd`. /// Message is `self::buf`: Vec<u8> with `self::pos` and length. /// /// # Arguments /// /// * `length` - Length of the buf to write. /// /// # Errors /// The socket file descriptor is broken. fn write_fd(&mut self, length: usize) -> std::io::Result<()> { let mut iov = iovec { iov_base: self.buf.as_slice()[(self.pos - length)..(self.pos - 1)].as_ptr() as *mut c_void, iov_len: length, }; // In `musl` toolchain, msghdr has private member `__pad0` and `__pad1`, it can't be // initialized in normal way. // SAFETY: The member variables of mhdr have been initialization later. let mut mhdr: msghdr = unsafe { std::mem::zeroed() }; mhdr.msg_name = std::ptr::null_mut(); mhdr.msg_namelen = 0; mhdr.msg_iov = &mut iov as *mut iovec; mhdr.msg_iovlen = 1; mhdr.msg_control = std::ptr::null_mut(); mhdr.msg_controllen = 0; mhdr.msg_flags = 0; // SAFETY: The buffer address and length recorded in mhdr are both legal. if unsafe { sendmsg(self.socket_fd, &mhdr, MSG_NOSIGNAL) } == -1 { Err(Error::new( ErrorKind::BrokenPipe, "The socket pipe is broken!", )) } else { Ok(()) } } /// Reset `SocketRWHandler` buffer and pos. pub fn clear(&mut self) { self.buf.clear(); self.scm_fd.clear(); self.pos = 0; } } impl Read for SocketRWHandler { fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { let start = self.pos; self.read_fd()?; buf[0..self.pos - start].copy_from_slice(&self.buf[start..self.pos]); Ok(self.pos - start) } } impl Write for SocketRWHandler { fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { self.buf.extend(buf); if let Some(pos) = self.pos.checked_add(buf.len()) { self.pos = pos; } else { return Err(ErrorKind::InvalidInput.into()); } self.write_fd(buf.len())?; Ok(buf.len()) } fn flush(&mut self) -> std::io::Result<()> { self.clear(); Ok(()) } } /// The handler to handle socket stream and parse socket stream bytes to /// json-string. /// /// # Examples /// /// ```no_run /// use std::io::prelude::*; /// use std::os::unix::io::AsRawFd; /// use std::os::unix::net::UnixStream; /// /// use machine_manager::socket::SocketHandler; /// /// fn main() -> std::io::Result<()> { /// let mut stream = UnixStream::connect("/path/to/my/socket")?; /// let mut handler = SocketHandler::new(stream.as_raw_fd()); /// handler.send_str(&String::from("hello world"))?; /// let mut response = String::new(); /// stream.read_to_string(&mut response)?; /// println!("{}", response); /// Ok(()) /// } /// ``` pub struct SocketHandler { /// Handler `Read` and `Write` for socket stream stream: SocketRWHandler, /// Buffer to leave with read result buffer: String, } impl SocketHandler { /// Allocates a new `SocketRWHandler` with `socket_fd` /// /// # Arguments /// /// * `r` - The file descriptor for socket. pub fn new(r: RawFd) -> Self { SocketHandler { stream: SocketRWHandler::new(r), buffer: String::new(), } } pub fn get_line(&mut self) -> Result<Option<String>> { self.buffer.clear(); self.stream.clear(); self.stream.read_fd().unwrap(); self.stream.get_buf_string().map(|buffer| { self.buffer = buffer; if self.stream.pos == 0 { None } else { Some(self.buffer.clone()) } }) } /// Parse the bytes received by `SocketHandler`. /// /// # Notes /// If the bytes ended with '\n', this function will remove it. And then /// parse to Deserialize object. pub fn decode_line<'de, D: Deserialize<'de>>( &'de mut self, ) -> (Result<Option<D>>, Option<RawFd>) { self.buffer.clear(); self.stream.clear(); self.stream.read_fd().unwrap(); match self.stream.get_buf_string() { Ok(buffer) => { self.buffer = buffer; if self.stream.pos == 0 { (Ok(None), None) } else { ( serde_json::from_str(&self.buffer) .map(Some) .map_err(From::from), self.stream.getfd(), ) } } Err(e) => (Err(e), None), } } /// Discard message from `socket_fd`. pub fn discard(&mut self) -> Result<()> { self.stream.read_fd()?; self.stream.clear(); self.buffer.clear(); Ok(()) } /// Send String to `socket_fd`. /// /// # Arguments /// /// * `s` - The `String` send to `socket_fd`. /// /// # Errors /// The socket file descriptor is broken. pub fn send_str(&mut self, s: &str) -> std::io::Result<()> { self.stream.flush().unwrap(); let msg = s.to_string() + "\r"; match self.stream.write(msg.as_bytes()) { Ok(_) => { let _ = self.stream.write(&[b'\n'])?; Ok(()) } Err(_) => Err(Error::new( ErrorKind::BrokenPipe, "The socket pipe is broken!", )), } } } #[cfg(test)] mod tests { use std::io::{Read, Write}; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::{UnixListener, UnixStream}; use std::time::Duration; use serde::{Deserialize, Serialize}; use crate::socket::{SocketHandler, SocketRWHandler}; // Environment Preparation for UnixSocket fn prepare_unix_socket_environment(socket_id: &str) -> (UnixListener, UnixStream, UnixStream) { let socket_name: String = format!("test_{}.sock", socket_id); let _ = std::fs::remove_file(&socket_name); let listener = UnixListener::bind(&socket_name).unwrap(); std::thread::sleep(Duration::from_millis(100)); let client = UnixStream::connect(&socket_name).unwrap(); let (server, _) = listener.accept().unwrap(); (listener, client, server) } // Environment Recovery for UnixSocket fn recover_unix_socket_environment(socket_id: &str) { let socket_name: String = format!("test_{}.sock", socket_id); std::fs::remove_file(&socket_name).unwrap(); } fn socket_basic_rw(client_fd: RawFd, server_fd: RawFd) -> bool { // Create `write_handler` and `read_handler` from `client_fd` and `server_fd` let mut write_handler = SocketRWHandler::new(client_fd); let mut read_handler = SocketRWHandler::new(server_fd); // Send a `buf` from `write_handler` to `read_handler` // 1.First write let test_buf1: [u8; 11] = [104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100]; assert_eq!(write_handler.write(&test_buf1).unwrap(), 11); assert_eq!(write_handler.pos, 11); let mut rst_buf = [0u8; 20]; assert_eq!(read_handler.read(&mut rst_buf).unwrap(), 11); assert_eq!(rst_buf[..11], test_buf1); assert_eq!(read_handler.buf, write_handler.buf); assert_eq!(read_handler.buf[..11], test_buf1); assert_eq!(write_handler.pos, 11); // 2.Second write let test_buf2: [u8; 10] = [104, 101, 108, 108, 111, 32, 114, 117, 115, 116]; assert_eq!(write_handler.write(&test_buf2).unwrap(), 10); assert_eq!(write_handler.pos, 21); assert_eq!(read_handler.read(&mut rst_buf).unwrap(), 10); assert_eq!(rst_buf[..10], test_buf2); assert_eq!(read_handler.buf, write_handler.buf); assert_eq!(read_handler.buf[11..], test_buf2); assert_eq!(write_handler.pos, 21); // 3.Use 'flush' and test third time let test_buf3: [u8; 6] = [115, 111, 99, 107, 101, 116]; write_handler.flush().unwrap(); read_handler.flush().unwrap(); assert_eq!(write_handler.pos, 0); assert_eq!(read_handler.pos, 0); assert!(write_handler.buf.is_empty()); assert!(read_handler.buf.is_empty()); assert_eq!(write_handler.write(&test_buf3).unwrap(), 6); assert_eq!(write_handler.pos, 6); assert_eq!(read_handler.read(&mut rst_buf).unwrap(), 6); assert_eq!(rst_buf[..6], test_buf3); assert_eq!(read_handler.buf, write_handler.buf); assert_eq!(read_handler.buf[..6], test_buf3); assert_eq!(write_handler.pos, 6); true } #[test] fn test_unix_socket_read_and_write() { // Pre test. Environment Preparation let (_, client, server) = prepare_unix_socket_environment("01"); // Test fn: socket basic read and write assert!(socket_basic_rw(client.as_raw_fd(), server.as_raw_fd())); // After test. Environment Recover recover_unix_socket_environment("01"); } #[test] fn test_socket_handler_sendstr() { // Pre test. Environment Preparation let (_, mut client, server) = prepare_unix_socket_environment("02"); let mut handler = SocketHandler::new(server.as_raw_fd()); // Send a `String` with fn `sendstr` in SocketHandler // 1.send str handler.send_str("I am a test str").unwrap(); let mut response = [0u8; 50]; let length = client.read(&mut response).unwrap(); assert_eq!( String::from_utf8_lossy(&response[..length]), "I am a test str\r\n".to_string() ); // 2.send String let message = String::from("I am a test String"); handler.send_str(&message).unwrap(); let length = client.read(&mut response).unwrap(); assert_eq!( String::from_utf8_lossy(&response[..length]), "I am a test String\r\n".to_string() ); // After test. Environment Recover recover_unix_socket_environment("02"); } #[derive(Serialize, Deserialize, PartialEq, Debug)] struct JsonTestStruct { name: String, age: u8, phones: Vec<String>, } #[test] fn test_socket_handler_json_parser() { // Pre test. Environment Preparation let (_, mut client, server) = prepare_unix_socket_environment("03"); let mut handler = SocketHandler::new(server.as_raw_fd()); // Use fn `decode_line` in `SocketHandler` to receive and parse msg to json struct // 1.msg without '\n' or 'EOF' let data = r#" { "name": "Lucky Dog", "age": 18, "phones": [ "+86 01234567890", "+86 09876543210" ] } "#; client.write(data.as_bytes()).unwrap(); let resp_json: JsonTestStruct = match handler.decode_line() { (Ok(buffer), _) => buffer.unwrap(), _ => panic!("Failed to decode line!"), }; assert_eq!( resp_json, JsonTestStruct { name: "Lucky Dog".to_string(), age: 18u8, phones: vec!["+86 01234567890".to_string(), "+86 09876543210".to_string()], }, ); // 2.msg with '\n' client.write(data.as_bytes()).unwrap(); client.write(b"\n").unwrap(); let resp_json: JsonTestStruct = match handler.decode_line() { (Ok(buffer), _) => buffer.unwrap(), _ => panic!("Failed to decode line!"), }; assert_eq!( resp_json, JsonTestStruct { name: "Lucky Dog".to_string(), age: 18u8, phones: vec!["+86 01234567890".to_string(), "+86 09876543210".to_string()], }, ); // After test. Environment Recover recover_unix_socket_environment("03"); } }