// Copyright (c) 2022 Huawei Technologies Co.,Ltd. All rights reserved.
//
// StratoVirt is licensed under Mulan PSL v2.
// You can use this software according to the terms and conditions of the Mulan
// PSL v2.
// You may obtain a copy of Mulan PSL v2 at:
//         http://license.coscl.org.cn/MulanPSL2
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
// See the Mulan PSL v2 for more details.

use crate::{
    client::{ClientIoHandler, APP_NAME},
    VncError,
};
use anyhow::{anyhow, Result};
use libc::{c_char, c_int, c_uint, c_void};
use log::{error, info};
use sasl2_sys::prelude::{
    sasl_conn_t, sasl_dispose, sasl_getprop, sasl_listmech, sasl_security_properties_t,
    sasl_server_init, sasl_server_new, sasl_server_start, sasl_server_step, sasl_setprop,
    sasl_ssf_t, SASL_CONTINUE, SASL_OK, SASL_SEC_PROPS, SASL_SSF, SASL_SSF_EXTERNAL,
    SASL_SUCCESS_DATA,
};
use sasl2_sys::sasl::SASL_USERNAME;
use std::ffi::{CStr, CString};
use std::ptr;
use util::byte_code::ByteCode;

/// Vnc Service.
const SERVICE: &str = "vnc";
const MECHNAME_MAX_LEN: u32 = 100;
const MECHNAME_MIN_LEN: u32 = 1;
const SASL_DATA_MAX_LEN: u32 = 1024 * 1024;
/// Minimum supported encryption length of ssf layer in sasl.
const MIN_SSF_LENGTH: usize = 56;

/// Authentication type
#[derive(Clone, Copy)]
pub enum AuthState {
    Invalid = 0,
    No = 1,
    Vnc = 2,
    Vencrypt = 19,
    Sasl = 20,
}

/// Authentication and encryption method.
#[derive(Clone, Copy)]
pub enum SubAuthState {
    /// Send plain Message + no auth.
    VncAuthVencryptPlain = 256,
    /// Tls vencrypt with x509 + no auth.
    VncAuthVencryptX509None = 260,
    /// Tls vencrypt with x509 + sasl.
    VncAuthVencryptX509Sasl = 263,
    /// Tls vencrypt + sasl.
    VncAuthVencryptTlssasl = 264,
}

/// Configuration for authentication.
/// Identity: authentication user.
#[derive(Debug, Clone)]
pub struct SaslAuth {
    pub identity: String,
}

impl SaslAuth {
    pub fn new(identity: String) -> Self {
        SaslAuth { identity }
    }
}

/// Struct of sasl authentiation.
#[derive(Debug, Clone)]
pub struct SaslConfig {
    /// State of sasl connection .
    pub sasl_conn: *mut sasl_conn_t,
    /// Mech list server support.
    pub mech_list: String,
    /// Authentication mechanism currently in use.
    pub mech_name: String,
    /// State of auth.
    pub sasl_stage: SaslStage,
    /// Security layer in sasl.
    pub want_ssf: bool,
    /// Strength of ssf.
    pub run_ssf: u32,
}

impl Default for SaslConfig {
    fn default() -> Self {
        SaslConfig {
            sasl_conn: ptr::null_mut() as *mut sasl_conn_t,
            mech_list: String::new(),
            mech_name: String::new(),
            sasl_stage: SaslStage::SaslServerStart,
            want_ssf: false,
            run_ssf: 0,
        }
    }
}

/// Authentication stage.
#[derive(Clone, Copy, PartialEq, Debug)]
pub enum SaslStage {
    SaslServerStart,
    SaslServerStep,
}

impl ClientIoHandler {
    /// Get length of mechname send form client.
    pub fn get_mechname_length(&mut self) -> Result<()> {
        let buf = self.read_incoming_msg();
        let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);

        if len > MECHNAME_MAX_LEN {
            error!("SASL mechname too long");
            return Err(anyhow!(VncError::AuthFailed(String::from(
                "SASL mechname too long"
            ))));
        }
        if len < MECHNAME_MIN_LEN {
            error!("SASL mechname too short");
            return Err(anyhow!(VncError::AuthFailed(String::from(
                "SASL mechname too short"
            ))));
        }

