diff --git a/Cargo.lock b/Cargo.lock index bd08850..0d42ad1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -229,6 +229,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "core-foundation-sys" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" + [[package]] name = "crc32fast" version = "1.2.1" @@ -238,6 +244,50 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ed27e177f16d65f0f0c22a213e17c696ace5dd64b14258b52f9417ccb52db4" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6455c0ca19f0d2fbf751b908d5c55c1f5cbc65e03c4225427254b46890bdde1e" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec02e091aa634e2c3ada4a392989e7c3116673ef0ac5b72232439094d73b7fd" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "lazy_static", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d82cfc11ce7f2c3faef78d8a684447b40d503d9681acebed6cb728d45940c4db" +dependencies = [ + "cfg-if", + "lazy_static", +] + [[package]] name = "csv" version = "1.1.6" @@ -320,6 +370,12 @@ dependencies = [ "shared_child", ] +[[package]] +name = "either" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" + [[package]] name = "embed-resource" version = "1.6.5" @@ -419,6 +475,7 @@ dependencies = [ "shell-escape", "snafu", "structopt", + "sysinfo", "tar", "tempfile", "test-env-log", @@ -624,6 +681,15 @@ version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" +[[package]] +name = "memoffset" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59accc507f1338036a0477ef61afdae33cde60840f4dfe481319ce3ad116ddf9" +dependencies = [ + "autocfg", +] + [[package]] name = "miniz_oxide" version = "0.4.4" @@ -634,6 +700,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "ntapi" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f6bb902e437b6d86e03cce10a7e2af662292c5dfef23b65899ea3ac9354ad44" +dependencies = [ + "winapi", +] + [[package]] name = "num-integer" version = "0.1.44" @@ -653,6 +728,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "object" version = "0.26.2" @@ -824,6 +909,31 @@ dependencies = [ "rand_core", ] +[[package]] +name = "rayon" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06aca804d41dbc8ba42dfd964f0d01334eceb64314b9ecf7c5fad5188a06d90" +dependencies = [ + "autocfg", + "crossbeam-deque", + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d78120e2c850279833f1dd3582f730c4ab53ed95aeaaaa862a2a5c71b1656d8e" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-utils", + "lazy_static", + "num_cpus", +] + [[package]] name = "redox_syscall" version = "0.2.10" @@ -1098,6 +1208,21 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "sysinfo" +version = "0.20.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e223c65cd36b485a34c2ce6e38efa40777d31c4166d9076030c74cdcf971679f" +dependencies = [ + "cfg-if", + "core-foundation-sys", + "libc", + "ntapi", + "once_cell", + "rayon", + "winapi", +] + [[package]] name = "tar" version = "0.4.37" diff --git a/Cargo.toml b/Cargo.toml index a0636a9..72062b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ encoding_rs_io = "0.1.7" ureq = { version = "2.3.0", features = ["json"] } url = "2.2.2" brotli-decompressor = "2.3.2" +sysinfo = "0.20.5" [dev-dependencies] pretty_assertions = "1.0.0" diff --git a/src/shell/infer/mod.rs b/src/shell/infer/mod.rs index 1bffc2d..29696ba 100644 --- a/src/shell/infer/mod.rs +++ b/src/shell/infer/mod.rs @@ -1,11 +1,40 @@ -#[cfg(unix)] -pub mod unix; - -#[cfg(windows)] -pub mod windows; +use super::{Bash, Fish, PowerShell, Shell, WindowsCmd, Zsh}; +use log::debug; +use std::ffi::OsStr; +use sysinfo::{ProcessExt, System, SystemExt}; #[derive(Debug)] struct ProcessInfo { parent_pid: Option, command: String, } + +pub fn infer_shell() -> Option> { + let system = System::new_all(); + let hashmap = system.processes(); + let mut current_pid = sysinfo::get_current_pid().ok(); + + while let Some(pid) = current_pid { + if let Some(process) = hashmap.get(&pid) { + current_pid = process.parent(); + let process_name = process + .exe() + .file_stem() + .and_then(OsStr::to_str) + .map(str::to_lowercase); + let sliced = process_name.as_ref().map(|x| &x[..]); + match sliced { + Some("sh" | "bash") => return Some(Box::from(Bash)), + Some("zsh") => return Some(Box::from(Zsh)), + Some("fish") => return Some(Box::from(Fish)), + Some("pwsh" | "powershell") => return Some(Box::from(PowerShell)), + Some("cmd") => return Some(Box::from(WindowsCmd)), + cmd_name => debug!("binary is not a supported shell: {:?}", cmd_name), + }; + } else { + current_pid = None; + } + } + + None +} diff --git a/src/shell/infer/unix.rs b/src/shell/infer/unix.rs deleted file mode 100644 index 323a03e..0000000 --- a/src/shell/infer/unix.rs +++ /dev/null @@ -1,99 +0,0 @@ -#![cfg(unix)] - -use super::super::{Bash, Fish, PowerShell, Shell, Zsh}; -use log::debug; -use std::io::{Error, ErrorKind}; - -#[derive(Debug)] -struct ProcessInfo { - parent_pid: Option, - command: String, -} - -const MAX_ITERATIONS: u8 = 10; - -pub fn infer_shell() -> Option> { - let mut pid = Some(std::process::id()); - let mut visited = 0; - - while pid != None && visited < MAX_ITERATIONS { - let process_info = get_process_info(pid.unwrap()).ok()?; - let binary = process_info - .command - .trim_start_matches('-') - .split('/') - .last() - .expect("Can't read file name of process tree"); - - match binary { - "sh" | "bash" => return Some(Box::from(Bash)), - "zsh" => return Some(Box::from(Zsh)), - "fish" => return Some(Box::from(Fish)), - "pwsh" => return Some(Box::from(PowerShell)), - cmd_name => debug!("binary is not a supported shell: {:?}", cmd_name), - }; - - pid = process_info.parent_pid; - visited += 1; - } - - None -} - -fn get_process_info(pid: u32) -> std::io::Result { - use std::io::{BufRead, BufReader}; - use std::process::Command; - - let buffer = Command::new("ps") - .arg("-o") - .arg("ppid,comm") - .arg(pid.to_string()) - .stdout(std::process::Stdio::piped()) - .spawn()? - .stdout - .ok_or_else(|| Error::from(ErrorKind::UnexpectedEof))?; - - let mut lines = BufReader::new(buffer).lines(); - - // skip header line - lines - .next() - .ok_or_else(|| Error::from(ErrorKind::UnexpectedEof))??; - - let line = lines - .next() - .ok_or_else(|| Error::from(ErrorKind::NotFound))??; - - let mut parts = line.trim().split_whitespace(); - let ppid = parts - .next() - .expect("Can't read the ppid from ps, should be the first item in the table"); - let command = parts - .next() - .expect("Can't read the command from ps, should be the second item in the table"); - - Ok(ProcessInfo { - parent_pid: ppid.parse().ok(), - command: command.into(), - }) -} - -#[cfg(all(test, unix))] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - use std::process::{Command, Stdio}; - - #[test] - fn test_get_process_info() { - let subprocess = Command::new("bash") - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn() - .expect("Can't execute command"); - let process_info = get_process_info(subprocess.id()); - let parent_pid = process_info.ok().and_then(|x| x.parent_pid); - assert_eq!(parent_pid, Some(std::process::id())); - } -} diff --git a/src/shell/infer/windows.rs b/src/shell/infer/windows.rs deleted file mode 100644 index 17ea46d..0000000 --- a/src/shell/infer/windows.rs +++ /dev/null @@ -1,84 +0,0 @@ -#![cfg(windows)] - -use super::super::{Bash, PowerShell, Shell, WindowsCmd}; -use serde::Deserialize; -use std::collections::HashMap; - -#[derive(Deserialize, Debug)] -pub struct ProcessInfo { - #[serde(rename = "ExecutablePath")] - executable_path: Option, - #[serde(rename = "ParentProcessId")] - parent_pid: u32, - #[serde(rename = "ProcessId")] - pid: u32, -} - -pub fn infer_shell() -> Option> { - let process_map = get_process_map().ok()?; - let process_tree = get_process_tree(process_map, std::process::id()); - - for process in process_tree { - if let Some(exec_path) = process.executable_path { - match exec_path.file_name().and_then(|x| x.to_str()) { - Some("cmd.exe") => { - return Some(Box::from(WindowsCmd)); - } - Some("bash.exe") => { - return Some(Box::from(Bash)); - } - Some("powershell.exe") | Some("pwsh.exe") => { - return Some(Box::from(PowerShell)); - } - _ => {} - } - } - } - - None -} - -type ProcessMap = HashMap; - -pub fn get_process_tree(mut process_map: ProcessMap, pid: u32) -> Vec { - let mut vec = vec![]; - let mut current = process_map.remove(&pid); - - while let Some(process) = current { - current = process_map.remove(&process.parent_pid); - vec.push(process); - } - - vec -} - -pub fn get_process_map() -> std::io::Result { - let stdout = std::process::Command::new("wmic") - .args(&[ - "process", - "get", - "processid,parentprocessid,executablepath", - "/format:csv", - ]) - .stdout(std::process::Stdio::piped()) - .spawn()? - .stdout - .ok_or(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?; - - let mut reader = csv::Reader::from_reader(stdout); - let hashmap: HashMap<_, _> = reader - .deserialize::() - .filter_map(Result::ok) - .map(|x| (x.pid, x)) - .collect(); - Ok(hashmap) -} - -#[cfg(test)] -mod tests { - #[test] - fn test_me() { - let processes = super::get_process_map().unwrap(); - assert!(processes.contains_key(&std::process::id())); - } -} diff --git a/src/shell/mod.rs b/src/shell/mod.rs index feb83c5..4e27cd0 100644 --- a/src/shell/mod.rs +++ b/src/shell/mod.rs @@ -10,17 +10,8 @@ mod shell; pub use bash::Bash; pub use fish::Fish; +pub use infer::infer_shell; pub use powershell::PowerShell; pub use shell::{Shell, AVAILABLE_SHELLS}; pub use windows_cmd::WindowsCmd; pub use zsh::Zsh; - -#[cfg(windows)] -pub fn infer_shell() -> Option> { - self::infer::windows::infer_shell() -} - -#[cfg(unix)] -pub fn infer_shell() -> Option> { - infer::unix::infer_shell() -}