// Copyright (c) 2023 Huawei Technologies Co.,Ltd. All rights reserved.
//
// StratoVirt is licensed under Mulan PSL v2.
// You can use this software according to the terms and conditions of the Mulan
// PSL v2.
// You may obtain a copy of Mulan PSL v2 at:
//         http://license.coscl.org.cn/MulanPSL2
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
// See the Mulan PSL v2 for more details.

use std::io::{IoSlice, IoSliceMut, Result as IoResult};
use std::net::{TcpListener, TcpStream};
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::{UnixListener, UnixStream};

use anyhow::Result;

/// Provide socket abstraction for UnixStream and TcpStream.
#[derive(Debug)]
pub enum SocketStream {
    Tcp {
        link_description: String,
        stream: TcpStream,
    },
    Unix {
        link_description: String,
        stream: UnixStream,
    },
}

impl SocketStream {
    pub fn link_description(&self) -> String {
        match self {
            SocketStream::Tcp {
                link_description, ..
            } => link_description.clone(),
            SocketStream::Unix {
                link_description, ..
            } => link_description.clone(),
        }
    }
}

impl AsRawFd for SocketStream {
    fn as_raw_fd(&self) -> RawFd {
        match self {
            SocketStream::Tcp { stream, .. } => stream.as_raw_fd(),
            SocketStream::Unix { stream, .. } => stream.as_raw_fd(),
        }
    }
}

impl std::io::Read for SocketStream {
    fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
        match self {
            SocketStream::Tcp { stream, .. } => stream.read(buf),
            SocketStream::Unix { stream, .. } => stream.read(buf),
        }
    }

    fn read_vectored(&mut self, bufs: &mut [IoSliceMut]) -> IoResult<usize> {
        match self {
            SocketStream::Tcp { stream, .. } => stream.read_vectored(bufs),
            SocketStream::Unix { stream, .. } => stream.read_vectored(bufs),
        }
    }
}

impl std::io::Write for SocketStream {
    fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
        match self {
            SocketStream::Tcp { stream, .. } => stream.write(buf),
            SocketStream::Unix { stream, .. } => stream.write(buf),
        }
    }

    fn write_vectored(&mut self, bufs: &[IoSlice]) -> IoResult<usize> {
        match self {
            SocketStream::Tcp { stream, .. } => stream.write_vectored(bufs),
            SocketStream::Unix { stream, .. } => stream.write_vectored(bufs),
        }
    }

    fn flush(&mut self) -> IoResult<()> {
        match self {
            SocketStream::Tcp { stream, .. } => stream.flush(),
            SocketStream::Unix { stream, .. } => stream.flush(),
        }
    }
}

/// Provide listener abstraction for UnixListener and TcpListener.
#[derive(Debug)]
pub enum SocketListener {
    Tcp {
        address: String,
        listener: TcpListener,
    },
    Unix {
        address: String,
        listener: UnixListener,
    },
}

impl SocketListener {
    pub fn bind_by_tcp(host: &str, port: u16) -> Result<Self> {
        let address = format!("{}:{}", &host, &port);
        let listener = TcpListener::bind(&address)?;
        listener.set_nonblocking(true)?;
        Ok(SocketListener::Tcp { address, listener })
    }

    pub fn bind_by_uds(path: &str) -> Result<Self> {
        let listener = UnixListener::bind(path)?;
        listener.set_nonblocking(true)?;
        Ok(SocketListener::Unix {
            address: String::from(path),
            listener,
        })
    }

    pub fn address(&self) -> String {
        match self {
            SocketListener::Tcp { address, .. } => address.clone(),
            SocketListener::Unix { address, .. } => address.clone(),
        }
    }

    pub fn accept(&self) -> Result<SocketStream> {
        match self {
            SocketListener::Tcp { listener, address } => {
                let (stream, sock_addr) = listener.accept()?;
                let peer_address = sock_addr.to_string();
                let link_description = format!(
                    "{{ protocol: tcp, address: {}, peer: {} }}",
                    address, peer_address
                );
                Ok(SocketStream::Tcp {
                    link_description,
                    stream,
                })
            }
            SocketListener::Unix { listener, address } => {
                let (stream, _) = listener.accept()?;
                let link_description = format!("{{ protocol: unix, address: {} }}", address);
                Ok(SocketStream::Unix {
                    link_description,
                    stream,
                })
            }
        }
    }
}

impl AsRawFd for SocketListener {
    fn as_raw_fd(&self) -> RawFd {
        match self {
            SocketListener::Tcp { listener, .. } => listener.as_raw_fd(),
            SocketListener::Unix { listener, .. } => listener.as_raw_fd(),
        }
    }
}