        self.update_event_handler(len as usize, ClientIoHandler::get_sasl_mechname);
        Ok(())
    }

    /// Start sasl authentication.
    /// 1. Sals server init.
    /// 2. Get the mechlist support by Sasl server.
    /// 3. Send the mechlist to client.
    pub fn start_sasl_auth(&mut self) -> Result<()> {
        if let Err(e) = self.sasl_server_init() {
            return Err(e);
        }

        if let Err(e) = self.set_ssf_for_sasl() {
            return Err(e);
        }

        if let Err(e) = self.send_mech_list() {
            return Err(e);
        }

        Ok(())
    }

    /// Get authentication mechanism supported by client.
    pub fn get_sasl_mechname(&mut self) -> Result<()> {
        let buf = self.read_incoming_msg();
        let mech_name = String::from_utf8_lossy(&buf).to_string();

        let mut locked_security = self.server.security_type.lock().unwrap();
        let mech_list: Vec<&str> = locked_security.saslconfig.mech_list.split(',').collect();
        for mech in mech_list {
            if mech_name == *mech {
                locked_security.saslconfig.mech_name = mech_name;
                break;
            }
        }
        // Unsupported mechanism.
        if locked_security.saslconfig.mech_name.is_empty() {
            return Err(anyhow!(VncError::AuthFailed(
                "Unsupported mechanism".to_string()
            )));
        }
        drop(locked_security);

        self.update_event_handler(4, ClientIoHandler::get_authmessage_length);
        Ok(())
    }

    /// Length of client authentication message.
    pub fn get_authmessage_length(&mut self) -> Result<()> {
        let buf = self.read_incoming_msg();
        let buf = [buf[0], buf[1], buf[2], buf[3]];
        let len = u32::from_be_bytes(buf);

        if len > SASL_DATA_MAX_LEN {
            error!("SASL start len too large");
            return Err(anyhow!(VncError::AuthFailed(
                "SASL start len too large".to_string()
            )));
        }

        if len == 0 {
            return self.client_sasl_auth();
        }
        self.update_event_handler(len as usize, ClientIoHandler::client_sasl_auth);
        Ok(())
    }

    /// Receive the authentication information from client and return the result.
    pub fn client_sasl_auth(&mut self) -> Result<()> {
        info!("Sasl Authentication");
        let buf = self.read_incoming_msg();

        let mut client_data = buf.to_vec();
        let mut client_len: c_uint = 0;
        if self.expect > 0 {
            client_len = (self.expect - 1) as c_uint;
            client_data[self.expect - 1] = 0_u8;
        }

        let server = self.server.clone();
        let mut locked_security = server.security_type.lock().unwrap();
        let err: c_int;
        let mut serverout: *const c_char = ptr::null_mut();
        let mut serverout_len: c_uint = 0;
        let mech_name = CString::new(locked_security.saslconfig.mech_name.as_str()).unwrap();

        // Start authentication.
        if locked_security.saslconfig.sasl_stage == SaslStage::SaslServerStart {
            unsafe {
                err = sasl_server_start(
                    locked_security.saslconfig.sasl_conn,
                    mech_name.as_ptr(),
                    client_data.as_ptr() as *const c_char,
                    client_len,
                    &mut serverout,
                    &mut serverout_len,
                )
            }
        } else {
            unsafe {
                err = sasl_server_step(
                    locked_security.saslconfig.sasl_conn,
                    client_data.as_ptr() as *const c_char,
                    client_len,
                    &mut serverout,
                    &mut serverout_len,
                )
            }
        }

        if err != SASL_OK && err != SASL_CONTINUE {
            unsafe { sasl_dispose(&mut locked_security.saslconfig.sasl_conn) }
            error!("Auth failed!");
            return Err(anyhow!(VncError::AuthFailed("Auth failed!".to_string())));
        }
        if serverout_len > SASL_DATA_MAX_LEN {
            unsafe { sasl_dispose(&mut locked_security.saslconfig.sasl_conn) }
            error!("SASL data too long");
            return Err(anyhow!(VncError::AuthFailed(
                "SASL data too long".to_string()
            )));
        }

        let mut buf = Vec::new();
        if serverout_len > 0 {
            // Authentication related information.
            let serverout = unsafe { CStr::from_ptr(serverout as *const c_char) };
            let auth_message = String::from(serverout.to_str().unwrap());
            buf.append(&mut ((serverout_len + 1) as u32).to_be_bytes().to_vec());
            buf.append(&mut auth_message.as_bytes().to_vec());
        } else {
            buf.append(&mut (0_u32).to_be_bytes().to_vec());
        }

        if err == SASL_OK {
            buf.append(&mut (1_u8).as_bytes().to_vec());
        } else if err == SASL_CONTINUE {
            buf.append(&mut (0_u8).as_bytes().to_vec());
        }
        drop(locked_security);

        if err == SASL_CONTINUE {
            // Authentication continue.
            let mut locked_security = server.security_type.lock().unwrap();
            locked_security.saslconfig.sasl_stage = SaslStage::SaslServerStep;
            self.update_event_handler(4, ClientIoHandler::get_authmessage_length);
            drop(locked_security);
            return Ok(());
        } else {
            if let Err(err) = self.sasl_check_ssf() {
                // Reject auth: the strength of ssf is too weak.
                auth_reject(&mut buf);
                self.write_msg(&buf);
                return Err(err);
            }

            if let Err(err) = self.sasl_check_authz() {
                // Reject auth: wrong sasl username.
                auth_reject(&mut buf);
                self.write_msg(&buf);
                return Err(err);
            }
            // Accpet auth.
            buf.append(&mut (0_u32).as_bytes().to_vec());
        }

        self.write_msg(&buf);
        self.update_event_handler(1, ClientIoHandler::handle_client_init);
        Ok(())
    }

    /// Sasl server init.
    fn sasl_server_init(&mut self) -> Result<()> {
        let mut err: c_int;
        let service = CString::new(SERVICE).unwrap();
        let appname = CString::new(APP_NAME).unwrap();
        let local_addr = self
            .stream
            .local_addr()
            .unwrap()
            .to_string()
            .replace(":", ";");
        let remote_addr = self
            .stream
            .peer_addr()
            .unwrap()
            .to_string()
            .replace(":", ";");
        info!("local_addr: {} remote_addr: {}", local_addr, remote_addr);
        let local_addr = CString::new(local_addr).unwrap();
        let remote_addr = CString::new(remote_addr).unwrap();
        // Sasl server init.
        unsafe {
            err = sasl_server_init(ptr::null_mut(), appname.as_ptr());
        }
        if err != SASL_OK {
            error!("SASL_FAIL error code {}", err);
            return Err(anyhow!(VncError::AuthFailed(format!(
                "SASL_FAIL error code {}",
                err
            ))));
        }
        let mut saslconfig = SaslConfig::default();
        unsafe {
            err = sasl_server_new(
                service.as_ptr(),
                ptr::null_mut(),
                ptr::null_mut(),
                local_addr.as_ptr(),
                remote_addr.as_ptr(),
                ptr::null_mut(),
                SASL_SUCCESS_DATA,
                &mut saslconfig.sasl_conn,
            );
        }
        if err != SASL_OK {
            error!("SASL_FAIL error code {}", err);
            return Err(anyhow!(VncError::AuthFailed(format!(
                "SASL_FAIL error code {}",
                err
            ))));
        }
        self.server.security_type.lock().unwrap().saslconfig = saslconfig;

        Ok(())
    }

    /// Set properties for sasl.
    fn set_ssf_for_sasl(&mut self) -> Result<()> {
        // Set the relevant properties of sasl.
        let mut err: c_int;
        let ssf: sasl_ssf_t = 256;
        let ssf = &ssf as *const sasl_ssf_t;
        let locked_security = self.server.security_type.lock().unwrap();
        unsafe {
            err = sasl_setprop(
                locked_security.saslconfig.sasl_conn,
                SASL_SSF_EXTERNAL as i32,
                ssf as *const c_void,
            );
        }
        if err != SASL_OK {
            error!("SASL_FAIL error code {}", err);
            return Err(anyhow!(VncError::AuthFailed(format!(
                "SASL_FAIL error code {}",
                err
            ))));
        }

        // Already using tls, disable ssf in sasl.
        let props_name = ptr::null_mut() as *mut *const c_char;
        let props_value = ptr::null_mut() as *mut *const c_char;
        let saslprops = sasl_security_properties_t {
            min_ssf: 0,
            max_ssf: 0,
            maxbufsize: 8192,
            security_flags: 0,
            property_names: props_name,
            property_values: props_value,
        };

        let props = &saslprops as *const sasl_security_properties_t;
        unsafe {
            err = sasl_setprop(
                locked_security.saslconfig.sasl_conn,
                SASL_SEC_PROPS.try_into().unwrap(),
                props as *const c_void,
            );
        }
        if err != SASL_OK {
            error!("SASL_FAIL error code {}", err);
            return Err(anyhow!(VncError::AuthFailed(format!(
                "SASL_FAIL error code {}",
                err
            ))));
        }

        Ok(())
    }

    /// Get the mechlist support by Sasl server.
    /// Send the mechlist to client.
    fn send_mech_list(&mut self) -> Result<()> {
        let err: c_int;
        let prefix = CString::new("").unwrap();
        let sep = CString::new(",").unwrap();
        let suffix = CString::new("").unwrap();
        let mut mechlist: *const c_char = ptr::null_mut();
        let mut locked_security = self.server.security_type.lock().unwrap();
        unsafe {
            err = sasl_listmech(
                locked_security.saslconfig.sasl_conn,
                ptr::null_mut(),
                prefix.as_ptr(),
                sep.as_ptr(),
                suffix.as_ptr(),
                &mut mechlist,
                ptr::null_mut(),
                ptr::null_mut(),
            );
        }
        if err != SASL_OK || mechlist.is_null() {
            error!("SASL_FAIL: no support sasl mechlist");
            return Err(anyhow!(VncError::AuthFailed(
                "SASL_FAIL: no support sasl mechlist".to_string()
            )));
        }
        let mech_list = unsafe { CStr::from_ptr(mechlist as *const c_char) };
        locked_security.saslconfig.mech_list = String::from(mech_list.to_str().unwrap());
        let mut buf = Vec::new();
        let len = locked_security.saslconfig.mech_list.len();
        buf.append(&mut (len as u32).to_be_bytes().to_vec());
        buf.append(&mut locked_security.saslconfig.mech_list.as_bytes().to_vec());
        drop(locked_security);
        self.write_msg(&buf);

        Ok(())
    }

    /// Check whether the ssf layer of sasl meets the strength requirements.
    fn sasl_check_ssf(&mut self) -> Result<()> {
        let server = self.server.clone();
        let mut locked_security = server.security_type.lock().unwrap();
        if !locked_security.saslconfig.want_ssf {
            return Ok(());
        }
        let err: c_int;
        let mut val: *const c_void = ptr::null_mut();
        unsafe {
            err = sasl_getprop(
                locked_security.saslconfig.sasl_conn,
                SASL_SSF as c_int,
                &mut val,
            )
        }
        if err != SASL_OK {
            error!("sasl_getprop: internal error");
            return Err(anyhow!(VncError::AuthFailed(String::from(
                "sasl_getprop: internal error"
            ))));
        }

        let ssf: usize = unsafe { *(val as *const usize) };
        if ssf < MIN_SSF_LENGTH {
            error!("SASL SSF too weak");
            return Err(anyhow!(VncError::AuthFailed(String::from(
                "SASL SSF too weak"
            ))));
        }

        locked_security.saslconfig.run_ssf = 1;
        drop(locked_security);
        Ok(())
    }

    /// Check username.
    fn sasl_check_authz(&mut self) -> Result<()> {
        let locked_security = self.server.security_type.lock().unwrap();
        let mut val: *const c_void = ptr::null_mut();
        let err = unsafe {
            sasl_getprop(
                locked_security.saslconfig.sasl_conn,
                SASL_USERNAME as c_int,
                &mut val,
            )
        };
        drop(locked_security);
        if err != SASL_OK {
            return Err(anyhow!(VncError::AuthFailed(String::from(
                "Cannot fetch SASL username"
            ))));
        }
        if val.is_null() {
            return Err(anyhow!(VncError::AuthFailed(String::from(
                "No SASL username set"
            ))));
        }
        let username = unsafe { CStr::from_ptr(val as *const c_char) };
        let username = String::from(username.to_str().unwrap());

        let server = self.server.clone();
        let locked_security = server.security_type.lock().unwrap();
        if let Some(saslauth) = &locked_security.saslauth {
            if saslauth.identity != username {
                return Err(anyhow!(VncError::AuthFailed(String::from(
                    "No SASL username set"
                ))));
            }
        } else {
            return Err(anyhow!(VncError::AuthFailed(String::from(
                "No SASL username set"
            ))));
        }
        drop(locked_security);

        Ok(())
    }
}

/// Auth reject.
fn auth_reject(buf: &mut Vec<u8>) {
    let reason = String::from("Authentication failed");
    buf.append(&mut (1_u32).to_be_bytes().to_vec());
    buf.append(&mut (reason.len() as u32).to_be_bytes().to_vec());
    buf.append(&mut reason.as_bytes().to_vec());
}